codetether_agent/rlm/oracle/
mod.rs1mod grep_oracle;
29mod schema;
30mod templates;
31mod tree_sitter_oracle;
32mod validator;
33
34pub use grep_oracle::GrepOracle;
35pub use schema::{AstPayload, AstResult, FinalPayload, GrepMatch, GrepPayload, SemanticPayload};
36pub use templates::{GeneratedQuery, QueryTemplate, TemplateKind};
37pub use tree_sitter_oracle::TreeSitterOracle;
38pub use validator::{OracleResult, TraceValidator, ValidatedTrace, VerificationMethod};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum QueryType {
43 PatternMatch,
45 Structural,
47 Semantic,
49}
50
51#[derive(Debug, Clone, PartialEq)]
53pub enum FinalAnswerFormat {
54 LineNumberedMatches {
56 matches: Vec<(usize, String)>,
57 },
58 CountResult {
60 count: usize,
61 },
62 StructuredData {
64 data: serde_json::Value,
65 },
66 FreeFormText {
68 text: String,
69 },
70}
71
72impl FinalAnswerFormat {
73 pub fn parse(answer: &str) -> Self {
75 let lines: Vec<&str> = answer.lines().collect();
77 let mut numbered_matches = Vec::new();
78 let mut all_valid = true;
79
80 for line in &lines {
81 let trimmed = line.trim();
83 if let Some(colon_pos) = trimmed.find(':') {
84 let num_part = trimmed[..colon_pos]
85 .trim()
86 .trim_start_matches('L')
87 .trim();
88 if let Ok(line_num) = num_part.parse::<usize>() {
89 let text_part = trimmed[colon_pos + 1..].trim().to_string();
90 numbered_matches.push((line_num, text_part));
91 } else {
92 all_valid = false;
93 break;
94 }
95 } else if !trimmed.is_empty() {
96 all_valid = false;
98 break;
99 }
100 }
101
102 if all_valid && !numbered_matches.is_empty() {
103 return Self::LineNumberedMatches {
104 matches: numbered_matches,
105 };
106 }
107
108 let lower = answer.to_lowercase();
110 if lower.contains("found") || lower.contains("count:") || lower.contains("occurrences") {
111 if let Some(count) = extract_count_from_text(answer) {
113 return Self::CountResult { count };
114 }
115 }
116
117 if answer.trim().starts_with('{') || answer.trim().starts_with('[') {
119 if let Ok(data) = serde_json::from_str::<serde_json::Value>(answer) {
120 return Self::StructuredData { data };
121 }
122 }
123
124 Self::FreeFormText {
126 text: answer.to_string(),
127 }
128 }
129}
130
131fn extract_count_from_text(text: &str) -> Option<usize> {
133 let re = regex::Regex::new(r"(?i)(?:found|count:?\s*)\s*(\d+)|(\d+)\s+(?:functions?|matches?|occurrences?|items?|results?)").ok()?;
135
136 for cap in re.captures_iter(text) {
137 if let Some(m) = cap.get(1) {
139 if let Ok(n) = m.as_str().parse() {
140 return Some(n);
141 }
142 }
143 if let Some(m) = cap.get(2) {
145 if let Ok(n) = m.as_str().parse() {
146 return Some(n);
147 }
148 }
149 }
150
151 None
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[test]
159 fn parse_line_numbered_matches() {
160 let answer = "42:async fn foo()\n100:pub struct Bar\n";
161 let format = FinalAnswerFormat::parse(answer);
162 match format {
163 FinalAnswerFormat::LineNumberedMatches { matches } => {
164 assert_eq!(matches.len(), 2);
165 assert_eq!(matches[0], (42, "async fn foo()".to_string()));
166 assert_eq!(matches[1], (100, "pub struct Bar".to_string()));
167 }
168 _ => panic!("Expected LineNumberedMatches"),
169 }
170 }
171
172 #[test]
173 fn parse_count_result() {
174 let answer = "Found 15 async functions";
175 let format = FinalAnswerFormat::parse(answer);
176 match format {
177 FinalAnswerFormat::CountResult { count } => assert_eq!(count, 15),
178 _ => panic!("Expected CountResult"),
179 }
180 }
181
182 #[test]
183 fn parse_structured_data() {
184 let answer = r#"{"name": "foo", "args": ["x", "y"]}"#;
185 let format = FinalAnswerFormat::parse(answer);
186 match format {
187 FinalAnswerFormat::StructuredData { data } => {
188 assert_eq!(data["name"], "foo");
189 }
190 _ => panic!("Expected StructuredData"),
191 }
192 }
193
194 #[test]
195 fn parse_free_form_text() {
196 let answer = "This function handles error cases by using the ? operator";
197 let format = FinalAnswerFormat::parse(answer);
198 match format {
199 FinalAnswerFormat::FreeFormText { text } => {
200 assert!(text.contains("error cases"));
201 }
202 _ => panic!("Expected FreeFormText"),
203 }
204 }
205}