Skip to main content

batty_cli/shim/
runtime_codex.rs

1//! Codex SDK-mode shim runtime: communicates with Codex via JSONL events
2//! from `codex exec --json` instead of screen-scraping a PTY.
3//!
4//! Unlike the Claude SDK runtime (persistent subprocess with stdin NDJSON),
5//! Codex uses a **spawn-per-message** model: each `SendMessage` launches a
6//! new `codex exec --json` subprocess. Multi-turn context is preserved via
7//! `codex exec resume <thread_id>`.
8//!
9//! Emits the same `Command`/`Event` protocol to the orchestrator as the
10//! PTY runtime, making it transparent to all upstream consumers.
11
12use std::collections::VecDeque;
13use std::io::Write;
14use std::io::{BufRead, BufReader};
15use std::process::{Child, Command, Stdio};
16use std::sync::{Arc, Mutex};
17use std::thread;
18use std::time::{Duration, Instant};
19
20use anyhow::{Context, Result};
21
22use super::codex_types::{self, CodexEvent};
23use super::common::{
24    self, MAX_QUEUE_DEPTH, QueuedMessage, SESSION_STATS_INTERVAL_SECS, drain_queue_errors,
25    format_injected_message,
26};
27use super::protocol::{Channel, Command as ShimCommand, Event, ShimState};
28use super::pty_log::PtyLogWriter;
29use super::runtime::ShimArgs;
30
31// ---------------------------------------------------------------------------
32// Configuration
33// ---------------------------------------------------------------------------
34
35const PROCESS_EXIT_POLL_MS: u64 = 100;
36const GROUP_TERM_GRACE_SECS: u64 = 2;
37
38// ---------------------------------------------------------------------------
39// Shared state
40// ---------------------------------------------------------------------------
41
42struct CodexState {
43    state: ShimState,
44    state_changed_at: Instant,
45    started_at: Instant,
46    /// Thread ID from the first `thread.started` event, used for resume.
47    thread_id: Option<String>,
48    /// Accumulated agent response text for the current turn.
49    accumulated_response: String,
50    /// Message ID of the currently pending (in-flight) message.
51    pending_message_id: Option<String>,
52    /// Messages queued while the agent is in Working state.
53    message_queue: VecDeque<QueuedMessage>,
54    /// Total bytes of response text received.
55    cumulative_output_bytes: u64,
56    /// The codex binary name/path.
57    program: String,
58    /// Working directory for spawning subprocesses.
59    cwd: std::path::PathBuf,
60    /// Whether a ContextApproaching event has already been emitted this session.
61    context_approaching_emitted: bool,
62}
63
64// ---------------------------------------------------------------------------
65// Main entry point
66// ---------------------------------------------------------------------------
67
68/// Run the Codex SDK-mode shim. This function does not return until the shim exits.
69///
70/// Unlike the Claude SDK runtime, each message spawns a new `codex exec --json`
71/// subprocess. The shim manages the lifecycle per-message and emits the same
72/// `Command`/`Event` protocol to the orchestrator.
73pub fn run_codex_sdk(args: ShimArgs, channel: Channel) -> Result<()> {
74    eprintln!("[shim-codex {}] started (spawn-per-message mode)", args.id);
75
76    // Shared state
77    let state = Arc::new(Mutex::new(CodexState {
78        state: ShimState::Idle,
79        state_changed_at: Instant::now(),
80        started_at: Instant::now(),
81        thread_id: None,
82        accumulated_response: String::new(),
83        pending_message_id: None,
84        message_queue: VecDeque::new(),
85        cumulative_output_bytes: 0,
86        program: "codex".to_string(),
87        cwd: args.cwd.clone(),
88        context_approaching_emitted: false,
89    }));
90
91    // PTY log writer (optional — writes readable text for tmux display)
92    let pty_log: Option<Arc<Mutex<PtyLogWriter>>> = args
93        .pty_log_path
94        .as_deref()
95        .map(|p| PtyLogWriter::new(p).context("failed to create PTY log"))
96        .transpose()?
97        .map(|w| Arc::new(Mutex::new(w)));
98
99    // Channel clones
100    let mut cmd_channel = channel;
101
102    // Emit Ready immediately — no persistent subprocess to wait for.
103    cmd_channel.send(&Event::Ready)?;
104
105    // Session stats thread
106    let state_stats = Arc::clone(&state);
107    let mut stats_channel = cmd_channel
108        .try_clone()
109        .context("failed to clone channel for stats")?;
110    thread::spawn(move || {
111        loop {
112            thread::sleep(Duration::from_secs(SESSION_STATS_INTERVAL_SECS));
113            let st = state_stats.lock().unwrap();
114            if st.state == ShimState::Dead {
115                return;
116            }
117            let output_bytes = st.cumulative_output_bytes;
118            let uptime_secs = st.started_at.elapsed().as_secs();
119            drop(st);
120
121            if stats_channel
122                .send(&Event::SessionStats {
123                    output_bytes,
124                    uptime_secs,
125                    input_tokens: 0,
126                    output_tokens: 0,
127                    context_usage_pct: None,
128                })
129                .is_err()
130            {
131                return;
132            }
133        }
134    });
135
136    // Command loop (main thread)
137    let state_cmd = Arc::clone(&state);
138    let shim_id = args.id.clone();
139    loop {
140        let cmd = match cmd_channel.recv::<ShimCommand>() {
141            Ok(Some(c)) => c,
142            Ok(None) => {
143                eprintln!("[shim-codex {shim_id}] orchestrator disconnected");
144                break;
145            }
146            Err(e) => {
147                eprintln!("[shim-codex {shim_id}] channel error: {e}");
148                break;
149            }
150        };
151
152        match cmd {
153            ShimCommand::SendMessage {
154                from,
155                body,
156                message_id,
157            } => {
158                let delivery_id = message_id.clone();
159                let mut st = state_cmd.lock().unwrap();
160                match st.state {
161                    ShimState::Idle => {
162                        st.pending_message_id = message_id;
163                        st.accumulated_response.clear();
164                        st.state = ShimState::Working;
165                        st.state_changed_at = Instant::now();
166                        let thread_id = st.thread_id.clone();
167                        let program = st.program.clone();
168                        let cwd = st.cwd.clone();
169                        drop(st);
170
171                        cmd_channel.send(&Event::StateChanged {
172                            from: ShimState::Idle,
173                            to: ShimState::Working,
174                            summary: String::new(),
175                        })?;
176                        if let Some(id) = delivery_id {
177                            cmd_channel.send(&Event::MessageDelivered { id })?;
178                        }
179
180                        // Spawn codex exec subprocess for this message
181                        let text = format_injected_message(&from, &body);
182                        let (exec_program, exec_args) =
183                            codex_types::codex_sdk_args(&program, thread_id.as_deref());
184
185                        let mut evt_channel = cmd_channel
186                            .try_clone()
187                            .context("failed to clone channel for codex exec")?;
188                        let state_exec = Arc::clone(&state_cmd);
189                        let pty_log_exec = pty_log.clone();
190                        let shim_id_exec = shim_id.clone();
191
192                        // Run the codex exec subprocess in a background thread
193                        thread::spawn(move || {
194                            run_codex_exec(
195                                &shim_id_exec,
196                                &exec_program,
197                                &exec_args,
198                                &text,
199                                &cwd,
200                                &state_exec,
201                                &mut evt_channel,
202                                pty_log_exec.as_ref(),
203                            );
204                        });
205                    }
206                    ShimState::Working => {
207                        // Queue the message
208                        if st.message_queue.len() >= MAX_QUEUE_DEPTH {
209                            let dropped = st.message_queue.pop_front();
210                            let dropped_id = dropped.as_ref().and_then(|m| m.message_id.clone());
211                            st.message_queue.push_back(QueuedMessage {
212                                from,
213                                body,
214                                message_id,
215                            });
216                            let depth = st.message_queue.len();
217                            drop(st);
218
219                            cmd_channel.send(&Event::Error {
220                                command: "SendMessage".into(),
221                                reason: format!(
222                                    "message queue full ({MAX_QUEUE_DEPTH}), dropped oldest message{}",
223                                    dropped_id
224                                        .map(|id| format!(" (id: {id})"))
225                                        .unwrap_or_default(),
226                                ),
227                            })?;
228                            cmd_channel.send(&Event::Warning {
229                                message: format!(
230                                    "message queued while agent working (depth: {depth})"
231                                ),
232                                idle_secs: None,
233                            })?;
234                        } else {
235                            st.message_queue.push_back(QueuedMessage {
236                                from,
237                                body,
238                                message_id,
239                            });
240                            let depth = st.message_queue.len();
241                            drop(st);
242
243                            cmd_channel.send(&Event::Warning {
244                                message: format!(
245                                    "message queued while agent working (depth: {depth})"
246                                ),
247                                idle_secs: None,
248                            })?;
249                        }
250                    }
251                    other => {
252                        drop(st);
253                        cmd_channel.send(&Event::Error {
254                            command: "SendMessage".into(),
255                            reason: format!("agent in {other} state, cannot accept message"),
256                        })?;
257                    }
258                }
259            }
260
261            ShimCommand::CaptureScreen { last_n_lines } => {
262                let st = state_cmd.lock().unwrap();
263                let content = match last_n_lines {
264                    Some(n) => last_n_lines_of(&st.accumulated_response, n),
265                    None => st.accumulated_response.clone(),
266                };
267                drop(st);
268                cmd_channel.send(&Event::ScreenCapture {
269                    content,
270                    cursor_row: 0,
271                    cursor_col: 0,
272                })?;
273            }
274
275            ShimCommand::GetState => {
276                let st = state_cmd.lock().unwrap();
277                let since = st.state_changed_at.elapsed().as_secs();
278                let state = st.state;
279                drop(st);
280                cmd_channel.send(&Event::State {
281                    state,
282                    since_secs: since,
283                })?;
284            }
285
286            ShimCommand::Resize { .. } => {
287                // No-op — no PTY.
288            }
289
290            ShimCommand::Ping => {
291                cmd_channel.send(&Event::Pong)?;
292            }
293
294            ShimCommand::Shutdown { reason, .. } => {
295                eprintln!(
296                    "[shim-codex {shim_id}] shutdown requested ({})",
297                    reason.label()
298                );
299                if let Err(error) = super::runtime::preserve_work_before_kill(&args.cwd) {
300                    eprintln!(
301                        "[shim-codex {shim_id}] failed to preserve work before shutdown: {error:#}"
302                    );
303                }
304                let mut st = state_cmd.lock().unwrap();
305                st.state = ShimState::Dead;
306                st.state_changed_at = Instant::now();
307                drop(st);
308                break;
309            }
310
311            ShimCommand::Kill => {
312                if let Err(error) = super::runtime::preserve_work_before_kill(&args.cwd) {
313                    eprintln!(
314                        "[shim-codex {shim_id}] failed to preserve work before kill: {error:#}"
315                    );
316                }
317                let mut st = state_cmd.lock().unwrap();
318                st.state = ShimState::Dead;
319                st.state_changed_at = Instant::now();
320                drop(st);
321                break;
322            }
323        }
324    }
325
326    Ok(())
327}
328
329// ---------------------------------------------------------------------------
330// Per-message subprocess execution
331// ---------------------------------------------------------------------------
332
333/// Spawn `codex exec --json ...` and process its JSONL output.
334/// When the subprocess exits, transition back to Idle and drain the queue.
335#[allow(clippy::too_many_arguments)]
336fn run_codex_exec(
337    shim_id: &str,
338    program: &str,
339    args: &[String],
340    prompt: &str,
341    cwd: &std::path::Path,
342    state: &Arc<Mutex<CodexState>>,
343    evt_channel: &mut Channel,
344    pty_log: Option<&Arc<Mutex<PtyLogWriter>>>,
345) {
346    // Spawn the subprocess
347    let mut child = match Command::new(program)
348        .args(args)
349        .current_dir(cwd)
350        .stdin(Stdio::piped())
351        .stdout(Stdio::piped())
352        .stderr(Stdio::piped())
353        .env_remove("CLAUDECODE")
354        .spawn()
355    {
356        Ok(c) => c,
357        Err(e) => {
358            eprintln!("[shim-codex {shim_id}] failed to spawn codex exec: {e}");
359            let mut st = state.lock().unwrap();
360            let msg_id = st.pending_message_id.take();
361            st.state = ShimState::Idle;
362            st.state_changed_at = Instant::now();
363            drop(st);
364            let _ = evt_channel.send(&Event::Error {
365                command: "SendMessage".into(),
366                reason: format!("codex exec spawn failed: {e}"),
367            });
368            let _ = evt_channel.send(&Event::StateChanged {
369                from: ShimState::Working,
370                to: ShimState::Idle,
371                summary: format!("spawn failed: {e}"),
372            });
373            let _ = evt_channel.send(&Event::Completion {
374                message_id: msg_id,
375                response: String::new(),
376                last_lines: format!("spawn failed: {e}"),
377            });
378            return;
379        }
380    };
381
382    let child_pid = child.id();
383    eprintln!("[shim-codex {shim_id}] codex exec spawned (pid {child_pid})");
384
385    if let Some(mut stdin) = child.stdin.take() {
386        if let Err(e) = stdin.write_all(prompt.as_bytes()) {
387            eprintln!("[shim-codex {shim_id}] failed to write prompt to stdin: {e}");
388        }
389    }
390
391    let stdout = child.stdout.take().unwrap();
392    let stderr = child.stderr.take().unwrap();
393
394    // stderr reader (log only)
395    let shim_id_err = shim_id.to_string();
396    let pty_log_err = pty_log.map(Arc::clone);
397    thread::spawn(move || {
398        let reader = BufReader::new(stderr);
399        for line_result in reader.lines() {
400            match line_result {
401                Ok(line) => {
402                    eprintln!("[shim-codex {shim_id_err}] stderr: {line}");
403                    if let Some(ref log) = pty_log_err {
404                        let _ = log
405                            .lock()
406                            .unwrap()
407                            .write(format!("[stderr] {line}\n").as_bytes());
408                    }
409                }
410                Err(_) => break,
411            }
412        }
413    });
414
415    // stdout JSONL reader
416    let reader = BufReader::new(stdout);
417    for line_result in reader.lines() {
418        let line = match line_result {
419            Ok(l) => l,
420            Err(e) => {
421                eprintln!("[shim-codex {shim_id}] stdout read error: {e}");
422                break;
423            }
424        };
425
426        if line.trim().is_empty() {
427            continue;
428        }
429
430        let evt: CodexEvent = match serde_json::from_str(&line) {
431            Ok(e) => e,
432            Err(e) => {
433                eprintln!("[shim-codex {shim_id}] ignoring unparseable JSONL: {e}");
434                continue;
435            }
436        };
437
438        match evt.event_type.as_str() {
439            "thread.started" => {
440                if let Some(tid) = evt.thread_id {
441                    let mut st = state.lock().unwrap();
442                    st.thread_id = Some(tid.clone());
443                    eprintln!("[shim-codex {shim_id}] thread started: {tid}");
444                }
445            }
446
447            "item.completed" | "item.updated" => {
448                if let Some(ref item) = evt.item {
449                    if let Some(text) = item.agent_text() {
450                        if !text.is_empty() {
451                            let mut st = state.lock().unwrap();
452                            // Replace accumulated response on each complete agent_message
453                            // (Codex sends the full text each time, not deltas)
454                            if evt.event_type == "item.completed" {
455                                st.accumulated_response = text.to_string();
456                            }
457                            st.cumulative_output_bytes += text.len() as u64;
458
459                            // Check for context approaching limit (proactive).
460                            if !st.context_approaching_emitted
461                                && common::detect_context_approaching_limit(text)
462                            {
463                                st.context_approaching_emitted = true;
464                                drop(st);
465                                let _ = evt_channel.send(&Event::ContextApproaching {
466                                    message: "Agent output contains context-pressure signals"
467                                        .into(),
468                                    input_tokens: 0,
469                                    output_tokens: 0,
470                                });
471                            } else {
472                                drop(st);
473                            }
474
475                            if let Some(log) = pty_log {
476                                let _ = log.lock().unwrap().write(text.as_bytes());
477                                let _ = log.lock().unwrap().write(b"\n");
478                            }
479                        }
480                    }
481                }
482            }
483
484            "turn.failed" => {
485                let error_msg = evt
486                    .error
487                    .as_ref()
488                    .map(|e| e.message.clone())
489                    .unwrap_or_else(|| "unknown error".to_string());
490                eprintln!("[shim-codex {shim_id}] turn failed: {error_msg}");
491
492                // Check for context exhaustion
493                if common::detect_context_exhausted(&error_msg) {
494                    let mut st = state.lock().unwrap();
495                    let last_lines = last_n_lines_of(&st.accumulated_response, 5);
496                    st.state = ShimState::ContextExhausted;
497                    st.state_changed_at = Instant::now();
498                    let drain =
499                        drain_queue_errors(&mut st.message_queue, ShimState::ContextExhausted);
500                    drop(st);
501
502                    let _ = evt_channel.send(&Event::StateChanged {
503                        from: ShimState::Working,
504                        to: ShimState::ContextExhausted,
505                        summary: last_lines.clone(),
506                    });
507                    let _ = evt_channel.send(&Event::ContextExhausted {
508                        message: error_msg,
509                        last_lines,
510                    });
511                    for event in drain {
512                        let _ = evt_channel.send(&event);
513                    }
514                    return;
515                }
516            }
517
518            "error" => {
519                let error_msg = evt
520                    .error
521                    .as_ref()
522                    .map(|e| e.message.clone())
523                    .unwrap_or_else(|| "stream error".to_string());
524                eprintln!("[shim-codex {shim_id}] error event: {error_msg}");
525
526                // Detect quota/billing exhaustion — emit a specific event so
527                // the daemon can pause dispatch and alert the human.
528                let lower = error_msg.to_ascii_lowercase();
529                if lower.contains("usage limit")
530                    || lower.contains("quota")
531                    || lower.contains("billing")
532                    || lower.contains("purchase more credits")
533                {
534                    eprintln!("[shim-codex {shim_id}] QUOTA EXHAUSTED: {error_msg}");
535                    let _ = evt_channel.send(&Event::Error {
536                        command: "QuotaExhausted".into(),
537                        reason: error_msg.clone(),
538                    });
539                }
540            }
541
542            // turn.started, turn.completed, item.started — informational, no action
543            _ => {}
544        }
545    }
546
547    // stdout closed — subprocess finished. Wait for exit.
548    let exit_code = child.wait().ok().and_then(|s| s.code());
549    eprintln!("[shim-codex {shim_id}] codex exec exited (code: {exit_code:?})");
550
551    // Transition Working → Idle, emit Completion
552    let mut st = state.lock().unwrap();
553    let response = std::mem::take(&mut st.accumulated_response);
554    let last_lines = last_n_lines_of(&response, 5);
555    let msg_id = st.pending_message_id.take();
556    st.state = ShimState::Idle;
557    st.state_changed_at = Instant::now();
558
559    // Check for queued messages
560    let queued_msg = if !st.message_queue.is_empty() {
561        st.message_queue.pop_front()
562    } else {
563        None
564    };
565
566    if let Some(ref qm) = queued_msg {
567        st.pending_message_id = qm.message_id.clone();
568        st.state = ShimState::Working;
569        st.state_changed_at = Instant::now();
570        st.accumulated_response.clear();
571    }
572
573    let thread_id = st.thread_id.clone();
574    let program = st.program.clone();
575    let cwd_owned = st.cwd.clone();
576    let queue_depth = st.message_queue.len();
577    drop(st);
578
579    let _ = evt_channel.send(&Event::StateChanged {
580        from: ShimState::Working,
581        to: ShimState::Idle,
582        summary: last_lines.clone(),
583    });
584    let _ = evt_channel.send(&Event::Completion {
585        message_id: msg_id,
586        response,
587        last_lines,
588    });
589
590    // Drain queued message by spawning another codex exec
591    if let Some(qm) = queued_msg {
592        let _ = evt_channel.send(&Event::StateChanged {
593            from: ShimState::Idle,
594            to: ShimState::Working,
595            summary: format!("delivering queued message ({queue_depth} remaining)"),
596        });
597
598        let text = format_injected_message(&qm.from, &qm.body);
599        let (exec_program, exec_args) = codex_types::codex_sdk_args(&program, thread_id.as_deref());
600
601        // Recursive call for queued message (same thread)
602        run_codex_exec(
603            shim_id,
604            &exec_program,
605            &exec_args,
606            &text,
607            &cwd_owned,
608            state,
609            evt_channel,
610            pty_log,
611        );
612    }
613}
614
615// ---------------------------------------------------------------------------
616// Helpers
617// ---------------------------------------------------------------------------
618
619/// Terminate a child process: SIGTERM, grace period, then SIGKILL.
620#[allow(dead_code)]
621fn terminate_child(child: &mut Child) {
622    let pid = child.id();
623    #[cfg(unix)]
624    {
625        unsafe {
626            libc::kill(pid as i32, libc::SIGTERM);
627        }
628        let deadline = Instant::now() + Duration::from_secs(GROUP_TERM_GRACE_SECS);
629        loop {
630            if Instant::now() > deadline {
631                break;
632            }
633            match child.try_wait() {
634                Ok(Some(_)) => return,
635                _ => thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS)),
636            }
637        }
638        unsafe {
639            libc::kill(pid as i32, libc::SIGKILL);
640        }
641    }
642    #[allow(unreachable_code)]
643    {
644        let _ = child.kill();
645    }
646}
647
648/// Extract the last N lines from a string.
649fn last_n_lines_of(text: &str, n: usize) -> String {
650    let lines: Vec<&str> = text.lines().collect();
651    let start = lines.len().saturating_sub(n);
652    lines[start..].join("\n")
653}
654
655// ---------------------------------------------------------------------------
656// Tests
657// ---------------------------------------------------------------------------
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use crate::shim::protocol;
663
664    #[test]
665    fn last_n_lines_basic() {
666        assert_eq!(last_n_lines_of("a\nb\nc", 2), "b\nc");
667        assert_eq!(last_n_lines_of("a\nb\nc", 10), "a\nb\nc");
668        assert_eq!(last_n_lines_of("", 5), "");
669    }
670
671    #[test]
672    fn codex_state_initial() {
673        let st = CodexState {
674            state: ShimState::Idle,
675            state_changed_at: Instant::now(),
676            started_at: Instant::now(),
677            thread_id: None,
678            accumulated_response: String::new(),
679            pending_message_id: None,
680            message_queue: VecDeque::new(),
681            cumulative_output_bytes: 0,
682            program: "codex".into(),
683            cwd: std::path::PathBuf::from("/tmp"),
684            context_approaching_emitted: false,
685        };
686        assert_eq!(st.state, ShimState::Idle);
687        assert!(st.thread_id.is_none());
688    }
689
690    #[test]
691    fn channel_events_roundtrip() {
692        let (parent_sock, child_sock) = protocol::socketpair().unwrap();
693        let mut parent = protocol::Channel::new(parent_sock);
694        let mut child = protocol::Channel::new(child_sock);
695
696        child.send(&Event::Ready).unwrap();
697        let event: Event = parent.recv().unwrap().unwrap();
698        assert!(matches!(event, Event::Ready));
699    }
700}