use anyhow::Result;
use regex::Regex;
use std::collections::HashSet;
use super::{FinalAnswerFormat, QueryType};
#[derive(Debug, Clone, PartialEq)]
pub enum GrepVerification {
ExactMatch,
UnorderedMatch,
SubsetMatch {
claimed: usize,
actual: usize,
},
HasFalsePositives {
false_positives: Vec<(usize, String)>,
},
HasFalseNegatives {
false_negatives: Vec<(usize, String)>,
},
Mismatch,
CannotVerify {
reason: String,
},
}
pub struct GrepOracle {
source: String,
source_lines: Vec<String>,
}
impl GrepOracle {
pub fn new(source: String) -> Self {
let source_lines = source.lines().map(|s| s.to_string()).collect();
Self {
source,
source_lines,
}
}
pub fn classify_query(query: &str) -> QueryType {
let lower = query.to_lowercase();
if lower.contains("find all")
|| lower.contains("list all")
|| lower.contains("search for")
|| lower.contains("grep")
|| lower.contains("lines matching")
|| lower.contains("occurrences of")
|| lower.contains("count")
{
return QueryType::PatternMatch;
}
if lower.contains("signature")
|| lower.contains("parameters")
|| lower.contains("return type")
|| lower.contains("fields of")
|| lower.contains("implements")
|| lower.contains("trait")
{
return QueryType::Structural;
}
QueryType::Semantic
}
pub fn infer_pattern(query: &str) -> Option<String> {
let lower = query.to_lowercase();
let patterns = [
(r"(?i)find\s+all\s+async\s+functions?", r"\basync\s+fn\b"),
(r"(?i)list\s+async\s+functions?", r"\basync\s+fn\b"),
(r"(?i)async\s+functions?", r"\basync\s+fn\b"),
(r"(?i)public\s+functions?", r"\bpub\s+fn\b"),
(r"(?i)pub\s+functions?", r"\bpub\s+fn\b"),
(r"(?i)find\s+all\s+structs?", r"\bstruct\b"),
(r"(?i)list\s+structs?", r"\bstruct\b"),
(r"(?i)all\s+structs?", r"\bstruct\b"),
(r"(?i)find\s+all\s+enums?", r"\benum\b"),
(r"(?i)list\s+enums?", r"\benum\b"),
(r"(?i)find\s+all\s+traits?", r"\btrait\b"),
(r"(?i)list\s+traits?", r"\btrait\b"),
(r"(?i)find\s+all\s+impls?", r"\bimpl\b"),
(r"(?i)implementations?", r"\bimpl\b"),
(r"(?i)error\s+handling", r"Result|anyhow|Error|\?"),
(r"(?i)unwrap\s+calls?", r"\.unwrap\(\)"),
(r"(?i)expect\s+calls?", r"\.expect\("),
(r"(?i)use\s+statements?", r"\buse\b"),
(r"(?i)imports?", r"\buse\b"),
(r"(?i)test\s+functions?", r"#\[test\]"),
(r"(?i)async\s+tests?", r"#\[tokio::test\]"),
(r"(?i)todo\s+comments?", r"TODO|FIXME|XXX"),
(r"(?i)comments?", r"//|/\*"),
(r"(?i)macro\s+calls?", r"[a-zA-Z_]+!"),
(r"(?i)println!?", r"println!"),
(r"(?i)string\s+literals?", r#""[^"]*""#),
];
for (pattern_re, grep_pattern) in patterns {
if let Ok(re) = Regex::new(pattern_re) {
if re.is_match(&lower) {
return Some(grep_pattern.to_string());
}
}
}
None
}
pub fn grep(&self, pattern: &str) -> Result<Vec<(usize, String)>> {
let re = Regex::new(pattern)?;
let matches: Vec<(usize, String)> = self
.source_lines
.iter()
.enumerate()
.filter(|(_, line)| re.is_match(line))
.map(|(i, line)| (i + 1, line.clone())) .collect();
Ok(matches)
}
pub fn verify(&self, answer: &str, query: &str) -> GrepVerification {
let format = FinalAnswerFormat::parse(answer);
let pattern = match Self::infer_pattern(query) {
Some(p) => p,
None => {
if let FinalAnswerFormat::CountResult { count: _ } = format {
return GrepVerification::CannotVerify {
reason: "Could not infer grep pattern from query".to_string(),
};
}
return GrepVerification::CannotVerify {
reason: "Could not infer grep pattern from query".to_string(),
};
}
};
let ground_truth = match self.grep(&pattern) {
Ok(m) => m,
Err(e) => {
return GrepVerification::CannotVerify {
reason: format!("Grep failed: {}", e),
}
}
};
match format {
FinalAnswerFormat::LineNumberedMatches { matches: claimed } => {
self.verify_matches(&claimed, &ground_truth)
}
FinalAnswerFormat::CountResult { count: claimed_count } => {
let actual_count = ground_truth.len();
if claimed_count == actual_count {
GrepVerification::ExactMatch
} else {
GrepVerification::SubsetMatch {
claimed: claimed_count,
actual: actual_count,
}
}
}
FinalAnswerFormat::StructuredData { .. } => {
GrepVerification::CannotVerify {
reason: "Structured data not supported by grep oracle".to_string(),
}
}
FinalAnswerFormat::FreeFormText { text } => {
if let Some(extracted) = self.extract_line_numbers_from_text(&text) {
self.verify_matches(&extracted, &ground_truth)
} else {
GrepVerification::CannotVerify {
reason: "Could not extract line numbers from free-form text".to_string(),
}
}
}
}
}
pub fn verify_matches(
&self,
claimed: &[(usize, String)],
ground_truth: &[(usize, String)],
) -> GrepVerification {
let claimed_set: HashSet<(usize, String)> = claimed.iter().cloned().collect();
let truth_set: HashSet<(usize, String)> = ground_truth.iter().cloned().collect();
if claimed == ground_truth {
return GrepVerification::ExactMatch;
}
if claimed_set == truth_set {
return GrepVerification::UnorderedMatch;
}
let false_positives: Vec<_> = claimed
.iter()
.filter(|item| !truth_set.contains(item))
.cloned()
.collect();
let false_negatives: Vec<_> = ground_truth
.iter()
.filter(|item| !claimed_set.contains(item))
.cloned()
.collect();
if !false_positives.is_empty() && !false_negatives.is_empty() {
GrepVerification::Mismatch
} else if !false_positives.is_empty() {
GrepVerification::HasFalsePositives { false_positives }
} else if !false_negatives.is_empty() {
GrepVerification::HasFalseNegatives { false_negatives }
} else {
GrepVerification::Mismatch
}
}
fn extract_line_numbers_from_text(&self, text: &str) -> Option<Vec<(usize, String)>> {
let mut results = Vec::new();
let line_re = Regex::new(r"(?i)(?:line\s+|L)?(\d+):\s*(.+)").ok()?;
for line in text.lines() {
if let Some(cap) = line_re.captures(line) {
if let (Some(num), Some(content)) = (cap.get(1), cap.get(2)) {
if let Ok(line_num) = num.as_str().parse::<usize>() {
results.push((line_num, content.as_str().trim().to_string()));
}
}
}
}
if results.is_empty() {
None
} else {
Some(results)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_rust_code() -> String {
r#"
use anyhow::Result;
/// Process data
pub async fn process(input: &str) -> Result<String> {
let data = parse(input)?;
Ok(data)
}
async fn parse(input: &str) -> Result<String> {
Ok(input.to_uppercase())
}
pub struct Config {
pub debug: bool,
}
impl Config {
pub fn new() -> Self {
Self { debug: false }
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_process() {
assert!(true);
}
}
"#.to_string()
}
#[test]
fn classify_pattern_match_query() {
assert_eq!(
GrepOracle::classify_query("Find all async functions"),
QueryType::PatternMatch
);
assert_eq!(
GrepOracle::classify_query("Count occurrences of TODO"),
QueryType::PatternMatch
);
}
#[test]
fn infer_async_pattern() {
let pattern = GrepOracle::infer_pattern("Find all async functions");
assert_eq!(pattern, Some(r"\basync\s+fn\b".to_string()));
}
#[test]
fn infer_pub_pattern() {
let pattern = GrepOracle::infer_pattern("List all public functions");
assert_eq!(pattern, Some(r"\bpub\s+fn\b".to_string()));
}
#[test]
fn grep_finds_matches() {
let oracle = GrepOracle::new(sample_rust_code());
let matches = oracle.grep(r"\basync\s+fn\b").unwrap();
assert_eq!(matches.len(), 2);
}
#[test]
fn verify_exact_match() {
let oracle = GrepOracle::new(sample_rust_code());
let answer = "3:pub async fn process(input: &str) -> Result<String> {\n8:async fn parse(input: &str) -> Result<String> {";
let result = oracle.verify(answer, "Find all async functions");
match result {
GrepVerification::ExactMatch
| GrepVerification::UnorderedMatch
| GrepVerification::SubsetMatch { .. } => (),
_ => panic!("Expected match, got {:?}", result),
}
}
#[test]
fn verify_count_result() {
let oracle = GrepOracle::new(sample_rust_code());
let answer = "Found 2 async functions";
let result = oracle.verify(answer, "Count async functions");
assert_eq!(result, GrepVerification::ExactMatch);
}
}