1use serde::{Deserialize, Serialize};
7use std::io::{self, Read, Write};
8use std::os::unix::net::UnixStream;
9
10#[derive(Debug, Serialize, Deserialize)]
15#[serde(tag = "cmd")]
16pub enum Command {
17 SendMessage {
18 from: String,
19 body: String,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 message_id: Option<String>,
22 },
23 CaptureScreen {
24 last_n_lines: Option<usize>,
25 },
26 GetState,
27 Resize {
28 rows: u16,
29 cols: u16,
30 },
31 Shutdown {
32 timeout_secs: u32,
33 #[serde(default)]
34 reason: ShutdownReason,
35 },
36 Kill,
37 Ping,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
41#[serde(rename_all = "snake_case")]
42pub enum ShutdownReason {
43 #[default]
44 Requested,
45 RestartHandoff,
46 ContextExhausted,
47 TopologyChange,
48 DaemonStop,
49}
50
51impl ShutdownReason {
52 pub fn label(self) -> &'static str {
53 match self {
54 Self::Requested => "requested",
55 Self::RestartHandoff => "restart_handoff",
56 Self::ContextExhausted => "context_exhausted",
57 Self::TopologyChange => "topology_change",
58 Self::DaemonStop => "daemon_stop",
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
68#[serde(tag = "event")]
69pub enum Event {
70 Ready,
71 StateChanged {
72 from: ShimState,
73 to: ShimState,
74 summary: String,
75 },
76 MessageDelivered {
77 id: String,
78 },
79 Completion {
80 #[serde(skip_serializing_if = "Option::is_none")]
81 message_id: Option<String>,
82 response: String,
83 last_lines: String,
84 },
85 Died {
86 exit_code: Option<i32>,
87 last_lines: String,
88 },
89 ContextExhausted {
90 message: String,
91 last_lines: String,
92 },
93 ContextWarning {
94 model: Option<String>,
95 output_bytes: u64,
96 uptime_secs: u64,
97 input_tokens: u64,
98 cached_input_tokens: u64,
99 cache_creation_input_tokens: u64,
100 cache_read_input_tokens: u64,
101 output_tokens: u64,
102 reasoning_output_tokens: u64,
103 used_tokens: u64,
104 context_limit_tokens: u64,
105 usage_pct: u8,
106 },
107 ContextApproaching {
108 message: String,
109 input_tokens: u64,
110 output_tokens: u64,
111 },
112 ScreenCapture {
113 content: String,
114 cursor_row: u16,
115 cursor_col: u16,
116 },
117 State {
118 state: ShimState,
119 since_secs: u64,
120 },
121 SessionStats {
122 output_bytes: u64,
123 uptime_secs: u64,
124 #[serde(default)]
125 input_tokens: u64,
126 #[serde(default)]
127 output_tokens: u64,
128 #[serde(default, skip_serializing_if = "Option::is_none")]
129 context_usage_pct: Option<u8>,
130 },
131 Pong,
132 Warning {
133 message: String,
134 idle_secs: Option<u64>,
135 },
136 DeliveryFailed {
137 id: String,
138 reason: String,
139 },
140 Error {
141 command: String,
142 reason: String,
143 },
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
151#[serde(rename_all = "snake_case")]
152pub enum ShimState {
153 Starting,
154 Idle,
155 Working,
156 Dead,
157 ContextExhausted,
158}
159
160impl std::fmt::Display for ShimState {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 Self::Starting => write!(f, "starting"),
164 Self::Idle => write!(f, "idle"),
165 Self::Working => write!(f, "working"),
166 Self::Dead => write!(f, "dead"),
167 Self::ContextExhausted => write!(f, "context_exhausted"),
168 }
169 }
170}
171
172pub struct Channel {
180 stream: UnixStream,
181 read_buf: Vec<u8>,
182}
183
184const MAX_MSG: usize = 1_048_576; impl Channel {
187 pub fn new(stream: UnixStream) -> Self {
188 Self {
189 stream,
190 read_buf: vec![0u8; 4096],
191 }
192 }
193
194 pub fn send<T: Serialize>(&mut self, msg: &T) -> anyhow::Result<()> {
196 let json = serde_json::to_vec(msg)?;
197 if json.len() > MAX_MSG {
198 anyhow::bail!("message too large: {} bytes", json.len());
199 }
200 let len = (json.len() as u32).to_be_bytes();
201 self.stream.write_all(&len)?;
202 self.stream.write_all(&json)?;
203 self.stream.flush()?;
204 Ok(())
205 }
206
207 pub fn recv<T: for<'de> Deserialize<'de>>(&mut self) -> anyhow::Result<Option<T>> {
210 let mut len_buf = [0u8; 4];
211 match self.stream.read_exact(&mut len_buf) {
212 Ok(()) => {}
213 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
214 Err(e) => return Err(e.into()),
215 }
216 let len = u32::from_be_bytes(len_buf) as usize;
217 if len > MAX_MSG {
218 anyhow::bail!("incoming message too large: {} bytes", len);
219 }
220 if self.read_buf.len() < len {
221 self.read_buf.resize(len, 0);
222 }
223 self.stream.read_exact(&mut self.read_buf[..len])?;
224 let msg = serde_json::from_slice(&self.read_buf[..len])?;
225 Ok(Some(msg))
226 }
227
228 pub fn set_read_timeout(&mut self, timeout: Option<std::time::Duration>) -> anyhow::Result<()> {
232 self.stream.set_read_timeout(timeout)?;
233 Ok(())
234 }
235
236 pub fn try_clone(&self) -> anyhow::Result<Self> {
238 Ok(Self {
239 stream: self.stream.try_clone()?,
240 read_buf: vec![0u8; 4096],
241 })
242 }
243}
244
245pub fn socketpair() -> anyhow::Result<(UnixStream, UnixStream)> {
252 let (a, b) = UnixStream::pair()?;
253 Ok((a, b))
254}
255
256#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn roundtrip_command_send_message() {
266 let (a, b) = socketpair().unwrap();
267 let mut sender = Channel::new(a);
268 let mut receiver = Channel::new(b);
269
270 let cmd = Command::SendMessage {
271 from: "user".into(),
272 body: "say hello".into(),
273 message_id: Some("msg-1".into()),
274 };
275 sender.send(&cmd).unwrap();
276 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
277
278 match received {
279 Command::SendMessage {
280 from,
281 body,
282 message_id,
283 } => {
284 assert_eq!(from, "user");
285 assert_eq!(body, "say hello");
286 assert_eq!(message_id.as_deref(), Some("msg-1"));
287 }
288 _ => panic!("wrong variant"),
289 }
290 }
291
292 #[test]
293 fn roundtrip_command_capture_screen() {
294 let (a, b) = socketpair().unwrap();
295 let mut sender = Channel::new(a);
296 let mut receiver = Channel::new(b);
297
298 let cmd = Command::CaptureScreen {
299 last_n_lines: Some(10),
300 };
301 sender.send(&cmd).unwrap();
302 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
303 match received {
304 Command::CaptureScreen { last_n_lines } => assert_eq!(last_n_lines, Some(10)),
305 _ => panic!("wrong variant"),
306 }
307 }
308
309 #[test]
310 fn roundtrip_command_get_state() {
311 let (a, b) = socketpair().unwrap();
312 let mut sender = Channel::new(a);
313 let mut receiver = Channel::new(b);
314
315 sender.send(&Command::GetState).unwrap();
316 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
317 assert!(matches!(received, Command::GetState));
318 }
319
320 #[test]
321 fn roundtrip_command_resize() {
322 let (a, b) = socketpair().unwrap();
323 let mut sender = Channel::new(a);
324 let mut receiver = Channel::new(b);
325
326 let cmd = Command::Resize {
327 rows: 50,
328 cols: 220,
329 };
330 sender.send(&cmd).unwrap();
331 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
332 match received {
333 Command::Resize { rows, cols } => {
334 assert_eq!(rows, 50);
335 assert_eq!(cols, 220);
336 }
337 _ => panic!("wrong variant"),
338 }
339 }
340
341 #[test]
342 fn roundtrip_command_shutdown() {
343 let (a, b) = socketpair().unwrap();
344 let mut sender = Channel::new(a);
345 let mut receiver = Channel::new(b);
346
347 let cmd = Command::Shutdown {
348 timeout_secs: 30,
349 reason: ShutdownReason::Requested,
350 };
351 sender.send(&cmd).unwrap();
352 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
353 match received {
354 Command::Shutdown {
355 timeout_secs,
356 reason,
357 } => {
358 assert_eq!(timeout_secs, 30);
359 assert_eq!(reason, ShutdownReason::Requested);
360 }
361 _ => panic!("wrong variant"),
362 }
363 }
364
365 #[test]
366 fn shutdown_reason_labels_restart_handoff_explicitly() {
367 assert_eq!(ShutdownReason::RestartHandoff.label(), "restart_handoff");
368 assert_ne!(
369 ShutdownReason::RestartHandoff.label(),
370 "orchestrator disconnected"
371 );
372 }
373
374 #[test]
375 fn roundtrip_command_kill() {
376 let (a, b) = socketpair().unwrap();
377 let mut sender = Channel::new(a);
378 let mut receiver = Channel::new(b);
379
380 sender.send(&Command::Kill).unwrap();
381 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
382 assert!(matches!(received, Command::Kill));
383 }
384
385 #[test]
386 fn roundtrip_command_ping() {
387 let (a, b) = socketpair().unwrap();
388 let mut sender = Channel::new(a);
389 let mut receiver = Channel::new(b);
390
391 sender.send(&Command::Ping).unwrap();
392 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
393 assert!(matches!(received, Command::Ping));
394 }
395
396 #[test]
397 fn roundtrip_event_completion() {
398 let (a, b) = socketpair().unwrap();
399 let mut sender = Channel::new(a);
400 let mut receiver = Channel::new(b);
401
402 let evt = Event::Completion {
403 message_id: None,
404 response: "Hello!".into(),
405 last_lines: "Hello!\n❯".into(),
406 };
407 sender.send(&evt).unwrap();
408 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
409
410 match received {
411 Event::Completion { response, .. } => assert_eq!(response, "Hello!"),
412 _ => panic!("wrong variant"),
413 }
414 }
415
416 #[test]
417 fn roundtrip_event_message_delivered() {
418 let (a, b) = socketpair().unwrap();
419 let mut sender = Channel::new(a);
420 let mut receiver = Channel::new(b);
421
422 let evt = Event::MessageDelivered { id: "msg-1".into() };
423 sender.send(&evt).unwrap();
424 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
425
426 match received {
427 Event::MessageDelivered { id } => assert_eq!(id, "msg-1"),
428 _ => panic!("wrong variant"),
429 }
430 }
431
432 #[test]
433 fn roundtrip_event_state_changed() {
434 let (a, b) = socketpair().unwrap();
435 let mut sender = Channel::new(a);
436 let mut receiver = Channel::new(b);
437
438 let evt = Event::StateChanged {
439 from: ShimState::Idle,
440 to: ShimState::Working,
441 summary: "working now".into(),
442 };
443 sender.send(&evt).unwrap();
444 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
445 match received {
446 Event::StateChanged { from, to, summary } => {
447 assert_eq!(from, ShimState::Idle);
448 assert_eq!(to, ShimState::Working);
449 assert_eq!(summary, "working now");
450 }
451 _ => panic!("wrong variant"),
452 }
453 }
454
455 #[test]
456 fn roundtrip_event_ready() {
457 let (a, b) = socketpair().unwrap();
458 let mut sender = Channel::new(a);
459 let mut receiver = Channel::new(b);
460
461 sender.send(&Event::Ready).unwrap();
462 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
463 assert!(matches!(received, Event::Ready));
464 }
465
466 #[test]
467 fn roundtrip_event_pong() {
468 let (a, b) = socketpair().unwrap();
469 let mut sender = Channel::new(a);
470 let mut receiver = Channel::new(b);
471
472 sender.send(&Event::Pong).unwrap();
473 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
474 assert!(matches!(received, Event::Pong));
475 }
476
477 #[test]
478 fn roundtrip_event_delivery_failed() {
479 let (a, b) = socketpair().unwrap();
480 let mut sender = Channel::new(a);
481 let mut receiver = Channel::new(b);
482
483 let evt = Event::DeliveryFailed {
484 id: "msg-1".into(),
485 reason: "stdin write failed".into(),
486 };
487 sender.send(&evt).unwrap();
488 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
489
490 match received {
491 Event::DeliveryFailed { id, reason } => {
492 assert_eq!(id, "msg-1");
493 assert_eq!(reason, "stdin write failed");
494 }
495 _ => panic!("wrong variant"),
496 }
497 }
498
499 #[test]
500 fn roundtrip_event_died() {
501 let (a, b) = socketpair().unwrap();
502 let mut sender = Channel::new(a);
503 let mut receiver = Channel::new(b);
504
505 let evt = Event::Died {
506 exit_code: Some(1),
507 last_lines: "error occurred".into(),
508 };
509 sender.send(&evt).unwrap();
510 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
511 match received {
512 Event::Died {
513 exit_code,
514 last_lines,
515 } => {
516 assert_eq!(exit_code, Some(1));
517 assert_eq!(last_lines, "error occurred");
518 }
519 _ => panic!("wrong variant"),
520 }
521 }
522
523 #[test]
524 fn roundtrip_event_context_exhausted() {
525 let (a, b) = socketpair().unwrap();
526 let mut sender = Channel::new(a);
527 let mut receiver = Channel::new(b);
528
529 let evt = Event::ContextExhausted {
530 message: "context full".into(),
531 last_lines: "last output".into(),
532 };
533 sender.send(&evt).unwrap();
534 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
535 match received {
536 Event::ContextExhausted {
537 message,
538 last_lines,
539 } => {
540 assert_eq!(message, "context full");
541 assert_eq!(last_lines, "last output");
542 }
543 _ => panic!("wrong variant"),
544 }
545 }
546
547 #[test]
548 fn roundtrip_event_screen_capture() {
549 let (a, b) = socketpair().unwrap();
550 let mut sender = Channel::new(a);
551 let mut receiver = Channel::new(b);
552
553 let evt = Event::ScreenCapture {
554 content: "screen data".into(),
555 cursor_row: 5,
556 cursor_col: 10,
557 };
558 sender.send(&evt).unwrap();
559 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
560 match received {
561 Event::ScreenCapture {
562 content,
563 cursor_row,
564 cursor_col,
565 } => {
566 assert_eq!(content, "screen data");
567 assert_eq!(cursor_row, 5);
568 assert_eq!(cursor_col, 10);
569 }
570 _ => panic!("wrong variant"),
571 }
572 }
573
574 #[test]
575 fn roundtrip_event_context_warning() {
576 let (a, b) = socketpair().unwrap();
577 let mut sender = Channel::new(a);
578 let mut receiver = Channel::new(b);
579
580 let evt = Event::ContextWarning {
581 model: Some("claude-sonnet-4-5".into()),
582 output_bytes: 12_345,
583 uptime_secs: 61,
584 input_tokens: 80_000,
585 cached_input_tokens: 5_000,
586 cache_creation_input_tokens: 4_000,
587 cache_read_input_tokens: 3_000,
588 output_tokens: 6_000,
589 reasoning_output_tokens: 2_000,
590 used_tokens: 100_000,
591 context_limit_tokens: 200_000,
592 usage_pct: 50,
593 };
594 sender.send(&evt).unwrap();
595 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
596 match received {
597 Event::ContextWarning {
598 model,
599 output_bytes,
600 uptime_secs,
601 input_tokens,
602 cached_input_tokens,
603 cache_creation_input_tokens,
604 cache_read_input_tokens,
605 output_tokens,
606 reasoning_output_tokens,
607 used_tokens,
608 context_limit_tokens,
609 usage_pct,
610 } => {
611 assert_eq!(model.as_deref(), Some("claude-sonnet-4-5"));
612 assert_eq!(output_bytes, 12_345);
613 assert_eq!(uptime_secs, 61);
614 assert_eq!(input_tokens, 80_000);
615 assert_eq!(cached_input_tokens, 5_000);
616 assert_eq!(cache_creation_input_tokens, 4_000);
617 assert_eq!(cache_read_input_tokens, 3_000);
618 assert_eq!(output_tokens, 6_000);
619 assert_eq!(reasoning_output_tokens, 2_000);
620 assert_eq!(used_tokens, 100_000);
621 assert_eq!(context_limit_tokens, 200_000);
622 assert_eq!(usage_pct, 50);
623 }
624 _ => panic!("wrong variant"),
625 }
626 }
627
628 #[test]
629 fn roundtrip_event_session_stats() {
630 let (a, b) = socketpair().unwrap();
631 let mut sender = Channel::new(a);
632 let mut receiver = Channel::new(b);
633
634 let evt = Event::SessionStats {
635 output_bytes: 123_456,
636 uptime_secs: 61,
637 input_tokens: 5000,
638 output_tokens: 1200,
639 context_usage_pct: Some(84),
640 };
641 sender.send(&evt).unwrap();
642 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
643 match received {
644 Event::SessionStats {
645 output_bytes,
646 uptime_secs,
647 input_tokens,
648 output_tokens,
649 context_usage_pct,
650 } => {
651 assert_eq!(output_bytes, 123_456);
652 assert_eq!(uptime_secs, 61);
653 assert_eq!(input_tokens, 5000);
654 assert_eq!(output_tokens, 1200);
655 assert_eq!(context_usage_pct, Some(84));
656 }
657 _ => panic!("wrong variant"),
658 }
659 }
660
661 #[test]
662 fn roundtrip_event_context_approaching() {
663 let (a, b) = socketpair().unwrap();
664 let mut sender = Channel::new(a);
665 let mut receiver = Channel::new(b);
666
667 let evt = Event::ContextApproaching {
668 message: "context pressure detected".into(),
669 input_tokens: 80000,
670 output_tokens: 20000,
671 };
672 sender.send(&evt).unwrap();
673 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
674 match received {
675 Event::ContextApproaching {
676 message,
677 input_tokens,
678 output_tokens,
679 } => {
680 assert_eq!(message, "context pressure detected");
681 assert_eq!(input_tokens, 80000);
682 assert_eq!(output_tokens, 20000);
683 }
684 _ => panic!("wrong variant"),
685 }
686 }
687
688 #[test]
689 fn roundtrip_event_error() {
690 let (a, b) = socketpair().unwrap();
691 let mut sender = Channel::new(a);
692 let mut receiver = Channel::new(b);
693
694 let evt = Event::Error {
695 command: "SendMessage".into(),
696 reason: "agent busy".into(),
697 };
698 sender.send(&evt).unwrap();
699 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
700 match received {
701 Event::Error { command, reason } => {
702 assert_eq!(command, "SendMessage");
703 assert_eq!(reason, "agent busy");
704 }
705 _ => panic!("wrong variant"),
706 }
707 }
708
709 #[test]
710 fn roundtrip_event_warning() {
711 let (a, b) = socketpair().unwrap();
712 let mut sender = Channel::new(a);
713 let mut receiver = Channel::new(b);
714
715 let evt = Event::Warning {
716 message: "no screen change".into(),
717 idle_secs: Some(300),
718 };
719 sender.send(&evt).unwrap();
720 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
721 match received {
722 Event::Warning { message, idle_secs } => {
723 assert_eq!(message, "no screen change");
724 assert_eq!(idle_secs, Some(300));
725 }
726 _ => panic!("wrong variant"),
727 }
728 }
729
730 #[test]
731 fn eof_returns_none() {
732 let (a, b) = socketpair().unwrap();
733 drop(a); let mut receiver = Channel::new(b);
735 let result: Option<Command> = receiver.recv().unwrap();
736 assert!(result.is_none());
737 }
738
739 #[test]
740 fn all_states_serialize() {
741 for state in [
742 ShimState::Starting,
743 ShimState::Idle,
744 ShimState::Working,
745 ShimState::Dead,
746 ShimState::ContextExhausted,
747 ] {
748 let json = serde_json::to_string(&state).unwrap();
749 let back: ShimState = serde_json::from_str(&json).unwrap();
750 assert_eq!(state, back);
751 }
752 }
753
754 #[test]
755 fn shim_state_display() {
756 assert_eq!(ShimState::Starting.to_string(), "starting");
757 assert_eq!(ShimState::Idle.to_string(), "idle");
758 assert_eq!(ShimState::Working.to_string(), "working");
759 assert_eq!(ShimState::Dead.to_string(), "dead");
760 assert_eq!(ShimState::ContextExhausted.to_string(), "context_exhausted");
761 }
762
763 #[test]
764 fn socketpair_creates_connected_pair() {
765 let (a, b) = socketpair().unwrap();
766 let mut ch_a = Channel::new(a);
768 let mut ch_b = Channel::new(b);
769 ch_a.send(&Command::Ping).unwrap();
770 let msg: Command = ch_b.recv().unwrap().unwrap();
771 assert!(matches!(msg, Command::Ping));
772 }
773}