1use crate::writer::JsonlWriter;
26use crate::{Attempt, Error, MineTask, Result};
27use camino::{Utf8Path, Utf8PathBuf};
28use serde::{Deserialize, Serialize};
29use sha2::{Digest, Sha256};
30use std::fmt::Write as _;
31use std::process::Stdio;
32use std::time::Duration;
33use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
34use tokio::process::{Child, ChildStdin, ChildStdout, Command};
35use tokio::time;
36use tracing::warn;
37
38const DEFAULT_ATTEMPT_ID: &str = "attempt";
40
41pub const LAKE_ROOT_ENV: &str = "LEAN_AGENT_LAKE_ROOT";
43
44#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
49pub struct RunnerResponse {
50 pub task_id: String,
52 #[serde(default = "default_attempt_id")]
54 pub attempt_id: String,
55 pub replacement: String,
57 #[serde(default, skip_serializing_if = "Option::is_none")]
59 pub model: Option<String>,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
62 pub prompt_hash: Option<String>,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
65 pub metadata: Option<serde_json::Value>,
66}
67
68fn default_attempt_id() -> String {
69 DEFAULT_ATTEMPT_ID.to_owned()
70}
71
72#[derive(Clone, Debug, Eq, PartialEq)]
74pub struct EvalOptions {
75 pub runner: Utf8PathBuf,
77 pub lake_root: Utf8PathBuf,
79 pub timeout: Duration,
81}
82
83impl EvalOptions {
84 #[must_use]
87 pub fn new(runner: Utf8PathBuf, lake_root: Utf8PathBuf) -> Self {
88 Self {
89 runner,
90 lake_root,
91 timeout: Duration::from_secs(120),
92 }
93 }
94}
95
96#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
98pub struct EvalSummary {
99 pub tasks_read: usize,
101 pub attempts_written: usize,
103 pub runner_errors: usize,
105 pub id_mismatches: usize,
107}
108
109enum ReplyOutcome {
111 Parsed(Box<RunnerResponse>),
113 Malformed(String),
115 Closed,
117 TimedOut,
119}
120
121pub async fn run_eval(
126 options: &EvalOptions,
127 tasks: &[MineTask],
128 writer: &mut JsonlWriter,
129) -> Result<EvalSummary> {
130 let mut summary = EvalSummary::default();
131 if tasks.is_empty() {
132 return Ok(summary);
133 }
134
135 let runner = resolve_runner(&options.runner)?;
136 let mut child = spawn_runner(&runner, &options.lake_root)?;
137 let mut stdin = child.stdin.take().ok_or_else(|| Error::RunnerProtocol {
138 detail: "runner stdin was not captured".to_owned(),
139 })?;
140 let stdout = child.stdout.take().ok_or_else(|| Error::RunnerProtocol {
141 detail: "runner stdout was not captured".to_owned(),
142 })?;
143 let mut lines = BufReader::new(stdout).lines();
144
145 for task in tasks {
146 summary.tasks_read += 1;
147 let task_line = serde_json::to_string(task)?;
148
149 if let Err(err) = write_task(&mut stdin, &task_line).await {
150 warn!(task = %task.task_id, error = %err, "failed to send task to runner; stopping");
151 summary.runner_errors += 1;
152 break;
153 }
154
155 let response = match read_reply(&mut lines, options.timeout).await {
156 ReplyOutcome::Parsed(response) => *response,
157 ReplyOutcome::Malformed(detail) => {
158 warn!(task = %task.task_id, %detail, "runner reply was malformed; skipping task");
159 summary.runner_errors += 1;
160 continue;
161 }
162 ReplyOutcome::Closed => {
163 warn!(task = %task.task_id, "runner closed its output early; stopping");
164 summary.runner_errors += 1;
165 break;
166 }
167 ReplyOutcome::TimedOut => {
168 warn!(task = %task.task_id, seconds = options.timeout.as_secs(), "runner timed out; stopping");
169 summary.runner_errors += 1;
170 let _ = child.start_kill();
171 break;
172 }
173 };
174
175 if response.task_id != task.task_id {
176 summary.id_mismatches += 1;
177 warn!(sent = %task.task_id, got = %response.task_id, "runner task_id mismatch; keeping the sent id");
178 }
179
180 let attempt = merge_attempt(task, response, &task_line);
181 writer.write_record(&attempt)?;
182 summary.attempts_written += 1;
183 }
184
185 drop(stdin);
186 let _ = child.wait().await;
187 writer.flush()?;
188 Ok(summary)
189}
190
191fn merge_attempt(task: &MineTask, response: RunnerResponse, task_line: &str) -> Attempt {
197 let prompt_hash = response
198 .prompt_hash
199 .unwrap_or_else(|| sha256_hex(task_line));
200 Attempt {
201 task_id: task.task_id.clone(),
202 attempt_id: response.attempt_id,
203 allowed_edit: task.allowed_edit.clone(),
204 replacement: response.replacement,
205 target_file: None,
206 extra_edits: Vec::new(),
207 original_diagnostic: task.diagnostic.clone(),
208 model: response.model,
209 prompt_hash: Some(prompt_hash),
210 metadata: response.metadata,
211 }
212}
213
214async fn write_task(stdin: &mut ChildStdin, task_line: &str) -> Result<()> {
216 stdin.write_all(task_line.as_bytes()).await?;
217 stdin.write_all(b"\n").await?;
218 stdin.flush().await?;
219 Ok(())
220}
221
222async fn read_reply(lines: &mut Lines<BufReader<ChildStdout>>, timeout: Duration) -> ReplyOutcome {
224 loop {
225 match time::timeout(timeout, lines.next_line()).await {
226 Err(_) => return ReplyOutcome::TimedOut,
227 Ok(Err(err)) => {
228 return ReplyOutcome::Malformed(format!("reading runner output: {err}"));
229 }
230 Ok(Ok(None)) => return ReplyOutcome::Closed,
231 Ok(Ok(Some(line))) => {
232 let trimmed = line.trim();
233 if trimmed.is_empty() {
234 continue;
235 }
236 return match serde_json::from_str::<RunnerResponse>(trimmed) {
237 Ok(response) => ReplyOutcome::Parsed(Box::new(response)),
238 Err(err) => ReplyOutcome::Malformed(format!("parsing runner attempt: {err}")),
239 };
240 }
241 }
242 }
243}
244
245fn spawn_runner(runner: &Utf8Path, lake_root: &Utf8Path) -> Result<Child> {
247 Command::new(runner.as_str())
248 .env(LAKE_ROOT_ENV, lake_root.as_str())
249 .stdin(Stdio::piped())
250 .stdout(Stdio::piped())
251 .stderr(Stdio::inherit())
252 .kill_on_drop(true)
253 .spawn()
254 .map_err(|source| Error::RunnerSpawn {
255 runner: runner.to_path_buf(),
256 source,
257 })
258}
259
260fn resolve_runner(runner: &Utf8Path) -> Result<Utf8PathBuf> {
263 match std::fs::canonicalize(runner) {
264 Ok(path) => Utf8PathBuf::from_path_buf(path).map_err(|path| Error::NonUtf8Path { path }),
265 Err(_) => Ok(runner.to_path_buf()),
266 }
267}
268
269fn sha256_hex(input: &str) -> String {
271 let digest = Sha256::digest(input.as_bytes());
272 let mut out = String::with_capacity(digest.len() * 2);
273 for byte in digest {
274 let _ = write!(out, "{byte:02x}");
275 }
276 out
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use crate::{
283 AllowedEdit, Diagnostic, DiagnosticSeverity, GoalState, LeanFile, MineKind, TargetSpan,
284 };
285 use camino::Utf8PathBuf;
286
287 fn sample_task(task_id: &str, line: u32, with_diagnostic: bool) -> MineTask {
288 let diagnostic = with_diagnostic.then(|| Diagnostic {
289 file: Some(Utf8PathBuf::from("Demo.lean")),
290 line: Some(line),
291 column: Some(2),
292 severity: DiagnosticSeverity::Error,
293 message: "error: unsolved goals".to_owned(),
294 goal_state: Some(GoalState("⊢ n = n".to_owned())),
295 });
296 MineTask {
297 task_id: task_id.to_owned(),
298 project: "demo".to_owned(),
299 file: LeanFile(Utf8PathBuf::from("Demo.lean")),
300 declaration: None,
301 kind: MineKind::Sorry,
302 line,
303 column: 2,
304 imports: vec!["import Init".to_owned()],
305 source_before: "theorem t : True := by\n ".to_owned(),
306 target_span: TargetSpan {
307 start_line: line,
308 start_column: 2,
309 end_line: line,
310 end_column: 7,
311 text: "sorry".to_owned(),
312 },
313 source_after: "\n".to_owned(),
314 diagnostic,
315 goal_state: None,
316 allowed_edit: AllowedEdit {
317 file: Utf8PathBuf::from("Demo.lean"),
318 start_line: line,
319 end_line: line,
320 },
321 instructions: "Replace only the target span.".to_owned(),
322 }
323 }
324
325 #[test]
326 fn sha256_hex_matches_known_vector() {
327 assert_eq!(
329 sha256_hex(""),
330 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
331 );
332 }
333
334 #[test]
335 fn merge_takes_span_from_task_and_proof_from_reply() {
336 let task = sample_task("Demo.t:2", 2, true);
337 let response = RunnerResponse {
338 task_id: "Demo.t:2".to_owned(),
339 attempt_id: "cand-1".to_owned(),
340 replacement: " rfl".to_owned(),
341 model: Some("test-model".to_owned()),
342 prompt_hash: Some("deadbeef".to_owned()),
343 metadata: Some(serde_json::json!({"latency_ms": 12})),
344 };
345 let attempt = merge_attempt(&task, response, "{\"task_id\":\"Demo.t:2\"}");
346 assert_eq!(attempt.task_id, "Demo.t:2");
347 assert_eq!(attempt.attempt_id, "cand-1");
348 assert_eq!(attempt.replacement, " rfl");
349 assert_eq!(attempt.allowed_edit.start_line, 2);
350 assert_eq!(attempt.allowed_edit.end_line, 2);
351 assert_eq!(attempt.model.as_deref(), Some("test-model"));
352 assert_eq!(attempt.prompt_hash.as_deref(), Some("deadbeef"));
353 assert!(attempt.original_diagnostic.is_some());
354 assert!(attempt.metadata.is_some());
355 }
356
357 #[test]
358 fn merge_computes_prompt_hash_when_reply_omits_it() {
359 let task = sample_task("Demo.t:2", 2, false);
360 let response = RunnerResponse {
361 task_id: "Demo.t:2".to_owned(),
362 attempt_id: DEFAULT_ATTEMPT_ID.to_owned(),
363 replacement: " rfl".to_owned(),
364 model: None,
365 prompt_hash: None,
366 metadata: None,
367 };
368 let attempt = merge_attempt(&task, response, "task-line");
369 assert_eq!(
370 attempt.prompt_hash.as_deref(),
371 Some(sha256_hex("task-line").as_str())
372 );
373 assert!(attempt.original_diagnostic.is_none());
374 }
375
376 #[test]
377 fn minimal_runner_response_deserializes_with_defaults() -> Result<()> {
378 let line = r#"{"task_id":"T","replacement":" rfl"}"#;
379 let parsed: RunnerResponse = serde_json::from_str(line)?;
380 assert_eq!(parsed.task_id, "T");
381 assert_eq!(parsed.attempt_id, DEFAULT_ATTEMPT_ID);
382 assert_eq!(parsed.replacement, " rfl");
383 assert!(parsed.model.is_none());
384 assert!(parsed.metadata.is_none());
385 Ok(())
386 }
387
388 #[cfg(unix)]
389 #[tokio::test]
390 async fn run_eval_streams_tasks_through_the_example_runner() -> Result<()> {
391 use std::os::unix::fs::PermissionsExt;
392 use tempfile::TempDir;
393
394 let runner = Utf8PathBuf::from(concat!(
395 env!("CARGO_MANIFEST_DIR"),
396 "/../../scripts/echo_runner.sh"
397 ));
398 let mut perms = std::fs::metadata(runner.as_std_path())?.permissions();
400 perms.set_mode(0o755);
401 std::fs::set_permissions(runner.as_std_path(), perms)?;
402
403 let tasks = vec![
404 sample_task("A.foo:2", 2, false),
405 sample_task("B.bar:3", 3, true),
406 ];
407 let dir = TempDir::new()?;
408 let out = Utf8PathBuf::from_path_buf(dir.path().join("attempts.jsonl"))
409 .map_err(|path| Error::NonUtf8Path { path })?;
410
411 let options = EvalOptions {
412 runner,
413 lake_root: Utf8PathBuf::from("."),
414 timeout: Duration::from_secs(30),
415 };
416 let mut writer = JsonlWriter::create(&out)?;
417 let summary = run_eval(&options, &tasks, &mut writer).await?;
418
419 assert_eq!(summary.tasks_read, 2);
420 assert_eq!(summary.attempts_written, 2);
421 assert_eq!(summary.runner_errors, 0);
422 assert_eq!(summary.id_mismatches, 0);
423
424 let content = std::fs::read_to_string(out.as_std_path())?;
425 let mut attempts = Vec::new();
426 for line in content.lines().filter(|line| !line.trim().is_empty()) {
427 attempts.push(serde_json::from_str::<Attempt>(line)?);
428 }
429 assert_eq!(attempts.len(), 2);
430 assert_eq!(attempts[0].task_id, "A.foo:2");
432 assert_eq!(attempts[1].task_id, "B.bar:3");
433 assert_eq!(attempts[0].replacement, " rfl");
434 assert_eq!(attempts[0].model.as_deref(), Some("echo-runner"));
435 assert!(attempts[0].prompt_hash.is_some());
437 assert!(attempts[1].original_diagnostic.is_some());
439 Ok(())
440 }
441}