Skip to main content

gitcortex_store/
kuzu.rs

1use std::path::{Path, PathBuf};
2
3use gitcortex_core::{
4    error::{GitCortexError, Result},
5    graph::{Edge, GraphDiff, Node, NodeId, NodeMetadata, Span},
6    schema::{EdgeKind, NodeKind, Visibility},
7    store::GraphStore,
8};
9use kuzu::{Connection, Database, SystemConfig, Value};
10
11use crate::{branch, schema as db_schema};
12
13// ── KuzuGraphStore ────────────────────────────────────────────────────────────
14
15/// Local KuzuDB-backed implementation of [`GraphStore`].
16///
17/// One database file per repo (`graph.kuzu`), with per-branch node/edge tables
18/// inside it. A fresh `Connection` is created for each operation so we avoid
19/// the self-referential lifetime that `Mutex<Connection<'db>>` would require.
20pub struct KuzuGraphStore {
21    db: Database,
22    repo_id: String,
23}
24
25impl KuzuGraphStore {
26    /// Open (or create) the graph database for the repo at `repo_root`.
27    pub fn open(repo_root: &Path) -> Result<Self> {
28        let repo_id = branch::repo_id(repo_root);
29        let db_path = branch::db_path(&repo_id);
30
31        if let Some(parent) = db_path.parent() {
32            std::fs::create_dir_all(parent)?;
33        }
34
35        let db = Database::new(&db_path, SystemConfig::default())
36            .map_err(|e| GitCortexError::Store(format!("open db: {e}")))?;
37
38        Ok(Self { db, repo_id })
39    }
40
41    // ── Private helpers ───────────────────────────────────────────────────────
42
43    fn conn(&self) -> Result<Connection<'_>> {
44        Connection::new(&self.db)
45            .map_err(|e| GitCortexError::Store(format!("open connection: {e}")))
46    }
47
48    fn ensure_branch(&self, branch: &str) -> Result<()> {
49        let mut conn = self.conn()?;
50        db_schema::ensure_branch(&mut conn, branch)
51    }
52}
53
54// ── GraphStore impl ───────────────────────────────────────────────────────────
55
56impl GraphStore for KuzuGraphStore {
57    // ── Write path ────────────────────────────────────────────────────────────
58
59    fn apply_diff(&mut self, branch: &str, diff: &GraphDiff) -> Result<()> {
60        if diff.is_empty() {
61            return Ok(());
62        }
63
64        self.ensure_branch(branch)?;
65        let nt = db_schema::node_table(branch);
66        let et = db_schema::edge_table(branch);
67        let conn = self.conn()?;
68
69        // Use explicit transactions so Phase 1 (node inserts) is committed and
70        // visible before Phase 2 (edge MATCHes) begins — required for KuzuDB's
71        // MVCC snapshot isolation to work correctly.
72        conn.query("BEGIN TRANSACTION")
73            .map_err(|e| GitCortexError::Store(format!("begin transaction: {e}")))?;
74
75        // 1. Remove all nodes (and their edges) for deleted/replaced files.
76        for file in &diff.removed_files {
77            let file_str = esc(file.to_string_lossy().as_ref());
78            conn.query(&format!(
79                "MATCH (n:{nt}) WHERE n.file = '{file_str}' DETACH DELETE n"
80            ))
81            .map_err(|e| GitCortexError::Store(format!("delete file nodes: {e}")))?;
82        }
83
84        // 2. Remove explicit node IDs.
85        for id in &diff.removed_node_ids {
86            let id_str = esc(&id.as_str());
87            conn.query(&format!(
88                "MATCH (n:{nt}) WHERE n.id = '{id_str}' DETACH DELETE n"
89            ))
90            .map_err(|e| GitCortexError::Store(format!("delete node: {e}")))?;
91        }
92
93        // 3. Remove explicit edges.
94        for (src, dst, kind) in &diff.removed_edges {
95            let s = esc(&src.as_str());
96            let d = esc(&dst.as_str());
97            let k = esc(&kind.to_string());
98            conn.query(&format!(
99                "MATCH (s:{nt})-[e:{et}]->(d:{nt}) \
100                 WHERE s.id = '{s}' AND d.id = '{d}' AND e.kind = '{k}' \
101                 DELETE e"
102            ))
103            .map_err(|e| GitCortexError::Store(format!("delete edge: {e}")))?;
104        }
105
106        // 4. Insert new nodes.
107        for node in &diff.added_nodes {
108            let id = esc(&node.id.as_str());
109            let kind = esc(&node.kind.to_string());
110            let name = esc(&node.name);
111            let qname = esc(&node.qualified_name);
112            let file = esc(node.file.to_string_lossy().as_ref());
113            let sl = node.span.start_line as i64;
114            let el = node.span.end_line as i64;
115            let loc = node.metadata.loc as i64;
116            let vis = esc(&vis_str(&node.metadata.visibility));
117            let is_async = node.metadata.is_async;
118            let is_unsafe = node.metadata.is_unsafe;
119
120            conn.query(&format!(
121                "CREATE (:{nt} {{\
122                    id: '{id}', kind: '{kind}', name: '{name}', \
123                    qualified_name: '{qname}', file: '{file}', \
124                    start_line: {sl}, end_line: {el}, loc: {loc}, \
125                    visibility: '{vis}', is_async: {is_async}, is_unsafe: {is_unsafe}\
126                }})"
127            ))
128            .map_err(|e| GitCortexError::Store(format!("insert node '{name}': {e}")))?;
129        }
130
131        // Commit node inserts so the edge MATCH queries in steps 5–6 see them.
132        conn.query("COMMIT")
133            .map_err(|e| GitCortexError::Store(format!("commit nodes: {e}")))?;
134
135        conn.query("BEGIN TRANSACTION")
136            .map_err(|e| GitCortexError::Store(format!("begin edge transaction: {e}")))?;
137
138        // 5. Insert new edges. If either endpoint is absent (cross-file), MATCH
139        //    yields no rows and the CREATE is silently skipped — correct behaviour.
140        for edge in &diff.added_edges {
141            let s = esc(&edge.src.as_str());
142            let d = esc(&edge.dst.as_str());
143            let k = esc(&edge.kind.to_string());
144
145            conn.query(&format!(
146                "MATCH (s:{nt} {{id: '{s}'}}), (d:{nt} {{id: '{d}'}}) \
147                 CREATE (s)-[:{et} {{kind: '{k}'}}]->(d)"
148            ))
149            .map_err(|e| GitCortexError::Store(format!("insert edge: {e}")))?;
150        }
151
152        // 6. Resolve cross-file deferred edges against the full store.
153        //    The diff-local pass couldn't find these callees/types because they
154        //    live in unchanged files. We match by name here — best-effort without
155        //    full type inference, filtered to the correct node kinds to reduce noise.
156
157        for (caller_id, callee_name) in &diff.deferred_calls {
158            let caller = esc(&caller_id.as_str());
159            let callee = esc(callee_name);
160            conn.query(&format!(
161                "MATCH (caller:{nt} {{id: '{caller}'}}), (callee:{nt}) \
162                 WHERE callee.name = '{callee}' \
163                 AND (callee.kind = 'function' OR callee.kind = 'method') \
164                 CREATE (caller)-[:{et} {{kind: 'calls'}}]->(callee)"
165            ))
166            .map_err(|e| GitCortexError::Store(format!("deferred call '{callee_name}': {e}")))?;
167        }
168
169        for (fn_id, type_name) in &diff.deferred_uses {
170            let fn_esc = esc(&fn_id.as_str());
171            let ty = esc(type_name);
172            conn.query(&format!(
173                "MATCH (fn_node:{nt} {{id: '{fn_esc}'}}), (ty:{nt}) \
174                 WHERE ty.name = '{ty}' \
175                 AND (ty.kind = 'struct' OR ty.kind = 'enum' \
176                      OR ty.kind = 'trait' OR ty.kind = 'type_alias') \
177                 CREATE (fn_node)-[:{et} {{kind: 'uses'}}]->(ty)"
178            ))
179            .map_err(|e| GitCortexError::Store(format!("deferred use '{type_name}': {e}")))?;
180        }
181
182        for (struct_id, trait_name) in &diff.deferred_implements {
183            let s = esc(&struct_id.as_str());
184            let t = esc(trait_name);
185            conn.query(&format!(
186                "MATCH (st:{nt} {{id: '{s}'}}), (tr:{nt}) \
187                 WHERE tr.name = '{t}' AND tr.kind = 'trait' \
188                 CREATE (st)-[:{et} {{kind: 'implements'}}]->(tr)"
189            ))
190            .map_err(|e| GitCortexError::Store(format!("deferred impl '{trait_name}': {e}")))?;
191        }
192
193        conn.query("COMMIT")
194            .map_err(|e| GitCortexError::Store(format!("commit edges: {e}")))?;
195
196        Ok(())
197    }
198
199    // ── Read path ─────────────────────────────────────────────────────────────
200
201    fn lookup_symbol(&self, branch: &str, name: &str) -> Result<Vec<Node>> {
202        self.ensure_branch(branch)?;
203        let nt = db_schema::node_table(branch);
204        let name_esc = esc(name);
205        let conn = self.conn()?;
206
207        let mut result = conn
208            .query(&format!(
209                "MATCH (n:{nt}) WHERE n.name = '{name_esc}' \
210                 RETURN {NODE_COLS}"
211            ))
212            .map_err(|e| GitCortexError::Store(e.to_string()))?;
213
214        rows_to_nodes(&mut result)
215    }
216
217    fn find_callers(&self, branch: &str, function_name: &str) -> Result<Vec<Node>> {
218        self.ensure_branch(branch)?;
219        let nt = db_schema::node_table(branch);
220        let et = db_schema::edge_table(branch);
221        let name_esc = esc(function_name);
222        let conn = self.conn()?;
223
224        let mut result = conn
225            .query(&format!(
226                "MATCH (caller:{nt})-[e:{et} {{kind: 'calls'}}]->(callee:{nt}) \
227                 WHERE callee.name = '{name_esc}' \
228                 RETURN caller.id, caller.kind, caller.name, caller.qualified_name, \
229                        caller.file, caller.start_line, caller.end_line, caller.loc, \
230                        caller.visibility, caller.is_async, caller.is_unsafe"
231            ))
232            .map_err(|e| GitCortexError::Store(e.to_string()))?;
233
234        rows_to_nodes(&mut result)
235    }
236
237    fn list_definitions(&self, branch: &str, file: &Path) -> Result<Vec<Node>> {
238        self.ensure_branch(branch)?;
239        let nt = db_schema::node_table(branch);
240        let file_esc = esc(file.to_string_lossy().as_ref());
241        let conn = self.conn()?;
242
243        let mut result = conn
244            .query(&format!(
245                "MATCH (n:{nt}) WHERE n.file = '{file_esc}' \
246                 RETURN {NODE_COLS} ORDER BY n.start_line"
247            ))
248            .map_err(|e| GitCortexError::Store(e.to_string()))?;
249
250        rows_to_nodes(&mut result)
251    }
252
253    fn branch_diff(&self, from: &str, to: &str) -> Result<GraphDiff> {
254        self.ensure_branch(from)?;
255        self.ensure_branch(to)?;
256
257        let from_nt = db_schema::node_table(from);
258        let to_nt = db_schema::node_table(to);
259        let mut conn = self.conn()?;
260
261        // Collect node IDs from each branch.
262        let from_ids = collect_ids(&mut conn, &from_nt)?;
263        let to_ids = collect_ids(&mut conn, &to_nt)?;
264
265        // Nodes in `to` but not in `from` → added.
266        let added_ids: Vec<&String> = to_ids.iter().filter(|id| !from_ids.contains(*id)).collect();
267
268        // Nodes in `from` but not in `to` → removed.
269        let removed_ids: Vec<&String> =
270            from_ids.iter().filter(|id| !to_ids.contains(*id)).collect();
271
272        let mut diff = GraphDiff::default();
273
274        for id in added_ids {
275            let id_esc = esc(id);
276            let mut r = conn
277                .query(&format!(
278                    "MATCH (n:{to_nt}) WHERE n.id = '{id_esc}' RETURN {NODE_COLS}"
279                ))
280                .map_err(|e| GitCortexError::Store(e.to_string()))?;
281            diff.added_nodes.extend(rows_to_nodes(&mut r)?);
282        }
283
284        for id in removed_ids {
285            if let Ok(node_id) = NodeId::try_from(id.as_str()) {
286                diff.removed_node_ids.push(node_id);
287            }
288        }
289
290        Ok(diff)
291    }
292
293    fn list_all_nodes(&self, branch: &str) -> Result<Vec<Node>> {
294        self.ensure_branch(branch)?;
295        let nt = db_schema::node_table(branch);
296        let conn = self.conn()?;
297        let mut result = conn
298            .query(&format!("MATCH (n:{nt}) RETURN {NODE_COLS}"))
299            .map_err(|e| GitCortexError::Store(e.to_string()))?;
300        rows_to_nodes(&mut result)
301    }
302
303    fn list_all_edges(&self, branch: &str) -> Result<Vec<Edge>> {
304        self.ensure_branch(branch)?;
305        let nt = db_schema::node_table(branch);
306        let et = db_schema::edge_table(branch);
307        let conn = self.conn()?;
308        let result = conn
309            .query(&format!(
310                "MATCH (s:{nt})-[e:{et}]->(d:{nt}) RETURN s.id, d.id, e.kind"
311            ))
312            .map_err(|e| GitCortexError::Store(e.to_string()))?;
313
314        let mut out = Vec::new();
315        for row in result {
316            let src_str = str_val(&row[0])?;
317            let dst_str = str_val(&row[1])?;
318            let kind_str = str_val(&row[2])?;
319            out.push(Edge {
320                src: NodeId::try_from(src_str.as_str())
321                    .map_err(|e| GitCortexError::Store(format!("bad src id: {e}")))?,
322                dst: NodeId::try_from(dst_str.as_str())
323                    .map_err(|e| GitCortexError::Store(format!("bad dst id: {e}")))?,
324                kind: edge_kind_from_str(&kind_str),
325            });
326        }
327        Ok(out)
328    }
329
330    // ── Indexing state ────────────────────────────────────────────────────────
331
332    fn last_indexed_sha(&self, branch_name: &str) -> Result<Option<String>> {
333        branch::read_last_sha(&self.repo_id, branch_name)
334    }
335
336    fn set_last_indexed_sha(&mut self, branch_name: &str, sha: &str) -> Result<()> {
337        branch::write_last_sha(&self.repo_id, branch_name, sha)
338    }
339}
340
341// ── Query helpers ─────────────────────────────────────────────────────────────
342
343/// Fixed column projection used in all node-returning queries.
344/// Order must match `row_to_node()`.
345const NODE_COLS: &str = "n.id, n.kind, n.name, n.qualified_name, n.file, \
346     n.start_line, n.end_line, n.loc, n.visibility, n.is_async, n.is_unsafe";
347
348fn rows_to_nodes(result: &mut kuzu::QueryResult) -> Result<Vec<Node>> {
349    let mut nodes = Vec::new();
350    for row in result.by_ref() {
351        nodes.push(row_to_node(row)?);
352    }
353    Ok(nodes)
354}
355
356fn row_to_node(row: Vec<Value>) -> Result<Node> {
357    if row.len() < 11 {
358        return Err(GitCortexError::Store(format!(
359            "expected 11 columns, got {}",
360            row.len()
361        )));
362    }
363    let id_str = str_val(&row[0])?;
364    let kind = kind_from_str(&str_val(&row[1])?);
365    let name = str_val(&row[2])?;
366    let qualified_name = str_val(&row[3])?;
367    let file = PathBuf::from(str_val(&row[4])?);
368    let start_line = i64_val(&row[5])? as u32;
369    let end_line = i64_val(&row[6])? as u32;
370    let loc = i64_val(&row[7])? as u32;
371    let visibility = vis_from_str(&str_val(&row[8])?);
372    let is_async = bool_val(&row[9])?;
373    let is_unsafe = bool_val(&row[10])?;
374
375    Ok(Node {
376        id: NodeId::try_from(id_str.as_str())
377            .map_err(|e| GitCortexError::Store(format!("bad node id: {e}")))?,
378        kind,
379        name,
380        qualified_name,
381        file,
382        span: Span {
383            start_line,
384            end_line,
385        },
386        metadata: NodeMetadata {
387            loc,
388            visibility,
389            is_async,
390            is_unsafe,
391            ..Default::default()
392        },
393    })
394}
395
396fn collect_ids(conn: &mut Connection, table: &str) -> Result<Vec<String>> {
397    let result = conn
398        .query(&format!("MATCH (n:{table}) RETURN n.id"))
399        .map_err(|e| GitCortexError::Store(e.to_string()))?;
400
401    let mut ids = Vec::new();
402    for row in result {
403        ids.push(str_val(&row[0])?);
404    }
405    Ok(ids)
406}
407
408// ── Value extraction ──────────────────────────────────────────────────────────
409
410fn str_val(v: &Value) -> Result<String> {
411    match v {
412        Value::String(s) => Ok(s.clone()),
413        other => Err(GitCortexError::Store(format!(
414            "expected String, got {other:?}"
415        ))),
416    }
417}
418
419fn i64_val(v: &Value) -> Result<i64> {
420    match v {
421        Value::Int64(n) => Ok(*n),
422        Value::Int32(n) => Ok(*n as i64),
423        other => Err(GitCortexError::Store(format!(
424            "expected Int64, got {other:?}"
425        ))),
426    }
427}
428
429fn bool_val(v: &Value) -> Result<bool> {
430    match v {
431        Value::Bool(b) => Ok(*b),
432        other => Err(GitCortexError::Store(format!(
433            "expected Bool, got {other:?}"
434        ))),
435    }
436}
437
438// ── Enum conversions ──────────────────────────────────────────────────────────
439
440fn kind_from_str(s: &str) -> NodeKind {
441    match s {
442        "folder" => NodeKind::Folder,
443        "file" => NodeKind::File,
444        "module" => NodeKind::Module,
445        "struct" => NodeKind::Struct,
446        "enum" => NodeKind::Enum,
447        "trait" => NodeKind::Trait,
448        "type_alias" => NodeKind::TypeAlias,
449        "function" => NodeKind::Function,
450        "method" => NodeKind::Method,
451        "constant" => NodeKind::Constant,
452        "macro" => NodeKind::Macro,
453        _ => NodeKind::Function,
454    }
455}
456
457fn edge_kind_from_str(s: &str) -> EdgeKind {
458    match s {
459        "calls" => EdgeKind::Calls,
460        "implements" => EdgeKind::Implements,
461        "uses" => EdgeKind::Uses,
462        "imports" => EdgeKind::Imports,
463        _ => EdgeKind::Contains,
464    }
465}
466
467fn vis_str(v: &Visibility) -> String {
468    match v {
469        Visibility::Pub => "pub".into(),
470        Visibility::PubCrate => "pub_crate".into(),
471        Visibility::Private => "private".into(),
472    }
473}
474
475fn vis_from_str(s: &str) -> Visibility {
476    match s {
477        "pub" => Visibility::Pub,
478        "pub_crate" => Visibility::PubCrate,
479        _ => Visibility::Private,
480    }
481}
482
483// ── String escaping ───────────────────────────────────────────────────────────
484
485/// Escape a string for inline use in a Cypher query.
486/// Replaces `\` → `\\` and `'` → `\'`.
487fn esc(s: &str) -> String {
488    s.replace('\\', "\\\\").replace('\'', "\\'")
489}