1use anyhow::Result;
47use chrono::{DateTime, Utc};
48use serde::{Deserialize, Serialize};
49use std::time::Instant;
50
51use super::schema::FinalPayload;
52use super::{grep_oracle::GrepOracle, tree_sitter_oracle::TreeSitterOracle, QueryType};
53use crate::rlm::repl::RlmAnalysisResult;
54
55#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
57pub enum OracleResult {
58 Golden(ValidatedTrace),
60 Unverified {
62 reason: String,
63 },
64 Failed {
66 reason: String,
67 diff: Option<String>,
68 trace: ValidatedTrace,
69 },
70}
71
72#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
74pub struct TraceStep {
75 pub iteration: usize,
77 pub action: String,
79 pub output: String,
81}
82
83#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
85pub struct ValidatedTrace {
86 pub prompt: String,
88 pub trace: Vec<TraceStep>,
90 #[serde(skip_serializing_if = "Option::is_none")]
92 pub final_payload: Option<FinalPayload>,
93 pub verdict: String,
95 #[serde(skip_serializing_if = "Option::is_none")]
97 pub oracle_diff: Option<String>,
98 pub repo_revision: String,
100 pub timestamp: String,
102 #[serde(skip)]
105 pub answer: String,
106 #[serde(skip)]
108 pub iterations: usize,
109 #[serde(skip)]
111 pub subcalls: usize,
112 #[serde(skip)]
114 pub input_tokens: usize,
115 #[serde(skip)]
117 pub output_tokens: usize,
118 #[serde(skip)]
120 pub elapsed_ms: u64,
121 #[serde(skip)]
123 pub source_path: Option<String>,
124 #[serde(skip)]
126 pub verification_method: VerificationMethod,
127 #[serde(skip)]
129 pub trace_id: String,
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
134pub enum VerificationMethod {
135 GrepOracle,
137 TreeSitterOracle,
139 #[default]
141 None,
142}
143
144pub struct TraceValidator {
146 confidence_threshold: f32,
148}
149
150impl Default for TraceValidator {
151 fn default() -> Self {
152 Self {
153 confidence_threshold: 0.95,
154 }
155 }
156}
157
158impl TraceValidator {
159 pub fn new() -> Self {
161 Self::default()
162 }
163
164 pub fn with_confidence_threshold(mut self, threshold: f32) -> Self {
166 self.confidence_threshold = threshold.clamp(0.0, 1.0);
167 self
168 }
169
170 pub fn validate(
187 &self,
188 result: &RlmAnalysisResult,
189 source: &str,
190 source_path: Option<&str>,
191 repo_revision: Option<&str>,
192 trace_steps: Option<Vec<TraceStep>>,
193 ) -> OracleResult {
194 let _start = Instant::now();
195
196 let revision = repo_revision
198 .map(|s| s.to_string())
199 .or_else(|| Self::get_git_revision().ok())
200 .unwrap_or_else(|| "unknown".to_string());
201
202 let query = result
204 .sub_queries
205 .first()
206 .map(|sq| sq.query.clone())
207 .unwrap_or_else(|| "unknown query".to_string());
208
209 let final_payload = FinalPayload::parse(&result.answer);
211
212 let base_trace = || ValidatedTrace {
214 prompt: query.clone(),
215 trace: trace_steps.unwrap_or_default(),
216 final_payload: Some(final_payload.clone()),
217 verdict: "unverified".to_string(),
218 oracle_diff: None,
219 repo_revision: revision.clone(),
220 timestamp: Utc::now().to_rfc3339(),
221 answer: result.answer.clone(),
223 iterations: result.iterations,
224 subcalls: result.sub_queries.len(),
225 input_tokens: result.stats.input_tokens,
226 output_tokens: result.stats.output_tokens,
227 elapsed_ms: result.stats.elapsed_ms,
228 source_path: source_path.map(|s| s.to_string()),
229 verification_method: VerificationMethod::None,
230 trace_id: uuid::Uuid::new_v4().to_string(),
231 };
232
233 let verdict = match &final_payload {
235 FinalPayload::Grep(_) => {
236 self.validate_grep_payload(&final_payload, source, source_path, &query, base_trace)
237 }
238 FinalPayload::Ast(_) => {
239 self.validate_ast_payload(&final_payload, source, source_path, &query, base_trace)
240 }
241 FinalPayload::Semantic(_) => {
242 return OracleResult::Unverified {
244 reason: "Semantic queries require LLM understanding - no deterministic oracle available".to_string(),
245 };
246 }
247 FinalPayload::Malformed { error, .. } => {
248 let mut trace = base_trace();
250 trace.verdict = "failed".to_string();
251 OracleResult::Failed {
252 reason: format!("Malformed FINAL payload: {}", error),
253 diff: None,
254 trace,
255 }
256 }
257 };
258
259 verdict
260 }
261
262 fn validate_grep_payload(
264 &self,
265 payload: &FinalPayload,
266 source: &str,
267 source_path: Option<&str>,
268 query: &str,
269 base_trace: impl FnOnce() -> ValidatedTrace,
270 ) -> OracleResult {
271 let grep_payload = match payload {
272 FinalPayload::Grep(p) => p,
273 _ => unreachable!(),
274 };
275
276 let oracle = GrepOracle::new(source.to_string());
277
278 let ground_truth = match oracle.grep(&grep_payload.pattern) {
280 Ok(m) => m,
281 Err(e) => {
282 return OracleResult::Unverified {
283 reason: format!("Could not run grep: {}", e),
284 };
285 }
286 };
287
288 let claimed: Vec<(usize, String)> = grep_payload.matches
290 .iter()
291 .map(|m| (m.line, m.text.clone()))
292 .collect();
293
294 let verification = oracle.verify_matches(&claimed, &ground_truth);
295
296 match verification {
297 super::grep_oracle::GrepVerification::ExactMatch
298 | super::grep_oracle::GrepVerification::UnorderedMatch => {
299 let mut trace = base_trace();
300 trace.verification_method = VerificationMethod::GrepOracle;
301 trace.verdict = "golden".to_string();
302
303 tracing::info!(
304 query = %query,
305 pattern = %grep_payload.pattern,
306 "Grep oracle verified trace as golden"
307 );
308
309 OracleResult::Golden(trace)
310 }
311 super::grep_oracle::GrepVerification::SubsetMatch { claimed, actual } => {
312 let coverage = claimed as f32 / actual.max(1) as f32;
313 if coverage >= self.confidence_threshold {
314 let mut trace = base_trace();
315 trace.verification_method = VerificationMethod::GrepOracle;
316 trace.verdict = "golden".to_string();
317
318 OracleResult::Golden(trace)
319 } else {
320 let diff = format!(
321 "Subset match: model claimed {} but source has {} (coverage: {:.1}%)",
322 claimed, actual, coverage * 100.0
323 );
324 let mut trace = base_trace();
325 trace.verification_method = VerificationMethod::GrepOracle;
326 trace.verdict = "failed".to_string();
327 trace.oracle_diff = Some(diff.clone());
328
329 OracleResult::Failed {
330 reason: diff.clone(),
331 diff: Some(diff),
332 trace,
333 }
334 }
335 }
336 super::grep_oracle::GrepVerification::HasFalsePositives { false_positives } => {
337 let diff = format!(
338 "False positives: {} claims not found in source: {:?}",
339 false_positives.len(),
340 false_positives
341 );
342 let mut trace = base_trace();
343 trace.verification_method = VerificationMethod::GrepOracle;
344 trace.verdict = "failed".to_string();
345 trace.oracle_diff = Some(diff.clone());
346
347 OracleResult::Failed {
348 reason: diff.clone(),
349 diff: Some(diff),
350 trace,
351 }
352 }
353 super::grep_oracle::GrepVerification::HasFalseNegatives { false_negatives } => {
354 let diff = format!(
355 "False negatives: {} items in source not claimed: {:?}",
356 false_negatives.len(),
357 false_negatives
358 );
359 let mut trace = base_trace();
360 trace.verification_method = VerificationMethod::GrepOracle;
361 trace.verdict = "failed".to_string();
362 trace.oracle_diff = Some(diff.clone());
363
364 OracleResult::Failed {
365 reason: diff.clone(),
366 diff: Some(diff),
367 trace,
368 }
369 }
370 super::grep_oracle::GrepVerification::Mismatch => {
371 let diff = "Complete mismatch between claimed and actual matches".to_string();
372 let mut trace = base_trace();
373 trace.verification_method = VerificationMethod::GrepOracle;
374 trace.verdict = "failed".to_string();
375 trace.oracle_diff = Some(diff.clone());
376
377 OracleResult::Failed {
378 reason: diff.clone(),
379 diff: Some(diff),
380 trace,
381 }
382 }
383 super::grep_oracle::GrepVerification::CannotVerify { reason } => {
384 OracleResult::Unverified { reason }
385 }
386 }
387 }
388
389 fn validate_ast_payload(
391 &self,
392 payload: &FinalPayload,
393 source: &str,
394 source_path: Option<&str>,
395 query: &str,
396 base_trace: impl FnOnce() -> ValidatedTrace,
397 ) -> OracleResult {
398 let ast_payload = match payload {
399 FinalPayload::Ast(p) => p,
400 _ => unreachable!(),
401 };
402
403 let mut oracle = TreeSitterOracle::new(source.to_string());
404
405 let actual_results = match ast_payload.query.as_str() {
407 "functions" => {
408 match oracle.get_functions() {
409 Ok(funcs) => funcs.iter().map(|f| f.name.clone()).collect(),
410 Err(e) => {
411 return OracleResult::Unverified {
412 reason: format!("Failed to parse AST: {}", e),
413 };
414 }
415 }
416 }
417 "structs" => {
418 match oracle.get_structs() {
419 Ok(structs) => structs.iter().map(|s| s.name.clone()).collect(),
420 Err(e) => {
421 return OracleResult::Unverified {
422 reason: format!("Failed to parse AST: {}", e),
423 };
424 }
425 }
426 }
427 "enums" => {
428 match oracle.get_enums() {
429 Ok(enums) => enums.iter().map(|e| e.name.clone()).collect(),
430 Err(e) => {
431 return OracleResult::Unverified {
432 reason: format!("Failed to parse AST: {}", e),
433 };
434 }
435 }
436 }
437 _ => {
438 match oracle.get_functions() {
440 Ok(funcs) => funcs.iter().map(|f| f.name.clone()).collect(),
441 Err(_) => vec![],
442 }
443 }
444 };
445
446 let claimed: std::collections::HashSet<_> = ast_payload.results
448 .iter()
449 .map(|r| r.name.clone())
450 .collect();
451 let actual: std::collections::HashSet<_> = actual_results.iter().cloned().collect();
452
453 if claimed == actual {
454 let mut trace = base_trace();
455 trace.verification_method = VerificationMethod::TreeSitterOracle;
456 trace.verdict = "golden".to_string();
457
458 OracleResult::Golden(trace)
459 } else if claimed.is_subset(&actual) {
460 let coverage = claimed.len() as f32 / actual.len().max(1) as f32;
461 if coverage >= self.confidence_threshold {
462 let mut trace = base_trace();
463 trace.verification_method = VerificationMethod::TreeSitterOracle;
464 trace.verdict = "golden".to_string();
465
466 OracleResult::Golden(trace)
467 } else {
468 let diff = format!(
469 "Partial match: claimed {:?}, actual {:?}",
470 claimed, actual
471 );
472 let mut trace = base_trace();
473 trace.verification_method = VerificationMethod::TreeSitterOracle;
474 trace.verdict = "failed".to_string();
475 trace.oracle_diff = Some(diff.clone());
476
477 OracleResult::Failed {
478 reason: diff.clone(),
479 diff: Some(diff),
480 trace,
481 }
482 }
483 } else {
484 let diff = format!(
485 "Mismatch: claimed {:?}, actual {:?}",
486 claimed, actual
487 );
488 let mut trace = base_trace();
489 trace.verification_method = VerificationMethod::TreeSitterOracle;
490 trace.verdict = "failed".to_string();
491 trace.oracle_diff = Some(diff.clone());
492
493 OracleResult::Failed {
494 reason: diff.clone(),
495 diff: Some(diff),
496 trace,
497 }
498 }
499 }
500
501 fn get_git_revision() -> Result<String> {
503 let output = std::process::Command::new("git")
504 .args(["rev-parse", "HEAD"])
505 .output()?;
506
507 Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
508 }
509
510 pub fn batch_validate<'a>(
512 &self,
513 traces: impl IntoIterator<Item = (RlmAnalysisResult, &'a str, Option<&'a str>)>,
514 ) -> BatchValidationStats {
515 self.batch_validate_with_options(traces, None, None)
516 }
517
518 pub fn batch_validate_with_options<'a>(
520 &self,
521 traces: impl IntoIterator<Item = (RlmAnalysisResult, &'a str, Option<&'a str>)>,
522 repo_revision: Option<&str>,
523 trace_steps: Option<Vec<TraceStep>>,
524 ) -> BatchValidationStats {
525 let mut stats = BatchValidationStats::default();
526
527 for (result, source, source_path) in traces {
528 match self.validate(&result, source, source_path, repo_revision, trace_steps.clone()) {
529 OracleResult::Golden(trace) => {
530 stats.golden.push(trace);
531 }
532 OracleResult::Unverified { reason } => {
533 stats.unverified.push((result, reason));
534 }
535 OracleResult::Failed { reason, trace, .. } => {
536 stats.failed.push((trace, reason));
537 }
538 }
539 }
540
541 stats
542 }
543}
544
545#[derive(Debug, Clone, Default)]
547pub struct BatchValidationStats {
548 pub golden: Vec<ValidatedTrace>,
550 pub unverified: Vec<(RlmAnalysisResult, String)>,
552 pub failed: Vec<(ValidatedTrace, String)>,
554}
555
556impl BatchValidationStats {
557 pub fn total(&self) -> usize {
559 self.golden.len() + self.unverified.len() + self.failed.len()
560 }
561
562 pub fn golden_rate(&self) -> f32 {
564 let total = self.total();
565 if total == 0 {
566 0.0
567 } else {
568 self.golden.len() as f32 / total as f32
569 }
570 }
571
572 pub fn write_jsonl(&self, path: &str) -> Result<usize> {
574 use std::fs::File;
575 use std::io::{BufWriter, Write};
576
577 let file = File::create(path)?;
578 let mut writer = BufWriter::new(file);
579
580 let mut count = 0;
581 for trace in &self.golden {
582 let json = serde_json::to_string(trace)?;
583 writeln!(writer, "{}", json)?;
584 count += 1;
585 }
586
587 writer.flush()?;
588 Ok(count)
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595 use crate::rlm::RlmStats;
596
597 fn make_result(answer: &str, query: &str) -> RlmAnalysisResult {
598 RlmAnalysisResult {
599 answer: answer.to_string(),
600 iterations: 2,
601 sub_queries: vec![],
602 stats: RlmStats {
603 input_tokens: 100,
604 output_tokens: 50,
605 iterations: 2,
606 subcalls: 0,
607 elapsed_ms: 500,
608 compression_ratio: 1.0,
609 },
610 }
611 }
612
613 fn sample_rust_code() -> &'static str {
614 r#"
615pub async fn process(input: &str) -> Result<String> {
616 let data = parse(input)?;
617 Ok(data)
618}
619
620async fn parse(input: &str) -> Result<String> {
621 Ok(input.to_uppercase())
622}
623
624pub struct Config {
625 pub debug: bool,
626}
627"#
628 }
629
630 #[test]
631 fn validate_grep_match() {
632 let validator = TraceValidator::new();
633 let source = sample_rust_code();
634 let result = make_result(
635 r#"{"kind": "grep", "file": "test.rs", "pattern": "async fn", "matches": [{"line": 1, "text": "pub async fn process(input: &str) -> Result<String> {"}, {"line": 5, "text": "async fn parse(input: &str) -> Result<String> {"}]}"#,
636 "Find all async functions",
637 );
638
639 match validator.validate(&result, source, Some("test.rs"), Some("abc123"), None) {
640 OracleResult::Golden(trace) => {
641 assert_eq!(trace.verification_method, VerificationMethod::GrepOracle);
642 assert_eq!(trace.verdict, "golden");
643 }
644 OracleResult::Unverified { .. } => panic!("Expected golden"),
645 OracleResult::Failed { .. } => panic!("Expected golden"),
646 }
647 }
648
649 #[test]
650 fn validate_semantic_unverified() {
651 let validator = TraceValidator::new();
652 let source = sample_rust_code();
653 let result = make_result(
654 r#"{"kind": "semantic", "file": "test.rs", "answer": "This function processes input by parsing it and returning uppercase"}"#,
655 "Explain what the process function does",
656 );
657
658 match validator.validate(&result, source, Some("test.rs"), Some("abc123"), None) {
659 OracleResult::Unverified { reason } => {
660 assert!(reason.contains("Semantic"));
661 }
662 OracleResult::Golden(_) => panic!("Expected unverified"),
663 OracleResult::Failed { .. } => panic!("Expected unverified"),
664 }
665 }
666
667 #[test]
668 fn batch_validate_mixed() {
669 let validator = TraceValidator::new();
670 let source = sample_rust_code();
671
672 let traces = vec![
673 (make_result(r#"{"kind": "grep", "file": "x.rs", "pattern": "async", "matches": []}"#, "Find async"), source, None),
674 (make_result(r#"{"kind": "semantic", "file": "x.rs", "answer": "text"}"#, "Explain"), source, None),
675 ];
676
677 let stats = validator.batch_validate(traces);
678
679 assert!(stats.golden.len() >= 1);
680 assert!(stats.unverified.len() >= 1);
681 assert!(stats.total() == 2);
682 }
683}