Skip to main content

zagens_runtime/repl/
runtime.rs

1//! Long-lived Python REPL runtime.
2//!
3//! One `python3 -u` subprocess lives for the duration of an RLM turn (or an
4//! inline `repl` block sequence in the agent loop). Code blocks are sent
5//! over stdin framed by `__RLM_RUN__`/`__RLM_END__` sentinels; the bootstrap
6//! `exec()`s them into the same global namespace so variables, imports,
7//! and even open file handles persist naturally across rounds.
8//!
9//! Sub-LLM helpers (`llm_query`, `llm_query_batched`, `rlm_query`,
10//! `rlm_query_batched`) are wired through a stdin/stdout RPC protocol:
11//! Python emits `__RLM_REQ_<sid>__::{json}` on stdout, Rust dispatches the
12//! request and writes `__RLM_RESP_<sid>__::{json}` back on stdin. No HTTP
13//! sidecar, no temp ports — the same pipes carry both control and data.
14//!
15//! The session id (`<sid>`) is a UUID generated per spawn, so user output
16//! that happens to contain "REQ" or "FINAL" can't be confused with control
17//! messages.
18
19use std::path::{Path, PathBuf};
20use std::process::Stdio;
21use std::time::{Duration, Instant};
22
23use serde::{Deserialize, Serialize};
24use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
25use tokio::process::{Child, ChildStdin, ChildStdout, Command};
26use uuid::Uuid;
27
28// ---------------------------------------------------------------------------
29// Public types
30// ---------------------------------------------------------------------------
31
32/// Result of executing one code block.
33#[derive(Debug, Clone)]
34pub struct ReplRound {
35    /// Stdout shown to the model as metadata next round.
36    pub stdout: String,
37    /// Full stdout (with sentinels stripped, but otherwise raw).
38    pub full_stdout: String,
39    /// Stderr from this round (if any).
40    pub stderr: String,
41    /// `True` if the user code raised an unhandled Python exception.
42    pub has_error: bool,
43    /// Captured `FINAL(value)` payload, if any.
44    pub final_value: Option<String>,
45    /// Number of `llm_query`/`rlm_query` RPCs the round issued.
46    pub rpc_count: u32,
47    /// Wall-clock duration of the round.
48    pub elapsed: Duration,
49}
50
51/// One RPC request emitted by Python during a round.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "type", rename_all = "snake_case")]
54pub enum RpcRequest {
55    /// `llm_query(prompt, model=None, max_tokens=None, system=None)`
56    Llm {
57        prompt: String,
58        #[serde(default)]
59        model: Option<String>,
60        #[serde(default)]
61        max_tokens: Option<u32>,
62        #[serde(default)]
63        system: Option<String>,
64    },
65    /// `llm_query_batched(prompts, model=None)`
66    LlmBatch {
67        prompts: Vec<String>,
68        #[serde(default)]
69        model: Option<String>,
70    },
71    /// `rlm_query(prompt, model=None)` — recursive sub-RLM (paper's `sub_RLM`).
72    Rlm {
73        prompt: String,
74        #[serde(default)]
75        model: Option<String>,
76    },
77    /// `rlm_query_batched(prompts, model=None)`
78    RlmBatch {
79        prompts: Vec<String>,
80        #[serde(default)]
81        model: Option<String>,
82    },
83}
84
85/// Response for one RPC request.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87#[serde(untagged)]
88pub enum RpcResponse {
89    /// Single-text reply (Llm / Rlm).
90    Single(SingleResp),
91    /// Batch reply (LlmBatch / RlmBatch).
92    Batch(BatchResp),
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct SingleResp {
97    #[serde(default)]
98    pub text: String,
99    #[serde(default, skip_serializing_if = "Option::is_none")]
100    pub error: Option<String>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct BatchResp {
105    pub results: Vec<SingleResp>,
106}
107
108/// Trait-object handle for dispatching Python RPCs back into Rust.
109///
110/// Each RLM turn supplies one. Implementations forward to the LLM client
111/// (and recursively into `run_rlm_turn_inner` for `Rlm` / `RlmBatch`).
112pub trait RpcDispatcher: Send + Sync {
113    fn dispatch<'a>(
114        &'a self,
115        req: RpcRequest,
116    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = RpcResponse> + Send + 'a>>;
117}
118
119// ---------------------------------------------------------------------------
120// Constants
121// ---------------------------------------------------------------------------
122
123const DEFAULT_STDOUT_LIMIT: usize = 8_192;
124const ROUND_TIMEOUT: Duration = Duration::from_secs(180);
125#[cfg(not(windows))]
126const SPAWN_READY_TIMEOUT: Duration = Duration::from_secs(10);
127#[cfg(windows)]
128const SPAWN_READY_TIMEOUT: Duration = Duration::from_secs(30);
129
130// ---------------------------------------------------------------------------
131// PythonRuntime
132// ---------------------------------------------------------------------------
133
134/// Long-lived Python REPL.
135#[derive(Debug)]
136pub struct PythonRuntime {
137    child: Child,
138    stdin: ChildStdin,
139    stdout: BufReader<ChildStdout>,
140    /// Per-spawn session id used in protocol sentinels.
141    session_id: String,
142    /// Path to the file holding `context` (kept around for cleanup).
143    context_path: Option<PathBuf>,
144    stdout_limit: usize,
145    round_count: u64,
146    started: Instant,
147}
148
149impl PythonRuntime {
150    /// Spawn a REPL with no `context` variable and no LLM helpers wired up.
151    /// Used by the agent loop for inline `repl` blocks the model emits in
152    /// regular conversation.
153    pub async fn new() -> Result<Self, String> {
154        Self::spawn_inner(None).await
155    }
156
157    /// Compatibility shim — older RLM code path used to pass a state file.
158    /// The state file is no longer used, but the path doubles as an extra
159    /// scratch location callers can rely on for cleanup symmetry.
160    pub fn with_state_path(_path: PathBuf) -> Self {
161        // Synchronous constructor is no longer meaningful: spawning Python
162        // is async. Callers in turn.rs already use `spawn_with_context` —
163        // this stub is kept only so the public surface compiles for any
164        // out-of-tree user. It returns a deliberately broken runtime that
165        // panics on first use, which is preferable to silently lying.
166        unreachable!(
167            "PythonRuntime::with_state_path is deprecated — \
168             use PythonRuntime::new() or PythonRuntime::spawn_with_context()"
169        )
170    }
171
172    /// Spawn a REPL with `context` (and `ctx`) preloaded from a file. Used
173    /// by the RLM turn loop.
174    pub async fn spawn_with_context(context_path: &Path) -> Result<Self, String> {
175        Self::spawn_inner(Some(context_path)).await
176    }
177
178    async fn spawn_inner(context_path: Option<&Path>) -> Result<Self, String> {
179        let session_id = Uuid::new_v4().simple().to_string();
180        let bootstrap = render_bootstrap(&session_id);
181
182        let python_bin = crate::python_env::find_python()
183            .map(|(bin, _, _)| bin)
184            .unwrap_or_else(|| "python3".to_string());
185        let mut cmd = Command::new(&python_bin);
186        cmd.arg("-I")
187            .arg("-u")
188            .arg("-c")
189            .arg(&bootstrap)
190            .stdin(Stdio::piped())
191            .stdout(Stdio::piped())
192            .stderr(Stdio::piped())
193            .kill_on_drop(true);
194
195        if let Some(path) = context_path {
196            cmd.env("RLM_CONTEXT_FILE", path);
197        }
198
199        let mut child = cmd
200            .spawn()
201            .map_err(|e| format!("failed to spawn {python_bin}: {e}"))?;
202
203        let stdin = child
204            .stdin
205            .take()
206            .ok_or_else(|| format!("{python_bin} stdin pipe missing"))?;
207        let raw_stdout = child
208            .stdout
209            .take()
210            .ok_or_else(|| format!("{python_bin} stdout pipe missing"))?;
211        let stdout = BufReader::new(raw_stdout);
212
213        let mut rt = Self {
214            child,
215            stdin,
216            stdout,
217            session_id: session_id.clone(),
218            context_path: context_path.map(Path::to_path_buf),
219            stdout_limit: DEFAULT_STDOUT_LIMIT,
220            round_count: 0,
221            started: Instant::now(),
222        };
223
224        // Wait for `__RLM_READY_<sid>__` before handing control back. If
225        // Python failed to start (missing module, syntax error in the
226        // bootstrap, etc.), this is where we'll find out.
227        let ready_sentinel = format!("__RLM_READY_{session_id}__");
228        match tokio::time::timeout(SPAWN_READY_TIMEOUT, rt.read_until_ready(&ready_sentinel)).await
229        {
230            Ok(Ok(())) => Ok(rt),
231            Ok(Err(e)) => {
232                let _ = rt.child.kill().await;
233                Err(format!("python3 bootstrap failed: {e}"))
234            }
235            Err(_) => {
236                let _ = rt.child.kill().await;
237                Err(format!(
238                    "python3 bootstrap did not signal ready within {}s",
239                    SPAWN_READY_TIMEOUT.as_secs()
240                ))
241            }
242        }
243    }
244
245    async fn read_until_ready(&mut self, ready_sentinel: &str) -> Result<(), String> {
246        loop {
247            let mut line = String::new();
248            let n = self
249                .stdout
250                .read_line(&mut line)
251                .await
252                .map_err(|e| format!("stdout read: {e}"))?;
253            if n == 0 {
254                return Err("python3 closed stdout before ready signal".to_string());
255            }
256            let trimmed = line.trim_end_matches(['\n', '\r']);
257            if trimmed == ready_sentinel {
258                return Ok(());
259            }
260            // Pre-ready output is rare; ignore it.
261        }
262    }
263
264    /// Execute a Python code block with no RPC dispatcher. Used for inline
265    /// `repl` blocks where `llm_query()` should fall back to a sentinel.
266    pub async fn execute(&mut self, code: &str) -> Result<ReplRound, String> {
267        self.run(code, None::<&dyn RpcDispatcher>).await
268    }
269
270    /// Execute a code block, dispatching any sub-LLM RPCs through `bridge`.
271    ///
272    /// Returns once Python emits `__RLM_DONE_<sid>__` or the round timeout
273    /// elapses (whichever happens first).
274    pub async fn run<D>(&mut self, code: &str, bridge: Option<&D>) -> Result<ReplRound, String>
275    where
276        D: RpcDispatcher + ?Sized,
277    {
278        let started = Instant::now();
279        self.round_count += 1;
280        let round_id = self.round_count;
281
282        // Send the code header + body + end marker in one write.
283        let header = format!("__RLM_RUN_{}__::{round_id}\n", self.session_id);
284        let footer = format!("__RLM_END_{}__\n", self.session_id);
285        let payload = format!("{header}{code}\n{footer}");
286        self.stdin
287            .write_all(payload.as_bytes())
288            .await
289            .map_err(|e| format!("stdin write: {e}"))?;
290        self.stdin
291            .flush()
292            .await
293            .map_err(|e| format!("stdin flush: {e}"))?;
294
295        // Sentinels for this session.
296        let req_prefix = format!("__RLM_REQ_{}__::", self.session_id);
297        let final_prefix = format!("__RLM_FINAL_{}__::", self.session_id);
298        let err_prefix = format!("__RLM_ERR_{}__::", self.session_id);
299        let done_prefix = format!("__RLM_DONE_{}__::", self.session_id);
300
301        let mut stdout_buf = String::new();
302        let mut final_value: Option<String> = None;
303        let mut had_error = false;
304        let mut rpc_count: u32 = 0;
305
306        let read_loop = async {
307            loop {
308                let mut line = String::new();
309                let n = self
310                    .stdout
311                    .read_line(&mut line)
312                    .await
313                    .map_err(|e| format!("stdout read: {e}"))?;
314                if n == 0 {
315                    return Err("python3 closed stdout mid-round".to_string());
316                }
317                let trimmed = line.trim_end_matches(['\n', '\r']);
318
319                if let Some(rest) = trimmed.strip_prefix(&done_prefix) {
320                    let _ = rest;
321                    break;
322                }
323                if let Some(rest) = trimmed.strip_prefix(&final_prefix) {
324                    // Stored as a JSON-encoded string.
325                    let v =
326                        serde_json::from_str::<String>(rest).unwrap_or_else(|_| rest.to_string());
327                    final_value = Some(v);
328                    continue;
329                }
330                if let Some(rest) = trimmed.strip_prefix(&err_prefix) {
331                    let traceback =
332                        serde_json::from_str::<String>(rest).unwrap_or_else(|_| rest.to_string());
333                    had_error = true;
334                    stdout_buf.push_str(&format!("[traceback]\n{traceback}\n"));
335                    continue;
336                }
337                if let Some(rest) = trimmed.strip_prefix(&req_prefix) {
338                    rpc_count = rpc_count.saturating_add(1);
339                    let req: RpcRequest = match serde_json::from_str(rest) {
340                        Ok(r) => r,
341                        Err(e) => {
342                            // Send an error response so Python isn't blocked.
343                            self.send_resp(&RpcResponse::Single(SingleResp {
344                                text: String::new(),
345                                error: Some(format!("malformed RPC: {e}")),
346                            }))
347                            .await?;
348                            continue;
349                        }
350                    };
351                    let resp = match bridge {
352                        Some(b) => b.dispatch(req).await,
353                        None => RpcResponse::Single(SingleResp {
354                            text: String::new(),
355                            error: Some("no LLM bridge bound to this REPL".to_string()),
356                        }),
357                    };
358                    self.send_resp(&resp).await?;
359                    continue;
360                }
361
362                stdout_buf.push_str(&line);
363            }
364            Ok::<_, String>(())
365        };
366
367        match tokio::time::timeout(ROUND_TIMEOUT, read_loop).await {
368            Ok(Ok(())) => {}
369            Ok(Err(e)) => return Err(e),
370            Err(_) => {
371                return Err(format!(
372                    "REPL round timed out after {}s",
373                    ROUND_TIMEOUT.as_secs()
374                ));
375            }
376        }
377
378        let stderr = self.drain_stderr().await;
379        let display = truncate_stdout(stdout_buf.trim_end_matches('\n'), self.stdout_limit);
380
381        Ok(ReplRound {
382            stdout: display,
383            full_stdout: stdout_buf,
384            stderr,
385            has_error: had_error,
386            final_value,
387            rpc_count,
388            elapsed: started.elapsed(),
389        })
390    }
391
392    async fn send_resp(&mut self, resp: &RpcResponse) -> Result<(), String> {
393        let body = serde_json::to_string(resp).map_err(|e| format!("encode rpc resp: {e}"))?;
394        let line = format!("__RLM_RESP_{}__::{body}\n", self.session_id);
395        self.stdin
396            .write_all(line.as_bytes())
397            .await
398            .map_err(|e| format!("stdin write resp: {e}"))?;
399        self.stdin
400            .flush()
401            .await
402            .map_err(|e| format!("stdin flush resp: {e}"))?;
403        Ok(())
404    }
405
406    async fn drain_stderr(&mut self) -> String {
407        // We don't continuously read stderr — drain whatever's pending after
408        // a round so it can show up in error reports without deadlocking
409        // anything during normal operation.
410        let Some(stderr) = self.child.stderr.as_mut() else {
411            return String::new();
412        };
413        use tokio::io::AsyncReadExt;
414        let mut buf = Vec::new();
415        // Best-effort read with a tight deadline; we don't want to block.
416        let fut = async {
417            let mut chunk = [0u8; 4096];
418            loop {
419                match tokio::time::timeout(Duration::from_millis(20), stderr.read(&mut chunk)).await
420                {
421                    Ok(Ok(0)) => break,
422                    Ok(Ok(n)) => buf.extend_from_slice(&chunk[..n]),
423                    _ => break,
424                }
425            }
426        };
427        let _ = fut.await;
428        String::from_utf8_lossy(&buf).to_string()
429    }
430
431    /// Total rounds executed.
432    pub fn round_count(&self) -> u64 {
433        self.round_count
434    }
435
436    /// Wall-clock uptime since spawn.
437    pub fn uptime(&self) -> Duration {
438        self.started.elapsed()
439    }
440
441    /// Cleanly tear down the subprocess.
442    pub async fn shutdown(mut self) {
443        let _ = self.stdin.shutdown().await;
444        let _ = self.child.kill().await;
445        if let Some(path) = self.context_path.take() {
446            let _ = tokio::fs::remove_file(path).await;
447        }
448    }
449}
450
451impl Drop for PythonRuntime {
452    fn drop(&mut self) {
453        // tokio sets `kill_on_drop(true)` on the child; the context file
454        // (if any) is removed on `shutdown()` — drop is best-effort.
455        if let Some(path) = self.context_path.take() {
456            let _ = std::fs::remove_file(path);
457        }
458    }
459}
460
461// ---------------------------------------------------------------------------
462// Bootstrap script
463// ---------------------------------------------------------------------------
464
465/// Render the Python bootstrap with session-specific sentinels baked in.
466/// The sentinels include a UUID to prevent user prints from being mistaken
467/// for control messages.
468fn render_bootstrap(session_id: &str) -> String {
469    BOOTSTRAP_TEMPLATE.replace("__SID__", session_id)
470}
471
472const BOOTSTRAP_TEMPLATE: &str = r#"
473import json as _json
474import os as _os
475import sys as _sys
476import traceback as _traceback
477
478_SID = "__SID__"
479_REQ = f"__RLM_REQ_{_SID}__::"
480_RESP = f"__RLM_RESP_{_SID}__::"
481_FINAL = f"__RLM_FINAL_{_SID}__::"
482_ERR = f"__RLM_ERR_{_SID}__::"
483_RUN = f"__RLM_RUN_{_SID}__::"
484_END = f"__RLM_END_{_SID}__"
485_DONE = f"__RLM_DONE_{_SID}__::"
486_READY = f"__RLM_READY_{_SID}__"
487
488def _rpc(req):
489    _sys.stdout.write(_REQ + _json.dumps(req) + "\n")
490    _sys.stdout.flush()
491    line = _sys.stdin.readline()
492    if not line:
493        return {"error": "rust driver closed stdin"}
494    if line.startswith(_RESP):
495        try:
496            return _json.loads(line[len(_RESP):])
497        except Exception as e:
498            return {"error": f"malformed rpc resp: {e}"}
499    return {"error": f"unexpected protocol line: {line[:120]!r}"}
500
501def llm_query(prompt, model=None, max_tokens=None, system=None):
502    """One-shot sub-LLM call. The model arg is accepted for compatibility but ignored by Rust."""
503    resp = _rpc({"type":"llm","prompt":str(prompt),"model":model,
504                 "max_tokens":max_tokens,"system":system})
505    if isinstance(resp, dict) and resp.get("error"):
506        return f"[llm_query error: {resp['error']}]"
507    if isinstance(resp, dict):
508        return resp.get("text","")
509    return str(resp)
510
511def llm_query_batched(prompts, model=None):
512    """Run multiple sub-LLM calls concurrently. The model arg is accepted for compatibility but ignored."""
513    if not isinstance(prompts, (list, tuple)):
514        return ["[llm_query_batched: prompts must be a list]"]
515    resp = _rpc({"type":"llm_batch","prompts":[str(p) for p in prompts],"model":model})
516    if isinstance(resp, dict) and resp.get("error"):
517        return [f"[llm_query_batched: {resp['error']}]" for _ in prompts]
518    results = (resp or {}).get("results", []) if isinstance(resp, dict) else []
519    if len(results) != len(prompts):
520        return [f"[llm_query_batched: size mismatch ({len(results)}/{len(prompts)})]" for _ in prompts]
521    out = []
522    for r in results:
523        if r.get("error"):
524            out.append(f"[child err: {r['error']}]")
525        else:
526            out.append(r.get("text",""))
527    return out
528
529def rlm_query(prompt, model=None):
530    """Recursive sub-RLM. The model arg is accepted for compatibility but ignored by Rust."""
531    resp = _rpc({"type":"rlm","prompt":str(prompt),"model":model})
532    if isinstance(resp, dict) and resp.get("error"):
533        return f"[rlm_query error: {resp['error']}]"
534    if isinstance(resp, dict):
535        return resp.get("text","")
536    return str(resp)
537
538def rlm_query_batched(prompts, model=None):
539    """Run multiple recursive sub-RLMs in parallel. The model arg is accepted for compatibility but ignored."""
540    if not isinstance(prompts, (list, tuple)):
541        return ["[rlm_query_batched: prompts must be a list]"]
542    resp = _rpc({"type":"rlm_batch","prompts":[str(p) for p in prompts],"model":model})
543    if isinstance(resp, dict) and resp.get("error"):
544        return [f"[rlm_query_batched: {resp['error']}]" for _ in prompts]
545    results = (resp or {}).get("results", []) if isinstance(resp, dict) else []
546    if len(results) != len(prompts):
547        return [f"[rlm_query_batched: size mismatch ({len(results)}/{len(prompts)})]" for _ in prompts]
548    out = []
549    for r in results:
550        if r.get("error"):
551            out.append(f"[child err: {r['error']}]")
552        else:
553            out.append(r.get("text",""))
554    return out
555
556def FINAL(value):
557    """Signal the loop to stop with this final answer."""
558    _sys.stdout.write(_FINAL + _json.dumps(str(value)) + "\n")
559    _sys.stdout.flush()
560
561def FINAL_VAR(name):
562    """Signal the loop to stop, returning the value of a named variable."""
563    name_str = str(name).strip().strip("'\"")
564    if name_str in globals():
565        FINAL(globals()[name_str])
566    else:
567        print(f"FINAL_VAR error: variable '{name_str}' not found. "
568              f"Use SHOW_VARS() to list available variables.", flush=True)
569
570def SHOW_VARS():
571    """Return a dict of {name: type-name} for all user variables in the REPL."""
572    out = {}
573    for k, v in list(globals().items()):
574        if k.startswith('_') or k in _BOOTSTRAP_NAMES:
575            continue
576        out[k] = type(v).__name__
577    return out
578
579def repl_get(name, default=None):
580    return globals().get(str(name), default)
581
582def repl_set(name, value):
583    globals()[str(name)] = value
584
585# Load the long input as `context` (and `ctx`) from a file. This keeps the
586# big string out of the process command-line and out of the LLM's window.
587_ctx_file = _os.environ.get("RLM_CONTEXT_FILE","")
588context = ""
589if _ctx_file:
590    try:
591        with open(_ctx_file, "r", encoding="utf-8", errors="replace") as f:
592            context = f.read()
593    except Exception as e:
594        _sys.stderr.write(f"[bootstrap] failed to load context: {e}\n")
595ctx = context  # short alias matching aleph
596
597_BOOTSTRAP_NAMES = {
598    "_SID","_REQ","_RESP","_FINAL","_ERR","_RUN","_END","_DONE","_READY",
599    "_rpc","_ctx_file","_BOOTSTRAP_NAMES","_main_loop",
600    "llm_query","llm_query_batched","rlm_query","rlm_query_batched",
601    "FINAL","FINAL_VAR","SHOW_VARS","repl_get","repl_set",
602    "context","ctx",
603    "_json","_os","_sys","_traceback",
604}
605
606def _main_loop():
607    _sys.stdout.write(_READY + "\n")
608    _sys.stdout.flush()
609    while True:
610        header = _sys.stdin.readline()
611        if not header:
612            return
613        if not header.startswith(_RUN):
614            continue
615        round_id = header.rstrip("\n")[len(_RUN):]
616        code_lines = []
617        while True:
618            line = _sys.stdin.readline()
619            if not line:
620                return
621            if line.rstrip("\n") == _END:
622                break
623            code_lines.append(line)
624        code = "".join(code_lines)
625        try:
626            exec(compile(code, f"<repl-{round_id}>", "exec"), globals())
627        except SystemExit:
628            _sys.stdout.write(_DONE + round_id + "\n")
629            _sys.stdout.flush()
630            return
631        except BaseException:
632            tb = _traceback.format_exc()
633            _sys.stdout.write(_ERR + _json.dumps(tb) + "\n")
634            _sys.stdout.flush()
635        _sys.stdout.write(_DONE + round_id + "\n")
636        _sys.stdout.flush()
637
638_main_loop()
639"#;
640
641// ---------------------------------------------------------------------------
642// Helpers
643// ---------------------------------------------------------------------------
644
645fn truncate_stdout(stdout: &str, limit: usize) -> String {
646    if stdout.len() <= limit {
647        return stdout.to_string();
648    }
649    let take = limit.saturating_sub(80);
650    let mut out: String = stdout.chars().take(take).collect();
651    let omitted = stdout.len().saturating_sub(out.len());
652    out.push_str(&format!(
653        "\n\n[... REPL output truncated: {omitted} bytes omitted ...]\n"
654    ));
655    out
656}
657
658// ---------------------------------------------------------------------------
659// Tests
660// ---------------------------------------------------------------------------
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665    use std::sync::Arc;
666    use std::sync::atomic::{AtomicU32, Ordering};
667    use tokio::sync::Mutex;
668
669    /// In-process dispatcher that records what was asked and replies with
670    /// canned text. Lets tests verify the round-trip without real network.
671    struct StubBridge {
672        calls: Arc<Mutex<Vec<RpcRequest>>>,
673        canned: Arc<AtomicU32>,
674    }
675
676    impl StubBridge {
677        fn new() -> Self {
678            Self {
679                calls: Arc::new(Mutex::new(Vec::new())),
680                canned: Arc::new(AtomicU32::new(0)),
681            }
682        }
683    }
684
685    impl RpcDispatcher for StubBridge {
686        fn dispatch<'a>(
687            &'a self,
688            req: RpcRequest,
689        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = RpcResponse> + Send + 'a>> {
690            Box::pin(async move {
691                self.calls.lock().await.push(req.clone());
692                let n = self.canned.fetch_add(1, Ordering::Relaxed);
693                match req {
694                    RpcRequest::Llm { prompt, .. } | RpcRequest::Rlm { prompt, .. } => {
695                        RpcResponse::Single(SingleResp {
696                            text: format!("stub#{n}: {prompt}"),
697                            error: None,
698                        })
699                    }
700                    RpcRequest::LlmBatch { prompts, .. } | RpcRequest::RlmBatch { prompts, .. } => {
701                        let results = prompts
702                            .into_iter()
703                            .enumerate()
704                            .map(|(i, p)| SingleResp {
705                                text: format!("stub#{n}.{i}: {p}"),
706                                error: None,
707                            })
708                            .collect();
709                        RpcResponse::Batch(BatchResp { results })
710                    }
711                }
712            })
713        }
714    }
715
716    fn write_temp_context(body: &str) -> std::path::PathBuf {
717        let dir = std::env::temp_dir().join("deepseek_repl_runtime_tests");
718        std::fs::create_dir_all(&dir).unwrap();
719        let path = dir.join(format!("ctx_{}_{}.txt", std::process::id(), Uuid::new_v4()));
720        std::fs::write(&path, body).unwrap();
721        path
722    }
723
724    #[tokio::test]
725    async fn spawns_and_executes_simple_print() {
726        let mut rt = PythonRuntime::new().await.expect("spawn");
727        let round = rt.execute("print('hello world')").await.expect("execute");
728        assert!(round.stdout.contains("hello world"));
729        assert!(!round.has_error);
730        assert!(round.final_value.is_none());
731        assert_eq!(round.rpc_count, 0);
732        rt.shutdown().await;
733    }
734
735    #[tokio::test]
736    async fn variables_persist_across_rounds() {
737        let mut rt = PythonRuntime::new().await.expect("spawn");
738        rt.execute("x = [1, 2, 3]").await.expect("r1");
739        rt.execute("x.append(99)").await.expect("r2");
740        let round = rt.execute("print(x)").await.expect("r3");
741        assert!(round.stdout.contains("[1, 2, 3, 99]"));
742        rt.shutdown().await;
743    }
744
745    #[tokio::test]
746    async fn imports_persist_across_rounds() {
747        let mut rt = PythonRuntime::new().await.expect("spawn");
748        rt.execute("import math").await.expect("r1");
749        let round = rt.execute("print(math.pi)").await.expect("r2");
750        assert!(round.stdout.contains("3.14"));
751        rt.shutdown().await;
752    }
753
754    #[tokio::test]
755    async fn context_loads_from_file() {
756        let path = write_temp_context("the quick brown fox");
757        let mut rt = PythonRuntime::spawn_with_context(&path)
758            .await
759            .expect("spawn");
760        let round = rt
761            .execute("print(len(context), context[:5])")
762            .await
763            .expect("execute");
764        assert!(round.stdout.contains("19"));
765        assert!(round.stdout.contains("the q"));
766        rt.shutdown().await;
767    }
768
769    #[tokio::test]
770    async fn ctx_alias_works() {
771        let path = write_temp_context("aleph-style");
772        let mut rt = PythonRuntime::spawn_with_context(&path)
773            .await
774            .expect("spawn");
775        let round = rt.execute("print(ctx)").await.expect("execute");
776        assert!(round.stdout.contains("aleph-style"));
777        rt.shutdown().await;
778    }
779
780    #[tokio::test]
781    async fn final_is_captured() {
782        let mut rt = PythonRuntime::new().await.expect("spawn");
783        let round = rt
784            .execute("FINAL('the answer is 42')")
785            .await
786            .expect("execute");
787        assert_eq!(round.final_value.as_deref(), Some("the answer is 42"));
788        rt.shutdown().await;
789    }
790
791    #[tokio::test]
792    async fn final_var_is_captured() {
793        let mut rt = PythonRuntime::new().await.expect("spawn");
794        rt.execute("answer = 'computed'").await.expect("r1");
795        let round = rt.execute("FINAL_VAR('answer')").await.expect("r2");
796        assert_eq!(round.final_value.as_deref(), Some("computed"));
797        rt.shutdown().await;
798    }
799
800    #[tokio::test]
801    async fn errors_are_reported_without_killing_runtime() {
802        let mut rt = PythonRuntime::new().await.expect("spawn");
803        let r1 = rt.execute("raise ValueError('boom')").await.expect("r1");
804        assert!(r1.has_error);
805        assert!(r1.full_stdout.contains("boom") || r1.stdout.contains("boom"));
806        // The runtime is still alive — next round should work.
807        let r2 = rt.execute("print('still here')").await.expect("r2");
808        assert!(r2.stdout.contains("still here"));
809        rt.shutdown().await;
810    }
811
812    #[tokio::test]
813    async fn rpc_dispatcher_round_trips_llm_query() {
814        let bridge = StubBridge::new();
815        let calls = Arc::clone(&bridge.calls);
816
817        let mut rt = PythonRuntime::new().await.expect("spawn");
818        let round = rt
819            .run("print(llm_query('hello'))", Some(&bridge))
820            .await
821            .expect("execute");
822        assert!(
823            round.stdout.contains("stub#0: hello"),
824            "stdout: {:?}",
825            round.stdout
826        );
827        assert_eq!(round.rpc_count, 1);
828
829        let recorded = calls.lock().await;
830        assert_eq!(recorded.len(), 1);
831        match &recorded[0] {
832            RpcRequest::Llm { prompt, .. } => assert_eq!(prompt, "hello"),
833            other => panic!("expected Llm request, got {other:?}"),
834        }
835        drop(recorded);
836        rt.shutdown().await;
837    }
838
839    #[tokio::test]
840    async fn rpc_dispatcher_round_trips_batch() {
841        let bridge = StubBridge::new();
842        let mut rt = PythonRuntime::new().await.expect("spawn");
843        let round = rt
844            .run(
845                "outs = llm_query_batched(['a','b','c']); print('|'.join(outs))",
846                Some(&bridge),
847            )
848            .await
849            .expect("execute");
850        assert!(round.stdout.contains("stub#0.0: a"));
851        assert!(round.stdout.contains("stub#0.1: b"));
852        assert!(round.stdout.contains("stub#0.2: c"));
853        assert_eq!(round.rpc_count, 1);
854        rt.shutdown().await;
855    }
856
857    #[tokio::test]
858    async fn no_dispatcher_returns_unavailable_sentinel() {
859        let mut rt = PythonRuntime::new().await.expect("spawn");
860        let round = rt.execute("print(llm_query('hi'))").await.expect("execute");
861        assert!(
862            round.stdout.contains("[llm_query error:") || round.stdout.contains("no LLM bridge"),
863            "stdout: {:?}",
864            round.stdout
865        );
866        rt.shutdown().await;
867    }
868
869    #[test]
870    fn truncate_keeps_short_unchanged() {
871        assert_eq!(truncate_stdout("hello", 100), "hello");
872    }
873
874    #[test]
875    fn truncate_clips_long() {
876        let long = "a".repeat(10_000);
877        let out = truncate_stdout(&long, 1024);
878        assert!(out.len() < 1500);
879        assert!(out.contains("truncated"));
880    }
881}