Skip to main content

forge_core/graph/
queries.rs

1//! Graph queries for caller and reference resolution using sqlitegraph
2//!
3//! This module implements graph traversal using sqlitegraph's high-level API.
4
5use crate::error::{ForgeError, Result};
6use crate::types::{Reference, SymbolId, ReferenceKind, Location};
7use std::path::Path;
8
9/// Graph query engine using sqlitegraph
10pub struct GraphQueryEngine {
11    db_path: std::path::PathBuf,
12}
13
14impl GraphQueryEngine {
15    /// Create a new query engine for the given database path
16    pub fn new(db_path: &Path) -> Self {
17        Self {
18            db_path: db_path.to_path_buf(),
19        }
20    }
21
22    /// Find all callers of a symbol by name
23    pub fn find_callers(&self, symbol_name: &str) -> Result<Vec<Reference>> {
24        use sqlitegraph::{open_graph, GraphConfig, snapshot::SnapshotId};
25        
26        let config = GraphConfig::sqlite();
27        let backend = open_graph(&self.db_path, &config)
28            .map_err(|e| ForgeError::DatabaseError(format!("Failed to open graph: {}", e)))?;
29        
30        let target_id = match self.find_symbol_id(&*backend, symbol_name)? {
31            Some(id) => id,
32            None => return Ok(Vec::new()),
33        };
34        
35        let snapshot = SnapshotId::current();
36        let caller_ids = backend.fetch_incoming(target_id)
37            .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?;
38        
39        let mut refs = Vec::new();
40        for caller_id in caller_ids {
41            if let Ok(node) = backend.get_node(snapshot, caller_id) {
42                refs.push(Reference {
43                    from: SymbolId(caller_id),
44                    to: SymbolId(target_id),
45                    kind: ReferenceKind::Call,
46                    location: Location {
47                        file_path: std::path::PathBuf::from(node.file_path.unwrap_or_default()),
48                        byte_start: 0,
49                        byte_end: 0,
50                        line_number: 0,
51                    },
52                });
53            }
54        }
55        
56        Ok(refs)
57    }
58
59    /// Find all references to a symbol
60    pub fn find_references(&self, symbol_name: &str) -> Result<Vec<Reference>> {
61        use sqlitegraph::{open_graph, GraphConfig, snapshot::SnapshotId};
62        
63        let config = GraphConfig::sqlite();
64        let backend = open_graph(&self.db_path, &config)
65            .map_err(|e| ForgeError::DatabaseError(format!("Failed to open graph: {}", e)))?;
66        
67        let target_id = match self.find_symbol_id(&*backend, symbol_name)? {
68            Some(id) => id,
69            None => return Ok(Vec::new()),
70        };
71        
72        let snapshot = SnapshotId::current();
73        let neighbor_ids = backend.fetch_incoming(target_id)
74            .map_err(|e| ForgeError::DatabaseError(format!("Query failed: {}", e)))?;
75        
76        let mut refs = Vec::new();
77        for neighbor_id in neighbor_ids {
78            if let Ok(node) = backend.get_node(snapshot, neighbor_id) {
79                refs.push(Reference {
80                    from: SymbolId(neighbor_id),
81                    to: SymbolId(target_id),
82                    kind: ReferenceKind::TypeReference,
83                    location: Location {
84                        file_path: std::path::PathBuf::from(node.file_path.unwrap_or_default()),
85                        byte_start: 0,
86                        byte_end: 0,
87                        line_number: 0,
88                    },
89                });
90            }
91        }
92        
93        Ok(refs)
94    }
95
96    /// Find symbol ID by name
97    fn find_symbol_id(&self, backend: &dyn sqlitegraph::GraphBackend, symbol_name: &str) -> Result<Option<i64>> {
98        use sqlitegraph::snapshot::SnapshotId;
99        
100        let snapshot = SnapshotId::current();
101        let ids = backend.entity_ids()
102            .map_err(|e| ForgeError::DatabaseError(format!("Failed to list entities: {}", e)))?;
103        
104        for id in ids {
105            if let Ok(node) = backend.get_node(snapshot, id) {
106                if node.name == symbol_name {
107                    return Ok(Some(id));
108                }
109            }
110        }
111        
112        Ok(None)
113    }
114
115    /// K-hop traversal to find impacted symbols
116    pub fn find_impacted_symbols(
117        &self, 
118        start_symbol: &str, 
119        max_hops: u32
120    ) -> Result<Vec<ImpactedSymbol>> {
121        use sqlitegraph::{open_graph, GraphConfig, snapshot::SnapshotId, backend::BackendDirection};
122        
123        let config = GraphConfig::sqlite();
124        let backend = open_graph(&self.db_path, &config)
125            .map_err(|e| ForgeError::DatabaseError(format!("Failed to open graph: {}", e)))?;
126        
127        let start_id = match self.find_symbol_id(&*backend, start_symbol)? {
128            Some(id) => id,
129            None => return Ok(Vec::new()),
130        };
131        
132        let snapshot = SnapshotId::current();
133        let impacted_ids = backend.k_hop(snapshot, start_id, max_hops, BackendDirection::Outgoing)
134            .map_err(|e| ForgeError::DatabaseError(format!("k-hop query failed: {}", e)))?;
135        
136        let mut results = Vec::new();
137        for id in impacted_ids {
138            if id == start_id {
139                continue;
140            }
141            
142            if let Ok(node) = backend.get_node(snapshot, id) {
143                results.push(ImpactedSymbol {
144                    symbol_id: id,
145                    name: node.name,
146                    kind: node.kind,
147                    file_path: node.file_path.unwrap_or_default(),
148                    hop_distance: 1,
149                    edge_type: "transitive".to_string(),
150                });
151            }
152        }
153        
154        Ok(results)
155    }
156}
157
158/// Impacted symbol from k-hop analysis
159#[derive(Debug, Clone)]
160pub struct ImpactedSymbol {
161    pub symbol_id: i64,
162    pub name: String,
163    pub kind: String,
164    pub file_path: String,
165    pub hop_distance: u32,
166    pub edge_type: String,
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use tempfile::tempdir;
173    
174    // Note: Tests require a properly initialized sqlitegraph database
175    // which has a specific schema. We test the API surface here.
176    
177    #[test]
178    fn test_query_engine_creation() {
179        let temp = tempdir().unwrap();
180        let db_path = temp.path().join("test.db");
181        
182        let engine = GraphQueryEngine::new(&db_path);
183        assert_eq!(engine.db_path, db_path);
184    }
185}