1use crate::graph::CodeGraph;
6
7use super::{Evidence, Grounded, GroundingResult};
8
9const STOP_WORDS: &[&str] = &[
11 "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
12 "do", "does", "did", "will", "would", "shall", "should", "may", "might", "must", "can",
13 "could", "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "about",
14 "between", "through", "during", "before", "after", "above", "below", "up", "down", "out",
15 "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where",
16 "why", "how", "all", "each", "every", "both", "few", "more", "most", "other", "some", "such",
17 "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just", "because",
18 "but", "and", "or", "if", "while", "that", "this", "these", "those", "it", "its", "my",
19 "your", "his", "her", "our", "their", "what", "which", "who", "whom", "we", "you", "he",
20 "she", "they", "me", "him", "us", "them", "i",
21];
22
23fn is_snake_case(s: &str) -> bool {
29 let chars: Vec<char> = s.chars().collect();
30 if chars.is_empty() {
31 return false;
32 }
33 if !chars[0].is_ascii_lowercase() {
35 return false;
36 }
37 if !s.contains('_') {
39 return false;
40 }
41 if !chars
43 .iter()
44 .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || *c == '_')
45 {
46 return false;
47 }
48 if s.starts_with('_') || s.ends_with('_') || s.contains("__") {
50 return false;
51 }
52 for segment in s.split('_') {
54 if segment.is_empty() {
55 return false;
56 }
57 }
58 true
59}
60
61fn is_camel_case(s: &str) -> bool {
65 let chars: Vec<char> = s.chars().collect();
66 if chars.len() < 2 {
67 return false;
68 }
69 if !chars[0].is_ascii_uppercase() {
71 return false;
72 }
73 if !chars.iter().all(|c| c.is_ascii_alphanumeric()) {
75 return false;
76 }
77 let upper_count = chars.iter().filter(|c| c.is_ascii_uppercase()).count();
80 if upper_count < 2 {
81 return false;
82 }
83 let has_lower_after_first = chars[1..].iter().any(|c| c.is_ascii_lowercase());
85 if !has_lower_after_first {
86 return false;
87 }
88 true
89}
90
91fn is_screaming_case(s: &str) -> bool {
95 let chars: Vec<char> = s.chars().collect();
96 if chars.is_empty() {
97 return false;
98 }
99 if !chars[0].is_ascii_uppercase() {
101 return false;
102 }
103 if !s.contains('_') {
105 return false;
106 }
107 if !chars
109 .iter()
110 .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_')
111 {
112 return false;
113 }
114 if s.starts_with('_') || s.ends_with('_') || s.contains("__") {
116 return false;
117 }
118 for segment in s.split('_') {
119 if segment.is_empty() {
120 return false;
121 }
122 }
123 true
124}
125
126fn is_stop_word(word: &str) -> bool {
128 STOP_WORDS.contains(&word.to_lowercase().as_str())
129}
130
131pub fn extract_code_references(claim: &str) -> Vec<String> {
143 let mut refs: Vec<String> = Vec::new();
144
145 let mut in_backtick = false;
147 let mut buf = String::new();
148 for ch in claim.chars() {
149 if ch == '`' {
150 if in_backtick {
151 let trimmed = buf.trim().to_string();
152 if !trimmed.is_empty() && !is_stop_word(&trimmed) {
153 refs.push(trimmed);
154 }
155 buf.clear();
156 }
157 in_backtick = !in_backtick;
158 } else if in_backtick {
159 buf.push(ch);
160 }
161 }
162
163 let tokens: Vec<&str> = claim
166 .split(|c: char| !c.is_ascii_alphanumeric() && c != '_')
167 .filter(|t| !t.is_empty())
168 .collect();
169
170 for token in &tokens {
171 if is_stop_word(token) {
172 continue;
173 }
174 if is_snake_case(token) || is_camel_case(token) || is_screaming_case(token) {
175 let s = (*token).to_string();
176 if !refs.contains(&s) {
177 refs.push(s);
178 }
179 }
180 }
181
182 refs
183}
184
185fn levenshtein(a: &str, b: &str) -> usize {
192 let a_chars: Vec<char> = a.chars().collect();
193 let b_chars: Vec<char> = b.chars().collect();
194 let m = a_chars.len();
195 let n = b_chars.len();
196
197 if m == 0 {
198 return n;
199 }
200 if n == 0 {
201 return m;
202 }
203
204 let mut prev: Vec<usize> = (0..=n).collect();
206 let mut curr: Vec<usize> = vec![0; n + 1];
207
208 for i in 1..=m {
209 curr[0] = i;
210 for j in 1..=n {
211 let cost = if a_chars[i - 1] == b_chars[j - 1] {
212 0
213 } else {
214 1
215 };
216 curr[j] = (prev[j] + 1) .min(curr[j - 1] + 1) .min(prev[j - 1] + cost); }
220 std::mem::swap(&mut prev, &mut curr);
221 }
222
223 prev[n]
224}
225
226pub struct GroundingEngine<'g> {
246 graph: &'g CodeGraph,
247}
248
249impl<'g> GroundingEngine<'g> {
250 pub fn new(graph: &'g CodeGraph) -> Self {
252 Self { graph }
253 }
254
255 fn evidence_from_unit(unit: &crate::types::CodeUnit) -> Evidence {
257 Evidence {
258 node_id: unit.id,
259 node_type: unit.unit_type.label().to_string(),
260 name: unit.name.clone(),
261 file_path: unit.file_path.display().to_string(),
262 line_number: Some(unit.span.start_line),
263 snippet: unit.signature.clone(),
264 }
265 }
266}
267
268impl<'g> Grounded for GroundingEngine<'g> {
269 fn ground_claim(&self, claim: &str) -> GroundingResult {
270 let refs = extract_code_references(claim);
271
272 if refs.is_empty() {
274 return GroundingResult::Ungrounded {
275 claim: claim.to_string(),
276 suggestions: Vec::new(),
277 };
278 }
279
280 let mut all_evidence: Vec<Evidence> = Vec::new();
281 let mut supported: Vec<String> = Vec::new();
282 let mut unsupported: Vec<String> = Vec::new();
283
284 for reference in &refs {
285 let evidence = self.find_evidence(reference);
286 if evidence.is_empty() {
287 unsupported.push(reference.clone());
288 } else {
289 supported.push(reference.clone());
290 all_evidence.extend(evidence);
291 }
292 }
293
294 if unsupported.is_empty() {
295 let confidence = 1.0_f32; GroundingResult::Verified {
298 evidence: all_evidence,
299 confidence,
300 }
301 } else if supported.is_empty() {
302 let mut suggestions: Vec<String> = Vec::new();
304 for u in &unsupported {
305 suggestions.extend(self.suggest_similar(u, 3));
306 }
307 suggestions.sort();
309 suggestions.dedup();
310 GroundingResult::Ungrounded {
311 claim: claim.to_string(),
312 suggestions,
313 }
314 } else {
315 let mut suggestions: Vec<String> = Vec::new();
317 for u in &unsupported {
318 suggestions.extend(self.suggest_similar(u, 3));
319 }
320 suggestions.sort();
321 suggestions.dedup();
322 GroundingResult::Partial {
323 supported,
324 unsupported,
325 suggestions,
326 }
327 }
328 }
329
330 fn find_evidence(&self, name: &str) -> Vec<Evidence> {
331 let mut results: Vec<Evidence> = Vec::new();
332
333 for unit in self.graph.units() {
335 if unit.name == name {
336 results.push(Self::evidence_from_unit(unit));
337 }
338 }
339
340 if results.is_empty() {
342 for unit in self.graph.units() {
343 if unit.qualified_name.contains(name) {
344 results.push(Self::evidence_from_unit(unit));
345 }
346 }
347 }
348
349 if results.is_empty() {
351 let lower = name.to_lowercase();
352 for unit in self.graph.units() {
353 if unit.name.to_lowercase() == lower {
354 results.push(Self::evidence_from_unit(unit));
355 }
356 }
357 }
358
359 results
360 }
361
362 fn suggest_similar(&self, name: &str, limit: usize) -> Vec<String> {
363 let lower = name.to_lowercase();
364 let threshold = name.len() / 2;
365
366 let mut candidates: Vec<(String, usize)> = Vec::new();
367
368 for unit in self.graph.units() {
369 let unit_lower = unit.name.to_lowercase();
370
371 if unit_lower.starts_with(&lower) || lower.starts_with(&unit_lower) {
373 if !candidates.iter().any(|(n, _)| *n == unit.name) {
374 candidates.push((unit.name.clone(), 0));
375 }
376 continue;
377 }
378
379 let dist = levenshtein(&lower, &unit_lower);
381 if dist <= threshold && dist > 0 {
382 if !candidates.iter().any(|(n, _)| *n == unit.name) {
383 candidates.push((unit.name.clone(), dist));
384 }
385 }
386 }
387
388 candidates.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
390
391 candidates
392 .into_iter()
393 .take(limit)
394 .map(|(name, _)| name)
395 .collect()
396 }
397}
398
399#[cfg(test)]
402mod tests {
403 use super::*;
404 use crate::types::{CodeUnit, CodeUnitType, Language, Span};
405 use std::path::PathBuf;
406
407 fn test_graph() -> CodeGraph {
409 let mut graph = CodeGraph::with_default_dimension();
410
411 graph.add_unit(CodeUnit::new(
412 CodeUnitType::Function,
413 Language::Python,
414 "process_payment".to_string(),
415 "payments.stripe.process_payment".to_string(),
416 PathBuf::from("src/payments/stripe.py"),
417 Span::new(10, 0, 30, 0),
418 ));
419
420 graph.add_unit(CodeUnit::new(
421 CodeUnitType::Type,
422 Language::Rust,
423 "CodeGraph".to_string(),
424 "crate::graph::CodeGraph".to_string(),
425 PathBuf::from("src/graph/code_graph.rs"),
426 Span::new(17, 0, 250, 0),
427 ));
428
429 graph.add_unit(CodeUnit::new(
430 CodeUnitType::Function,
431 Language::Rust,
432 "add_unit".to_string(),
433 "crate::graph::CodeGraph::add_unit".to_string(),
434 PathBuf::from("src/graph/code_graph.rs"),
435 Span::new(58, 0, 64, 0),
436 ));
437
438 graph.add_unit(CodeUnit::new(
439 CodeUnitType::Config,
440 Language::Rust,
441 "MAX_EDGES_PER_UNIT".to_string(),
442 "crate::types::MAX_EDGES_PER_UNIT".to_string(),
443 PathBuf::from("src/types/mod.rs"),
444 Span::new(40, 0, 40, 0),
445 ));
446
447 graph.add_unit(CodeUnit::new(
448 CodeUnitType::Function,
449 Language::Python,
450 "validate_amount".to_string(),
451 "payments.utils.validate_amount".to_string(),
452 PathBuf::from("src/payments/utils.py"),
453 Span::new(5, 0, 15, 0),
454 ));
455
456 graph
457 }
458
459 #[test]
462 fn extract_snake_case_refs() {
463 let refs = extract_code_references("The process_payment function validates the amount");
464 assert!(refs.contains(&"process_payment".to_string()));
465 }
466
467 #[test]
468 fn extract_camel_case_refs() {
469 let refs = extract_code_references("The CodeGraph struct holds all units");
470 assert!(refs.contains(&"CodeGraph".to_string()));
471 }
472
473 #[test]
474 fn extract_screaming_case_refs() {
475 let refs =
476 extract_code_references("The constant MAX_EDGES_PER_UNIT limits the edge count");
477 assert!(refs.contains(&"MAX_EDGES_PER_UNIT".to_string()));
478 }
479
480 #[test]
481 fn extract_backtick_refs() {
482 let refs = extract_code_references("Call `add_unit` to insert a node");
483 assert!(refs.contains(&"add_unit".to_string()));
484 }
485
486 #[test]
487 fn extract_mixed_refs() {
488 let refs = extract_code_references(
489 "The `process_payment` function in CodeGraph uses MAX_EDGES_PER_UNIT",
490 );
491 assert!(refs.contains(&"process_payment".to_string()));
492 assert!(refs.contains(&"CodeGraph".to_string()));
493 assert!(refs.contains(&"MAX_EDGES_PER_UNIT".to_string()));
494 }
495
496 #[test]
497 fn extract_filters_stop_words() {
498 let refs = extract_code_references("the is a an in on");
499 assert!(refs.is_empty());
500 }
501
502 #[test]
503 fn extract_no_duplicates() {
504 let refs = extract_code_references(
505 "`process_payment` calls process_payment to handle the process_payment flow",
506 );
507 let count = refs
508 .iter()
509 .filter(|r| *r == "process_payment")
510 .count();
511 assert_eq!(count, 1);
512 }
513
514 #[test]
517 fn ground_verified_claim() {
518 let graph = test_graph();
519 let engine = GroundingEngine::new(&graph);
520
521 let result = engine.ground_claim("The process_payment function exists");
522 match result {
523 GroundingResult::Verified { evidence, confidence } => {
524 assert!(!evidence.is_empty());
525 assert!(confidence > 0.0);
526 assert_eq!(evidence[0].name, "process_payment");
527 }
528 other => panic!("Expected Verified, got {:?}", other),
529 }
530 }
531
532 #[test]
533 fn ground_ungrounded_claim() {
534 let graph = test_graph();
535 let engine = GroundingEngine::new(&graph);
536
537 let result = engine.ground_claim("The send_invoice function sends emails");
538 match result {
539 GroundingResult::Ungrounded { claim, .. } => {
540 assert!(claim.contains("send_invoice"));
541 }
542 other => panic!("Expected Ungrounded, got {:?}", other),
543 }
544 }
545
546 #[test]
547 fn ground_partial_claim() {
548 let graph = test_graph();
549 let engine = GroundingEngine::new(&graph);
550
551 let result =
552 engine.ground_claim("process_payment calls send_notification after success");
553 match result {
554 GroundingResult::Partial {
555 supported,
556 unsupported,
557 ..
558 } => {
559 assert!(supported.contains(&"process_payment".to_string()));
560 assert!(unsupported.contains(&"send_notification".to_string()));
561 }
562 other => panic!("Expected Partial, got {:?}", other),
563 }
564 }
565
566 #[test]
567 fn ground_no_refs_is_ungrounded() {
568 let graph = test_graph();
569 let engine = GroundingEngine::new(&graph);
570
571 let result = engine.ground_claim("This is a normal English sentence.");
572 assert!(matches!(result, GroundingResult::Ungrounded { .. }));
573 }
574
575 #[test]
578 fn find_evidence_exact_name() {
579 let graph = test_graph();
580 let engine = GroundingEngine::new(&graph);
581
582 let ev = engine.find_evidence("add_unit");
583 assert_eq!(ev.len(), 1);
584 assert_eq!(ev[0].name, "add_unit");
585 assert_eq!(ev[0].node_type, "function");
586 }
587
588 #[test]
589 fn find_evidence_qualified_fallback() {
590 let graph = test_graph();
591 let engine = GroundingEngine::new(&graph);
592
593 let ev = engine.find_evidence("stripe");
595 assert!(!ev.is_empty());
596 assert_eq!(ev[0].name, "process_payment");
597 }
598
599 #[test]
600 fn find_evidence_case_insensitive_fallback() {
601 let graph = test_graph();
602 let engine = GroundingEngine::new(&graph);
603
604 let ev = engine.find_evidence("codegraph");
605 assert!(!ev.is_empty());
606 assert_eq!(ev[0].name, "CodeGraph");
607 }
608
609 #[test]
610 fn find_evidence_nonexistent() {
611 let graph = test_graph();
612 let engine = GroundingEngine::new(&graph);
613
614 let ev = engine.find_evidence("nonexistent_function");
615 assert!(ev.is_empty());
616 }
617
618 #[test]
621 fn suggest_similar_typo() {
622 let graph = test_graph();
623 let engine = GroundingEngine::new(&graph);
624
625 let suggestions = engine.suggest_similar("process_paymnt", 5);
626 assert!(
627 suggestions.contains(&"process_payment".to_string()),
628 "Expected process_payment in {:?}",
629 suggestions
630 );
631 }
632
633 #[test]
634 fn suggest_similar_prefix() {
635 let graph = test_graph();
636 let engine = GroundingEngine::new(&graph);
637
638 let suggestions = engine.suggest_similar("add", 5);
639 assert!(
640 suggestions.contains(&"add_unit".to_string()),
641 "Expected add_unit in {:?}",
642 suggestions
643 );
644 }
645
646 #[test]
647 fn suggest_similar_respects_limit() {
648 let graph = test_graph();
649 let engine = GroundingEngine::new(&graph);
650
651 let suggestions = engine.suggest_similar("a", 2);
652 assert!(suggestions.len() <= 2);
653 }
654
655 #[test]
658 fn levenshtein_identical() {
659 assert_eq!(levenshtein("hello", "hello"), 0);
660 }
661
662 #[test]
663 fn levenshtein_one_edit() {
664 assert_eq!(levenshtein("kitten", "sitten"), 1);
665 }
666
667 #[test]
668 fn levenshtein_full_diff() {
669 assert_eq!(levenshtein("abc", "xyz"), 3);
670 }
671
672 #[test]
673 fn levenshtein_empty() {
674 assert_eq!(levenshtein("", "hello"), 5);
675 assert_eq!(levenshtein("hello", ""), 5);
676 assert_eq!(levenshtein("", ""), 0);
677 }
678
679 #[test]
682 fn test_is_snake_case() {
683 assert!(is_snake_case("process_payment"));
684 assert!(is_snake_case("add_unit"));
685 assert!(is_snake_case("a_b"));
686 assert!(!is_snake_case("process")); assert!(!is_snake_case("ProcessPayment")); assert!(!is_snake_case("_leading"));
689 assert!(!is_snake_case("trailing_"));
690 assert!(!is_snake_case("double__under"));
691 }
692
693 #[test]
694 fn test_is_camel_case() {
695 assert!(is_camel_case("CodeGraph"));
696 assert!(is_camel_case("GroundingEngine"));
697 assert!(is_camel_case("MyType2"));
698 assert!(!is_camel_case("codegraph")); assert!(!is_camel_case("CODEGRAPH")); assert!(!is_camel_case("A")); assert!(!is_camel_case("Code")); }
703
704 #[test]
705 fn test_is_screaming_case() {
706 assert!(is_screaming_case("MAX_EDGES_PER_UNIT"));
707 assert!(is_screaming_case("API_KEY"));
708 assert!(!is_screaming_case("max_edges")); assert!(!is_screaming_case("NOUNDERSCORES")); assert!(!is_screaming_case("_LEADING"));
711 assert!(!is_screaming_case("TRAILING_"));
712 }
713}