Skip to main content

lean_agent_core/
eval.rs

1//! The agent stage: hand mined tasks to an external runner, collect attempts.
2//!
3//! `lean-agent` itself never calls a model. Instead it talks to a user-supplied
4//! runner over a line-oriented process contract: one task JSON per line goes to
5//! the runner's stdin, and one attempt JSON per line is read back from its
6//! stdout. The runner owns prompt construction and any model calls; this crate
7//! owns the task model, the pairing, and turning each reply into a replayable
8//! [`Attempt`].
9//!
10//! ## Process contract
11//!
12//! - The runner is spawned once and read in lock step: a task is written and
13//!   flushed, then exactly one reply line is read before the next task is sent.
14//!   A well-behaved runner emits one line per task and flushes after each.
15//! - Blank lines from the runner are ignored, so a chatty runner does not
16//!   desynchronize the exchange.
17//! - The lake root is passed to the runner in the `LEAN_AGENT_LAKE_ROOT`
18//!   environment variable so the runner can read the project if it needs to.
19//!
20//! Each reply is `{task_id, attempt_id, replacement, model?, prompt_hash?,
21//! metadata?}`. The reply carries only the proof text; the editable span,
22//! target file, and backing diagnostic come from the mined task, so the merged
23//! [`Attempt`] is everything `replay` needs.
24
25use crate::writer::JsonlWriter;
26use crate::{Attempt, Error, MineTask, Result};
27use camino::{Utf8Path, Utf8PathBuf};
28use serde::{Deserialize, Serialize};
29use sha2::{Digest, Sha256};
30use std::fmt::Write as _;
31use std::process::Stdio;
32use std::time::Duration;
33use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, Lines};
34use tokio::process::{Child, ChildStdin, ChildStdout, Command};
35use tokio::time;
36use tracing::warn;
37
38/// Default attempt identifier when a runner omits one.
39const DEFAULT_ATTEMPT_ID: &str = "attempt";
40
41/// Environment variable carrying the lake root through to the runner.
42pub const LAKE_ROOT_ENV: &str = "LEAN_AGENT_LAKE_ROOT";
43
44/// One line the runner writes back for one task.
45///
46/// Only `replacement` is required beyond the identifiers; the span and file are
47/// taken from the mined task, so the runner stays a pure text producer.
48#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
49pub struct RunnerResponse {
50    /// Task this reply answers; advisory, since pairing is positional.
51    pub task_id: String,
52    /// Identifier for this attempt; defaults to `attempt` when omitted.
53    #[serde(default = "default_attempt_id")]
54    pub attempt_id: String,
55    /// New content the runner proposes for the task's editable span.
56    pub replacement: String,
57    /// Model the runner used, when it reports one.
58    #[serde(default, skip_serializing_if = "Option::is_none")]
59    pub model: Option<String>,
60    /// Hash of the prompt the runner built, when it reports one.
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    pub prompt_hash: Option<String>,
63    /// Free-form runner metadata (cost, latency, sampling settings).
64    #[serde(default, skip_serializing_if = "Option::is_none")]
65    pub metadata: Option<serde_json::Value>,
66}
67
68fn default_attempt_id() -> String {
69    DEFAULT_ATTEMPT_ID.to_owned()
70}
71
72/// Runtime options for an eval run.
73#[derive(Clone, Debug, Eq, PartialEq)]
74pub struct EvalOptions {
75    /// Runner executable or script that speaks the process contract.
76    pub runner: Utf8PathBuf,
77    /// Lake workspace root, forwarded to the runner via [`LAKE_ROOT_ENV`].
78    pub lake_root: Utf8PathBuf,
79    /// How long to wait for one reply before treating the runner as stuck.
80    pub timeout: Duration,
81}
82
83impl EvalOptions {
84    /// Options for `runner` rooted at `lake_root`, with a two-minute per-task
85    /// reply timeout.
86    #[must_use]
87    pub fn new(runner: Utf8PathBuf, lake_root: Utf8PathBuf) -> Self {
88        Self {
89            runner,
90            lake_root,
91            timeout: Duration::from_secs(120),
92        }
93    }
94}
95
96/// Counts from an eval run.
97#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
98pub struct EvalSummary {
99    /// Tasks sent to the runner.
100    pub tasks_read: usize,
101    /// Attempts written to the output.
102    pub attempts_written: usize,
103    /// Tasks where the runner errored, timed out, or replied malformed.
104    pub runner_errors: usize,
105    /// Replies whose `task_id` did not match the task that was sent.
106    pub id_mismatches: usize,
107}
108
109/// Outcome of reading one reply line from the runner.
110enum ReplyOutcome {
111    /// A reply parsed into a [`RunnerResponse`].
112    Parsed(Box<RunnerResponse>),
113    /// A line arrived but did not parse; the detail explains why.
114    Malformed(String),
115    /// The runner closed its output before replying.
116    Closed,
117    /// No reply arrived within the timeout.
118    TimedOut,
119}
120
121/// Stream every task to the runner and write one attempt per reply.
122///
123/// Failures are converted into counts and log lines rather than aborting the
124/// run, except when the runner cannot be started at all.
125pub async fn run_eval(
126    options: &EvalOptions,
127    tasks: &[MineTask],
128    writer: &mut JsonlWriter,
129) -> Result<EvalSummary> {
130    let mut summary = EvalSummary::default();
131    if tasks.is_empty() {
132        return Ok(summary);
133    }
134
135    let runner = resolve_runner(&options.runner)?;
136    let mut child = spawn_runner(&runner, &options.lake_root)?;
137    let mut stdin = child.stdin.take().ok_or_else(|| Error::RunnerProtocol {
138        detail: "runner stdin was not captured".to_owned(),
139    })?;
140    let stdout = child.stdout.take().ok_or_else(|| Error::RunnerProtocol {
141        detail: "runner stdout was not captured".to_owned(),
142    })?;
143    let mut lines = BufReader::new(stdout).lines();
144
145    for task in tasks {
146        summary.tasks_read += 1;
147        let task_line = serde_json::to_string(task)?;
148
149        if let Err(err) = write_task(&mut stdin, &task_line).await {
150            warn!(task = %task.task_id, error = %err, "failed to send task to runner; stopping");
151            summary.runner_errors += 1;
152            break;
153        }
154
155        let response = match read_reply(&mut lines, options.timeout).await {
156            ReplyOutcome::Parsed(response) => *response,
157            ReplyOutcome::Malformed(detail) => {
158                warn!(task = %task.task_id, %detail, "runner reply was malformed; skipping task");
159                summary.runner_errors += 1;
160                continue;
161            }
162            ReplyOutcome::Closed => {
163                warn!(task = %task.task_id, "runner closed its output early; stopping");
164                summary.runner_errors += 1;
165                break;
166            }
167            ReplyOutcome::TimedOut => {
168                warn!(task = %task.task_id, seconds = options.timeout.as_secs(), "runner timed out; stopping");
169                summary.runner_errors += 1;
170                let _ = child.start_kill();
171                break;
172            }
173        };
174
175        if response.task_id != task.task_id {
176            summary.id_mismatches += 1;
177            warn!(sent = %task.task_id, got = %response.task_id, "runner task_id mismatch; keeping the sent id");
178        }
179
180        let attempt = merge_attempt(task, response, &task_line);
181        writer.write_record(&attempt)?;
182        summary.attempts_written += 1;
183    }
184
185    drop(stdin);
186    let _ = child.wait().await;
187    writer.flush()?;
188    Ok(summary)
189}
190
191/// Merge a mined task with a runner reply into a replayable attempt.
192///
193/// The task is authoritative for the editable span, target file, and backing
194/// diagnostic; the reply supplies the proof text and provenance. The prompt
195/// hash falls back to a hash of the exact task line that was sent.
196fn merge_attempt(task: &MineTask, response: RunnerResponse, task_line: &str) -> Attempt {
197    let prompt_hash = response
198        .prompt_hash
199        .unwrap_or_else(|| sha256_hex(task_line));
200    Attempt {
201        task_id: task.task_id.clone(),
202        attempt_id: response.attempt_id,
203        allowed_edit: task.allowed_edit.clone(),
204        replacement: response.replacement,
205        target_file: None,
206        extra_edits: Vec::new(),
207        original_diagnostic: task.diagnostic.clone(),
208        model: response.model,
209        prompt_hash: Some(prompt_hash),
210        metadata: response.metadata,
211    }
212}
213
214/// Send one task line to the runner and flush so it can act immediately.
215async fn write_task(stdin: &mut ChildStdin, task_line: &str) -> Result<()> {
216    stdin.write_all(task_line.as_bytes()).await?;
217    stdin.write_all(b"\n").await?;
218    stdin.flush().await?;
219    Ok(())
220}
221
222/// Read one reply, skipping blank lines, within `timeout`.
223async fn read_reply(lines: &mut Lines<BufReader<ChildStdout>>, timeout: Duration) -> ReplyOutcome {
224    loop {
225        match time::timeout(timeout, lines.next_line()).await {
226            Err(_) => return ReplyOutcome::TimedOut,
227            Ok(Err(err)) => {
228                return ReplyOutcome::Malformed(format!("reading runner output: {err}"));
229            }
230            Ok(Ok(None)) => return ReplyOutcome::Closed,
231            Ok(Ok(Some(line))) => {
232                let trimmed = line.trim();
233                if trimmed.is_empty() {
234                    continue;
235                }
236                return match serde_json::from_str::<RunnerResponse>(trimmed) {
237                    Ok(response) => ReplyOutcome::Parsed(Box::new(response)),
238                    Err(err) => ReplyOutcome::Malformed(format!("parsing runner attempt: {err}")),
239                };
240            }
241        }
242    }
243}
244
245/// Spawn the runner with the lake root in its environment.
246fn spawn_runner(runner: &Utf8Path, lake_root: &Utf8Path) -> Result<Child> {
247    Command::new(runner.as_str())
248        .env(LAKE_ROOT_ENV, lake_root.as_str())
249        .stdin(Stdio::piped())
250        .stdout(Stdio::piped())
251        .stderr(Stdio::inherit())
252        .kill_on_drop(true)
253        .spawn()
254        .map_err(|source| Error::RunnerSpawn {
255            runner: runner.to_path_buf(),
256            source,
257        })
258}
259
260/// Resolve the runner path: canonicalize a real file, else pass it through so a
261/// bare command name can still be found on `PATH`.
262fn resolve_runner(runner: &Utf8Path) -> Result<Utf8PathBuf> {
263    match std::fs::canonicalize(runner) {
264        Ok(path) => Utf8PathBuf::from_path_buf(path).map_err(|path| Error::NonUtf8Path { path }),
265        Err(_) => Ok(runner.to_path_buf()),
266    }
267}
268
269/// Lowercase hex SHA-256 of `input`.
270fn sha256_hex(input: &str) -> String {
271    let digest = Sha256::digest(input.as_bytes());
272    let mut out = String::with_capacity(digest.len() * 2);
273    for byte in digest {
274        let _ = write!(out, "{byte:02x}");
275    }
276    out
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use crate::{
283        AllowedEdit, Diagnostic, DiagnosticSeverity, GoalState, LeanFile, MineKind, TargetSpan,
284    };
285    use camino::Utf8PathBuf;
286
287    fn sample_task(task_id: &str, line: u32, with_diagnostic: bool) -> MineTask {
288        let diagnostic = with_diagnostic.then(|| Diagnostic {
289            file: Some(Utf8PathBuf::from("Demo.lean")),
290            line: Some(line),
291            column: Some(2),
292            severity: DiagnosticSeverity::Error,
293            message: "error: unsolved goals".to_owned(),
294            goal_state: Some(GoalState("⊢ n = n".to_owned())),
295        });
296        MineTask {
297            task_id: task_id.to_owned(),
298            project: "demo".to_owned(),
299            file: LeanFile(Utf8PathBuf::from("Demo.lean")),
300            declaration: None,
301            kind: MineKind::Sorry,
302            line,
303            column: 2,
304            imports: vec!["import Init".to_owned()],
305            source_before: "theorem t : True := by\n  ".to_owned(),
306            target_span: TargetSpan {
307                start_line: line,
308                start_column: 2,
309                end_line: line,
310                end_column: 7,
311                text: "sorry".to_owned(),
312            },
313            source_after: "\n".to_owned(),
314            diagnostic,
315            goal_state: None,
316            allowed_edit: AllowedEdit {
317                file: Utf8PathBuf::from("Demo.lean"),
318                start_line: line,
319                end_line: line,
320            },
321            instructions: "Replace only the target span.".to_owned(),
322        }
323    }
324
325    #[test]
326    fn sha256_hex_matches_known_vector() {
327        // SHA-256 of the empty string.
328        assert_eq!(
329            sha256_hex(""),
330            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
331        );
332    }
333
334    #[test]
335    fn merge_takes_span_from_task_and_proof_from_reply() {
336        let task = sample_task("Demo.t:2", 2, true);
337        let response = RunnerResponse {
338            task_id: "Demo.t:2".to_owned(),
339            attempt_id: "cand-1".to_owned(),
340            replacement: "  rfl".to_owned(),
341            model: Some("test-model".to_owned()),
342            prompt_hash: Some("deadbeef".to_owned()),
343            metadata: Some(serde_json::json!({"latency_ms": 12})),
344        };
345        let attempt = merge_attempt(&task, response, "{\"task_id\":\"Demo.t:2\"}");
346        assert_eq!(attempt.task_id, "Demo.t:2");
347        assert_eq!(attempt.attempt_id, "cand-1");
348        assert_eq!(attempt.replacement, "  rfl");
349        assert_eq!(attempt.allowed_edit.start_line, 2);
350        assert_eq!(attempt.allowed_edit.end_line, 2);
351        assert_eq!(attempt.model.as_deref(), Some("test-model"));
352        assert_eq!(attempt.prompt_hash.as_deref(), Some("deadbeef"));
353        assert!(attempt.original_diagnostic.is_some());
354        assert!(attempt.metadata.is_some());
355    }
356
357    #[test]
358    fn merge_computes_prompt_hash_when_reply_omits_it() {
359        let task = sample_task("Demo.t:2", 2, false);
360        let response = RunnerResponse {
361            task_id: "Demo.t:2".to_owned(),
362            attempt_id: DEFAULT_ATTEMPT_ID.to_owned(),
363            replacement: "  rfl".to_owned(),
364            model: None,
365            prompt_hash: None,
366            metadata: None,
367        };
368        let attempt = merge_attempt(&task, response, "task-line");
369        assert_eq!(
370            attempt.prompt_hash.as_deref(),
371            Some(sha256_hex("task-line").as_str())
372        );
373        assert!(attempt.original_diagnostic.is_none());
374    }
375
376    #[test]
377    fn minimal_runner_response_deserializes_with_defaults() -> Result<()> {
378        let line = r#"{"task_id":"T","replacement":"  rfl"}"#;
379        let parsed: RunnerResponse = serde_json::from_str(line)?;
380        assert_eq!(parsed.task_id, "T");
381        assert_eq!(parsed.attempt_id, DEFAULT_ATTEMPT_ID);
382        assert_eq!(parsed.replacement, "  rfl");
383        assert!(parsed.model.is_none());
384        assert!(parsed.metadata.is_none());
385        Ok(())
386    }
387
388    #[cfg(unix)]
389    #[tokio::test]
390    async fn run_eval_streams_tasks_through_the_example_runner() -> Result<()> {
391        use std::os::unix::fs::PermissionsExt;
392        use tempfile::TempDir;
393
394        let runner = Utf8PathBuf::from(concat!(
395            env!("CARGO_MANIFEST_DIR"),
396            "/../../scripts/echo_runner.sh"
397        ));
398        // Make sure the shipped example stays executable.
399        let mut perms = std::fs::metadata(runner.as_std_path())?.permissions();
400        perms.set_mode(0o755);
401        std::fs::set_permissions(runner.as_std_path(), perms)?;
402
403        let tasks = vec![
404            sample_task("A.foo:2", 2, false),
405            sample_task("B.bar:3", 3, true),
406        ];
407        let dir = TempDir::new()?;
408        let out = Utf8PathBuf::from_path_buf(dir.path().join("attempts.jsonl"))
409            .map_err(|path| Error::NonUtf8Path { path })?;
410
411        let options = EvalOptions {
412            runner,
413            lake_root: Utf8PathBuf::from("."),
414            timeout: Duration::from_secs(30),
415        };
416        let mut writer = JsonlWriter::create(&out)?;
417        let summary = run_eval(&options, &tasks, &mut writer).await?;
418
419        assert_eq!(summary.tasks_read, 2);
420        assert_eq!(summary.attempts_written, 2);
421        assert_eq!(summary.runner_errors, 0);
422        assert_eq!(summary.id_mismatches, 0);
423
424        let content = std::fs::read_to_string(out.as_std_path())?;
425        let mut attempts = Vec::new();
426        for line in content.lines().filter(|line| !line.trim().is_empty()) {
427            attempts.push(serde_json::from_str::<Attempt>(line)?);
428        }
429        assert_eq!(attempts.len(), 2);
430        // task_id comes from the sent task, replacement from the runner.
431        assert_eq!(attempts[0].task_id, "A.foo:2");
432        assert_eq!(attempts[1].task_id, "B.bar:3");
433        assert_eq!(attempts[0].replacement, "  rfl");
434        assert_eq!(attempts[0].model.as_deref(), Some("echo-runner"));
435        // prompt_hash falls back to a hash of the sent task line.
436        assert!(attempts[0].prompt_hash.is_some());
437        // The error task carries its diagnostic through for replay scoring.
438        assert!(attempts[1].original_diagnostic.is_some());
439        Ok(())
440    }
441}