1use std::collections::VecDeque;
8use std::io::{BufRead, BufReader, Write as IoWrite};
9use std::process::{Child, Command, Stdio};
10use std::sync::mpsc::{self, RecvTimeoutError};
11use std::sync::{Arc, Mutex};
12use std::thread;
13use std::time::{Duration, Instant};
14
15use anyhow::{Context, Result};
16
17use super::common::{
18 self, MAX_QUEUE_DEPTH, QueuedMessage, SESSION_STATS_INTERVAL_SECS, drain_queue_errors,
19 format_injected_message,
20};
21use super::protocol::{Channel, Command as ShimCommand, Event, ShimState};
22use super::pty_log::PtyLogWriter;
23use super::runtime::ShimArgs;
24use super::sdk_types::{self, SdkControlResponse, SdkOutput, SdkUserMessage};
25
26const PROCESS_EXIT_POLL_MS: u64 = 100;
31const GROUP_TERM_GRACE_SECS: u64 = 2;
32const WORKING_READ_TIMEOUT: Duration = Duration::from_secs(120);
33const STALLED_MID_TURN_MARKER: &str = "stalled mid-turn";
34const MESSAGE_PREVIEW_LIMIT: usize = 160;
35const SDK_COMMAND_POLL_MS: u64 = 1000;
36const SDK_KEEPALIVE_IDLE_SECS: u64 = 300;
37const SDK_KEEPALIVE_MESSAGE: &str =
38 "Continue monitoring. If you have no pending work, reply with 'idle'.";
39const PROACTIVE_CONTEXT_WARNING_PCT: u8 = 80;
40const DEFAULT_CONTEXT_LIMIT_TOKENS: u64 = 128_000;
41
42struct SdkState {
47 state: ShimState,
48 state_changed_at: Instant,
49 started_at: Instant,
50 session_id: String,
52 accumulated_response: String,
54 pending_message_id: Option<String>,
56 last_sent_message_from: Option<String>,
58 last_sent_message_preview: Option<String>,
60 last_model_name: Option<String>,
62 message_queue: VecDeque<QueuedMessage>,
64 cumulative_output_bytes: u64,
66 model: Option<String>,
68 test_failure_iterations: u8,
70 cumulative_input_tokens: u64,
72 cumulative_output_tokens: u64,
74 context_usage_pct: Option<u8>,
76}
77
78#[derive(Debug, Clone)]
79struct ForcedCompletion {
80 previous_state: ShimState,
81 response: String,
82 last_lines: String,
83 message_id: Option<String>,
84 queued_message: Option<QueuedMessage>,
85 queue_depth: usize,
86 session_id: String,
87}
88
89pub fn run_sdk(args: ShimArgs, channel: Channel) -> Result<()> {
98 let mut child = Command::new("bash")
100 .args(["-lc", &args.cmd])
101 .current_dir(&args.cwd)
102 .stdin(Stdio::piped())
103 .stdout(Stdio::piped())
104 .stderr(Stdio::piped())
105 .env_remove("CLAUDECODE") .spawn()
107 .with_context(|| format!("[shim-sdk {}] failed to spawn agent", args.id))?;
108
109 let child_pid = child.id();
110 eprintln!(
111 "[shim-sdk {}] spawned agent subprocess (pid {})",
112 args.id, child_pid
113 );
114
115 let child_stdin = child.stdin.take().context("failed to take child stdin")?;
116 let child_stdout = child.stdout.take().context("failed to take child stdout")?;
117 let child_stderr = child.stderr.take().context("failed to take child stderr")?;
118
119 let state = Arc::new(Mutex::new(SdkState {
121 state: ShimState::Idle, state_changed_at: Instant::now(),
123 started_at: Instant::now(),
124 session_id: String::new(),
125 accumulated_response: String::new(),
126 pending_message_id: None,
127 last_sent_message_from: None,
128 last_sent_message_preview: None,
129 last_model_name: None,
130 message_queue: VecDeque::new(),
131 cumulative_output_bytes: 0,
132 model: None,
133 test_failure_iterations: 0,
134 cumulative_input_tokens: 0,
135 cumulative_output_tokens: 0,
136 context_usage_pct: None,
137 }));
138
139 let stdin_writer = Arc::new(Mutex::new(child_stdin));
141
142 let pty_log: Option<Arc<Mutex<PtyLogWriter>>> = args
144 .pty_log_path
145 .as_deref()
146 .map(|p| PtyLogWriter::new(p).context("failed to create PTY log"))
147 .transpose()?
148 .map(|w| Arc::new(Mutex::new(w)));
149
150 let mut cmd_channel = channel;
152 let mut evt_channel = cmd_channel
153 .try_clone()
154 .context("failed to clone channel for stdout reader")?;
155
156 cmd_channel.set_read_timeout(Some(Duration::from_millis(SDK_COMMAND_POLL_MS)))?;
157
158 cmd_channel.send(&Event::Ready)?;
160
161 let state_stdout = Arc::clone(&state);
163 let stdin_for_approve = Arc::clone(&stdin_writer);
164 let pty_log_stdout = pty_log.clone();
165 let shim_id = args.id.clone();
166 let stdout_handle = thread::spawn(move || {
167 let (line_tx, line_rx) = mpsc::channel();
168 thread::spawn(move || {
169 let reader = BufReader::new(child_stdout);
170 for line_result in reader.lines() {
171 if line_tx.send(line_result).is_err() {
172 break;
173 }
174 }
175 });
176
177 loop {
178 let line_result = match stdout_read_timeout(&state_stdout) {
179 Some(timeout) => match line_rx.recv_timeout(timeout) {
180 Ok(line_result) => Some(line_result),
181 Err(RecvTimeoutError::Timeout) => {
182 if let Some(forced) = force_stalled_completion(&state_stdout, &shim_id) {
183 emit_forced_completion(&mut evt_channel, &stdin_for_approve, forced);
184 }
185 continue;
186 }
187 Err(RecvTimeoutError::Disconnected) => None,
188 },
189 None => line_rx.recv().ok(),
190 };
191
192 let Some(line_result) = line_result else {
193 break;
194 };
195 let line = match line_result {
196 Ok(l) => l,
197 Err(e) => {
198 eprintln!("[shim-sdk {shim_id}] stdout read error: {e}");
199 break;
200 }
201 };
202
203 if line.trim().is_empty() {
204 continue;
205 }
206
207 let msg: SdkOutput = match serde_json::from_str(&line) {
208 Ok(m) => m,
209 Err(e) => {
210 eprintln!("[shim-sdk {shim_id}] ignoring unparseable NDJSON line: {e}");
211 continue;
212 }
213 };
214
215 match msg.msg_type.as_str() {
216 "assistant" => {
217 if let Some(ref message) = msg.message {
219 let model_name = msg.model_name();
220 let text = sdk_types::extract_assistant_text(message);
221 if !text.is_empty() {
222 let mut st = state_stdout.lock().unwrap();
223 if !turn_in_flight(&st) {
224 continue;
225 }
226 if st.last_model_name.is_none() {
227 st.last_model_name = model_name.clone();
228 }
229 st.accumulated_response.push_str(&text);
230 st.cumulative_output_bytes += text.len() as u64;
231
232 if st.session_id.is_empty() {
234 if let Some(ref sid) = msg.session_id {
235 st.session_id = sid.clone();
236 }
237 }
238 if st.model.is_none() {
239 st.model = model_name.clone();
240 }
241 drop(st);
242
243 if let Some(ref log) = pty_log_stdout {
245 let _ = log.lock().unwrap().write(text.as_bytes());
246 }
247 }
248 }
249 }
250
251 "stream_event" => {
252 if let Some(ref event) = msg.event {
254 if let Some(text) = sdk_types::extract_stream_text(event) {
255 let mut st = state_stdout.lock().unwrap();
256 if !turn_in_flight(&st) {
257 continue;
258 }
259 st.accumulated_response.push_str(&text);
260 st.cumulative_output_bytes += text.len() as u64;
261
262 if st.session_id.is_empty() {
263 if let Some(ref sid) = msg.session_id {
264 st.session_id = sid.clone();
265 }
266 }
267 drop(st);
268
269 if let Some(ref log) = pty_log_stdout {
270 let _ = log.lock().unwrap().write(text.as_bytes());
271 }
272 }
273 }
274 }
275
276 "control_request" => {
277 if msg.request_subtype().as_deref() == Some("can_use_tool") {
279 if let (Some(req_id), Some(ref tool_use_id)) =
280 (msg.request_id.as_ref(), msg.request_tool_use_id())
281 {
282 let resp = SdkControlResponse::approve_tool(req_id, tool_use_id);
283 let ndjson = resp.to_ndjson();
284 if let Ok(mut writer) = stdin_for_approve.lock() {
285 let _ = writeln!(writer, "{ndjson}");
286 let _ = writer.flush();
287 }
288 }
289 }
290 }
291
292 "result" => {
293 let mut st = state_stdout.lock().unwrap();
294 if !turn_in_flight(&st) {
295 continue;
296 }
297
298 if st.session_id.is_empty() {
300 if let Some(ref sid) = msg.session_id {
301 st.session_id = sid.clone();
302 }
303 }
304 if let Some(model_name) = msg.model_name() {
305 st.last_model_name = Some(model_name);
306 }
307
308 let is_context_exhausted = msg
310 .errors
311 .as_ref()
312 .map(|errs| errs.iter().any(|e| common::detect_context_exhausted(e)))
313 .unwrap_or(false)
314 || msg
315 .result
316 .as_deref()
317 .map(common::detect_context_exhausted)
318 .unwrap_or(false);
319 let context_warning = proactive_context_warning(
320 &msg,
321 st.last_model_name.as_deref(),
322 st.cumulative_output_bytes,
323 st.started_at.elapsed().as_secs(),
324 );
325
326 if is_context_exhausted {
327 let last_lines = last_n_lines_of(&st.accumulated_response, 5);
328 let old = st.state;
329 st.state = ShimState::ContextExhausted;
330 st.state_changed_at = Instant::now();
331
332 let drain =
333 drain_queue_errors(&mut st.message_queue, ShimState::ContextExhausted);
334 drop(st);
335
336 let _ = evt_channel.send(&Event::StateChanged {
337 from: old,
338 to: ShimState::ContextExhausted,
339 summary: last_lines.clone(),
340 });
341 let _ = evt_channel.send(&Event::ContextExhausted {
342 message: "Agent reported context exhaustion".into(),
343 last_lines,
344 });
345 for event in drain {
346 let _ = evt_channel.send(&event);
347 }
348 continue;
349 }
350
351 if let Some(warning) = context_warning.clone() {
352 let _ = evt_channel.send(&Event::ContextWarning {
353 model: warning.model,
354 output_bytes: warning.output_bytes,
355 uptime_secs: warning.uptime_secs,
356 input_tokens: warning.usage.input_tokens,
357 cached_input_tokens: warning.usage.cached_input_tokens,
358 cache_creation_input_tokens: warning.usage.cache_creation_input_tokens,
359 cache_read_input_tokens: warning.usage.cache_read_input_tokens,
360 output_tokens: warning.usage.output_tokens,
361 reasoning_output_tokens: warning.usage.reasoning_output_tokens,
362 used_tokens: warning.used_tokens,
363 context_limit_tokens: warning.context_limit_tokens,
364 usage_pct: warning.usage_pct,
365 });
366 }
367
368 let response = if st.accumulated_response.is_empty() {
370 msg.result.clone().unwrap_or_default()
371 } else {
372 std::mem::take(&mut st.accumulated_response)
373 };
374 if let Some(followup) =
375 common::detect_test_failure_followup(&response, st.test_failure_iterations)
376 {
377 st.pending_message_id = None;
378 st.test_failure_iterations = followup.next_iteration_count;
379 st.last_sent_message_from = Some("batty".into());
380 st.last_sent_message_preview = Some(message_preview(&followup.body));
381 st.state = ShimState::Working;
382 st.state_changed_at = Instant::now();
383 let session_id = st.session_id.clone();
384 drop(st);
385
386 let text = format_injected_message("batty", &followup.body);
387 let user_msg = SdkUserMessage::new(&session_id, &text);
388 let ndjson = user_msg.to_ndjson();
389 if let Ok(mut writer) = stdin_for_approve.lock() {
390 let _ = writeln!(writer, "{ndjson}");
391 let _ = writer.flush();
392 }
393 let _ = evt_channel.send(&Event::Warning {
394 message: followup.notice,
395 idle_secs: None,
396 });
397 continue;
398 }
399 st.test_failure_iterations = 0;
400 let last_lines = last_n_lines_of(&response, 5);
401 let msg_id = st.pending_message_id.take();
402 let old = st.state;
403 st.state = ShimState::Idle;
404 st.state_changed_at = Instant::now();
405
406 let queued_msg = if !st.message_queue.is_empty() {
408 st.message_queue.pop_front()
409 } else {
410 None
411 };
412
413 if let Some(ref qm) = queued_msg {
415 st.pending_message_id = qm.message_id.clone();
416 st.last_sent_message_from = Some(qm.from.clone());
417 st.last_sent_message_preview = Some(message_preview(&qm.body));
418 st.state = ShimState::Working;
419 st.state_changed_at = Instant::now();
420 st.accumulated_response.clear();
421 st.test_failure_iterations = 0;
422 } else {
423 st.last_sent_message_from = None;
424 st.last_sent_message_preview = None;
425 }
426
427 let queue_depth = st.message_queue.len();
428 let session_id = st.session_id.clone();
429 drop(st);
430
431 let _ = evt_channel.send(&Event::StateChanged {
433 from: old,
434 to: ShimState::Idle,
435 summary: last_lines.clone(),
436 });
437 let _ = evt_channel.send(&Event::Completion {
438 message_id: msg_id,
439 response,
440 last_lines,
441 });
442
443 if let Some(qm) = queued_msg {
445 let text = format_injected_message(&qm.from, &qm.body);
446 let user_msg = SdkUserMessage::new(&session_id, &text);
447 let ndjson = user_msg.to_ndjson();
448 if let Ok(mut writer) = stdin_for_approve.lock() {
449 let _ = writeln!(writer, "{ndjson}");
450 let _ = writer.flush();
451 }
452 let _ = evt_channel.send(&Event::StateChanged {
453 from: ShimState::Idle,
454 to: ShimState::Working,
455 summary: format!("delivering queued message ({queue_depth} remaining)"),
456 });
457 }
458 }
459
460 _ => {
461 }
463 }
464 }
465
466 let mut st = state_stdout.lock().unwrap();
468 let last_lines = last_n_lines_of(&st.accumulated_response, 10);
469 let old = st.state;
470 st.state = ShimState::Dead;
471 st.state_changed_at = Instant::now();
472
473 let drain = drain_queue_errors(&mut st.message_queue, ShimState::Dead);
474 drop(st);
475
476 let _ = evt_channel.send(&Event::StateChanged {
477 from: old,
478 to: ShimState::Dead,
479 summary: last_lines.clone(),
480 });
481 let _ = evt_channel.send(&Event::Died {
482 exit_code: None,
483 last_lines,
484 });
485 for event in drain {
486 let _ = evt_channel.send(&event);
487 }
488 });
489
490 let shim_id_err = args.id.clone();
492 let pty_log_stderr = pty_log;
493 thread::spawn(move || {
494 let reader = BufReader::new(child_stderr);
495 for line_result in reader.lines() {
496 match line_result {
497 Ok(line) => {
498 eprintln!("[shim-sdk {shim_id_err}] stderr: {line}");
499 if let Some(ref log) = pty_log_stderr {
500 let _ = log
501 .lock()
502 .unwrap()
503 .write(format!("[stderr] {line}\n").as_bytes());
504 }
505 }
506 Err(_) => break,
507 }
508 }
509 });
510
511 let state_stats = Arc::clone(&state);
513 let mut stats_channel = cmd_channel
514 .try_clone()
515 .context("failed to clone channel for stats")?;
516 thread::spawn(move || {
517 loop {
518 thread::sleep(Duration::from_secs(SESSION_STATS_INTERVAL_SECS));
519 let st = state_stats.lock().unwrap();
520 if st.state == ShimState::Dead {
521 return;
522 }
523 let output_bytes = st.cumulative_output_bytes;
524 let uptime_secs = st.started_at.elapsed().as_secs();
525 let input_tokens = st.cumulative_input_tokens;
526 let output_tokens = st.cumulative_output_tokens;
527 let context_usage_pct = st.context_usage_pct;
528 drop(st);
529
530 if stats_channel
531 .send(&Event::SessionStats {
532 output_bytes,
533 uptime_secs,
534 input_tokens,
535 output_tokens,
536 context_usage_pct,
537 })
538 .is_err()
539 {
540 return;
541 }
542 }
543 });
544
545 let state_cmd = Arc::clone(&state);
547 let mut last_keepalive = Instant::now();
548 loop {
549 let cmd = match cmd_channel.recv::<ShimCommand>() {
550 Ok(Some(c)) => c,
551 Ok(None) => {
552 eprintln!(
553 "[shim-sdk {}] orchestrator disconnected, shutting down",
554 args.id
555 );
556 terminate_child(&mut child);
557 break;
558 }
559 Err(error)
560 if error
561 .downcast_ref::<std::io::Error>()
562 .is_some_and(|io_error| {
563 matches!(
564 io_error.kind(),
565 std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
566 )
567 }) =>
568 {
569 maybe_send_keepalive(&state_cmd, &stdin_writer, &mut last_keepalive);
570 continue;
571 }
572 Err(e) => {
573 eprintln!("[shim-sdk {}] channel error: {e}", args.id);
574 terminate_child(&mut child);
575 break;
576 }
577 };
578
579 match cmd {
580 ShimCommand::SendMessage {
581 from,
582 body,
583 message_id,
584 } => {
585 let delivery_id = message_id.clone();
586 last_keepalive = Instant::now();
587 let mut st = state_cmd.lock().unwrap();
588 match st.state {
589 ShimState::Idle => {
590 st.pending_message_id = message_id;
591 st.last_sent_message_from = Some(from.clone());
592 st.last_sent_message_preview = Some(message_preview(&body));
593 st.accumulated_response.clear();
594 st.test_failure_iterations = 0;
595 let session_id = st.session_id.clone();
596 st.state = ShimState::Working;
597 st.state_changed_at = Instant::now();
598 drop(st);
599
600 let text = format_injected_message(&from, &body);
601 let user_msg = SdkUserMessage::new(&session_id, &text);
602 let ndjson = user_msg.to_ndjson();
603
604 if let Ok(mut writer) = stdin_writer.lock() {
605 if let Err(e) = writeln!(writer, "{ndjson}") {
606 if let Some(id) = delivery_id {
607 cmd_channel.send(&Event::DeliveryFailed {
608 id,
609 reason: format!("stdin write failed: {e}"),
610 })?;
611 }
612 cmd_channel.send(&Event::Error {
613 command: "SendMessage".into(),
614 reason: format!("stdin write failed: {e}"),
615 })?;
616 continue;
617 }
618 let _ = writer.flush();
619 }
620
621 if let Some(id) = delivery_id {
622 cmd_channel.send(&Event::MessageDelivered { id })?;
623 }
624 cmd_channel.send(&Event::StateChanged {
625 from: ShimState::Idle,
626 to: ShimState::Working,
627 summary: String::new(),
628 })?;
629 }
630 ShimState::Working => {
631 if st.message_queue.len() >= MAX_QUEUE_DEPTH {
633 let dropped = st.message_queue.pop_front();
634 let dropped_id = dropped.as_ref().and_then(|m| m.message_id.clone());
635 st.message_queue.push_back(QueuedMessage {
636 from,
637 body,
638 message_id,
639 });
640 let depth = st.message_queue.len();
641 drop(st);
642
643 cmd_channel.send(&Event::Error {
644 command: "SendMessage".into(),
645 reason: format!(
646 "message queue full ({MAX_QUEUE_DEPTH}), dropped oldest message{}",
647 dropped_id
648 .map(|id| format!(" (id: {id})"))
649 .unwrap_or_default(),
650 ),
651 })?;
652 cmd_channel.send(&Event::Warning {
653 message: format!(
654 "message queued while agent working (depth: {depth})"
655 ),
656 idle_secs: None,
657 })?;
658 } else {
659 st.message_queue.push_back(QueuedMessage {
660 from,
661 body,
662 message_id,
663 });
664 let depth = st.message_queue.len();
665 drop(st);
666
667 cmd_channel.send(&Event::Warning {
668 message: format!(
669 "message queued while agent working (depth: {depth})"
670 ),
671 idle_secs: None,
672 })?;
673 }
674 }
675 other => {
676 drop(st);
677 cmd_channel.send(&Event::Error {
678 command: "SendMessage".into(),
679 reason: format!("agent in {other} state, cannot accept message"),
680 })?;
681 }
682 }
683 }
684
685 ShimCommand::CaptureScreen { last_n_lines } => {
686 let st = state_cmd.lock().unwrap();
687 let content = match last_n_lines {
688 Some(n) => last_n_lines_of(&st.accumulated_response, n),
689 None => st.accumulated_response.clone(),
690 };
691 drop(st);
692 cmd_channel.send(&Event::ScreenCapture {
693 content,
694 cursor_row: 0,
695 cursor_col: 0,
696 })?;
697 }
698
699 ShimCommand::GetState => {
700 let st = state_cmd.lock().unwrap();
701 let since = st.state_changed_at.elapsed().as_secs();
702 let state = st.state;
703 drop(st);
704 cmd_channel.send(&Event::State {
705 state,
706 since_secs: since,
707 })?;
708 }
709
710 ShimCommand::Resize { .. } => {
711 }
713
714 ShimCommand::Ping => {
715 last_keepalive = Instant::now();
716 cmd_channel.send(&Event::Pong)?;
717 }
718
719 ShimCommand::Shutdown {
720 timeout_secs,
721 reason,
722 } => {
723 eprintln!(
724 "[shim-sdk {}] shutdown requested ({}, timeout: {}s)",
725 args.id,
726 reason.label(),
727 timeout_secs
728 );
729 if let Err(error) = super::runtime::preserve_work_before_kill(&args.cwd) {
730 eprintln!("[shim-sdk {}] work preservation failed: {error}", args.id);
731 }
732 drop(stdin_writer);
734
735 let deadline = Instant::now() + Duration::from_secs(timeout_secs as u64);
736 loop {
737 if Instant::now() > deadline {
738 terminate_child(&mut child);
739 break;
740 }
741 match child.try_wait() {
742 Ok(Some(_)) => break,
743 _ => thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS)),
744 }
745 }
746 break;
747 }
748
749 ShimCommand::Kill => {
750 if let Err(error) = super::runtime::preserve_work_before_kill(&args.cwd) {
751 eprintln!("[shim-sdk {}] work preservation failed: {error}", args.id);
752 }
753 terminate_child(&mut child);
754 break;
755 }
756 }
757 }
758
759 stdout_handle.join().ok();
760 Ok(())
761}
762
763fn maybe_send_keepalive<W: IoWrite>(
764 state: &Arc<Mutex<SdkState>>,
765 stdin_writer: &Arc<Mutex<W>>,
766 last_keepalive: &mut Instant,
767) {
768 if last_keepalive.elapsed() < Duration::from_secs(SDK_KEEPALIVE_IDLE_SECS) {
769 return;
770 }
771
772 let session_id = {
773 let mut st = state.lock().unwrap();
774 if st.state != ShimState::Idle
775 || st.session_id.is_empty()
776 || st.pending_message_id.is_some()
777 {
778 return;
779 }
780 st.state = ShimState::Working;
781 st.state_changed_at = Instant::now();
782 st.accumulated_response.clear();
783 st.test_failure_iterations = 0;
784 st.session_id.clone()
785 };
786
787 let user_msg = SdkUserMessage::new(&session_id, SDK_KEEPALIVE_MESSAGE);
788 let ndjson = user_msg.to_ndjson();
789 if let Ok(mut writer) = stdin_writer.lock() {
790 if writeln!(writer, "{ndjson}").is_ok() {
791 let _ = writer.flush();
792 *last_keepalive = Instant::now();
793 return;
794 }
795 }
796
797 let mut st = state.lock().unwrap();
798 st.state = ShimState::Idle;
799 st.state_changed_at = Instant::now();
800}
801
802#[derive(Debug, Clone)]
803struct ProactiveContextWarning {
804 model: Option<String>,
805 usage: sdk_types::SdkTokenUsage,
806 output_bytes: u64,
807 uptime_secs: u64,
808 used_tokens: u64,
809 context_limit_tokens: u64,
810 usage_pct: u8,
811}
812
813fn proactive_context_warning(
814 msg: &SdkOutput,
815 last_model_name: Option<&str>,
816 output_bytes: u64,
817 uptime_secs: u64,
818) -> Option<ProactiveContextWarning> {
819 let usage = msg.token_usage()?;
820 let used_tokens = msg.usage_total_tokens();
821 if used_tokens == 0 {
822 return None;
823 }
824
825 let model = msg
826 .model_name()
827 .or_else(|| last_model_name.map(str::to_string));
828 let context_limit_tokens = resolved_model_context_limit_tokens(model.as_deref());
829 let usage_pct = ((used_tokens.saturating_mul(100)) / context_limit_tokens.max(1)) as u8;
830 if usage_pct < PROACTIVE_CONTEXT_WARNING_PCT {
831 return None;
832 }
833
834 Some(ProactiveContextWarning {
835 model,
836 usage,
837 output_bytes,
838 uptime_secs,
839 used_tokens,
840 context_limit_tokens,
841 usage_pct,
842 })
843}
844
845fn resolved_model_context_limit_tokens(model: Option<&str>) -> u64 {
846 let Some(model) = model else {
847 return DEFAULT_CONTEXT_LIMIT_TOKENS;
848 };
849 let normalized = model.to_ascii_lowercase();
850
851 if normalized.contains("1m") {
852 1_000_000
853 } else if normalized.starts_with("claude-") || normalized.contains("claude") {
854 200_000
855 } else {
856 DEFAULT_CONTEXT_LIMIT_TOKENS
857 }
858}
859
860#[cfg(test)]
865fn model_context_usage_pct(model: Option<&str>, total_tokens: u64) -> Option<u8> {
866 let limit = resolved_model_context_limit_tokens(model);
867 Some(((total_tokens.saturating_mul(100)) / limit).min(100) as u8)
868}
869
870#[cfg(test)]
871fn model_context_limit_tokens(model: &str) -> Option<u64> {
872 let model = model.to_ascii_lowercase();
873 if model.contains("1m") {
874 Some(1_000_000)
875 } else if model.starts_with("claude") {
876 Some(200_000)
877 } else {
878 None
879 }
880}
881
882fn terminate_child(child: &mut Child) {
884 let pid = child.id();
885
886 #[cfg(unix)]
887 {
888 unsafe {
889 libc::kill(pid as i32, libc::SIGTERM);
890 }
891 let deadline = Instant::now() + Duration::from_secs(GROUP_TERM_GRACE_SECS);
892 loop {
893 if Instant::now() > deadline {
894 break;
895 }
896 match child.try_wait() {
897 Ok(Some(_)) => return,
898 _ => thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS)),
899 }
900 }
901 unsafe {
902 libc::kill(pid as i32, libc::SIGKILL);
903 }
904 }
905
906 #[allow(unreachable_code)]
907 {
908 let _ = child.kill();
909 }
910}
911
912fn last_n_lines_of(text: &str, n: usize) -> String {
914 let lines: Vec<&str> = text.lines().collect();
915 let start = lines.len().saturating_sub(n);
916 lines[start..].join("\n")
917}
918
919fn stdout_read_timeout(state: &Arc<Mutex<SdkState>>) -> Option<Duration> {
920 let st = state.lock().unwrap();
921 (st.state == ShimState::Working).then_some(WORKING_READ_TIMEOUT)
922}
923
924fn turn_in_flight(state: &SdkState) -> bool {
925 state.state == ShimState::Working || state.pending_message_id.is_some()
926}
927
928fn message_preview(body: &str) -> String {
929 let normalized = body.split_whitespace().collect::<Vec<_>>().join(" ");
930 if normalized.chars().count() <= MESSAGE_PREVIEW_LIMIT {
931 normalized
932 } else {
933 let preview: String = normalized.chars().take(MESSAGE_PREVIEW_LIMIT).collect();
934 format!("{preview}...")
935 }
936}
937
938fn stalled_mid_turn_response(from: Option<&str>, preview: Option<&str>) -> String {
939 let source = from.unwrap_or("unknown");
940 let preview = preview.unwrap_or("(unavailable)");
941 format!(
942 "{STALLED_MID_TURN_MARKER}: no stdout from Claude SDK for {}s while working.\nlast_sent_message_from: {source}\nlast_sent_message_preview: {preview}",
943 WORKING_READ_TIMEOUT.as_secs()
944 )
945}
946
947fn force_stalled_completion(
948 state: &Arc<Mutex<SdkState>>,
949 shim_id: &str,
950) -> Option<ForcedCompletion> {
951 let mut st = state.lock().unwrap();
952 if st.state != ShimState::Working {
953 return None;
954 }
955
956 let response = stalled_mid_turn_response(
957 st.last_sent_message_from.as_deref(),
958 st.last_sent_message_preview.as_deref(),
959 );
960 let last_lines = last_n_lines_of(&response, 5);
961 let message_id = st.pending_message_id.take();
962 let previous_state = st.state;
963 let queued_message = st.message_queue.pop_front();
964
965 eprintln!(
966 "[shim-sdk {shim_id}] STALL DETECTED after {}s while working",
967 WORKING_READ_TIMEOUT.as_secs()
968 );
969
970 st.state = ShimState::Idle;
971 st.state_changed_at = Instant::now();
972 st.accumulated_response.clear();
973
974 if let Some(ref queued) = queued_message {
975 st.pending_message_id = queued.message_id.clone();
976 st.last_sent_message_from = Some(queued.from.clone());
977 st.last_sent_message_preview = Some(message_preview(&queued.body));
978 st.state = ShimState::Working;
979 st.state_changed_at = Instant::now();
980 } else {
981 st.last_sent_message_from = None;
982 st.last_sent_message_preview = None;
983 }
984
985 Some(ForcedCompletion {
986 previous_state,
987 response,
988 last_lines,
989 message_id,
990 queued_message,
991 queue_depth: st.message_queue.len(),
992 session_id: st.session_id.clone(),
993 })
994}
995
996fn emit_forced_completion<W: IoWrite>(
997 evt_channel: &mut Channel,
998 stdin_writer: &Arc<Mutex<W>>,
999 forced: ForcedCompletion,
1000) {
1001 let _ = evt_channel.send(&Event::StateChanged {
1002 from: forced.previous_state,
1003 to: ShimState::Idle,
1004 summary: forced.last_lines.clone(),
1005 });
1006 let _ = evt_channel.send(&Event::Completion {
1007 message_id: forced.message_id,
1008 response: forced.response,
1009 last_lines: forced.last_lines,
1010 });
1011
1012 if let Some(qm) = forced.queued_message {
1013 let text = format_injected_message(&qm.from, &qm.body);
1014 let user_msg = SdkUserMessage::new(&forced.session_id, &text);
1015 let ndjson = user_msg.to_ndjson();
1016 if let Ok(mut writer) = stdin_writer.lock() {
1017 let _ = writeln!(writer, "{ndjson}");
1018 let _ = writer.flush();
1019 }
1020 let _ = evt_channel.send(&Event::StateChanged {
1021 from: ShimState::Idle,
1022 to: ShimState::Working,
1023 summary: format!(
1024 "delivering queued message ({} remaining)",
1025 forced.queue_depth
1026 ),
1027 });
1028 }
1029}
1030
1031#[cfg(test)]
1036mod tests {
1037 use super::*;
1038 use crate::shim::protocol;
1039
1040 #[test]
1041 fn last_n_lines_basic() {
1042 let text = "a\nb\nc\nd\ne";
1043 assert_eq!(last_n_lines_of(text, 3), "c\nd\ne");
1044 assert_eq!(last_n_lines_of(text, 10), "a\nb\nc\nd\ne");
1045 assert_eq!(last_n_lines_of(text, 0), "");
1046 }
1047
1048 #[test]
1049 fn last_n_lines_empty() {
1050 assert_eq!(last_n_lines_of("", 5), "");
1051 }
1052
1053 #[test]
1054 fn sdk_state_initial_values() {
1055 let st = SdkState {
1056 state: ShimState::Idle,
1057 state_changed_at: Instant::now(),
1058 started_at: Instant::now(),
1059 session_id: String::new(),
1060 accumulated_response: String::new(),
1061 pending_message_id: None,
1062 last_sent_message_from: None,
1063 last_sent_message_preview: None,
1064 last_model_name: None,
1065 message_queue: VecDeque::new(),
1066 cumulative_output_bytes: 0,
1067 model: None,
1068 test_failure_iterations: 0,
1069 cumulative_input_tokens: 0,
1070 cumulative_output_tokens: 0,
1071 context_usage_pct: None,
1072 };
1073 assert_eq!(st.state, ShimState::Idle);
1074 assert!(st.session_id.is_empty());
1075 assert!(st.message_queue.is_empty());
1076 }
1077
1078 #[test]
1081 fn user_message_ndjson_format() {
1082 let msg = SdkUserMessage::new("sess-abc", "Fix the bug");
1083 let json: serde_json::Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
1084 assert_eq!(json["type"], "user");
1085 assert_eq!(json["session_id"], "sess-abc");
1086 assert_eq!(json["message"]["role"], "user");
1087 assert_eq!(json["message"]["content"], "Fix the bug");
1088 }
1089
1090 #[test]
1092 fn channel_round_trip_events() {
1093 let (parent_sock, child_sock) = protocol::socketpair().unwrap();
1094 let mut parent = protocol::Channel::new(parent_sock);
1095 let mut child = protocol::Channel::new(child_sock);
1096
1097 child.send(&Event::Ready).unwrap();
1098 let event: Event = parent.recv().unwrap().unwrap();
1099 assert!(matches!(event, Event::Ready));
1100
1101 child
1102 .send(&Event::Completion {
1103 message_id: Some("m1".into()),
1104 response: "done".into(),
1105 last_lines: "done".into(),
1106 })
1107 .unwrap();
1108 let event: Event = parent.recv().unwrap().unwrap();
1109 match event {
1110 Event::Completion {
1111 message_id,
1112 response,
1113 ..
1114 } => {
1115 assert_eq!(message_id.as_deref(), Some("m1"));
1116 assert_eq!(response, "done");
1117 }
1118 _ => panic!("expected Completion"),
1119 }
1120 }
1121
1122 #[test]
1124 fn context_exhaustion_from_errors() {
1125 assert!(common::detect_context_exhausted("context window exceeded"));
1126 assert!(common::detect_context_exhausted(
1127 "Error: the conversation is too long"
1128 ));
1129 assert!(!common::detect_context_exhausted("all good"));
1130 }
1131
1132 #[test]
1133 fn message_preview_normalizes_and_truncates() {
1134 let preview = message_preview("hello\n\nthere world");
1135 assert_eq!(preview, "hello there world");
1136
1137 let long = "x".repeat(MESSAGE_PREVIEW_LIMIT + 10);
1138 let truncated = message_preview(&long);
1139 assert!(truncated.ends_with("..."));
1140 assert!(truncated.len() > MESSAGE_PREVIEW_LIMIT);
1141 }
1142
1143 #[test]
1144 fn model_context_limit_tokens_detects_one_million_alias() {
1145 assert_eq!(
1146 model_context_limit_tokens("claude-opus-4-6-1m"),
1147 Some(1_000_000)
1148 );
1149 assert_eq!(
1150 model_context_limit_tokens("claude-sonnet-4-6"),
1151 Some(200_000)
1152 );
1153 assert_eq!(model_context_limit_tokens("gpt-5.4"), None);
1154 }
1155
1156 #[test]
1157 fn model_context_usage_pct_includes_cache_tokens() {
1158 assert_eq!(
1159 model_context_usage_pct(Some("claude-sonnet-4-6"), 180_000),
1160 Some(90)
1161 );
1162 assert_eq!(
1163 model_context_usage_pct(Some("claude-opus-4-6-1m"), 500_000),
1164 Some(50)
1165 );
1166 }
1167
1168 #[test]
1169 fn stalled_mid_turn_response_includes_tracked_message_context() {
1170 let response = stalled_mid_turn_response(Some("manager"), Some("continue task 496"));
1171 assert!(response.starts_with(STALLED_MID_TURN_MARKER));
1172 assert!(response.contains("last_sent_message_from: manager"));
1173 assert!(response.contains("last_sent_message_preview: continue task 496"));
1174 }
1175
1176 #[test]
1177 fn force_stalled_completion_releases_working_turn() {
1178 let state = Arc::new(Mutex::new(SdkState {
1179 state: ShimState::Working,
1180 state_changed_at: Instant::now(),
1181 started_at: Instant::now(),
1182 session_id: "sess-1".into(),
1183 accumulated_response: "partial output".into(),
1184 pending_message_id: Some("msg-1".into()),
1185 last_sent_message_from: Some("manager".into()),
1186 last_sent_message_preview: Some("continue task".into()),
1187 last_model_name: Some("claude-sonnet-4-5".into()),
1188 message_queue: VecDeque::new(),
1189 cumulative_output_bytes: 12,
1190 model: None,
1191 test_failure_iterations: 0,
1192 cumulative_input_tokens: 0,
1193 cumulative_output_tokens: 0,
1194 context_usage_pct: None,
1195 }));
1196
1197 let forced = force_stalled_completion(&state, "sdk-test").expect("forced completion");
1198 assert_eq!(forced.previous_state, ShimState::Working);
1199 assert_eq!(forced.message_id.as_deref(), Some("msg-1"));
1200 assert!(forced.response.starts_with(STALLED_MID_TURN_MARKER));
1201
1202 let st = state.lock().unwrap();
1203 assert_eq!(st.state, ShimState::Idle);
1204 assert!(st.pending_message_id.is_none());
1205 assert!(st.accumulated_response.is_empty());
1206 assert!(st.last_sent_message_from.is_none());
1207 assert!(st.last_sent_message_preview.is_none());
1208 }
1209
1210 #[test]
1211 fn force_stalled_completion_promotes_queued_message() {
1212 let state = Arc::new(Mutex::new(SdkState {
1213 state: ShimState::Working,
1214 state_changed_at: Instant::now(),
1215 started_at: Instant::now(),
1216 session_id: "sess-2".into(),
1217 accumulated_response: String::new(),
1218 pending_message_id: Some("msg-1".into()),
1219 last_sent_message_from: Some("manager".into()),
1220 last_sent_message_preview: Some("first".into()),
1221 last_model_name: Some("claude-sonnet-4-5".into()),
1222 message_queue: VecDeque::from([QueuedMessage {
1223 from: "architect".into(),
1224 body: "second message".into(),
1225 message_id: Some("msg-2".into()),
1226 }]),
1227 cumulative_output_bytes: 0,
1228 model: None,
1229 test_failure_iterations: 0,
1230 cumulative_input_tokens: 0,
1231 cumulative_output_tokens: 0,
1232 context_usage_pct: None,
1233 }));
1234
1235 let forced = force_stalled_completion(&state, "sdk-test").expect("forced completion");
1236 assert!(forced.queued_message.is_some());
1237 assert_eq!(forced.queue_depth, 0);
1238
1239 let st = state.lock().unwrap();
1240 assert_eq!(st.state, ShimState::Working);
1241 assert_eq!(st.pending_message_id.as_deref(), Some("msg-2"));
1242 assert_eq!(st.last_sent_message_from.as_deref(), Some("architect"));
1243 assert_eq!(
1244 st.last_sent_message_preview.as_deref(),
1245 Some("second message")
1246 );
1247 }
1248
1249 #[test]
1250 fn keepalive_is_skipped_before_interval() {
1251 let state = Arc::new(Mutex::new(SdkState {
1252 state: ShimState::Idle,
1253 state_changed_at: Instant::now(),
1254 started_at: Instant::now(),
1255 session_id: "sess-1".into(),
1256 accumulated_response: String::new(),
1257 pending_message_id: None,
1258 last_sent_message_from: None,
1259 last_sent_message_preview: None,
1260 last_model_name: None,
1261 message_queue: VecDeque::new(),
1262 cumulative_output_bytes: 0,
1263 model: None,
1264 test_failure_iterations: 0,
1265 cumulative_input_tokens: 0,
1266 cumulative_output_tokens: 0,
1267 context_usage_pct: None,
1268 }));
1269 let writer = Arc::new(Mutex::new(Vec::<u8>::new()));
1270 let mut last_keepalive = Instant::now();
1271
1272 maybe_send_keepalive(&state, &writer, &mut last_keepalive);
1273
1274 assert!(writer.lock().unwrap().is_empty());
1275 assert_eq!(state.lock().unwrap().state, ShimState::Idle);
1276 }
1277
1278 #[test]
1279 fn keepalive_sends_message_after_interval() {
1280 let state = Arc::new(Mutex::new(SdkState {
1281 state: ShimState::Idle,
1282 state_changed_at: Instant::now(),
1283 started_at: Instant::now(),
1284 session_id: "sess-1".into(),
1285 accumulated_response: "stale output".into(),
1286 pending_message_id: None,
1287 last_sent_message_from: None,
1288 last_sent_message_preview: None,
1289 last_model_name: None,
1290 message_queue: VecDeque::new(),
1291 cumulative_output_bytes: 0,
1292 model: None,
1293 test_failure_iterations: 0,
1294 cumulative_input_tokens: 0,
1295 cumulative_output_tokens: 0,
1296 context_usage_pct: None,
1297 }));
1298 let writer = Arc::new(Mutex::new(Vec::<u8>::new()));
1299 let mut last_keepalive = Instant::now() - Duration::from_secs(SDK_KEEPALIVE_IDLE_SECS + 1);
1300
1301 maybe_send_keepalive(&state, &writer, &mut last_keepalive);
1302
1303 let output = String::from_utf8(writer.lock().unwrap().clone()).unwrap();
1304 assert!(output.contains("\"type\":\"user\""));
1305 assert!(output.contains("\"session_id\":\"sess-1\""));
1306 assert!(output.contains(SDK_KEEPALIVE_MESSAGE));
1307
1308 let st = state.lock().unwrap();
1309 assert_eq!(st.state, ShimState::Working);
1310 assert!(st.accumulated_response.is_empty());
1311 }
1312
1313 #[test]
1314 fn keepalive_is_skipped_without_session() {
1315 let state = Arc::new(Mutex::new(SdkState {
1316 state: ShimState::Idle,
1317 state_changed_at: Instant::now(),
1318 started_at: Instant::now(),
1319 session_id: String::new(),
1320 accumulated_response: String::new(),
1321 pending_message_id: None,
1322 last_sent_message_from: None,
1323 last_sent_message_preview: None,
1324 last_model_name: None,
1325 message_queue: VecDeque::new(),
1326 cumulative_output_bytes: 0,
1327 model: None,
1328 test_failure_iterations: 0,
1329 cumulative_input_tokens: 0,
1330 cumulative_output_tokens: 0,
1331 context_usage_pct: None,
1332 }));
1333 let writer = Arc::new(Mutex::new(Vec::<u8>::new()));
1334 let mut last_keepalive = Instant::now() - Duration::from_secs(SDK_KEEPALIVE_IDLE_SECS + 1);
1335
1336 maybe_send_keepalive(&state, &writer, &mut last_keepalive);
1337
1338 assert!(writer.lock().unwrap().is_empty());
1339 assert_eq!(state.lock().unwrap().state, ShimState::Idle);
1340 }
1341
1342 #[test]
1343 fn proactive_context_warning_uses_model_aware_limits_and_cache_tokens() {
1344 let line = r#"{"type":"result","usage":{"input_tokens":100000,"cached_input_tokens":15000,"cache_creation_input_tokens":10000,"cache_read_input_tokens":5000,"output_tokens":20000,"reasoning_output_tokens":10000}}"#;
1345 let msg: SdkOutput = serde_json::from_str(line).unwrap();
1346 let warning =
1347 proactive_context_warning(&msg, Some("claude-sonnet-4-5"), 42_000, 900).unwrap();
1348
1349 assert_eq!(warning.context_limit_tokens, 200_000);
1350 assert_eq!(warning.used_tokens, 160_000);
1351 assert_eq!(warning.usage_pct, 80);
1352 assert_eq!(warning.model.as_deref(), Some("claude-sonnet-4-5"));
1353 }
1354
1355 #[test]
1356 fn proactive_context_warning_uses_one_million_limit_for_opus_1m() {
1357 let line = r#"{"type":"result","usage":{"input_tokens":700000,"cached_input_tokens":50000,"cache_creation_input_tokens":20000,"cache_read_input_tokens":10000,"output_tokens":10000,"reasoning_output_tokens":10000}}"#;
1358 let msg: SdkOutput = serde_json::from_str(line).unwrap();
1359 let warning =
1360 proactive_context_warning(&msg, Some("claude-opus-4.6-1m"), 42_000, 900).unwrap();
1361
1362 assert_eq!(warning.context_limit_tokens, 1_000_000);
1363 assert_eq!(warning.used_tokens, 800_000);
1364 assert_eq!(warning.usage_pct, 80);
1365 }
1366
1367 #[test]
1368 fn proactive_context_warning_skips_usage_below_threshold() {
1369 let line = r#"{"type":"result","usage":{"input_tokens":20000,"cached_input_tokens":1000,"output_tokens":4000}}"#;
1370 let msg: SdkOutput = serde_json::from_str(line).unwrap();
1371 assert!(proactive_context_warning(&msg, Some("claude-sonnet-4-5"), 10_000, 120).is_none());
1372 }
1373}