Skip to main content

forge_core/search/
mod.rs

1//! Search module - Semantic code search via llmgrep
2//!
3//! This module provides semantic code search by integrating with llmgrep,
4//! which queries magellan databases for symbols, references, and calls.
5
6use std::sync::Arc;
7use crate::storage::UnifiedGraphStore;
8use crate::error::{ForgeError, Result as ForgeResult};
9use crate::types::{Symbol, SymbolKind, Language, Location};
10
11/// Search module for semantic code queries.
12pub struct SearchModule {
13    store: Arc<UnifiedGraphStore>,
14}
15
16impl SearchModule {
17    /// Create a new SearchModule.
18    pub fn new(store: Arc<UnifiedGraphStore>) -> Self {
19        Self { store }
20    }
21
22    /// Indexes the codebase for search.
23    ///
24    /// This is a no-op for the current implementation which scans files directly.
25    /// In a future version with embedding-based search, this would build the index.
26    pub async fn index(&self) -> ForgeResult<()> {
27        // Current implementation scans files directly, no indexing needed
28        Ok(())
29    }
30
31    /// Pattern-based search using regex (async).
32    ///
33    /// Scans source files for patterns like "fn \w+\(" and returns matching symbols.
34    pub async fn pattern_search(&self, pattern: &str) -> ForgeResult<Vec<Symbol>> {
35        use regex::Regex;
36        
37        
38        // Compile the regex pattern
39        let regex = Regex::new(pattern)
40            .map_err(|e| ForgeError::DatabaseError(format!("Invalid regex pattern: {}", e)))?;
41        
42        let mut results = Vec::new();
43        
44        // Scan source files recursively
45        Self::search_files_recursive(
46            &self.store.codebase_path,
47            &self.store.codebase_path,
48            &regex,
49            &mut results,
50        ).await?;
51        
52        Ok(results)
53    }
54    
55    /// Recursively search files for pattern matches
56    async fn search_files_recursive(
57        root: &std::path::Path,
58        dir: &std::path::Path,
59        regex: &regex::Regex,
60        results: &mut Vec<Symbol>,
61    ) -> ForgeResult<()> {
62        use tokio::fs;
63        
64        let mut entries = fs::read_dir(dir).await
65            .map_err(|e| ForgeError::DatabaseError(format!("Failed to read dir: {}", e)))?;
66        
67        while let Some(entry) = entries.next_entry().await
68            .map_err(|e| ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))? 
69        {
70            let path = entry.path();
71            if path.is_dir() {
72                // Recurse into subdirectories
73                Box::pin(Self::search_files_recursive(root, &path, regex, results)).await?;
74            } else if path.is_file() && path.extension().map(|e| e == "rs").unwrap_or(false) {
75                // Read and search Rust files
76                if let Ok(content) = fs::read_to_string(&path).await {
77                    for (line_num, line) in content.lines().enumerate() {
78                        if regex.is_match(line) {
79                            // Extract symbol name from the matched line
80                            let symbol_name = extract_symbol_from_line(line);
81                            let relative_path = path.strip_prefix(root).unwrap_or(&path);
82                            
83                            results.push(Symbol {
84                                id: crate::types::SymbolId(0),
85                                name: symbol_name.clone(),
86                                fully_qualified_name: symbol_name,
87                                kind: SymbolKind::Function, // Assume function for fn patterns
88                                language: Language::Rust,
89                                location: Location {
90                                    file_path: relative_path.to_path_buf(),
91                                    byte_start: 0,
92                                    byte_end: line.len() as u32,
93                                    line_number: line_num + 1,
94                                },
95                                parent_id: None,
96                                metadata: serde_json::Value::Null,
97                            });
98                        }
99                    }
100                }
101            }
102        }
103        
104        Ok(())
105    }
106
107    /// Pattern-based search (alias for `pattern_search`).
108    pub async fn pattern(&self, pattern: &str) -> ForgeResult<Vec<Symbol>> {
109        self.pattern_search(pattern).await
110    }
111
112    /// Semantic search using natural language (async).
113    ///
114    /// Note: True semantic search would require embedding generation.
115    /// This implementation uses keyword matching on symbol names, with
116    /// substring matching for partial word matches.
117    pub async fn semantic_search(&self, query: &str) -> ForgeResult<Vec<Symbol>> {
118        // Extract keywords from the query
119        let keywords: Vec<&str> = query
120            .split_whitespace()
121            .filter(|w| w.len() >= 3) // Consider words 3+ chars
122            .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
123            .filter(|w| !w.is_empty())
124            .collect();
125        
126        if keywords.is_empty() {
127            return Ok(Vec::new());
128        }
129        
130        // First try exact pattern search
131        let mut all_results = Vec::new();
132        for keyword in &keywords {
133            let matches = self.pattern_search(keyword).await?;
134            all_results.extend(matches);
135        }
136        
137        // Also scan files for keywords that might match as substrings
138        // This handles cases like "addition" matching "add"
139        self.scan_for_substring_matches(&keywords, &mut all_results).await?;
140        
141        // Remove duplicates (by name)
142        let mut seen = std::collections::HashSet::new();
143        all_results.retain(|s| seen.insert(s.name.clone()));
144        
145        Ok(all_results)
146    }
147    
148    /// Scan files for symbols that contain keyword substrings
149    async fn scan_for_substring_matches(
150        &self,
151        keywords: &[&str],
152        results: &mut Vec<Symbol>,
153    ) -> ForgeResult<()> {
154        use tokio::fs;
155        
156        let codebase_path = &self.store.codebase_path;
157        let mut entries = fs::read_dir(codebase_path).await
158            .map_err(|e| ForgeError::DatabaseError(format!("Failed to read dir: {}", e)))?;
159        
160        while let Some(entry) = entries.next_entry().await
161            .map_err(|e| ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))? 
162        {
163            let path = entry.path();
164            if path.is_dir() {
165                // Recurse (simplified - in production use walkdir)
166                if let Ok(sub_entries) = fs::read_dir(&path).await {
167                    let mut sub_entries = sub_entries;
168                    while let Ok(Some(sub_entry)) = sub_entries.next_entry().await {
169                        let sub_path = sub_entry.path();
170                        if sub_path.is_file() && sub_path.extension().map(|e| e == "rs").unwrap_or(false) {
171                            Self::check_file_for_submatches(&sub_path, keywords, results, codebase_path).await?;
172                        }
173                    }
174                }
175            } else if path.is_file() && path.extension().map(|e| e == "rs").unwrap_or(false) {
176                Self::check_file_for_submatches(&path, keywords, results, codebase_path).await?;
177            }
178        }
179        
180        Ok(())
181    }
182    
183    async fn check_file_for_submatches(
184        path: &std::path::Path,
185        keywords: &[&str],
186        results: &mut Vec<Symbol>,
187        root: &std::path::Path,
188    ) -> ForgeResult<()> {
189        use tokio::fs;
190        
191        if let Ok(content) = fs::read_to_string(path).await {
192            for (line_num, line) in content.lines().enumerate() {
193                // Look for function definitions
194                if line.contains("fn ") {
195                    let fn_name = extract_symbol_from_line(line);
196                    // Check if any keyword is a substring of this function name
197                    // or if function name is a substring of any keyword
198                    for keyword in keywords {
199                        if fn_name.contains(keyword) || keyword.contains(&fn_name) {
200                            if !fn_name.is_empty() && fn_name != "fn" {
201                                let relative_path = path.strip_prefix(root).unwrap_or(path);
202                                results.push(Symbol {
203                                    id: crate::types::SymbolId(0),
204                                    name: fn_name.clone(),
205                                    fully_qualified_name: fn_name,
206                                    kind: SymbolKind::Function,
207                                    language: Language::Rust,
208                                    location: Location {
209                                        file_path: relative_path.to_path_buf(),
210                                        byte_start: 0,
211                                        byte_end: line.len() as u32,
212                                        line_number: line_num + 1,
213                                    },
214                                    parent_id: None,
215                                    metadata: serde_json::Value::Null,
216                                });
217                            }
218                            break;
219                        }
220                    }
221                }
222                
223                // Also look for struct definitions (for "calculator" -> "Calculator")
224                if line.contains("struct ") {
225                    let struct_name = extract_struct_from_line(line);
226                    for keyword in keywords {
227                        let keyword_lower = keyword.to_lowercase();
228                        let struct_lower = struct_name.to_lowercase();
229                        if struct_lower.contains(&keyword_lower) || keyword_lower.contains(&struct_lower) {
230                            if !struct_name.is_empty() {
231                                let relative_path = path.strip_prefix(root).unwrap_or(path);
232                                results.push(Symbol {
233                                    id: crate::types::SymbolId(0),
234                                    name: struct_name.clone(),
235                                    fully_qualified_name: struct_name,
236                                    kind: SymbolKind::Struct,
237                                    language: Language::Rust,
238                                    location: Location {
239                                        file_path: relative_path.to_path_buf(),
240                                        byte_start: 0,
241                                        byte_end: line.len() as u32,
242                                        line_number: line_num + 1,
243                                    },
244                                    parent_id: None,
245                                    metadata: serde_json::Value::Null,
246                                });
247                            }
248                            break;
249                        }
250                    }
251                }
252            }
253        }
254        
255        Ok(())
256    }
257
258    /// Semantic search (alias for `semantic_search`).
259    pub async fn semantic(&self, query: &str) -> ForgeResult<Vec<Symbol>> {
260        self.semantic_search(query).await
261    }
262
263    /// Find a specific symbol by name (async).
264    pub async fn symbol_by_name(&self, name: &str) -> ForgeResult<Option<Symbol>> {
265        let symbols = self.pattern_search(name).await?;
266        // Return first exact match or None
267        Ok(symbols.into_iter().find(|s| s.name == name))
268    }
269
270    /// Find all symbols of a specific kind (async).
271    pub async fn symbols_by_kind(&self, kind: SymbolKind) -> ForgeResult<Vec<Symbol>> {
272        // Query all symbols and filter by kind
273        let all_symbols = self.store.get_all_symbols().await
274            .map_err(|e| ForgeError::DatabaseError(format!("Kind search failed: {}", e)))?;
275
276        let filtered: Vec<Symbol> = all_symbols
277            .into_iter()
278            .filter(|s| s.kind == kind)
279            .collect();
280
281        Ok(filtered)
282    }
283}
284
285/// Map magellan SymbolKind to forge SymbolKind
286#[cfg(feature = "magellan")]
287#[expect(dead_code)] // Helper for magellan integration
288fn map_magellan_kind(kind: &magellan::SymbolKind) -> SymbolKind {
289    use magellan::SymbolKind as MagellanKind;
290    
291    match kind {
292        MagellanKind::Function => SymbolKind::Function,
293        MagellanKind::Method => SymbolKind::Method,
294        MagellanKind::Class => SymbolKind::Struct,
295        MagellanKind::Interface => SymbolKind::Trait,
296        MagellanKind::Enum => SymbolKind::Enum,
297        MagellanKind::Module => SymbolKind::Module,
298        MagellanKind::TypeAlias => SymbolKind::TypeAlias,
299        MagellanKind::Union => SymbolKind::Enum,
300        MagellanKind::Namespace => SymbolKind::Module,
301        MagellanKind::Unknown => SymbolKind::Function,
302    }
303}
304
305/// Extract function name from a source line
306/// e.g., "pub fn add(a: i32) -> i32 {" -> "add"
307fn extract_symbol_from_line(line: &str) -> String {
308    let line = line.trim();
309    
310    // Try to extract function name
311    if let Some(fn_pos) = line.find("fn ") {
312        let after_fn = &line[fn_pos + 3..];
313        // Find the end of the identifier (whitespace or ()
314        if let Some(end_pos) = after_fn.find(|c: char| c.is_whitespace() || c == '(') {
315            return after_fn[..end_pos].trim().to_string();
316        }
317    }
318    
319    // Default: return first word
320    line.split_whitespace().next().unwrap_or("").to_string()
321}
322
323/// Extract struct name from a source line
324/// e.g., "pub struct Calculator {" -> "Calculator"
325fn extract_struct_from_line(line: &str) -> String {
326    let line = line.trim();
327    
328    if let Some(struct_pos) = line.find("struct ") {
329        let after_struct = &line[struct_pos + 7..];
330        if let Some(end_pos) = after_struct.find(|c: char| c.is_whitespace() || c == '{' || c == ';' || c == '(') {
331            return after_struct[..end_pos].trim().to_string();
332        }
333    }
334    
335    // Default: return first word
336    line.split_whitespace().next().unwrap_or("").to_string()
337}
338
339/// Simple glob pattern matching (supports * wildcard)
340#[expect(dead_code)] // Helper for pattern matching
341fn glob_match(pattern: &str, text: &str) -> bool {
342    if !pattern.contains('*') {
343        return pattern == text;
344    }
345    
346    let parts: Vec<&str> = pattern.split('*').collect();
347    if parts.is_empty() {
348        return true;
349    }
350    
351    let mut text_remaining = text;
352    for (i, part) in parts.iter().enumerate() {
353        if part.is_empty() {
354            continue;
355        }
356        
357        if i == 0 && !pattern.starts_with('*') {
358            // First part must match at start
359            if !text_remaining.starts_with(part) {
360                return false;
361            }
362            text_remaining = &text_remaining[part.len()..];
363        } else if i == parts.len() - 1 && !pattern.ends_with('*') {
364            // Last part must match at end
365            if !text_remaining.ends_with(part) {
366                return false;
367            }
368        } else {
369            // Middle part can match anywhere
370            if let Some(pos) = text_remaining.find(part) {
371                text_remaining = &text_remaining[pos + part.len()..];
372            } else {
373                return false;
374            }
375        }
376    }
377    
378    true
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use crate::storage::BackendKind;
385
386    #[tokio::test]
387    async fn test_search_module_creation() {
388        let temp_dir = tempfile::tempdir().unwrap();
389        let store = Arc::new(UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite).await.unwrap());
390        let _search = SearchModule::new(store.clone());
391    }
392
393    #[tokio::test]
394    async fn test_pattern_search_empty() {
395        let temp_dir = tempfile::tempdir().unwrap();
396        let store = Arc::new(UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite).await.unwrap());
397        let search = SearchModule::new(store);
398
399        let results = search.pattern_search("nonexistent").await.unwrap();
400        assert_eq!(results.len(), 0);
401    }
402
403    #[tokio::test]
404    async fn test_symbol_by_name_not_found() {
405        let temp_dir = tempfile::tempdir().unwrap();
406        let store = Arc::new(UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite).await.unwrap());
407        let search = SearchModule::new(store);
408
409        let result = search.symbol_by_name("nonexistent").await.unwrap();
410        assert!(result.is_none());
411    }
412
413    #[tokio::test]
414    async fn test_symbols_by_kind() {
415        let temp_dir = tempfile::tempdir().unwrap();
416        let store = Arc::new(UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite).await.unwrap());
417        let search = SearchModule::new(store);
418
419        let functions = search.symbols_by_kind(SymbolKind::Function).await.unwrap();
420        // Empty since no symbols inserted yet
421        assert!(functions.is_empty());
422    }
423}