Skip to main content

agentic_codebase/grounding/
engine.rs

1//! Grounding engine — verifies code claims against the [`CodeGraph`].
2//!
3//! [`CodeGraph`]: crate::graph::CodeGraph
4
5use crate::graph::CodeGraph;
6
7use super::{Evidence, Grounded, GroundingResult};
8
9// ── Common English stop-words to filter from reference extraction ────────────
10const STOP_WORDS: &[&str] = &[
11    "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
12    "do", "does", "did", "will", "would", "shall", "should", "may", "might", "must", "can",
13    "could", "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "about",
14    "between", "through", "during", "before", "after", "above", "below", "up", "down", "out",
15    "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where",
16    "why", "how", "all", "each", "every", "both", "few", "more", "most", "other", "some", "such",
17    "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just", "because",
18    "but", "and", "or", "if", "while", "that", "this", "these", "those", "it", "its", "my",
19    "your", "his", "her", "our", "their", "what", "which", "who", "whom", "we", "you", "he",
20    "she", "they", "me", "him", "us", "them", "i",
21];
22
23// ── Pattern detection helpers (no regex crate) ──────────────────────────────
24
25/// Returns `true` if the string is a valid `snake_case` identifier.
26///
27/// Pattern: `[a-z][a-z0-9]*(_[a-z0-9]+)+`
28fn is_snake_case(s: &str) -> bool {
29    let chars: Vec<char> = s.chars().collect();
30    if chars.is_empty() {
31        return false;
32    }
33    // Must start with lowercase letter
34    if !chars[0].is_ascii_lowercase() {
35        return false;
36    }
37    // Must contain at least one underscore
38    if !s.contains('_') {
39        return false;
40    }
41    // Every character must be lowercase alphanumeric or underscore
42    if !chars
43        .iter()
44        .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || *c == '_')
45    {
46        return false;
47    }
48    // No leading/trailing/consecutive underscores
49    if s.starts_with('_') || s.ends_with('_') || s.contains("__") {
50        return false;
51    }
52    // Each segment after underscore must start with a lowercase letter or digit
53    for segment in s.split('_') {
54        if segment.is_empty() {
55            return false;
56        }
57    }
58    true
59}
60
61/// Returns `true` if the string is a valid `CamelCase` identifier.
62///
63/// Pattern: `[A-Z][a-z]+([A-Z][a-z]+)+`
64fn is_camel_case(s: &str) -> bool {
65    let chars: Vec<char> = s.chars().collect();
66    if chars.len() < 2 {
67        return false;
68    }
69    // Must start with an uppercase letter
70    if !chars[0].is_ascii_uppercase() {
71        return false;
72    }
73    // All characters must be alphabetic or digits
74    if !chars.iter().all(|c| c.is_ascii_alphanumeric()) {
75        return false;
76    }
77    // Must have at least two uppercase letters (one at start, one in body)
78    // to distinguish CamelCase from a regular capitalized word.
79    let upper_count = chars.iter().filter(|c| c.is_ascii_uppercase()).count();
80    if upper_count < 2 {
81        return false;
82    }
83    // After the first char there must be at least one lowercase letter
84    let has_lower_after_first = chars[1..].iter().any(|c| c.is_ascii_lowercase());
85    if !has_lower_after_first {
86        return false;
87    }
88    true
89}
90
91/// Returns `true` if the string is a valid `SCREAMING_CASE` identifier.
92///
93/// Pattern: `[A-Z][A-Z0-9]*(_[A-Z0-9]+)+`
94fn is_screaming_case(s: &str) -> bool {
95    let chars: Vec<char> = s.chars().collect();
96    if chars.is_empty() {
97        return false;
98    }
99    // Must start with uppercase letter
100    if !chars[0].is_ascii_uppercase() {
101        return false;
102    }
103    // Must contain at least one underscore
104    if !s.contains('_') {
105        return false;
106    }
107    // Every character must be uppercase alphanumeric or underscore
108    if !chars
109        .iter()
110        .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_')
111    {
112        return false;
113    }
114    // No leading/trailing/consecutive underscores
115    if s.starts_with('_') || s.ends_with('_') || s.contains("__") {
116        return false;
117    }
118    for segment in s.split('_') {
119        if segment.is_empty() {
120            return false;
121        }
122    }
123    true
124}
125
126/// Returns `true` if `word` is a common English stop-word.
127fn is_stop_word(word: &str) -> bool {
128    STOP_WORDS.contains(&word.to_lowercase().as_str())
129}
130
131// ── Reference extraction ─────────────────────────────────────────────────────
132
133/// Extract potential code identifiers from a natural-language claim.
134///
135/// Looks for:
136/// - Backtick-quoted identifiers (e.g. `` `foo_bar` ``)
137/// - `snake_case` tokens
138/// - `CamelCase` tokens
139/// - `SCREAMING_CASE` tokens
140///
141/// Common English stop-words are filtered out.
142pub fn extract_code_references(claim: &str) -> Vec<String> {
143    let mut refs: Vec<String> = Vec::new();
144
145    // 1. Extract backtick-quoted identifiers
146    let mut in_backtick = false;
147    let mut buf = String::new();
148    for ch in claim.chars() {
149        if ch == '`' {
150            if in_backtick {
151                let trimmed = buf.trim().to_string();
152                if !trimmed.is_empty() && !is_stop_word(&trimmed) {
153                    refs.push(trimmed);
154                }
155                buf.clear();
156            }
157            in_backtick = !in_backtick;
158        } else if in_backtick {
159            buf.push(ch);
160        }
161    }
162
163    // 2. Tokenize remaining text and check patterns
164    // Split on anything that isn't alphanumeric or underscore
165    let tokens: Vec<&str> = claim
166        .split(|c: char| !c.is_ascii_alphanumeric() && c != '_')
167        .filter(|t| !t.is_empty())
168        .collect();
169
170    for token in &tokens {
171        if is_stop_word(token) {
172            continue;
173        }
174        if is_snake_case(token) || is_camel_case(token) || is_screaming_case(token) {
175            let s = (*token).to_string();
176            if !refs.contains(&s) {
177                refs.push(s);
178            }
179        }
180    }
181
182    refs
183}
184
185// ── Levenshtein edit distance ────────────────────────────────────────────────
186
187/// Compute the Levenshtein edit distance between two strings.
188///
189/// Uses the standard iterative dynamic-programming approach with O(min(m,n))
190/// space.
191fn levenshtein(a: &str, b: &str) -> usize {
192    let a_chars: Vec<char> = a.chars().collect();
193    let b_chars: Vec<char> = b.chars().collect();
194    let m = a_chars.len();
195    let n = b_chars.len();
196
197    if m == 0 {
198        return n;
199    }
200    if n == 0 {
201        return m;
202    }
203
204    // Use two rows instead of full matrix to save memory.
205    let mut prev: Vec<usize> = (0..=n).collect();
206    let mut curr: Vec<usize> = vec![0; n + 1];
207
208    for i in 1..=m {
209        curr[0] = i;
210        for j in 1..=n {
211            let cost = if a_chars[i - 1] == b_chars[j - 1] {
212                0
213            } else {
214                1
215            };
216            curr[j] = (prev[j] + 1) // deletion
217                .min(curr[j - 1] + 1) // insertion
218                .min(prev[j - 1] + cost); // substitution
219        }
220        std::mem::swap(&mut prev, &mut curr);
221    }
222
223    prev[n]
224}
225
226// ── GroundingEngine ──────────────────────────────────────────────────────────
227
228/// Engine that verifies code claims against a [`CodeGraph`].
229///
230/// Wraps a reference to the code graph and implements the [`Grounded`] trait
231/// to provide anti-hallucination checks.
232///
233/// # Examples
234///
235/// ```ignore
236/// let engine = GroundingEngine::new(&graph);
237/// match engine.ground_claim("process_payment validates the Decimal amount") {
238///     GroundingResult::Verified { evidence, confidence } => { /* all good */ }
239///     GroundingResult::Partial { unsupported, .. } => { /* some unknown refs */ }
240///     GroundingResult::Ungrounded { claim, suggestions } => { /* hallucination */ }
241/// }
242/// ```
243///
244/// [`CodeGraph`]: crate::graph::CodeGraph
245pub struct GroundingEngine<'g> {
246    graph: &'g CodeGraph,
247}
248
249impl<'g> GroundingEngine<'g> {
250    /// Create a new grounding engine backed by the given code graph.
251    pub fn new(graph: &'g CodeGraph) -> Self {
252        Self { graph }
253    }
254
255    /// Build an [`Evidence`] record from a [`CodeUnit`][crate::types::CodeUnit].
256    fn evidence_from_unit(unit: &crate::types::CodeUnit) -> Evidence {
257        Evidence {
258            node_id: unit.id,
259            node_type: unit.unit_type.label().to_string(),
260            name: unit.name.clone(),
261            file_path: unit.file_path.display().to_string(),
262            line_number: Some(unit.span.start_line),
263            snippet: unit.signature.clone(),
264        }
265    }
266}
267
268impl<'g> Grounded for GroundingEngine<'g> {
269    fn ground_claim(&self, claim: &str) -> GroundingResult {
270        let refs = extract_code_references(claim);
271
272        // No identifiable code references — treat as ungrounded.
273        if refs.is_empty() {
274            return GroundingResult::Ungrounded {
275                claim: claim.to_string(),
276                suggestions: Vec::new(),
277            };
278        }
279
280        let mut all_evidence: Vec<Evidence> = Vec::new();
281        let mut supported: Vec<String> = Vec::new();
282        let mut unsupported: Vec<String> = Vec::new();
283
284        for reference in &refs {
285            let evidence = self.find_evidence(reference);
286            if evidence.is_empty() {
287                unsupported.push(reference.clone());
288            } else {
289                supported.push(reference.clone());
290                all_evidence.extend(evidence);
291            }
292        }
293
294        if unsupported.is_empty() {
295            // All references verified.
296            let confidence = 1.0_f32; // all matched
297            GroundingResult::Verified {
298                evidence: all_evidence,
299                confidence,
300            }
301        } else if supported.is_empty() {
302            // Nothing matched — potential hallucination.
303            let mut suggestions: Vec<String> = Vec::new();
304            for u in &unsupported {
305                suggestions.extend(self.suggest_similar(u, 3));
306            }
307            // Deduplicate suggestions
308            suggestions.sort();
309            suggestions.dedup();
310            GroundingResult::Ungrounded {
311                claim: claim.to_string(),
312                suggestions,
313            }
314        } else {
315            // Partial match.
316            let mut suggestions: Vec<String> = Vec::new();
317            for u in &unsupported {
318                suggestions.extend(self.suggest_similar(u, 3));
319            }
320            suggestions.sort();
321            suggestions.dedup();
322            GroundingResult::Partial {
323                supported,
324                unsupported,
325                suggestions,
326            }
327        }
328    }
329
330    fn find_evidence(&self, name: &str) -> Vec<Evidence> {
331        let mut results: Vec<Evidence> = Vec::new();
332
333        // 1. Exact match on simple name
334        for unit in self.graph.units() {
335            if unit.name == name {
336                results.push(Self::evidence_from_unit(unit));
337            }
338        }
339
340        // 2. If no exact match, try qualified_name contains
341        if results.is_empty() {
342            for unit in self.graph.units() {
343                if unit.qualified_name.contains(name) {
344                    results.push(Self::evidence_from_unit(unit));
345                }
346            }
347        }
348
349        // 3. If still empty, try case-insensitive exact match on name
350        if results.is_empty() {
351            let lower = name.to_lowercase();
352            for unit in self.graph.units() {
353                if unit.name.to_lowercase() == lower {
354                    results.push(Self::evidence_from_unit(unit));
355                }
356            }
357        }
358
359        results
360    }
361
362    fn suggest_similar(&self, name: &str, limit: usize) -> Vec<String> {
363        let lower = name.to_lowercase();
364        let threshold = name.len() / 2;
365
366        let mut candidates: Vec<(String, usize)> = Vec::new();
367
368        for unit in self.graph.units() {
369            let unit_lower = unit.name.to_lowercase();
370
371            // Prefix match — always include with distance 0
372            if unit_lower.starts_with(&lower) || lower.starts_with(&unit_lower) {
373                if !candidates.iter().any(|(n, _)| *n == unit.name) {
374                    candidates.push((unit.name.clone(), 0));
375                }
376                continue;
377            }
378
379            // Edit distance
380            let dist = levenshtein(&lower, &unit_lower);
381            if dist <= threshold && dist > 0 {
382                if !candidates.iter().any(|(n, _)| *n == unit.name) {
383                    candidates.push((unit.name.clone(), dist));
384                }
385            }
386        }
387
388        // Sort by distance (ascending), then alphabetically for ties
389        candidates.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
390
391        candidates
392            .into_iter()
393            .take(limit)
394            .map(|(name, _)| name)
395            .collect()
396    }
397}
398
399// ── Tests ────────────────────────────────────────────────────────────────────
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::types::{CodeUnit, CodeUnitType, Language, Span};
405    use std::path::PathBuf;
406
407    /// Build a small test graph for grounding tests.
408    fn test_graph() -> CodeGraph {
409        let mut graph = CodeGraph::with_default_dimension();
410
411        graph.add_unit(CodeUnit::new(
412            CodeUnitType::Function,
413            Language::Python,
414            "process_payment".to_string(),
415            "payments.stripe.process_payment".to_string(),
416            PathBuf::from("src/payments/stripe.py"),
417            Span::new(10, 0, 30, 0),
418        ));
419
420        graph.add_unit(CodeUnit::new(
421            CodeUnitType::Type,
422            Language::Rust,
423            "CodeGraph".to_string(),
424            "crate::graph::CodeGraph".to_string(),
425            PathBuf::from("src/graph/code_graph.rs"),
426            Span::new(17, 0, 250, 0),
427        ));
428
429        graph.add_unit(CodeUnit::new(
430            CodeUnitType::Function,
431            Language::Rust,
432            "add_unit".to_string(),
433            "crate::graph::CodeGraph::add_unit".to_string(),
434            PathBuf::from("src/graph/code_graph.rs"),
435            Span::new(58, 0, 64, 0),
436        ));
437
438        graph.add_unit(CodeUnit::new(
439            CodeUnitType::Config,
440            Language::Rust,
441            "MAX_EDGES_PER_UNIT".to_string(),
442            "crate::types::MAX_EDGES_PER_UNIT".to_string(),
443            PathBuf::from("src/types/mod.rs"),
444            Span::new(40, 0, 40, 0),
445        ));
446
447        graph.add_unit(CodeUnit::new(
448            CodeUnitType::Function,
449            Language::Python,
450            "validate_amount".to_string(),
451            "payments.utils.validate_amount".to_string(),
452            PathBuf::from("src/payments/utils.py"),
453            Span::new(5, 0, 15, 0),
454        ));
455
456        graph
457    }
458
459    // ── extract_code_references ──────────────────────────────────────────
460
461    #[test]
462    fn extract_snake_case_refs() {
463        let refs = extract_code_references("The process_payment function validates the amount");
464        assert!(refs.contains(&"process_payment".to_string()));
465    }
466
467    #[test]
468    fn extract_camel_case_refs() {
469        let refs = extract_code_references("The CodeGraph struct holds all units");
470        assert!(refs.contains(&"CodeGraph".to_string()));
471    }
472
473    #[test]
474    fn extract_screaming_case_refs() {
475        let refs =
476            extract_code_references("The constant MAX_EDGES_PER_UNIT limits the edge count");
477        assert!(refs.contains(&"MAX_EDGES_PER_UNIT".to_string()));
478    }
479
480    #[test]
481    fn extract_backtick_refs() {
482        let refs = extract_code_references("Call `add_unit` to insert a node");
483        assert!(refs.contains(&"add_unit".to_string()));
484    }
485
486    #[test]
487    fn extract_mixed_refs() {
488        let refs = extract_code_references(
489            "The `process_payment` function in CodeGraph uses MAX_EDGES_PER_UNIT",
490        );
491        assert!(refs.contains(&"process_payment".to_string()));
492        assert!(refs.contains(&"CodeGraph".to_string()));
493        assert!(refs.contains(&"MAX_EDGES_PER_UNIT".to_string()));
494    }
495
496    #[test]
497    fn extract_filters_stop_words() {
498        let refs = extract_code_references("the is a an in on");
499        assert!(refs.is_empty());
500    }
501
502    #[test]
503    fn extract_no_duplicates() {
504        let refs = extract_code_references(
505            "`process_payment` calls process_payment to handle the process_payment flow",
506        );
507        let count = refs
508            .iter()
509            .filter(|r| *r == "process_payment")
510            .count();
511        assert_eq!(count, 1);
512    }
513
514    // ── ground_claim ─────────────────────────────────────────────────────
515
516    #[test]
517    fn ground_verified_claim() {
518        let graph = test_graph();
519        let engine = GroundingEngine::new(&graph);
520
521        let result = engine.ground_claim("The process_payment function exists");
522        match result {
523            GroundingResult::Verified { evidence, confidence } => {
524                assert!(!evidence.is_empty());
525                assert!(confidence > 0.0);
526                assert_eq!(evidence[0].name, "process_payment");
527            }
528            other => panic!("Expected Verified, got {:?}", other),
529        }
530    }
531
532    #[test]
533    fn ground_ungrounded_claim() {
534        let graph = test_graph();
535        let engine = GroundingEngine::new(&graph);
536
537        let result = engine.ground_claim("The send_invoice function sends emails");
538        match result {
539            GroundingResult::Ungrounded { claim, .. } => {
540                assert!(claim.contains("send_invoice"));
541            }
542            other => panic!("Expected Ungrounded, got {:?}", other),
543        }
544    }
545
546    #[test]
547    fn ground_partial_claim() {
548        let graph = test_graph();
549        let engine = GroundingEngine::new(&graph);
550
551        let result =
552            engine.ground_claim("process_payment calls send_notification after success");
553        match result {
554            GroundingResult::Partial {
555                supported,
556                unsupported,
557                ..
558            } => {
559                assert!(supported.contains(&"process_payment".to_string()));
560                assert!(unsupported.contains(&"send_notification".to_string()));
561            }
562            other => panic!("Expected Partial, got {:?}", other),
563        }
564    }
565
566    #[test]
567    fn ground_no_refs_is_ungrounded() {
568        let graph = test_graph();
569        let engine = GroundingEngine::new(&graph);
570
571        let result = engine.ground_claim("This is a normal English sentence.");
572        assert!(matches!(result, GroundingResult::Ungrounded { .. }));
573    }
574
575    // ── find_evidence ────────────────────────────────────────────────────
576
577    #[test]
578    fn find_evidence_exact_name() {
579        let graph = test_graph();
580        let engine = GroundingEngine::new(&graph);
581
582        let ev = engine.find_evidence("add_unit");
583        assert_eq!(ev.len(), 1);
584        assert_eq!(ev[0].name, "add_unit");
585        assert_eq!(ev[0].node_type, "function");
586    }
587
588    #[test]
589    fn find_evidence_qualified_fallback() {
590        let graph = test_graph();
591        let engine = GroundingEngine::new(&graph);
592
593        // "stripe" appears in the qualified name of process_payment
594        let ev = engine.find_evidence("stripe");
595        assert!(!ev.is_empty());
596        assert_eq!(ev[0].name, "process_payment");
597    }
598
599    #[test]
600    fn find_evidence_case_insensitive_fallback() {
601        let graph = test_graph();
602        let engine = GroundingEngine::new(&graph);
603
604        let ev = engine.find_evidence("codegraph");
605        assert!(!ev.is_empty());
606        assert_eq!(ev[0].name, "CodeGraph");
607    }
608
609    #[test]
610    fn find_evidence_nonexistent() {
611        let graph = test_graph();
612        let engine = GroundingEngine::new(&graph);
613
614        let ev = engine.find_evidence("nonexistent_function");
615        assert!(ev.is_empty());
616    }
617
618    // ── suggest_similar ──────────────────────────────────────────────────
619
620    #[test]
621    fn suggest_similar_typo() {
622        let graph = test_graph();
623        let engine = GroundingEngine::new(&graph);
624
625        let suggestions = engine.suggest_similar("process_paymnt", 5);
626        assert!(
627            suggestions.contains(&"process_payment".to_string()),
628            "Expected process_payment in {:?}",
629            suggestions
630        );
631    }
632
633    #[test]
634    fn suggest_similar_prefix() {
635        let graph = test_graph();
636        let engine = GroundingEngine::new(&graph);
637
638        let suggestions = engine.suggest_similar("add", 5);
639        assert!(
640            suggestions.contains(&"add_unit".to_string()),
641            "Expected add_unit in {:?}",
642            suggestions
643        );
644    }
645
646    #[test]
647    fn suggest_similar_respects_limit() {
648        let graph = test_graph();
649        let engine = GroundingEngine::new(&graph);
650
651        let suggestions = engine.suggest_similar("a", 2);
652        assert!(suggestions.len() <= 2);
653    }
654
655    // ── levenshtein ──────────────────────────────────────────────────────
656
657    #[test]
658    fn levenshtein_identical() {
659        assert_eq!(levenshtein("hello", "hello"), 0);
660    }
661
662    #[test]
663    fn levenshtein_one_edit() {
664        assert_eq!(levenshtein("kitten", "sitten"), 1);
665    }
666
667    #[test]
668    fn levenshtein_full_diff() {
669        assert_eq!(levenshtein("abc", "xyz"), 3);
670    }
671
672    #[test]
673    fn levenshtein_empty() {
674        assert_eq!(levenshtein("", "hello"), 5);
675        assert_eq!(levenshtein("hello", ""), 5);
676        assert_eq!(levenshtein("", ""), 0);
677    }
678
679    // ── pattern detection helpers ────────────────────────────────────────
680
681    #[test]
682    fn test_is_snake_case() {
683        assert!(is_snake_case("process_payment"));
684        assert!(is_snake_case("add_unit"));
685        assert!(is_snake_case("a_b"));
686        assert!(!is_snake_case("process")); // no underscore
687        assert!(!is_snake_case("ProcessPayment")); // CamelCase
688        assert!(!is_snake_case("_leading"));
689        assert!(!is_snake_case("trailing_"));
690        assert!(!is_snake_case("double__under"));
691    }
692
693    #[test]
694    fn test_is_camel_case() {
695        assert!(is_camel_case("CodeGraph"));
696        assert!(is_camel_case("GroundingEngine"));
697        assert!(is_camel_case("MyType2"));
698        assert!(!is_camel_case("codegraph")); // all lower
699        assert!(!is_camel_case("CODEGRAPH")); // all upper
700        assert!(!is_camel_case("A")); // too short
701        assert!(!is_camel_case("Code")); // only one uppercase
702    }
703
704    #[test]
705    fn test_is_screaming_case() {
706        assert!(is_screaming_case("MAX_EDGES_PER_UNIT"));
707        assert!(is_screaming_case("API_KEY"));
708        assert!(!is_screaming_case("max_edges")); // lowercase
709        assert!(!is_screaming_case("NOUNDERSCORES")); // no underscore
710        assert!(!is_screaming_case("_LEADING"));
711        assert!(!is_screaming_case("TRAILING_"));
712    }
713}