1use std::path::{Path, PathBuf};
2
3use gitcortex_core::{
4 error::{GitCortexError, Result},
5 graph::{Edge, GraphDiff, Node, NodeId, NodeMetadata, Span},
6 schema::{EdgeKind, NodeKind, Visibility},
7 store::GraphStore,
8};
9use kuzu::{Connection, Database, SystemConfig, Value};
10
11use crate::{branch, schema as db_schema};
12
13pub struct KuzuGraphStore {
21 db: Database,
22 repo_id: String,
23}
24
25impl KuzuGraphStore {
26 pub fn open(repo_root: &Path) -> Result<Self> {
28 let repo_id = branch::repo_id(repo_root);
29 let db_path = branch::db_path(&repo_id);
30
31 if let Some(parent) = db_path.parent() {
32 std::fs::create_dir_all(parent)?;
33 }
34
35 let db = Database::new(&db_path, SystemConfig::default())
36 .map_err(|e| GitCortexError::Store(format!("open db: {e}")))?;
37
38 Ok(Self { db, repo_id })
39 }
40
41 fn conn(&self) -> Result<Connection<'_>> {
44 Connection::new(&self.db)
45 .map_err(|e| GitCortexError::Store(format!("open connection: {e}")))
46 }
47
48 fn ensure_branch(&self, branch: &str) -> Result<()> {
49 let mut conn = self.conn()?;
50 db_schema::ensure_branch(&mut conn, branch)
51 }
52}
53
54impl GraphStore for KuzuGraphStore {
57 fn apply_diff(&mut self, branch: &str, diff: &GraphDiff) -> Result<()> {
60 if diff.is_empty() {
61 return Ok(());
62 }
63
64 self.ensure_branch(branch)?;
65 let nt = db_schema::node_table(branch);
66 let et = db_schema::edge_table(branch);
67 let conn = self.conn()?;
68
69 conn.query("BEGIN TRANSACTION")
73 .map_err(|e| GitCortexError::Store(format!("begin transaction: {e}")))?;
74
75 for file in &diff.removed_files {
77 let file_str = esc(file.to_string_lossy().as_ref());
78 conn.query(&format!(
79 "MATCH (n:{nt}) WHERE n.file = '{file_str}' DETACH DELETE n"
80 ))
81 .map_err(|e| GitCortexError::Store(format!("delete file nodes: {e}")))?;
82 }
83
84 for id in &diff.removed_node_ids {
86 let id_str = esc(&id.as_str());
87 conn.query(&format!(
88 "MATCH (n:{nt}) WHERE n.id = '{id_str}' DETACH DELETE n"
89 ))
90 .map_err(|e| GitCortexError::Store(format!("delete node: {e}")))?;
91 }
92
93 for (src, dst, kind) in &diff.removed_edges {
95 let s = esc(&src.as_str());
96 let d = esc(&dst.as_str());
97 let k = esc(&kind.to_string());
98 conn.query(&format!(
99 "MATCH (s:{nt})-[e:{et}]->(d:{nt}) \
100 WHERE s.id = '{s}' AND d.id = '{d}' AND e.kind = '{k}' \
101 DELETE e"
102 ))
103 .map_err(|e| GitCortexError::Store(format!("delete edge: {e}")))?;
104 }
105
106 for node in &diff.added_nodes {
108 let id = esc(&node.id.as_str());
109 let kind = esc(&node.kind.to_string());
110 let name = esc(&node.name);
111 let qname = esc(&node.qualified_name);
112 let file = esc(node.file.to_string_lossy().as_ref());
113 let sl = node.span.start_line as i64;
114 let el = node.span.end_line as i64;
115 let loc = node.metadata.loc as i64;
116 let vis = esc(&vis_str(&node.metadata.visibility));
117 let is_async = node.metadata.is_async;
118 let is_unsafe = node.metadata.is_unsafe;
119
120 conn.query(&format!(
121 "CREATE (:{nt} {{\
122 id: '{id}', kind: '{kind}', name: '{name}', \
123 qualified_name: '{qname}', file: '{file}', \
124 start_line: {sl}, end_line: {el}, loc: {loc}, \
125 visibility: '{vis}', is_async: {is_async}, is_unsafe: {is_unsafe}\
126 }})"
127 ))
128 .map_err(|e| GitCortexError::Store(format!("insert node '{name}': {e}")))?;
129 }
130
131 conn.query("COMMIT")
133 .map_err(|e| GitCortexError::Store(format!("commit nodes: {e}")))?;
134
135 conn.query("BEGIN TRANSACTION")
136 .map_err(|e| GitCortexError::Store(format!("begin edge transaction: {e}")))?;
137
138 for edge in &diff.added_edges {
141 let s = esc(&edge.src.as_str());
142 let d = esc(&edge.dst.as_str());
143 let k = esc(&edge.kind.to_string());
144
145 conn.query(&format!(
146 "MATCH (s:{nt} {{id: '{s}'}}), (d:{nt} {{id: '{d}'}}) \
147 CREATE (s)-[:{et} {{kind: '{k}'}}]->(d)"
148 ))
149 .map_err(|e| GitCortexError::Store(format!("insert edge: {e}")))?;
150 }
151
152 for (caller_id, callee_name) in &diff.deferred_calls {
158 let caller = esc(&caller_id.as_str());
159 let callee = esc(callee_name);
160 conn.query(&format!(
161 "MATCH (caller:{nt} {{id: '{caller}'}}), (callee:{nt}) \
162 WHERE callee.name = '{callee}' \
163 AND (callee.kind = 'function' OR callee.kind = 'method') \
164 CREATE (caller)-[:{et} {{kind: 'calls'}}]->(callee)"
165 ))
166 .map_err(|e| GitCortexError::Store(format!("deferred call '{callee_name}': {e}")))?;
167 }
168
169 for (fn_id, type_name) in &diff.deferred_uses {
170 let fn_esc = esc(&fn_id.as_str());
171 let ty = esc(type_name);
172 conn.query(&format!(
173 "MATCH (fn_node:{nt} {{id: '{fn_esc}'}}), (ty:{nt}) \
174 WHERE ty.name = '{ty}' \
175 AND (ty.kind = 'struct' OR ty.kind = 'enum' \
176 OR ty.kind = 'trait' OR ty.kind = 'type_alias') \
177 CREATE (fn_node)-[:{et} {{kind: 'uses'}}]->(ty)"
178 ))
179 .map_err(|e| GitCortexError::Store(format!("deferred use '{type_name}': {e}")))?;
180 }
181
182 for (struct_id, trait_name) in &diff.deferred_implements {
183 let s = esc(&struct_id.as_str());
184 let t = esc(trait_name);
185 conn.query(&format!(
186 "MATCH (st:{nt} {{id: '{s}'}}), (tr:{nt}) \
187 WHERE tr.name = '{t}' AND tr.kind = 'trait' \
188 CREATE (st)-[:{et} {{kind: 'implements'}}]->(tr)"
189 ))
190 .map_err(|e| GitCortexError::Store(format!("deferred impl '{trait_name}': {e}")))?;
191 }
192
193 conn.query("COMMIT")
194 .map_err(|e| GitCortexError::Store(format!("commit edges: {e}")))?;
195
196 Ok(())
197 }
198
199 fn lookup_symbol(&self, branch: &str, name: &str) -> Result<Vec<Node>> {
202 self.ensure_branch(branch)?;
203 let nt = db_schema::node_table(branch);
204 let name_esc = esc(name);
205 let conn = self.conn()?;
206
207 let mut result = conn
208 .query(&format!(
209 "MATCH (n:{nt}) WHERE n.name = '{name_esc}' \
210 RETURN {NODE_COLS}"
211 ))
212 .map_err(|e| GitCortexError::Store(e.to_string()))?;
213
214 rows_to_nodes(&mut result)
215 }
216
217 fn find_callers(&self, branch: &str, function_name: &str) -> Result<Vec<Node>> {
218 self.ensure_branch(branch)?;
219 let nt = db_schema::node_table(branch);
220 let et = db_schema::edge_table(branch);
221 let name_esc = esc(function_name);
222 let conn = self.conn()?;
223
224 let mut result = conn
225 .query(&format!(
226 "MATCH (caller:{nt})-[e:{et} {{kind: 'calls'}}]->(callee:{nt}) \
227 WHERE callee.name = '{name_esc}' \
228 RETURN caller.id, caller.kind, caller.name, caller.qualified_name, \
229 caller.file, caller.start_line, caller.end_line, caller.loc, \
230 caller.visibility, caller.is_async, caller.is_unsafe"
231 ))
232 .map_err(|e| GitCortexError::Store(e.to_string()))?;
233
234 rows_to_nodes(&mut result)
235 }
236
237 fn list_definitions(&self, branch: &str, file: &Path) -> Result<Vec<Node>> {
238 self.ensure_branch(branch)?;
239 let nt = db_schema::node_table(branch);
240 let file_esc = esc(file.to_string_lossy().as_ref());
241 let conn = self.conn()?;
242
243 let mut result = conn
244 .query(&format!(
245 "MATCH (n:{nt}) WHERE n.file = '{file_esc}' \
246 RETURN {NODE_COLS} ORDER BY n.start_line"
247 ))
248 .map_err(|e| GitCortexError::Store(e.to_string()))?;
249
250 rows_to_nodes(&mut result)
251 }
252
253 fn branch_diff(&self, from: &str, to: &str) -> Result<GraphDiff> {
254 self.ensure_branch(from)?;
255 self.ensure_branch(to)?;
256
257 let from_nt = db_schema::node_table(from);
258 let to_nt = db_schema::node_table(to);
259 let mut conn = self.conn()?;
260
261 let from_ids = collect_ids(&mut conn, &from_nt)?;
263 let to_ids = collect_ids(&mut conn, &to_nt)?;
264
265 let added_ids: Vec<&String> = to_ids.iter().filter(|id| !from_ids.contains(*id)).collect();
267
268 let removed_ids: Vec<&String> =
270 from_ids.iter().filter(|id| !to_ids.contains(*id)).collect();
271
272 let mut diff = GraphDiff::default();
273
274 for id in added_ids {
275 let id_esc = esc(id);
276 let mut r = conn
277 .query(&format!(
278 "MATCH (n:{to_nt}) WHERE n.id = '{id_esc}' RETURN {NODE_COLS}"
279 ))
280 .map_err(|e| GitCortexError::Store(e.to_string()))?;
281 diff.added_nodes.extend(rows_to_nodes(&mut r)?);
282 }
283
284 for id in removed_ids {
285 if let Ok(node_id) = NodeId::try_from(id.as_str()) {
286 diff.removed_node_ids.push(node_id);
287 }
288 }
289
290 Ok(diff)
291 }
292
293 fn list_all_nodes(&self, branch: &str) -> Result<Vec<Node>> {
294 self.ensure_branch(branch)?;
295 let nt = db_schema::node_table(branch);
296 let conn = self.conn()?;
297 let mut result = conn
298 .query(&format!("MATCH (n:{nt}) RETURN {NODE_COLS}"))
299 .map_err(|e| GitCortexError::Store(e.to_string()))?;
300 rows_to_nodes(&mut result)
301 }
302
303 fn list_all_edges(&self, branch: &str) -> Result<Vec<Edge>> {
304 self.ensure_branch(branch)?;
305 let nt = db_schema::node_table(branch);
306 let et = db_schema::edge_table(branch);
307 let conn = self.conn()?;
308 let result = conn
309 .query(&format!(
310 "MATCH (s:{nt})-[e:{et}]->(d:{nt}) RETURN s.id, d.id, e.kind"
311 ))
312 .map_err(|e| GitCortexError::Store(e.to_string()))?;
313
314 let mut out = Vec::new();
315 for row in result {
316 let src_str = str_val(&row[0])?;
317 let dst_str = str_val(&row[1])?;
318 let kind_str = str_val(&row[2])?;
319 out.push(Edge {
320 src: NodeId::try_from(src_str.as_str())
321 .map_err(|e| GitCortexError::Store(format!("bad src id: {e}")))?,
322 dst: NodeId::try_from(dst_str.as_str())
323 .map_err(|e| GitCortexError::Store(format!("bad dst id: {e}")))?,
324 kind: edge_kind_from_str(&kind_str),
325 });
326 }
327 Ok(out)
328 }
329
330 fn last_indexed_sha(&self, branch_name: &str) -> Result<Option<String>> {
333 branch::read_last_sha(&self.repo_id, branch_name)
334 }
335
336 fn set_last_indexed_sha(&mut self, branch_name: &str, sha: &str) -> Result<()> {
337 branch::write_last_sha(&self.repo_id, branch_name, sha)
338 }
339}
340
341const NODE_COLS: &str = "n.id, n.kind, n.name, n.qualified_name, n.file, \
346 n.start_line, n.end_line, n.loc, n.visibility, n.is_async, n.is_unsafe";
347
348fn rows_to_nodes(result: &mut kuzu::QueryResult) -> Result<Vec<Node>> {
349 let mut nodes = Vec::new();
350 for row in result.by_ref() {
351 nodes.push(row_to_node(row)?);
352 }
353 Ok(nodes)
354}
355
356fn row_to_node(row: Vec<Value>) -> Result<Node> {
357 if row.len() < 11 {
358 return Err(GitCortexError::Store(format!(
359 "expected 11 columns, got {}",
360 row.len()
361 )));
362 }
363 let id_str = str_val(&row[0])?;
364 let kind = kind_from_str(&str_val(&row[1])?);
365 let name = str_val(&row[2])?;
366 let qualified_name = str_val(&row[3])?;
367 let file = PathBuf::from(str_val(&row[4])?);
368 let start_line = i64_val(&row[5])? as u32;
369 let end_line = i64_val(&row[6])? as u32;
370 let loc = i64_val(&row[7])? as u32;
371 let visibility = vis_from_str(&str_val(&row[8])?);
372 let is_async = bool_val(&row[9])?;
373 let is_unsafe = bool_val(&row[10])?;
374
375 Ok(Node {
376 id: NodeId::try_from(id_str.as_str())
377 .map_err(|e| GitCortexError::Store(format!("bad node id: {e}")))?,
378 kind,
379 name,
380 qualified_name,
381 file,
382 span: Span {
383 start_line,
384 end_line,
385 },
386 metadata: NodeMetadata {
387 loc,
388 visibility,
389 is_async,
390 is_unsafe,
391 ..Default::default()
392 },
393 })
394}
395
396fn collect_ids(conn: &mut Connection, table: &str) -> Result<Vec<String>> {
397 let result = conn
398 .query(&format!("MATCH (n:{table}) RETURN n.id"))
399 .map_err(|e| GitCortexError::Store(e.to_string()))?;
400
401 let mut ids = Vec::new();
402 for row in result {
403 ids.push(str_val(&row[0])?);
404 }
405 Ok(ids)
406}
407
408fn str_val(v: &Value) -> Result<String> {
411 match v {
412 Value::String(s) => Ok(s.clone()),
413 other => Err(GitCortexError::Store(format!(
414 "expected String, got {other:?}"
415 ))),
416 }
417}
418
419fn i64_val(v: &Value) -> Result<i64> {
420 match v {
421 Value::Int64(n) => Ok(*n),
422 Value::Int32(n) => Ok(*n as i64),
423 other => Err(GitCortexError::Store(format!(
424 "expected Int64, got {other:?}"
425 ))),
426 }
427}
428
429fn bool_val(v: &Value) -> Result<bool> {
430 match v {
431 Value::Bool(b) => Ok(*b),
432 other => Err(GitCortexError::Store(format!(
433 "expected Bool, got {other:?}"
434 ))),
435 }
436}
437
438fn kind_from_str(s: &str) -> NodeKind {
441 match s {
442 "folder" => NodeKind::Folder,
443 "file" => NodeKind::File,
444 "module" => NodeKind::Module,
445 "struct" => NodeKind::Struct,
446 "enum" => NodeKind::Enum,
447 "trait" => NodeKind::Trait,
448 "type_alias" => NodeKind::TypeAlias,
449 "function" => NodeKind::Function,
450 "method" => NodeKind::Method,
451 "constant" => NodeKind::Constant,
452 "macro" => NodeKind::Macro,
453 _ => NodeKind::Function,
454 }
455}
456
457fn edge_kind_from_str(s: &str) -> EdgeKind {
458 match s {
459 "calls" => EdgeKind::Calls,
460 "implements" => EdgeKind::Implements,
461 "uses" => EdgeKind::Uses,
462 "imports" => EdgeKind::Imports,
463 _ => EdgeKind::Contains,
464 }
465}
466
467fn vis_str(v: &Visibility) -> String {
468 match v {
469 Visibility::Pub => "pub".into(),
470 Visibility::PubCrate => "pub_crate".into(),
471 Visibility::Private => "private".into(),
472 }
473}
474
475fn vis_from_str(s: &str) -> Visibility {
476 match s {
477 "pub" => Visibility::Pub,
478 "pub_crate" => Visibility::PubCrate,
479 _ => Visibility::Private,
480 }
481}
482
483fn esc(s: &str) -> String {
488 s.replace('\\', "\\\\").replace('\'', "\\'")
489}