Skip to main content

forgekit_core/knowledge/
traversal.rs

1use crate::error::{ForgeError, Result};
2use crate::knowledge::types::{self, Direction, GraphNode};
3use crate::knowledge::KnowledgeGraph;
4
5impl KnowledgeGraph {
6    pub fn add_edge(
7        &self,
8        from: i64,
9        to: i64,
10        edge_type: &str,
11        data: serde_json::Value,
12    ) -> Result<i64> {
13        let spec = sqlitegraph::backend::EdgeSpec {
14            from,
15            to,
16            edge_type: edge_type.to_string(),
17            data,
18        };
19        self.backend
20            .insert_edge(spec)
21            .map_err(|e| ForgeError::DatabaseError(format!("Insert edge failed: {}", e)))
22    }
23
24    pub fn add_correlation(&self, from: i64, to: i64, confidence: f64, agent: &str) -> Result<()> {
25        let data = serde_json::json!({"confidence": confidence, "agent": agent,});
26        self.add_edge(from, to, types::edge::CORRELATES, data.clone())?;
27        self.add_edge(to, from, types::edge::CORRELATES, data)?;
28        Ok(())
29    }
30
31    pub fn callers_of(&self, symbol_id: i64, max_depth: u32) -> Result<Vec<GraphNode>> {
32        self.bfs_by_edge_type(
33            symbol_id,
34            types::edge::CALLS,
35            Direction::Incoming,
36            max_depth,
37        )
38    }
39
40    pub fn callees_of(&self, symbol_id: i64, max_depth: u32) -> Result<Vec<GraphNode>> {
41        self.bfs_by_edge_type(
42            symbol_id,
43            types::edge::CALLS,
44            Direction::Outgoing,
45            max_depth,
46        )
47    }
48
49    pub fn correlated(&self, node_id: i64) -> Result<Vec<GraphNode>> {
50        let incoming = self.neighbors(node_id, types::edge::CORRELATES, Direction::Incoming)?;
51        let outgoing = self.neighbors(node_id, types::edge::CORRELATES, Direction::Outgoing)?;
52        let mut seen = std::collections::HashSet::new();
53        let mut results = Vec::new();
54        for node in incoming.into_iter().chain(outgoing) {
55            if seen.insert(node.id) {
56                results.push(node);
57            }
58        }
59        Ok(results)
60    }
61
62    pub fn affected_by(&self, symbol_id: i64, depth: u32) -> Result<Vec<GraphNode>> {
63        self.bfs_by_edge_type(symbol_id, types::edge::AFFECTS, Direction::Incoming, depth)
64    }
65
66    fn bfs_by_edge_type(
67        &self,
68        start: i64,
69        edge_type: &str,
70        direction: Direction,
71        max_depth: u32,
72    ) -> Result<Vec<GraphNode>> {
73        let snap = Self::snapshot();
74        let dir = match direction {
75            Direction::Incoming => sqlitegraph::backend::BackendDirection::Incoming,
76            Direction::Outgoing => sqlitegraph::backend::BackendDirection::Outgoing,
77        };
78        let mut visited = std::collections::HashSet::new();
79        visited.insert(start);
80        let mut results = Vec::new();
81        let mut frontier: Vec<i64> = vec![start];
82
83        for _ in 0..max_depth {
84            let mut next_frontier = Vec::new();
85            for node_id in &frontier {
86                let query = sqlitegraph::backend::NeighborQuery {
87                    direction: dir,
88                    edge_type: Some(edge_type.to_string()),
89                };
90                if let Ok(neighbor_ids) = self.backend.neighbors(snap, *node_id, query) {
91                    for nid in neighbor_ids {
92                        if visited.insert(nid) {
93                            next_frontier.push(nid);
94                            if let Ok(entity) = self.backend.get_node(snap, nid) {
95                                results.push(GraphNode {
96                                    id: nid,
97                                    kind: entity.kind,
98                                    name: entity.name,
99                                    file_path: entity.file_path,
100                                    data: entity.data,
101                                });
102                            }
103                        }
104                    }
105                }
106            }
107            frontier = next_frontier;
108        }
109        Ok(results)
110    }
111
112    pub fn neighbors(
113        &self,
114        node_id: i64,
115        edge_type: &str,
116        direction: Direction,
117    ) -> Result<Vec<GraphNode>> {
118        let snap = Self::snapshot();
119        let dir = match direction {
120            Direction::Incoming => sqlitegraph::backend::BackendDirection::Incoming,
121            Direction::Outgoing => sqlitegraph::backend::BackendDirection::Outgoing,
122        };
123        let query = sqlitegraph::backend::NeighborQuery {
124            direction: dir,
125            edge_type: Some(edge_type.to_string()),
126        };
127        let neighbor_ids = self
128            .backend
129            .neighbors(snap, node_id, query)
130            .map_err(|e| ForgeError::DatabaseError(format!("Neighbor query failed: {}", e)))?;
131
132        let mut results = Vec::new();
133        for nid in neighbor_ids {
134            if let Ok(entity) = self.backend.get_node(snap, nid) {
135                results.push(GraphNode {
136                    id: nid,
137                    kind: entity.kind,
138                    name: entity.name,
139                    file_path: entity.file_path,
140                    data: entity.data,
141                });
142            }
143        }
144        Ok(results)
145    }
146
147    pub fn shortest_path(&self, from: i64, to: i64) -> Result<Option<Vec<i64>>> {
148        self.backend
149            .shortest_path(Self::snapshot(), from, to)
150            .map_err(|e| ForgeError::DatabaseError(format!("Shortest path failed: {}", e)))
151    }
152
153    pub fn reachability(&self, from: i64) -> Result<Vec<i64>> {
154        self.backend
155            .bfs(Self::snapshot(), from, 100)
156            .map_err(|e| ForgeError::DatabaseError(format!("Reachability failed: {}", e)))
157    }
158
159    pub fn k_hop(&self, from: i64, depth: u32, direction: Direction) -> Result<Vec<i64>> {
160        let dir = match direction {
161            Direction::Incoming => sqlitegraph::backend::BackendDirection::Incoming,
162            Direction::Outgoing => sqlitegraph::backend::BackendDirection::Outgoing,
163        };
164        self.backend
165            .k_hop(Self::snapshot(), from, depth, dir)
166            .map_err(|e| ForgeError::DatabaseError(format!("K-hop failed: {}", e)))
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use crate::knowledge::{open_kg, Direction, SourceSpan};
173
174    #[test]
175    fn test_add_edge() {
176        let (_temp, kg) = open_kg();
177        let a = kg
178            .add_symbol(
179                "func_a",
180                "Function",
181                "a",
182                &SourceSpan::new("f.rs", 1, 0, 10),
183                "Rust",
184                None,
185            )
186            .expect("invariant: fresh graph accepts inserts");
187        let b = kg
188            .add_symbol(
189                "func_b",
190                "Function",
191                "b",
192                &SourceSpan::new("f.rs", 2, 0, 10),
193                "Rust",
194                None,
195            )
196            .expect("invariant: fresh graph accepts inserts");
197
198        let edge_id = kg
199            .add_edge(a, b, "calls", serde_json::json!({"location_line": 5}))
200            .expect("invariant: fresh graph accepts edge inserts");
201        assert!(edge_id > 0);
202    }
203
204    #[test]
205    fn test_add_correlation_bidirectional() {
206        let (_temp, kg) = open_kg();
207        let sym = kg
208            .add_symbol(
209                "my_func",
210                "Function",
211                "a",
212                &SourceSpan::new("f.rs", 1, 0, 10),
213                "Rust",
214                None,
215            )
216            .expect("invariant: fresh graph accepts inserts");
217        let disc = kg
218            .add_discovery("claude1", "Symbol", "my_func", serde_json::json!({}))
219            .expect("invariant: fresh graph accepts inserts");
220
221        kg.add_correlation(disc, sym, 0.95, "claude1")
222            .expect("invariant: fresh graph accepts edge inserts");
223
224        let outgoing = kg
225            .neighbors(disc, "correlates", Direction::Outgoing)
226            .expect("invariant: traversal on known graph succeeds");
227        assert_eq!(outgoing.len(), 1);
228        assert_eq!(outgoing[0].name, "my_func");
229
230        let incoming = kg
231            .neighbors(sym, "correlates", Direction::Incoming)
232            .expect("invariant: traversal on known graph succeeds");
233        assert_eq!(incoming.len(), 1);
234        assert_eq!(incoming[0].name, "my_func");
235    }
236
237    #[test]
238    fn test_callers_of() {
239        let (_temp, kg) = open_kg();
240        let target = kg
241            .add_symbol(
242                "target_func",
243                "Function",
244                "t",
245                &SourceSpan::new("f.rs", 1, 0, 10),
246                "Rust",
247                None,
248            )
249            .expect("invariant: fresh graph accepts inserts");
250        let caller_a = kg
251            .add_symbol(
252                "caller_a",
253                "Function",
254                "a",
255                &SourceSpan::new("f.rs", 5, 0, 10),
256                "Rust",
257                None,
258            )
259            .expect("invariant: fresh graph accepts inserts");
260        let caller_b = kg
261            .add_symbol(
262                "caller_b",
263                "Function",
264                "b",
265                &SourceSpan::new("f.rs", 10, 0, 10),
266                "Rust",
267                None,
268            )
269            .expect("invariant: fresh graph accepts inserts");
270
271        kg.add_edge(caller_a, target, "calls", serde_json::json!({}))
272            .expect("invariant: fresh graph accepts edge inserts");
273        kg.add_edge(caller_b, target, "calls", serde_json::json!({}))
274            .expect("invariant: fresh graph accepts edge inserts");
275
276        let callers = kg
277            .callers_of(target, 1)
278            .expect("invariant: traversal on known graph succeeds");
279        assert_eq!(callers.len(), 2);
280    }
281
282    #[test]
283    fn test_callees_of() {
284        let (_temp, kg) = open_kg();
285        let func = kg
286            .add_symbol(
287                "func",
288                "Function",
289                "f",
290                &SourceSpan::new("f.rs", 1, 0, 10),
291                "Rust",
292                None,
293            )
294            .expect("invariant: fresh graph accepts inserts");
295        let callee_a = kg
296            .add_symbol(
297                "callee_a",
298                "Function",
299                "a",
300                &SourceSpan::new("f.rs", 5, 0, 10),
301                "Rust",
302                None,
303            )
304            .expect("invariant: fresh graph accepts inserts");
305        let callee_b = kg
306            .add_symbol(
307                "callee_b",
308                "Function",
309                "b",
310                &SourceSpan::new("f.rs", 10, 0, 10),
311                "Rust",
312                None,
313            )
314            .expect("invariant: fresh graph accepts inserts");
315
316        kg.add_edge(func, callee_a, "calls", serde_json::json!({}))
317            .expect("invariant: fresh graph accepts edge inserts");
318        kg.add_edge(func, callee_b, "calls", serde_json::json!({}))
319            .expect("invariant: fresh graph accepts edge inserts");
320
321        let callees = kg
322            .callees_of(func, 1)
323            .expect("invariant: traversal on known graph succeeds");
324        assert_eq!(callees.len(), 2);
325    }
326
327    #[test]
328    fn test_correlated_nodes() {
329        let (_temp, kg) = open_kg();
330        let sym = kg
331            .add_symbol(
332                "my_func",
333                "Function",
334                "a",
335                &SourceSpan::new("f.rs", 1, 0, 10),
336                "Rust",
337                None,
338            )
339            .expect("invariant: fresh graph accepts inserts");
340        let disc1 = kg
341            .add_discovery("agent1", "Symbol", "my_func", serde_json::json!({}))
342            .expect("invariant: fresh graph accepts inserts");
343        let disc2 = kg
344            .add_discovery("agent2", "CFG", "my_func", serde_json::json!({}))
345            .expect("invariant: fresh graph accepts inserts");
346
347        kg.add_correlation(disc1, sym, 0.9, "agent1")
348            .expect("invariant: fresh graph accepts edge inserts");
349        kg.add_correlation(disc2, sym, 0.8, "agent2")
350            .expect("invariant: fresh graph accepts edge inserts");
351
352        let correlated = kg
353            .correlated(sym)
354            .expect("invariant: traversal on known graph succeeds");
355        assert_eq!(correlated.len(), 2);
356    }
357
358    #[test]
359    fn test_affected_by() {
360        let (_temp, kg) = open_kg();
361        let sym = kg
362            .add_symbol(
363                "process_payment",
364                "Function",
365                "a",
366                &SourceSpan::new("f.rs", 1, 0, 10),
367                "Rust",
368                None,
369            )
370            .expect("invariant: fresh graph accepts inserts");
371        let issue = kg
372            .add_issue("high", "race condition", None)
373            .expect("invariant: fresh graph accepts inserts");
374
375        kg.add_edge(issue, sym, "affects", serde_json::json!({}))
376            .expect("invariant: fresh graph accepts edge inserts");
377
378        let affected = kg
379            .affected_by(sym, 1)
380            .expect("invariant: traversal on known graph succeeds");
381        assert_eq!(affected.len(), 1);
382    }
383
384    #[test]
385    fn test_shortest_path() {
386        let (_temp, kg) = open_kg();
387        let a = kg
388            .add_symbol(
389                "a",
390                "Function",
391                "a",
392                &SourceSpan::new("f.rs", 1, 0, 10),
393                "Rust",
394                None,
395            )
396            .expect("invariant: fresh graph accepts inserts");
397        let b = kg
398            .add_symbol(
399                "b",
400                "Function",
401                "b",
402                &SourceSpan::new("f.rs", 2, 0, 10),
403                "Rust",
404                None,
405            )
406            .expect("invariant: fresh graph accepts inserts");
407        let c = kg
408            .add_symbol(
409                "c",
410                "Function",
411                "c",
412                &SourceSpan::new("f.rs", 3, 0, 10),
413                "Rust",
414                None,
415            )
416            .expect("invariant: fresh graph accepts inserts");
417
418        kg.add_edge(a, b, "calls", serde_json::json!({}))
419            .expect("invariant: fresh graph accepts edge inserts");
420        kg.add_edge(b, c, "calls", serde_json::json!({}))
421            .expect("invariant: fresh graph accepts edge inserts");
422
423        let path = kg
424            .shortest_path(a, c)
425            .expect("invariant: algorithm on valid graph succeeds");
426        assert!(path.is_some());
427        let path = path.expect("invariant: connected nodes have a path");
428        assert!(path.contains(&a));
429        assert!(path.contains(&c));
430    }
431
432    #[test]
433    fn test_reachability() {
434        let (_temp, kg) = open_kg();
435        let a = kg
436            .add_symbol(
437                "a",
438                "Function",
439                "a",
440                &SourceSpan::new("f.rs", 1, 0, 10),
441                "Rust",
442                None,
443            )
444            .expect("invariant: fresh graph accepts inserts");
445        let b = kg
446            .add_symbol(
447                "b",
448                "Function",
449                "b",
450                &SourceSpan::new("f.rs", 2, 0, 10),
451                "Rust",
452                None,
453            )
454            .expect("invariant: fresh graph accepts inserts");
455        let c = kg
456            .add_symbol(
457                "c",
458                "Function",
459                "c",
460                &SourceSpan::new("f.rs", 3, 0, 10),
461                "Rust",
462                None,
463            )
464            .expect("invariant: fresh graph accepts inserts");
465
466        kg.add_edge(a, b, "calls", serde_json::json!({}))
467            .expect("invariant: fresh graph accepts edge inserts");
468        kg.add_edge(b, c, "calls", serde_json::json!({}))
469            .expect("invariant: fresh graph accepts edge inserts");
470
471        let reachable = kg
472            .reachability(a)
473            .expect("invariant: algorithm on valid graph succeeds");
474        assert!(reachable.contains(&b));
475        assert!(reachable.contains(&c));
476    }
477
478    #[test]
479    fn test_k_hop() {
480        let (_temp, kg) = open_kg();
481        let a = kg
482            .add_symbol(
483                "a",
484                "Function",
485                "a",
486                &SourceSpan::new("f.rs", 1, 0, 10),
487                "Rust",
488                None,
489            )
490            .expect("invariant: fresh graph accepts inserts");
491        let b = kg
492            .add_symbol(
493                "b",
494                "Function",
495                "b",
496                &SourceSpan::new("f.rs", 2, 0, 10),
497                "Rust",
498                None,
499            )
500            .expect("invariant: fresh graph accepts inserts");
501        let c = kg
502            .add_symbol(
503                "c",
504                "Function",
505                "c",
506                &SourceSpan::new("f.rs", 3, 0, 10),
507                "Rust",
508                None,
509            )
510            .expect("invariant: fresh graph accepts inserts");
511
512        kg.add_edge(a, b, "calls", serde_json::json!({}))
513            .expect("invariant: fresh graph accepts edge inserts");
514        kg.add_edge(b, c, "calls", serde_json::json!({}))
515            .expect("invariant: fresh graph accepts edge inserts");
516
517        let hop1 = kg
518            .k_hop(a, 1, Direction::Outgoing)
519            .expect("invariant: algorithm on valid graph succeeds");
520        assert!(hop1.contains(&b));
521
522        let hop2 = kg
523            .k_hop(a, 2, Direction::Outgoing)
524            .expect("invariant: algorithm on valid graph succeeds");
525        assert!(hop2.contains(&c));
526    }
527}