Skip to main content

batty_cli/shim/
protocol.rs

1//! Wire protocol: Commands (orchestrator→shim) and Events (shim→orchestrator).
2//!
3//! Transport: length-prefixed JSON over a Unix SOCK_STREAM socketpair.
4//! 4-byte big-endian length prefix + JSON payload.
5
6use serde::{Deserialize, Serialize};
7use std::io::{self, Read, Write};
8use std::os::unix::net::UnixStream;
9
10// ---------------------------------------------------------------------------
11// Commands (sent TO the shim)
12// ---------------------------------------------------------------------------
13
14#[derive(Debug, Serialize, Deserialize)]
15#[serde(tag = "cmd")]
16pub enum Command {
17    SendMessage {
18        from: String,
19        body: String,
20        #[serde(skip_serializing_if = "Option::is_none")]
21        message_id: Option<String>,
22    },
23    CaptureScreen {
24        last_n_lines: Option<usize>,
25    },
26    GetState,
27    Resize {
28        rows: u16,
29        cols: u16,
30    },
31    Shutdown {
32        timeout_secs: u32,
33    },
34    Kill,
35    Ping,
36}
37
38// ---------------------------------------------------------------------------
39// Events (sent FROM the shim)
40// ---------------------------------------------------------------------------
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(tag = "event")]
44pub enum Event {
45    Ready,
46    StateChanged {
47        from: ShimState,
48        to: ShimState,
49        summary: String,
50    },
51    Completion {
52        #[serde(skip_serializing_if = "Option::is_none")]
53        message_id: Option<String>,
54        response: String,
55        last_lines: String,
56    },
57    Died {
58        exit_code: Option<i32>,
59        last_lines: String,
60    },
61    ContextExhausted {
62        message: String,
63        last_lines: String,
64    },
65    ScreenCapture {
66        content: String,
67        cursor_row: u16,
68        cursor_col: u16,
69    },
70    State {
71        state: ShimState,
72        since_secs: u64,
73    },
74    Pong,
75    Warning {
76        message: String,
77        idle_secs: Option<u64>,
78    },
79    Error {
80        command: String,
81        reason: String,
82    },
83}
84
85// ---------------------------------------------------------------------------
86// Shim state
87// ---------------------------------------------------------------------------
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum ShimState {
92    Starting,
93    Idle,
94    Working,
95    Dead,
96    ContextExhausted,
97}
98
99impl std::fmt::Display for ShimState {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            Self::Starting => write!(f, "starting"),
103            Self::Idle => write!(f, "idle"),
104            Self::Working => write!(f, "working"),
105            Self::Dead => write!(f, "dead"),
106            Self::ContextExhausted => write!(f, "context_exhausted"),
107        }
108    }
109}
110
111// ---------------------------------------------------------------------------
112// Framed channel over a Unix socket
113// ---------------------------------------------------------------------------
114
115/// Blocking, length-prefixed JSON channel over a Unix stream socket.
116///
117/// Uses 4-byte big-endian length + JSON payload for robustness.
118pub struct Channel {
119    stream: UnixStream,
120    read_buf: Vec<u8>,
121}
122
123const MAX_MSG: usize = 1_048_576; // 1 MB
124
125impl Channel {
126    pub fn new(stream: UnixStream) -> Self {
127        Self {
128            stream,
129            read_buf: vec![0u8; 4096],
130        }
131    }
132
133    /// Send a serializable message.
134    pub fn send<T: Serialize>(&mut self, msg: &T) -> anyhow::Result<()> {
135        let json = serde_json::to_vec(msg)?;
136        if json.len() > MAX_MSG {
137            anyhow::bail!("message too large: {} bytes", json.len());
138        }
139        let len = (json.len() as u32).to_be_bytes();
140        self.stream.write_all(&len)?;
141        self.stream.write_all(&json)?;
142        self.stream.flush()?;
143        Ok(())
144    }
145
146    /// Receive a deserializable message. Blocks until a message arrives.
147    /// Returns Ok(None) on clean EOF (peer closed).
148    pub fn recv<T: for<'de> Deserialize<'de>>(&mut self) -> anyhow::Result<Option<T>> {
149        let mut len_buf = [0u8; 4];
150        match self.stream.read_exact(&mut len_buf) {
151            Ok(()) => {}
152            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
153            Err(e) => return Err(e.into()),
154        }
155        let len = u32::from_be_bytes(len_buf) as usize;
156        if len > MAX_MSG {
157            anyhow::bail!("incoming message too large: {} bytes", len);
158        }
159        if self.read_buf.len() < len {
160            self.read_buf.resize(len, 0);
161        }
162        self.stream.read_exact(&mut self.read_buf[..len])?;
163        let msg = serde_json::from_slice(&self.read_buf[..len])?;
164        Ok(Some(msg))
165    }
166
167    /// Set a read timeout on the underlying socket.
168    /// After this, `recv()` will return an error if no data arrives
169    /// within the given duration (instead of blocking forever).
170    pub fn set_read_timeout(&mut self, timeout: Option<std::time::Duration>) -> anyhow::Result<()> {
171        self.stream.set_read_timeout(timeout)?;
172        Ok(())
173    }
174
175    /// Clone the underlying fd for use in a second thread.
176    pub fn try_clone(&self) -> anyhow::Result<Self> {
177        Ok(Self {
178            stream: self.stream.try_clone()?,
179            read_buf: vec![0u8; 4096],
180        })
181    }
182}
183
184// ---------------------------------------------------------------------------
185// Create a connected socketpair
186// ---------------------------------------------------------------------------
187
188/// Create a connected pair of Unix stream sockets.
189/// Returns (parent_socket, child_socket).
190pub fn socketpair() -> anyhow::Result<(UnixStream, UnixStream)> {
191    let (a, b) = UnixStream::pair()?;
192    Ok((a, b))
193}
194
195// ---------------------------------------------------------------------------
196// Tests
197// ---------------------------------------------------------------------------
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn roundtrip_command_send_message() {
205        let (a, b) = socketpair().unwrap();
206        let mut sender = Channel::new(a);
207        let mut receiver = Channel::new(b);
208
209        let cmd = Command::SendMessage {
210            from: "user".into(),
211            body: "say hello".into(),
212            message_id: Some("msg-1".into()),
213        };
214        sender.send(&cmd).unwrap();
215        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
216
217        match received {
218            Command::SendMessage {
219                from,
220                body,
221                message_id,
222            } => {
223                assert_eq!(from, "user");
224                assert_eq!(body, "say hello");
225                assert_eq!(message_id.as_deref(), Some("msg-1"));
226            }
227            _ => panic!("wrong variant"),
228        }
229    }
230
231    #[test]
232    fn roundtrip_command_capture_screen() {
233        let (a, b) = socketpair().unwrap();
234        let mut sender = Channel::new(a);
235        let mut receiver = Channel::new(b);
236
237        let cmd = Command::CaptureScreen {
238            last_n_lines: Some(10),
239        };
240        sender.send(&cmd).unwrap();
241        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
242        match received {
243            Command::CaptureScreen { last_n_lines } => assert_eq!(last_n_lines, Some(10)),
244            _ => panic!("wrong variant"),
245        }
246    }
247
248    #[test]
249    fn roundtrip_command_get_state() {
250        let (a, b) = socketpair().unwrap();
251        let mut sender = Channel::new(a);
252        let mut receiver = Channel::new(b);
253
254        sender.send(&Command::GetState).unwrap();
255        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
256        assert!(matches!(received, Command::GetState));
257    }
258
259    #[test]
260    fn roundtrip_command_resize() {
261        let (a, b) = socketpair().unwrap();
262        let mut sender = Channel::new(a);
263        let mut receiver = Channel::new(b);
264
265        let cmd = Command::Resize {
266            rows: 50,
267            cols: 220,
268        };
269        sender.send(&cmd).unwrap();
270        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
271        match received {
272            Command::Resize { rows, cols } => {
273                assert_eq!(rows, 50);
274                assert_eq!(cols, 220);
275            }
276            _ => panic!("wrong variant"),
277        }
278    }
279
280    #[test]
281    fn roundtrip_command_shutdown() {
282        let (a, b) = socketpair().unwrap();
283        let mut sender = Channel::new(a);
284        let mut receiver = Channel::new(b);
285
286        let cmd = Command::Shutdown { timeout_secs: 30 };
287        sender.send(&cmd).unwrap();
288        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
289        match received {
290            Command::Shutdown { timeout_secs } => assert_eq!(timeout_secs, 30),
291            _ => panic!("wrong variant"),
292        }
293    }
294
295    #[test]
296    fn roundtrip_command_kill() {
297        let (a, b) = socketpair().unwrap();
298        let mut sender = Channel::new(a);
299        let mut receiver = Channel::new(b);
300
301        sender.send(&Command::Kill).unwrap();
302        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
303        assert!(matches!(received, Command::Kill));
304    }
305
306    #[test]
307    fn roundtrip_command_ping() {
308        let (a, b) = socketpair().unwrap();
309        let mut sender = Channel::new(a);
310        let mut receiver = Channel::new(b);
311
312        sender.send(&Command::Ping).unwrap();
313        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
314        assert!(matches!(received, Command::Ping));
315    }
316
317    #[test]
318    fn roundtrip_event_completion() {
319        let (a, b) = socketpair().unwrap();
320        let mut sender = Channel::new(a);
321        let mut receiver = Channel::new(b);
322
323        let evt = Event::Completion {
324            message_id: None,
325            response: "Hello!".into(),
326            last_lines: "Hello!\n❯".into(),
327        };
328        sender.send(&evt).unwrap();
329        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
330
331        match received {
332            Event::Completion { response, .. } => assert_eq!(response, "Hello!"),
333            _ => panic!("wrong variant"),
334        }
335    }
336
337    #[test]
338    fn roundtrip_event_state_changed() {
339        let (a, b) = socketpair().unwrap();
340        let mut sender = Channel::new(a);
341        let mut receiver = Channel::new(b);
342
343        let evt = Event::StateChanged {
344            from: ShimState::Idle,
345            to: ShimState::Working,
346            summary: "working now".into(),
347        };
348        sender.send(&evt).unwrap();
349        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
350        match received {
351            Event::StateChanged { from, to, summary } => {
352                assert_eq!(from, ShimState::Idle);
353                assert_eq!(to, ShimState::Working);
354                assert_eq!(summary, "working now");
355            }
356            _ => panic!("wrong variant"),
357        }
358    }
359
360    #[test]
361    fn roundtrip_event_ready() {
362        let (a, b) = socketpair().unwrap();
363        let mut sender = Channel::new(a);
364        let mut receiver = Channel::new(b);
365
366        sender.send(&Event::Ready).unwrap();
367        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
368        assert!(matches!(received, Event::Ready));
369    }
370
371    #[test]
372    fn roundtrip_event_pong() {
373        let (a, b) = socketpair().unwrap();
374        let mut sender = Channel::new(a);
375        let mut receiver = Channel::new(b);
376
377        sender.send(&Event::Pong).unwrap();
378        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
379        assert!(matches!(received, Event::Pong));
380    }
381
382    #[test]
383    fn roundtrip_event_died() {
384        let (a, b) = socketpair().unwrap();
385        let mut sender = Channel::new(a);
386        let mut receiver = Channel::new(b);
387
388        let evt = Event::Died {
389            exit_code: Some(1),
390            last_lines: "error occurred".into(),
391        };
392        sender.send(&evt).unwrap();
393        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
394        match received {
395            Event::Died {
396                exit_code,
397                last_lines,
398            } => {
399                assert_eq!(exit_code, Some(1));
400                assert_eq!(last_lines, "error occurred");
401            }
402            _ => panic!("wrong variant"),
403        }
404    }
405
406    #[test]
407    fn roundtrip_event_context_exhausted() {
408        let (a, b) = socketpair().unwrap();
409        let mut sender = Channel::new(a);
410        let mut receiver = Channel::new(b);
411
412        let evt = Event::ContextExhausted {
413            message: "context full".into(),
414            last_lines: "last output".into(),
415        };
416        sender.send(&evt).unwrap();
417        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
418        match received {
419            Event::ContextExhausted {
420                message,
421                last_lines,
422            } => {
423                assert_eq!(message, "context full");
424                assert_eq!(last_lines, "last output");
425            }
426            _ => panic!("wrong variant"),
427        }
428    }
429
430    #[test]
431    fn roundtrip_event_screen_capture() {
432        let (a, b) = socketpair().unwrap();
433        let mut sender = Channel::new(a);
434        let mut receiver = Channel::new(b);
435
436        let evt = Event::ScreenCapture {
437            content: "screen data".into(),
438            cursor_row: 5,
439            cursor_col: 10,
440        };
441        sender.send(&evt).unwrap();
442        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
443        match received {
444            Event::ScreenCapture {
445                content,
446                cursor_row,
447                cursor_col,
448            } => {
449                assert_eq!(content, "screen data");
450                assert_eq!(cursor_row, 5);
451                assert_eq!(cursor_col, 10);
452            }
453            _ => panic!("wrong variant"),
454        }
455    }
456
457    #[test]
458    fn roundtrip_event_error() {
459        let (a, b) = socketpair().unwrap();
460        let mut sender = Channel::new(a);
461        let mut receiver = Channel::new(b);
462
463        let evt = Event::Error {
464            command: "SendMessage".into(),
465            reason: "agent busy".into(),
466        };
467        sender.send(&evt).unwrap();
468        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
469        match received {
470            Event::Error { command, reason } => {
471                assert_eq!(command, "SendMessage");
472                assert_eq!(reason, "agent busy");
473            }
474            _ => panic!("wrong variant"),
475        }
476    }
477
478    #[test]
479    fn roundtrip_event_warning() {
480        let (a, b) = socketpair().unwrap();
481        let mut sender = Channel::new(a);
482        let mut receiver = Channel::new(b);
483
484        let evt = Event::Warning {
485            message: "no screen change".into(),
486            idle_secs: Some(300),
487        };
488        sender.send(&evt).unwrap();
489        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
490        match received {
491            Event::Warning { message, idle_secs } => {
492                assert_eq!(message, "no screen change");
493                assert_eq!(idle_secs, Some(300));
494            }
495            _ => panic!("wrong variant"),
496        }
497    }
498
499    #[test]
500    fn eof_returns_none() {
501        let (a, b) = socketpair().unwrap();
502        drop(a); // close sender
503        let mut receiver = Channel::new(b);
504        let result: Option<Command> = receiver.recv().unwrap();
505        assert!(result.is_none());
506    }
507
508    #[test]
509    fn all_states_serialize() {
510        for state in [
511            ShimState::Starting,
512            ShimState::Idle,
513            ShimState::Working,
514            ShimState::Dead,
515            ShimState::ContextExhausted,
516        ] {
517            let json = serde_json::to_string(&state).unwrap();
518            let back: ShimState = serde_json::from_str(&json).unwrap();
519            assert_eq!(state, back);
520        }
521    }
522
523    #[test]
524    fn shim_state_display() {
525        assert_eq!(ShimState::Starting.to_string(), "starting");
526        assert_eq!(ShimState::Idle.to_string(), "idle");
527        assert_eq!(ShimState::Working.to_string(), "working");
528        assert_eq!(ShimState::Dead.to_string(), "dead");
529        assert_eq!(ShimState::ContextExhausted.to_string(), "context_exhausted");
530    }
531
532    #[test]
533    fn socketpair_creates_connected_pair() {
534        let (a, b) = socketpair().unwrap();
535        // Basic connectivity: write on a, read on b
536        let mut ch_a = Channel::new(a);
537        let mut ch_b = Channel::new(b);
538        ch_a.send(&Command::Ping).unwrap();
539        let msg: Command = ch_b.recv().unwrap().unwrap();
540        assert!(matches!(msg, Command::Ping));
541    }
542}