Skip to main content

batty_cli/shim/
protocol.rs

1//! Wire protocol: Commands (orchestrator→shim) and Events (shim→orchestrator).
2//!
3//! Transport: length-prefixed JSON over a Unix SOCK_STREAM socketpair.
4//! 4-byte big-endian length prefix + JSON payload.
5
6use serde::{Deserialize, Serialize};
7use std::io::{self, Read, Write};
8use std::os::unix::net::UnixStream;
9
10// ---------------------------------------------------------------------------
11// Commands (sent TO the shim)
12// ---------------------------------------------------------------------------
13
14#[derive(Debug, Serialize, Deserialize)]
15#[serde(tag = "cmd")]
16pub enum Command {
17    SendMessage {
18        from: String,
19        body: String,
20        #[serde(skip_serializing_if = "Option::is_none")]
21        message_id: Option<String>,
22    },
23    CaptureScreen {
24        last_n_lines: Option<usize>,
25    },
26    GetState,
27    Resize {
28        rows: u16,
29        cols: u16,
30    },
31    Shutdown {
32        timeout_secs: u32,
33    },
34    Kill,
35    Ping,
36}
37
38// ---------------------------------------------------------------------------
39// Events (sent FROM the shim)
40// ---------------------------------------------------------------------------
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(tag = "event")]
44pub enum Event {
45    Ready,
46    StateChanged {
47        from: ShimState,
48        to: ShimState,
49        summary: String,
50    },
51    Completion {
52        #[serde(skip_serializing_if = "Option::is_none")]
53        message_id: Option<String>,
54        response: String,
55        last_lines: String,
56    },
57    Died {
58        exit_code: Option<i32>,
59        last_lines: String,
60    },
61    ContextExhausted {
62        message: String,
63        last_lines: String,
64    },
65    ScreenCapture {
66        content: String,
67        cursor_row: u16,
68        cursor_col: u16,
69    },
70    State {
71        state: ShimState,
72        since_secs: u64,
73    },
74    SessionStats {
75        output_bytes: u64,
76        uptime_secs: u64,
77    },
78    Pong,
79    Warning {
80        message: String,
81        idle_secs: Option<u64>,
82    },
83    Error {
84        command: String,
85        reason: String,
86    },
87}
88
89// ---------------------------------------------------------------------------
90// Shim state
91// ---------------------------------------------------------------------------
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum ShimState {
96    Starting,
97    Idle,
98    Working,
99    Dead,
100    ContextExhausted,
101}
102
103impl std::fmt::Display for ShimState {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match self {
106            Self::Starting => write!(f, "starting"),
107            Self::Idle => write!(f, "idle"),
108            Self::Working => write!(f, "working"),
109            Self::Dead => write!(f, "dead"),
110            Self::ContextExhausted => write!(f, "context_exhausted"),
111        }
112    }
113}
114
115// ---------------------------------------------------------------------------
116// Framed channel over a Unix socket
117// ---------------------------------------------------------------------------
118
119/// Blocking, length-prefixed JSON channel over a Unix stream socket.
120///
121/// Uses 4-byte big-endian length + JSON payload for robustness.
122pub struct Channel {
123    stream: UnixStream,
124    read_buf: Vec<u8>,
125}
126
127const MAX_MSG: usize = 1_048_576; // 1 MB
128
129impl Channel {
130    pub fn new(stream: UnixStream) -> Self {
131        Self {
132            stream,
133            read_buf: vec![0u8; 4096],
134        }
135    }
136
137    /// Send a serializable message.
138    pub fn send<T: Serialize>(&mut self, msg: &T) -> anyhow::Result<()> {
139        let json = serde_json::to_vec(msg)?;
140        if json.len() > MAX_MSG {
141            anyhow::bail!("message too large: {} bytes", json.len());
142        }
143        let len = (json.len() as u32).to_be_bytes();
144        self.stream.write_all(&len)?;
145        self.stream.write_all(&json)?;
146        self.stream.flush()?;
147        Ok(())
148    }
149
150    /// Receive a deserializable message. Blocks until a message arrives.
151    /// Returns Ok(None) on clean EOF (peer closed).
152    pub fn recv<T: for<'de> Deserialize<'de>>(&mut self) -> anyhow::Result<Option<T>> {
153        let mut len_buf = [0u8; 4];
154        match self.stream.read_exact(&mut len_buf) {
155            Ok(()) => {}
156            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
157            Err(e) => return Err(e.into()),
158        }
159        let len = u32::from_be_bytes(len_buf) as usize;
160        if len > MAX_MSG {
161            anyhow::bail!("incoming message too large: {} bytes", len);
162        }
163        if self.read_buf.len() < len {
164            self.read_buf.resize(len, 0);
165        }
166        self.stream.read_exact(&mut self.read_buf[..len])?;
167        let msg = serde_json::from_slice(&self.read_buf[..len])?;
168        Ok(Some(msg))
169    }
170
171    /// Set a read timeout on the underlying socket.
172    /// After this, `recv()` will return an error if no data arrives
173    /// within the given duration (instead of blocking forever).
174    pub fn set_read_timeout(&mut self, timeout: Option<std::time::Duration>) -> anyhow::Result<()> {
175        self.stream.set_read_timeout(timeout)?;
176        Ok(())
177    }
178
179    /// Clone the underlying fd for use in a second thread.
180    pub fn try_clone(&self) -> anyhow::Result<Self> {
181        Ok(Self {
182            stream: self.stream.try_clone()?,
183            read_buf: vec![0u8; 4096],
184        })
185    }
186}
187
188// ---------------------------------------------------------------------------
189// Create a connected socketpair
190// ---------------------------------------------------------------------------
191
192/// Create a connected pair of Unix stream sockets.
193/// Returns (parent_socket, child_socket).
194pub fn socketpair() -> anyhow::Result<(UnixStream, UnixStream)> {
195    let (a, b) = UnixStream::pair()?;
196    Ok((a, b))
197}
198
199// ---------------------------------------------------------------------------
200// Tests
201// ---------------------------------------------------------------------------
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn roundtrip_command_send_message() {
209        let (a, b) = socketpair().unwrap();
210        let mut sender = Channel::new(a);
211        let mut receiver = Channel::new(b);
212
213        let cmd = Command::SendMessage {
214            from: "user".into(),
215            body: "say hello".into(),
216            message_id: Some("msg-1".into()),
217        };
218        sender.send(&cmd).unwrap();
219        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
220
221        match received {
222            Command::SendMessage {
223                from,
224                body,
225                message_id,
226            } => {
227                assert_eq!(from, "user");
228                assert_eq!(body, "say hello");
229                assert_eq!(message_id.as_deref(), Some("msg-1"));
230            }
231            _ => panic!("wrong variant"),
232        }
233    }
234
235    #[test]
236    fn roundtrip_command_capture_screen() {
237        let (a, b) = socketpair().unwrap();
238        let mut sender = Channel::new(a);
239        let mut receiver = Channel::new(b);
240
241        let cmd = Command::CaptureScreen {
242            last_n_lines: Some(10),
243        };
244        sender.send(&cmd).unwrap();
245        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
246        match received {
247            Command::CaptureScreen { last_n_lines } => assert_eq!(last_n_lines, Some(10)),
248            _ => panic!("wrong variant"),
249        }
250    }
251
252    #[test]
253    fn roundtrip_command_get_state() {
254        let (a, b) = socketpair().unwrap();
255        let mut sender = Channel::new(a);
256        let mut receiver = Channel::new(b);
257
258        sender.send(&Command::GetState).unwrap();
259        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
260        assert!(matches!(received, Command::GetState));
261    }
262
263    #[test]
264    fn roundtrip_command_resize() {
265        let (a, b) = socketpair().unwrap();
266        let mut sender = Channel::new(a);
267        let mut receiver = Channel::new(b);
268
269        let cmd = Command::Resize {
270            rows: 50,
271            cols: 220,
272        };
273        sender.send(&cmd).unwrap();
274        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
275        match received {
276            Command::Resize { rows, cols } => {
277                assert_eq!(rows, 50);
278                assert_eq!(cols, 220);
279            }
280            _ => panic!("wrong variant"),
281        }
282    }
283
284    #[test]
285    fn roundtrip_command_shutdown() {
286        let (a, b) = socketpair().unwrap();
287        let mut sender = Channel::new(a);
288        let mut receiver = Channel::new(b);
289
290        let cmd = Command::Shutdown { timeout_secs: 30 };
291        sender.send(&cmd).unwrap();
292        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
293        match received {
294            Command::Shutdown { timeout_secs } => assert_eq!(timeout_secs, 30),
295            _ => panic!("wrong variant"),
296        }
297    }
298
299    #[test]
300    fn roundtrip_command_kill() {
301        let (a, b) = socketpair().unwrap();
302        let mut sender = Channel::new(a);
303        let mut receiver = Channel::new(b);
304
305        sender.send(&Command::Kill).unwrap();
306        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
307        assert!(matches!(received, Command::Kill));
308    }
309
310    #[test]
311    fn roundtrip_command_ping() {
312        let (a, b) = socketpair().unwrap();
313        let mut sender = Channel::new(a);
314        let mut receiver = Channel::new(b);
315
316        sender.send(&Command::Ping).unwrap();
317        let received: Command = receiver.recv::<Command>().unwrap().unwrap();
318        assert!(matches!(received, Command::Ping));
319    }
320
321    #[test]
322    fn roundtrip_event_completion() {
323        let (a, b) = socketpair().unwrap();
324        let mut sender = Channel::new(a);
325        let mut receiver = Channel::new(b);
326
327        let evt = Event::Completion {
328            message_id: None,
329            response: "Hello!".into(),
330            last_lines: "Hello!\n❯".into(),
331        };
332        sender.send(&evt).unwrap();
333        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
334
335        match received {
336            Event::Completion { response, .. } => assert_eq!(response, "Hello!"),
337            _ => panic!("wrong variant"),
338        }
339    }
340
341    #[test]
342    fn roundtrip_event_state_changed() {
343        let (a, b) = socketpair().unwrap();
344        let mut sender = Channel::new(a);
345        let mut receiver = Channel::new(b);
346
347        let evt = Event::StateChanged {
348            from: ShimState::Idle,
349            to: ShimState::Working,
350            summary: "working now".into(),
351        };
352        sender.send(&evt).unwrap();
353        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
354        match received {
355            Event::StateChanged { from, to, summary } => {
356                assert_eq!(from, ShimState::Idle);
357                assert_eq!(to, ShimState::Working);
358                assert_eq!(summary, "working now");
359            }
360            _ => panic!("wrong variant"),
361        }
362    }
363
364    #[test]
365    fn roundtrip_event_ready() {
366        let (a, b) = socketpair().unwrap();
367        let mut sender = Channel::new(a);
368        let mut receiver = Channel::new(b);
369
370        sender.send(&Event::Ready).unwrap();
371        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
372        assert!(matches!(received, Event::Ready));
373    }
374
375    #[test]
376    fn roundtrip_event_pong() {
377        let (a, b) = socketpair().unwrap();
378        let mut sender = Channel::new(a);
379        let mut receiver = Channel::new(b);
380
381        sender.send(&Event::Pong).unwrap();
382        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
383        assert!(matches!(received, Event::Pong));
384    }
385
386    #[test]
387    fn roundtrip_event_died() {
388        let (a, b) = socketpair().unwrap();
389        let mut sender = Channel::new(a);
390        let mut receiver = Channel::new(b);
391
392        let evt = Event::Died {
393            exit_code: Some(1),
394            last_lines: "error occurred".into(),
395        };
396        sender.send(&evt).unwrap();
397        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
398        match received {
399            Event::Died {
400                exit_code,
401                last_lines,
402            } => {
403                assert_eq!(exit_code, Some(1));
404                assert_eq!(last_lines, "error occurred");
405            }
406            _ => panic!("wrong variant"),
407        }
408    }
409
410    #[test]
411    fn roundtrip_event_context_exhausted() {
412        let (a, b) = socketpair().unwrap();
413        let mut sender = Channel::new(a);
414        let mut receiver = Channel::new(b);
415
416        let evt = Event::ContextExhausted {
417            message: "context full".into(),
418            last_lines: "last output".into(),
419        };
420        sender.send(&evt).unwrap();
421        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
422        match received {
423            Event::ContextExhausted {
424                message,
425                last_lines,
426            } => {
427                assert_eq!(message, "context full");
428                assert_eq!(last_lines, "last output");
429            }
430            _ => panic!("wrong variant"),
431        }
432    }
433
434    #[test]
435    fn roundtrip_event_screen_capture() {
436        let (a, b) = socketpair().unwrap();
437        let mut sender = Channel::new(a);
438        let mut receiver = Channel::new(b);
439
440        let evt = Event::ScreenCapture {
441            content: "screen data".into(),
442            cursor_row: 5,
443            cursor_col: 10,
444        };
445        sender.send(&evt).unwrap();
446        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
447        match received {
448            Event::ScreenCapture {
449                content,
450                cursor_row,
451                cursor_col,
452            } => {
453                assert_eq!(content, "screen data");
454                assert_eq!(cursor_row, 5);
455                assert_eq!(cursor_col, 10);
456            }
457            _ => panic!("wrong variant"),
458        }
459    }
460
461    #[test]
462    fn roundtrip_event_session_stats() {
463        let (a, b) = socketpair().unwrap();
464        let mut sender = Channel::new(a);
465        let mut receiver = Channel::new(b);
466
467        let evt = Event::SessionStats {
468            output_bytes: 123_456,
469            uptime_secs: 61,
470        };
471        sender.send(&evt).unwrap();
472        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
473        match received {
474            Event::SessionStats {
475                output_bytes,
476                uptime_secs,
477            } => {
478                assert_eq!(output_bytes, 123_456);
479                assert_eq!(uptime_secs, 61);
480            }
481            _ => panic!("wrong variant"),
482        }
483    }
484
485    #[test]
486    fn roundtrip_event_error() {
487        let (a, b) = socketpair().unwrap();
488        let mut sender = Channel::new(a);
489        let mut receiver = Channel::new(b);
490
491        let evt = Event::Error {
492            command: "SendMessage".into(),
493            reason: "agent busy".into(),
494        };
495        sender.send(&evt).unwrap();
496        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
497        match received {
498            Event::Error { command, reason } => {
499                assert_eq!(command, "SendMessage");
500                assert_eq!(reason, "agent busy");
501            }
502            _ => panic!("wrong variant"),
503        }
504    }
505
506    #[test]
507    fn roundtrip_event_warning() {
508        let (a, b) = socketpair().unwrap();
509        let mut sender = Channel::new(a);
510        let mut receiver = Channel::new(b);
511
512        let evt = Event::Warning {
513            message: "no screen change".into(),
514            idle_secs: Some(300),
515        };
516        sender.send(&evt).unwrap();
517        let received: Event = receiver.recv::<Event>().unwrap().unwrap();
518        match received {
519            Event::Warning { message, idle_secs } => {
520                assert_eq!(message, "no screen change");
521                assert_eq!(idle_secs, Some(300));
522            }
523            _ => panic!("wrong variant"),
524        }
525    }
526
527    #[test]
528    fn eof_returns_none() {
529        let (a, b) = socketpair().unwrap();
530        drop(a); // close sender
531        let mut receiver = Channel::new(b);
532        let result: Option<Command> = receiver.recv().unwrap();
533        assert!(result.is_none());
534    }
535
536    #[test]
537    fn all_states_serialize() {
538        for state in [
539            ShimState::Starting,
540            ShimState::Idle,
541            ShimState::Working,
542            ShimState::Dead,
543            ShimState::ContextExhausted,
544        ] {
545            let json = serde_json::to_string(&state).unwrap();
546            let back: ShimState = serde_json::from_str(&json).unwrap();
547            assert_eq!(state, back);
548        }
549    }
550
551    #[test]
552    fn shim_state_display() {
553        assert_eq!(ShimState::Starting.to_string(), "starting");
554        assert_eq!(ShimState::Idle.to_string(), "idle");
555        assert_eq!(ShimState::Working.to_string(), "working");
556        assert_eq!(ShimState::Dead.to_string(), "dead");
557        assert_eq!(ShimState::ContextExhausted.to_string(), "context_exhausted");
558    }
559
560    #[test]
561    fn socketpair_creates_connected_pair() {
562        let (a, b) = socketpair().unwrap();
563        // Basic connectivity: write on a, read on b
564        let mut ch_a = Channel::new(a);
565        let mut ch_b = Channel::new(b);
566        ch_a.send(&Command::Ping).unwrap();
567        let msg: Command = ch_b.recv().unwrap().unwrap();
568        assert!(matches!(msg, Command::Ping));
569    }
570}