Skip to main content

forge_core/graph/
mod.rs

1//! Graph module - Symbol and reference queries using sqlitegraph.
2//!
3//! This module provides access to code graph for querying symbols,
4//! finding references, and running graph algorithms.
5
6pub mod queries;
7
8use std::sync::Arc;
9use std::collections::{HashMap, HashSet, VecDeque};
10use crate::storage::UnifiedGraphStore;
11use crate::error::Result;
12use crate::types::{Symbol, SymbolId, Reference, Cycle, ReferenceKind};
13use queries::GraphQueryEngine;
14
15/// Graph module for symbol and reference queries.
16///
17/// # Examples
18///
19/// See crate-level documentation for usage examples.
20#[derive(Clone)]
21pub struct GraphModule {
22    store: Arc<UnifiedGraphStore>,
23}
24
25impl GraphModule {
26    pub(crate) fn new(store: Arc<UnifiedGraphStore>) -> Self {
27        Self { store }
28    }
29    
30    /// Get the underlying store for advanced operations
31    pub fn store(&self) -> &UnifiedGraphStore {
32        &self.store
33    }
34
35    /// Finds symbols by name.
36    ///
37    /// # Arguments
38    ///
39    /// * `name` - The symbol name to search for
40    ///
41    /// # Returns
42    ///
43    /// A vector of matching symbols
44    pub async fn find_symbol(&self, name: &str) -> Result<Vec<Symbol>> {
45        #[cfg(feature = "magellan")]
46        {
47            use magellan::CodeGraph;
48            
49            let codebase_path = &self.store.codebase_path;
50            let db_path = codebase_path.join(".forge").join("graph.db");
51            
52            // Open the magellan graph
53            let mut graph = CodeGraph::open(&db_path)
54                .map_err(|e| crate::error::ForgeError::DatabaseError(
55                    format!("Failed to open magellan graph: {}", e)
56                ))?;
57            
58            // Query all symbols and filter by name
59            // For now, we scan all files and their symbols
60            let mut results = Vec::new();
61            let file_nodes = graph.all_file_nodes()
62                .map_err(|e| crate::error::ForgeError::DatabaseError(
63                    format!("Failed to get file nodes: {}", e)
64                ))?;
65            
66            for (file_path, _file_node) in file_nodes {
67                let symbols = graph.symbols_in_file(&file_path)
68                    .map_err(|e| crate::error::ForgeError::DatabaseError(
69                        format!("Failed to get symbols: {}", e)
70                    ))?;
71                
72                for sym in symbols {
73                    if let Some(ref sym_name) = sym.name {
74                        if sym_name.contains(name) {
75                            use crate::types::SymbolKind;
76                            let kind = match sym.kind {
77                                magellan::SymbolKind::Function => SymbolKind::Function,
78                                magellan::SymbolKind::Method => SymbolKind::Method,
79                                magellan::SymbolKind::Class => SymbolKind::Struct,
80                                magellan::SymbolKind::Interface => SymbolKind::Trait,
81                                magellan::SymbolKind::Enum => SymbolKind::Enum,
82                                magellan::SymbolKind::Module => SymbolKind::Module,
83                                magellan::SymbolKind::TypeAlias => SymbolKind::TypeAlias,
84                                magellan::SymbolKind::Union => SymbolKind::Enum,
85                                magellan::SymbolKind::Namespace => SymbolKind::Module,
86                                magellan::SymbolKind::Unknown => SymbolKind::Function,
87                            };
88
89                            results.push(Symbol {
90                                id: SymbolId(0), // magellan uses different ID system
91                                name: sym_name.clone(),
92                                fully_qualified_name: sym.fqn.clone().unwrap_or_else(|| sym_name.clone()),
93                                kind,
94                                language: map_magellan_language(&sym.file_path),
95                                location: crate::types::Location {
96                                    file_path: sym.file_path.clone(),
97                                    byte_start: sym.byte_start as u32,
98                                    byte_end: sym.byte_end as u32,
99                                    line_number: sym.start_line,
100                                },
101                                parent_id: None,
102                                metadata: serde_json::Value::Null,
103                            });
104                        }
105                    }
106                }
107            }
108            
109            Ok(results)
110        }
111        
112        #[cfg(not(feature = "magellan"))]
113        {
114            self.store.query_symbols(name).await
115        }
116    }
117
118    /// Finds a symbol by its stable ID.
119    ///
120    /// # Arguments
121    ///
122    /// * `id` - The symbol identifier
123    ///
124    /// # Returns
125    ///
126    /// The symbol with the given ID
127    pub async fn find_symbol_by_id(&self, id: SymbolId) -> Result<Symbol> {
128        self.store.get_symbol(id).await
129    }
130
131    /// Finds all callers of a symbol.
132    ///
133    /// # Arguments
134    ///
135    /// * `name` - The symbol name
136    ///
137    /// # Returns
138    ///
139    /// A vector of references that call this symbol
140    pub async fn callers_of(&self, name: &str) -> Result<Vec<Reference>> {
141        // Use the SQL-based query engine for real graph traversal
142        let db_path = self.store.db_path.join("graph.db");
143        
144        if !db_path.exists() {
145            // Fall back to file search if no graph database
146            return self.search_callers_in_files(name).await;
147        }
148        
149        let engine = GraphQueryEngine::new(&db_path);
150        engine.find_callers(name)
151    }
152    
153    /// Fallback: Search for callers in source files directly
154    async fn search_callers_in_files(&self, name: &str) -> Result<Vec<Reference>> {
155        use tokio::fs;
156        use regex::Regex;
157        
158        let mut callers = Vec::new();
159        let pattern = Regex::new(&format!(r"\b{}\s*\(", regex::escape(name)))
160            .map_err(|e| crate::error::ForgeError::DatabaseError(format!("Invalid regex: {}", e)))?;
161        
162        let mut entries = fs::read_dir(&self.store.codebase_path).await
163            .map_err(|e| crate::error::ForgeError::DatabaseError(format!("Failed to read codebase: {}", e)))?;
164        
165        while let Some(entry) = entries.next_entry().await
166            .map_err(|e| crate::error::ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))? 
167        {
168            let path = entry.path();
169            if path.extension().map(|e| e == "rs").unwrap_or(false) {
170                if let Ok(content) = fs::read_to_string(&path).await {
171                    for (line_num, line) in content.lines().enumerate() {
172                        if pattern.is_match(line) && !line.trim().starts_with("fn ") {
173                            callers.push(Reference {
174                                from: SymbolId(0),
175                                to: SymbolId(0),
176                                kind: ReferenceKind::Call,
177                                location: crate::types::Location {
178                                    file_path: path.clone(),
179                                    byte_start: 0,
180                                    byte_end: line.len() as u32,
181                                    line_number: line_num + 1,
182                                },
183                            });
184                        }
185                    }
186                }
187            }
188        }
189        
190        Ok(callers)
191    }
192
193    /// Finds all references to a symbol.
194    ///
195    /// # Arguments
196    ///
197    /// * `name` - The symbol name
198    ///
199    /// # Returns
200    ///
201    /// A vector of all references (calls, uses, type refs).
202    /// Uses SQL-based graph queries for accurate cross-file reference resolution.
203    pub async fn references(&self, name: &str) -> Result<Vec<Reference>> {
204        // Use the SQL-based query engine for real graph traversal
205        let db_path = self.store.db_path.join("graph.db");
206        
207        if !db_path.exists() {
208            // Fall back to file search if no graph database
209            return self.search_references_in_files(name).await;
210        }
211        
212        let engine = GraphQueryEngine::new(&db_path);
213        let mut refs = engine.find_references(name)?;
214        
215        // Remove duplicates based on location
216        let mut seen = std::collections::HashSet::new();
217        refs.retain(|r| {
218            let key = (r.location.file_path.clone(), r.location.line_number);
219            seen.insert(key)
220        });
221        
222        Ok(refs)
223    }
224    
225    /// Fallback: Search for references in source files directly
226    async fn search_references_in_files(&self, name: &str) -> Result<Vec<Reference>> {
227        use tokio::fs;
228        
229        let mut refs = Vec::new();
230        let name_lower = name.to_lowercase();
231        
232        let mut entries = fs::read_dir(&self.store.codebase_path).await
233            .map_err(|e| crate::error::ForgeError::DatabaseError(format!("Failed to read codebase: {}", e)))?;
234        
235        while let Some(entry) = entries.next_entry().await
236            .map_err(|e| crate::error::ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))? 
237        {
238            let path = entry.path();
239            if path.extension().map(|e| e == "rs").unwrap_or(false) {
240                if let Ok(content) = fs::read_to_string(&path).await {
241                    for (line_num, line) in content.lines().enumerate() {
242                        if line.to_lowercase().contains(&name_lower) {
243                            refs.push(Reference {
244                                from: SymbolId(0),
245                                to: SymbolId(0),
246                                kind: ReferenceKind::TypeReference,
247                                location: crate::types::Location {
248                                    file_path: path.clone(),
249                                    byte_start: 0,
250                                    byte_end: line.len() as u32,
251                                    line_number: line_num + 1,
252                                },
253                            });
254                        }
255                    }
256                }
257            }
258        }
259        
260        Ok(refs)
261    }
262
263    /// Finds all symbols reachable from a given symbol.
264    ///
265    /// Uses BFS traversal to find all symbols that can be reached
266    /// from the starting symbol through the call graph.
267    ///
268    /// # Arguments
269    ///
270    /// * `id` - The starting symbol ID
271    ///
272    /// # Returns
273    ///
274    /// A vector of reachable symbol IDs
275    pub async fn reachable_from(&self, id: SymbolId) -> Result<Vec<SymbolId>> {
276        // Build adjacency list for BFS
277        let mut adjacency: HashMap<SymbolId, Vec<SymbolId>> = HashMap::new();
278
279        // Query references to build the graph
280        let refs = self.store.query_references(id).await?;
281        for reference in &refs {
282            adjacency.entry(reference.from)
283                .or_insert_with(Vec::new)
284                .push(reference.to);
285        }
286
287        // BFS from the starting node
288        let mut visited = HashSet::new();
289        let mut queue = VecDeque::new();
290        let mut reachable = Vec::new();
291
292        queue.push_back(id);
293        visited.insert(id);
294
295        while let Some(current) = queue.pop_front() {
296            if let Some(neighbors) = adjacency.get(&current) {
297                for &neighbor in neighbors {
298                    if visited.insert(neighbor) {
299                        queue.push_back(neighbor);
300                        reachable.push(neighbor);
301                    }
302                }
303            }
304        }
305
306        Ok(reachable)
307    }
308
309    /// Detects cycles in the call graph.
310    ///
311    /// Uses DFS-based cycle detection to find all strongly connected
312    /// components (cycles) in the call graph.
313    ///
314    /// # Returns
315    ///
316    /// A vector of detected cycles
317    pub async fn cycles(&self) -> Result<Vec<Cycle>> {
318        // For now, return empty as we need full graph traversal
319        // This will be implemented when we integrate sqlitegraph cycles API
320        // or implement Tarjan's SCC algorithm ourselves
321        Ok(Vec::new())
322    }
323
324    /// Returns the number of symbols in the graph.
325    pub async fn symbol_count(&self) -> Result<usize> {
326        self.store.symbol_count().await
327    }
328    
329    /// Analyze the impact of changing a symbol.
330    ///
331    /// Performs k-hop traversal to find all symbols that would be affected
332    /// by modifying the given symbol.
333    ///
334    /// # Arguments
335    ///
336    /// * `symbol_name` - The name of the symbol to analyze
337    /// * `max_hops` - Maximum traversal depth (default: 2)
338    ///
339    /// # Returns
340    ///
341    /// A vector of impacted symbols with their hop distance from the target
342    pub async fn impact_analysis(&self, symbol_name: &str, max_hops: Option<u32>) -> Result<Vec<queries::ImpactedSymbol>> {
343        let db_path = self.store.db_path.join("graph.db");
344        
345        if !db_path.exists() {
346            return Ok(Vec::new());
347        }
348        
349        let engine = GraphQueryEngine::new(&db_path);
350        let hops = max_hops.unwrap_or(2);
351        engine.find_impacted_symbols(symbol_name, hops)
352    }
353    
354    /// Indexes the codebase using magellan.
355    ///
356    /// This runs the magellan indexer to extract symbols and references
357    /// from the codebase and populate the graph database.
358    ///
359    /// For Native V3 backend, also indexes cross-file references using
360    /// sqlitegraph directly (a capability SQLite doesn't support).
361    ///
362    /// # Returns
363    ///
364    /// Ok(()) on success, or an error if indexing fails.
365    pub async fn index(&self) -> Result<()> {
366        #[cfg(feature = "magellan")]
367        {
368            use magellan::CodeGraph;
369            use std::path::Path;
370            
371            
372            let codebase_path = &self.store.codebase_path;
373            // Magellan only supports SQLite, so we always use the SQLite db path
374            let db_path = codebase_path.join(".forge").join("graph.db");
375            
376            // Open or create the magellan code graph
377            let mut graph = CodeGraph::open(&db_path)
378                .map_err(|e| crate::error::ForgeError::DatabaseError(
379                    format!("Failed to open magellan graph: {}", e)
380                ))?;
381            
382            // Scan the directory and index all files
383            let count = graph.scan_directory(Path::new(codebase_path), None)
384                .map_err(|e| crate::error::ForgeError::DatabaseError(
385                    format!("Failed to scan directory: {}", e)
386                ))?;
387            
388            tracing::info!("Indexed {} symbols from {}", count, codebase_path.display());
389            
390            // Also index references and calls for each Rust file recursively
391            Self::index_references_recursive(&mut graph, codebase_path, codebase_path).await?;
392            
393            // For Native V3 backend, also index cross-file references
394            // This is a capability that Native V3 enables over SQLite
395            if self.store.backend_kind == crate::storage::BackendKind::NativeV3 {
396                let cross_file_refs = self.store.index_cross_file_references().await?;
397                tracing::info!("Indexed {} cross-file references (Native V3 only)", cross_file_refs);
398            }
399            
400            Ok(())
401        }
402        
403        #[cfg(not(feature = "magellan"))]
404        {
405            tracing::warn!("magellan feature not enabled, skipping indexing");
406            Ok(())
407        }
408    }
409    
410    #[cfg(feature = "magellan")]
411    async fn index_references_recursive(
412        graph: &mut magellan::CodeGraph,
413        codebase_path: &std::path::Path,
414        current_dir: &std::path::Path,
415    ) -> Result<()> {
416        use tokio::fs;
417        
418        let mut entries = fs::read_dir(current_dir).await
419            .map_err(|e| crate::error::ForgeError::DatabaseError(format!("Failed to read dir: {}", e)))?;
420        
421        while let Some(entry) = entries.next_entry().await
422            .map_err(|e| crate::error::ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))? 
423        {
424            let path = entry.path();
425            if path.is_dir() {
426                // Recurse into subdirectories
427                Box::pin(Self::index_references_recursive(graph, codebase_path, &path)).await?;
428            } else if path.is_file() && path.extension().map(|e| e == "rs").unwrap_or(false) {
429                // Get relative path from codebase root
430                let relative_path = path.strip_prefix(codebase_path)
431                    .unwrap_or(&path)
432                    .to_string_lossy();
433                
434                if let Ok(source) = fs::read_to_string(&path).await {
435                    // Index references using relative path
436                    let _ = graph.index_references(&relative_path, source.as_bytes());
437                    // Index calls using relative path
438                    let _ = graph.index_calls(&relative_path, source.as_bytes());
439                }
440            }
441        }
442        
443        Ok(())
444    }
445}
446
447/// Map file extension to forge Language
448#[cfg(feature = "magellan")]
449fn map_magellan_language(file_path: &std::path::Path) -> crate::types::Language {
450    use crate::types::Language;
451    
452    match file_path.extension().and_then(|e| e.to_str()) {
453        Some("rs") => Language::Rust,
454        Some("py") => Language::Python,
455        Some("c") => Language::C,
456        Some("cpp") | Some("cc") | Some("cxx") => Language::Cpp,
457        Some("java") => Language::Java,
458        Some("js") => Language::JavaScript,
459        Some("ts") => Language::TypeScript,
460        Some("go") => Language::Go,
461        _ => Language::Unknown("other".to_string()),
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use crate::storage::BackendKind;
469
470    #[tokio::test]
471    async fn test_graph_module_creation() {
472        let temp_dir = tempfile::tempdir().unwrap();
473        let store = Arc::new(UnifiedGraphStore::open(
474            temp_dir.path(),
475            BackendKind::SQLite
476        ).await.unwrap());
477
478        let module = GraphModule::new(store.clone());
479        assert_eq!(module.store.db_path(), store.db_path());
480    }
481
482    #[tokio::test]
483    async fn test_find_symbol_empty() {
484        let temp_dir = tempfile::tempdir().unwrap();
485        let store = Arc::new(UnifiedGraphStore::open(
486            temp_dir.path(),
487            BackendKind::SQLite
488        ).await.unwrap());
489
490        let module = GraphModule::new(store);
491        let symbols = module.find_symbol("nonexistent").await.unwrap();
492        assert_eq!(symbols.len(), 0);
493    }
494
495    #[tokio::test]
496    async fn test_find_symbol_by_id_not_found() {
497        let temp_dir = tempfile::tempdir().unwrap();
498        let store = Arc::new(UnifiedGraphStore::open(
499            temp_dir.path(),
500            BackendKind::SQLite
501        ).await.unwrap());
502
503        let module = GraphModule::new(store);
504        let result = module.find_symbol_by_id(SymbolId(999)).await;
505        assert!(result.is_err());
506    }
507
508    #[tokio::test]
509    async fn test_callers_of_empty() {
510        let temp_dir = tempfile::tempdir().unwrap();
511        let store = Arc::new(UnifiedGraphStore::open(
512            temp_dir.path(),
513            BackendKind::SQLite
514        ).await.unwrap());
515
516        let module = GraphModule::new(store);
517        let callers = module.callers_of("nonexistent").await.unwrap();
518        assert_eq!(callers.len(), 0);
519    }
520}