Skip to main content

codetether_agent/rlm/oracle/
grep_oracle.rs

1//! Grep-based oracle for pattern-match verification.
2//!
3//! This oracle verifies FINAL() answers that claim to find patterns in source code
4//! by running actual grep operations and comparing results.
5//!
6//! # Supported Queries
7//!
8//! - "Find all async functions"
9//! - "Find all structs matching pattern X"
10//! - "Count occurrences of Y"
11//! - "List all error handling patterns"
12//!
13//! # Verification Strategy
14//!
15//! 1. Parse the FINAL() answer to extract claimed matches
16//! 2. Run actual grep on the source file with the inferred pattern
17//! 3. Compare line numbers and content
18//! 4. Return verification result
19
20use anyhow::Result;
21use regex::Regex;
22use std::collections::HashSet;
23
24use super::{FinalAnswerFormat, QueryType};
25
26/// Result of grep oracle verification.
27#[derive(Debug, Clone, PartialEq)]
28pub enum GrepVerification {
29    /// Answer matches ground truth exactly.
30    ExactMatch,
31    /// Answer matches but in different order.
32    UnorderedMatch,
33    /// Answer is a subset of ground truth (partial).
34    SubsetMatch { claimed: usize, actual: usize },
35    /// Answer contains claims not in ground truth (false positives).
36    HasFalsePositives {
37        false_positives: Vec<(usize, String)>,
38    },
39    /// Answer is missing ground truth items (false negatives).
40    HasFalseNegatives {
41        false_negatives: Vec<(usize, String)>,
42    },
43    /// Answer is completely different.
44    Mismatch,
45    /// Could not infer pattern from query.
46    CannotVerify { reason: String },
47}
48
49/// Grep-based oracle for validating pattern-match queries.
50pub struct GrepOracle {
51    /// Source code content
52    #[allow(dead_code)]
53    source: String,
54    /// Source code as lines (1-indexed when displayed)
55    source_lines: Vec<String>,
56}
57
58impl GrepOracle {
59    /// Create a new grep oracle for the given source file content.
60    pub fn new(source: String) -> Self {
61        let source_lines = source.lines().map(|s| s.to_string()).collect();
62        Self {
63            source,
64            source_lines,
65        }
66    }
67
68    /// Classify the query type based on keywords.
69    pub fn classify_query(query: &str) -> QueryType {
70        let lower = query.to_lowercase();
71
72        // Pattern-match indicators
73        if lower.contains("find all")
74            || lower.contains("list all")
75            || lower.contains("search for")
76            || lower.contains("grep")
77            || lower.contains("lines matching")
78            || lower.contains("occurrences of")
79            || lower.contains("count")
80        {
81            return QueryType::PatternMatch;
82        }
83
84        // Structural indicators
85        if lower.contains("signature")
86            || lower.contains("parameters")
87            || lower.contains("return type")
88            || lower.contains("fields of")
89            || lower.contains("implements")
90            || lower.contains("trait")
91        {
92            return QueryType::Structural;
93        }
94
95        QueryType::Semantic
96    }
97
98    /// Infer a grep pattern from the query string.
99    ///
100    /// Returns None if the pattern cannot be reliably inferred.
101    pub fn infer_pattern(query: &str) -> Option<String> {
102        let lower = query.to_lowercase();
103
104        // First, try to extract an explicit quoted literal pattern from the query.
105        // Examples:
106        // - Find all occurrences of 'async fn' in file.rs
107        // - grep for "Result<"
108        let quoted_re =
109            Regex::new(r#"(?i)(?:occurrences?\s+of|grep(?:\s+for)?|matching|containing)\s+['"`]([^'"`]+)['"`]"#).ok()?;
110        if let Some(caps) = quoted_re.captures(query)
111            && let Some(m) = caps.get(1)
112        {
113            return Some(regex::escape(m.as_str()));
114        }
115
116        // Fallback: capture any quoted literal in the query.
117        let any_quoted_re = Regex::new(r#"['"`]([^'"`]+)['"`]"#).ok()?;
118        if let Some(caps) = any_quoted_re.captures(query)
119            && let Some(m) = caps.get(1)
120        {
121            return Some(regex::escape(m.as_str()));
122        }
123
124        // Fallback: handle unquoted "occurrences of X in ..." forms.
125        let bare_occurrences_re = Regex::new(r"(?i)occurrences?\s+of\s+(.+?)(?:\s+in\b|$)").ok()?;
126        if let Some(caps) = bare_occurrences_re.captures(query)
127            && let Some(m) = caps.get(1)
128        {
129            let candidate = m.as_str().trim().trim_matches(&['"', '\'', '`'][..]);
130            if !candidate.is_empty() {
131                return Some(regex::escape(candidate));
132            }
133        }
134
135        // Common Rust patterns
136        let patterns = [
137            // Async functions
138            (r"(?i)find\s+all\s+async\s+functions?", r"\basync\s+fn\b"),
139            (r"(?i)list\s+async\s+functions?", r"\basync\s+fn\b"),
140            (r"(?i)async\s+functions?", r"\basync\s+fn\b"),
141            // Public functions
142            (r"(?i)public\s+functions?", r"\bpub\s+fn\b"),
143            (r"(?i)pub\s+functions?", r"\bpub\s+fn\b"),
144            // Structs
145            (r"(?i)find\s+all\s+structs?", r"\bstruct\b"),
146            (r"(?i)list\s+structs?", r"\bstruct\b"),
147            (r"(?i)all\s+structs?", r"\bstruct\b"),
148            // Enums
149            (r"(?i)find\s+all\s+enums?", r"\benum\b"),
150            (r"(?i)list\s+enums?", r"\benum\b"),
151            // Traits
152            (r"(?i)find\s+all\s+traits?", r"\btrait\b"),
153            (r"(?i)list\s+traits?", r"\btrait\b"),
154            // Impls
155            (r"(?i)find\s+all\s+impls?", r"\bimpl\b"),
156            (r"(?i)implementations?", r"\bimpl\b"),
157            // Error handling
158            (r"(?i)error\s+handling", r"Result|anyhow|Error|\?"),
159            (r"(?i)unwrap\s+calls?", r"\.unwrap\(\)"),
160            (r"(?i)expect\s+calls?", r"\.expect\("),
161            // Imports
162            (r"(?i)use\s+statements?", r"\buse\b"),
163            (r"(?i)imports?", r"\buse\b"),
164            // Tests
165            (r"(?i)test\s+functions?", r"#\[test\]"),
166            (r"(?i)async\s+tests?", r"#\[tokio::test\]"),
167            // Comments
168            (r"(?i)todo\s+comments?", r"TODO|FIXME|XXX"),
169            (r"(?i)comments?", r"//|/\*"),
170            // Macros
171            (r"(?i)macro\s+calls?", r"[a-zA-Z_]+!"),
172            (r"(?i)println!?", r"println!"),
173            // String literals
174            (r"(?i)string\s+literals?", r#""[^"]*""#),
175        ];
176
177        for (pattern_re, grep_pattern) in patterns {
178            if let Ok(re) = Regex::new(pattern_re)
179                && re.is_match(&lower)
180            {
181                return Some(grep_pattern.to_string());
182            }
183        }
184
185        None
186    }
187
188    /// Run grep on the source and return matches with line numbers.
189    pub fn grep(&self, pattern: &str) -> Result<Vec<(usize, String)>> {
190        let re = Regex::new(pattern)?;
191
192        let matches: Vec<(usize, String)> = self
193            .source_lines
194            .iter()
195            .enumerate()
196            .filter(|(_, line)| re.is_match(line))
197            .map(|(i, line)| (i + 1, line.clone())) // 1-indexed
198            .collect();
199
200        Ok(matches)
201    }
202
203    /// Verify a FINAL() answer against ground truth.
204    ///
205    /// Takes the claimed answer and the original query, infers the pattern,
206    /// runs grep, and compares results.
207    pub fn verify(&self, answer: &str, query: &str) -> GrepVerification {
208        let format = FinalAnswerFormat::parse(answer);
209
210        // Infer the grep pattern from the query
211        let pattern = Self::infer_pattern(query)
212            .or_else(|| Self::infer_pattern(answer))
213            .ok_or_else(|| "Could not infer grep pattern from query".to_string());
214        let pattern = match pattern {
215            Ok(p) => p,
216            Err(reason) => {
217                return GrepVerification::CannotVerify { reason };
218            }
219        };
220
221        // Get ground truth
222        let ground_truth = match self.grep(&pattern) {
223            Ok(m) => m,
224            Err(e) => {
225                return GrepVerification::CannotVerify {
226                    reason: format!("Grep failed: {}", e),
227                };
228            }
229        };
230
231        // Compare based on answer format
232        match format {
233            FinalAnswerFormat::LineNumberedMatches { matches: claimed } => {
234                self.verify_matches(&claimed, &ground_truth)
235            }
236            FinalAnswerFormat::CountResult {
237                count: claimed_count,
238            } => {
239                let actual_count = ground_truth.len();
240                if claimed_count == actual_count {
241                    GrepVerification::ExactMatch
242                } else {
243                    GrepVerification::SubsetMatch {
244                        claimed: claimed_count,
245                        actual: actual_count,
246                    }
247                }
248            }
249            FinalAnswerFormat::StructuredData { .. } => GrepVerification::CannotVerify {
250                reason: "Structured data not supported by grep oracle".to_string(),
251            },
252            FinalAnswerFormat::FreeFormText { text } => {
253                // Try to extract line numbers from free-form text
254                if let Some(extracted) = self.extract_line_numbers_from_text(&text) {
255                    self.verify_matches(&extracted, &ground_truth)
256                } else {
257                    GrepVerification::CannotVerify {
258                        reason: "Could not extract line numbers from free-form text".to_string(),
259                    }
260                }
261            }
262        }
263    }
264
265    /// Verify claimed matches against ground truth (public for validator use).
266    pub fn verify_matches(
267        &self,
268        claimed: &[(usize, String)],
269        ground_truth: &[(usize, String)],
270    ) -> GrepVerification {
271        let claimed_set: HashSet<(usize, String)> = claimed.iter().cloned().collect();
272        let truth_set: HashSet<(usize, String)> = ground_truth.iter().cloned().collect();
273
274        // Check for exact match (including order)
275        if claimed == ground_truth {
276            return GrepVerification::ExactMatch;
277        }
278
279        // Check for unordered match
280        if claimed_set == truth_set {
281            return GrepVerification::UnorderedMatch;
282        }
283
284        // Check for false positives (claimed but not in ground truth)
285        let false_positives: Vec<_> = claimed
286            .iter()
287            .filter(|item| !truth_set.contains(item))
288            .cloned()
289            .collect();
290
291        // Check for false negatives (in ground truth but not claimed)
292        let false_negatives: Vec<_> = ground_truth
293            .iter()
294            .filter(|item| !claimed_set.contains(item))
295            .cloned()
296            .collect();
297
298        if !false_positives.is_empty() && !false_negatives.is_empty() {
299            // Both false positives and negatives - complete mismatch
300            GrepVerification::Mismatch
301        } else if !false_positives.is_empty() {
302            GrepVerification::HasFalsePositives { false_positives }
303        } else if !false_negatives.is_empty() {
304            GrepVerification::HasFalseNegatives { false_negatives }
305        } else {
306            // Should not reach here, but treat as mismatch
307            GrepVerification::Mismatch
308        }
309    }
310
311    /// Try to extract line numbers from free-form text.
312    ///
313    /// Looks for patterns like "line 42", "L42:", "lines 10-20", etc.
314    fn extract_line_numbers_from_text(&self, text: &str) -> Option<Vec<(usize, String)>> {
315        let mut results = Vec::new();
316
317        // Pattern: "line 42: text" or "L42: text" or "42: text"
318        let line_re = Regex::new(r"(?i)(?:line\s+|L)?(\d+):\s*(.+)").ok()?;
319
320        for line in text.lines() {
321            if let Some(cap) = line_re.captures(line)
322                && let (Some(num), Some(content)) = (cap.get(1), cap.get(2))
323                && let Ok(line_num) = num.as_str().parse::<usize>()
324            {
325                results.push((line_num, content.as_str().trim().to_string()));
326            }
327        }
328
329        if results.is_empty() {
330            None
331        } else {
332            Some(results)
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    fn sample_rust_code() -> String {
342        r#"
343use anyhow::Result;
344
345/// Process data
346pub async fn process(input: &str) -> Result<String> {
347    let data = parse(input)?;
348    Ok(data)
349}
350
351async fn parse(input: &str) -> Result<String> {
352    Ok(input.to_uppercase())
353}
354
355pub struct Config {
356    pub debug: bool,
357}
358
359impl Config {
360    pub fn new() -> Self {
361        Self { debug: false }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    #[test]
368    fn test_process() {
369        assert!(true);
370    }
371}
372"#
373        .to_string()
374    }
375
376    #[test]
377    fn classify_pattern_match_query() {
378        assert_eq!(
379            GrepOracle::classify_query("Find all async functions"),
380            QueryType::PatternMatch
381        );
382        assert_eq!(
383            GrepOracle::classify_query("Count occurrences of TODO"),
384            QueryType::PatternMatch
385        );
386    }
387
388    #[test]
389    fn infer_async_pattern() {
390        let pattern = GrepOracle::infer_pattern("Find all async functions");
391        assert_eq!(pattern, Some(r"\basync\s+fn\b".to_string()));
392    }
393
394    #[test]
395    fn infer_pub_pattern() {
396        let pattern = GrepOracle::infer_pattern("List all public functions");
397        assert_eq!(pattern, Some(r"\bpub\s+fn\b".to_string()));
398    }
399
400    #[test]
401    fn infer_quoted_literal_pattern() {
402        let pattern =
403            GrepOracle::infer_pattern("Find all occurrences of 'async fn' in src/rlm/repl.rs");
404        assert_eq!(pattern, Some(regex::escape("async fn")));
405    }
406
407    #[test]
408    fn grep_finds_matches() {
409        let oracle = GrepOracle::new(sample_rust_code());
410        let matches = oracle.grep(r"\basync\s+fn\b").unwrap();
411        assert_eq!(matches.len(), 2);
412    }
413
414    #[test]
415    fn verify_exact_match() {
416        let oracle = GrepOracle::new(sample_rust_code());
417        let answer = oracle
418            .grep(r"\basync\s+fn\b")
419            .unwrap_or_default()
420            .iter()
421            .map(|(line, text)| format!("{line}:{text}"))
422            .collect::<Vec<_>>()
423            .join("\n");
424        let result = oracle.verify(&answer, "Find all async functions");
425        // May not be exact due to whitespace, but should be close
426        match result {
427            GrepVerification::ExactMatch
428            | GrepVerification::UnorderedMatch
429            | GrepVerification::SubsetMatch { .. } => (),
430            _ => panic!("Expected match, got {:?}", result),
431        }
432    }
433
434    #[test]
435    fn verify_count_result() {
436        let oracle = GrepOracle::new(sample_rust_code());
437        let answer = "Found 2 async functions";
438        let result = oracle.verify(answer, "Count async functions");
439        assert_eq!(result, GrepVerification::ExactMatch);
440    }
441}