codetether_agent/rlm/oracle/
tree_sitter_oracle.rs1use anyhow::{Result, anyhow};
23use std::collections::HashMap;
24use streaming_iterator::StreamingIterator;
25
26use super::QueryType;
27
28pub struct TreeSitterOracle {
30 source: String,
32 tree: Option<tree_sitter::Tree>,
34 parser: Option<tree_sitter::Parser>,
36}
37
38#[derive(Debug, Clone, PartialEq)]
40pub struct AstQueryResult {
41 pub query_type: String,
43 pub matches: Vec<AstMatch>,
45}
46
47#[derive(Debug, Clone, PartialEq)]
49pub struct AstMatch {
50 pub line: usize,
52 pub column: usize,
54 pub captures: HashMap<String, String>,
56 pub text: String,
58}
59
60#[derive(Debug, Clone, PartialEq)]
62pub enum TreeSitterVerification {
63 ExactMatch,
65 UnorderedMatch,
67 SubsetMatch { claimed: usize, actual: usize },
69 HasErrors { errors: Vec<String> },
71 Mismatch,
73 CannotVerify { reason: String },
75}
76
77impl TreeSitterOracle {
78 pub fn new(source: String) -> Self {
80 Self {
81 source,
82 tree: None,
83 parser: None,
84 }
85 }
86
87 fn ensure_parser(&mut self) -> Result<()> {
89 if self.parser.is_some() {
90 return Ok(());
91 }
92
93 let mut parser = tree_sitter::Parser::new();
94 parser.set_language(&tree_sitter_rust::LANGUAGE.into())?;
95 self.parser = Some(parser);
96 Ok(())
97 }
98
99 fn parse(&mut self) -> Result<&tree_sitter::Tree> {
101 self.ensure_parser()?;
102
103 if self.tree.is_none() {
104 let parser = self
105 .parser
106 .as_mut()
107 .ok_or_else(|| anyhow!("Parser not initialized"))?;
108 let tree = parser
109 .parse(&self.source, None)
110 .ok_or_else(|| anyhow!("Failed to parse source"))?;
111 self.tree = Some(tree);
112 }
113
114 Ok(self.tree.as_ref().unwrap())
115 }
116
117 pub fn query(&mut self, query_str: &str) -> Result<AstQueryResult> {
124 self.parse()?;
125 let tree = self.tree.as_ref().unwrap();
126 let root = tree.root_node();
127
128 let query = tree_sitter::Query::new(&tree_sitter_rust::LANGUAGE.into(), query_str)?;
129 let mut cursor = tree_sitter::QueryCursor::new();
130
131 let source_bytes = self.source.as_bytes();
132 let mut results = Vec::new();
133
134 let mut matches = cursor.matches(&query, root, source_bytes);
135 while let Some(match_) = matches.next() {
136 let mut captures = HashMap::new();
137 let mut text = String::new();
138 let mut line = 1;
139 let mut column = 1;
140
141 for capture in match_.captures {
142 let node = capture.node;
143 let capture_name = query.capture_names()[capture.index as usize].to_string();
144 let capture_text = node.utf8_text(source_bytes)?.to_string();
145
146 captures.insert(capture_name, capture_text.clone());
147
148 if text.is_empty() {
149 text = capture_text;
150 line = node.start_position().row + 1;
151 column = node.start_position().column + 1;
152 }
153 }
154
155 results.push(AstMatch {
156 line,
157 column,
158 captures,
159 text,
160 });
161 }
162
163 Ok(AstQueryResult {
164 query_type: query_str.to_string(),
165 matches: results,
166 })
167 }
168
169 pub fn get_functions(&mut self) -> Result<Vec<FunctionSignature>> {
171 let result = self.query(
172 r#"
173 (function_item
174 name: (identifier) @name
175 parameters: (parameters) @params
176 return_type: (_)? @return_type)
177 "#,
178 )?;
179
180 let mut functions = Vec::new();
181 for m in result.matches {
182 let name = m.captures.get("name").cloned().unwrap_or_default();
183 let params = m.captures.get("params").cloned().unwrap_or_default();
184 let return_type = m.captures.get("return_type").cloned();
185
186 functions.push(FunctionSignature {
187 name,
188 params,
189 return_type,
190 line: m.line,
191 });
192 }
193
194 Ok(functions)
195 }
196
197 pub fn get_structs(&mut self) -> Result<Vec<StructDefinition>> {
199 let result = self.query(
200 r#"
201 (struct_item
202 name: (type_identifier) @name
203 body: (field_declaration_list)? @body)
204 "#,
205 )?;
206
207 let mut structs = Vec::new();
208 for m in result.matches {
209 let name = m.captures.get("name").cloned().unwrap_or_default();
210 let body = m.captures.get("body").cloned().unwrap_or_default();
211
212 let fields = self.extract_struct_fields(&body)?;
214
215 structs.push(StructDefinition {
216 name,
217 fields,
218 line: m.line,
219 });
220 }
221
222 Ok(structs)
223 }
224
225 fn extract_struct_fields(&self, body: &str) -> Result<Vec<String>> {
227 let mut fields = Vec::new();
228
229 let re = regex::Regex::new(r"(?:pub\s+)?(\w+)\s*:")?;
231 for cap in re.captures_iter(body) {
232 if let Some(name) = cap.get(1) {
233 fields.push(name.as_str().to_string());
234 }
235 }
236
237 Ok(fields)
238 }
239
240 pub fn get_enums(&mut self) -> Result<Vec<EnumDefinition>> {
242 let result = self.query(
243 r#"
244 (enum_item
245 name: (type_identifier) @name
246 body: (enum_variant_list)? @body)
247 "#,
248 )?;
249
250 let mut enums = Vec::new();
251 for m in result.matches {
252 let name = m.captures.get("name").cloned().unwrap_or_default();
253 let body = m.captures.get("body").cloned().unwrap_or_default();
254
255 let variants = self.extract_enum_variants(&body)?;
257
258 enums.push(EnumDefinition {
259 name,
260 variants,
261 line: m.line,
262 });
263 }
264
265 Ok(enums)
266 }
267
268 fn extract_enum_variants(&self, body: &str) -> Result<Vec<String>> {
270 let mut variants = Vec::new();
271
272 let re = regex::Regex::new(r"(\w+)\s*(?:,|=|\{|\()")?;
273 for cap in re.captures_iter(body) {
274 if let Some(name) = cap.get(1) {
275 let name_str = name.as_str();
276 if !["pub", "fn", "struct", "enum", "impl", "trait"].contains(&name_str) {
278 variants.push(name_str.to_string());
279 }
280 }
281 }
282
283 Ok(variants)
284 }
285
286 pub fn get_impls(&mut self) -> Result<Vec<ImplDefinition>> {
288 let result = self.query(
289 r#"
290 [
291 (impl_item
292 type: (type_identifier) @type
293 trait: (type_identifier)? @trait
294 body: (declaration_list)? @body)
295 (impl_item
296 for: (type_identifier) @for
297 trait: (type_identifier) @trait
298 body: (declaration_list)? @body)
299 ]
300 "#,
301 )?;
302
303 let mut impls = Vec::new();
304 for m in result.matches {
305 let type_name = m
306 .captures
307 .get("type")
308 .or_else(|| m.captures.get("for"))
309 .cloned()
310 .unwrap_or_default();
311 let trait_name = m.captures.get("trait").cloned();
312 let body = m.captures.get("body").cloned().unwrap_or_default();
313
314 impls.push(ImplDefinition {
315 type_name,
316 trait_name,
317 method_count: body.matches("fn ").count(),
318 line: m.line,
319 });
320 }
321
322 Ok(impls)
323 }
324
325 pub fn count_error_patterns(&mut self) -> Result<ErrorPatternCounts> {
327 let result_types =
329 self.query(r#"(generic_type type: (type_identifier) @name (#eq? @name "Result"))"#)?;
330
331 let try_operators = self.query(r#"(try_expression)"#)?;
333
334 let unwrap_calls = self.query(r#"(call_expression function: (field_expression field: (field_identifier) @method (#eq? @method "unwrap")))"#)?;
336
337 let expect_calls = self.query(r#"(call_expression function: (field_expression field: (field_identifier) @method (#eq? @method "expect")))"#)?;
339
340 let match_exprs = self.query(r#"(match_expression)"#)?;
342
343 Ok(ErrorPatternCounts {
344 result_types: result_types.matches.len(),
345 try_operators: try_operators.matches.len(),
346 unwrap_calls: unwrap_calls.matches.len(),
347 expect_calls: expect_calls.matches.len(),
348 match_expressions: match_exprs.matches.len(),
349 })
350 }
351
352 pub fn verify(&mut self, answer: &str, query: &str) -> TreeSitterVerification {
354 let query_type = Self::classify_query(query);
355
356 match query_type {
357 QueryType::Structural => {
358 if query.to_lowercase().contains("function") {
360 self.verify_functions(answer)
361 } else if query.to_lowercase().contains("struct") {
362 self.verify_structs(answer)
363 } else if query.to_lowercase().contains("enum") {
364 self.verify_enums(answer)
365 } else if query.to_lowercase().contains("impl") {
366 self.verify_impls(answer)
367 } else {
368 TreeSitterVerification::CannotVerify {
369 reason: "Unknown structural query type".to_string(),
370 }
371 }
372 }
373 _ => TreeSitterVerification::CannotVerify {
374 reason: "Not a structural query".to_string(),
375 },
376 }
377 }
378
379 fn classify_query(query: &str) -> QueryType {
381 let lower = query.to_lowercase();
382
383 if lower.contains("signature")
384 || lower.contains("parameters")
385 || lower.contains("return type")
386 || lower.contains("fields of")
387 || lower.contains("what fields")
388 || lower.contains("struct definition")
389 || lower.contains("enum variants")
390 || lower.contains("implements")
391 || lower.contains("methods")
392 {
393 return QueryType::Structural;
394 }
395
396 QueryType::Semantic
397 }
398
399 fn verify_functions(&mut self, answer: &str) -> TreeSitterVerification {
400 let functions = match self.get_functions() {
401 Ok(f) => f,
402 Err(e) => {
403 return TreeSitterVerification::CannotVerify {
404 reason: format!("Failed to parse functions: {}", e),
405 };
406 }
407 };
408
409 let claimed_names: Vec<String> = answer
411 .lines()
412 .filter_map(|line| {
413 let re = regex::Regex::new(r"\bfn\s+(\w+)").ok()?;
415 re.captures(line)
416 .and_then(|cap| cap.get(1))
417 .map(|m| m.as_str().to_string())
418 })
419 .collect();
420
421 if claimed_names.is_empty() {
422 return TreeSitterVerification::CannotVerify {
423 reason: "Could not extract function names from answer".to_string(),
424 };
425 }
426
427 let actual_names: Vec<String> = functions.iter().map(|f| f.name.clone()).collect();
428 let claimed_set: std::collections::HashSet<_> = claimed_names.iter().cloned().collect();
429 let actual_set: std::collections::HashSet<_> = actual_names.iter().cloned().collect();
430
431 if claimed_set == actual_set {
432 TreeSitterVerification::ExactMatch
433 } else if claimed_set.is_subset(&actual_set) {
434 TreeSitterVerification::SubsetMatch {
435 claimed: claimed_names.len(),
436 actual: actual_names.len(),
437 }
438 } else {
439 let errors = claimed_names
440 .iter()
441 .filter(|name| !actual_set.contains(*name))
442 .map(|name| format!("Function '{}' not found", name))
443 .collect();
444 TreeSitterVerification::HasErrors { errors }
445 }
446 }
447
448 fn verify_structs(&mut self, answer: &str) -> TreeSitterVerification {
449 let structs = match self.get_structs() {
450 Ok(s) => s,
451 Err(e) => {
452 return TreeSitterVerification::CannotVerify {
453 reason: format!("Failed to parse structs: {}", e),
454 };
455 }
456 };
457
458 let claimed_names: Vec<String> = answer
460 .lines()
461 .filter_map(|line| {
462 let re = regex::Regex::new(r"\bstruct\s+(\w+)").ok()?;
463 re.captures(line)
464 .and_then(|cap| cap.get(1))
465 .map(|m| m.as_str().to_string())
466 })
467 .collect();
468
469 if claimed_names.is_empty() {
470 return TreeSitterVerification::CannotVerify {
471 reason: "Could not extract struct names from answer".to_string(),
472 };
473 }
474
475 let actual_names: Vec<String> = structs.iter().map(|s| s.name.clone()).collect();
476 let claimed_set: std::collections::HashSet<_> = claimed_names.iter().cloned().collect();
477 let actual_set: std::collections::HashSet<_> = actual_names.iter().cloned().collect();
478
479 if claimed_set == actual_set {
480 TreeSitterVerification::ExactMatch
481 } else if claimed_set.is_subset(&actual_set) {
482 TreeSitterVerification::SubsetMatch {
483 claimed: claimed_names.len(),
484 actual: actual_names.len(),
485 }
486 } else {
487 let errors = claimed_names
488 .iter()
489 .filter(|name| !actual_set.contains(*name))
490 .map(|name| format!("Struct '{}' not found", name))
491 .collect();
492 TreeSitterVerification::HasErrors { errors }
493 }
494 }
495
496 fn verify_enums(&mut self, _answer: &str) -> TreeSitterVerification {
497 TreeSitterVerification::CannotVerify {
499 reason: "Enum verification not yet implemented".to_string(),
500 }
501 }
502
503 fn verify_impls(&mut self, _answer: &str) -> TreeSitterVerification {
504 TreeSitterVerification::CannotVerify {
506 reason: "Impl verification not yet implemented".to_string(),
507 }
508 }
509}
510
511#[derive(Debug, Clone, PartialEq)]
513pub struct FunctionSignature {
514 pub name: String,
515 pub params: String,
516 pub return_type: Option<String>,
517 pub line: usize,
518}
519
520#[derive(Debug, Clone, PartialEq)]
522pub struct StructDefinition {
523 pub name: String,
524 pub fields: Vec<String>,
525 pub line: usize,
526}
527
528#[derive(Debug, Clone, PartialEq)]
530pub struct EnumDefinition {
531 pub name: String,
532 pub variants: Vec<String>,
533 pub line: usize,
534}
535
536#[derive(Debug, Clone, PartialEq)]
538pub struct ImplDefinition {
539 pub type_name: String,
540 pub trait_name: Option<String>,
541 pub method_count: usize,
542 pub line: usize,
543}
544
545#[derive(Debug, Clone, PartialEq)]
547pub struct ErrorPatternCounts {
548 pub result_types: usize,
549 pub try_operators: usize,
550 pub unwrap_calls: usize,
551 pub expect_calls: usize,
552 pub match_expressions: usize,
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 fn sample_rust_code() -> String {
560 r#"
561use anyhow::Result;
562
563pub struct Config {
564 pub debug: bool,
565 pub timeout: u64,
566}
567
568impl Config {
569 pub fn new() -> Self {
570 Self { debug: false, timeout: 30 }
571 }
572
573 pub fn with_debug(mut self) -> Self {
574 self.debug = true;
575 self
576 }
577}
578
579pub async fn process(input: &str) -> Result<String> {
580 let data = parse(input)?;
581 Ok(data.to_uppercase())
582}
583
584fn parse(input: &str) -> Result<String> {
585 if input.is_empty() {
586 return Err(anyhow!("empty input"));
587 }
588 Ok(input.to_string())
589}
590
591enum Status {
592 Active,
593 Inactive,
594 Pending,
595}
596"#
597 .to_string()
598 }
599
600 #[test]
601 fn get_functions_finds_all() {
602 let mut oracle = TreeSitterOracle::new(sample_rust_code());
603 let functions = oracle.get_functions().unwrap();
604 assert!(functions.len() >= 3);
605
606 let names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect();
607 assert!(names.contains(&"new"));
608 assert!(names.contains(&"with_debug"));
609 assert!(names.contains(&"process"));
610 assert!(names.contains(&"parse"));
611 }
612
613 #[test]
614 fn get_structs_finds_all() {
615 let mut oracle = TreeSitterOracle::new(sample_rust_code());
616 let structs = oracle.get_structs().unwrap();
617 assert!(structs.len() >= 1);
618
619 let config_struct = structs.iter().find(|s| s.name == "Config").unwrap();
620 assert!(config_struct.fields.contains(&"debug".to_string()));
621 assert!(config_struct.fields.contains(&"timeout".to_string()));
622 }
623
624 #[test]
625 fn get_enums_finds_all() {
626 let mut oracle = TreeSitterOracle::new(sample_rust_code());
627 let enums = oracle.get_enums().unwrap();
628 assert!(enums.len() >= 1);
629
630 let status_enum = enums.iter().find(|e| e.name == "Status").unwrap();
631 assert!(status_enum.variants.contains(&"Active".to_string()));
632 assert!(status_enum.variants.contains(&"Inactive".to_string()));
633 }
634
635 #[test]
636 fn count_error_patterns() {
637 let mut oracle = TreeSitterOracle::new(sample_rust_code());
638 let counts = oracle.count_error_patterns().unwrap();
639
640 assert!(counts.result_types >= 2); assert!(counts.try_operators >= 1); }
643}