Skip to main content

agentic_codebase/grounding/
engine.rs

1//! Grounding engine — verifies code claims against the [`CodeGraph`].
2//!
3//! [`CodeGraph`]: crate::graph::CodeGraph
4
5use crate::graph::CodeGraph;
6
7use super::{Evidence, Grounded, GroundingResult};
8
9// ── Common English stop-words to filter from reference extraction ────────────
10const STOP_WORDS: &[&str] = &[
11    "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had",
12    "do", "does", "did", "will", "would", "shall", "should", "may", "might", "must", "can",
13    "could", "to", "of", "in", "for", "on", "with", "at", "by", "from", "as", "into", "about",
14    "between", "through", "during", "before", "after", "above", "below", "up", "down", "out",
15    "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where",
16    "why", "how", "all", "each", "every", "both", "few", "more", "most", "other", "some", "such",
17    "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just", "because",
18    "but", "and", "or", "if", "while", "that", "this", "these", "those", "it", "its", "my", "your",
19    "his", "her", "our", "their", "what", "which", "who", "whom", "we", "you", "he", "she", "they",
20    "me", "him", "us", "them", "i",
21];
22
23// ── Pattern detection helpers (no regex crate) ──────────────────────────────
24
25/// Returns `true` if the string is a valid `snake_case` identifier.
26///
27/// Pattern: `[a-z][a-z0-9]*(_[a-z0-9]+)+`
28fn is_snake_case(s: &str) -> bool {
29    let chars: Vec<char> = s.chars().collect();
30    if chars.is_empty() {
31        return false;
32    }
33    // Must start with lowercase letter
34    if !chars[0].is_ascii_lowercase() {
35        return false;
36    }
37    // Must contain at least one underscore
38    if !s.contains('_') {
39        return false;
40    }
41    // Every character must be lowercase alphanumeric or underscore
42    if !chars
43        .iter()
44        .all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || *c == '_')
45    {
46        return false;
47    }
48    // No leading/trailing/consecutive underscores
49    if s.starts_with('_') || s.ends_with('_') || s.contains("__") {
50        return false;
51    }
52    // Each segment after underscore must start with a lowercase letter or digit
53    for segment in s.split('_') {
54        if segment.is_empty() {
55            return false;
56        }
57    }
58    true
59}
60
61/// Returns `true` if the string is a valid `CamelCase` identifier.
62///
63/// Pattern: `[A-Z][a-z]+([A-Z][a-z]+)+`
64fn is_camel_case(s: &str) -> bool {
65    let chars: Vec<char> = s.chars().collect();
66    if chars.len() < 2 {
67        return false;
68    }
69    // Must start with an uppercase letter
70    if !chars[0].is_ascii_uppercase() {
71        return false;
72    }
73    // All characters must be alphabetic or digits
74    if !chars.iter().all(|c| c.is_ascii_alphanumeric()) {
75        return false;
76    }
77    // Must have at least two uppercase letters (one at start, one in body)
78    // to distinguish CamelCase from a regular capitalized word.
79    let upper_count = chars.iter().filter(|c| c.is_ascii_uppercase()).count();
80    if upper_count < 2 {
81        return false;
82    }
83    // After the first char there must be at least one lowercase letter
84    let has_lower_after_first = chars[1..].iter().any(|c| c.is_ascii_lowercase());
85    if !has_lower_after_first {
86        return false;
87    }
88    true
89}
90
91/// Returns `true` if the string is a valid `SCREAMING_CASE` identifier.
92///
93/// Pattern: `[A-Z][A-Z0-9]*(_[A-Z0-9]+)+`
94fn is_screaming_case(s: &str) -> bool {
95    let chars: Vec<char> = s.chars().collect();
96    if chars.is_empty() {
97        return false;
98    }
99    // Must start with uppercase letter
100    if !chars[0].is_ascii_uppercase() {
101        return false;
102    }
103    // Must contain at least one underscore
104    if !s.contains('_') {
105        return false;
106    }
107    // Every character must be uppercase alphanumeric or underscore
108    if !chars
109        .iter()
110        .all(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_')
111    {
112        return false;
113    }
114    // No leading/trailing/consecutive underscores
115    if s.starts_with('_') || s.ends_with('_') || s.contains("__") {
116        return false;
117    }
118    for segment in s.split('_') {
119        if segment.is_empty() {
120            return false;
121        }
122    }
123    true
124}
125
126/// Returns `true` if `word` is a common English stop-word.
127fn is_stop_word(word: &str) -> bool {
128    STOP_WORDS.contains(&word.to_lowercase().as_str())
129}
130
131// ── Reference extraction ─────────────────────────────────────────────────────
132
133/// Extract potential code identifiers from a natural-language claim.
134///
135/// Looks for:
136/// - Backtick-quoted identifiers (e.g. `` `foo_bar` ``)
137/// - `snake_case` tokens
138/// - `CamelCase` tokens
139/// - `SCREAMING_CASE` tokens
140///
141/// Common English stop-words are filtered out.
142pub fn extract_code_references(claim: &str) -> Vec<String> {
143    let mut refs: Vec<String> = Vec::new();
144
145    // 1. Extract backtick-quoted identifiers
146    let mut in_backtick = false;
147    let mut buf = String::new();
148    for ch in claim.chars() {
149        if ch == '`' {
150            if in_backtick {
151                let trimmed = buf.trim().to_string();
152                if !trimmed.is_empty() && !is_stop_word(&trimmed) {
153                    refs.push(trimmed);
154                }
155                buf.clear();
156            }
157            in_backtick = !in_backtick;
158        } else if in_backtick {
159            buf.push(ch);
160        }
161    }
162
163    // 2. Tokenize remaining text and check patterns
164    // Split on anything that isn't alphanumeric or underscore
165    let tokens: Vec<&str> = claim
166        .split(|c: char| !c.is_ascii_alphanumeric() && c != '_')
167        .filter(|t| !t.is_empty())
168        .collect();
169
170    for token in &tokens {
171        if is_stop_word(token) {
172            continue;
173        }
174        if is_snake_case(token) || is_camel_case(token) || is_screaming_case(token) {
175            let s = (*token).to_string();
176            if !refs.contains(&s) {
177                refs.push(s);
178            }
179        }
180    }
181
182    refs
183}
184
185// ── Levenshtein edit distance ────────────────────────────────────────────────
186
187/// Compute the Levenshtein edit distance between two strings.
188///
189/// Uses the standard iterative dynamic-programming approach with O(min(m,n))
190/// space.
191fn levenshtein(a: &str, b: &str) -> usize {
192    let a_chars: Vec<char> = a.chars().collect();
193    let b_chars: Vec<char> = b.chars().collect();
194    let m = a_chars.len();
195    let n = b_chars.len();
196
197    if m == 0 {
198        return n;
199    }
200    if n == 0 {
201        return m;
202    }
203
204    // Use two rows instead of full matrix to save memory.
205    let mut prev: Vec<usize> = (0..=n).collect();
206    let mut curr: Vec<usize> = vec![0; n + 1];
207
208    for i in 1..=m {
209        curr[0] = i;
210        for j in 1..=n {
211            let cost = if a_chars[i - 1] == b_chars[j - 1] {
212                0
213            } else {
214                1
215            };
216            curr[j] = (prev[j] + 1) // deletion
217                .min(curr[j - 1] + 1) // insertion
218                .min(prev[j - 1] + cost); // substitution
219        }
220        std::mem::swap(&mut prev, &mut curr);
221    }
222
223    prev[n]
224}
225
226// ── GroundingEngine ──────────────────────────────────────────────────────────
227
228/// Engine that verifies code claims against a [`CodeGraph`].
229///
230/// Wraps a reference to the code graph and implements the [`Grounded`] trait
231/// to provide anti-hallucination checks.
232///
233/// # Examples
234///
235/// ```ignore
236/// let engine = GroundingEngine::new(&graph);
237/// match engine.ground_claim("process_payment validates the Decimal amount") {
238///     GroundingResult::Verified { evidence, confidence } => { /* all good */ }
239///     GroundingResult::Partial { unsupported, .. } => { /* some unknown refs */ }
240///     GroundingResult::Ungrounded { claim, suggestions } => { /* hallucination */ }
241/// }
242/// ```
243///
244/// [`CodeGraph`]: crate::graph::CodeGraph
245pub struct GroundingEngine<'g> {
246    graph: &'g CodeGraph,
247}
248
249impl<'g> GroundingEngine<'g> {
250    /// Create a new grounding engine backed by the given code graph.
251    pub fn new(graph: &'g CodeGraph) -> Self {
252        Self { graph }
253    }
254
255    /// Build an [`Evidence`] record from a [`CodeUnit`][crate::types::CodeUnit].
256    fn evidence_from_unit(unit: &crate::types::CodeUnit) -> Evidence {
257        Evidence {
258            node_id: unit.id,
259            node_type: unit.unit_type.label().to_string(),
260            name: unit.name.clone(),
261            file_path: unit.file_path.display().to_string(),
262            line_number: Some(unit.span.start_line),
263            snippet: unit.signature.clone(),
264        }
265    }
266}
267
268impl<'g> Grounded for GroundingEngine<'g> {
269    fn ground_claim(&self, claim: &str) -> GroundingResult {
270        let refs = extract_code_references(claim);
271
272        // No identifiable code references — treat as ungrounded.
273        if refs.is_empty() {
274            return GroundingResult::Ungrounded {
275                claim: claim.to_string(),
276                suggestions: Vec::new(),
277            };
278        }
279
280        let mut all_evidence: Vec<Evidence> = Vec::new();
281        let mut supported: Vec<String> = Vec::new();
282        let mut unsupported: Vec<String> = Vec::new();
283
284        for reference in &refs {
285            let evidence = self.find_evidence(reference);
286            if evidence.is_empty() {
287                unsupported.push(reference.clone());
288            } else {
289                supported.push(reference.clone());
290                all_evidence.extend(evidence);
291            }
292        }
293
294        if unsupported.is_empty() {
295            // All references verified.
296            let confidence = 1.0_f32; // all matched
297            GroundingResult::Verified {
298                evidence: all_evidence,
299                confidence,
300            }
301        } else if supported.is_empty() {
302            // Nothing matched — potential hallucination.
303            let mut suggestions: Vec<String> = Vec::new();
304            for u in &unsupported {
305                suggestions.extend(self.suggest_similar(u, 3));
306            }
307            // Deduplicate suggestions
308            suggestions.sort();
309            suggestions.dedup();
310            GroundingResult::Ungrounded {
311                claim: claim.to_string(),
312                suggestions,
313            }
314        } else {
315            // Partial match.
316            let mut suggestions: Vec<String> = Vec::new();
317            for u in &unsupported {
318                suggestions.extend(self.suggest_similar(u, 3));
319            }
320            suggestions.sort();
321            suggestions.dedup();
322            GroundingResult::Partial {
323                supported,
324                unsupported,
325                suggestions,
326            }
327        }
328    }
329
330    fn find_evidence(&self, name: &str) -> Vec<Evidence> {
331        let mut results: Vec<Evidence> = Vec::new();
332
333        // 1. Exact match on simple name
334        for unit in self.graph.units() {
335            if unit.name == name {
336                results.push(Self::evidence_from_unit(unit));
337            }
338        }
339
340        // 2. If no exact match, try qualified_name contains
341        if results.is_empty() {
342            for unit in self.graph.units() {
343                if unit.qualified_name.contains(name) {
344                    results.push(Self::evidence_from_unit(unit));
345                }
346            }
347        }
348
349        // 3. If still empty, try case-insensitive exact match on name
350        if results.is_empty() {
351            let lower = name.to_lowercase();
352            for unit in self.graph.units() {
353                if unit.name.to_lowercase() == lower {
354                    results.push(Self::evidence_from_unit(unit));
355                }
356            }
357        }
358
359        results
360    }
361
362    fn suggest_similar(&self, name: &str, limit: usize) -> Vec<String> {
363        let lower = name.to_lowercase();
364        let threshold = name.len() / 2;
365
366        let mut candidates: Vec<(String, usize)> = Vec::new();
367
368        for unit in self.graph.units() {
369            let unit_lower = unit.name.to_lowercase();
370
371            // Prefix match — always include with distance 0
372            if unit_lower.starts_with(&lower) || lower.starts_with(&unit_lower) {
373                if !candidates.iter().any(|(n, _)| *n == unit.name) {
374                    candidates.push((unit.name.clone(), 0));
375                }
376                continue;
377            }
378
379            // Edit distance
380            let dist = levenshtein(&lower, &unit_lower);
381            if dist <= threshold && dist > 0 && !candidates.iter().any(|(n, _)| *n == unit.name) {
382                candidates.push((unit.name.clone(), dist));
383            }
384        }
385
386        // Sort by distance (ascending), then alphabetically for ties
387        candidates.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
388
389        candidates
390            .into_iter()
391            .take(limit)
392            .map(|(name, _)| name)
393            .collect()
394    }
395}
396
397// ── Tests ────────────────────────────────────────────────────────────────────
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use crate::types::{CodeUnit, CodeUnitType, Language, Span};
403    use std::path::PathBuf;
404
405    /// Build a small test graph for grounding tests.
406    fn test_graph() -> CodeGraph {
407        let mut graph = CodeGraph::with_default_dimension();
408
409        graph.add_unit(CodeUnit::new(
410            CodeUnitType::Function,
411            Language::Python,
412            "process_payment".to_string(),
413            "payments.stripe.process_payment".to_string(),
414            PathBuf::from("src/payments/stripe.py"),
415            Span::new(10, 0, 30, 0),
416        ));
417
418        graph.add_unit(CodeUnit::new(
419            CodeUnitType::Type,
420            Language::Rust,
421            "CodeGraph".to_string(),
422            "crate::graph::CodeGraph".to_string(),
423            PathBuf::from("src/graph/code_graph.rs"),
424            Span::new(17, 0, 250, 0),
425        ));
426
427        graph.add_unit(CodeUnit::new(
428            CodeUnitType::Function,
429            Language::Rust,
430            "add_unit".to_string(),
431            "crate::graph::CodeGraph::add_unit".to_string(),
432            PathBuf::from("src/graph/code_graph.rs"),
433            Span::new(58, 0, 64, 0),
434        ));
435
436        graph.add_unit(CodeUnit::new(
437            CodeUnitType::Config,
438            Language::Rust,
439            "MAX_EDGES_PER_UNIT".to_string(),
440            "crate::types::MAX_EDGES_PER_UNIT".to_string(),
441            PathBuf::from("src/types/mod.rs"),
442            Span::new(40, 0, 40, 0),
443        ));
444
445        graph.add_unit(CodeUnit::new(
446            CodeUnitType::Function,
447            Language::Python,
448            "validate_amount".to_string(),
449            "payments.utils.validate_amount".to_string(),
450            PathBuf::from("src/payments/utils.py"),
451            Span::new(5, 0, 15, 0),
452        ));
453
454        graph
455    }
456
457    // ── extract_code_references ──────────────────────────────────────────
458
459    #[test]
460    fn extract_snake_case_refs() {
461        let refs = extract_code_references("The process_payment function validates the amount");
462        assert!(refs.contains(&"process_payment".to_string()));
463    }
464
465    #[test]
466    fn extract_camel_case_refs() {
467        let refs = extract_code_references("The CodeGraph struct holds all units");
468        assert!(refs.contains(&"CodeGraph".to_string()));
469    }
470
471    #[test]
472    fn extract_screaming_case_refs() {
473        let refs = extract_code_references("The constant MAX_EDGES_PER_UNIT limits the edge count");
474        assert!(refs.contains(&"MAX_EDGES_PER_UNIT".to_string()));
475    }
476
477    #[test]
478    fn extract_backtick_refs() {
479        let refs = extract_code_references("Call `add_unit` to insert a node");
480        assert!(refs.contains(&"add_unit".to_string()));
481    }
482
483    #[test]
484    fn extract_mixed_refs() {
485        let refs = extract_code_references(
486            "The `process_payment` function in CodeGraph uses MAX_EDGES_PER_UNIT",
487        );
488        assert!(refs.contains(&"process_payment".to_string()));
489        assert!(refs.contains(&"CodeGraph".to_string()));
490        assert!(refs.contains(&"MAX_EDGES_PER_UNIT".to_string()));
491    }
492
493    #[test]
494    fn extract_filters_stop_words() {
495        let refs = extract_code_references("the is a an in on");
496        assert!(refs.is_empty());
497    }
498
499    #[test]
500    fn extract_no_duplicates() {
501        let refs = extract_code_references(
502            "`process_payment` calls process_payment to handle the process_payment flow",
503        );
504        let count = refs.iter().filter(|r| *r == "process_payment").count();
505        assert_eq!(count, 1);
506    }
507
508    // ── ground_claim ─────────────────────────────────────────────────────
509
510    #[test]
511    fn ground_verified_claim() {
512        let graph = test_graph();
513        let engine = GroundingEngine::new(&graph);
514
515        let result = engine.ground_claim("The process_payment function exists");
516        match result {
517            GroundingResult::Verified {
518                evidence,
519                confidence,
520            } => {
521                assert!(!evidence.is_empty());
522                assert!(confidence > 0.0);
523                assert_eq!(evidence[0].name, "process_payment");
524            }
525            other => panic!("Expected Verified, got {:?}", other),
526        }
527    }
528
529    #[test]
530    fn ground_ungrounded_claim() {
531        let graph = test_graph();
532        let engine = GroundingEngine::new(&graph);
533
534        let result = engine.ground_claim("The send_invoice function sends emails");
535        match result {
536            GroundingResult::Ungrounded { claim, .. } => {
537                assert!(claim.contains("send_invoice"));
538            }
539            other => panic!("Expected Ungrounded, got {:?}", other),
540        }
541    }
542
543    #[test]
544    fn ground_partial_claim() {
545        let graph = test_graph();
546        let engine = GroundingEngine::new(&graph);
547
548        let result = engine.ground_claim("process_payment calls send_notification after success");
549        match result {
550            GroundingResult::Partial {
551                supported,
552                unsupported,
553                ..
554            } => {
555                assert!(supported.contains(&"process_payment".to_string()));
556                assert!(unsupported.contains(&"send_notification".to_string()));
557            }
558            other => panic!("Expected Partial, got {:?}", other),
559        }
560    }
561
562    #[test]
563    fn ground_no_refs_is_ungrounded() {
564        let graph = test_graph();
565        let engine = GroundingEngine::new(&graph);
566
567        let result = engine.ground_claim("This is a normal English sentence.");
568        assert!(matches!(result, GroundingResult::Ungrounded { .. }));
569    }
570
571    // ── find_evidence ────────────────────────────────────────────────────
572
573    #[test]
574    fn find_evidence_exact_name() {
575        let graph = test_graph();
576        let engine = GroundingEngine::new(&graph);
577
578        let ev = engine.find_evidence("add_unit");
579        assert_eq!(ev.len(), 1);
580        assert_eq!(ev[0].name, "add_unit");
581        assert_eq!(ev[0].node_type, "function");
582    }
583
584    #[test]
585    fn find_evidence_qualified_fallback() {
586        let graph = test_graph();
587        let engine = GroundingEngine::new(&graph);
588
589        // "stripe" appears in the qualified name of process_payment
590        let ev = engine.find_evidence("stripe");
591        assert!(!ev.is_empty());
592        assert_eq!(ev[0].name, "process_payment");
593    }
594
595    #[test]
596    fn find_evidence_case_insensitive_fallback() {
597        let graph = test_graph();
598        let engine = GroundingEngine::new(&graph);
599
600        let ev = engine.find_evidence("codegraph");
601        assert!(!ev.is_empty());
602        assert_eq!(ev[0].name, "CodeGraph");
603    }
604
605    #[test]
606    fn find_evidence_nonexistent() {
607        let graph = test_graph();
608        let engine = GroundingEngine::new(&graph);
609
610        let ev = engine.find_evidence("nonexistent_function");
611        assert!(ev.is_empty());
612    }
613
614    // ── suggest_similar ──────────────────────────────────────────────────
615
616    #[test]
617    fn suggest_similar_typo() {
618        let graph = test_graph();
619        let engine = GroundingEngine::new(&graph);
620
621        let suggestions = engine.suggest_similar("process_paymnt", 5);
622        assert!(
623            suggestions.contains(&"process_payment".to_string()),
624            "Expected process_payment in {:?}",
625            suggestions
626        );
627    }
628
629    #[test]
630    fn suggest_similar_prefix() {
631        let graph = test_graph();
632        let engine = GroundingEngine::new(&graph);
633
634        let suggestions = engine.suggest_similar("add", 5);
635        assert!(
636            suggestions.contains(&"add_unit".to_string()),
637            "Expected add_unit in {:?}",
638            suggestions
639        );
640    }
641
642    #[test]
643    fn suggest_similar_respects_limit() {
644        let graph = test_graph();
645        let engine = GroundingEngine::new(&graph);
646
647        let suggestions = engine.suggest_similar("a", 2);
648        assert!(suggestions.len() <= 2);
649    }
650
651    // ── levenshtein ──────────────────────────────────────────────────────
652
653    #[test]
654    fn levenshtein_identical() {
655        assert_eq!(levenshtein("hello", "hello"), 0);
656    }
657
658    #[test]
659    fn levenshtein_one_edit() {
660        assert_eq!(levenshtein("kitten", "sitten"), 1);
661    }
662
663    #[test]
664    fn levenshtein_full_diff() {
665        assert_eq!(levenshtein("abc", "xyz"), 3);
666    }
667
668    #[test]
669    fn levenshtein_empty() {
670        assert_eq!(levenshtein("", "hello"), 5);
671        assert_eq!(levenshtein("hello", ""), 5);
672        assert_eq!(levenshtein("", ""), 0);
673    }
674
675    // ── pattern detection helpers ────────────────────────────────────────
676
677    #[test]
678    fn test_is_snake_case() {
679        assert!(is_snake_case("process_payment"));
680        assert!(is_snake_case("add_unit"));
681        assert!(is_snake_case("a_b"));
682        assert!(!is_snake_case("process")); // no underscore
683        assert!(!is_snake_case("ProcessPayment")); // CamelCase
684        assert!(!is_snake_case("_leading"));
685        assert!(!is_snake_case("trailing_"));
686        assert!(!is_snake_case("double__under"));
687    }
688
689    #[test]
690    fn test_is_camel_case() {
691        assert!(is_camel_case("CodeGraph"));
692        assert!(is_camel_case("GroundingEngine"));
693        assert!(is_camel_case("MyType2"));
694        assert!(!is_camel_case("codegraph")); // all lower
695        assert!(!is_camel_case("CODEGRAPH")); // all upper
696        assert!(!is_camel_case("A")); // too short
697        assert!(!is_camel_case("Code")); // only one uppercase
698    }
699
700    #[test]
701    fn test_is_screaming_case() {
702        assert!(is_screaming_case("MAX_EDGES_PER_UNIT"));
703        assert!(is_screaming_case("API_KEY"));
704        assert!(!is_screaming_case("max_edges")); // lowercase
705        assert!(!is_screaming_case("NOUNDERSCORES")); // no underscore
706        assert!(!is_screaming_case("_LEADING"));
707        assert!(!is_screaming_case("TRAILING_"));
708    }
709}