Skip to main content

codetether_agent/rlm/oracle/
tree_sitter_oracle.rs

1//! Tree-sitter oracle for structural AST verification.
2//!
3//! This oracle uses tree-sitter to parse source code and verify structural
4//! queries about function signatures, struct fields, trait implementations, etc.
5//!
6//! # Supported Queries
7//!
8//! - Function signatures (name, args, return type)
9//! - Struct/enum definitions and field listings
10//! - Impl blocks and trait implementations
11//! - Error handling patterns (Result, match arms, ? operator)
12//!
13//! # Features
14//!
15//! - Exposed as a verification oracle for FINAL() answers
16//! - Also exposed as a new DSL command: `ast_query("(function_item)")`
17//!
18//! # Dependencies
19//!
20//! Requires `tree-sitter` and `tree-sitter-rust` crates.
21
22use anyhow::{Result, anyhow};
23use std::collections::HashMap;
24use streaming_iterator::StreamingIterator;
25
26use super::QueryType;
27
28/// Tree-sitter based oracle for validating structural queries.
29pub struct TreeSitterOracle {
30    /// Source code content
31    source: String,
32    /// Parsed tree-sitter tree (lazy-initialized)
33    tree: Option<tree_sitter::Tree>,
34    /// Language parser
35    parser: Option<tree_sitter::Parser>,
36}
37
38/// Result of tree-sitter AST query.
39#[derive(Debug, Clone, PartialEq)]
40pub struct AstQueryResult {
41    /// Query type that was executed
42    pub query_type: String,
43    /// Matched nodes with their captures
44    pub matches: Vec<AstMatch>,
45}
46
47/// A single match from an AST query.
48#[derive(Debug, Clone, PartialEq)]
49pub struct AstMatch {
50    /// Line number (1-indexed)
51    pub line: usize,
52    /// Column (1-indexed)
53    pub column: usize,
54    /// Captured nodes
55    pub captures: HashMap<String, String>,
56    /// Full text of the matched node
57    pub text: String,
58}
59
60/// Result of tree-sitter oracle verification.
61#[derive(Debug, Clone, PartialEq)]
62pub enum TreeSitterVerification {
63    /// Answer matches AST truth exactly.
64    ExactMatch,
65    /// Answer matches but in different order.
66    UnorderedMatch,
67    /// Answer is a subset (partial match).
68    SubsetMatch { claimed: usize, actual: usize },
69    /// Answer contains incorrect claims.
70    HasErrors { errors: Vec<String> },
71    /// Answer is completely different.
72    Mismatch,
73    /// Could not parse or verify.
74    CannotVerify { reason: String },
75}
76
77impl TreeSitterOracle {
78    /// Create a new tree-sitter oracle for the given source.
79    pub fn new(source: String) -> Self {
80        Self {
81            source,
82            tree: None,
83            parser: None,
84        }
85    }
86
87    /// Initialize the parser (lazy).
88    fn ensure_parser(&mut self) -> Result<()> {
89        if self.parser.is_some() {
90            return Ok(());
91        }
92
93        let mut parser = tree_sitter::Parser::new();
94        parser.set_language(&tree_sitter_rust::LANGUAGE.into())?;
95        self.parser = Some(parser);
96        Ok(())
97    }
98
99    /// Parse the source and return the tree.
100    fn parse(&mut self) -> Result<&tree_sitter::Tree> {
101        self.ensure_parser()?;
102
103        if self.tree.is_none() {
104            let parser = self
105                .parser
106                .as_mut()
107                .ok_or_else(|| anyhow!("Parser not initialized"))?;
108            let tree = parser
109                .parse(&self.source, None)
110                .ok_or_else(|| anyhow!("Failed to parse source"))?;
111            self.tree = Some(tree);
112        }
113
114        Ok(self.tree.as_ref().unwrap())
115    }
116
117    /// Execute a tree-sitter S-expression query.
118    ///
119    /// Example queries:
120    /// - `(function_item name: (identifier) @name)`
121    /// - `(struct_item name: (type_identifier) @name body: (field_declaration_list))`
122    /// - `(impl_item trait: (type_identifier) @trait for: (type_identifier) @for)`
123    pub fn query(&mut self, query_str: &str) -> Result<AstQueryResult> {
124        self.parse()?;
125        let tree = self.tree.as_ref().unwrap();
126        let root = tree.root_node();
127
128        let query = tree_sitter::Query::new(&tree_sitter_rust::LANGUAGE.into(), query_str)?;
129        let mut cursor = tree_sitter::QueryCursor::new();
130
131        let source_bytes = self.source.as_bytes();
132        let mut results = Vec::new();
133
134        let mut matches = cursor.matches(&query, root, source_bytes);
135        while let Some(match_) = matches.next() {
136            let mut captures = HashMap::new();
137            let mut text = String::new();
138            let mut line = 1;
139            let mut column = 1;
140
141            for capture in match_.captures {
142                let node = capture.node;
143                let capture_name = query.capture_names()[capture.index as usize].to_string();
144                let capture_text = node.utf8_text(source_bytes)?.to_string();
145
146                captures.insert(capture_name, capture_text.clone());
147
148                if text.is_empty() {
149                    text = capture_text;
150                    line = node.start_position().row + 1;
151                    column = node.start_position().column + 1;
152                }
153            }
154
155            results.push(AstMatch {
156                line,
157                column,
158                captures,
159                text,
160            });
161        }
162
163        Ok(AstQueryResult {
164            query_type: query_str.to_string(),
165            matches: results,
166        })
167    }
168
169    /// Get all function signatures in the source.
170    pub fn get_functions(&mut self) -> Result<Vec<FunctionSignature>> {
171        let result = self.query(
172            r#"
173            (function_item
174                name: (identifier) @name
175                parameters: (parameters) @params
176                return_type: (_)? @return_type)
177            "#,
178        )?;
179
180        let mut functions = Vec::new();
181        for m in result.matches {
182            let name = m.captures.get("name").cloned().unwrap_or_default();
183            let params = m.captures.get("params").cloned().unwrap_or_default();
184            let return_type = m.captures.get("return_type").cloned();
185
186            functions.push(FunctionSignature {
187                name,
188                params,
189                return_type,
190                line: m.line,
191            });
192        }
193
194        Ok(functions)
195    }
196
197    /// Get all struct definitions in the source.
198    pub fn get_structs(&mut self) -> Result<Vec<StructDefinition>> {
199        let result = self.query(
200            r#"
201            (struct_item
202                name: (type_identifier) @name
203                body: (field_declaration_list)? @body)
204            "#,
205        )?;
206
207        let mut structs = Vec::new();
208        for m in result.matches {
209            let name = m.captures.get("name").cloned().unwrap_or_default();
210            let body = m.captures.get("body").cloned().unwrap_or_default();
211
212            // Extract fields from body
213            let fields = self.extract_struct_fields(&body)?;
214
215            structs.push(StructDefinition {
216                name,
217                fields,
218                line: m.line,
219            });
220        }
221
222        Ok(structs)
223    }
224
225    /// Extract field names from a struct body.
226    fn extract_struct_fields(&self, body: &str) -> Result<Vec<String>> {
227        let mut fields = Vec::new();
228
229        // Simple regex-based extraction (faster than re-parsing)
230        let re = regex::Regex::new(r"(?:pub\s+)?(\w+)\s*:")?;
231        for cap in re.captures_iter(body) {
232            if let Some(name) = cap.get(1) {
233                fields.push(name.as_str().to_string());
234            }
235        }
236
237        Ok(fields)
238    }
239
240    /// Get all enum definitions in the source.
241    pub fn get_enums(&mut self) -> Result<Vec<EnumDefinition>> {
242        let result = self.query(
243            r#"
244            (enum_item
245                name: (type_identifier) @name
246                body: (enum_variant_list)? @body)
247            "#,
248        )?;
249
250        let mut enums = Vec::new();
251        for m in result.matches {
252            let name = m.captures.get("name").cloned().unwrap_or_default();
253            let body = m.captures.get("body").cloned().unwrap_or_default();
254
255            // Extract variants from body
256            let variants = self.extract_enum_variants(&body)?;
257
258            enums.push(EnumDefinition {
259                name,
260                variants,
261                line: m.line,
262            });
263        }
264
265        Ok(enums)
266    }
267
268    /// Extract variant names from an enum body.
269    fn extract_enum_variants(&self, body: &str) -> Result<Vec<String>> {
270        let mut variants = Vec::new();
271
272        let re = regex::Regex::new(r"(\w+)\s*(?:,|=|\{|\()")?;
273        for cap in re.captures_iter(body) {
274            if let Some(name) = cap.get(1) {
275                let name_str = name.as_str();
276                // Skip keywords
277                if !["pub", "fn", "struct", "enum", "impl", "trait"].contains(&name_str) {
278                    variants.push(name_str.to_string());
279                }
280            }
281        }
282
283        Ok(variants)
284    }
285
286    /// Get all impl blocks in the source.
287    pub fn get_impls(&mut self) -> Result<Vec<ImplDefinition>> {
288        let result = self.query(
289            r#"
290            [
291                (impl_item
292                    type: (type_identifier) @type
293                    trait: (type_identifier)? @trait
294                    body: (declaration_list)? @body)
295                (impl_item
296                    for: (type_identifier) @for
297                    trait: (type_identifier) @trait
298                    body: (declaration_list)? @body)
299            ]
300            "#,
301        )?;
302
303        let mut impls = Vec::new();
304        for m in result.matches {
305            let type_name = m
306                .captures
307                .get("type")
308                .or_else(|| m.captures.get("for"))
309                .cloned()
310                .unwrap_or_default();
311            let trait_name = m.captures.get("trait").cloned();
312            let body = m.captures.get("body").cloned().unwrap_or_default();
313
314            impls.push(ImplDefinition {
315                type_name,
316                trait_name,
317                method_count: body.matches("fn ").count(),
318                line: m.line,
319            });
320        }
321
322        Ok(impls)
323    }
324
325    /// Count error handling patterns.
326    pub fn count_error_patterns(&mut self) -> Result<ErrorPatternCounts> {
327        // Count Result<T> types
328        let result_types =
329            self.query(r#"(generic_type type: (type_identifier) @name (#eq? @name "Result"))"#)?;
330
331        // Count ? operators
332        let try_operators = self.query(r#"(try_expression)"#)?;
333
334        // Count .unwrap() calls
335        let unwrap_calls = self.query(r#"(call_expression function: (field_expression field: (field_identifier) @method (#eq? @method "unwrap")))"#)?;
336
337        // Count .expect() calls
338        let expect_calls = self.query(r#"(call_expression function: (field_expression field: (field_identifier) @method (#eq? @method "expect")))"#)?;
339
340        // Count match expressions
341        let match_exprs = self.query(r#"(match_expression)"#)?;
342
343        Ok(ErrorPatternCounts {
344            result_types: result_types.matches.len(),
345            try_operators: try_operators.matches.len(),
346            unwrap_calls: unwrap_calls.matches.len(),
347            expect_calls: expect_calls.matches.len(),
348            match_expressions: match_exprs.matches.len(),
349        })
350    }
351
352    /// Verify a FINAL() answer against AST truth.
353    pub fn verify(&mut self, answer: &str, query: &str) -> TreeSitterVerification {
354        let query_type = Self::classify_query(query);
355
356        match query_type {
357            QueryType::Structural => {
358                // Try to match against different structural queries
359                if query.to_lowercase().contains("function") {
360                    self.verify_functions(answer)
361                } else if query.to_lowercase().contains("struct") {
362                    self.verify_structs(answer)
363                } else if query.to_lowercase().contains("enum") {
364                    self.verify_enums(answer)
365                } else if query.to_lowercase().contains("impl") {
366                    self.verify_impls(answer)
367                } else {
368                    TreeSitterVerification::CannotVerify {
369                        reason: "Unknown structural query type".to_string(),
370                    }
371                }
372            }
373            _ => TreeSitterVerification::CannotVerify {
374                reason: "Not a structural query".to_string(),
375            },
376        }
377    }
378
379    /// Classify query type for tree-sitter routing.
380    fn classify_query(query: &str) -> QueryType {
381        let lower = query.to_lowercase();
382
383        if lower.contains("signature")
384            || lower.contains("parameters")
385            || lower.contains("return type")
386            || lower.contains("fields of")
387            || lower.contains("what fields")
388            || lower.contains("struct definition")
389            || lower.contains("enum variants")
390            || lower.contains("implements")
391            || lower.contains("methods")
392        {
393            return QueryType::Structural;
394        }
395
396        QueryType::Semantic
397    }
398
399    fn verify_functions(&mut self, answer: &str) -> TreeSitterVerification {
400        let functions = match self.get_functions() {
401            Ok(f) => f,
402            Err(e) => {
403                return TreeSitterVerification::CannotVerify {
404                    reason: format!("Failed to parse functions: {}", e),
405                };
406            }
407        };
408
409        // Parse answer to extract claimed function names
410        let claimed_names: Vec<String> = answer
411            .lines()
412            .filter_map(|line| {
413                // Try to extract function name from line
414                let re = regex::Regex::new(r"\bfn\s+(\w+)").ok()?;
415                re.captures(line)
416                    .and_then(|cap| cap.get(1))
417                    .map(|m| m.as_str().to_string())
418            })
419            .collect();
420
421        if claimed_names.is_empty() {
422            return TreeSitterVerification::CannotVerify {
423                reason: "Could not extract function names from answer".to_string(),
424            };
425        }
426
427        let actual_names: Vec<String> = functions.iter().map(|f| f.name.clone()).collect();
428        let claimed_set: std::collections::HashSet<_> = claimed_names.iter().cloned().collect();
429        let actual_set: std::collections::HashSet<_> = actual_names.iter().cloned().collect();
430
431        if claimed_set == actual_set {
432            TreeSitterVerification::ExactMatch
433        } else if claimed_set.is_subset(&actual_set) {
434            TreeSitterVerification::SubsetMatch {
435                claimed: claimed_names.len(),
436                actual: actual_names.len(),
437            }
438        } else {
439            let errors = claimed_names
440                .iter()
441                .filter(|name| !actual_set.contains(*name))
442                .map(|name| format!("Function '{}' not found", name))
443                .collect();
444            TreeSitterVerification::HasErrors { errors }
445        }
446    }
447
448    fn verify_structs(&mut self, answer: &str) -> TreeSitterVerification {
449        let structs = match self.get_structs() {
450            Ok(s) => s,
451            Err(e) => {
452                return TreeSitterVerification::CannotVerify {
453                    reason: format!("Failed to parse structs: {}", e),
454                };
455            }
456        };
457
458        // Similar logic to verify_functions
459        let claimed_names: Vec<String> = answer
460            .lines()
461            .filter_map(|line| {
462                let re = regex::Regex::new(r"\bstruct\s+(\w+)").ok()?;
463                re.captures(line)
464                    .and_then(|cap| cap.get(1))
465                    .map(|m| m.as_str().to_string())
466            })
467            .collect();
468
469        if claimed_names.is_empty() {
470            return TreeSitterVerification::CannotVerify {
471                reason: "Could not extract struct names from answer".to_string(),
472            };
473        }
474
475        let actual_names: Vec<String> = structs.iter().map(|s| s.name.clone()).collect();
476        let claimed_set: std::collections::HashSet<_> = claimed_names.iter().cloned().collect();
477        let actual_set: std::collections::HashSet<_> = actual_names.iter().cloned().collect();
478
479        if claimed_set == actual_set {
480            TreeSitterVerification::ExactMatch
481        } else if claimed_set.is_subset(&actual_set) {
482            TreeSitterVerification::SubsetMatch {
483                claimed: claimed_names.len(),
484                actual: actual_names.len(),
485            }
486        } else {
487            let errors = claimed_names
488                .iter()
489                .filter(|name| !actual_set.contains(*name))
490                .map(|name| format!("Struct '{}' not found", name))
491                .collect();
492            TreeSitterVerification::HasErrors { errors }
493        }
494    }
495
496    fn verify_enums(&mut self, _answer: &str) -> TreeSitterVerification {
497        // Similar pattern
498        TreeSitterVerification::CannotVerify {
499            reason: "Enum verification not yet implemented".to_string(),
500        }
501    }
502
503    fn verify_impls(&mut self, _answer: &str) -> TreeSitterVerification {
504        // Similar pattern
505        TreeSitterVerification::CannotVerify {
506            reason: "Impl verification not yet implemented".to_string(),
507        }
508    }
509}
510
511/// Function signature information.
512#[derive(Debug, Clone, PartialEq)]
513pub struct FunctionSignature {
514    pub name: String,
515    pub params: String,
516    pub return_type: Option<String>,
517    pub line: usize,
518}
519
520/// Struct definition information.
521#[derive(Debug, Clone, PartialEq)]
522pub struct StructDefinition {
523    pub name: String,
524    pub fields: Vec<String>,
525    pub line: usize,
526}
527
528/// Enum definition information.
529#[derive(Debug, Clone, PartialEq)]
530pub struct EnumDefinition {
531    pub name: String,
532    pub variants: Vec<String>,
533    pub line: usize,
534}
535
536/// Impl block information.
537#[derive(Debug, Clone, PartialEq)]
538pub struct ImplDefinition {
539    pub type_name: String,
540    pub trait_name: Option<String>,
541    pub method_count: usize,
542    pub line: usize,
543}
544
545/// Counts of error handling patterns.
546#[derive(Debug, Clone, PartialEq)]
547pub struct ErrorPatternCounts {
548    pub result_types: usize,
549    pub try_operators: usize,
550    pub unwrap_calls: usize,
551    pub expect_calls: usize,
552    pub match_expressions: usize,
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558
559    fn sample_rust_code() -> String {
560        r#"
561use anyhow::Result;
562
563pub struct Config {
564    pub debug: bool,
565    pub timeout: u64,
566}
567
568impl Config {
569    pub fn new() -> Self {
570        Self { debug: false, timeout: 30 }
571    }
572    
573    pub fn with_debug(mut self) -> Self {
574        self.debug = true;
575        self
576    }
577}
578
579pub async fn process(input: &str) -> Result<String> {
580    let data = parse(input)?;
581    Ok(data.to_uppercase())
582}
583
584fn parse(input: &str) -> Result<String> {
585    if input.is_empty() {
586        return Err(anyhow!("empty input"));
587    }
588    Ok(input.to_string())
589}
590
591enum Status {
592    Active,
593    Inactive,
594    Pending,
595}
596"#
597        .to_string()
598    }
599
600    #[test]
601    fn get_functions_finds_all() {
602        let mut oracle = TreeSitterOracle::new(sample_rust_code());
603        let functions = oracle.get_functions().unwrap();
604        assert!(functions.len() >= 3);
605
606        let names: Vec<&str> = functions.iter().map(|f| f.name.as_str()).collect();
607        assert!(names.contains(&"new"));
608        assert!(names.contains(&"with_debug"));
609        assert!(names.contains(&"process"));
610        assert!(names.contains(&"parse"));
611    }
612
613    #[test]
614    fn get_structs_finds_all() {
615        let mut oracle = TreeSitterOracle::new(sample_rust_code());
616        let structs = oracle.get_structs().unwrap();
617        assert!(structs.len() >= 1);
618
619        let config_struct = structs.iter().find(|s| s.name == "Config").unwrap();
620        assert!(config_struct.fields.contains(&"debug".to_string()));
621        assert!(config_struct.fields.contains(&"timeout".to_string()));
622    }
623
624    #[test]
625    fn get_enums_finds_all() {
626        let mut oracle = TreeSitterOracle::new(sample_rust_code());
627        let enums = oracle.get_enums().unwrap();
628        assert!(enums.len() >= 1);
629
630        let status_enum = enums.iter().find(|e| e.name == "Status").unwrap();
631        assert!(status_enum.variants.contains(&"Active".to_string()));
632        assert!(status_enum.variants.contains(&"Inactive".to_string()));
633    }
634
635    #[test]
636    fn count_error_patterns() {
637        let mut oracle = TreeSitterOracle::new(sample_rust_code());
638        let counts = oracle.count_error_patterns().unwrap();
639
640        assert!(counts.result_types >= 2); // At least 2 Result<T>
641        assert!(counts.try_operators >= 1); // At least 1 ? operator
642    }
643}