Skip to main content

batty_cli/shim/
protocol.rs

1//! Wire protocol: Commands (orchestrator→shim) and Events (shim→orchestrator).
2//!
3//! Transport: length-prefixed JSON over a Unix SOCK_STREAM socketpair.
4//! 4-byte big-endian length prefix + JSON payload.
5
6use serde::{Deserialize, Serialize};
7use std::io::{self, Read, Write};
8use std::os::unix::net::UnixStream;
9
10// ---------------------------------------------------------------------------
11// Commands (sent TO the shim)
12// ---------------------------------------------------------------------------
13
14#[derive(Debug, Serialize, Deserialize)]
15#[serde(tag = "cmd")]
16pub enum Command {
17    SendMessage {
18        from: String,
19        body: String,
20        #[serde(skip_serializing_if = "Option::is_none")]
21        message_id: Option<String>,
22    },
23    CaptureScreen {
24        last_n_lines: Option<usize>,
25    },
26    GetState,
27    Resize {
28        rows: u16,
29        cols: u16,
30    },
31    Shutdown {
32        timeout_secs: u32,
33        #[serde(default)]
34        reason: ShutdownReason,
35    },
36    Kill,
37    Ping,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
41#[serde(rename_all = "snake_case")]
42pub enum ShutdownReason {
43    #[default]
44    Requested,
45    RestartHandoff,
46    ContextExhausted,
47    TopologyChange,
48    DaemonStop,
49}
50
51impl ShutdownReason {
52    pub fn label(self) -> &'static str {
53        match self {
54            Self::Requested => "requested",
55            Self::RestartHandoff => "restart_handoff",
56            Self::ContextExhausted => "context_exhausted",
57            Self::TopologyChange => "topology_change",
58            Self::DaemonStop => "daemon_stop",
59        }
60    }
61}
62
63// ---------------------------------------------------------------------------
64// Events (sent FROM the shim)
65// ---------------------------------------------------------------------------
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68#[serde(tag = "event")]
69pub enum Event {
70    Ready,
71    StateChanged {
72        from: ShimState,
73        to: ShimState,
74        summary: String,
75    },
76    MessageDelivered {
77        id: String,
78    },
79    Completion {
80        #[serde(skip_serializing_if = "Option::is_none")]
81        message_id: Option<String>,
82        response: String,
83        last_lines: String,
84    },
85    Died {
86        exit_code: Option<i32>,
87        last_lines: String,
88    },
89    ContextExhausted {
90        message: String,
91        last_lines: String,
92    },
93    ContextWarning {
94        model: Option<String>,
95        output_bytes: u64,
96        uptime_secs: u64,
97        input_tokens: u64,
98        cached_input_tokens: u64,
99        cache_creation_input_tokens: u64,
100        cache_read_input_tokens: u64,
101        output_tokens: u64,
102        reasoning_output_tokens: u64,
103        used_tokens: u64,
104        context_limit_tokens: u64,
105        usage_pct: u8,
106    },
107    ContextApproaching {
108        message: String,
109        input_tokens: u64,
110        output_tokens: u64,
111    },
112    ScreenCapture {
113        content: String,
114        cursor_row: u16,
115        cursor_col: u16,
116    },
117    State {
118        state: ShimState,
119        since_secs: u64,
120    },
121    SessionStats {
122        output_bytes: u64,
123        uptime_secs: u64,
124        #[serde(default)]
125        input_tokens: u64,
126        #[serde(default)]
127        output_tokens: u64,
128        #[serde(default, skip_serializing_if = "Option::is_none")]
129        context_usage_pct: Option<u8>,
130    },
131    Pong,
132    Warning {
133        message: String,
134        idle_secs: Option<u64>,
135    },
136    DeliveryFailed {
137        id: String,
138        reason: String,
139    },
140    Error {
141        command: String,
142        reason: String,
143    },
144}
145
146// ---------------------------------------------------------------------------
147// Shim state
148// ---------------------------------------------------------------------------
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
151#[serde(rename_all = "snake_case")]
152pub enum ShimState {
153    Starting,
154    Idle,
155    Working,
156    Dead,
157    ContextExhausted,
158}
159
160impl std::fmt::Display for ShimState {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            Self::Starting => write!(f, "starting"),
164            Self::Idle => write!(f, "idle"),
165            Self::Working => write!(f, "working"),
166            Self::Dead => write!(f, "dead"),
167            Self::ContextExhausted => write!(f, "context_exhausted"),
168        }
169    }
170}
171
172// ---------------------------------------------------------------------------
173// Framed channel over a Unix socket
174// ---------------------------------------------------------------------------
175
176/// Blocking, length-prefixed JSON channel over a Unix stream socket.
177///
178/// Uses 4-byte big-endian length + JSON payload for robustness.
179pub struct Channel {
180    stream: UnixStream,
181    read_buf: Vec<u8>,
182}
183
184const MAX_MSG: usize = 1_048_576; // 1 MB
185
186impl Channel {
187    pub fn new(stream: UnixStream) -> Self {
188        Self {
189            stream,
190            read_buf: vec![0u8; 4096],
191        }
192    }
193
194    /// Send a serializable message.
195    pub fn send<T: Serialize>(&mut self, msg: &T) -> anyhow::Result<()> {
196        let json = serde_json::to_vec(msg)?;
197        if json.len() > MAX_MSG {
198            anyhow::bail!("message too large: {} bytes", json.len());
199        }
200        let len = (json.len() as u32).to_be_bytes();
201        self.stream.write_all(&len)?;
202        self.stream.write_all(&json)?;
203        self.stream.flush()?;
204        Ok(())
205    }
206
207    /// Receive a deserializable message. Blocks until a message arrives.
208    /// Returns Ok(None) on clean EOF (peer closed).
209    pub fn recv<T: for<'de> Deserialize<'de>>(&mut self) -> anyhow::Result<Option<T>> {
210        let mut len_buf = [0u8; 4];
211        match self.stream.read_exact(&mut len_buf) {
212            Ok(()) => {}
213            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
214            Err(e) => return Err(e.into()),
215        }
216        let len = u32::from_be_bytes(len_buf) as usize;
217        if len > MAX_MSG {
218            anyhow::bail!("incoming message too large: {} bytes", len);
219        }
220        if self.read_buf.len() < len {
221            self.read_buf.resize(len, 0);
222        }
223        self.stream.read_exact(&mut self.read_buf[..len])?;
224        let msg = serde_json::from_slice(&self.read_buf[..len])?;
225        Ok(Some(msg))
226    }
227
228    /// Set a read timeout on the underlying socket.
229    /// After this, `recv()` will return an error if no data arrives
230    /// within the given duration (instead of blocking forever).
231    pub fn set_read_timeout(&mut self, timeout: Option<std::time::Duration>) -> anyhow::Result<()> {
232        self.stream.set_read_timeout(timeout)?;
233        Ok(())
234    }
235
236    /// Clone the underlying fd for use in a second thread.
237    pub fn try_clone(&self) -> anyhow::Result<Self> {
238        Ok(Self {
239            stream: self.stream.try_clone()?,
240            read_buf: vec![0u8; 4096],
241        })
242    }
243}
244
245// ---------------------------------------------------------------------------
246// Create a connected socketpair
247// ---------------------------------------------------------------------------
248
249/// Create a connected pair of Unix stream sockets.
250/// Returns (parent_socket, child_socket).
251pub fn socketpair() -> anyhow::Result<(UnixStream, UnixStream)> {
252    let (a, b) = UnixStream::pair()?;
253    Ok((a, b))
254}
255
256// ---------------------------------------------------------------------------
257// Tests
258// ---------------------------------------------------------------------------
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn roundtrip_command_send_message() {
266        let (a, b) = socketpair().unwrap();
267        let mut sender = Channel::new(a);
268        let mut receiver = Channel::new(b);
269
270        let cmd = Command::SendMessage {
271            from: "user".into(),
272            body: "say hello".into(),
273            message_id: Some("msg-1".into()),
274        };
275        sender.send(&cmd).unwrap();
276        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
277
278        match received {
279            Command::SendMessage {
280                from,
281                body,
282                message_id,
283            } => {
284                assert_eq!(from, "user");
285                assert_eq!(body, "say hello");
286                assert_eq!(message_id.as_deref(), Some("msg-1"));
287            }
288            _ => panic!("wrong variant"),
289        }
290    }
291
292    #[test]
293    fn roundtrip_command_capture_screen() {
294        let (a, b) = socketpair().unwrap();
295        let mut sender = Channel::new(a);
296        let mut receiver = Channel::new(b);
297
298        let cmd = Command::CaptureScreen {
299            last_n_lines: Some(10),
300        };
301        sender.send(&cmd).unwrap();
302        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
303        match received {
304            Command::CaptureScreen { last_n_lines } => assert_eq!(last_n_lines, Some(10)),
305            _ => panic!("wrong variant"),
306        }
307    }
308
309    #[test]
310    fn roundtrip_command_get_state() {
311        let (a, b) = socketpair().unwrap();
312        let mut sender = Channel::new(a);
313        let mut receiver = Channel::new(b);
314
315        sender.send(&Command::GetState).unwrap();
316        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
317        assert!(matches!(received, Command::GetState));
318    }
319
320    #[test]
321    fn roundtrip_command_resize() {
322        let (a, b) = socketpair().unwrap();
323        let mut sender = Channel::new(a);
324        let mut receiver = Channel::new(b);
325
326        let cmd = Command::Resize {
327            rows: 50,
328            cols: 220,
329        };
330        sender.send(&cmd).unwrap();
331        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
332        match received {
333            Command::Resize { rows, cols } => {
334                assert_eq!(rows, 50);
335                assert_eq!(cols, 220);
336            }
337            _ => panic!("wrong variant"),
338        }
339    }
340
341    #[test]
342    fn roundtrip_command_shutdown() {
343        let (a, b) = socketpair().unwrap();
344        let mut sender = Channel::new(a);
345        let mut receiver = Channel::new(b);
346
347        let cmd = Command::Shutdown {
348            timeout_secs: 30,
349            reason: ShutdownReason::Requested,
350        };
351        sender.send(&cmd).unwrap();
352        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
353        match received {
354            Command::Shutdown {
355                timeout_secs,
356                reason,
357            } => {
358                assert_eq!(timeout_secs, 30);
359                assert_eq!(reason, ShutdownReason::Requested);
360            }
361            _ => panic!("wrong variant"),
362        }
363    }
364
365    #[test]
366    fn shutdown_reason_labels_restart_handoff_explicitly() {
367        assert_eq!(ShutdownReason::RestartHandoff.label(), "restart_handoff");
368        assert_ne!(
369            ShutdownReason::RestartHandoff.label(),
370            "orchestrator disconnected"
371        );
372    }
373
374    #[test]
375    fn roundtrip_command_kill() {
376        let (a, b) = socketpair().unwrap();
377        let mut sender = Channel::new(a);
378        let mut receiver = Channel::new(b);
379
380        sender.send(&Command::Kill).unwrap();
381        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
382        assert!(matches!(received, Command::Kill));
383    }
384
385    #[test]
386    fn roundtrip_command_ping() {
387        let (a, b) = socketpair().unwrap();
388        let mut sender = Channel::new(a);
389        let mut receiver = Channel::new(b);
390
391        sender.send(&Command::Ping).unwrap();
392        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
393        assert!(matches!(received, Command::Ping));
394    }
395
396    #[test]
397    fn roundtrip_event_completion() {
398        let (a, b) = socketpair().unwrap();
399        let mut sender = Channel::new(a);
400        let mut receiver = Channel::new(b);
401
402        let evt = Event::Completion {
403            message_id: None,
404            response: "Hello!".into(),
405            last_lines: "Hello!\n❯".into(),
406        };
407        sender.send(&evt).unwrap();
408        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
409
410        match received {
411            Event::Completion { response, .. } => assert_eq!(response, "Hello!"),
412            _ => panic!("wrong variant"),
413        }
414    }
415
416    #[test]
417    fn roundtrip_event_message_delivered() {
418        let (a, b) = socketpair().unwrap();
419        let mut sender = Channel::new(a);
420        let mut receiver = Channel::new(b);
421
422        let evt = Event::MessageDelivered { id: "msg-1".into() };
423        sender.send(&evt).unwrap();
424        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
425
426        match received {
427            Event::MessageDelivered { id } => assert_eq!(id, "msg-1"),
428            _ => panic!("wrong variant"),
429        }
430    }
431
432    #[test]
433    fn roundtrip_event_state_changed() {
434        let (a, b) = socketpair().unwrap();
435        let mut sender = Channel::new(a);
436        let mut receiver = Channel::new(b);
437
438        let evt = Event::StateChanged {
439            from: ShimState::Idle,
440            to: ShimState::Working,
441            summary: "working now".into(),
442        };
443        sender.send(&evt).unwrap();
444        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
445        match received {
446            Event::StateChanged { from, to, summary } => {
447                assert_eq!(from, ShimState::Idle);
448                assert_eq!(to, ShimState::Working);
449                assert_eq!(summary, "working now");
450            }
451            _ => panic!("wrong variant"),
452        }
453    }
454
455    #[test]
456    fn roundtrip_event_ready() {
457        let (a, b) = socketpair().unwrap();
458        let mut sender = Channel::new(a);
459        let mut receiver = Channel::new(b);
460
461        sender.send(&Event::Ready).unwrap();
462        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
463        assert!(matches!(received, Event::Ready));
464    }
465
466    #[test]
467    fn roundtrip_event_pong() {
468        let (a, b) = socketpair().unwrap();
469        let mut sender = Channel::new(a);
470        let mut receiver = Channel::new(b);
471
472        sender.send(&Event::Pong).unwrap();
473        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
474        assert!(matches!(received, Event::Pong));
475    }
476
477    #[test]
478    fn roundtrip_event_delivery_failed() {
479        let (a, b) = socketpair().unwrap();
480        let mut sender = Channel::new(a);
481        let mut receiver = Channel::new(b);
482
483        let evt = Event::DeliveryFailed {
484            id: "msg-1".into(),
485            reason: "stdin write failed".into(),
486        };
487        sender.send(&evt).unwrap();
488        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
489
490        match received {
491            Event::DeliveryFailed { id, reason } => {
492                assert_eq!(id, "msg-1");
493                assert_eq!(reason, "stdin write failed");
494            }
495            _ => panic!("wrong variant"),
496        }
497    }
498
499    #[test]
500    fn roundtrip_event_died() {
501        let (a, b) = socketpair().unwrap();
502        let mut sender = Channel::new(a);
503        let mut receiver = Channel::new(b);
504
505        let evt = Event::Died {
506            exit_code: Some(1),
507            last_lines: "error occurred".into(),
508        };
509        sender.send(&evt).unwrap();
510        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
511        match received {
512            Event::Died {
513                exit_code,
514                last_lines,
515            } => {
516                assert_eq!(exit_code, Some(1));
517                assert_eq!(last_lines, "error occurred");
518            }
519            _ => panic!("wrong variant"),
520        }
521    }
522
523    #[test]
524    fn roundtrip_event_context_exhausted() {
525        let (a, b) = socketpair().unwrap();
526        let mut sender = Channel::new(a);
527        let mut receiver = Channel::new(b);
528
529        let evt = Event::ContextExhausted {
530            message: "context full".into(),
531            last_lines: "last output".into(),
532        };
533        sender.send(&evt).unwrap();
534        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
535        match received {
536            Event::ContextExhausted {
537                message,
538                last_lines,
539            } => {
540                assert_eq!(message, "context full");
541                assert_eq!(last_lines, "last output");
542            }
543            _ => panic!("wrong variant"),
544        }
545    }
546
547    #[test]
548    fn roundtrip_event_screen_capture() {
549        let (a, b) = socketpair().unwrap();
550        let mut sender = Channel::new(a);
551        let mut receiver = Channel::new(b);
552
553        let evt = Event::ScreenCapture {
554            content: "screen data".into(),
555            cursor_row: 5,
556            cursor_col: 10,
557        };
558        sender.send(&evt).unwrap();
559        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
560        match received {
561            Event::ScreenCapture {
562                content,
563                cursor_row,
564                cursor_col,
565            } => {
566                assert_eq!(content, "screen data");
567                assert_eq!(cursor_row, 5);
568                assert_eq!(cursor_col, 10);
569            }
570            _ => panic!("wrong variant"),
571        }
572    }
573
574    #[test]
575    fn roundtrip_event_context_warning() {
576        let (a, b) = socketpair().unwrap();
577        let mut sender = Channel::new(a);
578        let mut receiver = Channel::new(b);
579
580        let evt = Event::ContextWarning {
581            model: Some("claude-sonnet-4-5".into()),
582            output_bytes: 12_345,
583            uptime_secs: 61,
584            input_tokens: 80_000,
585            cached_input_tokens: 5_000,
586            cache_creation_input_tokens: 4_000,
587            cache_read_input_tokens: 3_000,
588            output_tokens: 6_000,
589            reasoning_output_tokens: 2_000,
590            used_tokens: 100_000,
591            context_limit_tokens: 200_000,
592            usage_pct: 50,
593        };
594        sender.send(&evt).unwrap();
595        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
596        match received {
597            Event::ContextWarning {
598                model,
599                output_bytes,
600                uptime_secs,
601                input_tokens,
602                cached_input_tokens,
603                cache_creation_input_tokens,
604                cache_read_input_tokens,
605                output_tokens,
606                reasoning_output_tokens,
607                used_tokens,
608                context_limit_tokens,
609                usage_pct,
610            } => {
611                assert_eq!(model.as_deref(), Some("claude-sonnet-4-5"));
612                assert_eq!(output_bytes, 12_345);
613                assert_eq!(uptime_secs, 61);
614                assert_eq!(input_tokens, 80_000);
615                assert_eq!(cached_input_tokens, 5_000);
616                assert_eq!(cache_creation_input_tokens, 4_000);
617                assert_eq!(cache_read_input_tokens, 3_000);
618                assert_eq!(output_tokens, 6_000);
619                assert_eq!(reasoning_output_tokens, 2_000);
620                assert_eq!(used_tokens, 100_000);
621                assert_eq!(context_limit_tokens, 200_000);
622                assert_eq!(usage_pct, 50);
623            }
624            _ => panic!("wrong variant"),
625        }
626    }
627
628    #[test]
629    fn roundtrip_event_session_stats() {
630        let (a, b) = socketpair().unwrap();
631        let mut sender = Channel::new(a);
632        let mut receiver = Channel::new(b);
633
634        let evt = Event::SessionStats {
635            output_bytes: 123_456,
636            uptime_secs: 61,
637            input_tokens: 5000,
638            output_tokens: 1200,
639            context_usage_pct: Some(84),
640        };
641        sender.send(&evt).unwrap();
642        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
643        match received {
644            Event::SessionStats {
645                output_bytes,
646                uptime_secs,
647                input_tokens,
648                output_tokens,
649                context_usage_pct,
650            } => {
651                assert_eq!(output_bytes, 123_456);
652                assert_eq!(uptime_secs, 61);
653                assert_eq!(input_tokens, 5000);
654                assert_eq!(output_tokens, 1200);
655                assert_eq!(context_usage_pct, Some(84));
656            }
657            _ => panic!("wrong variant"),
658        }
659    }
660
661    #[test]
662    fn roundtrip_event_context_approaching() {
663        let (a, b) = socketpair().unwrap();
664        let mut sender = Channel::new(a);
665        let mut receiver = Channel::new(b);
666
667        let evt = Event::ContextApproaching {
668            message: "context pressure detected".into(),
669            input_tokens: 80000,
670            output_tokens: 20000,
671        };
672        sender.send(&evt).unwrap();
673        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
674        match received {
675            Event::ContextApproaching {
676                message,
677                input_tokens,
678                output_tokens,
679            } => {
680                assert_eq!(message, "context pressure detected");
681                assert_eq!(input_tokens, 80000);
682                assert_eq!(output_tokens, 20000);
683            }
684            _ => panic!("wrong variant"),
685        }
686    }
687
688    #[test]
689    fn roundtrip_event_error() {
690        let (a, b) = socketpair().unwrap();
691        let mut sender = Channel::new(a);
692        let mut receiver = Channel::new(b);
693
694        let evt = Event::Error {
695            command: "SendMessage".into(),
696            reason: "agent busy".into(),
697        };
698        sender.send(&evt).unwrap();
699        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
700        match received {
701            Event::Error { command, reason } => {
702                assert_eq!(command, "SendMessage");
703                assert_eq!(reason, "agent busy");
704            }
705            _ => panic!("wrong variant"),
706        }
707    }
708
709    #[test]
710    fn roundtrip_event_warning() {
711        let (a, b) = socketpair().unwrap();
712        let mut sender = Channel::new(a);
713        let mut receiver = Channel::new(b);
714
715        let evt = Event::Warning {
716            message: "no screen change".into(),
717            idle_secs: Some(300),
718        };
719        sender.send(&evt).unwrap();
720        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
721        match received {
722            Event::Warning { message, idle_secs } => {
723                assert_eq!(message, "no screen change");
724                assert_eq!(idle_secs, Some(300));
725            }
726            _ => panic!("wrong variant"),
727        }
728    }
729
730    #[test]
731    fn eof_returns_none() {
732        let (a, b) = socketpair().unwrap();
733        drop(a); // close sender
734        let mut receiver = Channel::new(b);
735        let result: Option<Command> = receiver.recv().unwrap();
736        assert!(result.is_none());
737    }
738
739    #[test]
740    fn all_states_serialize() {
741        for state in [
742            ShimState::Starting,
743            ShimState::Idle,
744            ShimState::Working,
745            ShimState::Dead,
746            ShimState::ContextExhausted,
747        ] {
748            let json = serde_json::to_string(&state).unwrap();
749            let back: ShimState = serde_json::from_str(&json).unwrap();
750            assert_eq!(state, back);
751        }
752    }
753
754    #[test]
755    fn shim_state_display() {
756        assert_eq!(ShimState::Starting.to_string(), "starting");
757        assert_eq!(ShimState::Idle.to_string(), "idle");
758        assert_eq!(ShimState::Working.to_string(), "working");
759        assert_eq!(ShimState::Dead.to_string(), "dead");
760        assert_eq!(ShimState::ContextExhausted.to_string(), "context_exhausted");
761    }
762
763    #[test]
764    fn socketpair_creates_connected_pair() {
765        let (a, b) = socketpair().unwrap();
766        // Basic connectivity: write on a, read on b
767        let mut ch_a = Channel::new(a);
768        let mut ch_b = Channel::new(b);
769        ch_a.send(&Command::Ping).unwrap();
770        let msg: Command = ch_b.recv().unwrap().unwrap();
771        assert!(matches!(msg, Command::Ping));
772    }
773}