1use serde::{Deserialize, Serialize};
8
9use crate::graph::CodeGraph;
10use crate::index::embedding_index::{EmbeddingIndex, EmbeddingMatch};
11use crate::types::CodeUnitType;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17pub enum QueryIntent {
18 FindFunction,
20 FindType,
22 FindUsage,
24 FindImplementation,
26 FindTest,
28 General,
30}
31
32impl QueryIntent {
33 pub fn classify(query: &str) -> Self {
35 let q = query.to_lowercase();
36 if q.contains("test") || q.contains("spec") || q.starts_with("how is") {
37 return Self::FindTest;
38 }
39 if q.contains("function")
40 || q.contains("method")
41 || q.contains("fn ")
42 || q.starts_with("def ")
43 {
44 return Self::FindFunction;
45 }
46 if q.contains("type")
47 || q.contains("struct")
48 || q.contains("class")
49 || q.contains("enum")
50 || q.contains("interface")
51 {
52 return Self::FindType;
53 }
54 if q.contains("usage")
55 || q.contains("call")
56 || q.contains("who uses")
57 || q.contains("where is")
58 {
59 return Self::FindUsage;
60 }
61 if q.contains("implement") || q.contains("how does") || q.contains("logic for") {
62 return Self::FindImplementation;
63 }
64 Self::General
65 }
66
67 pub fn label(&self) -> &str {
69 match self {
70 Self::FindFunction => "find_function",
71 Self::FindType => "find_type",
72 Self::FindUsage => "find_usage",
73 Self::FindImplementation => "find_implementation",
74 Self::FindTest => "find_test",
75 Self::General => "general",
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum SearchScope {
83 All,
85 Module(String),
87 File(String),
89 UnitType(CodeUnitType),
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct SemanticQuery {
96 pub raw: String,
98 pub intent: QueryIntent,
100 pub keywords: Vec<String>,
102 pub scope: SearchScope,
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct SemanticMatch {
109 pub unit_id: u64,
111 pub name: String,
113 pub qualified_name: String,
115 pub unit_type: String,
117 pub file_path: String,
119 pub relevance: f64,
121 pub explanation: String,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct SemanticSearchResult {
128 pub query: SemanticQuery,
130 pub matches: Vec<SemanticMatch>,
132 pub candidates_scanned: usize,
134}
135
136pub struct SemanticSearchEngine<'g> {
140 graph: &'g CodeGraph,
141 embedding_index: EmbeddingIndex,
142}
143
144impl<'g> SemanticSearchEngine<'g> {
145 pub fn new(graph: &'g CodeGraph) -> Self {
146 let embedding_index = EmbeddingIndex::build(graph);
147 Self {
148 graph,
149 embedding_index,
150 }
151 }
152
153 pub fn parse_query(&self, raw: &str) -> SemanticQuery {
155 let intent = QueryIntent::classify(raw);
156 let keywords = extract_keywords(raw);
157 let scope = self.infer_scope(raw);
158
159 SemanticQuery {
160 raw: raw.to_string(),
161 intent,
162 keywords,
163 scope,
164 }
165 }
166
167 pub fn search(&self, raw_query: &str, top_k: usize) -> SemanticSearchResult {
169 let query = self.parse_query(raw_query);
170 let candidates_scanned = self.graph.unit_count();
171
172 let mut scored: Vec<SemanticMatch> = Vec::new();
174
175 for unit in self.graph.units() {
176 match &query.scope {
178 SearchScope::All => {}
179 SearchScope::Module(prefix) => {
180 if !unit.qualified_name.starts_with(prefix.as_str()) {
181 continue;
182 }
183 }
184 SearchScope::File(path) => {
185 if unit.file_path.display().to_string() != *path {
186 continue;
187 }
188 }
189 SearchScope::UnitType(ut) => {
190 if unit.unit_type != *ut {
191 continue;
192 }
193 }
194 }
195
196 let intent_bonus = match query.intent {
198 QueryIntent::FindFunction => {
199 if unit.unit_type == CodeUnitType::Function {
200 0.15
201 } else {
202 0.0
203 }
204 }
205 QueryIntent::FindType => {
206 if unit.unit_type == CodeUnitType::Type {
207 0.15
208 } else {
209 0.0
210 }
211 }
212 QueryIntent::FindTest => {
213 if unit.unit_type == CodeUnitType::Test {
214 0.15
215 } else {
216 0.0
217 }
218 }
219 _ => 0.0,
220 };
221
222 let name_lower = unit.name.to_lowercase();
224 let qname_lower = unit.qualified_name.to_lowercase();
225
226 let mut keyword_score: f64 = 0.0;
227 let mut matched_keywords = Vec::new();
228
229 for kw in &query.keywords {
230 if name_lower == *kw {
231 keyword_score += 0.5;
232 matched_keywords.push(format!("exact name match '{}'", kw));
233 } else if name_lower.contains(kw.as_str()) {
234 keyword_score += 0.3;
235 matched_keywords.push(format!("name contains '{}'", kw));
236 } else if qname_lower.contains(kw.as_str()) {
237 keyword_score += 0.15;
238 matched_keywords.push(format!("qualified name contains '{}'", kw));
239 }
240 }
241
242 let total_score = (keyword_score + intent_bonus).min(1.0_f64);
243
244 if total_score > 0.1 {
245 let explanation = if matched_keywords.is_empty() {
246 format!("Intent match: {}", query.intent.label())
247 } else {
248 matched_keywords.join("; ")
249 };
250
251 scored.push(SemanticMatch {
252 unit_id: unit.id,
253 name: unit.name.clone(),
254 qualified_name: unit.qualified_name.clone(),
255 unit_type: unit.unit_type.label().to_string(),
256 file_path: unit.file_path.display().to_string(),
257 relevance: total_score,
258 explanation,
259 });
260 }
261 }
262
263 scored.sort_by(|a, b| {
265 b.relevance
266 .partial_cmp(&a.relevance)
267 .unwrap_or(std::cmp::Ordering::Equal)
268 });
269 scored.truncate(top_k);
270
271 SemanticSearchResult {
272 query,
273 matches: scored,
274 candidates_scanned,
275 }
276 }
277
278 pub fn find_similar(&self, unit_id: u64, top_k: usize) -> Vec<SemanticMatch> {
280 let unit = match self.graph.get_unit(unit_id) {
281 Some(u) => u,
282 None => return Vec::new(),
283 };
284
285 let embedding_matches: Vec<EmbeddingMatch> =
286 self.embedding_index
287 .search(&unit.feature_vec, top_k + 1, 0.0);
288
289 embedding_matches
290 .into_iter()
291 .filter(|m| m.unit_id != unit_id)
292 .take(top_k)
293 .filter_map(|m| {
294 self.graph.get_unit(m.unit_id).map(|u| SemanticMatch {
295 unit_id: u.id,
296 name: u.name.clone(),
297 qualified_name: u.qualified_name.clone(),
298 unit_type: u.unit_type.label().to_string(),
299 file_path: u.file_path.display().to_string(),
300 relevance: m.score as f64,
301 explanation: format!("Embedding similarity: {:.3}", m.score),
302 })
303 })
304 .collect()
305 }
306
307 pub fn explain_match(&self, unit_id: u64, raw_query: &str) -> Option<String> {
309 let unit = self.graph.get_unit(unit_id)?;
310 let query = self.parse_query(raw_query);
311
312 let mut reasons = Vec::new();
313
314 for kw in &query.keywords {
315 let name_lower = unit.name.to_lowercase();
316 if name_lower.contains(kw.as_str()) {
317 reasons.push(format!("Name contains keyword '{}'", kw));
318 }
319 let qname_lower = unit.qualified_name.to_lowercase();
320 if qname_lower.contains(kw.as_str()) && !name_lower.contains(kw.as_str()) {
321 reasons.push(format!("Qualified name contains keyword '{}'", kw));
322 }
323 }
324
325 match query.intent {
326 QueryIntent::FindFunction if unit.unit_type == CodeUnitType::Function => {
327 reasons.push("Matches intent: looking for functions".to_string());
328 }
329 QueryIntent::FindType if unit.unit_type == CodeUnitType::Type => {
330 reasons.push("Matches intent: looking for types".to_string());
331 }
332 QueryIntent::FindTest if unit.unit_type == CodeUnitType::Test => {
333 reasons.push("Matches intent: looking for tests".to_string());
334 }
335 _ => {}
336 }
337
338 if reasons.is_empty() {
339 Some("No direct match found".to_string())
340 } else {
341 Some(reasons.join("; "))
342 }
343 }
344
345 fn infer_scope(&self, query: &str) -> SearchScope {
348 let q = query.to_lowercase();
349 if q.contains(".rs") || q.contains(".py") || q.contains(".ts") || q.contains(".js") {
351 for word in query.split_whitespace() {
353 if word.contains('.') && !word.starts_with('.') {
354 return SearchScope::File(word.to_string());
355 }
356 }
357 }
358 if q.contains("in module ") || q.contains("in mod ") {
360 if let Some(rest) = q
361 .split("in module ")
362 .nth(1)
363 .or_else(|| q.split("in mod ").nth(1))
364 {
365 let module = rest.split_whitespace().next().unwrap_or("");
366 if !module.is_empty() {
367 return SearchScope::Module(module.to_string());
368 }
369 }
370 }
371 SearchScope::All
372 }
373}
374
375fn extract_keywords(query: &str) -> Vec<String> {
379 let stop_words = [
380 "the",
381 "a",
382 "an",
383 "is",
384 "are",
385 "was",
386 "were",
387 "be",
388 "been",
389 "being",
390 "have",
391 "has",
392 "had",
393 "do",
394 "does",
395 "did",
396 "will",
397 "would",
398 "could",
399 "should",
400 "may",
401 "might",
402 "shall",
403 "can",
404 "need",
405 "dare",
406 "ought",
407 "used",
408 "to",
409 "of",
410 "in",
411 "for",
412 "on",
413 "with",
414 "at",
415 "by",
416 "from",
417 "as",
418 "into",
419 "through",
420 "during",
421 "before",
422 "after",
423 "above",
424 "below",
425 "between",
426 "out",
427 "off",
428 "over",
429 "under",
430 "again",
431 "further",
432 "then",
433 "once",
434 "here",
435 "there",
436 "when",
437 "where",
438 "why",
439 "how",
440 "all",
441 "each",
442 "every",
443 "both",
444 "few",
445 "more",
446 "most",
447 "other",
448 "some",
449 "such",
450 "no",
451 "nor",
452 "not",
453 "only",
454 "own",
455 "same",
456 "so",
457 "than",
458 "too",
459 "very",
460 "just",
461 "because",
462 "but",
463 "and",
464 "or",
465 "if",
466 "while",
467 "that",
468 "this",
469 "what",
470 "which",
471 "who",
472 "whom",
473 "find",
474 "search",
475 "look",
476 "show",
477 "get",
478 "function",
479 "method",
480 "type",
481 "struct",
482 "class",
483 "enum",
484 "test",
485 "usage",
486 "implement",
487 "call",
488 ];
489 let stop_set: std::collections::HashSet<&str> = stop_words.iter().copied().collect();
490
491 query
492 .to_lowercase()
493 .split(|c: char| !c.is_alphanumeric() && c != '_')
494 .filter(|w| w.len() >= 2 && !stop_set.contains(w))
495 .map(|w| w.to_string())
496 .collect()
497}
498
499#[cfg(test)]
502mod tests {
503 use super::*;
504 use crate::types::{CodeUnit, CodeUnitType, Language, Span};
505 use std::path::PathBuf;
506
507 fn test_graph() -> CodeGraph {
508 let mut graph = CodeGraph::with_default_dimension();
509 graph.add_unit(CodeUnit::new(
510 CodeUnitType::Function,
511 Language::Rust,
512 "process_payment".to_string(),
513 "billing::process_payment".to_string(),
514 PathBuf::from("src/billing.rs"),
515 Span::new(1, 0, 20, 0),
516 ));
517 graph.add_unit(CodeUnit::new(
518 CodeUnitType::Type,
519 Language::Rust,
520 "PaymentResult".to_string(),
521 "billing::PaymentResult".to_string(),
522 PathBuf::from("src/billing.rs"),
523 Span::new(21, 0, 30, 0),
524 ));
525 graph.add_unit(CodeUnit::new(
526 CodeUnitType::Test,
527 Language::Rust,
528 "test_payment".to_string(),
529 "tests::test_payment".to_string(),
530 PathBuf::from("tests/billing_test.rs"),
531 Span::new(1, 0, 15, 0),
532 ));
533 graph
534 }
535
536 #[test]
537 fn classify_intent() {
538 assert_eq!(
539 QueryIntent::classify("find function process_payment"),
540 QueryIntent::FindFunction
541 );
542 assert_eq!(
543 QueryIntent::classify("show me the struct User"),
544 QueryIntent::FindType
545 );
546 assert_eq!(
547 QueryIntent::classify("test for payment"),
548 QueryIntent::FindTest
549 );
550 assert_eq!(
551 QueryIntent::classify("payment processing"),
552 QueryIntent::General
553 );
554 }
555
556 #[test]
557 fn keyword_search() {
558 let graph = test_graph();
559 let engine = SemanticSearchEngine::new(&graph);
560 let result = engine.search("payment", 10);
561 assert!(result.matches.len() >= 2); }
563
564 #[test]
565 fn intent_boosts_correct_type() {
566 let graph = test_graph();
567 let engine = SemanticSearchEngine::new(&graph);
568 let result = engine.search("function payment", 10);
569 if result.matches.len() >= 2 {
571 assert_eq!(result.matches[0].unit_type, "function");
572 }
573 }
574
575 #[test]
576 fn explain_match_works() {
577 let graph = test_graph();
578 let engine = SemanticSearchEngine::new(&graph);
579 let explanation = engine.explain_match(0, "payment");
580 assert!(explanation.is_some());
581 assert!(explanation.unwrap().contains("payment"));
582 }
583}