1use crate::error::{ForgeError, Result as ForgeResult};
8use crate::storage::UnifiedGraphStore;
9use crate::types::{Language, Location, Symbol, SymbolId, SymbolKind};
10use std::path::PathBuf;
11use std::sync::Arc;
12
13pub struct SearchModule {
15 store: Arc<UnifiedGraphStore>,
16}
17
18impl SearchModule {
19 pub fn new(store: Arc<UnifiedGraphStore>) -> Self {
21 Self { store }
22 }
23
24 pub async fn index(&self) -> ForgeResult<()> {
29 Ok(())
30 }
31
32 pub async fn pattern_search(&self, pattern: &str) -> ForgeResult<Vec<Symbol>> {
37 let db_path = self.store.db_path.clone();
38 if db_path.exists() {
39 if let Ok(results) = self.search_via_llmgrep(pattern, true).await {
40 return Ok(results);
41 }
42 }
43
44 self.pattern_search_via_files(pattern).await
45 }
46
47 pub async fn pattern(&self, pattern: &str) -> ForgeResult<Vec<Symbol>> {
49 self.pattern_search(pattern).await
50 }
51
52 pub async fn semantic_search(&self, query: &str) -> ForgeResult<Vec<Symbol>> {
57 if query.trim().is_empty() {
58 return Ok(Vec::new());
59 }
60
61 let db_path = self.store.db_path.clone();
62 if db_path.exists() {
63 if let Ok(results) = self.search_via_llmgrep(query, false).await {
64 return Ok(results);
65 }
66 }
67
68 self.semantic_search_via_files(query).await
69 }
70
71 pub async fn semantic(&self, query: &str) -> ForgeResult<Vec<Symbol>> {
73 self.semantic_search(query).await
74 }
75
76 pub async fn symbol_by_name(&self, name: &str) -> ForgeResult<Option<Symbol>> {
78 let symbols = self.pattern_search(name).await?;
79 Ok(symbols.into_iter().find(|s| s.name == Arc::from(name)))
80 }
81
82 pub async fn symbols_by_kind(&self, kind: SymbolKind) -> ForgeResult<Vec<Symbol>> {
84 let all_symbols = self
85 .store
86 .get_all_symbols()
87 .await
88 .map_err(|e| ForgeError::DatabaseError(format!("Kind search failed: {}", e)))?;
89
90 Ok(all_symbols.into_iter().filter(|s| s.kind == kind).collect())
91 }
92
93 pub async fn references(&self, symbol_name: &str, limit: usize) -> ForgeResult<Vec<Symbol>> {
95 let db_path = self.store.db_path.clone();
96 if !db_path.exists() {
97 return Ok(Vec::new());
98 }
99 llmgrep::forge::search_references(symbol_name, &db_path, limit)
100 .map(|refs| {
101 refs.into_iter()
102 .map(|r| Symbol {
103 id: SymbolId(0),
104 name: Arc::from(r.referenced_symbol.clone()),
105 fully_qualified_name: Arc::from(r.referenced_symbol),
106 kind: SymbolKind::Function,
107 language: Language::Unknown("unknown".to_string()),
108 location: Location {
109 file_path: PathBuf::from(&r.span.file_path),
110 byte_start: r.span.byte_start as u32,
111 byte_end: r.span.byte_end as u32,
112 line_number: r.span.start_line as usize,
113 },
114 parent_id: None,
115 metadata: serde_json::Value::Null,
116 })
117 .collect()
118 })
119 .map_err(|e| ForgeError::DatabaseError(format!("Reference search failed: {}", e)))
120 }
121
122 pub async fn calls(&self, symbol_name: &str, limit: usize) -> ForgeResult<Vec<Symbol>> {
124 let db_path = self.store.db_path.clone();
125 if !db_path.exists() {
126 return Ok(Vec::new());
127 }
128 llmgrep::forge::search_calls(symbol_name, &db_path, limit)
129 .map(|calls| {
130 calls
131 .into_iter()
132 .map(|c| Symbol {
133 id: SymbolId(0),
134 name: Arc::from(c.caller.clone()),
135 fully_qualified_name: Arc::from(c.caller.clone()),
136 kind: SymbolKind::Function,
137 language: Language::Unknown("unknown".to_string()),
138 location: Location {
139 file_path: PathBuf::from(&c.span.file_path),
140 byte_start: c.span.byte_start as u32,
141 byte_end: c.span.byte_end as u32,
142 line_number: c.span.start_line as usize,
143 },
144 parent_id: None,
145 metadata: serde_json::Value::Null,
146 })
147 .collect()
148 })
149 .map_err(|e| ForgeError::DatabaseError(format!("Call search failed: {}", e)))
150 }
151
152 pub async fn lookup(&self, fqn: &str) -> ForgeResult<Option<Symbol>> {
154 let db_path = self.store.db_path.clone();
155 if !db_path.exists() {
156 return Ok(None);
157 }
158 llmgrep::forge::lookup_symbol(fqn, &db_path)
159 .map(|m| Some(llmgrep_match_to_symbol(m)))
160 .map_err(|e| ForgeError::DatabaseError(format!("Lookup failed: {}", e)))
161 }
162
163 async fn search_via_llmgrep(&self, query: &str, use_regex: bool) -> ForgeResult<Vec<Symbol>> {
166 let db_path = self.store.db_path.clone();
167
168 let result = if use_regex {
169 llmgrep::forge::search_symbols_regex(query, &db_path, 50)
170 } else {
171 llmgrep::forge::search_symbols(query, &db_path, 50)
172 };
173
174 result
175 .map(|matches| matches.into_iter().map(llmgrep_match_to_symbol).collect())
176 .map_err(|e| ForgeError::DatabaseError(format!("llmgrep search failed: {}", e)))
177 }
178
179 async fn pattern_search_via_files(&self, pattern: &str) -> ForgeResult<Vec<Symbol>> {
182 use regex::Regex;
183
184 let regex = Regex::new(pattern)
185 .map_err(|e| ForgeError::DatabaseError(format!("Invalid regex pattern: {}", e)))?;
186
187 let mut results = Vec::new();
188 let mut files = Vec::new();
189 collect_source_files(&self.store.codebase_path, &mut files).await;
190
191 for path in files {
192 if let Ok(content) = tokio::fs::read_to_string(&path).await {
193 for (line_num, line) in content.lines().enumerate() {
194 if regex.is_match(line) {
195 let symbol_name = extract_symbol_from_line(line);
196 let relative_path = path
197 .strip_prefix(&self.store.codebase_path)
198 .unwrap_or(&path);
199 results.push(Symbol {
200 id: SymbolId(0),
201 name: Arc::from(symbol_name.clone()),
202 fully_qualified_name: Arc::from(symbol_name),
203 kind: SymbolKind::Function,
204 language: Language::Rust,
205 location: Location {
206 file_path: relative_path.to_path_buf(),
207 byte_start: 0,
208 byte_end: line.len() as u32,
209 line_number: line_num + 1,
210 },
211 parent_id: None,
212 metadata: serde_json::Value::Null,
213 });
214 }
215 }
216 }
217 }
218
219 Ok(results)
220 }
221
222 async fn semantic_search_via_files(&self, query: &str) -> ForgeResult<Vec<Symbol>> {
223 let keywords: Vec<&str> = query
224 .split_whitespace()
225 .filter(|w| w.len() >= 3)
226 .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
227 .filter(|w| !w.is_empty())
228 .collect();
229
230 if keywords.is_empty() {
231 return Ok(Vec::new());
232 }
233
234 let mut results = Vec::new();
235 let mut files = Vec::new();
236 collect_source_files(&self.store.codebase_path, &mut files).await;
237
238 for path in files {
239 let Ok(content) = tokio::fs::read_to_string(&path).await else {
240 continue;
241 };
242 for (line_num, line) in content.lines().enumerate() {
243 let name = extract_symbol_from_line(line);
244 if name.is_empty() || name == "fn" {
245 continue;
246 }
247 let name_lower = name.to_lowercase();
248 let matches_keyword = keywords.iter().any(|kw| {
249 let kw_lower = kw.to_lowercase();
250 name_lower.contains(&kw_lower) || kw_lower.contains(&name_lower)
251 });
252 if matches_keyword {
253 let relative_path = path
254 .strip_prefix(&self.store.codebase_path)
255 .unwrap_or(&path);
256 results.push(Symbol {
257 id: SymbolId(0),
258 name: Arc::from(name.clone()),
259 fully_qualified_name: Arc::from(name.clone()),
260 kind: if line.contains("struct ") {
261 SymbolKind::Struct
262 } else {
263 SymbolKind::Function
264 },
265 language: Language::Rust,
266 location: Location {
267 file_path: relative_path.to_path_buf(),
268 byte_start: 0,
269 byte_end: line.len() as u32,
270 line_number: line_num + 1,
271 },
272 parent_id: None,
273 metadata: serde_json::Value::Null,
274 });
275 }
276 }
277 }
278
279 let mut seen = std::collections::HashSet::new();
280 results.retain(|s| seen.insert(s.name.clone()));
281
282 Ok(results)
283 }
284}
285
286async fn collect_source_files(dir: &std::path::Path, files: &mut Vec<PathBuf>) {
287 let Ok(mut entries) = tokio::fs::read_dir(dir).await else {
288 return;
289 };
290 while let Ok(Some(entry)) = entries.next_entry().await {
291 let path = entry.path();
292 if path.is_dir() {
293 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
294 if matches!(
295 name,
296 "target" | ".git" | ".forge" | ".magellan" | "node_modules"
297 ) {
298 continue;
299 }
300 }
301 Box::pin(collect_source_files(&path, files)).await;
302 } else if path.is_file()
303 && path
304 .extension()
305 .map(|e| {
306 matches!(
307 e.to_str(),
308 Some("rs" | "py" | "ts" | "js" | "go" | "java" | "c" | "cpp")
309 )
310 })
311 .unwrap_or(false)
312 {
313 files.push(path);
314 }
315 }
316}
317
318fn llmgrep_match_to_symbol(m: llmgrep::output::SymbolMatch) -> Symbol {
319 let kind = map_llmgrep_kind(&m.kind);
320 let language = m
321 .language
322 .as_deref()
323 .map(map_llmgrep_language)
324 .unwrap_or(Language::Unknown("unknown".to_string()));
325 let fqn: Arc<str> = Arc::from(m.fqn.clone().unwrap_or_else(|| m.name.clone()));
326
327 Symbol {
328 id: SymbolId(0),
329 name: Arc::from(m.name),
330 fully_qualified_name: fqn,
331 kind,
332 language,
333 location: Location {
334 file_path: PathBuf::from(&m.span.file_path),
335 byte_start: m.span.byte_start as u32,
336 byte_end: m.span.byte_end as u32,
337 line_number: m.span.start_line as usize,
338 },
339 parent_id: None,
340 metadata: serde_json::Value::Null,
341 }
342}
343
344fn map_llmgrep_kind(kind: &str) -> SymbolKind {
345 match kind {
346 "function_item" | "function" => SymbolKind::Function,
347 "method_item" | "method" | "impl_item" => SymbolKind::Method,
348 "struct_item" | "struct" | "class" => SymbolKind::Struct,
349 "trait_item" | "trait" | "interface" => SymbolKind::Trait,
350 "enum_item" | "enum" => SymbolKind::Enum,
351 "mod_item" | "module" | "namespace" => SymbolKind::Module,
352 "type_item" | "type_alias" => SymbolKind::TypeAlias,
353 "const_item" | "constant" => SymbolKind::Constant,
354 "field" | "property" => SymbolKind::Field,
355 _ => SymbolKind::Function,
356 }
357}
358
359fn map_llmgrep_language(lang: &str) -> Language {
360 match lang {
361 "rust" => Language::Rust,
362 "python" => Language::Python,
363 "c" => Language::C,
364 "cpp" | "c++" => Language::Cpp,
365 "java" => Language::Java,
366 "javascript" | "js" => Language::JavaScript,
367 "typescript" | "ts" => Language::TypeScript,
368 "go" => Language::Go,
369 _ => Language::Unknown(lang.to_string()),
370 }
371}
372
373fn extract_symbol_from_line(line: &str) -> String {
374 let line = line.trim();
375
376 if let Some(fn_pos) = line.find("fn ") {
377 let after_fn = &line[fn_pos + 3..];
378 if let Some(end_pos) = after_fn.find(|c: char| c.is_whitespace() || c == '(') {
379 return after_fn[..end_pos].trim().to_string();
380 }
381 }
382
383 line.split_whitespace().next().unwrap_or("").to_string()
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use crate::storage::BackendKind;
390
391 #[tokio::test]
392 async fn test_search_module_creation() {
393 let temp_dir = tempfile::tempdir().unwrap();
394 let store = Arc::new(
395 UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite)
396 .await
397 .unwrap(),
398 );
399 let _search = SearchModule::new(Arc::clone(&store));
400 }
401
402 #[tokio::test]
403 async fn test_pattern_search_empty() {
404 let temp_dir = tempfile::tempdir().unwrap();
405 let store = Arc::new(
406 UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite)
407 .await
408 .unwrap(),
409 );
410 let search = SearchModule::new(store);
411
412 let results = search.pattern_search("nonexistent").await.unwrap();
413 assert_eq!(results.len(), 0);
414 }
415
416 #[tokio::test]
417 async fn test_symbol_by_name_not_found() {
418 let temp_dir = tempfile::tempdir().unwrap();
419 let store = Arc::new(
420 UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite)
421 .await
422 .unwrap(),
423 );
424 let search = SearchModule::new(store);
425
426 let result = search.symbol_by_name("nonexistent").await.unwrap();
427 assert!(result.is_none());
428 }
429
430 #[tokio::test]
431 async fn test_symbols_by_kind() {
432 let temp_dir = tempfile::tempdir().unwrap();
433 let store = Arc::new(
434 UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite)
435 .await
436 .unwrap(),
437 );
438 let search = SearchModule::new(store);
439
440 let functions = search.symbols_by_kind(SymbolKind::Function).await.unwrap();
441 assert!(functions.is_empty());
442 }
443
444 #[test]
445 fn test_extract_symbol_from_line() {
446 assert_eq!(
447 extract_symbol_from_line("pub fn add(a: i32) -> i32 {"),
448 "add"
449 );
450 assert_eq!(extract_symbol_from_line("fn hello() {"), "hello");
451 }
452}