Skip to main content

codetether_agent/rlm/oracle/
mod.rs

1//! Deterministic oracle system for validating RLM REPL trace outputs.
2//!
3//! This module provides oracles that can verify FINAL() answers from RLM traces
4//! without requiring cloud LLM judges. This enables synthetic training data
5//! generation for the BitNet distilled navigation model.
6//!
7//! # Architecture
8//!
9//! - **Grep Oracle**: Pattern-match verification (e.g., "find all async functions")
10//! - **Tree-sitter Oracle**: Structural AST verification (function signatures, etc.)
11//! - **Validator Pipeline**: Routes queries to appropriate oracles and outputs golden traces
12//!
13//! # Usage
14//!
15//! ```ignore
16//! use codetether_agent::rlm::oracle::{TraceValidator, OracleResult};
17//!
18//! let validator = TraceValidator::new();
19//! let result = validator.validate(&analysis_result, &source_file).await;
20//!
21//! match result {
22//!     OracleResult::Golden(trace) => save_to_jsonl(trace),
23//!     OracleResult::Unverified => {} // No oracle available
24//!     OracleResult::Failed(reason) => {} // Oracle disagrees
25//! }
26//! ```
27
28mod ast_validation;
29mod batch;
30mod batch_write;
31mod consensus;
32mod consensus_helpers;
33mod grep_oracle;
34mod grep_validation;
35mod query_type;
36mod record;
37mod schema;
38#[path = "storage/mod.rs"]
39mod storage;
40mod templates;
41mod trace_types;
42mod tree_sitter_oracle;
43mod types;
44#[path = "validator/mod.rs"]
45mod validator;
46
47pub use grep_oracle::{GrepOracle, GrepVerification};
48pub use record::OracleTraceRecord;
49pub use schema::{AstPayload, AstResult, FinalPayload, GrepMatch, GrepPayload, SemanticPayload};
50pub use storage::{
51    OracleTracePersistResult, OracleTraceStorage, OracleTraceSyncStats, default_spool_dir,
52};
53pub use templates::{GeneratedQuery, QueryTemplate, TemplateKind};
54pub use trace_types::{OracleResult, ValidatedTrace};
55pub use tree_sitter_oracle::{TreeSitterOracle, TreeSitterVerification};
56pub use types::{TraceStep, VerificationMethod};
57pub use validator::{BatchValidationStats, SplitWriteStats, TraceValidator};
58
59/// Query type classification for routing to the appropriate oracle.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum QueryType {
62    /// Pattern-match query (grep-based: line numbers, text matches)
63    PatternMatch,
64    /// Structural query (AST-based: function signatures, struct fields)
65    Structural,
66    /// Semantic query (requires LLM understanding - no deterministic oracle)
67    Semantic,
68}
69
70/// Classification of an RLM FINAL() answer format.
71#[derive(Debug, Clone, PartialEq)]
72pub enum FinalAnswerFormat {
73    /// Line-numbered matches (e.g., "42:async fn foo()", "100:pub struct Bar")
74    LineNumberedMatches { matches: Vec<(usize, String)> },
75    /// Count result (e.g., "Found 15 occurrences")
76    CountResult { count: usize },
77    /// Structured data (e.g., function signature JSON)
78    StructuredData { data: serde_json::Value },
79    /// Free-form text (semantic - no deterministic verification)
80    FreeFormText { text: String },
81}
82
83impl FinalAnswerFormat {
84    /// Parse a FINAL() answer string into its classified format.
85    pub fn parse(answer: &str) -> Self {
86        // Try to parse as line-numbered matches
87        let lines: Vec<&str> = answer.lines().collect();
88        let mut numbered_matches = Vec::new();
89        let mut all_valid = true;
90
91        for line in &lines {
92            // Pattern: "42:text" or "42: text" or "L42: text"
93            let trimmed = line.trim();
94            if let Some(colon_pos) = trimmed.find(':') {
95                let num_part = trimmed[..colon_pos].trim().trim_start_matches('L').trim();
96                if let Ok(line_num) = num_part.parse::<usize>() {
97                    let text_part = trimmed[colon_pos + 1..].trim().to_string();
98                    numbered_matches.push((line_num, text_part));
99                } else {
100                    all_valid = false;
101                    break;
102                }
103            } else if !trimmed.is_empty() {
104                // Non-empty line without line number
105                all_valid = false;
106                break;
107            }
108        }
109
110        if all_valid && !numbered_matches.is_empty() {
111            return Self::LineNumberedMatches {
112                matches: numbered_matches,
113            };
114        }
115
116        // Try to parse as count result
117        let lower = answer.to_lowercase();
118        if lower.contains("found") || lower.contains("count:") || lower.contains("occurrences") {
119            // Extract number from text like "Found 15 async functions"
120            if let Some(count) = extract_count_from_text(answer) {
121                return Self::CountResult { count };
122            }
123        }
124
125        // Try to parse as JSON
126        if (answer.trim().starts_with('{') || answer.trim().starts_with('['))
127            && let Ok(data) = serde_json::from_str::<serde_json::Value>(answer)
128        {
129            return Self::StructuredData { data };
130        }
131
132        // Default to free-form text
133        Self::FreeFormText {
134            text: answer.to_string(),
135        }
136    }
137}
138
139/// Extract a count number from natural language text.
140fn extract_count_from_text(text: &str) -> Option<usize> {
141    // Look for patterns like "15 functions", "count: 42", "Found 7"
142    let re = regex::Regex::new(r"(?i)(?:found|count:?\s*)\s*(\d+)|(\d+)\s+(?:functions?|matches?|occurrences?|items?|results?)").ok()?;
143
144    for cap in re.captures_iter(text) {
145        // Try first group (found/count)
146        if let Some(m) = cap.get(1)
147            && let Ok(n) = m.as_str().parse()
148        {
149            return Some(n);
150        }
151        // Try second group (number before word)
152        if let Some(m) = cap.get(2)
153            && let Ok(n) = m.as_str().parse()
154        {
155            return Some(n);
156        }
157    }
158
159    None
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn parse_line_numbered_matches() {
168        let answer = "42:async fn foo()\n100:pub struct Bar\n";
169        let format = FinalAnswerFormat::parse(answer);
170        match format {
171            FinalAnswerFormat::LineNumberedMatches { matches } => {
172                assert_eq!(matches.len(), 2);
173                assert_eq!(matches[0], (42, "async fn foo()".to_string()));
174                assert_eq!(matches[1], (100, "pub struct Bar".to_string()));
175            }
176            _ => panic!("Expected LineNumberedMatches"),
177        }
178    }
179
180    #[test]
181    fn parse_count_result() {
182        let answer = "Found 15 async functions";
183        let format = FinalAnswerFormat::parse(answer);
184        match format {
185            FinalAnswerFormat::CountResult { count } => assert_eq!(count, 15),
186            _ => panic!("Expected CountResult"),
187        }
188    }
189
190    #[test]
191    fn parse_structured_data() {
192        let answer = r#"{"name": "foo", "args": ["x", "y"]}"#;
193        let format = FinalAnswerFormat::parse(answer);
194        match format {
195            FinalAnswerFormat::StructuredData { data } => {
196                assert_eq!(data["name"], "foo");
197            }
198            _ => panic!("Expected StructuredData"),
199        }
200    }
201
202    #[test]
203    fn parse_free_form_text() {
204        let answer = "This function handles error cases by using the ? operator";
205        let format = FinalAnswerFormat::parse(answer);
206        match format {
207            FinalAnswerFormat::FreeFormText { text } => {
208                assert!(text.contains("error cases"));
209            }
210            _ => panic!("Expected FreeFormText"),
211        }
212    }
213}