Skip to main content

batty_cli/shim/
runtime_codex.rs

1//! Codex SDK-mode shim runtime: communicates with Codex via JSONL events
2//! from `codex exec --json` instead of screen-scraping a PTY.
3//!
4//! Unlike the Claude SDK runtime (persistent subprocess with stdin NDJSON),
5//! Codex uses a **spawn-per-message** model: each `SendMessage` launches a
6//! new `codex exec --json` subprocess. Multi-turn context is preserved via
7//! `codex exec resume <thread_id>`.
8//!
9//! Emits the same `Command`/`Event` protocol to the orchestrator as the
10//! PTY runtime, making it transparent to all upstream consumers.
11
12use std::collections::VecDeque;
13use std::io::{BufRead, BufReader};
14use std::process::{Child, Command, Stdio};
15use std::sync::{Arc, Mutex};
16use std::thread;
17use std::time::{Duration, Instant};
18
19use anyhow::{Context, Result};
20
21use super::codex_types::{self, CodexEvent};
22use super::common::{
23    self, MAX_QUEUE_DEPTH, QueuedMessage, SESSION_STATS_INTERVAL_SECS, drain_queue_errors,
24    format_injected_message,
25};
26use super::protocol::{Channel, Command as ShimCommand, Event, ShimState};
27use super::pty_log::PtyLogWriter;
28use super::runtime::ShimArgs;
29
30// ---------------------------------------------------------------------------
31// Configuration
32// ---------------------------------------------------------------------------
33
34const PROCESS_EXIT_POLL_MS: u64 = 100;
35const GROUP_TERM_GRACE_SECS: u64 = 2;
36
37// ---------------------------------------------------------------------------
38// Shared state
39// ---------------------------------------------------------------------------
40
41struct CodexState {
42    state: ShimState,
43    state_changed_at: Instant,
44    started_at: Instant,
45    /// Thread ID from the first `thread.started` event, used for resume.
46    thread_id: Option<String>,
47    /// Accumulated agent response text for the current turn.
48    accumulated_response: String,
49    /// Message ID of the currently pending (in-flight) message.
50    pending_message_id: Option<String>,
51    /// Messages queued while the agent is in Working state.
52    message_queue: VecDeque<QueuedMessage>,
53    /// Total bytes of response text received.
54    cumulative_output_bytes: u64,
55    /// The codex binary name/path.
56    program: String,
57    /// Working directory for spawning subprocesses.
58    cwd: std::path::PathBuf,
59}
60
61// ---------------------------------------------------------------------------
62// Main entry point
63// ---------------------------------------------------------------------------
64
65/// Run the Codex SDK-mode shim. This function does not return until the shim exits.
66///
67/// Unlike the Claude SDK runtime, each message spawns a new `codex exec --json`
68/// subprocess. The shim manages the lifecycle per-message and emits the same
69/// `Command`/`Event` protocol to the orchestrator.
70pub fn run_codex_sdk(args: ShimArgs, channel: Channel) -> Result<()> {
71    eprintln!("[shim-codex {}] started (spawn-per-message mode)", args.id);
72
73    // Shared state
74    let state = Arc::new(Mutex::new(CodexState {
75        state: ShimState::Idle,
76        state_changed_at: Instant::now(),
77        started_at: Instant::now(),
78        thread_id: None,
79        accumulated_response: String::new(),
80        pending_message_id: None,
81        message_queue: VecDeque::new(),
82        cumulative_output_bytes: 0,
83        program: "codex".to_string(),
84        cwd: args.cwd.clone(),
85    }));
86
87    // PTY log writer (optional — writes readable text for tmux display)
88    let pty_log: Option<Arc<Mutex<PtyLogWriter>>> = args
89        .pty_log_path
90        .as_deref()
91        .map(|p| PtyLogWriter::new(p).context("failed to create PTY log"))
92        .transpose()?
93        .map(|w| Arc::new(Mutex::new(w)));
94
95    // Channel clones
96    let mut cmd_channel = channel;
97
98    // Emit Ready immediately — no persistent subprocess to wait for.
99    cmd_channel.send(&Event::Ready)?;
100
101    // Session stats thread
102    let state_stats = Arc::clone(&state);
103    let mut stats_channel = cmd_channel
104        .try_clone()
105        .context("failed to clone channel for stats")?;
106    thread::spawn(move || {
107        loop {
108            thread::sleep(Duration::from_secs(SESSION_STATS_INTERVAL_SECS));
109            let st = state_stats.lock().unwrap();
110            if st.state == ShimState::Dead {
111                return;
112            }
113            let output_bytes = st.cumulative_output_bytes;
114            let uptime_secs = st.started_at.elapsed().as_secs();
115            drop(st);
116
117            if stats_channel
118                .send(&Event::SessionStats {
119                    output_bytes,
120                    uptime_secs,
121                })
122                .is_err()
123            {
124                return;
125            }
126        }
127    });
128
129    // Command loop (main thread)
130    let state_cmd = Arc::clone(&state);
131    let shim_id = args.id.clone();
132    loop {
133        let cmd = match cmd_channel.recv::<ShimCommand>() {
134            Ok(Some(c)) => c,
135            Ok(None) => {
136                eprintln!("[shim-codex {shim_id}] orchestrator disconnected");
137                break;
138            }
139            Err(e) => {
140                eprintln!("[shim-codex {shim_id}] channel error: {e}");
141                break;
142            }
143        };
144
145        match cmd {
146            ShimCommand::SendMessage {
147                from,
148                body,
149                message_id,
150            } => {
151                let mut st = state_cmd.lock().unwrap();
152                match st.state {
153                    ShimState::Idle => {
154                        st.pending_message_id = message_id;
155                        st.accumulated_response.clear();
156                        st.state = ShimState::Working;
157                        st.state_changed_at = Instant::now();
158                        let thread_id = st.thread_id.clone();
159                        let program = st.program.clone();
160                        let cwd = st.cwd.clone();
161                        drop(st);
162
163                        cmd_channel.send(&Event::StateChanged {
164                            from: ShimState::Idle,
165                            to: ShimState::Working,
166                            summary: String::new(),
167                        })?;
168
169                        // Spawn codex exec subprocess for this message
170                        let text = format_injected_message(&from, &body);
171                        let exec_cmd =
172                            codex_types::codex_sdk_command(&program, &text, thread_id.as_deref());
173
174                        let mut evt_channel = cmd_channel
175                            .try_clone()
176                            .context("failed to clone channel for codex exec")?;
177                        let state_exec = Arc::clone(&state_cmd);
178                        let pty_log_exec = pty_log.clone();
179                        let shim_id_exec = shim_id.clone();
180
181                        // Run the codex exec subprocess in a background thread
182                        thread::spawn(move || {
183                            run_codex_exec(
184                                &shim_id_exec,
185                                &exec_cmd,
186                                &cwd,
187                                &state_exec,
188                                &mut evt_channel,
189                                pty_log_exec.as_ref(),
190                            );
191                        });
192                    }
193                    ShimState::Working => {
194                        // Queue the message
195                        if st.message_queue.len() >= MAX_QUEUE_DEPTH {
196                            let dropped = st.message_queue.pop_front();
197                            let dropped_id = dropped.as_ref().and_then(|m| m.message_id.clone());
198                            st.message_queue.push_back(QueuedMessage {
199                                from,
200                                body,
201                                message_id,
202                            });
203                            let depth = st.message_queue.len();
204                            drop(st);
205
206                            cmd_channel.send(&Event::Error {
207                                command: "SendMessage".into(),
208                                reason: format!(
209                                    "message queue full ({MAX_QUEUE_DEPTH}), dropped oldest message{}",
210                                    dropped_id
211                                        .map(|id| format!(" (id: {id})"))
212                                        .unwrap_or_default(),
213                                ),
214                            })?;
215                            cmd_channel.send(&Event::Warning {
216                                message: format!(
217                                    "message queued while agent working (depth: {depth})"
218                                ),
219                                idle_secs: None,
220                            })?;
221                        } else {
222                            st.message_queue.push_back(QueuedMessage {
223                                from,
224                                body,
225                                message_id,
226                            });
227                            let depth = st.message_queue.len();
228                            drop(st);
229
230                            cmd_channel.send(&Event::Warning {
231                                message: format!(
232                                    "message queued while agent working (depth: {depth})"
233                                ),
234                                idle_secs: None,
235                            })?;
236                        }
237                    }
238                    other => {
239                        drop(st);
240                        cmd_channel.send(&Event::Error {
241                            command: "SendMessage".into(),
242                            reason: format!("agent in {other} state, cannot accept message"),
243                        })?;
244                    }
245                }
246            }
247
248            ShimCommand::CaptureScreen { last_n_lines } => {
249                let st = state_cmd.lock().unwrap();
250                let content = match last_n_lines {
251                    Some(n) => last_n_lines_of(&st.accumulated_response, n),
252                    None => st.accumulated_response.clone(),
253                };
254                drop(st);
255                cmd_channel.send(&Event::ScreenCapture {
256                    content,
257                    cursor_row: 0,
258                    cursor_col: 0,
259                })?;
260            }
261
262            ShimCommand::GetState => {
263                let st = state_cmd.lock().unwrap();
264                let since = st.state_changed_at.elapsed().as_secs();
265                let state = st.state;
266                drop(st);
267                cmd_channel.send(&Event::State {
268                    state,
269                    since_secs: since,
270                })?;
271            }
272
273            ShimCommand::Resize { .. } => {
274                // No-op — no PTY.
275            }
276
277            ShimCommand::Ping => {
278                cmd_channel.send(&Event::Pong)?;
279            }
280
281            ShimCommand::Shutdown { .. } => {
282                eprintln!("[shim-codex {shim_id}] shutdown requested");
283                let mut st = state_cmd.lock().unwrap();
284                st.state = ShimState::Dead;
285                st.state_changed_at = Instant::now();
286                drop(st);
287                break;
288            }
289
290            ShimCommand::Kill => {
291                let mut st = state_cmd.lock().unwrap();
292                st.state = ShimState::Dead;
293                st.state_changed_at = Instant::now();
294                drop(st);
295                break;
296            }
297        }
298    }
299
300    Ok(())
301}
302
303// ---------------------------------------------------------------------------
304// Per-message subprocess execution
305// ---------------------------------------------------------------------------
306
307/// Spawn `codex exec --json ...` and process its JSONL output.
308/// When the subprocess exits, transition back to Idle and drain the queue.
309fn run_codex_exec(
310    shim_id: &str,
311    exec_cmd: &str,
312    cwd: &std::path::Path,
313    state: &Arc<Mutex<CodexState>>,
314    evt_channel: &mut Channel,
315    pty_log: Option<&Arc<Mutex<PtyLogWriter>>>,
316) {
317    // Spawn the subprocess
318    let mut child = match Command::new("bash")
319        .args(["-lc", exec_cmd])
320        .current_dir(cwd)
321        .stdin(Stdio::null())
322        .stdout(Stdio::piped())
323        .stderr(Stdio::piped())
324        .env_remove("CLAUDECODE")
325        .spawn()
326    {
327        Ok(c) => c,
328        Err(e) => {
329            eprintln!("[shim-codex {shim_id}] failed to spawn codex exec: {e}");
330            let mut st = state.lock().unwrap();
331            let msg_id = st.pending_message_id.take();
332            st.state = ShimState::Idle;
333            st.state_changed_at = Instant::now();
334            drop(st);
335            let _ = evt_channel.send(&Event::Error {
336                command: "SendMessage".into(),
337                reason: format!("codex exec spawn failed: {e}"),
338            });
339            let _ = evt_channel.send(&Event::StateChanged {
340                from: ShimState::Working,
341                to: ShimState::Idle,
342                summary: format!("spawn failed: {e}"),
343            });
344            let _ = evt_channel.send(&Event::Completion {
345                message_id: msg_id,
346                response: String::new(),
347                last_lines: format!("spawn failed: {e}"),
348            });
349            return;
350        }
351    };
352
353    let child_pid = child.id();
354    eprintln!("[shim-codex {shim_id}] codex exec spawned (pid {child_pid})");
355
356    let stdout = child.stdout.take().unwrap();
357    let stderr = child.stderr.take().unwrap();
358
359    // stderr reader (log only)
360    let shim_id_err = shim_id.to_string();
361    let pty_log_err = pty_log.map(Arc::clone);
362    thread::spawn(move || {
363        let reader = BufReader::new(stderr);
364        for line_result in reader.lines() {
365            match line_result {
366                Ok(line) => {
367                    eprintln!("[shim-codex {shim_id_err}] stderr: {line}");
368                    if let Some(ref log) = pty_log_err {
369                        let _ = log
370                            .lock()
371                            .unwrap()
372                            .write(format!("[stderr] {line}\n").as_bytes());
373                    }
374                }
375                Err(_) => break,
376            }
377        }
378    });
379
380    // stdout JSONL reader
381    let reader = BufReader::new(stdout);
382    for line_result in reader.lines() {
383        let line = match line_result {
384            Ok(l) => l,
385            Err(e) => {
386                eprintln!("[shim-codex {shim_id}] stdout read error: {e}");
387                break;
388            }
389        };
390
391        if line.trim().is_empty() {
392            continue;
393        }
394
395        let evt: CodexEvent = match serde_json::from_str(&line) {
396            Ok(e) => e,
397            Err(e) => {
398                eprintln!("[shim-codex {shim_id}] ignoring unparseable JSONL: {e}");
399                continue;
400            }
401        };
402
403        match evt.event_type.as_str() {
404            "thread.started" => {
405                if let Some(tid) = evt.thread_id {
406                    let mut st = state.lock().unwrap();
407                    st.thread_id = Some(tid.clone());
408                    eprintln!("[shim-codex {shim_id}] thread started: {tid}");
409                }
410            }
411
412            "item.completed" | "item.updated" => {
413                if let Some(ref item) = evt.item {
414                    if let Some(text) = item.agent_text() {
415                        if !text.is_empty() {
416                            let mut st = state.lock().unwrap();
417                            // Replace accumulated response on each complete agent_message
418                            // (Codex sends the full text each time, not deltas)
419                            if evt.event_type == "item.completed" {
420                                st.accumulated_response = text.to_string();
421                            }
422                            st.cumulative_output_bytes += text.len() as u64;
423                            drop(st);
424
425                            if let Some(log) = pty_log {
426                                let _ = log.lock().unwrap().write(text.as_bytes());
427                                let _ = log.lock().unwrap().write(b"\n");
428                            }
429                        }
430                    }
431                }
432            }
433
434            "turn.failed" => {
435                let error_msg = evt
436                    .error
437                    .as_ref()
438                    .map(|e| e.message.clone())
439                    .unwrap_or_else(|| "unknown error".to_string());
440                eprintln!("[shim-codex {shim_id}] turn failed: {error_msg}");
441
442                // Check for context exhaustion
443                if common::detect_context_exhausted(&error_msg) {
444                    let mut st = state.lock().unwrap();
445                    let last_lines = last_n_lines_of(&st.accumulated_response, 5);
446                    st.state = ShimState::ContextExhausted;
447                    st.state_changed_at = Instant::now();
448                    let drain =
449                        drain_queue_errors(&mut st.message_queue, ShimState::ContextExhausted);
450                    drop(st);
451
452                    let _ = evt_channel.send(&Event::StateChanged {
453                        from: ShimState::Working,
454                        to: ShimState::ContextExhausted,
455                        summary: last_lines.clone(),
456                    });
457                    let _ = evt_channel.send(&Event::ContextExhausted {
458                        message: error_msg,
459                        last_lines,
460                    });
461                    for event in drain {
462                        let _ = evt_channel.send(&event);
463                    }
464                    return;
465                }
466            }
467
468            "error" => {
469                let error_msg = evt
470                    .error
471                    .as_ref()
472                    .map(|e| e.message.clone())
473                    .unwrap_or_else(|| "stream error".to_string());
474                eprintln!("[shim-codex {shim_id}] error event: {error_msg}");
475            }
476
477            // turn.started, turn.completed, item.started — informational, no action
478            _ => {}
479        }
480    }
481
482    // stdout closed — subprocess finished. Wait for exit.
483    let exit_code = child.wait().ok().and_then(|s| s.code());
484    eprintln!("[shim-codex {shim_id}] codex exec exited (code: {exit_code:?})");
485
486    // Transition Working → Idle, emit Completion
487    let mut st = state.lock().unwrap();
488    let response = std::mem::take(&mut st.accumulated_response);
489    let last_lines = last_n_lines_of(&response, 5);
490    let msg_id = st.pending_message_id.take();
491    st.state = ShimState::Idle;
492    st.state_changed_at = Instant::now();
493
494    // Check for queued messages
495    let queued_msg = if !st.message_queue.is_empty() {
496        st.message_queue.pop_front()
497    } else {
498        None
499    };
500
501    if let Some(ref qm) = queued_msg {
502        st.pending_message_id = qm.message_id.clone();
503        st.state = ShimState::Working;
504        st.state_changed_at = Instant::now();
505        st.accumulated_response.clear();
506    }
507
508    let thread_id = st.thread_id.clone();
509    let program = st.program.clone();
510    let cwd_owned = st.cwd.clone();
511    let queue_depth = st.message_queue.len();
512    drop(st);
513
514    let _ = evt_channel.send(&Event::StateChanged {
515        from: ShimState::Working,
516        to: ShimState::Idle,
517        summary: last_lines.clone(),
518    });
519    let _ = evt_channel.send(&Event::Completion {
520        message_id: msg_id,
521        response,
522        last_lines,
523    });
524
525    // Drain queued message by spawning another codex exec
526    if let Some(qm) = queued_msg {
527        let _ = evt_channel.send(&Event::StateChanged {
528            from: ShimState::Idle,
529            to: ShimState::Working,
530            summary: format!("delivering queued message ({queue_depth} remaining)"),
531        });
532
533        let text = format_injected_message(&qm.from, &qm.body);
534        let exec_cmd = codex_types::codex_sdk_command(&program, &text, thread_id.as_deref());
535
536        // Recursive call for queued message (same thread)
537        run_codex_exec(shim_id, &exec_cmd, &cwd_owned, state, evt_channel, pty_log);
538    }
539}
540
541// ---------------------------------------------------------------------------
542// Helpers
543// ---------------------------------------------------------------------------
544
545/// Terminate a child process: SIGTERM, grace period, then SIGKILL.
546#[allow(dead_code)]
547fn terminate_child(child: &mut Child) {
548    let pid = child.id();
549    #[cfg(unix)]
550    {
551        unsafe {
552            libc::kill(pid as i32, libc::SIGTERM);
553        }
554        let deadline = Instant::now() + Duration::from_secs(GROUP_TERM_GRACE_SECS);
555        loop {
556            if Instant::now() > deadline {
557                break;
558            }
559            match child.try_wait() {
560                Ok(Some(_)) => return,
561                _ => thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS)),
562            }
563        }
564        unsafe {
565            libc::kill(pid as i32, libc::SIGKILL);
566        }
567    }
568    #[allow(unreachable_code)]
569    {
570        let _ = child.kill();
571    }
572}
573
574/// Extract the last N lines from a string.
575fn last_n_lines_of(text: &str, n: usize) -> String {
576    let lines: Vec<&str> = text.lines().collect();
577    let start = lines.len().saturating_sub(n);
578    lines[start..].join("\n")
579}
580
581// ---------------------------------------------------------------------------
582// Tests
583// ---------------------------------------------------------------------------
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use crate::shim::protocol;
589
590    #[test]
591    fn last_n_lines_basic() {
592        assert_eq!(last_n_lines_of("a\nb\nc", 2), "b\nc");
593        assert_eq!(last_n_lines_of("a\nb\nc", 10), "a\nb\nc");
594        assert_eq!(last_n_lines_of("", 5), "");
595    }
596
597    #[test]
598    fn codex_state_initial() {
599        let st = CodexState {
600            state: ShimState::Idle,
601            state_changed_at: Instant::now(),
602            started_at: Instant::now(),
603            thread_id: None,
604            accumulated_response: String::new(),
605            pending_message_id: None,
606            message_queue: VecDeque::new(),
607            cumulative_output_bytes: 0,
608            program: "codex".into(),
609            cwd: std::path::PathBuf::from("/tmp"),
610        };
611        assert_eq!(st.state, ShimState::Idle);
612        assert!(st.thread_id.is_none());
613    }
614
615    #[test]
616    fn channel_events_roundtrip() {
617        let (parent_sock, child_sock) = protocol::socketpair().unwrap();
618        let mut parent = protocol::Channel::new(parent_sock);
619        let mut child = protocol::Channel::new(child_sock);
620
621        child.send(&Event::Ready).unwrap();
622        let event: Event = parent.recv().unwrap().unwrap();
623        assert!(matches!(event, Event::Ready));
624    }
625}