1use std::collections::VecDeque;
8use std::io::{BufRead, BufReader, Write as IoWrite};
9use std::process::{Child, Command, Stdio};
10use std::sync::{Arc, Mutex};
11use std::thread;
12use std::time::{Duration, Instant};
13
14use anyhow::{Context, Result};
15
16use super::common::{
17 self, MAX_QUEUE_DEPTH, QueuedMessage, SESSION_STATS_INTERVAL_SECS, drain_queue_errors,
18 format_injected_message,
19};
20use super::protocol::{Channel, Command as ShimCommand, Event, ShimState};
21use super::pty_log::PtyLogWriter;
22use super::runtime::ShimArgs;
23use super::sdk_types::{self, SdkControlResponse, SdkOutput, SdkUserMessage};
24
25const PROCESS_EXIT_POLL_MS: u64 = 100;
30const GROUP_TERM_GRACE_SECS: u64 = 2;
31
32struct SdkState {
37 state: ShimState,
38 state_changed_at: Instant,
39 started_at: Instant,
40 session_id: String,
42 accumulated_response: String,
44 pending_message_id: Option<String>,
46 message_queue: VecDeque<QueuedMessage>,
48 cumulative_output_bytes: u64,
50}
51
52pub fn run_sdk(args: ShimArgs, channel: Channel) -> Result<()> {
61 let mut child = Command::new("bash")
63 .args(["-lc", &args.cmd])
64 .current_dir(&args.cwd)
65 .stdin(Stdio::piped())
66 .stdout(Stdio::piped())
67 .stderr(Stdio::piped())
68 .env_remove("CLAUDECODE") .spawn()
70 .with_context(|| format!("[shim-sdk {}] failed to spawn agent", args.id))?;
71
72 let child_pid = child.id();
73 eprintln!(
74 "[shim-sdk {}] spawned agent subprocess (pid {})",
75 args.id, child_pid
76 );
77
78 let child_stdin = child.stdin.take().context("failed to take child stdin")?;
79 let child_stdout = child.stdout.take().context("failed to take child stdout")?;
80 let child_stderr = child.stderr.take().context("failed to take child stderr")?;
81
82 let state = Arc::new(Mutex::new(SdkState {
84 state: ShimState::Idle, state_changed_at: Instant::now(),
86 started_at: Instant::now(),
87 session_id: String::new(),
88 accumulated_response: String::new(),
89 pending_message_id: None,
90 message_queue: VecDeque::new(),
91 cumulative_output_bytes: 0,
92 }));
93
94 let stdin_writer = Arc::new(Mutex::new(child_stdin));
96
97 let pty_log: Option<Arc<Mutex<PtyLogWriter>>> = args
99 .pty_log_path
100 .as_deref()
101 .map(|p| PtyLogWriter::new(p).context("failed to create PTY log"))
102 .transpose()?
103 .map(|w| Arc::new(Mutex::new(w)));
104
105 let mut cmd_channel = channel;
107 let mut evt_channel = cmd_channel
108 .try_clone()
109 .context("failed to clone channel for stdout reader")?;
110
111 cmd_channel.send(&Event::Ready)?;
113
114 let state_stdout = Arc::clone(&state);
116 let stdin_for_approve = Arc::clone(&stdin_writer);
117 let pty_log_stdout = pty_log.clone();
118 let shim_id = args.id.clone();
119 let stdout_handle = thread::spawn(move || {
120 let reader = BufReader::new(child_stdout);
121 for line_result in reader.lines() {
122 let line = match line_result {
123 Ok(l) => l,
124 Err(e) => {
125 eprintln!("[shim-sdk {shim_id}] stdout read error: {e}");
126 break;
127 }
128 };
129
130 if line.trim().is_empty() {
131 continue;
132 }
133
134 let msg: SdkOutput = match serde_json::from_str(&line) {
135 Ok(m) => m,
136 Err(e) => {
137 eprintln!("[shim-sdk {shim_id}] ignoring unparseable NDJSON line: {e}");
138 continue;
139 }
140 };
141
142 match msg.msg_type.as_str() {
143 "assistant" => {
144 if let Some(ref message) = msg.message {
146 let text = sdk_types::extract_assistant_text(message);
147 if !text.is_empty() {
148 let mut st = state_stdout.lock().unwrap();
149 st.accumulated_response.push_str(&text);
150 st.cumulative_output_bytes += text.len() as u64;
151
152 if st.session_id.is_empty() {
154 if let Some(ref sid) = msg.session_id {
155 st.session_id = sid.clone();
156 }
157 }
158 drop(st);
159
160 if let Some(ref log) = pty_log_stdout {
162 let _ = log.lock().unwrap().write(text.as_bytes());
163 }
164 }
165 }
166 }
167
168 "stream_event" => {
169 if let Some(ref event) = msg.event {
171 if let Some(text) = sdk_types::extract_stream_text(event) {
172 let mut st = state_stdout.lock().unwrap();
173 st.accumulated_response.push_str(&text);
174 st.cumulative_output_bytes += text.len() as u64;
175
176 if st.session_id.is_empty() {
177 if let Some(ref sid) = msg.session_id {
178 st.session_id = sid.clone();
179 }
180 }
181 drop(st);
182
183 if let Some(ref log) = pty_log_stdout {
184 let _ = log.lock().unwrap().write(text.as_bytes());
185 }
186 }
187 }
188 }
189
190 "control_request" => {
191 if msg.request_subtype().as_deref() == Some("can_use_tool") {
193 if let (Some(req_id), Some(ref tool_use_id)) =
194 (msg.request_id.as_ref(), msg.request_tool_use_id())
195 {
196 let resp = SdkControlResponse::approve_tool(req_id, tool_use_id);
197 let ndjson = resp.to_ndjson();
198 if let Ok(mut writer) = stdin_for_approve.lock() {
199 let _ = writeln!(writer, "{ndjson}");
200 let _ = writer.flush();
201 }
202 }
203 }
204 }
205
206 "result" => {
207 let mut st = state_stdout.lock().unwrap();
208
209 if st.session_id.is_empty() {
211 if let Some(ref sid) = msg.session_id {
212 st.session_id = sid.clone();
213 }
214 }
215
216 let is_context_exhausted = msg
218 .errors
219 .as_ref()
220 .map(|errs| errs.iter().any(|e| common::detect_context_exhausted(e)))
221 .unwrap_or(false)
222 || msg
223 .result
224 .as_deref()
225 .map(common::detect_context_exhausted)
226 .unwrap_or(false);
227
228 if is_context_exhausted {
229 let last_lines = last_n_lines_of(&st.accumulated_response, 5);
230 let old = st.state;
231 st.state = ShimState::ContextExhausted;
232 st.state_changed_at = Instant::now();
233
234 let drain =
235 drain_queue_errors(&mut st.message_queue, ShimState::ContextExhausted);
236 drop(st);
237
238 let _ = evt_channel.send(&Event::StateChanged {
239 from: old,
240 to: ShimState::ContextExhausted,
241 summary: last_lines.clone(),
242 });
243 let _ = evt_channel.send(&Event::ContextExhausted {
244 message: "Agent reported context exhaustion".into(),
245 last_lines,
246 });
247 for event in drain {
248 let _ = evt_channel.send(&event);
249 }
250 continue;
251 }
252
253 let response = if st.accumulated_response.is_empty() {
255 msg.result.clone().unwrap_or_default()
256 } else {
257 std::mem::take(&mut st.accumulated_response)
258 };
259 let last_lines = last_n_lines_of(&response, 5);
260 let msg_id = st.pending_message_id.take();
261 let old = st.state;
262 st.state = ShimState::Idle;
263 st.state_changed_at = Instant::now();
264
265 let queued_msg = if !st.message_queue.is_empty() {
267 st.message_queue.pop_front()
268 } else {
269 None
270 };
271
272 if let Some(ref qm) = queued_msg {
274 st.pending_message_id = qm.message_id.clone();
275 st.state = ShimState::Working;
276 st.state_changed_at = Instant::now();
277 st.accumulated_response.clear();
278 }
279
280 let queue_depth = st.message_queue.len();
281 let session_id = st.session_id.clone();
282 drop(st);
283
284 let _ = evt_channel.send(&Event::StateChanged {
286 from: old,
287 to: ShimState::Idle,
288 summary: last_lines.clone(),
289 });
290 let _ = evt_channel.send(&Event::Completion {
291 message_id: msg_id,
292 response,
293 last_lines,
294 });
295
296 if let Some(qm) = queued_msg {
298 let text = format_injected_message(&qm.from, &qm.body);
299 let user_msg = SdkUserMessage::new(&session_id, &text);
300 let ndjson = user_msg.to_ndjson();
301 if let Ok(mut writer) = stdin_for_approve.lock() {
302 let _ = writeln!(writer, "{ndjson}");
303 let _ = writer.flush();
304 }
305 let _ = evt_channel.send(&Event::StateChanged {
306 from: ShimState::Idle,
307 to: ShimState::Working,
308 summary: format!("delivering queued message ({queue_depth} remaining)"),
309 });
310 }
311 }
312
313 _ => {
314 }
316 }
317 }
318
319 let mut st = state_stdout.lock().unwrap();
321 let last_lines = last_n_lines_of(&st.accumulated_response, 10);
322 let old = st.state;
323 st.state = ShimState::Dead;
324 st.state_changed_at = Instant::now();
325
326 let drain = drain_queue_errors(&mut st.message_queue, ShimState::Dead);
327 drop(st);
328
329 let _ = evt_channel.send(&Event::StateChanged {
330 from: old,
331 to: ShimState::Dead,
332 summary: last_lines.clone(),
333 });
334 let _ = evt_channel.send(&Event::Died {
335 exit_code: None,
336 last_lines,
337 });
338 for event in drain {
339 let _ = evt_channel.send(&event);
340 }
341 });
342
343 let shim_id_err = args.id.clone();
345 let pty_log_stderr = pty_log;
346 thread::spawn(move || {
347 let reader = BufReader::new(child_stderr);
348 for line_result in reader.lines() {
349 match line_result {
350 Ok(line) => {
351 eprintln!("[shim-sdk {shim_id_err}] stderr: {line}");
352 if let Some(ref log) = pty_log_stderr {
353 let _ = log
354 .lock()
355 .unwrap()
356 .write(format!("[stderr] {line}\n").as_bytes());
357 }
358 }
359 Err(_) => break,
360 }
361 }
362 });
363
364 let state_stats = Arc::clone(&state);
366 let mut stats_channel = cmd_channel
367 .try_clone()
368 .context("failed to clone channel for stats")?;
369 thread::spawn(move || {
370 loop {
371 thread::sleep(Duration::from_secs(SESSION_STATS_INTERVAL_SECS));
372 let st = state_stats.lock().unwrap();
373 if st.state == ShimState::Dead {
374 return;
375 }
376 let output_bytes = st.cumulative_output_bytes;
377 let uptime_secs = st.started_at.elapsed().as_secs();
378 drop(st);
379
380 if stats_channel
381 .send(&Event::SessionStats {
382 output_bytes,
383 uptime_secs,
384 })
385 .is_err()
386 {
387 return;
388 }
389 }
390 });
391
392 let state_cmd = Arc::clone(&state);
394 loop {
395 let cmd = match cmd_channel.recv::<ShimCommand>() {
396 Ok(Some(c)) => c,
397 Ok(None) => {
398 eprintln!(
399 "[shim-sdk {}] orchestrator disconnected, shutting down",
400 args.id
401 );
402 terminate_child(&mut child);
403 break;
404 }
405 Err(e) => {
406 eprintln!("[shim-sdk {}] channel error: {e}", args.id);
407 terminate_child(&mut child);
408 break;
409 }
410 };
411
412 match cmd {
413 ShimCommand::SendMessage {
414 from,
415 body,
416 message_id,
417 } => {
418 let mut st = state_cmd.lock().unwrap();
419 match st.state {
420 ShimState::Idle => {
421 st.pending_message_id = message_id;
422 st.accumulated_response.clear();
423 let session_id = st.session_id.clone();
424 st.state = ShimState::Working;
425 st.state_changed_at = Instant::now();
426 drop(st);
427
428 let text = format_injected_message(&from, &body);
429 let user_msg = SdkUserMessage::new(&session_id, &text);
430 let ndjson = user_msg.to_ndjson();
431
432 if let Ok(mut writer) = stdin_writer.lock() {
433 if let Err(e) = writeln!(writer, "{ndjson}") {
434 cmd_channel.send(&Event::Error {
435 command: "SendMessage".into(),
436 reason: format!("stdin write failed: {e}"),
437 })?;
438 continue;
439 }
440 let _ = writer.flush();
441 }
442
443 cmd_channel.send(&Event::StateChanged {
444 from: ShimState::Idle,
445 to: ShimState::Working,
446 summary: String::new(),
447 })?;
448 }
449 ShimState::Working => {
450 if st.message_queue.len() >= MAX_QUEUE_DEPTH {
452 let dropped = st.message_queue.pop_front();
453 let dropped_id = dropped.as_ref().and_then(|m| m.message_id.clone());
454 st.message_queue.push_back(QueuedMessage {
455 from,
456 body,
457 message_id,
458 });
459 let depth = st.message_queue.len();
460 drop(st);
461
462 cmd_channel.send(&Event::Error {
463 command: "SendMessage".into(),
464 reason: format!(
465 "message queue full ({MAX_QUEUE_DEPTH}), dropped oldest message{}",
466 dropped_id
467 .map(|id| format!(" (id: {id})"))
468 .unwrap_or_default(),
469 ),
470 })?;
471 cmd_channel.send(&Event::Warning {
472 message: format!(
473 "message queued while agent working (depth: {depth})"
474 ),
475 idle_secs: None,
476 })?;
477 } else {
478 st.message_queue.push_back(QueuedMessage {
479 from,
480 body,
481 message_id,
482 });
483 let depth = st.message_queue.len();
484 drop(st);
485
486 cmd_channel.send(&Event::Warning {
487 message: format!(
488 "message queued while agent working (depth: {depth})"
489 ),
490 idle_secs: None,
491 })?;
492 }
493 }
494 other => {
495 drop(st);
496 cmd_channel.send(&Event::Error {
497 command: "SendMessage".into(),
498 reason: format!("agent in {other} state, cannot accept message"),
499 })?;
500 }
501 }
502 }
503
504 ShimCommand::CaptureScreen { last_n_lines } => {
505 let st = state_cmd.lock().unwrap();
506 let content = match last_n_lines {
507 Some(n) => last_n_lines_of(&st.accumulated_response, n),
508 None => st.accumulated_response.clone(),
509 };
510 drop(st);
511 cmd_channel.send(&Event::ScreenCapture {
512 content,
513 cursor_row: 0,
514 cursor_col: 0,
515 })?;
516 }
517
518 ShimCommand::GetState => {
519 let st = state_cmd.lock().unwrap();
520 let since = st.state_changed_at.elapsed().as_secs();
521 let state = st.state;
522 drop(st);
523 cmd_channel.send(&Event::State {
524 state,
525 since_secs: since,
526 })?;
527 }
528
529 ShimCommand::Resize { .. } => {
530 }
532
533 ShimCommand::Ping => {
534 cmd_channel.send(&Event::Pong)?;
535 }
536
537 ShimCommand::Shutdown { timeout_secs } => {
538 eprintln!(
539 "[shim-sdk {}] shutdown requested (timeout: {}s)",
540 args.id, timeout_secs
541 );
542 drop(stdin_writer);
544
545 let deadline = Instant::now() + Duration::from_secs(timeout_secs as u64);
546 loop {
547 if Instant::now() > deadline {
548 terminate_child(&mut child);
549 break;
550 }
551 match child.try_wait() {
552 Ok(Some(_)) => break,
553 _ => thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS)),
554 }
555 }
556 break;
557 }
558
559 ShimCommand::Kill => {
560 terminate_child(&mut child);
561 break;
562 }
563 }
564 }
565
566 stdout_handle.join().ok();
567 Ok(())
568}
569
570fn terminate_child(child: &mut Child) {
576 let pid = child.id();
577
578 #[cfg(unix)]
579 {
580 unsafe {
581 libc::kill(pid as i32, libc::SIGTERM);
582 }
583 let deadline = Instant::now() + Duration::from_secs(GROUP_TERM_GRACE_SECS);
584 loop {
585 if Instant::now() > deadline {
586 break;
587 }
588 match child.try_wait() {
589 Ok(Some(_)) => return,
590 _ => thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS)),
591 }
592 }
593 unsafe {
594 libc::kill(pid as i32, libc::SIGKILL);
595 }
596 }
597
598 #[allow(unreachable_code)]
599 {
600 let _ = child.kill();
601 }
602}
603
604fn last_n_lines_of(text: &str, n: usize) -> String {
606 let lines: Vec<&str> = text.lines().collect();
607 let start = lines.len().saturating_sub(n);
608 lines[start..].join("\n")
609}
610
611#[cfg(test)]
616mod tests {
617 use super::*;
618 use crate::shim::protocol;
619
620 #[test]
621 fn last_n_lines_basic() {
622 let text = "a\nb\nc\nd\ne";
623 assert_eq!(last_n_lines_of(text, 3), "c\nd\ne");
624 assert_eq!(last_n_lines_of(text, 10), "a\nb\nc\nd\ne");
625 assert_eq!(last_n_lines_of(text, 0), "");
626 }
627
628 #[test]
629 fn last_n_lines_empty() {
630 assert_eq!(last_n_lines_of("", 5), "");
631 }
632
633 #[test]
634 fn sdk_state_initial_values() {
635 let st = SdkState {
636 state: ShimState::Idle,
637 state_changed_at: Instant::now(),
638 started_at: Instant::now(),
639 session_id: String::new(),
640 accumulated_response: String::new(),
641 pending_message_id: None,
642 message_queue: VecDeque::new(),
643 cumulative_output_bytes: 0,
644 };
645 assert_eq!(st.state, ShimState::Idle);
646 assert!(st.session_id.is_empty());
647 assert!(st.message_queue.is_empty());
648 }
649
650 #[test]
653 fn user_message_ndjson_format() {
654 let msg = SdkUserMessage::new("sess-abc", "Fix the bug");
655 let json: serde_json::Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
656 assert_eq!(json["type"], "user");
657 assert_eq!(json["session_id"], "sess-abc");
658 assert_eq!(json["message"]["role"], "user");
659 assert_eq!(json["message"]["content"], "Fix the bug");
660 }
661
662 #[test]
664 fn channel_round_trip_events() {
665 let (parent_sock, child_sock) = protocol::socketpair().unwrap();
666 let mut parent = protocol::Channel::new(parent_sock);
667 let mut child = protocol::Channel::new(child_sock);
668
669 child.send(&Event::Ready).unwrap();
670 let event: Event = parent.recv().unwrap().unwrap();
671 assert!(matches!(event, Event::Ready));
672
673 child
674 .send(&Event::Completion {
675 message_id: Some("m1".into()),
676 response: "done".into(),
677 last_lines: "done".into(),
678 })
679 .unwrap();
680 let event: Event = parent.recv().unwrap().unwrap();
681 match event {
682 Event::Completion {
683 message_id,
684 response,
685 ..
686 } => {
687 assert_eq!(message_id.as_deref(), Some("m1"));
688 assert_eq!(response, "done");
689 }
690 _ => panic!("expected Completion"),
691 }
692 }
693
694 #[test]
696 fn context_exhaustion_from_errors() {
697 assert!(common::detect_context_exhausted("context window exceeded"));
698 assert!(common::detect_context_exhausted(
699 "Error: the conversation is too long"
700 ));
701 assert!(!common::detect_context_exhausted("all good"));
702 }
703}