codetether_agent/rlm/oracle/
grep_oracle.rs1use anyhow::Result;
21use regex::Regex;
22use std::collections::HashSet;
23
24use super::{FinalAnswerFormat, QueryType};
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum GrepVerification {
29 ExactMatch,
31 UnorderedMatch,
33 SubsetMatch {
35 claimed: usize,
36 actual: usize,
37 },
38 HasFalsePositives {
40 false_positives: Vec<(usize, String)>,
41 },
42 HasFalseNegatives {
44 false_negatives: Vec<(usize, String)>,
45 },
46 Mismatch,
48 CannotVerify {
50 reason: String,
51 },
52}
53
54pub struct GrepOracle {
56 source: String,
58 source_lines: Vec<String>,
60}
61
62impl GrepOracle {
63 pub fn new(source: String) -> Self {
65 let source_lines = source.lines().map(|s| s.to_string()).collect();
66 Self {
67 source,
68 source_lines,
69 }
70 }
71
72 pub fn classify_query(query: &str) -> QueryType {
74 let lower = query.to_lowercase();
75
76 if lower.contains("find all")
78 || lower.contains("list all")
79 || lower.contains("search for")
80 || lower.contains("grep")
81 || lower.contains("lines matching")
82 || lower.contains("occurrences of")
83 || lower.contains("count")
84 {
85 return QueryType::PatternMatch;
86 }
87
88 if lower.contains("signature")
90 || lower.contains("parameters")
91 || lower.contains("return type")
92 || lower.contains("fields of")
93 || lower.contains("implements")
94 || lower.contains("trait")
95 {
96 return QueryType::Structural;
97 }
98
99 QueryType::Semantic
100 }
101
102 pub fn infer_pattern(query: &str) -> Option<String> {
106 let lower = query.to_lowercase();
107
108 let patterns = [
110 (r"(?i)find\s+all\s+async\s+functions?", r"\basync\s+fn\b"),
112 (r"(?i)list\s+async\s+functions?", r"\basync\s+fn\b"),
113 (r"(?i)async\s+functions?", r"\basync\s+fn\b"),
114
115 (r"(?i)public\s+functions?", r"\bpub\s+fn\b"),
117 (r"(?i)pub\s+functions?", r"\bpub\s+fn\b"),
118
119 (r"(?i)find\s+all\s+structs?", r"\bstruct\b"),
121 (r"(?i)list\s+structs?", r"\bstruct\b"),
122 (r"(?i)all\s+structs?", r"\bstruct\b"),
123
124 (r"(?i)find\s+all\s+enums?", r"\benum\b"),
126 (r"(?i)list\s+enums?", r"\benum\b"),
127
128 (r"(?i)find\s+all\s+traits?", r"\btrait\b"),
130 (r"(?i)list\s+traits?", r"\btrait\b"),
131
132 (r"(?i)find\s+all\s+impls?", r"\bimpl\b"),
134 (r"(?i)implementations?", r"\bimpl\b"),
135
136 (r"(?i)error\s+handling", r"Result|anyhow|Error|\?"),
138 (r"(?i)unwrap\s+calls?", r"\.unwrap\(\)"),
139 (r"(?i)expect\s+calls?", r"\.expect\("),
140
141 (r"(?i)use\s+statements?", r"\buse\b"),
143 (r"(?i)imports?", r"\buse\b"),
144
145 (r"(?i)test\s+functions?", r"#\[test\]"),
147 (r"(?i)async\s+tests?", r"#\[tokio::test\]"),
148
149 (r"(?i)todo\s+comments?", r"TODO|FIXME|XXX"),
151 (r"(?i)comments?", r"//|/\*"),
152
153 (r"(?i)macro\s+calls?", r"[a-zA-Z_]+!"),
155 (r"(?i)println!?", r"println!"),
156
157 (r"(?i)string\s+literals?", r#""[^"]*""#),
159 ];
160
161 for (pattern_re, grep_pattern) in patterns {
162 if let Ok(re) = Regex::new(pattern_re) {
163 if re.is_match(&lower) {
164 return Some(grep_pattern.to_string());
165 }
166 }
167 }
168
169 None
170 }
171
172 pub fn grep(&self, pattern: &str) -> Result<Vec<(usize, String)>> {
174 let re = Regex::new(pattern)?;
175
176 let matches: Vec<(usize, String)> = self
177 .source_lines
178 .iter()
179 .enumerate()
180 .filter(|(_, line)| re.is_match(line))
181 .map(|(i, line)| (i + 1, line.clone())) .collect();
183
184 Ok(matches)
185 }
186
187 pub fn verify(&self, answer: &str, query: &str) -> GrepVerification {
192 let format = FinalAnswerFormat::parse(answer);
193
194 let pattern = match Self::infer_pattern(query) {
196 Some(p) => p,
197 None => {
198 if let FinalAnswerFormat::CountResult { count: _ } = format {
200 return GrepVerification::CannotVerify {
202 reason: "Could not infer grep pattern from query".to_string(),
203 };
204 }
205 return GrepVerification::CannotVerify {
206 reason: "Could not infer grep pattern from query".to_string(),
207 };
208 }
209 };
210
211 let ground_truth = match self.grep(&pattern) {
213 Ok(m) => m,
214 Err(e) => {
215 return GrepVerification::CannotVerify {
216 reason: format!("Grep failed: {}", e),
217 }
218 }
219 };
220
221 match format {
223 FinalAnswerFormat::LineNumberedMatches { matches: claimed } => {
224 self.verify_matches(&claimed, &ground_truth)
225 }
226 FinalAnswerFormat::CountResult { count: claimed_count } => {
227 let actual_count = ground_truth.len();
228 if claimed_count == actual_count {
229 GrepVerification::ExactMatch
230 } else {
231 GrepVerification::SubsetMatch {
232 claimed: claimed_count,
233 actual: actual_count,
234 }
235 }
236 }
237 FinalAnswerFormat::StructuredData { .. } => {
238 GrepVerification::CannotVerify {
239 reason: "Structured data not supported by grep oracle".to_string(),
240 }
241 }
242 FinalAnswerFormat::FreeFormText { text } => {
243 if let Some(extracted) = self.extract_line_numbers_from_text(&text) {
245 self.verify_matches(&extracted, &ground_truth)
246 } else {
247 GrepVerification::CannotVerify {
248 reason: "Could not extract line numbers from free-form text".to_string(),
249 }
250 }
251 }
252 }
253 }
254
255 pub fn verify_matches(
257 &self,
258 claimed: &[(usize, String)],
259 ground_truth: &[(usize, String)],
260 ) -> GrepVerification {
261 let claimed_set: HashSet<(usize, String)> = claimed.iter().cloned().collect();
262 let truth_set: HashSet<(usize, String)> = ground_truth.iter().cloned().collect();
263
264 if claimed == ground_truth {
266 return GrepVerification::ExactMatch;
267 }
268
269 if claimed_set == truth_set {
271 return GrepVerification::UnorderedMatch;
272 }
273
274 let false_positives: Vec<_> = claimed
276 .iter()
277 .filter(|item| !truth_set.contains(item))
278 .cloned()
279 .collect();
280
281 let false_negatives: Vec<_> = ground_truth
283 .iter()
284 .filter(|item| !claimed_set.contains(item))
285 .cloned()
286 .collect();
287
288 if !false_positives.is_empty() && !false_negatives.is_empty() {
289 GrepVerification::Mismatch
291 } else if !false_positives.is_empty() {
292 GrepVerification::HasFalsePositives { false_positives }
293 } else if !false_negatives.is_empty() {
294 GrepVerification::HasFalseNegatives { false_negatives }
295 } else {
296 GrepVerification::Mismatch
298 }
299 }
300
301 fn extract_line_numbers_from_text(&self, text: &str) -> Option<Vec<(usize, String)>> {
305 let mut results = Vec::new();
306
307 let line_re = Regex::new(r"(?i)(?:line\s+|L)?(\d+):\s*(.+)").ok()?;
309
310 for line in text.lines() {
311 if let Some(cap) = line_re.captures(line) {
312 if let (Some(num), Some(content)) = (cap.get(1), cap.get(2)) {
313 if let Ok(line_num) = num.as_str().parse::<usize>() {
314 results.push((line_num, content.as_str().trim().to_string()));
315 }
316 }
317 }
318 }
319
320 if results.is_empty() {
321 None
322 } else {
323 Some(results)
324 }
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 fn sample_rust_code() -> String {
333 r#"
334use anyhow::Result;
335
336/// Process data
337pub async fn process(input: &str) -> Result<String> {
338 let data = parse(input)?;
339 Ok(data)
340}
341
342async fn parse(input: &str) -> Result<String> {
343 Ok(input.to_uppercase())
344}
345
346pub struct Config {
347 pub debug: bool,
348}
349
350impl Config {
351 pub fn new() -> Self {
352 Self { debug: false }
353 }
354}
355
356#[cfg(test)]
357mod tests {
358 #[test]
359 fn test_process() {
360 assert!(true);
361 }
362}
363"#.to_string()
364 }
365
366 #[test]
367 fn classify_pattern_match_query() {
368 assert_eq!(
369 GrepOracle::classify_query("Find all async functions"),
370 QueryType::PatternMatch
371 );
372 assert_eq!(
373 GrepOracle::classify_query("Count occurrences of TODO"),
374 QueryType::PatternMatch
375 );
376 }
377
378 #[test]
379 fn infer_async_pattern() {
380 let pattern = GrepOracle::infer_pattern("Find all async functions");
381 assert_eq!(pattern, Some(r"\basync\s+fn\b".to_string()));
382 }
383
384 #[test]
385 fn infer_pub_pattern() {
386 let pattern = GrepOracle::infer_pattern("List all public functions");
387 assert_eq!(pattern, Some(r"\bpub\s+fn\b".to_string()));
388 }
389
390 #[test]
391 fn grep_finds_matches() {
392 let oracle = GrepOracle::new(sample_rust_code());
393 let matches = oracle.grep(r"\basync\s+fn\b").unwrap();
394 assert_eq!(matches.len(), 2);
395 }
396
397 #[test]
398 fn verify_exact_match() {
399 let oracle = GrepOracle::new(sample_rust_code());
400 let answer = "3:pub async fn process(input: &str) -> Result<String> {\n8:async fn parse(input: &str) -> Result<String> {";
401 let result = oracle.verify(answer, "Find all async functions");
402 match result {
404 GrepVerification::ExactMatch
405 | GrepVerification::UnorderedMatch
406 | GrepVerification::SubsetMatch { .. } => (),
407 _ => panic!("Expected match, got {:?}", result),
408 }
409 }
410
411 #[test]
412 fn verify_count_result() {
413 let oracle = GrepOracle::new(sample_rust_code());
414 let answer = "Found 2 async functions";
415 let result = oracle.verify(answer, "Count async functions");
416 assert_eq!(result, GrepVerification::ExactMatch);
417 }
418}