1use crate::accept::{AcceptReport, AcceptRequest, RejectReason};
15use crate::patch::{SpanReplacement, apply_edits};
16use crate::workspace::{CopyOptions, Workspace};
17use crate::writer::JsonlWriter;
18use crate::{
19 AllowedEdit, Diagnostic, DiagnosticSeverity, FileStatus, FileTrace, GoalState, LeanFile,
20 Provenance, Result, TraceConfig, accept, capture_provenance, run_lean_file,
21};
22use camino::{Utf8Path, Utf8PathBuf};
23use serde::{Deserialize, Serialize};
24use std::collections::{BTreeSet, HashMap};
25use std::time::{Duration, Instant};
26use tracing::{debug, info, warn};
27
28type DiagnosticSignature = (Option<u32>, String);
33
34#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
36pub struct Attempt {
37 pub task_id: String,
39 #[serde(default = "default_attempt_id")]
41 pub attempt_id: String,
42 pub allowed_edit: AllowedEdit,
44 pub replacement: String,
46 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub target_file: Option<Utf8PathBuf>,
49 #[serde(default, skip_serializing_if = "Vec::is_empty")]
51 pub extra_edits: Vec<SpanReplacement>,
52 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub original_diagnostic: Option<Diagnostic>,
55 #[serde(default, skip_serializing_if = "Option::is_none")]
57 pub model: Option<String>,
58 #[serde(default, skip_serializing_if = "Option::is_none")]
60 pub prompt_hash: Option<String>,
61 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub metadata: Option<serde_json::Value>,
64}
65
66fn default_attempt_id() -> String {
67 "attempt".to_owned()
68}
69
70impl Attempt {
71 #[must_use]
73 pub fn primary_span(&self) -> SpanReplacement {
74 SpanReplacement {
75 file: self.allowed_edit.file.clone(),
76 start_line: self.allowed_edit.start_line,
77 end_line: self.allowed_edit.end_line,
78 replacement: self.replacement.clone(),
79 }
80 }
81
82 #[must_use]
84 pub fn target(&self) -> Utf8PathBuf {
85 self.target_file
86 .clone()
87 .unwrap_or_else(|| self.allowed_edit.file.clone())
88 }
89}
90
91#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
93#[serde(rename_all = "snake_case")]
94pub enum ReplayStatus {
95 Passed,
97 Rejected,
99 Failed,
101 PatchRefused,
104 TimedOut,
106 RunnerError,
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize)]
112pub struct ReplayResult {
113 pub task_id: String,
115 pub attempt_id: String,
117 pub status: ReplayStatus,
119 pub compile_passed: bool,
121 #[serde(default)]
123 pub accepted: bool,
124 pub diagnostic_count: usize,
126 pub new_errors: usize,
128 pub resolved_original_error: bool,
130 pub regression: bool,
132 pub elapsed_ms: u64,
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub final_goal_state: Option<GoalState>,
137 #[serde(default, skip_serializing_if = "Option::is_none")]
139 pub guards: Option<AcceptReport>,
140 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub reject_reason: Option<RejectReason>,
143 #[serde(skip_serializing_if = "Option::is_none")]
146 pub patch_error: Option<String>,
147}
148
149#[derive(Clone, Debug, Eq, PartialEq)]
151pub struct ReplayOptions {
152 pub lake_root: Utf8PathBuf,
154 pub timeout: Duration,
156 pub keep_workdir: bool,
158 pub allow_multi_file: bool,
160 pub compute_baseline: bool,
162 pub reverse_dep: bool,
164 pub cache_get: bool,
167}
168
169impl ReplayOptions {
170 #[must_use]
173 pub fn new(lake_root: Utf8PathBuf) -> Self {
174 Self {
175 lake_root,
176 timeout: Duration::from_secs(60),
177 keep_workdir: false,
178 allow_multi_file: false,
179 compute_baseline: true,
180 reverse_dep: true,
181 cache_get: true,
182 }
183 }
184}
185
186#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
188pub struct ReplaySummary {
189 pub attempts: usize,
191 pub compiled_pass: usize,
193 pub rejected: usize,
195 pub compiled_fail: usize,
197 pub patch_refused: usize,
199 pub errored: usize,
201}
202
203impl ReplaySummary {
204 fn record(&mut self, result: &ReplayResult) {
205 self.attempts += 1;
206 match result.status {
207 ReplayStatus::Passed => self.compiled_pass += 1,
208 ReplayStatus::Rejected => self.rejected += 1,
209 ReplayStatus::Failed => self.compiled_fail += 1,
210 ReplayStatus::PatchRefused => self.patch_refused += 1,
211 ReplayStatus::TimedOut | ReplayStatus::RunnerError => self.errored += 1,
212 }
213 }
214}
215
216#[derive(Clone, Debug, Default)]
218struct Baseline {
219 error_signatures: BTreeSet<DiagnosticSignature>,
220}
221
222pub async fn run_replay(
224 options: &ReplayOptions,
225 attempts: &[Attempt],
226 writer: &mut JsonlWriter,
227) -> Result<ReplaySummary> {
228 let provenance = capture_provenance(options.lake_root.as_path()).await;
229 let copy_options = CopyOptions::default();
230 let mut baselines: HashMap<String, Baseline> = HashMap::new();
231 let mut summary = ReplaySummary::default();
232
233 for attempt in attempts {
234 let result =
235 replay_attempt(options, &provenance, ©_options, &mut baselines, attempt).await;
236 writer.write_record(&result)?;
237 summary.record(&result);
238 }
239
240 writer.flush()?;
241 Ok(summary)
242}
243
244async fn replay_attempt(
247 options: &ReplayOptions,
248 provenance: &Provenance,
249 copy_options: &CopyOptions,
250 baselines: &mut HashMap<String, Baseline>,
251 attempt: &Attempt,
252) -> ReplayResult {
253 let started = Instant::now();
254 let target = attempt.target();
255
256 let baseline = if options.compute_baseline {
257 baseline_for(options, provenance, copy_options, baselines, &target).await
258 } else {
259 Baseline::default()
260 };
261
262 let workspace =
263 match Workspace::materialize(&options.lake_root, options.keep_workdir, copy_options) {
264 Ok(workspace) => workspace,
265 Err(err) => {
266 return terminal_result(
267 attempt,
268 ReplayStatus::RunnerError,
269 started.elapsed(),
270 err.to_string(),
271 );
272 }
273 };
274 if workspace.is_kept() {
275 info!(task = %attempt.task_id, attempt = %attempt.attempt_id, workdir = %workspace.root(), "kept replay workspace");
276 }
277
278 let mut edits = vec![attempt.primary_span()];
279 edits.extend(attempt.extra_edits.iter().cloned());
280 if let Err(err) = apply_edits(workspace.root(), &edits, options.allow_multi_file) {
281 return terminal_result(
282 attempt,
283 ReplayStatus::PatchRefused,
284 started.elapsed(),
285 err.to_string(),
286 );
287 }
288
289 if options.cache_get {
290 cache_get_if_available(workspace.root(), options.timeout).await;
291 }
292
293 let trace = run_lean_file(
294 &compile_config(workspace.root(), options.timeout),
295 provenance,
296 LeanFile(target.clone()),
297 )
298 .await;
299
300 let mut result = score(attempt, &baseline, &trace, started.elapsed());
301 if result.compile_passed {
302 let request = AcceptRequest {
303 lake_root: &options.lake_root,
304 workspace_root: workspace.root(),
305 target: &target,
306 edit_line: attempt.allowed_edit.start_line,
307 patched_diagnostics: &trace.diagnostics,
308 provenance,
309 timeout: options.timeout,
310 run_reverse_dep: options.reverse_dep,
311 negative_control: None,
312 };
313 let outcome = accept::evaluate(&request).await;
314 result.accepted = outcome.accepted;
315 result.guards = Some(outcome.report);
316 result.reject_reason = outcome.reject_reason;
317 if !outcome.accepted {
318 result.status = ReplayStatus::Rejected;
319 }
320 }
321 result.elapsed_ms = millis(started.elapsed());
322 result
323}
324
325async fn cache_get_if_available(workspace_root: &Utf8Path, timeout: Duration) {
331 let manifest = workspace_root.join("lake-manifest.json");
332 let Ok(text) = std::fs::read_to_string(&manifest) else {
333 return;
334 };
335 if !text.contains("mathlib") {
336 return;
337 }
338 let mut command = tokio::process::Command::new("lake");
339 command
340 .args(["exe", "cache", "get"])
341 .current_dir(workspace_root)
342 .kill_on_drop(true)
343 .stdout(std::process::Stdio::null())
344 .stderr(std::process::Stdio::null());
345 match command.spawn() {
346 Ok(child) => {
347 if tokio::time::timeout(timeout, child.wait_with_output())
348 .await
349 .is_err()
350 {
351 warn!("lake exe cache get timed out; continuing without it");
352 }
353 }
354 Err(err) => debug!(error = %err, "lake exe cache get unavailable; continuing"),
355 }
356}
357
358async fn baseline_for(
360 options: &ReplayOptions,
361 provenance: &Provenance,
362 copy_options: &CopyOptions,
363 baselines: &mut HashMap<String, Baseline>,
364 target: &Utf8Path,
365) -> Baseline {
366 if let Some(baseline) = baselines.get(target.as_str()) {
367 return baseline.clone();
368 }
369 let baseline = compute_baseline(options, provenance, copy_options, target).await;
370 baselines.insert(target.as_str().to_owned(), baseline.clone());
371 baseline
372}
373
374async fn compute_baseline(
376 options: &ReplayOptions,
377 provenance: &Provenance,
378 copy_options: &CopyOptions,
379 target: &Utf8Path,
380) -> Baseline {
381 let workspace = match Workspace::materialize(&options.lake_root, false, copy_options) {
382 Ok(workspace) => workspace,
383 Err(err) => {
384 warn!(target = %target, error = %err, "baseline workspace failed; scoring without it");
385 return Baseline::default();
386 }
387 };
388 let trace = run_lean_file(
389 &compile_config(workspace.root(), options.timeout),
390 provenance,
391 LeanFile(target.to_path_buf()),
392 )
393 .await;
394 Baseline {
395 error_signatures: error_signatures(&trace.diagnostics),
396 }
397}
398
399fn score(
401 attempt: &Attempt,
402 baseline: &Baseline,
403 trace: &FileTrace,
404 elapsed: Duration,
405) -> ReplayResult {
406 let patched_errors = error_signatures(&trace.diagnostics);
407 let new_errors = patched_errors
408 .iter()
409 .filter(|signature| !baseline.error_signatures.contains(*signature))
410 .count();
411
412 let status = match trace.status {
413 FileStatus::Passed => ReplayStatus::Passed,
414 FileStatus::Failed => ReplayStatus::Failed,
415 FileStatus::TimedOut => ReplayStatus::TimedOut,
416 FileStatus::RunnerError => ReplayStatus::RunnerError,
417 };
418 let compile_passed = trace.status == FileStatus::Passed;
419
420 let patch_error = match status {
421 ReplayStatus::TimedOut | ReplayStatus::RunnerError => trace
422 .diagnostics
423 .first()
424 .map(|diagnostic| diagnostic.message.clone()),
425 _ => None,
426 };
427
428 ReplayResult {
429 task_id: attempt.task_id.clone(),
430 attempt_id: attempt.attempt_id.clone(),
431 status,
432 compile_passed,
433 accepted: false,
434 diagnostic_count: trace.diagnostics.len(),
435 new_errors,
436 resolved_original_error: resolved_original(attempt, baseline, &trace.diagnostics),
437 regression: new_errors > 0,
438 elapsed_ms: millis(elapsed),
439 final_goal_state: trace
440 .diagnostics
441 .iter()
442 .find_map(|diagnostic| diagnostic.goal_state.clone()),
443 guards: None,
444 reject_reason: None,
445 patch_error,
446 }
447}
448
449fn resolved_original(attempt: &Attempt, baseline: &Baseline, patched: &[Diagnostic]) -> bool {
455 if let Some(original) = &attempt.original_diagnostic {
456 let target = signature(original);
457 return !patched
458 .iter()
459 .any(|diagnostic| signature(diagnostic) == target);
460 }
461 if baseline.error_signatures.is_empty() {
462 return false;
463 }
464 let patched_errors = error_signatures(patched);
465 baseline
466 .error_signatures
467 .iter()
468 .all(|signature| !patched_errors.contains(signature))
469}
470
471fn terminal_result(
473 attempt: &Attempt,
474 status: ReplayStatus,
475 elapsed: Duration,
476 message: String,
477) -> ReplayResult {
478 ReplayResult {
479 task_id: attempt.task_id.clone(),
480 attempt_id: attempt.attempt_id.clone(),
481 status,
482 compile_passed: false,
483 accepted: false,
484 diagnostic_count: 0,
485 new_errors: 0,
486 resolved_original_error: false,
487 regression: false,
488 elapsed_ms: millis(elapsed),
489 final_goal_state: None,
490 guards: None,
491 reject_reason: None,
492 patch_error: Some(message),
493 }
494}
495
496fn compile_config(lake_root: &Utf8Path, timeout: Duration) -> TraceConfig {
498 let mut config = TraceConfig::new(lake_root.to_path_buf());
499 config.timeout = timeout;
500 config.include_warnings = true;
501 config
502}
503
504fn error_signatures(diagnostics: &[Diagnostic]) -> BTreeSet<DiagnosticSignature> {
506 diagnostics
507 .iter()
508 .filter(|diagnostic| diagnostic.severity == DiagnosticSeverity::Error)
509 .map(signature)
510 .collect()
511}
512
513fn signature(diagnostic: &Diagnostic) -> DiagnosticSignature {
519 let message = match &diagnostic.file {
520 Some(file) => diagnostic
521 .message
522 .replace(file.as_str(), "")
523 .trim()
524 .to_owned(),
525 None => diagnostic.message.clone(),
526 };
527 (diagnostic.line, message)
528}
529
530fn millis(elapsed: Duration) -> u64 {
532 u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX)
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use chrono::Utc;
539 use uuid::Uuid;
540
541 fn diagnostic(line: u32, severity: DiagnosticSeverity, message: &str) -> Diagnostic {
542 Diagnostic {
543 file: Some(Utf8PathBuf::from("/work/Demo.lean")),
544 line: Some(line),
545 column: Some(0),
546 severity,
547 message: message.to_owned(),
548 goal_state: None,
549 }
550 }
551
552 fn trace_with(status: FileStatus, diagnostics: Vec<Diagnostic>) -> FileTrace {
553 FileTrace {
554 run_id: Uuid::new_v4(),
555 file: LeanFile(Utf8PathBuf::from("Demo.lean")),
556 status,
557 exit_code: Some(0),
558 elapsed: Duration::from_millis(5),
559 diagnostics,
560 stdout: None,
561 stderr: None,
562 lean_version: None,
563 lake_version: None,
564 git_commit: None,
565 created_at: Utc::now(),
566 }
567 }
568
569 fn attempt(original: Option<Diagnostic>) -> Attempt {
570 Attempt {
571 task_id: "Demo.demo_one:1".to_owned(),
572 attempt_id: "a1".to_owned(),
573 allowed_edit: AllowedEdit {
574 file: Utf8PathBuf::from("Demo.lean"),
575 start_line: 1,
576 end_line: 1,
577 },
578 replacement: "theorem demo_one : 1 + 1 = 2 := by rfl".to_owned(),
579 target_file: None,
580 extra_edits: Vec::new(),
581 original_diagnostic: original,
582 model: None,
583 prompt_hash: None,
584 metadata: None,
585 }
586 }
587
588 #[test]
589 fn minimal_attempt_deserializes_with_defaults() -> Result<()> {
590 let line = r#"{"task_id":"T","allowed_edit":{"file":"Demo.lean","start_line":1,"end_line":1},"replacement":"by rfl"}"#;
591 let parsed: Attempt = serde_json::from_str(line)?;
592 assert_eq!(parsed.task_id, "T");
593 assert_eq!(parsed.attempt_id, "attempt");
594 assert_eq!(parsed.target(), Utf8PathBuf::from("Demo.lean"));
595 assert!(parsed.extra_edits.is_empty());
596 let span = parsed.primary_span();
597 assert_eq!(span.start_line, 1);
598 assert_eq!(span.replacement, "by rfl");
599 Ok(())
600 }
601
602 #[test]
603 fn clean_proof_passes_with_no_new_errors() {
604 let warning = diagnostic(1, DiagnosticSeverity::Warning, "declaration uses `sorry`");
605 let baseline = Baseline::default();
606 let trace = trace_with(FileStatus::Passed, Vec::new());
607 let result = score(
608 &attempt(Some(warning)),
609 &baseline,
610 &trace,
611 Duration::from_millis(12),
612 );
613 assert_eq!(result.status, ReplayStatus::Passed);
614 assert!(result.compile_passed);
615 assert_eq!(result.new_errors, 0);
616 assert!(!result.regression);
617 assert!(result.resolved_original_error);
618 assert_eq!(result.diagnostic_count, 0);
619 }
620
621 #[test]
622 fn broken_proof_fails_and_flags_regression() {
623 let warning = diagnostic(1, DiagnosticSeverity::Warning, "declaration uses `sorry`");
624 let baseline = Baseline::default();
625 let error = diagnostic(1, DiagnosticSeverity::Error, "Type mismatch");
626 let trace = trace_with(FileStatus::Failed, vec![error]);
627 let result = score(
628 &attempt(Some(warning)),
629 &baseline,
630 &trace,
631 Duration::from_millis(20),
632 );
633 assert_eq!(result.status, ReplayStatus::Failed);
634 assert!(!result.compile_passed);
635 assert_eq!(result.new_errors, 1);
636 assert!(result.regression);
637 assert!(result.resolved_original_error);
639 }
640
641 #[test]
642 fn fixing_a_baseline_error_resolves_without_original_diagnostic() {
643 let baseline = Baseline {
644 error_signatures: error_signatures(&[diagnostic(
645 3,
646 DiagnosticSeverity::Error,
647 "unsolved goals",
648 )]),
649 };
650 let trace = trace_with(FileStatus::Passed, Vec::new());
651 let result = score(&attempt(None), &baseline, &trace, Duration::from_millis(8));
652 assert!(result.compile_passed);
653 assert!(result.resolved_original_error);
654 assert_eq!(result.new_errors, 0);
655 assert!(!result.regression);
656 }
657
658 #[test]
659 fn baseline_error_that_persists_is_not_a_new_error() {
660 let persistent = diagnostic(3, DiagnosticSeverity::Error, "unsolved goals");
661 let baseline = Baseline {
662 error_signatures: error_signatures(std::slice::from_ref(&persistent)),
663 };
664 let trace = trace_with(FileStatus::Failed, vec![persistent]);
665 let result = score(&attempt(None), &baseline, &trace, Duration::from_millis(8));
666 assert_eq!(result.new_errors, 0);
667 assert!(!result.regression);
668 assert!(!result.resolved_original_error);
669 }
670}