1use serde::{Deserialize, Serialize};
7use std::io::{self, Read, Write};
8use std::os::unix::net::UnixStream;
9
10#[derive(Debug, Serialize, Deserialize)]
15#[serde(tag = "cmd")]
16pub enum Command {
17 SendMessage {
18 from: String,
19 body: String,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 message_id: Option<String>,
22 },
23 CaptureScreen {
24 last_n_lines: Option<usize>,
25 },
26 GetState,
27 Resize {
28 rows: u16,
29 cols: u16,
30 },
31 Shutdown {
32 timeout_secs: u32,
33 },
34 Kill,
35 Ping,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(tag = "event")]
44pub enum Event {
45 Ready,
46 StateChanged {
47 from: ShimState,
48 to: ShimState,
49 summary: String,
50 },
51 Completion {
52 #[serde(skip_serializing_if = "Option::is_none")]
53 message_id: Option<String>,
54 response: String,
55 last_lines: String,
56 },
57 Died {
58 exit_code: Option<i32>,
59 last_lines: String,
60 },
61 ContextExhausted {
62 message: String,
63 last_lines: String,
64 },
65 ScreenCapture {
66 content: String,
67 cursor_row: u16,
68 cursor_col: u16,
69 },
70 State {
71 state: ShimState,
72 since_secs: u64,
73 },
74 SessionStats {
75 output_bytes: u64,
76 uptime_secs: u64,
77 },
78 Pong,
79 Warning {
80 message: String,
81 idle_secs: Option<u64>,
82 },
83 Error {
84 command: String,
85 reason: String,
86 },
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum ShimState {
96 Starting,
97 Idle,
98 Working,
99 Dead,
100 ContextExhausted,
101}
102
103impl std::fmt::Display for ShimState {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 Self::Starting => write!(f, "starting"),
107 Self::Idle => write!(f, "idle"),
108 Self::Working => write!(f, "working"),
109 Self::Dead => write!(f, "dead"),
110 Self::ContextExhausted => write!(f, "context_exhausted"),
111 }
112 }
113}
114
115pub struct Channel {
123 stream: UnixStream,
124 read_buf: Vec<u8>,
125}
126
127const MAX_MSG: usize = 1_048_576; impl Channel {
130 pub fn new(stream: UnixStream) -> Self {
131 Self {
132 stream,
133 read_buf: vec![0u8; 4096],
134 }
135 }
136
137 pub fn send<T: Serialize>(&mut self, msg: &T) -> anyhow::Result<()> {
139 let json = serde_json::to_vec(msg)?;
140 if json.len() > MAX_MSG {
141 anyhow::bail!("message too large: {} bytes", json.len());
142 }
143 let len = (json.len() as u32).to_be_bytes();
144 self.stream.write_all(&len)?;
145 self.stream.write_all(&json)?;
146 self.stream.flush()?;
147 Ok(())
148 }
149
150 pub fn recv<T: for<'de> Deserialize<'de>>(&mut self) -> anyhow::Result<Option<T>> {
153 let mut len_buf = [0u8; 4];
154 match self.stream.read_exact(&mut len_buf) {
155 Ok(()) => {}
156 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
157 Err(e) => return Err(e.into()),
158 }
159 let len = u32::from_be_bytes(len_buf) as usize;
160 if len > MAX_MSG {
161 anyhow::bail!("incoming message too large: {} bytes", len);
162 }
163 if self.read_buf.len() < len {
164 self.read_buf.resize(len, 0);
165 }
166 self.stream.read_exact(&mut self.read_buf[..len])?;
167 let msg = serde_json::from_slice(&self.read_buf[..len])?;
168 Ok(Some(msg))
169 }
170
171 pub fn set_read_timeout(&mut self, timeout: Option<std::time::Duration>) -> anyhow::Result<()> {
175 self.stream.set_read_timeout(timeout)?;
176 Ok(())
177 }
178
179 pub fn try_clone(&self) -> anyhow::Result<Self> {
181 Ok(Self {
182 stream: self.stream.try_clone()?,
183 read_buf: vec![0u8; 4096],
184 })
185 }
186}
187
188pub fn socketpair() -> anyhow::Result<(UnixStream, UnixStream)> {
195 let (a, b) = UnixStream::pair()?;
196 Ok((a, b))
197}
198
199#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn roundtrip_command_send_message() {
209 let (a, b) = socketpair().unwrap();
210 let mut sender = Channel::new(a);
211 let mut receiver = Channel::new(b);
212
213 let cmd = Command::SendMessage {
214 from: "user".into(),
215 body: "say hello".into(),
216 message_id: Some("msg-1".into()),
217 };
218 sender.send(&cmd).unwrap();
219 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
220
221 match received {
222 Command::SendMessage {
223 from,
224 body,
225 message_id,
226 } => {
227 assert_eq!(from, "user");
228 assert_eq!(body, "say hello");
229 assert_eq!(message_id.as_deref(), Some("msg-1"));
230 }
231 _ => panic!("wrong variant"),
232 }
233 }
234
235 #[test]
236 fn roundtrip_command_capture_screen() {
237 let (a, b) = socketpair().unwrap();
238 let mut sender = Channel::new(a);
239 let mut receiver = Channel::new(b);
240
241 let cmd = Command::CaptureScreen {
242 last_n_lines: Some(10),
243 };
244 sender.send(&cmd).unwrap();
245 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
246 match received {
247 Command::CaptureScreen { last_n_lines } => assert_eq!(last_n_lines, Some(10)),
248 _ => panic!("wrong variant"),
249 }
250 }
251
252 #[test]
253 fn roundtrip_command_get_state() {
254 let (a, b) = socketpair().unwrap();
255 let mut sender = Channel::new(a);
256 let mut receiver = Channel::new(b);
257
258 sender.send(&Command::GetState).unwrap();
259 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
260 assert!(matches!(received, Command::GetState));
261 }
262
263 #[test]
264 fn roundtrip_command_resize() {
265 let (a, b) = socketpair().unwrap();
266 let mut sender = Channel::new(a);
267 let mut receiver = Channel::new(b);
268
269 let cmd = Command::Resize {
270 rows: 50,
271 cols: 220,
272 };
273 sender.send(&cmd).unwrap();
274 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
275 match received {
276 Command::Resize { rows, cols } => {
277 assert_eq!(rows, 50);
278 assert_eq!(cols, 220);
279 }
280 _ => panic!("wrong variant"),
281 }
282 }
283
284 #[test]
285 fn roundtrip_command_shutdown() {
286 let (a, b) = socketpair().unwrap();
287 let mut sender = Channel::new(a);
288 let mut receiver = Channel::new(b);
289
290 let cmd = Command::Shutdown { timeout_secs: 30 };
291 sender.send(&cmd).unwrap();
292 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
293 match received {
294 Command::Shutdown { timeout_secs } => assert_eq!(timeout_secs, 30),
295 _ => panic!("wrong variant"),
296 }
297 }
298
299 #[test]
300 fn roundtrip_command_kill() {
301 let (a, b) = socketpair().unwrap();
302 let mut sender = Channel::new(a);
303 let mut receiver = Channel::new(b);
304
305 sender.send(&Command::Kill).unwrap();
306 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
307 assert!(matches!(received, Command::Kill));
308 }
309
310 #[test]
311 fn roundtrip_command_ping() {
312 let (a, b) = socketpair().unwrap();
313 let mut sender = Channel::new(a);
314 let mut receiver = Channel::new(b);
315
316 sender.send(&Command::Ping).unwrap();
317 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
318 assert!(matches!(received, Command::Ping));
319 }
320
321 #[test]
322 fn roundtrip_event_completion() {
323 let (a, b) = socketpair().unwrap();
324 let mut sender = Channel::new(a);
325 let mut receiver = Channel::new(b);
326
327 let evt = Event::Completion {
328 message_id: None,
329 response: "Hello!".into(),
330 last_lines: "Hello!\n❯".into(),
331 };
332 sender.send(&evt).unwrap();
333 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
334
335 match received {
336 Event::Completion { response, .. } => assert_eq!(response, "Hello!"),
337 _ => panic!("wrong variant"),
338 }
339 }
340
341 #[test]
342 fn roundtrip_event_state_changed() {
343 let (a, b) = socketpair().unwrap();
344 let mut sender = Channel::new(a);
345 let mut receiver = Channel::new(b);
346
347 let evt = Event::StateChanged {
348 from: ShimState::Idle,
349 to: ShimState::Working,
350 summary: "working now".into(),
351 };
352 sender.send(&evt).unwrap();
353 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
354 match received {
355 Event::StateChanged { from, to, summary } => {
356 assert_eq!(from, ShimState::Idle);
357 assert_eq!(to, ShimState::Working);
358 assert_eq!(summary, "working now");
359 }
360 _ => panic!("wrong variant"),
361 }
362 }
363
364 #[test]
365 fn roundtrip_event_ready() {
366 let (a, b) = socketpair().unwrap();
367 let mut sender = Channel::new(a);
368 let mut receiver = Channel::new(b);
369
370 sender.send(&Event::Ready).unwrap();
371 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
372 assert!(matches!(received, Event::Ready));
373 }
374
375 #[test]
376 fn roundtrip_event_pong() {
377 let (a, b) = socketpair().unwrap();
378 let mut sender = Channel::new(a);
379 let mut receiver = Channel::new(b);
380
381 sender.send(&Event::Pong).unwrap();
382 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
383 assert!(matches!(received, Event::Pong));
384 }
385
386 #[test]
387 fn roundtrip_event_died() {
388 let (a, b) = socketpair().unwrap();
389 let mut sender = Channel::new(a);
390 let mut receiver = Channel::new(b);
391
392 let evt = Event::Died {
393 exit_code: Some(1),
394 last_lines: "error occurred".into(),
395 };
396 sender.send(&evt).unwrap();
397 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
398 match received {
399 Event::Died {
400 exit_code,
401 last_lines,
402 } => {
403 assert_eq!(exit_code, Some(1));
404 assert_eq!(last_lines, "error occurred");
405 }
406 _ => panic!("wrong variant"),
407 }
408 }
409
410 #[test]
411 fn roundtrip_event_context_exhausted() {
412 let (a, b) = socketpair().unwrap();
413 let mut sender = Channel::new(a);
414 let mut receiver = Channel::new(b);
415
416 let evt = Event::ContextExhausted {
417 message: "context full".into(),
418 last_lines: "last output".into(),
419 };
420 sender.send(&evt).unwrap();
421 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
422 match received {
423 Event::ContextExhausted {
424 message,
425 last_lines,
426 } => {
427 assert_eq!(message, "context full");
428 assert_eq!(last_lines, "last output");
429 }
430 _ => panic!("wrong variant"),
431 }
432 }
433
434 #[test]
435 fn roundtrip_event_screen_capture() {
436 let (a, b) = socketpair().unwrap();
437 let mut sender = Channel::new(a);
438 let mut receiver = Channel::new(b);
439
440 let evt = Event::ScreenCapture {
441 content: "screen data".into(),
442 cursor_row: 5,
443 cursor_col: 10,
444 };
445 sender.send(&evt).unwrap();
446 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
447 match received {
448 Event::ScreenCapture {
449 content,
450 cursor_row,
451 cursor_col,
452 } => {
453 assert_eq!(content, "screen data");
454 assert_eq!(cursor_row, 5);
455 assert_eq!(cursor_col, 10);
456 }
457 _ => panic!("wrong variant"),
458 }
459 }
460
461 #[test]
462 fn roundtrip_event_session_stats() {
463 let (a, b) = socketpair().unwrap();
464 let mut sender = Channel::new(a);
465 let mut receiver = Channel::new(b);
466
467 let evt = Event::SessionStats {
468 output_bytes: 123_456,
469 uptime_secs: 61,
470 };
471 sender.send(&evt).unwrap();
472 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
473 match received {
474 Event::SessionStats {
475 output_bytes,
476 uptime_secs,
477 } => {
478 assert_eq!(output_bytes, 123_456);
479 assert_eq!(uptime_secs, 61);
480 }
481 _ => panic!("wrong variant"),
482 }
483 }
484
485 #[test]
486 fn roundtrip_event_error() {
487 let (a, b) = socketpair().unwrap();
488 let mut sender = Channel::new(a);
489 let mut receiver = Channel::new(b);
490
491 let evt = Event::Error {
492 command: "SendMessage".into(),
493 reason: "agent busy".into(),
494 };
495 sender.send(&evt).unwrap();
496 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
497 match received {
498 Event::Error { command, reason } => {
499 assert_eq!(command, "SendMessage");
500 assert_eq!(reason, "agent busy");
501 }
502 _ => panic!("wrong variant"),
503 }
504 }
505
506 #[test]
507 fn roundtrip_event_warning() {
508 let (a, b) = socketpair().unwrap();
509 let mut sender = Channel::new(a);
510 let mut receiver = Channel::new(b);
511
512 let evt = Event::Warning {
513 message: "no screen change".into(),
514 idle_secs: Some(300),
515 };
516 sender.send(&evt).unwrap();
517 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
518 match received {
519 Event::Warning { message, idle_secs } => {
520 assert_eq!(message, "no screen change");
521 assert_eq!(idle_secs, Some(300));
522 }
523 _ => panic!("wrong variant"),
524 }
525 }
526
527 #[test]
528 fn eof_returns_none() {
529 let (a, b) = socketpair().unwrap();
530 drop(a); let mut receiver = Channel::new(b);
532 let result: Option<Command> = receiver.recv().unwrap();
533 assert!(result.is_none());
534 }
535
536 #[test]
537 fn all_states_serialize() {
538 for state in [
539 ShimState::Starting,
540 ShimState::Idle,
541 ShimState::Working,
542 ShimState::Dead,
543 ShimState::ContextExhausted,
544 ] {
545 let json = serde_json::to_string(&state).unwrap();
546 let back: ShimState = serde_json::from_str(&json).unwrap();
547 assert_eq!(state, back);
548 }
549 }
550
551 #[test]
552 fn shim_state_display() {
553 assert_eq!(ShimState::Starting.to_string(), "starting");
554 assert_eq!(ShimState::Idle.to_string(), "idle");
555 assert_eq!(ShimState::Working.to_string(), "working");
556 assert_eq!(ShimState::Dead.to_string(), "dead");
557 assert_eq!(ShimState::ContextExhausted.to_string(), "context_exhausted");
558 }
559
560 #[test]
561 fn socketpair_creates_connected_pair() {
562 let (a, b) = socketpair().unwrap();
563 let mut ch_a = Channel::new(a);
565 let mut ch_b = Channel::new(b);
566 ch_a.send(&Command::Ping).unwrap();
567 let msg: Command = ch_b.recv().unwrap().unwrap();
568 assert!(matches!(msg, Command::Ping));
569 }
570}