Skip to main content

agentic_codebase/index/
semantic_search.rs

1//! Semantic Search Enhancement — Invention 9.
2//!
3//! Natural-language code search that understands intent, not just keywords.
4//! Wraps the existing `EmbeddingIndex` with query understanding and intent
5//! classification to provide more meaningful search results.
6
7use serde::{Deserialize, Serialize};
8
9use crate::graph::CodeGraph;
10use crate::index::embedding_index::{EmbeddingIndex, EmbeddingMatch};
11use crate::types::CodeUnitType;
12
13// ── Types ────────────────────────────────────────────────────────────────────
14
15/// Intent behind a semantic query.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum QueryIntent {
18    /// Looking for a function/method definition.
19    FindFunction,
20    /// Looking for a type/struct/class definition.
21    FindType,
22    /// Looking for usages / call sites.
23    FindUsage,
24    /// Looking for implementations of a concept.
25    FindImplementation,
26    /// Looking for tests.
27    FindTest,
28    /// General text search.
29    General,
30}
31
32impl QueryIntent {
33    /// Classify intent from a natural-language query.
34    pub fn classify(query: &str) -> Self {
35        let q = query.to_lowercase();
36        if q.contains("test") || q.contains("spec") || q.starts_with("how is") {
37            return Self::FindTest;
38        }
39        if q.contains("function")
40            || q.contains("method")
41            || q.contains("fn ")
42            || q.starts_with("def ")
43        {
44            return Self::FindFunction;
45        }
46        if q.contains("type")
47            || q.contains("struct")
48            || q.contains("class")
49            || q.contains("enum")
50            || q.contains("interface")
51        {
52            return Self::FindType;
53        }
54        if q.contains("usage")
55            || q.contains("call")
56            || q.contains("who uses")
57            || q.contains("where is")
58        {
59            return Self::FindUsage;
60        }
61        if q.contains("implement") || q.contains("how does") || q.contains("logic for") {
62            return Self::FindImplementation;
63        }
64        Self::General
65    }
66
67    /// Label for display.
68    pub fn label(&self) -> &str {
69        match self {
70            Self::FindFunction => "find_function",
71            Self::FindType => "find_type",
72            Self::FindUsage => "find_usage",
73            Self::FindImplementation => "find_implementation",
74            Self::FindTest => "find_test",
75            Self::General => "general",
76        }
77    }
78}
79
80/// Scope restriction for a search.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum SearchScope {
83    /// Search the entire codebase.
84    All,
85    /// Restrict to a specific module path prefix.
86    Module(String),
87    /// Restrict to a specific file.
88    File(String),
89    /// Restrict to a specific code unit type.
90    UnitType(CodeUnitType),
91}
92
93/// A semantic search query with parsed intent and scope.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct SemanticQuery {
96    /// Original query string.
97    pub raw: String,
98    /// Classified intent.
99    pub intent: QueryIntent,
100    /// Extracted keywords (lowercase).
101    pub keywords: Vec<String>,
102    /// Scope restriction.
103    pub scope: SearchScope,
104}
105
106/// A single match from semantic search.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SemanticMatch {
109    /// Unit ID.
110    pub unit_id: u64,
111    /// Unit name.
112    pub name: String,
113    /// Qualified name.
114    pub qualified_name: String,
115    /// Type label.
116    pub unit_type: String,
117    /// File path.
118    pub file_path: String,
119    /// Combined relevance score (0.0–1.0).
120    pub relevance: f64,
121    /// Why this matched.
122    pub explanation: String,
123}
124
125/// Full result of a semantic search.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct SemanticSearchResult {
128    /// The parsed query.
129    pub query: SemanticQuery,
130    /// Ranked matches.
131    pub matches: Vec<SemanticMatch>,
132    /// Total candidates scanned.
133    pub candidates_scanned: usize,
134}
135
136// ── SemanticSearchEngine ─────────────────────────────────────────────────────
137
138/// Enhanced semantic search engine wrapping `EmbeddingIndex`.
139pub struct SemanticSearchEngine<'g> {
140    graph: &'g CodeGraph,
141    embedding_index: EmbeddingIndex,
142}
143
144impl<'g> SemanticSearchEngine<'g> {
145    pub fn new(graph: &'g CodeGraph) -> Self {
146        let embedding_index = EmbeddingIndex::build(graph);
147        Self {
148            graph,
149            embedding_index,
150        }
151    }
152
153    /// Parse a natural-language query into a structured `SemanticQuery`.
154    pub fn parse_query(&self, raw: &str) -> SemanticQuery {
155        let intent = QueryIntent::classify(raw);
156        let keywords = extract_keywords(raw);
157        let scope = self.infer_scope(raw);
158
159        SemanticQuery {
160            raw: raw.to_string(),
161            intent,
162            keywords,
163            scope,
164        }
165    }
166
167    /// Perform a semantic search.
168    pub fn search(&self, raw_query: &str, top_k: usize) -> SemanticSearchResult {
169        let query = self.parse_query(raw_query);
170        let candidates_scanned = self.graph.unit_count();
171
172        // Keyword-based scoring across all units
173        let mut scored: Vec<SemanticMatch> = Vec::new();
174
175        for unit in self.graph.units() {
176            // Apply scope filtering
177            match &query.scope {
178                SearchScope::All => {}
179                SearchScope::Module(prefix) => {
180                    if !unit.qualified_name.starts_with(prefix.as_str()) {
181                        continue;
182                    }
183                }
184                SearchScope::File(path) => {
185                    if unit.file_path.display().to_string() != *path {
186                        continue;
187                    }
188                }
189                SearchScope::UnitType(ut) => {
190                    if unit.unit_type != *ut {
191                        continue;
192                    }
193                }
194            }
195
196            // Apply intent filtering
197            let intent_bonus = match query.intent {
198                QueryIntent::FindFunction => {
199                    if unit.unit_type == CodeUnitType::Function {
200                        0.15
201                    } else {
202                        0.0
203                    }
204                }
205                QueryIntent::FindType => {
206                    if unit.unit_type == CodeUnitType::Type {
207                        0.15
208                    } else {
209                        0.0
210                    }
211                }
212                QueryIntent::FindTest => {
213                    if unit.unit_type == CodeUnitType::Test {
214                        0.15
215                    } else {
216                        0.0
217                    }
218                }
219                _ => 0.0,
220            };
221
222            // Keyword scoring
223            let name_lower = unit.name.to_lowercase();
224            let qname_lower = unit.qualified_name.to_lowercase();
225
226            let mut keyword_score: f64 = 0.0;
227            let mut matched_keywords = Vec::new();
228
229            for kw in &query.keywords {
230                if name_lower == *kw {
231                    keyword_score += 0.5;
232                    matched_keywords.push(format!("exact name match '{}'", kw));
233                } else if name_lower.contains(kw.as_str()) {
234                    keyword_score += 0.3;
235                    matched_keywords.push(format!("name contains '{}'", kw));
236                } else if qname_lower.contains(kw.as_str()) {
237                    keyword_score += 0.15;
238                    matched_keywords.push(format!("qualified name contains '{}'", kw));
239                }
240            }
241
242            let total_score = (keyword_score + intent_bonus).min(1.0_f64);
243
244            if total_score > 0.1 {
245                let explanation = if matched_keywords.is_empty() {
246                    format!("Intent match: {}", query.intent.label())
247                } else {
248                    matched_keywords.join("; ")
249                };
250
251                scored.push(SemanticMatch {
252                    unit_id: unit.id,
253                    name: unit.name.clone(),
254                    qualified_name: unit.qualified_name.clone(),
255                    unit_type: unit.unit_type.label().to_string(),
256                    file_path: unit.file_path.display().to_string(),
257                    relevance: total_score,
258                    explanation,
259                });
260            }
261        }
262
263        // Sort by relevance descending
264        scored.sort_by(|a, b| {
265            b.relevance
266                .partial_cmp(&a.relevance)
267                .unwrap_or(std::cmp::Ordering::Equal)
268        });
269        scored.truncate(top_k);
270
271        SemanticSearchResult {
272            query,
273            matches: scored,
274            candidates_scanned,
275        }
276    }
277
278    /// Find units similar to a given unit by embedding similarity.
279    pub fn find_similar(&self, unit_id: u64, top_k: usize) -> Vec<SemanticMatch> {
280        let unit = match self.graph.get_unit(unit_id) {
281            Some(u) => u,
282            None => return Vec::new(),
283        };
284
285        let embedding_matches: Vec<EmbeddingMatch> =
286            self.embedding_index
287                .search(&unit.feature_vec, top_k + 1, 0.0);
288
289        embedding_matches
290            .into_iter()
291            .filter(|m| m.unit_id != unit_id)
292            .take(top_k)
293            .filter_map(|m| {
294                self.graph.get_unit(m.unit_id).map(|u| SemanticMatch {
295                    unit_id: u.id,
296                    name: u.name.clone(),
297                    qualified_name: u.qualified_name.clone(),
298                    unit_type: u.unit_type.label().to_string(),
299                    file_path: u.file_path.display().to_string(),
300                    relevance: m.score as f64,
301                    explanation: format!("Embedding similarity: {:.3}", m.score),
302                })
303            })
304            .collect()
305    }
306
307    /// Explain why a unit matched a query.
308    pub fn explain_match(&self, unit_id: u64, raw_query: &str) -> Option<String> {
309        let unit = self.graph.get_unit(unit_id)?;
310        let query = self.parse_query(raw_query);
311
312        let mut reasons = Vec::new();
313
314        for kw in &query.keywords {
315            let name_lower = unit.name.to_lowercase();
316            if name_lower.contains(kw.as_str()) {
317                reasons.push(format!("Name contains keyword '{}'", kw));
318            }
319            let qname_lower = unit.qualified_name.to_lowercase();
320            if qname_lower.contains(kw.as_str()) && !name_lower.contains(kw.as_str()) {
321                reasons.push(format!("Qualified name contains keyword '{}'", kw));
322            }
323        }
324
325        match query.intent {
326            QueryIntent::FindFunction if unit.unit_type == CodeUnitType::Function => {
327                reasons.push("Matches intent: looking for functions".to_string());
328            }
329            QueryIntent::FindType if unit.unit_type == CodeUnitType::Type => {
330                reasons.push("Matches intent: looking for types".to_string());
331            }
332            QueryIntent::FindTest if unit.unit_type == CodeUnitType::Test => {
333                reasons.push("Matches intent: looking for tests".to_string());
334            }
335            _ => {}
336        }
337
338        if reasons.is_empty() {
339            Some("No direct match found".to_string())
340        } else {
341            Some(reasons.join("; "))
342        }
343    }
344
345    // ── Internal ─────────────────────────────────────────────────────────
346
347    fn infer_scope(&self, query: &str) -> SearchScope {
348        let q = query.to_lowercase();
349        // Check for explicit file references
350        if q.contains(".rs") || q.contains(".py") || q.contains(".ts") || q.contains(".js") {
351            // Try to extract a file path
352            for word in query.split_whitespace() {
353                if word.contains('.') && !word.starts_with('.') {
354                    return SearchScope::File(word.to_string());
355                }
356            }
357        }
358        // Check for module references
359        if q.contains("in module ") || q.contains("in mod ") {
360            if let Some(rest) = q
361                .split("in module ")
362                .nth(1)
363                .or_else(|| q.split("in mod ").nth(1))
364            {
365                let module = rest.split_whitespace().next().unwrap_or("");
366                if !module.is_empty() {
367                    return SearchScope::Module(module.to_string());
368                }
369            }
370        }
371        SearchScope::All
372    }
373}
374
375// ── Helpers ──────────────────────────────────────────────────────────────────
376
377/// Extract meaningful keywords from a query string.
378fn extract_keywords(query: &str) -> Vec<String> {
379    let stop_words = [
380        "the",
381        "a",
382        "an",
383        "is",
384        "are",
385        "was",
386        "were",
387        "be",
388        "been",
389        "being",
390        "have",
391        "has",
392        "had",
393        "do",
394        "does",
395        "did",
396        "will",
397        "would",
398        "could",
399        "should",
400        "may",
401        "might",
402        "shall",
403        "can",
404        "need",
405        "dare",
406        "ought",
407        "used",
408        "to",
409        "of",
410        "in",
411        "for",
412        "on",
413        "with",
414        "at",
415        "by",
416        "from",
417        "as",
418        "into",
419        "through",
420        "during",
421        "before",
422        "after",
423        "above",
424        "below",
425        "between",
426        "out",
427        "off",
428        "over",
429        "under",
430        "again",
431        "further",
432        "then",
433        "once",
434        "here",
435        "there",
436        "when",
437        "where",
438        "why",
439        "how",
440        "all",
441        "each",
442        "every",
443        "both",
444        "few",
445        "more",
446        "most",
447        "other",
448        "some",
449        "such",
450        "no",
451        "nor",
452        "not",
453        "only",
454        "own",
455        "same",
456        "so",
457        "than",
458        "too",
459        "very",
460        "just",
461        "because",
462        "but",
463        "and",
464        "or",
465        "if",
466        "while",
467        "that",
468        "this",
469        "what",
470        "which",
471        "who",
472        "whom",
473        "find",
474        "search",
475        "look",
476        "show",
477        "get",
478        "function",
479        "method",
480        "type",
481        "struct",
482        "class",
483        "enum",
484        "test",
485        "usage",
486        "implement",
487        "call",
488    ];
489    let stop_set: std::collections::HashSet<&str> = stop_words.iter().copied().collect();
490
491    query
492        .to_lowercase()
493        .split(|c: char| !c.is_alphanumeric() && c != '_')
494        .filter(|w| w.len() >= 2 && !stop_set.contains(w))
495        .map(|w| w.to_string())
496        .collect()
497}
498
499// ── Tests ────────────────────────────────────────────────────────────────────
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504    use crate::types::{CodeUnit, CodeUnitType, Language, Span};
505    use std::path::PathBuf;
506
507    fn test_graph() -> CodeGraph {
508        let mut graph = CodeGraph::with_default_dimension();
509        graph.add_unit(CodeUnit::new(
510            CodeUnitType::Function,
511            Language::Rust,
512            "process_payment".to_string(),
513            "billing::process_payment".to_string(),
514            PathBuf::from("src/billing.rs"),
515            Span::new(1, 0, 20, 0),
516        ));
517        graph.add_unit(CodeUnit::new(
518            CodeUnitType::Type,
519            Language::Rust,
520            "PaymentResult".to_string(),
521            "billing::PaymentResult".to_string(),
522            PathBuf::from("src/billing.rs"),
523            Span::new(21, 0, 30, 0),
524        ));
525        graph.add_unit(CodeUnit::new(
526            CodeUnitType::Test,
527            Language::Rust,
528            "test_payment".to_string(),
529            "tests::test_payment".to_string(),
530            PathBuf::from("tests/billing_test.rs"),
531            Span::new(1, 0, 15, 0),
532        ));
533        graph
534    }
535
536    #[test]
537    fn classify_intent() {
538        assert_eq!(
539            QueryIntent::classify("find function process_payment"),
540            QueryIntent::FindFunction
541        );
542        assert_eq!(
543            QueryIntent::classify("show me the struct User"),
544            QueryIntent::FindType
545        );
546        assert_eq!(
547            QueryIntent::classify("test for payment"),
548            QueryIntent::FindTest
549        );
550        assert_eq!(
551            QueryIntent::classify("payment processing"),
552            QueryIntent::General
553        );
554    }
555
556    #[test]
557    fn keyword_search() {
558        let graph = test_graph();
559        let engine = SemanticSearchEngine::new(&graph);
560        let result = engine.search("payment", 10);
561        assert!(result.matches.len() >= 2); // process_payment and PaymentResult
562    }
563
564    #[test]
565    fn intent_boosts_correct_type() {
566        let graph = test_graph();
567        let engine = SemanticSearchEngine::new(&graph);
568        let result = engine.search("function payment", 10);
569        // Function intent should boost process_payment over PaymentResult
570        if result.matches.len() >= 2 {
571            assert_eq!(result.matches[0].unit_type, "function");
572        }
573    }
574
575    #[test]
576    fn explain_match_works() {
577        let graph = test_graph();
578        let engine = SemanticSearchEngine::new(&graph);
579        let explanation = engine.explain_match(0, "payment");
580        assert!(explanation.is_some());
581        assert!(explanation.unwrap().contains("payment"));
582    }
583}