Skip to main content

sift/
query.rs

1use crate::embed::Embedder;
2use crate::index::CodeIndex;
3use serde::Serialize;
4use std::path::Path;
5
6#[derive(Debug, Serialize)]
7pub struct QueryResult {
8    #[serde(rename = "type")]
9    pub result_type: &'static str,
10    pub name: String,
11    pub kind: String,
12    pub file: String,
13    pub line: usize,
14    pub end_line: usize,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub score: Option<f64>,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub doc: Option<String>,
19}
20
21#[derive(Debug, Serialize)]
22pub struct CallResult {
23    #[serde(rename = "type")]
24    pub result_type: &'static str,
25    pub caller: String,
26    pub callee: String,
27    pub file: String,
28    pub line: usize,
29}
30
31#[derive(Debug, Serialize)]
32pub struct FileResult {
33    #[serde(rename = "type")]
34    pub result_type: &'static str,
35    pub file: String,
36    pub symbols: Vec<String>,
37}
38
39#[derive(Debug, Serialize)]
40pub struct ImportResult {
41    #[serde(rename = "type")]
42    pub result_type: &'static str,
43    pub file: String,
44    pub symbol: String,
45    pub resolved: bool,
46    pub resolved_file: Option<String>,
47    pub resolved_line: Option<usize>,
48    pub resolved_kind: Option<String>,
49}
50
51#[derive(Debug, Serialize)]
52pub struct ImporterResult {
53    #[serde(rename = "type")]
54    pub result_type: &'static str,
55    pub symbol: String,
56    pub importer_file: String,
57    pub import_name: String,
58}
59
60#[derive(Debug, Serialize)]
61#[serde(untagged)]
62pub enum OutputRow {
63    Query(QueryResult),
64    Call(CallResult),
65    File(FileResult),
66    Simple(SimpleResult),
67    Import(ImportResult),
68    Importer(ImporterResult),
69}
70
71#[derive(Debug, Serialize)]
72pub struct SimpleResult {
73    #[serde(rename = "type")]
74    pub result_type: String,
75    pub value: String,
76}
77
78pub struct QueryEngine<'a> {
79    index: &'a CodeIndex,
80    embedder: Option<Box<dyn Embedder + 'a>>,
81}
82
83impl<'a> QueryEngine<'a> {
84    pub fn new(index: &'a CodeIndex) -> Self {
85        Self { index, embedder: None }
86    }
87
88    pub fn with_embedder(index: &'a CodeIndex, embedder: Box<dyn Embedder + 'a>) -> Self {
89        Self { index, embedder: Some(embedder) }
90    }
91
92    pub fn execute(&self, query: &str) -> Vec<OutputRow> {
93        let query = query.trim();
94        let (cmd, arg) = query
95            .split_once(' ')
96            .map(|(c, a)| (c, a.trim()))
97            .unwrap_or((query, ""));
98        match (cmd, arg) {
99            ("define", a) => self.cmd_define(a),
100            ("calls", a) => self.cmd_calls(a),
101            ("callees", a) => self.cmd_callees(a),
102            ("implements", a) => self.cmd_implements(a),
103            ("imports", a) => self.cmd_imports(a),
104            ("importers", a) => self.cmd_importers(a),
105            ("file", a) => self.cmd_file(a),
106            ("symbols", a) if a.starts_with("matching ") => {
107                self.cmd_symbols_matching(a.strip_prefix("matching ").unwrap_or("").trim())
108            }
109            ("semantic", a) => self.cmd_semantic(a),
110            ("files", "") => self.cmd_files(),
111            _ => self.cmd_define(query),
112        }
113    }
114
115    fn rel(&self, path: &Path) -> String {
116        self.index.relative_path(path)
117    }
118
119    fn cmd_define(&self, name: &str) -> Vec<OutputRow> {
120        self.index
121            .find_symbols_by_name(name)
122            .into_iter()
123            .map(|s| {
124                OutputRow::Query(QueryResult {
125                    result_type: "definition",
126                    name: s.name.clone(),
127                    kind: format!("{:?}", s.kind).to_lowercase(),
128                    file: self.rel(&s.file),
129                    line: s.line,
130                    end_line: s.end_line,
131                    score: None,
132                    doc: s.doc.clone(),
133                })
134            })
135            .collect()
136    }
137
138    fn cmd_calls(&self, name: &str) -> Vec<OutputRow> {
139        self.index
140            .find_calls_to(name)
141            .into_iter()
142            .map(|c| {
143                OutputRow::Call(CallResult {
144                    result_type: "call",
145                    caller: c.caller_name.clone(),
146                    callee: c.callee_name.clone(),
147                    file: self.rel(&c.caller_file),
148                    line: c.caller_line,
149                })
150            })
151            .collect()
152    }
153
154    fn cmd_callees(&self, name: &str) -> Vec<OutputRow> {
155        self.index
156            .find_calls_by(name)
157            .into_iter()
158            .map(|c| {
159                OutputRow::Call(CallResult {
160                    result_type: "callee",
161                    caller: c.caller_name.clone(),
162                    callee: c.callee_name.clone(),
163                    file: self.rel(&c.caller_file),
164                    line: c.caller_line,
165                })
166            })
167            .collect()
168    }
169
170    fn cmd_implements(&self, name: &str) -> Vec<OutputRow> {
171        self.index
172            .find_implementations(name)
173            .into_iter()
174            .map(|s| {
175                OutputRow::Query(QueryResult {
176                    result_type: "implementation",
177                    name: s.name.clone(),
178                    kind: format!("{:?}", s.kind).to_lowercase(),
179                    file: self.rel(&s.file),
180                    line: s.line,
181                    end_line: s.end_line,
182                    score: None,
183                    doc: s.doc.clone(),
184                })
185            })
186            .collect()
187    }
188
189    fn cmd_imports(&self, path: &str) -> Vec<OutputRow> {
190        let query = Path::new(path);
191        let matched: Vec<_> = self
192            .index
193            .files
194            .iter()
195            .filter(|f| self.rel(f) == path || f.ends_with(query))
196            .collect();
197        let mut rows = Vec::new();
198        for f in matched {
199            for imp in self.index.find_imports_in_file(f) {
200                rows.push(OutputRow::Import(ImportResult {
201                    result_type: "import",
202                    file: self.rel(f),
203                    symbol: imp.symbol_name.clone(),
204                    resolved: imp.resolved_to.is_some(),
205                    resolved_file: imp.resolved_file.as_ref().map(|p| self.rel(p)),
206                    resolved_line: imp.resolved_line,
207                    resolved_kind: imp.resolved_kind.clone(),
208                }));
209            }
210        }
211        rows
212    }
213
214    fn cmd_importers(&self, name: &str) -> Vec<OutputRow> {
215        self.index
216            .find_importers_of(name)
217            .into_iter()
218            .map(|imp| OutputRow::Importer(ImporterResult {
219                result_type: "importer",
220                symbol: name.to_string(),
221                importer_file: self.rel(&imp.file),
222                import_name: imp.symbol_name.clone(),
223            }))
224            .collect()
225    }
226
227    fn cmd_file(&self, path: &str) -> Vec<OutputRow> {
228        let query = Path::new(path);
229        let matched: Vec<_> = self
230            .index
231            .files
232            .iter()
233            .filter(|f| {
234                self.rel(f) == path || f.ends_with(query)
235            })
236            .collect();
237
238        if matched.is_empty() {
239            return vec![];
240        }
241
242        let mut rows = Vec::new();
243        for f in matched {
244            let syms = self.index.find_symbols_in_file(f);
245            rows.push(OutputRow::File(FileResult {
246                result_type: "file",
247                file: self.rel(f),
248                symbols: syms.into_iter().map(|s| s.name.clone()).collect(),
249            }));
250        }
251        rows
252    }
253
254    fn cmd_symbols_matching(&self, pattern: &str) -> Vec<OutputRow> {
255        self.index
256            .find_symbols_by_pattern(pattern)
257            .into_iter()
258            .map(|s| {
259                OutputRow::Query(QueryResult {
260                    result_type: "definition",
261                    name: s.name.clone(),
262                    kind: format!("{:?}", s.kind).to_lowercase(),
263                    file: self.rel(&s.file),
264                    line: s.line,
265                    end_line: s.end_line,
266                    score: None,
267                    doc: s.doc.clone(),
268                })
269            })
270            .collect()
271    }
272
273    fn cmd_semantic(&self, query_text: &str) -> Vec<OutputRow> {
274        let Some(embedder) = &self.embedder else {
275            return vec![];
276        };
277        let has_embeddings = self.index.symbols.iter().any(|s| s.embedding.is_some());
278        if !has_embeddings {
279            return vec![];
280        }
281        let Ok(embeddings) = embedder.embed(&[query_text]) else {
282            return vec![];
283        };
284        let Some(query_embed) = embeddings.into_iter().next() else {
285            return vec![];
286        };
287        self.index
288            .semantic_search(&query_embed, 10)
289            .into_iter()
290            .map(|(score, s)| {
291                OutputRow::Query(QueryResult {
292                    result_type: "semantic",
293                    name: s.name.clone(),
294                    kind: format!("{:?}", s.kind).to_lowercase(),
295                    file: self.rel(&s.file),
296                    line: s.line,
297                    end_line: s.end_line,
298                    score: Some(score),
299                    doc: s.doc.clone(),
300                })
301            })
302            .collect()
303    }
304
305    fn cmd_files(&self) -> Vec<OutputRow> {
306        self.index
307            .files
308            .iter()
309            .map(|f| {
310                OutputRow::Simple(SimpleResult {
311                    result_type: "file".to_string(),
312                    value: self.rel(f),
313                })
314            })
315            .collect()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::index::CodeIndex;
323    use crate::parser::{DefKind, ParsedDef, ParsedFile, ParsedImport, ParsedRef, RefKind};
324    use std::path::PathBuf;
325
326    fn make_index() -> CodeIndex {
327        let files = vec![
328            ParsedFile {
329                path: PathBuf::from("/root/src/main.rs"),
330                language: crate::parser::LanguageId::Rust,
331                definitions: vec![
332                    ParsedDef { name: "main".into(), kind: DefKind::Function, start_line: 1, end_line: 10, doc: None },
333                    ParsedDef { name: "run".into(), kind: DefKind::Function, start_line: 12, end_line: 20, doc: None },
334                ],
335                references: vec![
336                    ParsedRef { name: "run".into(), kind: RefKind::Call, line: 5 },
337                    ParsedRef { name: "helper".into(), kind: RefKind::Call, line: 6 },
338                ],
339                imports: vec![
340                    ParsedImport { name: "HashMap".into() },
341                ],
342            },
343            ParsedFile {
344                path: PathBuf::from("/root/src/helper.rs"),
345                language: crate::parser::LanguageId::Rust,
346                definitions: vec![
347                    ParsedDef { name: "helper".into(), kind: DefKind::Function, start_line: 1, end_line: 3, doc: None },
348                ],
349                references: vec![],
350                imports: vec![],
351            },
352            ParsedFile {
353                path: PathBuf::from("/root/src/collections.rs"),
354                language: crate::parser::LanguageId::Rust,
355                definitions: vec![
356                    ParsedDef { name: "HashMap".into(), kind: DefKind::Struct, start_line: 10, end_line: 50, doc: None },
357                ],
358                references: vec![],
359                imports: vec![],
360            },
361        ];
362        CodeIndex::build(files, Path::new("/root"), None)
363    }
364
365    #[test]
366    fn test_define_query() {
367        let index = make_index();
368        let engine = QueryEngine::new(&index);
369        let results = engine.execute("define main");
370        assert_eq!(results.len(), 1);
371        if let OutputRow::Query(r) = &results[0] {
372            assert_eq!(r.name, "main");
373            assert_eq!(r.file, "src/main.rs");
374        } else {
375            panic!("expected Query result");
376        }
377    }
378
379    #[test]
380    fn test_define_missing() {
381        let index = make_index();
382        let engine = QueryEngine::new(&index);
383        let results = engine.execute("define nonexistent");
384        assert_eq!(results.len(), 0);
385    }
386
387    #[test]
388    fn test_calls_query() {
389        let index = make_index();
390        let engine = QueryEngine::new(&index);
391        let results = engine.execute("calls helper");
392        assert_eq!(results.len(), 1);
393        if let OutputRow::Call(r) = &results[0] {
394            assert_eq!(r.callee, "helper");
395            assert_eq!(r.caller, "main");
396        } else {
397            panic!("expected Call result");
398        }
399    }
400
401    #[test]
402    fn test_callees_query() {
403        let index = make_index();
404        let engine = QueryEngine::new(&index);
405        let results = engine.execute("callees main");
406        assert_eq!(results.len(), 2);
407        let callees: Vec<&str> = results.iter().map(|r| {
408            if let OutputRow::Call(c) = r { c.callee.as_str() } else { "" }
409        }).collect();
410        assert!(callees.contains(&"run"));
411        assert!(callees.contains(&"helper"));
412    }
413
414    #[test]
415    fn test_files_query() {
416        let index = make_index();
417        let engine = QueryEngine::new(&index);
418        let results = engine.execute("files");
419        assert_eq!(results.len(), 3);
420        let files: Vec<&str> = results.iter().map(|r| {
421            if let OutputRow::Simple(s) = r { s.value.as_str() } else { "" }
422        }).collect();
423        assert!(files.contains(&"src/main.rs"));
424        assert!(files.contains(&"src/helper.rs"));
425        assert!(files.contains(&"src/collections.rs"));
426    }
427
428    #[test]
429    fn test_file_query() {
430        let index = make_index();
431        let engine = QueryEngine::new(&index);
432        let results = engine.execute("file src/main.rs");
433        assert_eq!(results.len(), 1);
434        if let OutputRow::File(r) = &results[0] {
435            assert_eq!(r.file, "src/main.rs");
436            assert!(r.symbols.contains(&"main".to_string()));
437        } else {
438            panic!("expected File result");
439        }
440    }
441
442    #[test]
443    fn test_file_query_partial_path() {
444        let index = make_index();
445        let engine = QueryEngine::new(&index);
446        let results = engine.execute("file main.rs");
447        assert_eq!(results.len(), 1);
448    }
449
450    #[test]
451    fn test_symbols_matching() {
452        let index = make_index();
453        let engine = QueryEngine::new(&index);
454        let results = engine.execute("symbols matching run");
455        assert_eq!(results.len(), 1);
456    }
457
458    #[test]
459    fn test_bare_name_fallback() {
460        let index = make_index();
461        let engine = QueryEngine::new(&index);
462        let results = engine.execute("main");
463        assert_eq!(results.len(), 1);
464        if let OutputRow::Query(r) = &results[0] {
465            assert_eq!(r.name, "main");
466        } else {
467            panic!("expected Query result");
468        }
469    }
470
471    #[test]
472    fn test_implements_query() {
473        let index = make_index();
474        let engine = QueryEngine::new(&index);
475        let results = engine.execute("implements nonexistent");
476        assert_eq!(results.len(), 0);
477    }
478
479    #[test]
480    fn test_empty_query() {
481        let index = make_index();
482        let engine = QueryEngine::new(&index);
483        let results = engine.execute("");
484        assert_eq!(results.len(), 0);
485    }
486
487    #[test]
488    fn test_imports_query() {
489        let index = make_index();
490        let engine = QueryEngine::new(&index);
491        let results = engine.execute("imports src/main.rs");
492        assert_eq!(results.len(), 1);
493        if let OutputRow::Import(r) = &results[0] {
494            assert_eq!(r.symbol, "HashMap");
495            assert!(r.resolved);
496            assert_eq!(r.resolved_file.as_deref(), Some("src/collections.rs"));
497            assert_eq!(r.resolved_kind.as_deref(), Some("struct"));
498        } else {
499            panic!("expected Import result");
500        }
501    }
502
503    #[test]
504    fn test_importers_query() {
505        let index = make_index();
506        let engine = QueryEngine::new(&index);
507        let results = engine.execute("importers HashMap");
508        assert_eq!(results.len(), 1);
509        if let OutputRow::Importer(r) = &results[0] {
510            assert_eq!(r.symbol, "HashMap");
511            assert_eq!(r.importer_file, "src/main.rs");
512        } else {
513            panic!("expected Importer result");
514        }
515    }
516
517    #[test]
518    fn test_imports_query_unresolved() {
519        let index = make_index();
520        let engine = QueryEngine::new(&index);
521        // helper.rs has no imports
522        let results = engine.execute("imports src/helper.rs");
523        assert_eq!(results.len(), 0);
524    }
525
526    #[test]
527    fn test_importers_query_no_results() {
528        let index = make_index();
529        let engine = QueryEngine::new(&index);
530        let results = engine.execute("importers nonexistent");
531        assert_eq!(results.len(), 0);
532    }
533}