mod grep_oracle;
mod schema;
mod templates;
mod tree_sitter_oracle;
mod validator;
pub use grep_oracle::GrepOracle;
pub use schema::{AstPayload, AstResult, FinalPayload, GrepMatch, GrepPayload, SemanticPayload};
pub use templates::{GeneratedQuery, QueryTemplate, TemplateKind};
pub use tree_sitter_oracle::TreeSitterOracle;
pub use validator::{OracleResult, TraceValidator, ValidatedTrace, VerificationMethod};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryType {
PatternMatch,
Structural,
Semantic,
}
#[derive(Debug, Clone, PartialEq)]
pub enum FinalAnswerFormat {
LineNumberedMatches {
matches: Vec<(usize, String)>,
},
CountResult {
count: usize,
},
StructuredData {
data: serde_json::Value,
},
FreeFormText {
text: String,
},
}
impl FinalAnswerFormat {
pub fn parse(answer: &str) -> Self {
let lines: Vec<&str> = answer.lines().collect();
let mut numbered_matches = Vec::new();
let mut all_valid = true;
for line in &lines {
let trimmed = line.trim();
if let Some(colon_pos) = trimmed.find(':') {
let num_part = trimmed[..colon_pos]
.trim()
.trim_start_matches('L')
.trim();
if let Ok(line_num) = num_part.parse::<usize>() {
let text_part = trimmed[colon_pos + 1..].trim().to_string();
numbered_matches.push((line_num, text_part));
} else {
all_valid = false;
break;
}
} else if !trimmed.is_empty() {
all_valid = false;
break;
}
}
if all_valid && !numbered_matches.is_empty() {
return Self::LineNumberedMatches {
matches: numbered_matches,
};
}
let lower = answer.to_lowercase();
if lower.contains("found") || lower.contains("count:") || lower.contains("occurrences") {
if let Some(count) = extract_count_from_text(answer) {
return Self::CountResult { count };
}
}
if answer.trim().starts_with('{') || answer.trim().starts_with('[') {
if let Ok(data) = serde_json::from_str::<serde_json::Value>(answer) {
return Self::StructuredData { data };
}
}
Self::FreeFormText {
text: answer.to_string(),
}
}
}
fn extract_count_from_text(text: &str) -> Option<usize> {
let re = regex::Regex::new(r"(?i)(?:found|count:?\s*)\s*(\d+)|(\d+)\s+(?:functions?|matches?|occurrences?|items?|results?)").ok()?;
for cap in re.captures_iter(text) {
if let Some(m) = cap.get(1) {
if let Ok(n) = m.as_str().parse() {
return Some(n);
}
}
if let Some(m) = cap.get(2) {
if let Ok(n) = m.as_str().parse() {
return Some(n);
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_line_numbered_matches() {
let answer = "42:async fn foo()\n100:pub struct Bar\n";
let format = FinalAnswerFormat::parse(answer);
match format {
FinalAnswerFormat::LineNumberedMatches { matches } => {
assert_eq!(matches.len(), 2);
assert_eq!(matches[0], (42, "async fn foo()".to_string()));
assert_eq!(matches[1], (100, "pub struct Bar".to_string()));
}
_ => panic!("Expected LineNumberedMatches"),
}
}
#[test]
fn parse_count_result() {
let answer = "Found 15 async functions";
let format = FinalAnswerFormat::parse(answer);
match format {
FinalAnswerFormat::CountResult { count } => assert_eq!(count, 15),
_ => panic!("Expected CountResult"),
}
}
#[test]
fn parse_structured_data() {
let answer = r#"{"name": "foo", "args": ["x", "y"]}"#;
let format = FinalAnswerFormat::parse(answer);
match format {
FinalAnswerFormat::StructuredData { data } => {
assert_eq!(data["name"], "foo");
}
_ => panic!("Expected StructuredData"),
}
}
#[test]
fn parse_free_form_text() {
let answer = "This function handles error cases by using the ? operator";
let format = FinalAnswerFormat::parse(answer);
match format {
FinalAnswerFormat::FreeFormText { text } => {
assert!(text.contains("error cases"));
}
_ => panic!("Expected FreeFormText"),
}
}
}