1use std::collections::VecDeque;
13use std::io::Write;
14use std::io::{BufRead, BufReader};
15use std::process::{Child, Command, Stdio};
16use std::sync::{Arc, Mutex};
17use std::thread;
18use std::time::{Duration, Instant};
19
20use anyhow::{Context, Result};
21
22use super::codex_types::{self, CodexEvent};
23use super::common::{
24 self, MAX_QUEUE_DEPTH, QueuedMessage, SESSION_STATS_INTERVAL_SECS, drain_queue_errors,
25 format_injected_message,
26};
27use super::protocol::{Channel, Command as ShimCommand, Event, ShimState};
28use super::pty_log::PtyLogWriter;
29use super::runtime::ShimArgs;
30
31const PROCESS_EXIT_POLL_MS: u64 = 100;
36const GROUP_TERM_GRACE_SECS: u64 = 2;
37
38struct CodexState {
43 state: ShimState,
44 state_changed_at: Instant,
45 started_at: Instant,
46 thread_id: Option<String>,
48 accumulated_response: String,
50 pending_message_id: Option<String>,
52 message_queue: VecDeque<QueuedMessage>,
54 cumulative_output_bytes: u64,
56 program: String,
58 cwd: std::path::PathBuf,
60 context_approaching_emitted: bool,
62}
63
64pub fn run_codex_sdk(args: ShimArgs, channel: Channel) -> Result<()> {
74 eprintln!("[shim-codex {}] started (spawn-per-message mode)", args.id);
75
76 let state = Arc::new(Mutex::new(CodexState {
78 state: ShimState::Idle,
79 state_changed_at: Instant::now(),
80 started_at: Instant::now(),
81 thread_id: None,
82 accumulated_response: String::new(),
83 pending_message_id: None,
84 message_queue: VecDeque::new(),
85 cumulative_output_bytes: 0,
86 program: "codex".to_string(),
87 cwd: args.cwd.clone(),
88 context_approaching_emitted: false,
89 }));
90
91 let pty_log: Option<Arc<Mutex<PtyLogWriter>>> = args
93 .pty_log_path
94 .as_deref()
95 .map(|p| PtyLogWriter::new(p).context("failed to create PTY log"))
96 .transpose()?
97 .map(|w| Arc::new(Mutex::new(w)));
98
99 let mut cmd_channel = channel;
101
102 cmd_channel.send(&Event::Ready)?;
104
105 let state_stats = Arc::clone(&state);
107 let mut stats_channel = cmd_channel
108 .try_clone()
109 .context("failed to clone channel for stats")?;
110 thread::spawn(move || {
111 loop {
112 thread::sleep(Duration::from_secs(SESSION_STATS_INTERVAL_SECS));
113 let st = state_stats.lock().unwrap();
114 if st.state == ShimState::Dead {
115 return;
116 }
117 let output_bytes = st.cumulative_output_bytes;
118 let uptime_secs = st.started_at.elapsed().as_secs();
119 drop(st);
120
121 if stats_channel
122 .send(&Event::SessionStats {
123 output_bytes,
124 uptime_secs,
125 input_tokens: 0,
126 output_tokens: 0,
127 context_usage_pct: None,
128 })
129 .is_err()
130 {
131 return;
132 }
133 }
134 });
135
136 let state_cmd = Arc::clone(&state);
138 let shim_id = args.id.clone();
139 loop {
140 let cmd = match cmd_channel.recv::<ShimCommand>() {
141 Ok(Some(c)) => c,
142 Ok(None) => {
143 eprintln!("[shim-codex {shim_id}] orchestrator disconnected");
144 break;
145 }
146 Err(e) => {
147 eprintln!("[shim-codex {shim_id}] channel error: {e}");
148 break;
149 }
150 };
151
152 match cmd {
153 ShimCommand::SendMessage {
154 from,
155 body,
156 message_id,
157 } => {
158 let delivery_id = message_id.clone();
159 let mut st = state_cmd.lock().unwrap();
160 match st.state {
161 ShimState::Idle => {
162 st.pending_message_id = message_id;
163 st.accumulated_response.clear();
164 st.state = ShimState::Working;
165 st.state_changed_at = Instant::now();
166 let thread_id = st.thread_id.clone();
167 let program = st.program.clone();
168 let cwd = st.cwd.clone();
169 drop(st);
170
171 cmd_channel.send(&Event::StateChanged {
172 from: ShimState::Idle,
173 to: ShimState::Working,
174 summary: String::new(),
175 })?;
176 if let Some(id) = delivery_id {
177 cmd_channel.send(&Event::MessageDelivered { id })?;
178 }
179
180 let text = format_injected_message(&from, &body);
182 let (exec_program, exec_args) =
183 codex_types::codex_sdk_args(&program, thread_id.as_deref());
184
185 let mut evt_channel = cmd_channel
186 .try_clone()
187 .context("failed to clone channel for codex exec")?;
188 let state_exec = Arc::clone(&state_cmd);
189 let pty_log_exec = pty_log.clone();
190 let shim_id_exec = shim_id.clone();
191
192 thread::spawn(move || {
194 run_codex_exec(
195 &shim_id_exec,
196 &exec_program,
197 &exec_args,
198 &text,
199 &cwd,
200 &state_exec,
201 &mut evt_channel,
202 pty_log_exec.as_ref(),
203 );
204 });
205 }
206 ShimState::Working => {
207 if st.message_queue.len() >= MAX_QUEUE_DEPTH {
209 let dropped = st.message_queue.pop_front();
210 let dropped_id = dropped.as_ref().and_then(|m| m.message_id.clone());
211 st.message_queue.push_back(QueuedMessage {
212 from,
213 body,
214 message_id,
215 });
216 let depth = st.message_queue.len();
217 drop(st);
218
219 cmd_channel.send(&Event::Error {
220 command: "SendMessage".into(),
221 reason: format!(
222 "message queue full ({MAX_QUEUE_DEPTH}), dropped oldest message{}",
223 dropped_id
224 .map(|id| format!(" (id: {id})"))
225 .unwrap_or_default(),
226 ),
227 })?;
228 cmd_channel.send(&Event::Warning {
229 message: format!(
230 "message queued while agent working (depth: {depth})"
231 ),
232 idle_secs: None,
233 })?;
234 } else {
235 st.message_queue.push_back(QueuedMessage {
236 from,
237 body,
238 message_id,
239 });
240 let depth = st.message_queue.len();
241 drop(st);
242
243 cmd_channel.send(&Event::Warning {
244 message: format!(
245 "message queued while agent working (depth: {depth})"
246 ),
247 idle_secs: None,
248 })?;
249 }
250 }
251 other => {
252 drop(st);
253 cmd_channel.send(&Event::Error {
254 command: "SendMessage".into(),
255 reason: format!("agent in {other} state, cannot accept message"),
256 })?;
257 }
258 }
259 }
260
261 ShimCommand::CaptureScreen { last_n_lines } => {
262 let st = state_cmd.lock().unwrap();
263 let content = match last_n_lines {
264 Some(n) => last_n_lines_of(&st.accumulated_response, n),
265 None => st.accumulated_response.clone(),
266 };
267 drop(st);
268 cmd_channel.send(&Event::ScreenCapture {
269 content,
270 cursor_row: 0,
271 cursor_col: 0,
272 })?;
273 }
274
275 ShimCommand::GetState => {
276 let st = state_cmd.lock().unwrap();
277 let since = st.state_changed_at.elapsed().as_secs();
278 let state = st.state;
279 drop(st);
280 cmd_channel.send(&Event::State {
281 state,
282 since_secs: since,
283 })?;
284 }
285
286 ShimCommand::Resize { .. } => {
287 }
289
290 ShimCommand::Ping => {
291 cmd_channel.send(&Event::Pong)?;
292 }
293
294 ShimCommand::Shutdown { reason, .. } => {
295 eprintln!(
296 "[shim-codex {shim_id}] shutdown requested ({})",
297 reason.label()
298 );
299 if let Err(error) = super::runtime::preserve_work_before_kill(&args.cwd) {
300 eprintln!(
301 "[shim-codex {shim_id}] failed to preserve work before shutdown: {error:#}"
302 );
303 }
304 let mut st = state_cmd.lock().unwrap();
305 st.state = ShimState::Dead;
306 st.state_changed_at = Instant::now();
307 drop(st);
308 break;
309 }
310
311 ShimCommand::Kill => {
312 if let Err(error) = super::runtime::preserve_work_before_kill(&args.cwd) {
313 eprintln!(
314 "[shim-codex {shim_id}] failed to preserve work before kill: {error:#}"
315 );
316 }
317 let mut st = state_cmd.lock().unwrap();
318 st.state = ShimState::Dead;
319 st.state_changed_at = Instant::now();
320 drop(st);
321 break;
322 }
323 }
324 }
325
326 Ok(())
327}
328
329#[allow(clippy::too_many_arguments)]
336fn run_codex_exec(
337 shim_id: &str,
338 program: &str,
339 args: &[String],
340 prompt: &str,
341 cwd: &std::path::Path,
342 state: &Arc<Mutex<CodexState>>,
343 evt_channel: &mut Channel,
344 pty_log: Option<&Arc<Mutex<PtyLogWriter>>>,
345) {
346 let mut child = match Command::new(program)
348 .args(args)
349 .current_dir(cwd)
350 .stdin(Stdio::piped())
351 .stdout(Stdio::piped())
352 .stderr(Stdio::piped())
353 .env_remove("CLAUDECODE")
354 .spawn()
355 {
356 Ok(c) => c,
357 Err(e) => {
358 eprintln!("[shim-codex {shim_id}] failed to spawn codex exec: {e}");
359 let mut st = state.lock().unwrap();
360 let msg_id = st.pending_message_id.take();
361 st.state = ShimState::Idle;
362 st.state_changed_at = Instant::now();
363 drop(st);
364 let _ = evt_channel.send(&Event::Error {
365 command: "SendMessage".into(),
366 reason: format!("codex exec spawn failed: {e}"),
367 });
368 let _ = evt_channel.send(&Event::StateChanged {
369 from: ShimState::Working,
370 to: ShimState::Idle,
371 summary: format!("spawn failed: {e}"),
372 });
373 let _ = evt_channel.send(&Event::Completion {
374 message_id: msg_id,
375 response: String::new(),
376 last_lines: format!("spawn failed: {e}"),
377 });
378 return;
379 }
380 };
381
382 let child_pid = child.id();
383 eprintln!("[shim-codex {shim_id}] codex exec spawned (pid {child_pid})");
384
385 if let Some(mut stdin) = child.stdin.take() {
386 if let Err(e) = stdin.write_all(prompt.as_bytes()) {
387 eprintln!("[shim-codex {shim_id}] failed to write prompt to stdin: {e}");
388 }
389 }
390
391 let stdout = child.stdout.take().unwrap();
392 let stderr = child.stderr.take().unwrap();
393
394 let shim_id_err = shim_id.to_string();
396 let pty_log_err = pty_log.map(Arc::clone);
397 thread::spawn(move || {
398 let reader = BufReader::new(stderr);
399 for line_result in reader.lines() {
400 match line_result {
401 Ok(line) => {
402 eprintln!("[shim-codex {shim_id_err}] stderr: {line}");
403 if let Some(ref log) = pty_log_err {
404 let _ = log
405 .lock()
406 .unwrap()
407 .write(format!("[stderr] {line}\n").as_bytes());
408 }
409 }
410 Err(_) => break,
411 }
412 }
413 });
414
415 let reader = BufReader::new(stdout);
417 for line_result in reader.lines() {
418 let line = match line_result {
419 Ok(l) => l,
420 Err(e) => {
421 eprintln!("[shim-codex {shim_id}] stdout read error: {e}");
422 break;
423 }
424 };
425
426 if line.trim().is_empty() {
427 continue;
428 }
429
430 let evt: CodexEvent = match serde_json::from_str(&line) {
431 Ok(e) => e,
432 Err(e) => {
433 eprintln!("[shim-codex {shim_id}] ignoring unparseable JSONL: {e}");
434 continue;
435 }
436 };
437
438 match evt.event_type.as_str() {
439 "thread.started" => {
440 if let Some(tid) = evt.thread_id {
441 let mut st = state.lock().unwrap();
442 st.thread_id = Some(tid.clone());
443 eprintln!("[shim-codex {shim_id}] thread started: {tid}");
444 }
445 }
446
447 "item.completed" | "item.updated" => {
448 if let Some(ref item) = evt.item {
449 if let Some(text) = item.agent_text() {
450 if !text.is_empty() {
451 let mut st = state.lock().unwrap();
452 if evt.event_type == "item.completed" {
455 st.accumulated_response = text.to_string();
456 }
457 st.cumulative_output_bytes += text.len() as u64;
458
459 if !st.context_approaching_emitted
461 && common::detect_context_approaching_limit(text)
462 {
463 st.context_approaching_emitted = true;
464 drop(st);
465 let _ = evt_channel.send(&Event::ContextApproaching {
466 message: "Agent output contains context-pressure signals"
467 .into(),
468 input_tokens: 0,
469 output_tokens: 0,
470 });
471 } else {
472 drop(st);
473 }
474
475 if let Some(log) = pty_log {
476 let _ = log.lock().unwrap().write(text.as_bytes());
477 let _ = log.lock().unwrap().write(b"\n");
478 }
479 }
480 }
481 }
482 }
483
484 "turn.failed" => {
485 let error_msg = evt
486 .error
487 .as_ref()
488 .map(|e| e.message.clone())
489 .unwrap_or_else(|| "unknown error".to_string());
490 eprintln!("[shim-codex {shim_id}] turn failed: {error_msg}");
491
492 if common::detect_context_exhausted(&error_msg) {
494 let mut st = state.lock().unwrap();
495 let last_lines = last_n_lines_of(&st.accumulated_response, 5);
496 st.state = ShimState::ContextExhausted;
497 st.state_changed_at = Instant::now();
498 let drain =
499 drain_queue_errors(&mut st.message_queue, ShimState::ContextExhausted);
500 drop(st);
501
502 let _ = evt_channel.send(&Event::StateChanged {
503 from: ShimState::Working,
504 to: ShimState::ContextExhausted,
505 summary: last_lines.clone(),
506 });
507 let _ = evt_channel.send(&Event::ContextExhausted {
508 message: error_msg,
509 last_lines,
510 });
511 for event in drain {
512 let _ = evt_channel.send(&event);
513 }
514 return;
515 }
516 }
517
518 "error" => {
519 let error_msg = evt
520 .error
521 .as_ref()
522 .map(|e| e.message.clone())
523 .unwrap_or_else(|| "stream error".to_string());
524 eprintln!("[shim-codex {shim_id}] error event: {error_msg}");
525
526 let lower = error_msg.to_ascii_lowercase();
529 if lower.contains("usage limit")
530 || lower.contains("quota")
531 || lower.contains("billing")
532 || lower.contains("purchase more credits")
533 {
534 eprintln!("[shim-codex {shim_id}] QUOTA EXHAUSTED: {error_msg}");
535 let _ = evt_channel.send(&Event::Error {
536 command: "QuotaExhausted".into(),
537 reason: error_msg.clone(),
538 });
539 }
540 }
541
542 _ => {}
544 }
545 }
546
547 let exit_code = child.wait().ok().and_then(|s| s.code());
549 eprintln!("[shim-codex {shim_id}] codex exec exited (code: {exit_code:?})");
550
551 let mut st = state.lock().unwrap();
553 let response = std::mem::take(&mut st.accumulated_response);
554 let last_lines = last_n_lines_of(&response, 5);
555 let msg_id = st.pending_message_id.take();
556 st.state = ShimState::Idle;
557 st.state_changed_at = Instant::now();
558
559 let queued_msg = if !st.message_queue.is_empty() {
561 st.message_queue.pop_front()
562 } else {
563 None
564 };
565
566 if let Some(ref qm) = queued_msg {
567 st.pending_message_id = qm.message_id.clone();
568 st.state = ShimState::Working;
569 st.state_changed_at = Instant::now();
570 st.accumulated_response.clear();
571 }
572
573 let thread_id = st.thread_id.clone();
574 let program = st.program.clone();
575 let cwd_owned = st.cwd.clone();
576 let queue_depth = st.message_queue.len();
577 drop(st);
578
579 let _ = evt_channel.send(&Event::StateChanged {
580 from: ShimState::Working,
581 to: ShimState::Idle,
582 summary: last_lines.clone(),
583 });
584 let _ = evt_channel.send(&Event::Completion {
585 message_id: msg_id,
586 response,
587 last_lines,
588 });
589
590 if let Some(qm) = queued_msg {
592 let _ = evt_channel.send(&Event::StateChanged {
593 from: ShimState::Idle,
594 to: ShimState::Working,
595 summary: format!("delivering queued message ({queue_depth} remaining)"),
596 });
597
598 let text = format_injected_message(&qm.from, &qm.body);
599 let (exec_program, exec_args) = codex_types::codex_sdk_args(&program, thread_id.as_deref());
600
601 run_codex_exec(
603 shim_id,
604 &exec_program,
605 &exec_args,
606 &text,
607 &cwd_owned,
608 state,
609 evt_channel,
610 pty_log,
611 );
612 }
613}
614
615#[allow(dead_code)]
621fn terminate_child(child: &mut Child) {
622 let pid = child.id();
623 #[cfg(unix)]
624 {
625 unsafe {
626 libc::kill(pid as i32, libc::SIGTERM);
627 }
628 let deadline = Instant::now() + Duration::from_secs(GROUP_TERM_GRACE_SECS);
629 loop {
630 if Instant::now() > deadline {
631 break;
632 }
633 match child.try_wait() {
634 Ok(Some(_)) => return,
635 _ => thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS)),
636 }
637 }
638 unsafe {
639 libc::kill(pid as i32, libc::SIGKILL);
640 }
641 }
642 #[allow(unreachable_code)]
643 {
644 let _ = child.kill();
645 }
646}
647
648fn last_n_lines_of(text: &str, n: usize) -> String {
650 let lines: Vec<&str> = text.lines().collect();
651 let start = lines.len().saturating_sub(n);
652 lines[start..].join("\n")
653}
654
655#[cfg(test)]
660mod tests {
661 use super::*;
662 use crate::shim::protocol;
663
664 #[test]
665 fn last_n_lines_basic() {
666 assert_eq!(last_n_lines_of("a\nb\nc", 2), "b\nc");
667 assert_eq!(last_n_lines_of("a\nb\nc", 10), "a\nb\nc");
668 assert_eq!(last_n_lines_of("", 5), "");
669 }
670
671 #[test]
672 fn codex_state_initial() {
673 let st = CodexState {
674 state: ShimState::Idle,
675 state_changed_at: Instant::now(),
676 started_at: Instant::now(),
677 thread_id: None,
678 accumulated_response: String::new(),
679 pending_message_id: None,
680 message_queue: VecDeque::new(),
681 cumulative_output_bytes: 0,
682 program: "codex".into(),
683 cwd: std::path::PathBuf::from("/tmp"),
684 context_approaching_emitted: false,
685 };
686 assert_eq!(st.state, ShimState::Idle);
687 assert!(st.thread_id.is_none());
688 }
689
690 #[test]
691 fn channel_events_roundtrip() {
692 let (parent_sock, child_sock) = protocol::socketpair().unwrap();
693 let mut parent = protocol::Channel::new(parent_sock);
694 let mut child = protocol::Channel::new(child_sock);
695
696 child.send(&Event::Ready).unwrap();
697 let event: Event = parent.recv().unwrap().unwrap();
698 assert!(matches!(event, Event::Ready));
699 }
700}