codetether_agent/rlm/oracle/
grep_oracle.rs1use anyhow::Result;
21use regex::Regex;
22use std::collections::HashSet;
23
24use super::{FinalAnswerFormat, QueryType};
25
26#[derive(Debug, Clone, PartialEq)]
28pub enum GrepVerification {
29 ExactMatch,
31 UnorderedMatch,
33 SubsetMatch { claimed: usize, actual: usize },
35 HasFalsePositives {
37 false_positives: Vec<(usize, String)>,
38 },
39 HasFalseNegatives {
41 false_negatives: Vec<(usize, String)>,
42 },
43 Mismatch,
45 CannotVerify { reason: String },
47}
48
49pub struct GrepOracle {
51 #[allow(dead_code)]
53 source: String,
54 source_lines: Vec<String>,
56}
57
58impl GrepOracle {
59 pub fn new(source: String) -> Self {
61 let source_lines = source.lines().map(|s| s.to_string()).collect();
62 Self {
63 source,
64 source_lines,
65 }
66 }
67
68 pub fn classify_query(query: &str) -> QueryType {
70 let lower = query.to_lowercase();
71
72 if lower.contains("find all")
74 || lower.contains("list all")
75 || lower.contains("search for")
76 || lower.contains("grep")
77 || lower.contains("lines matching")
78 || lower.contains("occurrences of")
79 || lower.contains("count")
80 {
81 return QueryType::PatternMatch;
82 }
83
84 if lower.contains("signature")
86 || lower.contains("parameters")
87 || lower.contains("return type")
88 || lower.contains("fields of")
89 || lower.contains("implements")
90 || lower.contains("trait")
91 {
92 return QueryType::Structural;
93 }
94
95 QueryType::Semantic
96 }
97
98 pub fn infer_pattern(query: &str) -> Option<String> {
102 let lower = query.to_lowercase();
103
104 let quoted_re =
109 Regex::new(r#"(?i)(?:occurrences?\s+of|grep(?:\s+for)?|matching|containing)\s+['"`]([^'"`]+)['"`]"#).ok()?;
110 if let Some(caps) = quoted_re.captures(query)
111 && let Some(m) = caps.get(1)
112 {
113 return Some(regex::escape(m.as_str()));
114 }
115
116 let any_quoted_re = Regex::new(r#"['"`]([^'"`]+)['"`]"#).ok()?;
118 if let Some(caps) = any_quoted_re.captures(query)
119 && let Some(m) = caps.get(1)
120 {
121 return Some(regex::escape(m.as_str()));
122 }
123
124 let bare_occurrences_re = Regex::new(r"(?i)occurrences?\s+of\s+(.+?)(?:\s+in\b|$)").ok()?;
126 if let Some(caps) = bare_occurrences_re.captures(query)
127 && let Some(m) = caps.get(1)
128 {
129 let candidate = m.as_str().trim().trim_matches(&['"', '\'', '`'][..]);
130 if !candidate.is_empty() {
131 return Some(regex::escape(candidate));
132 }
133 }
134
135 let patterns = [
137 (r"(?i)find\s+all\s+async\s+functions?", r"\basync\s+fn\b"),
139 (r"(?i)list\s+async\s+functions?", r"\basync\s+fn\b"),
140 (r"(?i)async\s+functions?", r"\basync\s+fn\b"),
141 (r"(?i)public\s+functions?", r"\bpub\s+fn\b"),
143 (r"(?i)pub\s+functions?", r"\bpub\s+fn\b"),
144 (r"(?i)find\s+all\s+structs?", r"\bstruct\b"),
146 (r"(?i)list\s+structs?", r"\bstruct\b"),
147 (r"(?i)all\s+structs?", r"\bstruct\b"),
148 (r"(?i)find\s+all\s+enums?", r"\benum\b"),
150 (r"(?i)list\s+enums?", r"\benum\b"),
151 (r"(?i)find\s+all\s+traits?", r"\btrait\b"),
153 (r"(?i)list\s+traits?", r"\btrait\b"),
154 (r"(?i)find\s+all\s+impls?", r"\bimpl\b"),
156 (r"(?i)implementations?", r"\bimpl\b"),
157 (r"(?i)error\s+handling", r"Result|anyhow|Error|\?"),
159 (r"(?i)unwrap\s+calls?", r"\.unwrap\(\)"),
160 (r"(?i)expect\s+calls?", r"\.expect\("),
161 (r"(?i)use\s+statements?", r"\buse\b"),
163 (r"(?i)imports?", r"\buse\b"),
164 (r"(?i)test\s+functions?", r"#\[test\]"),
166 (r"(?i)async\s+tests?", r"#\[tokio::test\]"),
167 (r"(?i)todo\s+comments?", r"TODO|FIXME|XXX"),
169 (r"(?i)comments?", r"//|/\*"),
170 (r"(?i)macro\s+calls?", r"[a-zA-Z_]+!"),
172 (r"(?i)println!?", r"println!"),
173 (r"(?i)string\s+literals?", r#""[^"]*""#),
175 ];
176
177 for (pattern_re, grep_pattern) in patterns {
178 if let Ok(re) = Regex::new(pattern_re)
179 && re.is_match(&lower)
180 {
181 return Some(grep_pattern.to_string());
182 }
183 }
184
185 None
186 }
187
188 pub fn grep(&self, pattern: &str) -> Result<Vec<(usize, String)>> {
190 let re = Regex::new(pattern)?;
191
192 let matches: Vec<(usize, String)> = self
193 .source_lines
194 .iter()
195 .enumerate()
196 .filter(|(_, line)| re.is_match(line))
197 .map(|(i, line)| (i + 1, line.clone())) .collect();
199
200 Ok(matches)
201 }
202
203 pub fn verify(&self, answer: &str, query: &str) -> GrepVerification {
208 let format = FinalAnswerFormat::parse(answer);
209
210 let pattern = Self::infer_pattern(query)
212 .or_else(|| Self::infer_pattern(answer))
213 .ok_or_else(|| "Could not infer grep pattern from query".to_string());
214 let pattern = match pattern {
215 Ok(p) => p,
216 Err(reason) => {
217 return GrepVerification::CannotVerify { reason };
218 }
219 };
220
221 let ground_truth = match self.grep(&pattern) {
223 Ok(m) => m,
224 Err(e) => {
225 return GrepVerification::CannotVerify {
226 reason: format!("Grep failed: {}", e),
227 };
228 }
229 };
230
231 match format {
233 FinalAnswerFormat::LineNumberedMatches { matches: claimed } => {
234 self.verify_matches(&claimed, &ground_truth)
235 }
236 FinalAnswerFormat::CountResult {
237 count: claimed_count,
238 } => {
239 let actual_count = ground_truth.len();
240 if claimed_count == actual_count {
241 GrepVerification::ExactMatch
242 } else {
243 GrepVerification::SubsetMatch {
244 claimed: claimed_count,
245 actual: actual_count,
246 }
247 }
248 }
249 FinalAnswerFormat::StructuredData { .. } => GrepVerification::CannotVerify {
250 reason: "Structured data not supported by grep oracle".to_string(),
251 },
252 FinalAnswerFormat::FreeFormText { text } => {
253 if let Some(extracted) = self.extract_line_numbers_from_text(&text) {
255 self.verify_matches(&extracted, &ground_truth)
256 } else {
257 GrepVerification::CannotVerify {
258 reason: "Could not extract line numbers from free-form text".to_string(),
259 }
260 }
261 }
262 }
263 }
264
265 pub fn verify_matches(
267 &self,
268 claimed: &[(usize, String)],
269 ground_truth: &[(usize, String)],
270 ) -> GrepVerification {
271 let claimed_set: HashSet<(usize, String)> = claimed.iter().cloned().collect();
272 let truth_set: HashSet<(usize, String)> = ground_truth.iter().cloned().collect();
273
274 if claimed == ground_truth {
276 return GrepVerification::ExactMatch;
277 }
278
279 if claimed_set == truth_set {
281 return GrepVerification::UnorderedMatch;
282 }
283
284 let false_positives: Vec<_> = claimed
286 .iter()
287 .filter(|item| !truth_set.contains(item))
288 .cloned()
289 .collect();
290
291 let false_negatives: Vec<_> = ground_truth
293 .iter()
294 .filter(|item| !claimed_set.contains(item))
295 .cloned()
296 .collect();
297
298 if !false_positives.is_empty() && !false_negatives.is_empty() {
299 GrepVerification::Mismatch
301 } else if !false_positives.is_empty() {
302 GrepVerification::HasFalsePositives { false_positives }
303 } else if !false_negatives.is_empty() {
304 GrepVerification::HasFalseNegatives { false_negatives }
305 } else {
306 GrepVerification::Mismatch
308 }
309 }
310
311 fn extract_line_numbers_from_text(&self, text: &str) -> Option<Vec<(usize, String)>> {
315 let mut results = Vec::new();
316
317 let line_re = Regex::new(r"(?i)(?:line\s+|L)?(\d+):\s*(.+)").ok()?;
319
320 for line in text.lines() {
321 if let Some(cap) = line_re.captures(line)
322 && let (Some(num), Some(content)) = (cap.get(1), cap.get(2))
323 && let Ok(line_num) = num.as_str().parse::<usize>()
324 {
325 results.push((line_num, content.as_str().trim().to_string()));
326 }
327 }
328
329 if results.is_empty() {
330 None
331 } else {
332 Some(results)
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 fn sample_rust_code() -> String {
342 r#"
343use anyhow::Result;
344
345/// Process data
346pub async fn process(input: &str) -> Result<String> {
347 let data = parse(input)?;
348 Ok(data)
349}
350
351async fn parse(input: &str) -> Result<String> {
352 Ok(input.to_uppercase())
353}
354
355pub struct Config {
356 pub debug: bool,
357}
358
359impl Config {
360 pub fn new() -> Self {
361 Self { debug: false }
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 #[test]
368 fn test_process() {
369 assert!(true);
370 }
371}
372"#
373 .to_string()
374 }
375
376 #[test]
377 fn classify_pattern_match_query() {
378 assert_eq!(
379 GrepOracle::classify_query("Find all async functions"),
380 QueryType::PatternMatch
381 );
382 assert_eq!(
383 GrepOracle::classify_query("Count occurrences of TODO"),
384 QueryType::PatternMatch
385 );
386 }
387
388 #[test]
389 fn infer_async_pattern() {
390 let pattern = GrepOracle::infer_pattern("Find all async functions");
391 assert_eq!(pattern, Some(r"\basync\s+fn\b".to_string()));
392 }
393
394 #[test]
395 fn infer_pub_pattern() {
396 let pattern = GrepOracle::infer_pattern("List all public functions");
397 assert_eq!(pattern, Some(r"\bpub\s+fn\b".to_string()));
398 }
399
400 #[test]
401 fn infer_quoted_literal_pattern() {
402 let pattern =
403 GrepOracle::infer_pattern("Find all occurrences of 'async fn' in src/rlm/repl.rs");
404 assert_eq!(pattern, Some(regex::escape("async fn")));
405 }
406
407 #[test]
408 fn grep_finds_matches() {
409 let oracle = GrepOracle::new(sample_rust_code());
410 let matches = oracle.grep(r"\basync\s+fn\b").unwrap();
411 assert_eq!(matches.len(), 2);
412 }
413
414 #[test]
415 fn verify_exact_match() {
416 let oracle = GrepOracle::new(sample_rust_code());
417 let answer = oracle
418 .grep(r"\basync\s+fn\b")
419 .unwrap_or_default()
420 .iter()
421 .map(|(line, text)| format!("{line}:{text}"))
422 .collect::<Vec<_>>()
423 .join("\n");
424 let result = oracle.verify(&answer, "Find all async functions");
425 match result {
427 GrepVerification::ExactMatch
428 | GrepVerification::UnorderedMatch
429 | GrepVerification::SubsetMatch { .. } => (),
430 _ => panic!("Expected match, got {:?}", result),
431 }
432 }
433
434 #[test]
435 fn verify_count_result() {
436 let oracle = GrepOracle::new(sample_rust_code());
437 let answer = "Found 2 async functions";
438 let result = oracle.verify(answer, "Count async functions");
439 assert_eq!(result, GrepVerification::ExactMatch);
440 }
441}