Skip to main content

forgekit_core/search/
mod.rs

1//! Search module - Semantic code search via llmgrep
2//!
3//! This module provides semantic code search by delegating to `llmgrep::forge`
4//! convenience functions. When llmgrep is disabled, falls back to regex-based
5//! file scanning.
6
7use crate::error::{ForgeError, Result as ForgeResult};
8use crate::storage::UnifiedGraphStore;
9use crate::types::{Language, Location, Symbol, SymbolId, SymbolKind};
10use std::path::PathBuf;
11use std::sync::Arc;
12
13/// Search module for semantic code queries.
14pub struct SearchModule {
15    store: Arc<UnifiedGraphStore>,
16}
17
18impl SearchModule {
19    /// Create a new SearchModule.
20    pub fn new(store: Arc<UnifiedGraphStore>) -> Self {
21        Self { store }
22    }
23
24    /// Indexes the codebase for search.
25    ///
26    /// llmgrep reads magellan's DB directly, so this is a no-op.
27    /// The graph module's `index()` populates the shared DB.
28    pub async fn index(&self) -> ForgeResult<()> {
29        Ok(())
30    }
31
32    /// Pattern-based search using regex.
33    ///
34    /// With llmgrep: delegates to `llmgrep::forge::search_symbols_regex`.
35    /// Without: scans source files recursively with regex.
36    pub async fn pattern_search(&self, pattern: &str) -> ForgeResult<Vec<Symbol>> {
37        let db_path = self.store.db_path.clone();
38        if db_path.exists() {
39            if let Ok(results) = self.search_via_llmgrep(pattern, true).await {
40                return Ok(results);
41            }
42        }
43
44        self.pattern_search_via_files(pattern).await
45    }
46
47    /// Pattern-based search (alias for `pattern_search`).
48    pub async fn pattern(&self, pattern: &str) -> ForgeResult<Vec<Symbol>> {
49        self.pattern_search(pattern).await
50    }
51
52    /// Semantic search using natural language.
53    ///
54    /// With llmgrep: delegates to `llmgrep::forge::search_symbols`.
55    /// Without: splits query into keywords and scans files.
56    pub async fn semantic_search(&self, query: &str) -> ForgeResult<Vec<Symbol>> {
57        if query.trim().is_empty() {
58            return Ok(Vec::new());
59        }
60
61        let db_path = self.store.db_path.clone();
62        if db_path.exists() {
63            if let Ok(results) = self.search_via_llmgrep(query, false).await {
64                return Ok(results);
65            }
66        }
67
68        self.semantic_search_via_files(query).await
69    }
70
71    /// Semantic search (alias for `semantic_search`).
72    pub async fn semantic(&self, query: &str) -> ForgeResult<Vec<Symbol>> {
73        self.semantic_search(query).await
74    }
75
76    /// Find a specific symbol by name.
77    pub async fn symbol_by_name(&self, name: &str) -> ForgeResult<Option<Symbol>> {
78        let symbols = self.pattern_search(name).await?;
79        Ok(symbols.into_iter().find(|s| s.name == Arc::from(name)))
80    }
81
82    /// Find all symbols of a specific kind.
83    pub async fn symbols_by_kind(&self, kind: SymbolKind) -> ForgeResult<Vec<Symbol>> {
84        let all_symbols = self
85            .store
86            .get_all_symbols()
87            .await
88            .map_err(|e| ForgeError::DatabaseError(format!("Kind search failed: {}", e)))?;
89
90        Ok(all_symbols.into_iter().filter(|s| s.kind == kind).collect())
91    }
92
93    /// Find all references to a symbol.
94    pub async fn references(&self, symbol_name: &str, limit: usize) -> ForgeResult<Vec<Symbol>> {
95        let db_path = self.store.db_path.clone();
96        if !db_path.exists() {
97            return Ok(Vec::new());
98        }
99        llmgrep::forge::search_references(symbol_name, &db_path, limit)
100            .map(|refs| {
101                refs.into_iter()
102                    .map(|r| Symbol {
103                        id: SymbolId(0),
104                        name: Arc::from(r.referenced_symbol.clone()),
105                        fully_qualified_name: Arc::from(r.referenced_symbol),
106                        kind: SymbolKind::Function,
107                        language: Language::Unknown("unknown".to_string()),
108                        location: Location {
109                            file_path: PathBuf::from(&r.span.file_path),
110                            byte_start: r.span.byte_start as u32,
111                            byte_end: r.span.byte_end as u32,
112                            line_number: r.span.start_line as usize,
113                        },
114                        parent_id: None,
115                        metadata: serde_json::Value::Null,
116                    })
117                    .collect()
118            })
119            .map_err(|e| ForgeError::DatabaseError(format!("Reference search failed: {}", e)))
120    }
121
122    /// Find all calls involving a symbol.
123    pub async fn calls(&self, symbol_name: &str, limit: usize) -> ForgeResult<Vec<Symbol>> {
124        let db_path = self.store.db_path.clone();
125        if !db_path.exists() {
126            return Ok(Vec::new());
127        }
128        llmgrep::forge::search_calls(symbol_name, &db_path, limit)
129            .map(|calls| {
130                calls
131                    .into_iter()
132                    .map(|c| Symbol {
133                        id: SymbolId(0),
134                        name: Arc::from(c.caller.clone()),
135                        fully_qualified_name: Arc::from(c.caller.clone()),
136                        kind: SymbolKind::Function,
137                        language: Language::Unknown("unknown".to_string()),
138                        location: Location {
139                            file_path: PathBuf::from(&c.span.file_path),
140                            byte_start: c.span.byte_start as u32,
141                            byte_end: c.span.byte_end as u32,
142                            line_number: c.span.start_line as usize,
143                        },
144                        parent_id: None,
145                        metadata: serde_json::Value::Null,
146                    })
147                    .collect()
148            })
149            .map_err(|e| ForgeError::DatabaseError(format!("Call search failed: {}", e)))
150    }
151
152    /// Lookup a symbol by fully-qualified name.
153    pub async fn lookup(&self, fqn: &str) -> ForgeResult<Option<Symbol>> {
154        let db_path = self.store.db_path.clone();
155        if !db_path.exists() {
156            return Ok(None);
157        }
158        llmgrep::forge::lookup_symbol(fqn, &db_path)
159            .map(|m| Some(llmgrep_match_to_symbol(m)))
160            .map_err(|e| ForgeError::DatabaseError(format!("Lookup failed: {}", e)))
161    }
162
163    // -- llmgrep-backed search --
164
165    async fn search_via_llmgrep(&self, query: &str, use_regex: bool) -> ForgeResult<Vec<Symbol>> {
166        let db_path = self.store.db_path.clone();
167
168        let result = if use_regex {
169            llmgrep::forge::search_symbols_regex(query, &db_path, 50)
170        } else {
171            llmgrep::forge::search_symbols(query, &db_path, 50)
172        };
173
174        result
175            .map(|matches| matches.into_iter().map(llmgrep_match_to_symbol).collect())
176            .map_err(|e| ForgeError::DatabaseError(format!("llmgrep search failed: {}", e)))
177    }
178
179    // -- File-based fallback search --
180
181    async fn pattern_search_via_files(&self, pattern: &str) -> ForgeResult<Vec<Symbol>> {
182        use regex::Regex;
183
184        let regex = Regex::new(pattern)
185            .map_err(|e| ForgeError::DatabaseError(format!("Invalid regex pattern: {}", e)))?;
186
187        let mut results = Vec::new();
188        let mut files = Vec::new();
189        collect_source_files(&self.store.codebase_path, &mut files).await;
190
191        for path in files {
192            if let Ok(content) = tokio::fs::read_to_string(&path).await {
193                for (line_num, line) in content.lines().enumerate() {
194                    if regex.is_match(line) {
195                        let symbol_name = extract_symbol_from_line(line);
196                        let relative_path = path
197                            .strip_prefix(&self.store.codebase_path)
198                            .unwrap_or(&path);
199                        results.push(Symbol {
200                            id: SymbolId(0),
201                            name: Arc::from(symbol_name.clone()),
202                            fully_qualified_name: Arc::from(symbol_name),
203                            kind: SymbolKind::Function,
204                            language: Language::Rust,
205                            location: Location {
206                                file_path: relative_path.to_path_buf(),
207                                byte_start: 0,
208                                byte_end: line.len() as u32,
209                                line_number: line_num + 1,
210                            },
211                            parent_id: None,
212                            metadata: serde_json::Value::Null,
213                        });
214                    }
215                }
216            }
217        }
218
219        Ok(results)
220    }
221
222    async fn semantic_search_via_files(&self, query: &str) -> ForgeResult<Vec<Symbol>> {
223        let keywords: Vec<&str> = query
224            .split_whitespace()
225            .filter(|w| w.len() >= 3)
226            .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
227            .filter(|w| !w.is_empty())
228            .collect();
229
230        if keywords.is_empty() {
231            return Ok(Vec::new());
232        }
233
234        let mut results = Vec::new();
235        let mut files = Vec::new();
236        collect_source_files(&self.store.codebase_path, &mut files).await;
237
238        for path in files {
239            let Ok(content) = tokio::fs::read_to_string(&path).await else {
240                continue;
241            };
242            for (line_num, line) in content.lines().enumerate() {
243                let name = extract_symbol_from_line(line);
244                if name.is_empty() || name == "fn" {
245                    continue;
246                }
247                let name_lower = name.to_lowercase();
248                let matches_keyword = keywords.iter().any(|kw| {
249                    let kw_lower = kw.to_lowercase();
250                    name_lower.contains(&kw_lower) || kw_lower.contains(&name_lower)
251                });
252                if matches_keyword {
253                    let relative_path = path
254                        .strip_prefix(&self.store.codebase_path)
255                        .unwrap_or(&path);
256                    results.push(Symbol {
257                        id: SymbolId(0),
258                        name: Arc::from(name.clone()),
259                        fully_qualified_name: Arc::from(name.clone()),
260                        kind: if line.contains("struct ") {
261                            SymbolKind::Struct
262                        } else {
263                            SymbolKind::Function
264                        },
265                        language: Language::Rust,
266                        location: Location {
267                            file_path: relative_path.to_path_buf(),
268                            byte_start: 0,
269                            byte_end: line.len() as u32,
270                            line_number: line_num + 1,
271                        },
272                        parent_id: None,
273                        metadata: serde_json::Value::Null,
274                    });
275                }
276            }
277        }
278
279        let mut seen = std::collections::HashSet::new();
280        results.retain(|s| seen.insert(s.name.clone()));
281
282        Ok(results)
283    }
284}
285
286async fn collect_source_files(dir: &std::path::Path, files: &mut Vec<PathBuf>) {
287    let Ok(mut entries) = tokio::fs::read_dir(dir).await else {
288        return;
289    };
290    while let Ok(Some(entry)) = entries.next_entry().await {
291        let path = entry.path();
292        if path.is_dir() {
293            if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
294                if matches!(
295                    name,
296                    "target" | ".git" | ".forge" | ".magellan" | "node_modules"
297                ) {
298                    continue;
299                }
300            }
301            Box::pin(collect_source_files(&path, files)).await;
302        } else if path.is_file()
303            && path
304                .extension()
305                .map(|e| {
306                    matches!(
307                        e.to_str(),
308                        Some("rs" | "py" | "ts" | "js" | "go" | "java" | "c" | "cpp")
309                    )
310                })
311                .unwrap_or(false)
312        {
313            files.push(path);
314        }
315    }
316}
317
318fn llmgrep_match_to_symbol(m: llmgrep::output::SymbolMatch) -> Symbol {
319    let kind = map_llmgrep_kind(&m.kind);
320    let language = m
321        .language
322        .as_deref()
323        .map(map_llmgrep_language)
324        .unwrap_or(Language::Unknown("unknown".to_string()));
325    let fqn: Arc<str> = Arc::from(m.fqn.clone().unwrap_or_else(|| m.name.clone()));
326
327    Symbol {
328        id: SymbolId(0),
329        name: Arc::from(m.name),
330        fully_qualified_name: fqn,
331        kind,
332        language,
333        location: Location {
334            file_path: PathBuf::from(&m.span.file_path),
335            byte_start: m.span.byte_start as u32,
336            byte_end: m.span.byte_end as u32,
337            line_number: m.span.start_line as usize,
338        },
339        parent_id: None,
340        metadata: serde_json::Value::Null,
341    }
342}
343
344fn map_llmgrep_kind(kind: &str) -> SymbolKind {
345    match kind {
346        "function_item" | "function" => SymbolKind::Function,
347        "method_item" | "method" | "impl_item" => SymbolKind::Method,
348        "struct_item" | "struct" | "class" => SymbolKind::Struct,
349        "trait_item" | "trait" | "interface" => SymbolKind::Trait,
350        "enum_item" | "enum" => SymbolKind::Enum,
351        "mod_item" | "module" | "namespace" => SymbolKind::Module,
352        "type_item" | "type_alias" => SymbolKind::TypeAlias,
353        "const_item" | "constant" => SymbolKind::Constant,
354        "field" | "property" => SymbolKind::Field,
355        _ => SymbolKind::Function,
356    }
357}
358
359fn map_llmgrep_language(lang: &str) -> Language {
360    match lang {
361        "rust" => Language::Rust,
362        "python" => Language::Python,
363        "c" => Language::C,
364        "cpp" | "c++" => Language::Cpp,
365        "java" => Language::Java,
366        "javascript" | "js" => Language::JavaScript,
367        "typescript" | "ts" => Language::TypeScript,
368        "go" => Language::Go,
369        _ => Language::Unknown(lang.to_string()),
370    }
371}
372
373fn extract_symbol_from_line(line: &str) -> String {
374    let line = line.trim();
375
376    if let Some(fn_pos) = line.find("fn ") {
377        let after_fn = &line[fn_pos + 3..];
378        if let Some(end_pos) = after_fn.find(|c: char| c.is_whitespace() || c == '(') {
379            return after_fn[..end_pos].trim().to_string();
380        }
381    }
382
383    line.split_whitespace().next().unwrap_or("").to_string()
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use crate::storage::BackendKind;
390
391    #[tokio::test]
392    async fn test_search_module_creation() {
393        let temp_dir = tempfile::tempdir().unwrap();
394        let store = Arc::new(
395            UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite)
396                .await
397                .unwrap(),
398        );
399        let _search = SearchModule::new(Arc::clone(&store));
400    }
401
402    #[tokio::test]
403    async fn test_pattern_search_empty() {
404        let temp_dir = tempfile::tempdir().unwrap();
405        let store = Arc::new(
406            UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite)
407                .await
408                .unwrap(),
409        );
410        let search = SearchModule::new(store);
411
412        let results = search.pattern_search("nonexistent").await.unwrap();
413        assert_eq!(results.len(), 0);
414    }
415
416    #[tokio::test]
417    async fn test_symbol_by_name_not_found() {
418        let temp_dir = tempfile::tempdir().unwrap();
419        let store = Arc::new(
420            UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite)
421                .await
422                .unwrap(),
423        );
424        let search = SearchModule::new(store);
425
426        let result = search.symbol_by_name("nonexistent").await.unwrap();
427        assert!(result.is_none());
428    }
429
430    #[tokio::test]
431    async fn test_symbols_by_kind() {
432        let temp_dir = tempfile::tempdir().unwrap();
433        let store = Arc::new(
434            UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite)
435                .await
436                .unwrap(),
437        );
438        let search = SearchModule::new(store);
439
440        let functions = search.symbols_by_kind(SymbolKind::Function).await.unwrap();
441        assert!(functions.is_empty());
442    }
443
444    #[test]
445    fn test_extract_symbol_from_line() {
446        assert_eq!(
447            extract_symbol_from_line("pub fn add(a: i32) -> i32 {"),
448            "add"
449        );
450        assert_eq!(extract_symbol_from_line("fn hello() {"), "hello");
451    }
452}