Skip to main content

codetether_agent/rlm/oracle/
grep_oracle.rs

1//! Grep-based oracle for pattern-match verification.
2//!
3//! This oracle verifies FINAL() answers that claim to find patterns in source code
4//! by running actual grep operations and comparing results.
5//!
6//! # Supported Queries
7//!
8//! - "Find all async functions"
9//! - "Find all structs matching pattern X"
10//! - "Count occurrences of Y"
11//! - "List all error handling patterns"
12//!
13//! # Verification Strategy
14//!
15//! 1. Parse the FINAL() answer to extract claimed matches
16//! 2. Run actual grep on the source file with the inferred pattern
17//! 3. Compare line numbers and content
18//! 4. Return verification result
19
20use anyhow::Result;
21use regex::Regex;
22use std::collections::HashSet;
23
24use super::{FinalAnswerFormat, QueryType};
25
26/// Result of grep oracle verification.
27#[derive(Debug, Clone, PartialEq)]
28pub enum GrepVerification {
29    /// Answer matches ground truth exactly.
30    ExactMatch,
31    /// Answer matches but in different order.
32    UnorderedMatch,
33    /// Answer is a subset of ground truth (partial).
34    SubsetMatch {
35        claimed: usize,
36        actual: usize,
37    },
38    /// Answer contains claims not in ground truth (false positives).
39    HasFalsePositives {
40        false_positives: Vec<(usize, String)>,
41    },
42    /// Answer is missing ground truth items (false negatives).
43    HasFalseNegatives {
44        false_negatives: Vec<(usize, String)>,
45    },
46    /// Answer is completely different.
47    Mismatch,
48    /// Could not infer pattern from query.
49    CannotVerify {
50        reason: String,
51    },
52}
53
54/// Grep-based oracle for validating pattern-match queries.
55pub struct GrepOracle {
56    /// Source code content
57    source: String,
58    /// Source code as lines (1-indexed when displayed)
59    source_lines: Vec<String>,
60}
61
62impl GrepOracle {
63    /// Create a new grep oracle for the given source file content.
64    pub fn new(source: String) -> Self {
65        let source_lines = source.lines().map(|s| s.to_string()).collect();
66        Self {
67            source,
68            source_lines,
69        }
70    }
71
72    /// Classify the query type based on keywords.
73    pub fn classify_query(query: &str) -> QueryType {
74        let lower = query.to_lowercase();
75        
76        // Pattern-match indicators
77        if lower.contains("find all")
78            || lower.contains("list all")
79            || lower.contains("search for")
80            || lower.contains("grep")
81            || lower.contains("lines matching")
82            || lower.contains("occurrences of")
83            || lower.contains("count")
84        {
85            return QueryType::PatternMatch;
86        }
87        
88        // Structural indicators
89        if lower.contains("signature")
90            || lower.contains("parameters")
91            || lower.contains("return type")
92            || lower.contains("fields of")
93            || lower.contains("implements")
94            || lower.contains("trait")
95        {
96            return QueryType::Structural;
97        }
98        
99        QueryType::Semantic
100    }
101
102    /// Infer a grep pattern from the query string.
103    ///
104    /// Returns None if the pattern cannot be reliably inferred.
105    pub fn infer_pattern(query: &str) -> Option<String> {
106        let lower = query.to_lowercase();
107        
108        // Common Rust patterns
109        let patterns = [
110            // Async functions
111            (r"(?i)find\s+all\s+async\s+functions?", r"\basync\s+fn\b"),
112            (r"(?i)list\s+async\s+functions?", r"\basync\s+fn\b"),
113            (r"(?i)async\s+functions?", r"\basync\s+fn\b"),
114            
115            // Public functions
116            (r"(?i)public\s+functions?", r"\bpub\s+fn\b"),
117            (r"(?i)pub\s+functions?", r"\bpub\s+fn\b"),
118            
119            // Structs
120            (r"(?i)find\s+all\s+structs?", r"\bstruct\b"),
121            (r"(?i)list\s+structs?", r"\bstruct\b"),
122            (r"(?i)all\s+structs?", r"\bstruct\b"),
123            
124            // Enums
125            (r"(?i)find\s+all\s+enums?", r"\benum\b"),
126            (r"(?i)list\s+enums?", r"\benum\b"),
127            
128            // Traits
129            (r"(?i)find\s+all\s+traits?", r"\btrait\b"),
130            (r"(?i)list\s+traits?", r"\btrait\b"),
131            
132            // Impls
133            (r"(?i)find\s+all\s+impls?", r"\bimpl\b"),
134            (r"(?i)implementations?", r"\bimpl\b"),
135            
136            // Error handling
137            (r"(?i)error\s+handling", r"Result|anyhow|Error|\?"),
138            (r"(?i)unwrap\s+calls?", r"\.unwrap\(\)"),
139            (r"(?i)expect\s+calls?", r"\.expect\("),
140            
141            // Imports
142            (r"(?i)use\s+statements?", r"\buse\b"),
143            (r"(?i)imports?", r"\buse\b"),
144            
145            // Tests
146            (r"(?i)test\s+functions?", r"#\[test\]"),
147            (r"(?i)async\s+tests?", r"#\[tokio::test\]"),
148            
149            // Comments
150            (r"(?i)todo\s+comments?", r"TODO|FIXME|XXX"),
151            (r"(?i)comments?", r"//|/\*"),
152            
153            // Macros
154            (r"(?i)macro\s+calls?", r"[a-zA-Z_]+!"),
155            (r"(?i)println!?", r"println!"),
156            
157            // String literals
158            (r"(?i)string\s+literals?", r#""[^"]*""#),
159        ];
160
161        for (pattern_re, grep_pattern) in patterns {
162            if let Ok(re) = Regex::new(pattern_re) {
163                if re.is_match(&lower) {
164                    return Some(grep_pattern.to_string());
165                }
166            }
167        }
168        
169        None
170    }
171
172    /// Run grep on the source and return matches with line numbers.
173    pub fn grep(&self, pattern: &str) -> Result<Vec<(usize, String)>> {
174        let re = Regex::new(pattern)?;
175        
176        let matches: Vec<(usize, String)> = self
177            .source_lines
178            .iter()
179            .enumerate()
180            .filter(|(_, line)| re.is_match(line))
181            .map(|(i, line)| (i + 1, line.clone())) // 1-indexed
182            .collect();
183        
184        Ok(matches)
185    }
186
187    /// Verify a FINAL() answer against ground truth.
188    ///
189    /// Takes the claimed answer and the original query, infers the pattern,
190    /// runs grep, and compares results.
191    pub fn verify(&self, answer: &str, query: &str) -> GrepVerification {
192        let format = FinalAnswerFormat::parse(answer);
193        
194        // Infer the grep pattern from the query
195        let pattern = match Self::infer_pattern(query) {
196            Some(p) => p,
197            None => {
198                // Try to extract pattern from answer if it looks like a count
199                if let FinalAnswerFormat::CountResult { count: _ } = format {
200                    // For count results, we still need a pattern
201                    return GrepVerification::CannotVerify {
202                        reason: "Could not infer grep pattern from query".to_string(),
203                    };
204                }
205                return GrepVerification::CannotVerify {
206                    reason: "Could not infer grep pattern from query".to_string(),
207                };
208            }
209        };
210
211        // Get ground truth
212        let ground_truth = match self.grep(&pattern) {
213            Ok(m) => m,
214            Err(e) => {
215                return GrepVerification::CannotVerify {
216                    reason: format!("Grep failed: {}", e),
217                }
218            }
219        };
220
221        // Compare based on answer format
222        match format {
223            FinalAnswerFormat::LineNumberedMatches { matches: claimed } => {
224                self.verify_matches(&claimed, &ground_truth)
225            }
226            FinalAnswerFormat::CountResult { count: claimed_count } => {
227                let actual_count = ground_truth.len();
228                if claimed_count == actual_count {
229                    GrepVerification::ExactMatch
230                } else {
231                    GrepVerification::SubsetMatch {
232                        claimed: claimed_count,
233                        actual: actual_count,
234                    }
235                }
236            }
237            FinalAnswerFormat::StructuredData { .. } => {
238                GrepVerification::CannotVerify {
239                    reason: "Structured data not supported by grep oracle".to_string(),
240                }
241            }
242            FinalAnswerFormat::FreeFormText { text } => {
243                // Try to extract line numbers from free-form text
244                if let Some(extracted) = self.extract_line_numbers_from_text(&text) {
245                    self.verify_matches(&extracted, &ground_truth)
246                } else {
247                    GrepVerification::CannotVerify {
248                        reason: "Could not extract line numbers from free-form text".to_string(),
249                    }
250                }
251            }
252        }
253    }
254
255    /// Verify claimed matches against ground truth (public for validator use).
256    pub fn verify_matches(
257        &self,
258        claimed: &[(usize, String)],
259        ground_truth: &[(usize, String)],
260    ) -> GrepVerification {
261        let claimed_set: HashSet<(usize, String)> = claimed.iter().cloned().collect();
262        let truth_set: HashSet<(usize, String)> = ground_truth.iter().cloned().collect();
263
264        // Check for exact match (including order)
265        if claimed == ground_truth {
266            return GrepVerification::ExactMatch;
267        }
268
269        // Check for unordered match
270        if claimed_set == truth_set {
271            return GrepVerification::UnorderedMatch;
272        }
273
274        // Check for false positives (claimed but not in ground truth)
275        let false_positives: Vec<_> = claimed
276            .iter()
277            .filter(|item| !truth_set.contains(item))
278            .cloned()
279            .collect();
280
281        // Check for false negatives (in ground truth but not claimed)
282        let false_negatives: Vec<_> = ground_truth
283            .iter()
284            .filter(|item| !claimed_set.contains(item))
285            .cloned()
286            .collect();
287
288        if !false_positives.is_empty() && !false_negatives.is_empty() {
289            // Both false positives and negatives - complete mismatch
290            GrepVerification::Mismatch
291        } else if !false_positives.is_empty() {
292            GrepVerification::HasFalsePositives { false_positives }
293        } else if !false_negatives.is_empty() {
294            GrepVerification::HasFalseNegatives { false_negatives }
295        } else {
296            // Should not reach here, but treat as mismatch
297            GrepVerification::Mismatch
298        }
299    }
300
301    /// Try to extract line numbers from free-form text.
302    ///
303    /// Looks for patterns like "line 42", "L42:", "lines 10-20", etc.
304    fn extract_line_numbers_from_text(&self, text: &str) -> Option<Vec<(usize, String)>> {
305        let mut results = Vec::new();
306        
307        // Pattern: "line 42: text" or "L42: text" or "42: text"
308        let line_re = Regex::new(r"(?i)(?:line\s+|L)?(\d+):\s*(.+)").ok()?;
309        
310        for line in text.lines() {
311            if let Some(cap) = line_re.captures(line) {
312                if let (Some(num), Some(content)) = (cap.get(1), cap.get(2)) {
313                    if let Ok(line_num) = num.as_str().parse::<usize>() {
314                        results.push((line_num, content.as_str().trim().to_string()));
315                    }
316                }
317            }
318        }
319        
320        if results.is_empty() {
321            None
322        } else {
323            Some(results)
324        }
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    fn sample_rust_code() -> String {
333        r#"
334use anyhow::Result;
335
336/// Process data
337pub async fn process(input: &str) -> Result<String> {
338    let data = parse(input)?;
339    Ok(data)
340}
341
342async fn parse(input: &str) -> Result<String> {
343    Ok(input.to_uppercase())
344}
345
346pub struct Config {
347    pub debug: bool,
348}
349
350impl Config {
351    pub fn new() -> Self {
352        Self { debug: false }
353    }
354}
355
356#[cfg(test)]
357mod tests {
358    #[test]
359    fn test_process() {
360        assert!(true);
361    }
362}
363"#.to_string()
364    }
365
366    #[test]
367    fn classify_pattern_match_query() {
368        assert_eq!(
369            GrepOracle::classify_query("Find all async functions"),
370            QueryType::PatternMatch
371        );
372        assert_eq!(
373            GrepOracle::classify_query("Count occurrences of TODO"),
374            QueryType::PatternMatch
375        );
376    }
377
378    #[test]
379    fn infer_async_pattern() {
380        let pattern = GrepOracle::infer_pattern("Find all async functions");
381        assert_eq!(pattern, Some(r"\basync\s+fn\b".to_string()));
382    }
383
384    #[test]
385    fn infer_pub_pattern() {
386        let pattern = GrepOracle::infer_pattern("List all public functions");
387        assert_eq!(pattern, Some(r"\bpub\s+fn\b".to_string()));
388    }
389
390    #[test]
391    fn grep_finds_matches() {
392        let oracle = GrepOracle::new(sample_rust_code());
393        let matches = oracle.grep(r"\basync\s+fn\b").unwrap();
394        assert_eq!(matches.len(), 2);
395    }
396
397    #[test]
398    fn verify_exact_match() {
399        let oracle = GrepOracle::new(sample_rust_code());
400        let answer = "3:pub async fn process(input: &str) -> Result<String> {\n8:async fn parse(input: &str) -> Result<String> {";
401        let result = oracle.verify(answer, "Find all async functions");
402        // May not be exact due to whitespace, but should be close
403        match result {
404            GrepVerification::ExactMatch 
405            | GrepVerification::UnorderedMatch 
406            | GrepVerification::SubsetMatch { .. } => (),
407            _ => panic!("Expected match, got {:?}", result),
408        }
409    }
410
411    #[test]
412    fn verify_count_result() {
413        let oracle = GrepOracle::new(sample_rust_code());
414        let answer = "Found 2 async functions";
415        let result = oracle.verify(answer, "Count async functions");
416        assert_eq!(result, GrepVerification::ExactMatch);
417    }
418}