codetether_agent/rlm/oracle/
tree_sitter_oracle.rs1use anyhow::{Result, anyhow};
23use std::collections::HashMap;
24use streaming_iterator::StreamingIterator;
25
26use super::QueryType;
27
28pub struct TreeSitterOracle {
30 source: String,
32 tree: Option<tree_sitter::Tree>,
34 parser: Option<tree_sitter::Parser>,
36}
37
38#[derive(Debug, Clone, PartialEq)]
40pub struct AstQueryResult {
41 pub query_type: String,
43 pub matches: Vec<AstMatch>,
45}
46
47#[derive(Debug, Clone, PartialEq)]
49pub struct AstMatch {
50 pub line: usize,
52 pub column: usize,
54 pub captures: HashMap<String, String>,
56 pub text: String,
58}
59
60#[derive(Debug, Clone, PartialEq)]
62pub enum TreeSitterVerification {
63 ExactMatch,
65 UnorderedMatch,
67 SubsetMatch {
69 claimed: usize,
70 actual: usize,
71 },
72 HasErrors {
74 errors: Vec<String>,
75 },
76 Mismatch,
78 CannotVerify {
80 reason: String,
81 },
82}
83
84impl TreeSitterOracle {
85 pub fn new(source: String) -> Self {
87 Self {
88 source,
89 tree: None,
90 parser: None,
91 }
92 }
93
94 fn ensure_parser(&mut self) -> Result<()> {
96 if self.parser.is_some() {
97 return Ok(());
98 }
99
100 let mut parser = tree_sitter::Parser::new();
101 parser.set_language(&tree_sitter_rust::LANGUAGE.into())?;
102 self.parser = Some(parser);
103 Ok(())
104 }
105
106 fn parse(&mut self) -> Result<&tree_sitter::Tree> {
108 self.ensure_parser()?;
109
110 if self.tree.is_none() {
111 let parser = self.parser.as_mut().ok_or_else(|| anyhow!("Parser not initialized"))?;
112 let tree = parser.parse(&self.source, None)
113 .ok_or_else(|| anyhow!("Failed to parse source"))?;
114 self.tree = Some(tree);
115 }
116
117 Ok(self.tree.as_ref().unwrap())
118 }
119
120 pub fn query(&mut self, query_str: &str) -> Result<AstQueryResult> {
127 self.parse()?;
128 let tree = self.tree.as_ref().unwrap();
129 let root = tree.root_node();
130
131 let query = tree_sitter::Query::new(&tree_sitter_rust::LANGUAGE.into(), query_str)?;
132 let mut cursor = tree_sitter::QueryCursor::new();
133
134 let source_bytes = self.source.as_bytes();
135 let mut results = Vec::new();
136
137 let mut matches = cursor.matches(&query, root, source_bytes);
138 while let Some(match_) = matches.next() {
139 let mut captures = HashMap::new();
140 let mut text = String::new();
141 let mut line = 1;
142 let mut column = 1;
143
144 for capture in match_.captures {
145 let node = capture.node;
146 let capture_name = query.capture_names()[capture.index as usize].to_string();
147 let capture_text = node.utf8_text(source_bytes)?.to_string();
148
149 captures.insert(capture_name, capture_text.clone());
150
151 if text.is_empty() {
152 text = capture_text;
153 line = node.start_position().row + 1;
154 column = node.start_position().column + 1;
155 }
156 }
157
158 results.push(AstMatch {
159 line,
160 column,
161 captures,
162 text,
163 });
164 }
165
166 Ok(AstQueryResult {
167 query_type: query_str.to_string(),
168 matches: results,
169 })
170 }
171
172 pub fn get_functions(&mut self) -> Result<Vec<FunctionSignature>> {
174 let result = self.query(
175 r#"
176 (function_item
177 name: (identifier) @name
178 parameters: (parameters) @params
179 return_type: (_)? @return_type)
180 "#
181 )?;
182
183 let mut functions = Vec::new();
184 for m in result.matches {
185 let name = m.captures.get("name").cloned().unwrap_or_default();
186 let params = m.captures.get("params").cloned().unwrap_or_default();
187 let return_type = m.captures.get("return_type").cloned();
188
189 functions.push(FunctionSignature {
190 name,
191 params,
192 return_type,
193 line: m.line,
194 });
195 }
196
197 Ok(functions)
198 }
199
200 pub fn get_structs(&mut self) -> Result<Vec<StructDefinition>> {
202 let result = self.query(
203 r#"
204 (struct_item
205 name: (type_identifier) @name
206 body: (field_declaration_list)? @body)
207 "#
208 )?;
209
210 let mut structs = Vec::new();
211 for m in result.matches {
212 let name = m.captures.get("name").cloned().unwrap_or_default();
213 let body = m.captures.get("body").cloned().unwrap_or_default();
214
215 let fields = self.extract_struct_fields(&body)?;
217
218 structs.push(StructDefinition {
219 name,
220 fields,
221 line: m.line,
222 });
223 }
224
225 Ok(structs)
226 }
227
228 fn extract_struct_fields(&self, body: &str) -> Result<Vec<String>> {
230 let mut fields = Vec::new();
231
232 let re = regex::Regex::new(r"(?:pub\s+)?(\w+)\s*:")?;
234 for cap in re.captures_iter(body) {
235 if let Some(name) = cap.get(1) {
236 fields.push(name.as_str().to_string());
237 }
238 }
239
240 Ok(fields)
241 }
242
243 pub fn get_enums(&mut self) -> Result<Vec<EnumDefinition>> {
245 let result = self.query(
246 r#"
247 (enum_item
248 name: (type_identifier) @name
249 body: (enum_variant_list)? @body)
250 "#
251 )?;
252
253 let mut enums = Vec::new();
254 for m in result.matches {
255 let name = m.captures.get("name").cloned().unwrap_or_default();
256 let body = m.captures.get("body").cloned().unwrap_or_default();
257
258 let variants = self.extract_enum_variants(&body)?;
260
261 enums.push(EnumDefinition {
262 name,
263 variants,
264 line: m.line,
265 });
266 }
267
268 Ok(enums)
269 }
270
271 fn extract_enum_variants(&self, body: &str) -> Result<Vec<String>> {
273 let mut variants = Vec::new();
274
275 let re = regex::Regex::new(r"(\w+)\s*(?:,|=|\{|\()")?;
276 for cap in re.captures_iter(body) {
277 if let Some(name) = cap.get(1) {
278 let name_str = name.as_str();
279 if !["pub", "fn", "struct", "enum", "impl", "trait"].contains(&name_str) {
281 variants.push(name_str.to_string());
282 }
283 }
284 }
285
286 Ok(variants)
287 }
288
289 pub fn get_impls(&mut self) -> Result<Vec<ImplDefinition>> {
291 let result = self.query(
292 r#"
293 [
294 (impl_item
295 type: (type_identifier) @type
296 trait: (type_identifier)? @trait
297 body: (declaration_list)? @body)
298 (impl_item
299 for: (type_identifier) @for
300 trait: (type_identifier) @trait
301 body: (declaration_list)? @body)
302 ]
303 "#
304 )?;
305
306 let mut impls = Vec::new();
307 for m in result.matches {
308 let type_name = m.captures.get("type")
309 .or_else(|| m.captures.get("for"))
310 .cloned()
311 .unwrap_or_default();
312 let trait_name = m.captures.get("trait").cloned();
313 let body = m.captures.get("body").cloned().unwrap_or_default();
314
315 impls.push(ImplDefinition {
316 type_name,
317 trait_name,
318 method_count: body.matches("fn ").count(),
319 line: m.line,
320 });
321 }
322
323 Ok(impls)
324 }
325
326 pub fn count_error_patterns(&mut self) -> Result<ErrorPatternCounts> {
328 let result_types = self.query(r#"(generic_type type: (type_identifier) @name (#eq? @name "Result"))"#)?;
330
331 let try_operators = self.query(r#"(try_expression)"#)?;
333
334 let unwrap_calls = self.query(r#"(call_expression function: (field_expression field: (field_identifier) @method (#eq? @method "unwrap")))"#)?;
336
337 let expect_calls = self.query(r#"(call_expression function: (field_expression field: (field_identifier) @method (#eq? @method "expect")))"#)?;
339
340 let match_exprs = self.query(r#"(match_expression)"#)?;
342
343 Ok(ErrorPatternCounts {
344 result_types: result_types.matches.len(),
345 try_operators: try_operators.matches.len(),
346 unwrap_calls: unwrap_calls.matches.len(),
347 expect_calls: expect_calls.matches.len(),
348 match_expressions: match_exprs.matches.len(),
349 })
350 }
351
352 pub fn verify(&mut self, answer: &str, query: &str) -> TreeSitterVerification {
354 let query_type = Self::classify_query(query);
355
356 match query_type {
357 QueryType::Structural => {
358 if query.to_lowercase().contains("function") {
360 self.verify_functions(answer)
361 } else if query.to_lowercase().contains("struct") {
362 self.verify_structs(answer)
363 } else if query.to_lowercase().contains("enum") {
364 self.verify_enums(answer)
365 } else if query.to_lowercase().contains("impl") {
366 self.verify_impls(answer)
367 } else {
368 TreeSitterVerification::CannotVerify {
369 reason: "Unknown structural query type".to_string(),
370 }
371 }
372 }
373 _ => TreeSitterVerification::CannotVerify {
374 reason: "Not a structural query".to_string(),
375 }
376 }
377 }
378
379 fn classify_query(query: &str) -> QueryType {
381 let lower = query.to_lowercase();
382
383 if lower.contains("signature")
384 || lower.contains("parameters")
385 || lower.contains("return type")
386 || lower.contains("fields of")
387 || lower.contains("what fields")
388 || lower.contains("struct definition")
389 || lower.contains("enum variants")
390 || lower.contains("implements")
391 || lower.contains("methods")
392 {
393 return QueryType::Structural;
394 }
395
396 QueryType::Semantic
397 }
398
399 fn verify_functions(&mut self, answer: &str) -> TreeSitterVerification {
400 let functions = match self.get_functions() {
401 Ok(f) => f,
402 Err(e) => return TreeSitterVerification::CannotVerify {
403 reason: format!("Failed to parse functions: {}", e),
404 },
405 };
406
407 let claimed_names: Vec<String> = answer
409 .lines()
410 .filter_map(|line| {
411 let re = regex::Regex::new(r"\bfn\s+(\w+)").ok()?;
413 re.captures(line)
414 .and_then(|cap| cap.get(1))
415 .map(|m| m.as_str().to_string())
416 })
417 .collect();
418
419 if claimed_names.is_empty() {
420 return TreeSitterVerification::CannotVerify {
421 reason: "Could not extract function names from answer".to_string(),
422 };
423 }
424
425 let actual_names: Vec<String> = functions.iter().map(|f| f.name.clone()).collect();
426 let claimed_set: std::collections::HashSet<_> = claimed_names.iter().cloned().collect();
427 let actual_set: std::collections::HashSet<_> = actual_names.iter().cloned().collect();
428
429 if claimed_set == actual_set {
430 TreeSitterVerification::ExactMatch
431 } else if claimed_set.is_subset(&actual_set) {
432 TreeSitterVerification::SubsetMatch {
433 claimed: claimed_names.len(),
434 actual: actual_names.len(),
435 }
436 } else {
437 let errors = claimed_names
438 .iter()
439 .filter(|name| !actual_set.contains(*name))
440 .map(|name| format!("Function '{}' not found", name))
441 .collect();
442 TreeSitterVerification::HasErrors { errors }
443 }
444 }
445
446 fn verify_structs(&mut self, answer: &str) -> TreeSitterVerification {
447 let structs = match self.get_structs() {
448 Ok(s) => s,
449 Err(e) => return TreeSitterVerification::CannotVerify {
450 reason: format!("Failed to parse structs: {}", e),
451 },
452 };
453
454 let claimed_names: Vec<String> = answer
456 .lines()
457 .filter_map(|line| {
458 let re = regex::Regex::new(r"\bstruct\s+(\w+)").ok()?;
459 re.captures(line)
460 .and_then(|cap| cap.get(1))
461 .map(|m| m.as_str().to_string())
462 })
463 .collect();
464
465 if claimed_names.is_empty() {
466 return TreeSitterVerification::CannotVerify {
467 reason: "Could not extract struct names from answer".to_string(),
468 };
469 }
470
471 let actual_names: Vec<String> = structs.iter().map(|s| s.name.clone()).collect();
472 let claimed_set: std::collections::HashSet<_> = claimed_names.iter().cloned().collect();
473 let actual_set: std::collections::HashSet<_> = actual_names.iter().cloned().collect();
474
475 if claimed_set == actual_set {
476 TreeSitterVerification::ExactMatch
477 } else if claimed_set.is_subset(&actual_set) {
478 TreeSitterVerification::SubsetMatch {
479 claimed: claimed_names.len(),
480 actual: actual_names.len(),
481 }
482 } else {
483 let errors = claimed_names
484 .iter()
485 .filter(|name| !actual_set.contains(*name))
486 .map(|name| format!("Struct '{}' not found", name))
487 .collect();
488 TreeSitterVerification::HasErrors { errors }
489 }
490 }
491
492 fn verify_enums(&mut self, _answer: &str) -> TreeSitterVerification {
493 TreeSitterVerification::CannotVerify {
495 reason: "Enum verification not yet implemented".to_string(),
496 }
497 }
498
499 fn verify_impls(&mut self, _answer: &str) -> TreeSitterVerification {
500 TreeSitterVerification::CannotVerify {
502 reason: "Impl verification not yet implemented".to_string(),
503 }
504 }
505}
506
507#[derive(Debug, Clone, PartialEq)]
509pub struct FunctionSignature {
510 pub name: String,
511 pub params: String,
512 pub return_type: Option<String>,
513 pub line: usize,
514}
515
516#[derive(Debug, Clone, PartialEq)]
518pub struct StructDefinition {
519 pub name: String,
520 pub fields: Vec<String>,
521 pub line: usize,
522}
523
524#[derive(Debug, Clone, PartialEq)]
526pub struct EnumDefinition {
527 pub name: String,
528 pub variants: Vec<String>,
529 pub line: usize,
530}
531
532#[derive(Debug, Clone, PartialEq)]
534pub struct ImplDefinition {
535 pub type_name: String,
536 pub trait_name: Option<String>,
537 pub method_count: usize,
538 pub line: usize,
539}
540
541#[derive(Debug, Clone, PartialEq)]
543pub struct ErrorPatternCounts {
544 pub result_types: usize,
545 pub try_operators: usize,
546 pub unwrap_calls: usize,
547 pub expect_calls: usize,
548 pub match_expressions: usize,
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554
555 fn sample_rust_code() -> String {
556 r#"
557use anyhow::Result;
558
559pub struct Config {
560 pub debug: bool,
561 pub timeout: u64,
562}
563
564impl Config {
565 pub fn new() -> Self {
566 Self { debug: false, timeout: 30 }
567 }
568
569 pub fn with_debug(mut self) -> Self {
570 self.debug = true;
571 self
572 }
573}
574
575pub async fn process(input: &str) -> Result<String> {
576 let data = parse(input)?;
577 Ok(data.to_uppercase())
578}
579
580fn parse(input: &str) -> Result<String> {
581 if input.is_empty() {
582 return Err(anyhow!("empty input"));
583 }
584 Ok(input.to_string())
585}
586
587enum Status {
588 Active,
589 Inactive,
590 Pending,
591}
592"#.to_string()
593 }
594
595 #[test]
596 fn get_functions_finds_all() {
597 let mut oracle = TreeSitterOracle::new(sample_rust_code());
598 let functions = oracle.get_functions().unwrap();
599 assert!(functions.len() >= 3);
600
601 let names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect();
602 assert!(names.contains(&"new"));
603 assert!(names.contains(&"with_debug"));
604 assert!(names.contains(&"process"));
605 assert!(names.contains(&"parse"));
606 }
607
608 #[test]
609 fn get_structs_finds_all() {
610 let mut oracle = TreeSitterOracle::new(sample_rust_code());
611 let structs = oracle.get_structs().unwrap();
612 assert!(structs.len() >= 1);
613
614 let config_struct = structs.iter().find(|s| s.name == "Config").unwrap();
615 assert!(config_struct.fields.contains(&"debug".to_string()));
616 assert!(config_struct.fields.contains(&"timeout".to_string()));
617 }
618
619 #[test]
620 fn get_enums_finds_all() {
621 let mut oracle = TreeSitterOracle::new(sample_rust_code());
622 let enums = oracle.get_enums().unwrap();
623 assert!(enums.len() >= 1);
624
625 let status_enum = enums.iter().find(|e| e.name == "Status").unwrap();
626 assert!(status_enum.variants.contains(&"Active".to_string()));
627 assert!(status_enum.variants.contains(&"Inactive".to_string()));
628 }
629
630 #[test]
631 fn count_error_patterns() {
632 let mut oracle = TreeSitterOracle::new(sample_rust_code());
633 let counts = oracle.count_error_patterns().unwrap();
634
635 assert!(counts.result_types >= 2); assert!(counts.try_operators >= 1); }
638}