Skip to main content

codetether_agent/rlm/oracle/
tree_sitter_oracle.rs

1//! Tree-sitter oracle for structural AST verification.
2//!
3//! This oracle uses tree-sitter to parse source code and verify structural
4//! queries about function signatures, struct fields, trait implementations, etc.
5//!
6//! # Supported Queries
7//!
8//! - Function signatures (name, args, return type)
9//! - Struct/enum definitions and field listings
10//! - Impl blocks and trait implementations
11//! - Error handling patterns (Result, match arms, ? operator)
12//!
13//! # Features
14//!
15//! - Exposed as a verification oracle for FINAL() answers
16//! - Also exposed as a new DSL command: `ast_query("(function_item)")`
17//!
18//! # Dependencies
19//!
20//! Requires `tree-sitter` and `tree-sitter-rust` crates.
21
22use anyhow::{Result, anyhow};
23use std::collections::HashMap;
24use streaming_iterator::StreamingIterator;
25
26use super::QueryType;
27
28/// Tree-sitter based oracle for validating structural queries.
29pub struct TreeSitterOracle {
30    /// Source code content
31    source: String,
32    /// Parsed tree-sitter tree (lazy-initialized)
33    tree: Option<tree_sitter::Tree>,
34    /// Language parser
35    parser: Option<tree_sitter::Parser>,
36}
37
38/// Result of tree-sitter AST query.
39#[derive(Debug, Clone, PartialEq)]
40pub struct AstQueryResult {
41    /// Query type that was executed
42    pub query_type: String,
43    /// Matched nodes with their captures
44    pub matches: Vec<AstMatch>,
45}
46
47/// A single match from an AST query.
48#[derive(Debug, Clone, PartialEq)]
49pub struct AstMatch {
50    /// Line number (1-indexed)
51    pub line: usize,
52    /// Column (1-indexed)
53    pub column: usize,
54    /// Captured nodes
55    pub captures: HashMap<String, String>,
56    /// Full text of the matched node
57    pub text: String,
58}
59
60/// Result of tree-sitter oracle verification.
61#[derive(Debug, Clone, PartialEq)]
62pub enum TreeSitterVerification {
63    /// Answer matches AST truth exactly.
64    ExactMatch,
65    /// Answer matches but in different order.
66    UnorderedMatch,
67    /// Answer is a subset (partial match).
68    SubsetMatch {
69        claimed: usize,
70        actual: usize,
71    },
72    /// Answer contains incorrect claims.
73    HasErrors {
74        errors: Vec<String>,
75    },
76    /// Answer is completely different.
77    Mismatch,
78    /// Could not parse or verify.
79    CannotVerify {
80        reason: String,
81    },
82}
83
84impl TreeSitterOracle {
85    /// Create a new tree-sitter oracle for the given source.
86    pub fn new(source: String) -> Self {
87        Self {
88            source,
89            tree: None,
90            parser: None,
91        }
92    }
93
94    /// Initialize the parser (lazy).
95    fn ensure_parser(&mut self) -> Result<()> {
96        if self.parser.is_some() {
97            return Ok(());
98        }
99
100        let mut parser = tree_sitter::Parser::new();
101        parser.set_language(&tree_sitter_rust::LANGUAGE.into())?;
102        self.parser = Some(parser);
103        Ok(())
104    }
105
106    /// Parse the source and return the tree.
107    fn parse(&mut self) -> Result<&tree_sitter::Tree> {
108        self.ensure_parser()?;
109        
110        if self.tree.is_none() {
111            let parser = self.parser.as_mut().ok_or_else(|| anyhow!("Parser not initialized"))?;
112            let tree = parser.parse(&self.source, None)
113                .ok_or_else(|| anyhow!("Failed to parse source"))?;
114            self.tree = Some(tree);
115        }
116        
117        Ok(self.tree.as_ref().unwrap())
118    }
119
120    /// Execute a tree-sitter S-expression query.
121    ///
122    /// Example queries:
123    /// - `(function_item name: (identifier) @name)`
124    /// - `(struct_item name: (type_identifier) @name body: (field_declaration_list))`
125    /// - `(impl_item trait: (type_identifier) @trait for: (type_identifier) @for)`
126    pub fn query(&mut self, query_str: &str) -> Result<AstQueryResult> {
127        self.parse()?;
128        let tree = self.tree.as_ref().unwrap();
129        let root = tree.root_node();
130        
131        let query = tree_sitter::Query::new(&tree_sitter_rust::LANGUAGE.into(), query_str)?;
132        let mut cursor = tree_sitter::QueryCursor::new();
133        
134        let source_bytes = self.source.as_bytes();
135        let mut results = Vec::new();
136        
137        let mut matches = cursor.matches(&query, root, source_bytes);
138        while let Some(match_) = matches.next() {
139            let mut captures = HashMap::new();
140            let mut text = String::new();
141            let mut line = 1;
142            let mut column = 1;
143            
144            for capture in match_.captures {
145                let node = capture.node;
146                let capture_name = query.capture_names()[capture.index as usize].to_string();
147                let capture_text = node.utf8_text(source_bytes)?.to_string();
148                
149                captures.insert(capture_name, capture_text.clone());
150                
151                if text.is_empty() {
152                    text = capture_text;
153                    line = node.start_position().row + 1;
154                    column = node.start_position().column + 1;
155                }
156            }
157            
158            results.push(AstMatch {
159                line,
160                column,
161                captures,
162                text,
163            });
164        }
165        
166        Ok(AstQueryResult {
167            query_type: query_str.to_string(),
168            matches: results,
169        })
170    }
171
172    /// Get all function signatures in the source.
173    pub fn get_functions(&mut self) -> Result<Vec<FunctionSignature>> {
174        let result = self.query(
175            r#"
176            (function_item
177                name: (identifier) @name
178                parameters: (parameters) @params
179                return_type: (_)? @return_type)
180            "#
181        )?;
182        
183        let mut functions = Vec::new();
184        for m in result.matches {
185            let name = m.captures.get("name").cloned().unwrap_or_default();
186            let params = m.captures.get("params").cloned().unwrap_or_default();
187            let return_type = m.captures.get("return_type").cloned();
188            
189            functions.push(FunctionSignature {
190                name,
191                params,
192                return_type,
193                line: m.line,
194            });
195        }
196        
197        Ok(functions)
198    }
199
200    /// Get all struct definitions in the source.
201    pub fn get_structs(&mut self) -> Result<Vec<StructDefinition>> {
202        let result = self.query(
203            r#"
204            (struct_item
205                name: (type_identifier) @name
206                body: (field_declaration_list)? @body)
207            "#
208        )?;
209        
210        let mut structs = Vec::new();
211        for m in result.matches {
212            let name = m.captures.get("name").cloned().unwrap_or_default();
213            let body = m.captures.get("body").cloned().unwrap_or_default();
214            
215            // Extract fields from body
216            let fields = self.extract_struct_fields(&body)?;
217            
218            structs.push(StructDefinition {
219                name,
220                fields,
221                line: m.line,
222            });
223        }
224        
225        Ok(structs)
226    }
227
228    /// Extract field names from a struct body.
229    fn extract_struct_fields(&self, body: &str) -> Result<Vec<String>> {
230        let mut fields = Vec::new();
231        
232        // Simple regex-based extraction (faster than re-parsing)
233        let re = regex::Regex::new(r"(?:pub\s+)?(\w+)\s*:")?;
234        for cap in re.captures_iter(body) {
235            if let Some(name) = cap.get(1) {
236                fields.push(name.as_str().to_string());
237            }
238        }
239        
240        Ok(fields)
241    }
242
243    /// Get all enum definitions in the source.
244    pub fn get_enums(&mut self) -> Result<Vec<EnumDefinition>> {
245        let result = self.query(
246            r#"
247            (enum_item
248                name: (type_identifier) @name
249                body: (enum_variant_list)? @body)
250            "#
251        )?;
252        
253        let mut enums = Vec::new();
254        for m in result.matches {
255            let name = m.captures.get("name").cloned().unwrap_or_default();
256            let body = m.captures.get("body").cloned().unwrap_or_default();
257            
258            // Extract variants from body
259            let variants = self.extract_enum_variants(&body)?;
260            
261            enums.push(EnumDefinition {
262                name,
263                variants,
264                line: m.line,
265            });
266        }
267        
268        Ok(enums)
269    }
270
271    /// Extract variant names from an enum body.
272    fn extract_enum_variants(&self, body: &str) -> Result<Vec<String>> {
273        let mut variants = Vec::new();
274        
275        let re = regex::Regex::new(r"(\w+)\s*(?:,|=|\{|\()")?;
276        for cap in re.captures_iter(body) {
277            if let Some(name) = cap.get(1) {
278                let name_str = name.as_str();
279                // Skip keywords
280                if !["pub", "fn", "struct", "enum", "impl", "trait"].contains(&name_str) {
281                    variants.push(name_str.to_string());
282                }
283            }
284        }
285        
286        Ok(variants)
287    }
288
289    /// Get all impl blocks in the source.
290    pub fn get_impls(&mut self) -> Result<Vec<ImplDefinition>> {
291        let result = self.query(
292            r#"
293            [
294                (impl_item
295                    type: (type_identifier) @type
296                    trait: (type_identifier)? @trait
297                    body: (declaration_list)? @body)
298                (impl_item
299                    for: (type_identifier) @for
300                    trait: (type_identifier) @trait
301                    body: (declaration_list)? @body)
302            ]
303            "#
304        )?;
305        
306        let mut impls = Vec::new();
307        for m in result.matches {
308            let type_name = m.captures.get("type")
309                .or_else(|| m.captures.get("for"))
310                .cloned()
311                .unwrap_or_default();
312            let trait_name = m.captures.get("trait").cloned();
313            let body = m.captures.get("body").cloned().unwrap_or_default();
314            
315            impls.push(ImplDefinition {
316                type_name,
317                trait_name,
318                method_count: body.matches("fn ").count(),
319                line: m.line,
320            });
321        }
322        
323        Ok(impls)
324    }
325
326    /// Count error handling patterns.
327    pub fn count_error_patterns(&mut self) -> Result<ErrorPatternCounts> {
328        // Count Result<T> types
329        let result_types = self.query(r#"(generic_type type: (type_identifier) @name (#eq? @name "Result"))"#)?;
330        
331        // Count ? operators
332        let try_operators = self.query(r#"(try_expression)"#)?;
333        
334        // Count .unwrap() calls
335        let unwrap_calls = self.query(r#"(call_expression function: (field_expression field: (field_identifier) @method (#eq? @method "unwrap")))"#)?;
336        
337        // Count .expect() calls
338        let expect_calls = self.query(r#"(call_expression function: (field_expression field: (field_identifier) @method (#eq? @method "expect")))"#)?;
339        
340        // Count match expressions
341        let match_exprs = self.query(r#"(match_expression)"#)?;
342        
343        Ok(ErrorPatternCounts {
344            result_types: result_types.matches.len(),
345            try_operators: try_operators.matches.len(),
346            unwrap_calls: unwrap_calls.matches.len(),
347            expect_calls: expect_calls.matches.len(),
348            match_expressions: match_exprs.matches.len(),
349        })
350    }
351
352    /// Verify a FINAL() answer against AST truth.
353    pub fn verify(&mut self, answer: &str, query: &str) -> TreeSitterVerification {
354        let query_type = Self::classify_query(query);
355        
356        match query_type {
357            QueryType::Structural => {
358                // Try to match against different structural queries
359                if query.to_lowercase().contains("function") {
360                    self.verify_functions(answer)
361                } else if query.to_lowercase().contains("struct") {
362                    self.verify_structs(answer)
363                } else if query.to_lowercase().contains("enum") {
364                    self.verify_enums(answer)
365                } else if query.to_lowercase().contains("impl") {
366                    self.verify_impls(answer)
367                } else {
368                    TreeSitterVerification::CannotVerify {
369                        reason: "Unknown structural query type".to_string(),
370                    }
371                }
372            }
373            _ => TreeSitterVerification::CannotVerify {
374                reason: "Not a structural query".to_string(),
375            }
376        }
377    }
378
379    /// Classify query type for tree-sitter routing.
380    fn classify_query(query: &str) -> QueryType {
381        let lower = query.to_lowercase();
382        
383        if lower.contains("signature")
384            || lower.contains("parameters")
385            || lower.contains("return type")
386            || lower.contains("fields of")
387            || lower.contains("what fields")
388            || lower.contains("struct definition")
389            || lower.contains("enum variants")
390            || lower.contains("implements")
391            || lower.contains("methods")
392        {
393            return QueryType::Structural;
394        }
395        
396        QueryType::Semantic
397    }
398
399    fn verify_functions(&mut self, answer: &str) -> TreeSitterVerification {
400        let functions = match self.get_functions() {
401            Ok(f) => f,
402            Err(e) => return TreeSitterVerification::CannotVerify {
403                reason: format!("Failed to parse functions: {}", e),
404            },
405        };
406        
407        // Parse answer to extract claimed function names
408        let claimed_names: Vec<String> = answer
409            .lines()
410            .filter_map(|line| {
411                // Try to extract function name from line
412                let re = regex::Regex::new(r"\bfn\s+(\w+)").ok()?;
413                re.captures(line)
414                    .and_then(|cap| cap.get(1))
415                    .map(|m| m.as_str().to_string())
416            })
417            .collect();
418        
419        if claimed_names.is_empty() {
420            return TreeSitterVerification::CannotVerify {
421                reason: "Could not extract function names from answer".to_string(),
422            };
423        }
424        
425        let actual_names: Vec<String> = functions.iter().map(|f| f.name.clone()).collect();
426        let claimed_set: std::collections::HashSet<_> = claimed_names.iter().cloned().collect();
427        let actual_set: std::collections::HashSet<_> = actual_names.iter().cloned().collect();
428        
429        if claimed_set == actual_set {
430            TreeSitterVerification::ExactMatch
431        } else if claimed_set.is_subset(&actual_set) {
432            TreeSitterVerification::SubsetMatch {
433                claimed: claimed_names.len(),
434                actual: actual_names.len(),
435            }
436        } else {
437            let errors = claimed_names
438                .iter()
439                .filter(|name| !actual_set.contains(*name))
440                .map(|name| format!("Function '{}' not found", name))
441                .collect();
442            TreeSitterVerification::HasErrors { errors }
443        }
444    }
445
446    fn verify_structs(&mut self, answer: &str) -> TreeSitterVerification {
447        let structs = match self.get_structs() {
448            Ok(s) => s,
449            Err(e) => return TreeSitterVerification::CannotVerify {
450                reason: format!("Failed to parse structs: {}", e),
451            },
452        };
453        
454        // Similar logic to verify_functions
455        let claimed_names: Vec<String> = answer
456            .lines()
457            .filter_map(|line| {
458                let re = regex::Regex::new(r"\bstruct\s+(\w+)").ok()?;
459                re.captures(line)
460                    .and_then(|cap| cap.get(1))
461                    .map(|m| m.as_str().to_string())
462            })
463            .collect();
464        
465        if claimed_names.is_empty() {
466            return TreeSitterVerification::CannotVerify {
467                reason: "Could not extract struct names from answer".to_string(),
468            };
469        }
470        
471        let actual_names: Vec<String> = structs.iter().map(|s| s.name.clone()).collect();
472        let claimed_set: std::collections::HashSet<_> = claimed_names.iter().cloned().collect();
473        let actual_set: std::collections::HashSet<_> = actual_names.iter().cloned().collect();
474        
475        if claimed_set == actual_set {
476            TreeSitterVerification::ExactMatch
477        } else if claimed_set.is_subset(&actual_set) {
478            TreeSitterVerification::SubsetMatch {
479                claimed: claimed_names.len(),
480                actual: actual_names.len(),
481            }
482        } else {
483            let errors = claimed_names
484                .iter()
485                .filter(|name| !actual_set.contains(*name))
486                .map(|name| format!("Struct '{}' not found", name))
487                .collect();
488            TreeSitterVerification::HasErrors { errors }
489        }
490    }
491
492    fn verify_enums(&mut self, _answer: &str) -> TreeSitterVerification {
493        // Similar pattern
494        TreeSitterVerification::CannotVerify {
495            reason: "Enum verification not yet implemented".to_string(),
496        }
497    }
498
499    fn verify_impls(&mut self, _answer: &str) -> TreeSitterVerification {
500        // Similar pattern
501        TreeSitterVerification::CannotVerify {
502            reason: "Impl verification not yet implemented".to_string(),
503        }
504    }
505}
506
507/// Function signature information.
508#[derive(Debug, Clone, PartialEq)]
509pub struct FunctionSignature {
510    pub name: String,
511    pub params: String,
512    pub return_type: Option<String>,
513    pub line: usize,
514}
515
516/// Struct definition information.
517#[derive(Debug, Clone, PartialEq)]
518pub struct StructDefinition {
519    pub name: String,
520    pub fields: Vec<String>,
521    pub line: usize,
522}
523
524/// Enum definition information.
525#[derive(Debug, Clone, PartialEq)]
526pub struct EnumDefinition {
527    pub name: String,
528    pub variants: Vec<String>,
529    pub line: usize,
530}
531
532/// Impl block information.
533#[derive(Debug, Clone, PartialEq)]
534pub struct ImplDefinition {
535    pub type_name: String,
536    pub trait_name: Option<String>,
537    pub method_count: usize,
538    pub line: usize,
539}
540
541/// Counts of error handling patterns.
542#[derive(Debug, Clone, PartialEq)]
543pub struct ErrorPatternCounts {
544    pub result_types: usize,
545    pub try_operators: usize,
546    pub unwrap_calls: usize,
547    pub expect_calls: usize,
548    pub match_expressions: usize,
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554
555    fn sample_rust_code() -> String {
556        r#"
557use anyhow::Result;
558
559pub struct Config {
560    pub debug: bool,
561    pub timeout: u64,
562}
563
564impl Config {
565    pub fn new() -> Self {
566        Self { debug: false, timeout: 30 }
567    }
568    
569    pub fn with_debug(mut self) -> Self {
570        self.debug = true;
571        self
572    }
573}
574
575pub async fn process(input: &str) -> Result<String> {
576    let data = parse(input)?;
577    Ok(data.to_uppercase())
578}
579
580fn parse(input: &str) -> Result<String> {
581    if input.is_empty() {
582        return Err(anyhow!("empty input"));
583    }
584    Ok(input.to_string())
585}
586
587enum Status {
588    Active,
589    Inactive,
590    Pending,
591}
592"#.to_string()
593    }
594
595    #[test]
596    fn get_functions_finds_all() {
597        let mut oracle = TreeSitterOracle::new(sample_rust_code());
598        let functions = oracle.get_functions().unwrap();
599        assert!(functions.len() >= 3);
600        
601        let names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect();
602        assert!(names.contains(&"new"));
603        assert!(names.contains(&"with_debug"));
604        assert!(names.contains(&"process"));
605        assert!(names.contains(&"parse"));
606    }
607
608    #[test]
609    fn get_structs_finds_all() {
610        let mut oracle = TreeSitterOracle::new(sample_rust_code());
611        let structs = oracle.get_structs().unwrap();
612        assert!(structs.len() >= 1);
613        
614        let config_struct = structs.iter().find(|s| s.name == "Config").unwrap();
615        assert!(config_struct.fields.contains(&"debug".to_string()));
616        assert!(config_struct.fields.contains(&"timeout".to_string()));
617    }
618
619    #[test]
620    fn get_enums_finds_all() {
621        let mut oracle = TreeSitterOracle::new(sample_rust_code());
622        let enums = oracle.get_enums().unwrap();
623        assert!(enums.len() >= 1);
624        
625        let status_enum = enums.iter().find(|e| e.name == "Status").unwrap();
626        assert!(status_enum.variants.contains(&"Active".to_string()));
627        assert!(status_enum.variants.contains(&"Inactive".to_string()));
628    }
629
630    #[test]
631    fn count_error_patterns() {
632        let mut oracle = TreeSitterOracle::new(sample_rust_code());
633        let counts = oracle.count_error_patterns().unwrap();
634        
635        assert!(counts.result_types >= 2); // At least 2 Result<T>
636        assert!(counts.try_operators >= 1); // At least 1 ? operator
637    }
638}