Skip to main content

codetether_agent/rlm/oracle/
mod.rs

1//! Deterministic oracle system for validating RLM REPL trace outputs.
2//!
3//! This module provides oracles that can verify FINAL() answers from RLM traces
4//! without requiring cloud LLM judges. This enables synthetic training data
5//! generation for the BitNet distilled navigation model.
6//!
7//! # Architecture
8//!
9//! - **Grep Oracle**: Pattern-match verification (e.g., "find all async functions")
10//! - **Tree-sitter Oracle**: Structural AST verification (function signatures, etc.)
11//! - **Validator Pipeline**: Routes queries to appropriate oracles and outputs golden traces
12//!
13//! # Usage
14//!
15//! ```ignore
16//! use codetether_agent::rlm::oracle::{TraceValidator, OracleResult};
17//!
18//! let validator = TraceValidator::new();
19//! let result = validator.validate(&analysis_result, &source_file).await;
20//!
21//! match result {
22//!     OracleResult::Golden(trace) => save_to_jsonl(trace),
23//!     OracleResult::Unverified => {} // No oracle available
24//!     OracleResult::Failed(reason) => {} // Oracle disagrees
25//! }
26//! ```
27
28mod grep_oracle;
29mod schema;
30mod templates;
31mod tree_sitter_oracle;
32mod validator;
33
34pub use grep_oracle::GrepOracle;
35pub use schema::{AstPayload, AstResult, FinalPayload, GrepMatch, GrepPayload, SemanticPayload};
36pub use templates::{GeneratedQuery, QueryTemplate, TemplateKind};
37pub use tree_sitter_oracle::TreeSitterOracle;
38pub use validator::{OracleResult, TraceValidator, ValidatedTrace, VerificationMethod};
39
40/// Query type classification for routing to the appropriate oracle.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum QueryType {
43    /// Pattern-match query (grep-based: line numbers, text matches)
44    PatternMatch,
45    /// Structural query (AST-based: function signatures, struct fields)
46    Structural,
47    /// Semantic query (requires LLM understanding - no deterministic oracle)
48    Semantic,
49}
50
51/// Classification of an RLM FINAL() answer format.
52#[derive(Debug, Clone, PartialEq)]
53pub enum FinalAnswerFormat {
54    /// Line-numbered matches (e.g., "42:async fn foo()", "100:pub struct Bar")
55    LineNumberedMatches {
56        matches: Vec<(usize, String)>,
57    },
58    /// Count result (e.g., "Found 15 occurrences")
59    CountResult {
60        count: usize,
61    },
62    /// Structured data (e.g., function signature JSON)
63    StructuredData {
64        data: serde_json::Value,
65    },
66    /// Free-form text (semantic - no deterministic verification)
67    FreeFormText {
68        text: String,
69    },
70}
71
72impl FinalAnswerFormat {
73    /// Parse a FINAL() answer string into its classified format.
74    pub fn parse(answer: &str) -> Self {
75        // Try to parse as line-numbered matches
76        let lines: Vec<&str> = answer.lines().collect();
77        let mut numbered_matches = Vec::new();
78        let mut all_valid = true;
79
80        for line in &lines {
81            // Pattern: "42:text" or "42: text" or "L42: text"
82            let trimmed = line.trim();
83            if let Some(colon_pos) = trimmed.find(':') {
84                let num_part = trimmed[..colon_pos]
85                    .trim()
86                    .trim_start_matches('L')
87                    .trim();
88                if let Ok(line_num) = num_part.parse::<usize>() {
89                    let text_part = trimmed[colon_pos + 1..].trim().to_string();
90                    numbered_matches.push((line_num, text_part));
91                } else {
92                    all_valid = false;
93                    break;
94                }
95            } else if !trimmed.is_empty() {
96                // Non-empty line without line number
97                all_valid = false;
98                break;
99            }
100        }
101
102        if all_valid && !numbered_matches.is_empty() {
103            return Self::LineNumberedMatches {
104                matches: numbered_matches,
105            };
106        }
107
108        // Try to parse as count result
109        let lower = answer.to_lowercase();
110        if lower.contains("found") || lower.contains("count:") || lower.contains("occurrences") {
111            // Extract number from text like "Found 15 async functions"
112            if let Some(count) = extract_count_from_text(answer) {
113                return Self::CountResult { count };
114            }
115        }
116
117        // Try to parse as JSON
118        if answer.trim().starts_with('{') || answer.trim().starts_with('[') {
119            if let Ok(data) = serde_json::from_str::<serde_json::Value>(answer) {
120                return Self::StructuredData { data };
121            }
122        }
123
124        // Default to free-form text
125        Self::FreeFormText {
126            text: answer.to_string(),
127        }
128    }
129}
130
131/// Extract a count number from natural language text.
132fn extract_count_from_text(text: &str) -> Option<usize> {
133    // Look for patterns like "15 functions", "count: 42", "Found 7"
134    let re = regex::Regex::new(r"(?i)(?:found|count:?\s*)\s*(\d+)|(\d+)\s+(?:functions?|matches?|occurrences?|items?|results?)").ok()?;
135    
136    for cap in re.captures_iter(text) {
137        // Try first group (found/count)
138        if let Some(m) = cap.get(1) {
139            if let Ok(n) = m.as_str().parse() {
140                return Some(n);
141            }
142        }
143        // Try second group (number before word)
144        if let Some(m) = cap.get(2) {
145            if let Ok(n) = m.as_str().parse() {
146                return Some(n);
147            }
148        }
149    }
150    
151    None
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn parse_line_numbered_matches() {
160        let answer = "42:async fn foo()\n100:pub struct Bar\n";
161        let format = FinalAnswerFormat::parse(answer);
162        match format {
163            FinalAnswerFormat::LineNumberedMatches { matches } => {
164                assert_eq!(matches.len(), 2);
165                assert_eq!(matches[0], (42, "async fn foo()".to_string()));
166                assert_eq!(matches[1], (100, "pub struct Bar".to_string()));
167            }
168            _ => panic!("Expected LineNumberedMatches"),
169        }
170    }
171
172    #[test]
173    fn parse_count_result() {
174        let answer = "Found 15 async functions";
175        let format = FinalAnswerFormat::parse(answer);
176        match format {
177            FinalAnswerFormat::CountResult { count } => assert_eq!(count, 15),
178            _ => panic!("Expected CountResult"),
179        }
180    }
181
182    #[test]
183    fn parse_structured_data() {
184        let answer = r#"{"name": "foo", "args": ["x", "y"]}"#;
185        let format = FinalAnswerFormat::parse(answer);
186        match format {
187            FinalAnswerFormat::StructuredData { data } => {
188                assert_eq!(data["name"], "foo");
189            }
190            _ => panic!("Expected StructuredData"),
191        }
192    }
193
194    #[test]
195    fn parse_free_form_text() {
196        let answer = "This function handles error cases by using the ? operator";
197        let format = FinalAnswerFormat::parse(answer);
198        match format {
199            FinalAnswerFormat::FreeFormText { text } => {
200                assert!(text.contains("error cases"));
201            }
202            _ => panic!("Expected FreeFormText"),
203        }
204    }
205}