Skip to main content

codetether_agent/rlm/oracle/
validator.rs

1//! Trace validator pipeline for RLM REPL outputs.
2//!
3//! Orchestrates validation of RLM analysis results by:
4//! 1. Classifying the query type (pattern-match vs structural vs semantic)
5//! 2. Routing to the appropriate oracle
6//! 3. Marking traces as "golden" (verified), "unverified", or "failed"
7//! 4. Outputting golden traces as JSONL for downstream SFT training
8//!
9//! # JSONL Output Format
10//!
11//! ```json
12//! {
13//!   "prompt": "Find all async functions in src/rlm/repl.rs",
14//!   "trace": [{"iteration": 1, "action": "grep(\"async fn\")", "output": "..."}],
15//!   "final_payload": { "kind": "grep", ... },
16//!   "verdict": "golden",
17//!   "oracle_diff": null,
18//!   "repo_revision": "abc123def",
19//!   "timestamp": "2026-02-21T18:40:00Z"
20//! }
21//! ```
22//!
23//! # Usage
24//!
25//! ```ignore
26//! use codetether_agent::rlm::oracle::{TraceValidator, OracleResult};
27//! use codetether_agent::rlm::RlmAnalysisResult;
28//!
29//! let validator = TraceValidator::new();
30//! let result = validator.validate(&analysis_result, &source_code).await;
31//!
32//! match result {
33//!     OracleResult::Golden(trace) => {
34//!         // Write to JSONL file for training
35//!         writeln!(jsonl_file, "{}", serde_json::to_string(&trace)?)?;
36//!     }
37//!     OracleResult::Unverified => {
38//!         // No deterministic oracle available - skip or flag for manual review
39//!     }
40//!     OracleResult::Failed(reason) => {
41//!         // Oracle disagrees - discard or investigate
42//!     }
43//! }
44//! ```
45
46use anyhow::Result;
47use chrono::{DateTime, Utc};
48use serde::{Deserialize, Serialize};
49use std::time::Instant;
50
51use super::schema::FinalPayload;
52use super::{grep_oracle::GrepOracle, tree_sitter_oracle::TreeSitterOracle, QueryType};
53use crate::rlm::repl::RlmAnalysisResult;
54
55/// Result of oracle validation.
56#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57pub enum OracleResult {
58    /// Answer verified as correct (golden training example).
59    Golden(ValidatedTrace),
60    /// No deterministic oracle available for this query type.
61    Unverified {
62        reason: String,
63    },
64    /// Oracle disagrees with the answer (failed verification).
65    Failed {
66        reason: String,
67        diff: Option<String>,
68        trace: ValidatedTrace,
69    },
70}
71
72/// A single step in the RLM trace.
73#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
74pub struct TraceStep {
75    /// Iteration number (1-indexed)
76    pub iteration: usize,
77    /// Action performed (e.g., "grep(\"async fn\")")
78    pub action: String,
79    /// Output from the action
80    pub output: String,
81}
82
83/// A validated RLM trace ready for training data export.
84#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
85pub struct ValidatedTrace {
86    /// Original query/question
87    pub prompt: String,
88    /// Trace of steps taken
89    pub trace: Vec<TraceStep>,
90    /// Parsed FINAL() payload (if JSON)
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub final_payload: Option<FinalPayload>,
93    /// Verdict: "golden", "failed", or "unverified"
94    pub verdict: String,
95    /// Oracle diff (what the model got wrong - only on failures)
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub oracle_diff: Option<String>,
98    /// Git commit SHA for reproducibility
99    pub repo_revision: String,
100    /// Verification timestamp (ISO 8601)
101    pub timestamp: String,
102    // Legacy fields (kept for compatibility)
103    /// Model's FINAL() answer (raw string)
104    #[serde(skip)]
105    pub answer: String,
106    /// Number of RLM iterations
107    #[serde(skip)]
108    pub iterations: usize,
109    /// Number of sub-LLM calls
110    #[serde(skip)]
111    pub subcalls: usize,
112    /// Token usage - input
113    #[serde(skip)]
114    pub input_tokens: usize,
115    /// Token usage - output
116    #[serde(skip)]
117    pub output_tokens: usize,
118    /// Elapsed time in milliseconds
119    #[serde(skip)]
120    pub elapsed_ms: u64,
121    /// Source file path (if available)
122    #[serde(skip)]
123    pub source_path: Option<String>,
124    /// Oracle verification method used
125    #[serde(skip)]
126    pub verification_method: VerificationMethod,
127    /// Unique trace ID
128    #[serde(skip)]
129    pub trace_id: String,
130}
131
132/// Method used to verify the trace.
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
134pub enum VerificationMethod {
135    /// Grep-based pattern matching
136    GrepOracle,
137    /// Tree-sitter AST verification
138    TreeSitterOracle,
139    /// No oracle available
140    #[default]
141    None,
142}
143
144/// Trace validator that routes to appropriate oracles.
145pub struct TraceValidator {
146    /// Minimum confidence threshold for golden classification
147    confidence_threshold: f32,
148}
149
150impl Default for TraceValidator {
151    fn default() -> Self {
152        Self {
153            confidence_threshold: 0.95,
154        }
155    }
156}
157
158impl TraceValidator {
159    /// Create a new trace validator.
160    pub fn new() -> Self {
161        Self::default()
162    }
163
164    /// Set the confidence threshold for golden classification.
165    pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
166        self.confidence_threshold = threshold.clamp(0.0, 1.0);
167        self
168    }
169
170    /// Validate an RLM analysis result against source code.
171    /// Validate an RLM analysis result against source code.
172    ///
173    /// # Arguments
174    ///
175    /// * `result` - The RLM analysis result to validate
176    /// * `source` - The source code that was analyzed
177    /// * `source_path` - Optional path to the source file (for metadata)
178    /// * `repo_revision` - Git commit SHA for reproducibility
179    /// * `trace_steps` - Optional trace steps from execution
180    ///
181    /// # Returns
182    ///
183    /// - `OracleResult::Golden` if the answer is verified correct
184    /// - `OracleResult::Unverified` if no oracle is available
185    /// - `OracleResult::Failed` if the oracle disagrees with the answer
186    pub fn validate(
187        &self,
188        result: &RlmAnalysisResult,
189        source: &str,
190        source_path: Option<&str>,
191        repo_revision: Option<&str>,
192        trace_steps: Option<Vec<TraceStep>>,
193    ) -> OracleResult {
194        let _start = Instant::now();
195        
196        // Get current git revision if not provided
197        let revision = repo_revision
198            .map(|s| s.to_string())
199            .or_else(|| Self::get_git_revision().ok())
200            .unwrap_or_else(|| "unknown".to_string());
201        
202        // Extract the original query from sub_queries or use a placeholder
203        let query = result
204            .sub_queries
205            .first()
206            .map(|sq| sq.query.clone())
207            .unwrap_or_else(|| "unknown query".to_string());
208        
209        // Try to parse the answer as JSON
210        let final_payload = FinalPayload::parse(&result.answer);
211        
212        // Build base trace
213        let base_trace = || ValidatedTrace {
214            prompt: query.clone(),
215            trace: trace_steps.unwrap_or_default(),
216            final_payload: Some(final_payload.clone()),
217            verdict: "unverified".to_string(),
218            oracle_diff: None,
219            repo_revision: revision.clone(),
220            timestamp: Utc::now().to_rfc3339(),
221            // Legacy fields
222            answer: result.answer.clone(),
223            iterations: result.iterations,
224            subcalls: result.sub_queries.len(),
225            input_tokens: result.stats.input_tokens,
226            output_tokens: result.stats.output_tokens,
227            elapsed_ms: result.stats.elapsed_ms,
228            source_path: source_path.map(|s| s.to_string()),
229            verification_method: VerificationMethod::None,
230            trace_id: uuid::Uuid::new_v4().to_string(),
231        };
232
233        // Route based on payload kind first (more reliable than query classification)
234        let verdict = match &final_payload {
235            FinalPayload::Grep(_) => {
236                self.validate_grep_payload(&final_payload, source, source_path, &query, base_trace)
237            }
238            FinalPayload::Ast(_) => {
239                self.validate_ast_payload(&final_payload, source, source_path, &query, base_trace)
240            }
241            FinalPayload::Semantic(_) => {
242                // Semantic queries are unverifiable
243                return OracleResult::Unverified {
244                    reason: "Semantic queries require LLM understanding - no deterministic oracle available".to_string(),
245                };
246            }
247            FinalPayload::Malformed { error, .. } => {
248                // Return failed because the payload is malformed
249                let mut trace = base_trace();
250                trace.verdict = "failed".to_string();
251                OracleResult::Failed {
252                    reason: format!("Malformed FINAL payload: {}", error),
253                    diff: None,
254                    trace,
255                }
256            }
257        };
258        
259        verdict
260    }
261
262    /// Validate using grep payload directly.
263    fn validate_grep_payload(
264        &self,
265        payload: &FinalPayload,
266        source: &str,
267        source_path: Option<&str>,
268        query: &str,
269        base_trace: impl FnOnce() -> ValidatedTrace,
270    ) -> OracleResult {
271        let grep_payload = match payload {
272            FinalPayload::Grep(p) => p,
273            _ => unreachable!(),
274        };
275        
276        let oracle = GrepOracle::new(source.to_string());
277        
278        // Run actual grep to get ground truth
279        let ground_truth = match oracle.grep(&grep_payload.pattern) {
280            Ok(m) => m,
281            Err(e) => {
282                return OracleResult::Unverified {
283                    reason: format!("Could not run grep: {}", e),
284                };
285            }
286        };
287        
288        // Convert payload matches to the format expected by verification
289        let claimed: Vec<(usize, String)> = grep_payload.matches
290            .iter()
291            .map(|m| (m.line, m.text.clone()))
292            .collect();
293        
294        let verification = oracle.verify_matches(&claimed, &ground_truth);
295        
296        match verification {
297            super::grep_oracle::GrepVerification::ExactMatch
298            | super::grep_oracle::GrepVerification::UnorderedMatch => {
299                let mut trace = base_trace();
300                trace.verification_method = VerificationMethod::GrepOracle;
301                trace.verdict = "golden".to_string();
302                
303                tracing::info!(
304                    query = %query,
305                    pattern = %grep_payload.pattern,
306                    "Grep oracle verified trace as golden"
307                );
308                
309                OracleResult::Golden(trace)
310            }
311            super::grep_oracle::GrepVerification::SubsetMatch { claimed, actual } => {
312                let coverage = claimed as f32 / actual.max(1) as f32;
313                if coverage >= self.confidence_threshold {
314                    let mut trace = base_trace();
315                    trace.verification_method = VerificationMethod::GrepOracle;
316                    trace.verdict = "golden".to_string();
317                    
318                    OracleResult::Golden(trace)
319                } else {
320                    let diff = format!(
321                        "Subset match: model claimed {} but source has {} (coverage: {:.1}%)",
322                        claimed, actual, coverage * 100.0
323                    );
324                    let mut trace = base_trace();
325                    trace.verification_method = VerificationMethod::GrepOracle;
326                    trace.verdict = "failed".to_string();
327                    trace.oracle_diff = Some(diff.clone());
328                    
329                    OracleResult::Failed {
330                        reason: diff.clone(),
331                        diff: Some(diff),
332                        trace,
333                    }
334                }
335            }
336            super::grep_oracle::GrepVerification::HasFalsePositives { false_positives } => {
337                let diff = format!(
338                    "False positives: {} claims not found in source: {:?}",
339                    false_positives.len(),
340                    false_positives
341                );
342                let mut trace = base_trace();
343                trace.verification_method = VerificationMethod::GrepOracle;
344                trace.verdict = "failed".to_string();
345                trace.oracle_diff = Some(diff.clone());
346                
347                OracleResult::Failed {
348                    reason: diff.clone(),
349                    diff: Some(diff),
350                    trace,
351                }
352            }
353            super::grep_oracle::GrepVerification::HasFalseNegatives { false_negatives } => {
354                let diff = format!(
355                    "False negatives: {} items in source not claimed: {:?}",
356                    false_negatives.len(),
357                    false_negatives
358                );
359                let mut trace = base_trace();
360                trace.verification_method = VerificationMethod::GrepOracle;
361                trace.verdict = "failed".to_string();
362                trace.oracle_diff = Some(diff.clone());
363                
364                OracleResult::Failed {
365                    reason: diff.clone(),
366                    diff: Some(diff),
367                    trace,
368                }
369            }
370            super::grep_oracle::GrepVerification::Mismatch => {
371                let diff = "Complete mismatch between claimed and actual matches".to_string();
372                let mut trace = base_trace();
373                trace.verification_method = VerificationMethod::GrepOracle;
374                trace.verdict = "failed".to_string();
375                trace.oracle_diff = Some(diff.clone());
376                
377                OracleResult::Failed {
378                    reason: diff.clone(),
379                    diff: Some(diff),
380                    trace,
381                }
382            }
383            super::grep_oracle::GrepVerification::CannotVerify { reason } => {
384                OracleResult::Unverified { reason }
385            }
386        }
387    }
388
389    /// Validate using AST payload directly.
390    fn validate_ast_payload(
391        &self,
392        payload: &FinalPayload,
393        source: &str,
394        source_path: Option<&str>,
395        query: &str,
396        base_trace: impl FnOnce() -> ValidatedTrace,
397    ) -> OracleResult {
398        let ast_payload = match payload {
399            FinalPayload::Ast(p) => p,
400            _ => unreachable!(),
401        };
402        
403        let mut oracle = TreeSitterOracle::new(source.to_string());
404        
405        // Get actual AST results based on query type
406        let actual_results = match ast_payload.query.as_str() {
407            "functions" => {
408                match oracle.get_functions() {
409                    Ok(funcs) => funcs.iter().map(|f| f.name.clone()).collect(),
410                    Err(e) => {
411                        return OracleResult::Unverified {
412                            reason: format!("Failed to parse AST: {}", e),
413                        };
414                    }
415                }
416            }
417            "structs" => {
418                match oracle.get_structs() {
419                    Ok(structs) => structs.iter().map(|s| s.name.clone()).collect(),
420                    Err(e) => {
421                        return OracleResult::Unverified {
422                            reason: format!("Failed to parse AST: {}", e),
423                        };
424                    }
425                }
426            }
427            "enums" => {
428                match oracle.get_enums() {
429                    Ok(enums) => enums.iter().map(|e| e.name.clone()).collect(),
430                    Err(e) => {
431                        return OracleResult::Unverified {
432                            reason: format!("Failed to parse AST: {}", e),
433                        };
434                    }
435                }
436            }
437            _ => {
438                // Generic query - try to match against all
439                match oracle.get_functions() {
440                    Ok(funcs) => funcs.iter().map(|f| f.name.clone()).collect(),
441                    Err(_) => vec![],
442                }
443            }
444        };
445        
446        // Compare with claimed results
447        let claimed: std::collections::HashSet<_> = ast_payload.results
448            .iter()
449            .map(|r| r.name.clone())
450            .collect();
451        let actual: std::collections::HashSet<_> = actual_results.iter().cloned().collect();
452        
453        if claimed == actual {
454            let mut trace = base_trace();
455            trace.verification_method = VerificationMethod::TreeSitterOracle;
456            trace.verdict = "golden".to_string();
457            
458            OracleResult::Golden(trace)
459        } else if claimed.is_subset(&actual) {
460            let coverage = claimed.len() as f32 / actual.len().max(1) as f32;
461            if coverage >= self.confidence_threshold {
462                let mut trace = base_trace();
463                trace.verification_method = VerificationMethod::TreeSitterOracle;
464                trace.verdict = "golden".to_string();
465                
466                OracleResult::Golden(trace)
467            } else {
468                let diff = format!(
469                    "Partial match: claimed {:?}, actual {:?}",
470                    claimed, actual
471                );
472                let mut trace = base_trace();
473                trace.verification_method = VerificationMethod::TreeSitterOracle;
474                trace.verdict = "failed".to_string();
475                trace.oracle_diff = Some(diff.clone());
476                
477                OracleResult::Failed {
478                    reason: diff.clone(),
479                    diff: Some(diff),
480                    trace,
481                }
482            }
483        } else {
484            let diff = format!(
485                "Mismatch: claimed {:?}, actual {:?}",
486                claimed, actual
487            );
488            let mut trace = base_trace();
489            trace.verification_method = VerificationMethod::TreeSitterOracle;
490            trace.verdict = "failed".to_string();
491            trace.oracle_diff = Some(diff.clone());
492            
493            OracleResult::Failed {
494                reason: diff.clone(),
495                diff: Some(diff),
496                trace,
497            }
498        }
499    }
500
501    /// Get the current git revision.
502    fn get_git_revision() -> Result<String> {
503        let output = std::process::Command::new("git")
504            .args(["rev-parse", "HEAD"])
505            .output()?;
506        
507        Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
508    }
509
510    /// Batch validate multiple traces and return statistics.
511    pub fn batch_validate<'a>(
512        &self,
513        traces: impl IntoIterator<Item = (RlmAnalysisResult, &'a str, Option<&'a str>)>,
514    ) -> BatchValidationStats {
515        self.batch_validate_with_options(traces, None, None)
516    }
517
518    /// Batch validate with additional options.
519    pub fn batch_validate_with_options<'a>(
520        &self,
521        traces: impl IntoIterator<Item = (RlmAnalysisResult, &'a str, Option<&'a str>)>,
522        repo_revision: Option<&str>,
523        trace_steps: Option<Vec<TraceStep>>,
524    ) -> BatchValidationStats {
525        let mut stats = BatchValidationStats::default();
526        
527        for (result, source, source_path) in traces {
528            match self.validate(&result, source, source_path, repo_revision, trace_steps.clone()) {
529                OracleResult::Golden(trace) => {
530                    stats.golden.push(trace);
531                }
532                OracleResult::Unverified { reason } => {
533                    stats.unverified.push((result, reason));
534                }
535                OracleResult::Failed { reason, trace, .. } => {
536                    stats.failed.push((trace, reason));
537                }
538            }
539        }
540        
541        stats
542    }
543}
544
545/// Statistics from batch validation.
546#[derive(Debug, Clone, Default)]
547pub struct BatchValidationStats {
548    /// Traces verified as golden (ready for training)
549    pub golden: Vec<ValidatedTrace>,
550    /// Traces that could not be verified
551    pub unverified: Vec<(RlmAnalysisResult, String)>,
552    /// Traces that failed verification
553    pub failed: Vec<(ValidatedTrace, String)>,
554}
555
556impl BatchValidationStats {
557    /// Total number of traces processed.
558    pub fn total(&self) -> usize {
559        self.golden.len() + self.unverified.len() + self.failed.len()
560    }
561    
562    /// Percentage of traces verified as golden.
563    pub fn golden_rate(&self) -> f32 {
564        let total = self.total();
565        if total == 0 {
566            0.0
567        } else {
568            self.golden.len() as f32 / total as f32
569        }
570    }
571    
572    /// Write golden traces to a JSONL file.
573    pub fn write_jsonl(&self, path: &str) -> Result<usize> {
574        use std::fs::File;
575        use std::io::{BufWriter, Write};
576        
577        let file = File::create(path)?;
578        let mut writer = BufWriter::new(file);
579        
580        let mut count = 0;
581        for trace in &self.golden {
582            let json = serde_json::to_string(trace)?;
583            writeln!(writer, "{}", json)?;
584            count += 1;
585        }
586        
587        writer.flush()?;
588        Ok(count)
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595    use crate::rlm::RlmStats;
596
597    fn make_result(answer: &str, query: &str) -> RlmAnalysisResult {
598        RlmAnalysisResult {
599            answer: answer.to_string(),
600            iterations: 2,
601            sub_queries: vec![],
602            stats: RlmStats {
603                input_tokens: 100,
604                output_tokens: 50,
605                iterations: 2,
606                subcalls: 0,
607                elapsed_ms: 500,
608                compression_ratio: 1.0,
609            },
610        }
611    }
612
613    fn sample_rust_code() -> &'static str {
614        r#"
615pub async fn process(input: &str) -> Result<String> {
616    let data = parse(input)?;
617    Ok(data)
618}
619
620async fn parse(input: &str) -> Result<String> {
621    Ok(input.to_uppercase())
622}
623
624pub struct Config {
625    pub debug: bool,
626}
627"#
628    }
629
630    #[test]
631    fn validate_grep_match() {
632        let validator = TraceValidator::new();
633        let source = sample_rust_code();
634        let result = make_result(
635            r#"{"kind": "grep", "file": "test.rs", "pattern": "async fn", "matches": [{"line": 1, "text": "pub async fn process(input: &str) -> Result<String> {"}, {"line": 5, "text": "async fn parse(input: &str) -> Result<String> {"}]}"#,
636            "Find all async functions",
637        );
638        
639        match validator.validate(&result, source, Some("test.rs"), Some("abc123"), None) {
640            OracleResult::Golden(trace) => {
641                assert_eq!(trace.verification_method, VerificationMethod::GrepOracle);
642                assert_eq!(trace.verdict, "golden");
643            }
644            OracleResult::Unverified { .. } => panic!("Expected golden"),
645            OracleResult::Failed { .. } => panic!("Expected golden"),
646        }
647    }
648
649    #[test]
650    fn validate_semantic_unverified() {
651        let validator = TraceValidator::new();
652        let source = sample_rust_code();
653        let result = make_result(
654            r#"{"kind": "semantic", "file": "test.rs", "answer": "This function processes input by parsing it and returning uppercase"}"#,
655            "Explain what the process function does",
656        );
657        
658        match validator.validate(&result, source, Some("test.rs"), Some("abc123"), None) {
659            OracleResult::Unverified { reason } => {
660                assert!(reason.contains("Semantic"));
661            }
662            OracleResult::Golden(_) => panic!("Expected unverified"),
663            OracleResult::Failed { .. } => panic!("Expected unverified"),
664        }
665    }
666
667    #[test]
668    fn batch_validate_mixed() {
669        let validator = TraceValidator::new();
670        let source = sample_rust_code();
671        
672        let traces = vec![
673            (make_result(r#"{"kind": "grep", "file": "x.rs", "pattern": "async", "matches": []}"#, "Find async"), source, None),
674            (make_result(r#"{"kind": "semantic", "file": "x.rs", "answer": "text"}"#, "Explain"), source, None),
675        ];
676        
677        let stats = validator.batch_validate(traces);
678        
679        assert!(stats.golden.len() >= 1);
680        assert!(stats.unverified.len() >= 1);
681        assert!(stats.total() == 2);
682    }
683}