Skip to main content

batty_cli/shim/
runtime_sdk.rs

1//! SDK-mode shim runtime: communicates with Claude Code via NDJSON on
2//! stdin/stdout instead of screen-scraping a PTY.
3//!
4//! Emits the same `Command`/`Event` protocol to the orchestrator as the
5//! PTY runtime (`runtime.rs`), making it transparent to all upstream consumers.
6
7use std::collections::VecDeque;
8use std::io::{BufRead, BufReader, Write as IoWrite};
9use std::process::{Child, Command, Stdio};
10use std::sync::{Arc, Mutex};
11use std::thread;
12use std::time::{Duration, Instant};
13
14use anyhow::{Context, Result};
15
16use super::common::{
17    self, MAX_QUEUE_DEPTH, QueuedMessage, SESSION_STATS_INTERVAL_SECS, drain_queue_errors,
18    format_injected_message,
19};
20use super::protocol::{Channel, Command as ShimCommand, Event, ShimState};
21use super::pty_log::PtyLogWriter;
22use super::runtime::ShimArgs;
23use super::sdk_types::{self, SdkControlResponse, SdkOutput, SdkUserMessage};
24
25// ---------------------------------------------------------------------------
26// Configuration
27// ---------------------------------------------------------------------------
28
29const PROCESS_EXIT_POLL_MS: u64 = 100;
30const GROUP_TERM_GRACE_SECS: u64 = 2;
31
32// ---------------------------------------------------------------------------
33// Shared state
34// ---------------------------------------------------------------------------
35
36struct SdkState {
37    state: ShimState,
38    state_changed_at: Instant,
39    started_at: Instant,
40    /// Session ID returned by Claude Code in its first response.
41    session_id: String,
42    /// Accumulated assistant response text for the current turn.
43    accumulated_response: String,
44    /// Message ID of the currently pending (in-flight) message.
45    pending_message_id: Option<String>,
46    /// Messages queued while the agent is in Working state.
47    message_queue: VecDeque<QueuedMessage>,
48    /// Total bytes of response text received.
49    cumulative_output_bytes: u64,
50}
51
52// ---------------------------------------------------------------------------
53// Main entry point
54// ---------------------------------------------------------------------------
55
56/// Run the SDK-mode shim. This function does not return until the shim exits.
57///
58/// `channel` is the pre-connected socket to the orchestrator (fd 3 or socketpair).
59/// `args.cmd` must be a shell command that launches Claude Code in stream-json mode.
60pub fn run_sdk(args: ShimArgs, channel: Channel) -> Result<()> {
61    // -- Spawn subprocess with piped stdin/stdout/stderr --
62    let mut child = Command::new("bash")
63        .args(["-lc", &args.cmd])
64        .current_dir(&args.cwd)
65        .stdin(Stdio::piped())
66        .stdout(Stdio::piped())
67        .stderr(Stdio::piped())
68        .env_remove("CLAUDECODE") // prevent nested detection
69        .spawn()
70        .with_context(|| format!("[shim-sdk {}] failed to spawn agent", args.id))?;
71
72    let child_pid = child.id();
73    eprintln!(
74        "[shim-sdk {}] spawned agent subprocess (pid {})",
75        args.id, child_pid
76    );
77
78    let child_stdin = child.stdin.take().context("failed to take child stdin")?;
79    let child_stdout = child.stdout.take().context("failed to take child stdout")?;
80    let child_stderr = child.stderr.take().context("failed to take child stderr")?;
81
82    // Shared state
83    let state = Arc::new(Mutex::new(SdkState {
84        state: ShimState::Idle, // SDK mode is immediately ready
85        state_changed_at: Instant::now(),
86        started_at: Instant::now(),
87        session_id: String::new(),
88        accumulated_response: String::new(),
89        pending_message_id: None,
90        message_queue: VecDeque::new(),
91        cumulative_output_bytes: 0,
92    }));
93
94    // Shared stdin writer (used by both command loop and stdout reader for auto-approve)
95    let stdin_writer = Arc::new(Mutex::new(child_stdin));
96
97    // -- PTY log writer (optional — writes readable text, not raw NDJSON) --
98    let pty_log: Option<Arc<Mutex<PtyLogWriter>>> = args
99        .pty_log_path
100        .as_deref()
101        .map(|p| PtyLogWriter::new(p).context("failed to create PTY log"))
102        .transpose()?
103        .map(|w| Arc::new(Mutex::new(w)));
104
105    // -- Channel clones for threads --
106    let mut cmd_channel = channel;
107    let mut evt_channel = cmd_channel
108        .try_clone()
109        .context("failed to clone channel for stdout reader")?;
110
111    // Emit Ready immediately — Claude -p mode accepts input on stdin right away.
112    cmd_channel.send(&Event::Ready)?;
113
114    // -- stdout reader thread --
115    let state_stdout = Arc::clone(&state);
116    let stdin_for_approve = Arc::clone(&stdin_writer);
117    let pty_log_stdout = pty_log.clone();
118    let shim_id = args.id.clone();
119    let stdout_handle = thread::spawn(move || {
120        let reader = BufReader::new(child_stdout);
121        for line_result in reader.lines() {
122            let line = match line_result {
123                Ok(l) => l,
124                Err(e) => {
125                    eprintln!("[shim-sdk {shim_id}] stdout read error: {e}");
126                    break;
127                }
128            };
129
130            if line.trim().is_empty() {
131                continue;
132            }
133
134            let msg: SdkOutput = match serde_json::from_str(&line) {
135                Ok(m) => m,
136                Err(e) => {
137                    eprintln!("[shim-sdk {shim_id}] ignoring unparseable NDJSON line: {e}");
138                    continue;
139                }
140            };
141
142            match msg.msg_type.as_str() {
143                "assistant" => {
144                    // Extract text from the assistant message
145                    if let Some(ref message) = msg.message {
146                        let text = sdk_types::extract_assistant_text(message);
147                        if !text.is_empty() {
148                            let mut st = state_stdout.lock().unwrap();
149                            st.accumulated_response.push_str(&text);
150                            st.cumulative_output_bytes += text.len() as u64;
151
152                            // Update session_id from first response
153                            if st.session_id.is_empty() {
154                                if let Some(ref sid) = msg.session_id {
155                                    st.session_id = sid.clone();
156                                }
157                            }
158                            drop(st);
159
160                            // Write to PTY log for tmux display
161                            if let Some(ref log) = pty_log_stdout {
162                                let _ = log.lock().unwrap().write(text.as_bytes());
163                            }
164                        }
165                    }
166                }
167
168                "stream_event" => {
169                    // Extract incremental text delta
170                    if let Some(ref event) = msg.event {
171                        if let Some(text) = sdk_types::extract_stream_text(event) {
172                            let mut st = state_stdout.lock().unwrap();
173                            st.accumulated_response.push_str(&text);
174                            st.cumulative_output_bytes += text.len() as u64;
175
176                            if st.session_id.is_empty() {
177                                if let Some(ref sid) = msg.session_id {
178                                    st.session_id = sid.clone();
179                                }
180                            }
181                            drop(st);
182
183                            if let Some(ref log) = pty_log_stdout {
184                                let _ = log.lock().unwrap().write(text.as_bytes());
185                            }
186                        }
187                    }
188                }
189
190                "control_request" => {
191                    // Auto-approve tool use requests
192                    if msg.request_subtype().as_deref() == Some("can_use_tool") {
193                        if let (Some(req_id), Some(ref tool_use_id)) =
194                            (msg.request_id.as_ref(), msg.request_tool_use_id())
195                        {
196                            let resp = SdkControlResponse::approve_tool(req_id, tool_use_id);
197                            let ndjson = resp.to_ndjson();
198                            if let Ok(mut writer) = stdin_for_approve.lock() {
199                                let _ = writeln!(writer, "{ndjson}");
200                                let _ = writer.flush();
201                            }
202                        }
203                    }
204                }
205
206                "result" => {
207                    let mut st = state_stdout.lock().unwrap();
208
209                    // Capture session_id
210                    if st.session_id.is_empty() {
211                        if let Some(ref sid) = msg.session_id {
212                            st.session_id = sid.clone();
213                        }
214                    }
215
216                    // Check for context exhaustion
217                    let is_context_exhausted = msg
218                        .errors
219                        .as_ref()
220                        .map(|errs| errs.iter().any(|e| common::detect_context_exhausted(e)))
221                        .unwrap_or(false)
222                        || msg
223                            .result
224                            .as_deref()
225                            .map(common::detect_context_exhausted)
226                            .unwrap_or(false);
227
228                    if is_context_exhausted {
229                        let last_lines = last_n_lines_of(&st.accumulated_response, 5);
230                        let old = st.state;
231                        st.state = ShimState::ContextExhausted;
232                        st.state_changed_at = Instant::now();
233
234                        let drain =
235                            drain_queue_errors(&mut st.message_queue, ShimState::ContextExhausted);
236                        drop(st);
237
238                        let _ = evt_channel.send(&Event::StateChanged {
239                            from: old,
240                            to: ShimState::ContextExhausted,
241                            summary: last_lines.clone(),
242                        });
243                        let _ = evt_channel.send(&Event::ContextExhausted {
244                            message: "Agent reported context exhaustion".into(),
245                            last_lines,
246                        });
247                        for event in drain {
248                            let _ = evt_channel.send(&event);
249                        }
250                        continue;
251                    }
252
253                    // Normal completion: Working → Idle
254                    let response = if st.accumulated_response.is_empty() {
255                        msg.result.clone().unwrap_or_default()
256                    } else {
257                        std::mem::take(&mut st.accumulated_response)
258                    };
259                    let last_lines = last_n_lines_of(&response, 5);
260                    let msg_id = st.pending_message_id.take();
261                    let old = st.state;
262                    st.state = ShimState::Idle;
263                    st.state_changed_at = Instant::now();
264
265                    // Check for queued messages to deliver immediately
266                    let queued_msg = if !st.message_queue.is_empty() {
267                        st.message_queue.pop_front()
268                    } else {
269                        None
270                    };
271
272                    // If injecting a queued message, stay Working
273                    if let Some(ref qm) = queued_msg {
274                        st.pending_message_id = qm.message_id.clone();
275                        st.state = ShimState::Working;
276                        st.state_changed_at = Instant::now();
277                        st.accumulated_response.clear();
278                    }
279
280                    let queue_depth = st.message_queue.len();
281                    let session_id = st.session_id.clone();
282                    drop(st);
283
284                    // Emit completion events
285                    let _ = evt_channel.send(&Event::StateChanged {
286                        from: old,
287                        to: ShimState::Idle,
288                        summary: last_lines.clone(),
289                    });
290                    let _ = evt_channel.send(&Event::Completion {
291                        message_id: msg_id,
292                        response,
293                        last_lines,
294                    });
295
296                    // Inject queued message
297                    if let Some(qm) = queued_msg {
298                        let text = format_injected_message(&qm.from, &qm.body);
299                        let user_msg = SdkUserMessage::new(&session_id, &text);
300                        let ndjson = user_msg.to_ndjson();
301                        if let Ok(mut writer) = stdin_for_approve.lock() {
302                            let _ = writeln!(writer, "{ndjson}");
303                            let _ = writer.flush();
304                        }
305                        let _ = evt_channel.send(&Event::StateChanged {
306                            from: ShimState::Idle,
307                            to: ShimState::Working,
308                            summary: format!("delivering queued message ({queue_depth} remaining)"),
309                        });
310                    }
311                }
312
313                _ => {
314                    // Silently ignore unknown message types (future-proof)
315                }
316            }
317        }
318
319        // stdout EOF — agent process closed
320        let mut st = state_stdout.lock().unwrap();
321        let last_lines = last_n_lines_of(&st.accumulated_response, 10);
322        let old = st.state;
323        st.state = ShimState::Dead;
324        st.state_changed_at = Instant::now();
325
326        let drain = drain_queue_errors(&mut st.message_queue, ShimState::Dead);
327        drop(st);
328
329        let _ = evt_channel.send(&Event::StateChanged {
330            from: old,
331            to: ShimState::Dead,
332            summary: last_lines.clone(),
333        });
334        let _ = evt_channel.send(&Event::Died {
335            exit_code: None,
336            last_lines,
337        });
338        for event in drain {
339            let _ = evt_channel.send(&event);
340        }
341    });
342
343    // -- stderr reader thread --
344    let shim_id_err = args.id.clone();
345    let pty_log_stderr = pty_log;
346    thread::spawn(move || {
347        let reader = BufReader::new(child_stderr);
348        for line_result in reader.lines() {
349            match line_result {
350                Ok(line) => {
351                    eprintln!("[shim-sdk {shim_id_err}] stderr: {line}");
352                    if let Some(ref log) = pty_log_stderr {
353                        let _ = log
354                            .lock()
355                            .unwrap()
356                            .write(format!("[stderr] {line}\n").as_bytes());
357                    }
358                }
359                Err(_) => break,
360            }
361        }
362    });
363
364    // -- Session stats thread --
365    let state_stats = Arc::clone(&state);
366    let mut stats_channel = cmd_channel
367        .try_clone()
368        .context("failed to clone channel for stats")?;
369    thread::spawn(move || {
370        loop {
371            thread::sleep(Duration::from_secs(SESSION_STATS_INTERVAL_SECS));
372            let st = state_stats.lock().unwrap();
373            if st.state == ShimState::Dead {
374                return;
375            }
376            let output_bytes = st.cumulative_output_bytes;
377            let uptime_secs = st.started_at.elapsed().as_secs();
378            drop(st);
379
380            if stats_channel
381                .send(&Event::SessionStats {
382                    output_bytes,
383                    uptime_secs,
384                })
385                .is_err()
386            {
387                return;
388            }
389        }
390    });
391
392    // -- Command loop (main thread) --
393    let state_cmd = Arc::clone(&state);
394    loop {
395        let cmd = match cmd_channel.recv::<ShimCommand>() {
396            Ok(Some(c)) => c,
397            Ok(None) => {
398                eprintln!(
399                    "[shim-sdk {}] orchestrator disconnected, shutting down",
400                    args.id
401                );
402                terminate_child(&mut child);
403                break;
404            }
405            Err(e) => {
406                eprintln!("[shim-sdk {}] channel error: {e}", args.id);
407                terminate_child(&mut child);
408                break;
409            }
410        };
411
412        match cmd {
413            ShimCommand::SendMessage {
414                from,
415                body,
416                message_id,
417            } => {
418                let mut st = state_cmd.lock().unwrap();
419                match st.state {
420                    ShimState::Idle => {
421                        st.pending_message_id = message_id;
422                        st.accumulated_response.clear();
423                        let session_id = st.session_id.clone();
424                        st.state = ShimState::Working;
425                        st.state_changed_at = Instant::now();
426                        drop(st);
427
428                        let text = format_injected_message(&from, &body);
429                        let user_msg = SdkUserMessage::new(&session_id, &text);
430                        let ndjson = user_msg.to_ndjson();
431
432                        if let Ok(mut writer) = stdin_writer.lock() {
433                            if let Err(e) = writeln!(writer, "{ndjson}") {
434                                cmd_channel.send(&Event::Error {
435                                    command: "SendMessage".into(),
436                                    reason: format!("stdin write failed: {e}"),
437                                })?;
438                                continue;
439                            }
440                            let _ = writer.flush();
441                        }
442
443                        cmd_channel.send(&Event::StateChanged {
444                            from: ShimState::Idle,
445                            to: ShimState::Working,
446                            summary: String::new(),
447                        })?;
448                    }
449                    ShimState::Working => {
450                        // Queue the message
451                        if st.message_queue.len() >= MAX_QUEUE_DEPTH {
452                            let dropped = st.message_queue.pop_front();
453                            let dropped_id = dropped.as_ref().and_then(|m| m.message_id.clone());
454                            st.message_queue.push_back(QueuedMessage {
455                                from,
456                                body,
457                                message_id,
458                            });
459                            let depth = st.message_queue.len();
460                            drop(st);
461
462                            cmd_channel.send(&Event::Error {
463                                command: "SendMessage".into(),
464                                reason: format!(
465                                    "message queue full ({MAX_QUEUE_DEPTH}), dropped oldest message{}",
466                                    dropped_id
467                                        .map(|id| format!(" (id: {id})"))
468                                        .unwrap_or_default(),
469                                ),
470                            })?;
471                            cmd_channel.send(&Event::Warning {
472                                message: format!(
473                                    "message queued while agent working (depth: {depth})"
474                                ),
475                                idle_secs: None,
476                            })?;
477                        } else {
478                            st.message_queue.push_back(QueuedMessage {
479                                from,
480                                body,
481                                message_id,
482                            });
483                            let depth = st.message_queue.len();
484                            drop(st);
485
486                            cmd_channel.send(&Event::Warning {
487                                message: format!(
488                                    "message queued while agent working (depth: {depth})"
489                                ),
490                                idle_secs: None,
491                            })?;
492                        }
493                    }
494                    other => {
495                        drop(st);
496                        cmd_channel.send(&Event::Error {
497                            command: "SendMessage".into(),
498                            reason: format!("agent in {other} state, cannot accept message"),
499                        })?;
500                    }
501                }
502            }
503
504            ShimCommand::CaptureScreen { last_n_lines } => {
505                let st = state_cmd.lock().unwrap();
506                let content = match last_n_lines {
507                    Some(n) => last_n_lines_of(&st.accumulated_response, n),
508                    None => st.accumulated_response.clone(),
509                };
510                drop(st);
511                cmd_channel.send(&Event::ScreenCapture {
512                    content,
513                    cursor_row: 0,
514                    cursor_col: 0,
515                })?;
516            }
517
518            ShimCommand::GetState => {
519                let st = state_cmd.lock().unwrap();
520                let since = st.state_changed_at.elapsed().as_secs();
521                let state = st.state;
522                drop(st);
523                cmd_channel.send(&Event::State {
524                    state,
525                    since_secs: since,
526                })?;
527            }
528
529            ShimCommand::Resize { .. } => {
530                // No-op in SDK mode — no PTY to resize.
531            }
532
533            ShimCommand::Ping => {
534                cmd_channel.send(&Event::Pong)?;
535            }
536
537            ShimCommand::Shutdown { timeout_secs } => {
538                eprintln!(
539                    "[shim-sdk {}] shutdown requested (timeout: {}s)",
540                    args.id, timeout_secs
541                );
542                // Close stdin to signal EOF to the subprocess
543                drop(stdin_writer);
544
545                let deadline = Instant::now() + Duration::from_secs(timeout_secs as u64);
546                loop {
547                    if Instant::now() > deadline {
548                        terminate_child(&mut child);
549                        break;
550                    }
551                    match child.try_wait() {
552                        Ok(Some(_)) => break,
553                        _ => thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS)),
554                    }
555                }
556                break;
557            }
558
559            ShimCommand::Kill => {
560                terminate_child(&mut child);
561                break;
562            }
563        }
564    }
565
566    stdout_handle.join().ok();
567    Ok(())
568}
569
570// ---------------------------------------------------------------------------
571// Helpers
572// ---------------------------------------------------------------------------
573
574/// Terminate a child process: SIGTERM, grace period, then SIGKILL.
575fn terminate_child(child: &mut Child) {
576    let pid = child.id();
577
578    #[cfg(unix)]
579    {
580        unsafe {
581            libc::kill(pid as i32, libc::SIGTERM);
582        }
583        let deadline = Instant::now() + Duration::from_secs(GROUP_TERM_GRACE_SECS);
584        loop {
585            if Instant::now() > deadline {
586                break;
587            }
588            match child.try_wait() {
589                Ok(Some(_)) => return,
590                _ => thread::sleep(Duration::from_millis(PROCESS_EXIT_POLL_MS)),
591            }
592        }
593        unsafe {
594            libc::kill(pid as i32, libc::SIGKILL);
595        }
596    }
597
598    #[allow(unreachable_code)]
599    {
600        let _ = child.kill();
601    }
602}
603
604/// Extract the last N lines from a string.
605fn last_n_lines_of(text: &str, n: usize) -> String {
606    let lines: Vec<&str> = text.lines().collect();
607    let start = lines.len().saturating_sub(n);
608    lines[start..].join("\n")
609}
610
611// ---------------------------------------------------------------------------
612// Tests
613// ---------------------------------------------------------------------------
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618    use crate::shim::protocol;
619
620    #[test]
621    fn last_n_lines_basic() {
622        let text = "a\nb\nc\nd\ne";
623        assert_eq!(last_n_lines_of(text, 3), "c\nd\ne");
624        assert_eq!(last_n_lines_of(text, 10), "a\nb\nc\nd\ne");
625        assert_eq!(last_n_lines_of(text, 0), "");
626    }
627
628    #[test]
629    fn last_n_lines_empty() {
630        assert_eq!(last_n_lines_of("", 5), "");
631    }
632
633    #[test]
634    fn sdk_state_initial_values() {
635        let st = SdkState {
636            state: ShimState::Idle,
637            state_changed_at: Instant::now(),
638            started_at: Instant::now(),
639            session_id: String::new(),
640            accumulated_response: String::new(),
641            pending_message_id: None,
642            message_queue: VecDeque::new(),
643            cumulative_output_bytes: 0,
644        };
645        assert_eq!(st.state, ShimState::Idle);
646        assert!(st.session_id.is_empty());
647        assert!(st.message_queue.is_empty());
648    }
649
650    /// Verify that the command loop handles SendMessage in Idle state:
651    /// format a user message NDJSON and transition to Working.
652    #[test]
653    fn user_message_ndjson_format() {
654        let msg = SdkUserMessage::new("sess-abc", "Fix the bug");
655        let json: serde_json::Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
656        assert_eq!(json["type"], "user");
657        assert_eq!(json["session_id"], "sess-abc");
658        assert_eq!(json["message"]["role"], "user");
659        assert_eq!(json["message"]["content"], "Fix the bug");
660    }
661
662    /// Verify that the protocol socketpair still works for our Event types.
663    #[test]
664    fn channel_round_trip_events() {
665        let (parent_sock, child_sock) = protocol::socketpair().unwrap();
666        let mut parent = protocol::Channel::new(parent_sock);
667        let mut child = protocol::Channel::new(child_sock);
668
669        child.send(&Event::Ready).unwrap();
670        let event: Event = parent.recv().unwrap().unwrap();
671        assert!(matches!(event, Event::Ready));
672
673        child
674            .send(&Event::Completion {
675                message_id: Some("m1".into()),
676                response: "done".into(),
677                last_lines: "done".into(),
678            })
679            .unwrap();
680        let event: Event = parent.recv().unwrap().unwrap();
681        match event {
682            Event::Completion {
683                message_id,
684                response,
685                ..
686            } => {
687                assert_eq!(message_id.as_deref(), Some("m1"));
688                assert_eq!(response, "done");
689            }
690            _ => panic!("expected Completion"),
691        }
692    }
693
694    /// Verify context exhaustion detection from SDK result errors.
695    #[test]
696    fn context_exhaustion_from_errors() {
697        assert!(common::detect_context_exhausted("context window exceeded"));
698        assert!(common::detect_context_exhausted(
699            "Error: the conversation is too long"
700        ));
701        assert!(!common::detect_context_exhausted("all good"));
702    }
703}