1use crate::graph::CodeGraph;
6
7use super::{Evidence, Grounded, GroundingResult};
8
9const STOP_WORDS: &[&str] = &[
11 "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
12 "do", "does", "did", "will", "would", "shall", "should", "may", "might", "must", "can",
13 "could", "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "about",
14 "between", "through", "during", "before", "after", "above", "below", "up", "down", "out",
15 "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where",
16 "why", "how", "all", "each", "every", "both", "few", "more", "most", "other", "some", "such",
17 "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just", "because",
18 "but", "and", "or", "if", "while", "that", "this", "these", "those", "it", "its", "my", "your",
19 "his", "her", "our", "their", "what", "which", "who", "whom", "we", "you", "he", "she", "they",
20 "me", "him", "us", "them", "i",
21];
22
23fn is_snake_case(s: &str) -> bool {
29 let chars: Vec<char> = s.chars().collect();
30 if chars.is_empty() {
31 return false;
32 }
33 if !chars[0].is_ascii_lowercase() {
35 return false;
36 }
37 if !s.contains('_') {
39 return false;
40 }
41 if !chars
43 .iter()
44 .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || *c == '_')
45 {
46 return false;
47 }
48 if s.starts_with('_') || s.ends_with('_') || s.contains("__") {
50 return false;
51 }
52 for segment in s.split('_') {
54 if segment.is_empty() {
55 return false;
56 }
57 }
58 true
59}
60
61fn is_camel_case(s: &str) -> bool {
65 let chars: Vec<char> = s.chars().collect();
66 if chars.len() < 2 {
67 return false;
68 }
69 if !chars[0].is_ascii_uppercase() {
71 return false;
72 }
73 if !chars.iter().all(|c| c.is_ascii_alphanumeric()) {
75 return false;
76 }
77 let upper_count = chars.iter().filter(|c| c.is_ascii_uppercase()).count();
80 if upper_count < 2 {
81 return false;
82 }
83 let has_lower_after_first = chars[1..].iter().any(|c| c.is_ascii_lowercase());
85 if !has_lower_after_first {
86 return false;
87 }
88 true
89}
90
91fn is_screaming_case(s: &str) -> bool {
95 let chars: Vec<char> = s.chars().collect();
96 if chars.is_empty() {
97 return false;
98 }
99 if !chars[0].is_ascii_uppercase() {
101 return false;
102 }
103 if !s.contains('_') {
105 return false;
106 }
107 if !chars
109 .iter()
110 .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_')
111 {
112 return false;
113 }
114 if s.starts_with('_') || s.ends_with('_') || s.contains("__") {
116 return false;
117 }
118 for segment in s.split('_') {
119 if segment.is_empty() {
120 return false;
121 }
122 }
123 true
124}
125
126fn is_stop_word(word: &str) -> bool {
128 STOP_WORDS.contains(&word.to_lowercase().as_str())
129}
130
131pub fn extract_code_references(claim: &str) -> Vec<String> {
143 let mut refs: Vec<String> = Vec::new();
144
145 let mut in_backtick = false;
147 let mut buf = String::new();
148 for ch in claim.chars() {
149 if ch == '`' {
150 if in_backtick {
151 let trimmed = buf.trim().to_string();
152 if !trimmed.is_empty() && !is_stop_word(&trimmed) {
153 refs.push(trimmed);
154 }
155 buf.clear();
156 }
157 in_backtick = !in_backtick;
158 } else if in_backtick {
159 buf.push(ch);
160 }
161 }
162
163 let tokens: Vec<&str> = claim
166 .split(|c: char| !c.is_ascii_alphanumeric() && c != '_')
167 .filter(|t| !t.is_empty())
168 .collect();
169
170 for token in &tokens {
171 if is_stop_word(token) {
172 continue;
173 }
174 if is_snake_case(token) || is_camel_case(token) || is_screaming_case(token) {
175 let s = (*token).to_string();
176 if !refs.contains(&s) {
177 refs.push(s);
178 }
179 }
180 }
181
182 refs
183}
184
185fn levenshtein(a: &str, b: &str) -> usize {
192 let a_chars: Vec<char> = a.chars().collect();
193 let b_chars: Vec<char> = b.chars().collect();
194 let m = a_chars.len();
195 let n = b_chars.len();
196
197 if m == 0 {
198 return n;
199 }
200 if n == 0 {
201 return m;
202 }
203
204 let mut prev: Vec<usize> = (0..=n).collect();
206 let mut curr: Vec<usize> = vec![0; n + 1];
207
208 for i in 1..=m {
209 curr[0] = i;
210 for j in 1..=n {
211 let cost = if a_chars[i - 1] == b_chars[j - 1] {
212 0
213 } else {
214 1
215 };
216 curr[j] = (prev[j] + 1) .min(curr[j - 1] + 1) .min(prev[j - 1] + cost); }
220 std::mem::swap(&mut prev, &mut curr);
221 }
222
223 prev[n]
224}
225
226pub struct GroundingEngine<'g> {
246 graph: &'g CodeGraph,
247}
248
249impl<'g> GroundingEngine<'g> {
250 pub fn new(graph: &'g CodeGraph) -> Self {
252 Self { graph }
253 }
254
255 fn evidence_from_unit(unit: &crate::types::CodeUnit) -> Evidence {
257 Evidence {
258 node_id: unit.id,
259 node_type: unit.unit_type.label().to_string(),
260 name: unit.name.clone(),
261 file_path: unit.file_path.display().to_string(),
262 line_number: Some(unit.span.start_line),
263 snippet: unit.signature.clone(),
264 }
265 }
266}
267
268impl<'g> Grounded for GroundingEngine<'g> {
269 fn ground_claim(&self, claim: &str) -> GroundingResult {
270 let refs = extract_code_references(claim);
271
272 if refs.is_empty() {
274 return GroundingResult::Ungrounded {
275 claim: claim.to_string(),
276 suggestions: Vec::new(),
277 };
278 }
279
280 let mut all_evidence: Vec<Evidence> = Vec::new();
281 let mut supported: Vec<String> = Vec::new();
282 let mut unsupported: Vec<String> = Vec::new();
283
284 for reference in &refs {
285 let evidence = self.find_evidence(reference);
286 if evidence.is_empty() {
287 unsupported.push(reference.clone());
288 } else {
289 supported.push(reference.clone());
290 all_evidence.extend(evidence);
291 }
292 }
293
294 if unsupported.is_empty() {
295 let confidence = 1.0_f32; GroundingResult::Verified {
298 evidence: all_evidence,
299 confidence,
300 }
301 } else if supported.is_empty() {
302 let mut suggestions: Vec<String> = Vec::new();
304 for u in &unsupported {
305 suggestions.extend(self.suggest_similar(u, 3));
306 }
307 suggestions.sort();
309 suggestions.dedup();
310 GroundingResult::Ungrounded {
311 claim: claim.to_string(),
312 suggestions,
313 }
314 } else {
315 let mut suggestions: Vec<String> = Vec::new();
317 for u in &unsupported {
318 suggestions.extend(self.suggest_similar(u, 3));
319 }
320 suggestions.sort();
321 suggestions.dedup();
322 GroundingResult::Partial {
323 supported,
324 unsupported,
325 suggestions,
326 }
327 }
328 }
329
330 fn find_evidence(&self, name: &str) -> Vec<Evidence> {
331 let mut results: Vec<Evidence> = Vec::new();
332
333 for unit in self.graph.units() {
335 if unit.name == name {
336 results.push(Self::evidence_from_unit(unit));
337 }
338 }
339
340 if results.is_empty() {
342 for unit in self.graph.units() {
343 if unit.qualified_name.contains(name) {
344 results.push(Self::evidence_from_unit(unit));
345 }
346 }
347 }
348
349 if results.is_empty() {
351 let lower = name.to_lowercase();
352 for unit in self.graph.units() {
353 if unit.name.to_lowercase() == lower {
354 results.push(Self::evidence_from_unit(unit));
355 }
356 }
357 }
358
359 results
360 }
361
362 fn suggest_similar(&self, name: &str, limit: usize) -> Vec<String> {
363 let lower = name.to_lowercase();
364 let threshold = name.len() / 2;
365
366 let mut candidates: Vec<(String, usize)> = Vec::new();
367
368 for unit in self.graph.units() {
369 let unit_lower = unit.name.to_lowercase();
370
371 if unit_lower.starts_with(&lower) || lower.starts_with(&unit_lower) {
373 if !candidates.iter().any(|(n, _)| *n == unit.name) {
374 candidates.push((unit.name.clone(), 0));
375 }
376 continue;
377 }
378
379 let dist = levenshtein(&lower, &unit_lower);
381 if dist <= threshold && dist > 0 && !candidates.iter().any(|(n, _)| *n == unit.name) {
382 candidates.push((unit.name.clone(), dist));
383 }
384 }
385
386 candidates.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
388
389 candidates
390 .into_iter()
391 .take(limit)
392 .map(|(name, _)| name)
393 .collect()
394 }
395}
396
397#[cfg(test)]
400mod tests {
401 use super::*;
402 use crate::types::{CodeUnit, CodeUnitType, Language, Span};
403 use std::path::PathBuf;
404
405 fn test_graph() -> CodeGraph {
407 let mut graph = CodeGraph::with_default_dimension();
408
409 graph.add_unit(CodeUnit::new(
410 CodeUnitType::Function,
411 Language::Python,
412 "process_payment".to_string(),
413 "payments.stripe.process_payment".to_string(),
414 PathBuf::from("src/payments/stripe.py"),
415 Span::new(10, 0, 30, 0),
416 ));
417
418 graph.add_unit(CodeUnit::new(
419 CodeUnitType::Type,
420 Language::Rust,
421 "CodeGraph".to_string(),
422 "crate::graph::CodeGraph".to_string(),
423 PathBuf::from("src/graph/code_graph.rs"),
424 Span::new(17, 0, 250, 0),
425 ));
426
427 graph.add_unit(CodeUnit::new(
428 CodeUnitType::Function,
429 Language::Rust,
430 "add_unit".to_string(),
431 "crate::graph::CodeGraph::add_unit".to_string(),
432 PathBuf::from("src/graph/code_graph.rs"),
433 Span::new(58, 0, 64, 0),
434 ));
435
436 graph.add_unit(CodeUnit::new(
437 CodeUnitType::Config,
438 Language::Rust,
439 "MAX_EDGES_PER_UNIT".to_string(),
440 "crate::types::MAX_EDGES_PER_UNIT".to_string(),
441 PathBuf::from("src/types/mod.rs"),
442 Span::new(40, 0, 40, 0),
443 ));
444
445 graph.add_unit(CodeUnit::new(
446 CodeUnitType::Function,
447 Language::Python,
448 "validate_amount".to_string(),
449 "payments.utils.validate_amount".to_string(),
450 PathBuf::from("src/payments/utils.py"),
451 Span::new(5, 0, 15, 0),
452 ));
453
454 graph
455 }
456
457 #[test]
460 fn extract_snake_case_refs() {
461 let refs = extract_code_references("The process_payment function validates the amount");
462 assert!(refs.contains(&"process_payment".to_string()));
463 }
464
465 #[test]
466 fn extract_camel_case_refs() {
467 let refs = extract_code_references("The CodeGraph struct holds all units");
468 assert!(refs.contains(&"CodeGraph".to_string()));
469 }
470
471 #[test]
472 fn extract_screaming_case_refs() {
473 let refs = extract_code_references("The constant MAX_EDGES_PER_UNIT limits the edge count");
474 assert!(refs.contains(&"MAX_EDGES_PER_UNIT".to_string()));
475 }
476
477 #[test]
478 fn extract_backtick_refs() {
479 let refs = extract_code_references("Call `add_unit` to insert a node");
480 assert!(refs.contains(&"add_unit".to_string()));
481 }
482
483 #[test]
484 fn extract_mixed_refs() {
485 let refs = extract_code_references(
486 "The `process_payment` function in CodeGraph uses MAX_EDGES_PER_UNIT",
487 );
488 assert!(refs.contains(&"process_payment".to_string()));
489 assert!(refs.contains(&"CodeGraph".to_string()));
490 assert!(refs.contains(&"MAX_EDGES_PER_UNIT".to_string()));
491 }
492
493 #[test]
494 fn extract_filters_stop_words() {
495 let refs = extract_code_references("the is a an in on");
496 assert!(refs.is_empty());
497 }
498
499 #[test]
500 fn extract_no_duplicates() {
501 let refs = extract_code_references(
502 "`process_payment` calls process_payment to handle the process_payment flow",
503 );
504 let count = refs.iter().filter(|r| *r == "process_payment").count();
505 assert_eq!(count, 1);
506 }
507
508 #[test]
511 fn ground_verified_claim() {
512 let graph = test_graph();
513 let engine = GroundingEngine::new(&graph);
514
515 let result = engine.ground_claim("The process_payment function exists");
516 match result {
517 GroundingResult::Verified {
518 evidence,
519 confidence,
520 } => {
521 assert!(!evidence.is_empty());
522 assert!(confidence > 0.0);
523 assert_eq!(evidence[0].name, "process_payment");
524 }
525 other => panic!("Expected Verified, got {:?}", other),
526 }
527 }
528
529 #[test]
530 fn ground_ungrounded_claim() {
531 let graph = test_graph();
532 let engine = GroundingEngine::new(&graph);
533
534 let result = engine.ground_claim("The send_invoice function sends emails");
535 match result {
536 GroundingResult::Ungrounded { claim, .. } => {
537 assert!(claim.contains("send_invoice"));
538 }
539 other => panic!("Expected Ungrounded, got {:?}", other),
540 }
541 }
542
543 #[test]
544 fn ground_partial_claim() {
545 let graph = test_graph();
546 let engine = GroundingEngine::new(&graph);
547
548 let result = engine.ground_claim("process_payment calls send_notification after success");
549 match result {
550 GroundingResult::Partial {
551 supported,
552 unsupported,
553 ..
554 } => {
555 assert!(supported.contains(&"process_payment".to_string()));
556 assert!(unsupported.contains(&"send_notification".to_string()));
557 }
558 other => panic!("Expected Partial, got {:?}", other),
559 }
560 }
561
562 #[test]
563 fn ground_no_refs_is_ungrounded() {
564 let graph = test_graph();
565 let engine = GroundingEngine::new(&graph);
566
567 let result = engine.ground_claim("This is a normal English sentence.");
568 assert!(matches!(result, GroundingResult::Ungrounded { .. }));
569 }
570
571 #[test]
574 fn find_evidence_exact_name() {
575 let graph = test_graph();
576 let engine = GroundingEngine::new(&graph);
577
578 let ev = engine.find_evidence("add_unit");
579 assert_eq!(ev.len(), 1);
580 assert_eq!(ev[0].name, "add_unit");
581 assert_eq!(ev[0].node_type, "function");
582 }
583
584 #[test]
585 fn find_evidence_qualified_fallback() {
586 let graph = test_graph();
587 let engine = GroundingEngine::new(&graph);
588
589 let ev = engine.find_evidence("stripe");
591 assert!(!ev.is_empty());
592 assert_eq!(ev[0].name, "process_payment");
593 }
594
595 #[test]
596 fn find_evidence_case_insensitive_fallback() {
597 let graph = test_graph();
598 let engine = GroundingEngine::new(&graph);
599
600 let ev = engine.find_evidence("codegraph");
601 assert!(!ev.is_empty());
602 assert_eq!(ev[0].name, "CodeGraph");
603 }
604
605 #[test]
606 fn find_evidence_nonexistent() {
607 let graph = test_graph();
608 let engine = GroundingEngine::new(&graph);
609
610 let ev = engine.find_evidence("nonexistent_function");
611 assert!(ev.is_empty());
612 }
613
614 #[test]
617 fn suggest_similar_typo() {
618 let graph = test_graph();
619 let engine = GroundingEngine::new(&graph);
620
621 let suggestions = engine.suggest_similar("process_paymnt", 5);
622 assert!(
623 suggestions.contains(&"process_payment".to_string()),
624 "Expected process_payment in {:?}",
625 suggestions
626 );
627 }
628
629 #[test]
630 fn suggest_similar_prefix() {
631 let graph = test_graph();
632 let engine = GroundingEngine::new(&graph);
633
634 let suggestions = engine.suggest_similar("add", 5);
635 assert!(
636 suggestions.contains(&"add_unit".to_string()),
637 "Expected add_unit in {:?}",
638 suggestions
639 );
640 }
641
642 #[test]
643 fn suggest_similar_respects_limit() {
644 let graph = test_graph();
645 let engine = GroundingEngine::new(&graph);
646
647 let suggestions = engine.suggest_similar("a", 2);
648 assert!(suggestions.len() <= 2);
649 }
650
651 #[test]
654 fn levenshtein_identical() {
655 assert_eq!(levenshtein("hello", "hello"), 0);
656 }
657
658 #[test]
659 fn levenshtein_one_edit() {
660 assert_eq!(levenshtein("kitten", "sitten"), 1);
661 }
662
663 #[test]
664 fn levenshtein_full_diff() {
665 assert_eq!(levenshtein("abc", "xyz"), 3);
666 }
667
668 #[test]
669 fn levenshtein_empty() {
670 assert_eq!(levenshtein("", "hello"), 5);
671 assert_eq!(levenshtein("hello", ""), 5);
672 assert_eq!(levenshtein("", ""), 0);
673 }
674
675 #[test]
678 fn test_is_snake_case() {
679 assert!(is_snake_case("process_payment"));
680 assert!(is_snake_case("add_unit"));
681 assert!(is_snake_case("a_b"));
682 assert!(!is_snake_case("process")); assert!(!is_snake_case("ProcessPayment")); assert!(!is_snake_case("_leading"));
685 assert!(!is_snake_case("trailing_"));
686 assert!(!is_snake_case("double__under"));
687 }
688
689 #[test]
690 fn test_is_camel_case() {
691 assert!(is_camel_case("CodeGraph"));
692 assert!(is_camel_case("GroundingEngine"));
693 assert!(is_camel_case("MyType2"));
694 assert!(!is_camel_case("codegraph")); assert!(!is_camel_case("CODEGRAPH")); assert!(!is_camel_case("A")); assert!(!is_camel_case("Code")); }
699
700 #[test]
701 fn test_is_screaming_case() {
702 assert!(is_screaming_case("MAX_EDGES_PER_UNIT"));
703 assert!(is_screaming_case("API_KEY"));
704 assert!(!is_screaming_case("max_edges")); assert!(!is_screaming_case("NOUNDERSCORES")); assert!(!is_screaming_case("_LEADING"));
707 assert!(!is_screaming_case("TRAILING_"));
708 }
709}