1use serde::{Deserialize, Serialize};
7use std::io::{self, Read, Write};
8use std::os::unix::net::UnixStream;
9
10#[derive(Debug, Serialize, Deserialize)]
15#[serde(tag = "cmd")]
16pub enum Command {
17 SendMessage {
18 from: String,
19 body: String,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 message_id: Option<String>,
22 },
23 CaptureScreen {
24 last_n_lines: Option<usize>,
25 },
26 GetState,
27 Resize {
28 rows: u16,
29 cols: u16,
30 },
31 Shutdown {
32 timeout_secs: u32,
33 #[serde(default)]
34 reason: ShutdownReason,
35 },
36 Kill,
37 Ping,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
41#[serde(rename_all = "snake_case")]
42pub enum ShutdownReason {
43 #[default]
44 Requested,
45 RestartHandoff,
46 ContextExhausted,
47 TopologyChange,
48 DaemonStop,
49}
50
51impl ShutdownReason {
52 pub fn label(self) -> &'static str {
53 match self {
54 Self::Requested => "requested",
55 Self::RestartHandoff => "restart_handoff",
56 Self::ContextExhausted => "context_exhausted",
57 Self::TopologyChange => "topology_change",
58 Self::DaemonStop => "daemon_stop",
59 }
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
68#[serde(tag = "event")]
69pub enum Event {
70 Ready,
71 StateChanged {
72 from: ShimState,
73 to: ShimState,
74 summary: String,
75 },
76 MessageDelivered {
77 id: String,
78 },
79 Completion {
80 #[serde(skip_serializing_if = "Option::is_none")]
81 message_id: Option<String>,
82 response: String,
83 last_lines: String,
84 },
85 Died {
86 exit_code: Option<i32>,
87 last_lines: String,
88 },
89 ContextExhausted {
90 message: String,
91 last_lines: String,
92 },
93 ContextWarning {
94 model: Option<String>,
95 output_bytes: u64,
96 uptime_secs: u64,
97 input_tokens: u64,
98 cached_input_tokens: u64,
99 cache_creation_input_tokens: u64,
100 cache_read_input_tokens: u64,
101 output_tokens: u64,
102 reasoning_output_tokens: u64,
103 used_tokens: u64,
104 context_limit_tokens: u64,
105 usage_pct: u8,
106 },
107 ContextApproaching {
108 message: String,
109 input_tokens: u64,
110 output_tokens: u64,
111 },
112 ScreenCapture {
113 content: String,
114 cursor_row: u16,
115 cursor_col: u16,
116 },
117 State {
118 state: ShimState,
119 since_secs: u64,
120 },
121 SessionStats {
122 output_bytes: u64,
123 uptime_secs: u64,
124 #[serde(default)]
125 input_tokens: u64,
126 #[serde(default)]
127 output_tokens: u64,
128 #[serde(default, skip_serializing_if = "Option::is_none")]
129 context_usage_pct: Option<u8>,
130 },
131 Pong,
132 Warning {
133 message: String,
134 idle_secs: Option<u64>,
135 },
136 DeliveryFailed {
137 id: String,
138 reason: String,
139 },
140 Error {
141 command: String,
142 reason: String,
143 },
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
151#[serde(rename_all = "snake_case")]
152pub enum ShimState {
153 Starting,
154 Idle,
155 Working,
156 Dead,
157 ContextExhausted,
158}
159
160impl std::fmt::Display for ShimState {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 Self::Starting => write!(f, "starting"),
164 Self::Idle => write!(f, "idle"),
165 Self::Working => write!(f, "working"),
166 Self::Dead => write!(f, "dead"),
167 Self::ContextExhausted => write!(f, "context_exhausted"),
168 }
169 }
170}
171
172pub struct Channel {
180 stream: UnixStream,
181 read_buf: Vec<u8>,
182}
183
184const MAX_MSG: usize = 1_048_576; impl Channel {
187 pub fn new(stream: UnixStream) -> Self {
188 Self {
189 stream,
190 read_buf: vec![0u8; 4096],
191 }
192 }
193
194 pub fn send<T: Serialize>(&mut self, msg: &T) -> anyhow::Result<()> {
196 let json = serde_json::to_vec(msg)?;
197 if json.len() > MAX_MSG {
198 anyhow::bail!("message too large: {} bytes", json.len());
199 }
200 let len = (json.len() as u32).to_be_bytes();
201 self.stream.write_all(&len)?;
202 self.stream.write_all(&json)?;
203 self.stream.flush()?;
204 Ok(())
205 }
206
207 pub fn recv<T: for<'de> Deserialize<'de>>(&mut self) -> anyhow::Result<Option<T>> {
210 let mut len_buf = [0u8; 4];
211 match self.stream.read_exact(&mut len_buf) {
212 Ok(()) => {}
213 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
214 Err(e) => return Err(e.into()),
215 }
216 let len = u32::from_be_bytes(len_buf) as usize;
217 if len > MAX_MSG {
218 anyhow::bail!("incoming message too large: {} bytes", len);
219 }
220 if self.read_buf.len() < len {
221 self.read_buf.resize(len, 0);
222 }
223 self.stream.read_exact(&mut self.read_buf[..len])?;
224 let msg = serde_json::from_slice(&self.read_buf[..len])?;
225 Ok(Some(msg))
226 }
227
228 pub fn set_read_timeout(&mut self, timeout: Option<std::time::Duration>) -> anyhow::Result<()> {
232 self.stream.set_read_timeout(timeout)?;
233 Ok(())
234 }
235
236 pub fn set_write_timeout(
248 &mut self,
249 timeout: Option<std::time::Duration>,
250 ) -> anyhow::Result<()> {
251 self.stream.set_write_timeout(timeout)?;
252 Ok(())
253 }
254
255 pub fn try_clone(&self) -> anyhow::Result<Self> {
257 Ok(Self {
258 stream: self.stream.try_clone()?,
259 read_buf: vec![0u8; 4096],
260 })
261 }
262}
263
264pub fn socketpair() -> anyhow::Result<(UnixStream, UnixStream)> {
271 let (a, b) = UnixStream::pair()?;
272 Ok((a, b))
273}
274
275#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn roundtrip_command_send_message() {
285 let (a, b) = socketpair().unwrap();
286 let mut sender = Channel::new(a);
287 let mut receiver = Channel::new(b);
288
289 let cmd = Command::SendMessage {
290 from: "user".into(),
291 body: "say hello".into(),
292 message_id: Some("msg-1".into()),
293 };
294 sender.send(&cmd).unwrap();
295 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
296
297 match received {
298 Command::SendMessage {
299 from,
300 body,
301 message_id,
302 } => {
303 assert_eq!(from, "user");
304 assert_eq!(body, "say hello");
305 assert_eq!(message_id.as_deref(), Some("msg-1"));
306 }
307 _ => panic!("wrong variant"),
308 }
309 }
310
311 #[test]
312 fn roundtrip_command_capture_screen() {
313 let (a, b) = socketpair().unwrap();
314 let mut sender = Channel::new(a);
315 let mut receiver = Channel::new(b);
316
317 let cmd = Command::CaptureScreen {
318 last_n_lines: Some(10),
319 };
320 sender.send(&cmd).unwrap();
321 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
322 match received {
323 Command::CaptureScreen { last_n_lines } => assert_eq!(last_n_lines, Some(10)),
324 _ => panic!("wrong variant"),
325 }
326 }
327
328 #[test]
329 fn roundtrip_command_get_state() {
330 let (a, b) = socketpair().unwrap();
331 let mut sender = Channel::new(a);
332 let mut receiver = Channel::new(b);
333
334 sender.send(&Command::GetState).unwrap();
335 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
336 assert!(matches!(received, Command::GetState));
337 }
338
339 #[test]
340 fn roundtrip_command_resize() {
341 let (a, b) = socketpair().unwrap();
342 let mut sender = Channel::new(a);
343 let mut receiver = Channel::new(b);
344
345 let cmd = Command::Resize {
346 rows: 50,
347 cols: 220,
348 };
349 sender.send(&cmd).unwrap();
350 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
351 match received {
352 Command::Resize { rows, cols } => {
353 assert_eq!(rows, 50);
354 assert_eq!(cols, 220);
355 }
356 _ => panic!("wrong variant"),
357 }
358 }
359
360 #[test]
361 fn roundtrip_command_shutdown() {
362 let (a, b) = socketpair().unwrap();
363 let mut sender = Channel::new(a);
364 let mut receiver = Channel::new(b);
365
366 let cmd = Command::Shutdown {
367 timeout_secs: 30,
368 reason: ShutdownReason::Requested,
369 };
370 sender.send(&cmd).unwrap();
371 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
372 match received {
373 Command::Shutdown {
374 timeout_secs,
375 reason,
376 } => {
377 assert_eq!(timeout_secs, 30);
378 assert_eq!(reason, ShutdownReason::Requested);
379 }
380 _ => panic!("wrong variant"),
381 }
382 }
383
384 #[test]
385 fn shutdown_reason_labels_restart_handoff_explicitly() {
386 assert_eq!(ShutdownReason::RestartHandoff.label(), "restart_handoff");
387 assert_ne!(
388 ShutdownReason::RestartHandoff.label(),
389 "orchestrator disconnected"
390 );
391 }
392
393 #[test]
394 fn roundtrip_command_kill() {
395 let (a, b) = socketpair().unwrap();
396 let mut sender = Channel::new(a);
397 let mut receiver = Channel::new(b);
398
399 sender.send(&Command::Kill).unwrap();
400 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
401 assert!(matches!(received, Command::Kill));
402 }
403
404 #[test]
405 fn roundtrip_command_ping() {
406 let (a, b) = socketpair().unwrap();
407 let mut sender = Channel::new(a);
408 let mut receiver = Channel::new(b);
409
410 sender.send(&Command::Ping).unwrap();
411 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
412 assert!(matches!(received, Command::Ping));
413 }
414
415 #[test]
416 fn roundtrip_event_completion() {
417 let (a, b) = socketpair().unwrap();
418 let mut sender = Channel::new(a);
419 let mut receiver = Channel::new(b);
420
421 let evt = Event::Completion {
422 message_id: None,
423 response: "Hello!".into(),
424 last_lines: "Hello!\n❯".into(),
425 };
426 sender.send(&evt).unwrap();
427 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
428
429 match received {
430 Event::Completion { response, .. } => assert_eq!(response, "Hello!"),
431 _ => panic!("wrong variant"),
432 }
433 }
434
435 #[test]
436 fn roundtrip_event_message_delivered() {
437 let (a, b) = socketpair().unwrap();
438 let mut sender = Channel::new(a);
439 let mut receiver = Channel::new(b);
440
441 let evt = Event::MessageDelivered { id: "msg-1".into() };
442 sender.send(&evt).unwrap();
443 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
444
445 match received {
446 Event::MessageDelivered { id } => assert_eq!(id, "msg-1"),
447 _ => panic!("wrong variant"),
448 }
449 }
450
451 #[test]
452 fn roundtrip_event_state_changed() {
453 let (a, b) = socketpair().unwrap();
454 let mut sender = Channel::new(a);
455 let mut receiver = Channel::new(b);
456
457 let evt = Event::StateChanged {
458 from: ShimState::Idle,
459 to: ShimState::Working,
460 summary: "working now".into(),
461 };
462 sender.send(&evt).unwrap();
463 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
464 match received {
465 Event::StateChanged { from, to, summary } => {
466 assert_eq!(from, ShimState::Idle);
467 assert_eq!(to, ShimState::Working);
468 assert_eq!(summary, "working now");
469 }
470 _ => panic!("wrong variant"),
471 }
472 }
473
474 #[test]
475 fn roundtrip_event_ready() {
476 let (a, b) = socketpair().unwrap();
477 let mut sender = Channel::new(a);
478 let mut receiver = Channel::new(b);
479
480 sender.send(&Event::Ready).unwrap();
481 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
482 assert!(matches!(received, Event::Ready));
483 }
484
485 #[test]
486 fn roundtrip_event_pong() {
487 let (a, b) = socketpair().unwrap();
488 let mut sender = Channel::new(a);
489 let mut receiver = Channel::new(b);
490
491 sender.send(&Event::Pong).unwrap();
492 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
493 assert!(matches!(received, Event::Pong));
494 }
495
496 #[test]
497 fn roundtrip_event_delivery_failed() {
498 let (a, b) = socketpair().unwrap();
499 let mut sender = Channel::new(a);
500 let mut receiver = Channel::new(b);
501
502 let evt = Event::DeliveryFailed {
503 id: "msg-1".into(),
504 reason: "stdin write failed".into(),
505 };
506 sender.send(&evt).unwrap();
507 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
508
509 match received {
510 Event::DeliveryFailed { id, reason } => {
511 assert_eq!(id, "msg-1");
512 assert_eq!(reason, "stdin write failed");
513 }
514 _ => panic!("wrong variant"),
515 }
516 }
517
518 #[test]
519 fn roundtrip_event_died() {
520 let (a, b) = socketpair().unwrap();
521 let mut sender = Channel::new(a);
522 let mut receiver = Channel::new(b);
523
524 let evt = Event::Died {
525 exit_code: Some(1),
526 last_lines: "error occurred".into(),
527 };
528 sender.send(&evt).unwrap();
529 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
530 match received {
531 Event::Died {
532 exit_code,
533 last_lines,
534 } => {
535 assert_eq!(exit_code, Some(1));
536 assert_eq!(last_lines, "error occurred");
537 }
538 _ => panic!("wrong variant"),
539 }
540 }
541
542 #[test]
543 fn roundtrip_event_context_exhausted() {
544 let (a, b) = socketpair().unwrap();
545 let mut sender = Channel::new(a);
546 let mut receiver = Channel::new(b);
547
548 let evt = Event::ContextExhausted {
549 message: "context full".into(),
550 last_lines: "last output".into(),
551 };
552 sender.send(&evt).unwrap();
553 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
554 match received {
555 Event::ContextExhausted {
556 message,
557 last_lines,
558 } => {
559 assert_eq!(message, "context full");
560 assert_eq!(last_lines, "last output");
561 }
562 _ => panic!("wrong variant"),
563 }
564 }
565
566 #[test]
567 fn roundtrip_event_screen_capture() {
568 let (a, b) = socketpair().unwrap();
569 let mut sender = Channel::new(a);
570 let mut receiver = Channel::new(b);
571
572 let evt = Event::ScreenCapture {
573 content: "screen data".into(),
574 cursor_row: 5,
575 cursor_col: 10,
576 };
577 sender.send(&evt).unwrap();
578 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
579 match received {
580 Event::ScreenCapture {
581 content,
582 cursor_row,
583 cursor_col,
584 } => {
585 assert_eq!(content, "screen data");
586 assert_eq!(cursor_row, 5);
587 assert_eq!(cursor_col, 10);
588 }
589 _ => panic!("wrong variant"),
590 }
591 }
592
593 #[test]
594 fn roundtrip_event_context_warning() {
595 let (a, b) = socketpair().unwrap();
596 let mut sender = Channel::new(a);
597 let mut receiver = Channel::new(b);
598
599 let evt = Event::ContextWarning {
600 model: Some("claude-sonnet-4-5".into()),
601 output_bytes: 12_345,
602 uptime_secs: 61,
603 input_tokens: 80_000,
604 cached_input_tokens: 5_000,
605 cache_creation_input_tokens: 4_000,
606 cache_read_input_tokens: 3_000,
607 output_tokens: 6_000,
608 reasoning_output_tokens: 2_000,
609 used_tokens: 100_000,
610 context_limit_tokens: 200_000,
611 usage_pct: 50,
612 };
613 sender.send(&evt).unwrap();
614 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
615 match received {
616 Event::ContextWarning {
617 model,
618 output_bytes,
619 uptime_secs,
620 input_tokens,
621 cached_input_tokens,
622 cache_creation_input_tokens,
623 cache_read_input_tokens,
624 output_tokens,
625 reasoning_output_tokens,
626 used_tokens,
627 context_limit_tokens,
628 usage_pct,
629 } => {
630 assert_eq!(model.as_deref(), Some("claude-sonnet-4-5"));
631 assert_eq!(output_bytes, 12_345);
632 assert_eq!(uptime_secs, 61);
633 assert_eq!(input_tokens, 80_000);
634 assert_eq!(cached_input_tokens, 5_000);
635 assert_eq!(cache_creation_input_tokens, 4_000);
636 assert_eq!(cache_read_input_tokens, 3_000);
637 assert_eq!(output_tokens, 6_000);
638 assert_eq!(reasoning_output_tokens, 2_000);
639 assert_eq!(used_tokens, 100_000);
640 assert_eq!(context_limit_tokens, 200_000);
641 assert_eq!(usage_pct, 50);
642 }
643 _ => panic!("wrong variant"),
644 }
645 }
646
647 #[test]
648 fn roundtrip_event_session_stats() {
649 let (a, b) = socketpair().unwrap();
650 let mut sender = Channel::new(a);
651 let mut receiver = Channel::new(b);
652
653 let evt = Event::SessionStats {
654 output_bytes: 123_456,
655 uptime_secs: 61,
656 input_tokens: 5000,
657 output_tokens: 1200,
658 context_usage_pct: Some(84),
659 };
660 sender.send(&evt).unwrap();
661 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
662 match received {
663 Event::SessionStats {
664 output_bytes,
665 uptime_secs,
666 input_tokens,
667 output_tokens,
668 context_usage_pct,
669 } => {
670 assert_eq!(output_bytes, 123_456);
671 assert_eq!(uptime_secs, 61);
672 assert_eq!(input_tokens, 5000);
673 assert_eq!(output_tokens, 1200);
674 assert_eq!(context_usage_pct, Some(84));
675 }
676 _ => panic!("wrong variant"),
677 }
678 }
679
680 #[test]
681 fn roundtrip_event_context_approaching() {
682 let (a, b) = socketpair().unwrap();
683 let mut sender = Channel::new(a);
684 let mut receiver = Channel::new(b);
685
686 let evt = Event::ContextApproaching {
687 message: "context pressure detected".into(),
688 input_tokens: 80000,
689 output_tokens: 20000,
690 };
691 sender.send(&evt).unwrap();
692 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
693 match received {
694 Event::ContextApproaching {
695 message,
696 input_tokens,
697 output_tokens,
698 } => {
699 assert_eq!(message, "context pressure detected");
700 assert_eq!(input_tokens, 80000);
701 assert_eq!(output_tokens, 20000);
702 }
703 _ => panic!("wrong variant"),
704 }
705 }
706
707 #[test]
708 fn roundtrip_event_error() {
709 let (a, b) = socketpair().unwrap();
710 let mut sender = Channel::new(a);
711 let mut receiver = Channel::new(b);
712
713 let evt = Event::Error {
714 command: "SendMessage".into(),
715 reason: "agent busy".into(),
716 };
717 sender.send(&evt).unwrap();
718 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
719 match received {
720 Event::Error { command, reason } => {
721 assert_eq!(command, "SendMessage");
722 assert_eq!(reason, "agent busy");
723 }
724 _ => panic!("wrong variant"),
725 }
726 }
727
728 #[test]
729 fn roundtrip_event_warning() {
730 let (a, b) = socketpair().unwrap();
731 let mut sender = Channel::new(a);
732 let mut receiver = Channel::new(b);
733
734 let evt = Event::Warning {
735 message: "no screen change".into(),
736 idle_secs: Some(300),
737 };
738 sender.send(&evt).unwrap();
739 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
740 match received {
741 Event::Warning { message, idle_secs } => {
742 assert_eq!(message, "no screen change");
743 assert_eq!(idle_secs, Some(300));
744 }
745 _ => panic!("wrong variant"),
746 }
747 }
748
749 #[test]
750 fn eof_returns_none() {
751 let (a, b) = socketpair().unwrap();
752 drop(a); let mut receiver = Channel::new(b);
754 let result: Option<Command> = receiver.recv().unwrap();
755 assert!(result.is_none());
756 }
757
758 #[test]
759 fn all_states_serialize() {
760 for state in [
761 ShimState::Starting,
762 ShimState::Idle,
763 ShimState::Working,
764 ShimState::Dead,
765 ShimState::ContextExhausted,
766 ] {
767 let json = serde_json::to_string(&state).unwrap();
768 let back: ShimState = serde_json::from_str(&json).unwrap();
769 assert_eq!(state, back);
770 }
771 }
772
773 #[test]
774 fn shim_state_display() {
775 assert_eq!(ShimState::Starting.to_string(), "starting");
776 assert_eq!(ShimState::Idle.to_string(), "idle");
777 assert_eq!(ShimState::Working.to_string(), "working");
778 assert_eq!(ShimState::Dead.to_string(), "dead");
779 assert_eq!(ShimState::ContextExhausted.to_string(), "context_exhausted");
780 }
781
782 #[test]
783 fn socketpair_creates_connected_pair() {
784 let (a, b) = socketpair().unwrap();
785 let mut ch_a = Channel::new(a);
787 let mut ch_b = Channel::new(b);
788 ch_a.send(&Command::Ping).unwrap();
789 let msg: Command = ch_b.recv().unwrap().unwrap();
790 assert!(matches!(msg, Command::Ping));
791 }
792
793 #[test]
794 fn send_times_out_when_peer_stops_reading() {
795 let (a, _b) = socketpair().unwrap();
806 let mut sender = Channel::new(a);
807 sender
808 .set_write_timeout(Some(std::time::Duration::from_millis(50)))
809 .unwrap();
810
811 let big_body = "x".repeat(256 * 1024);
817 let cmd = Command::SendMessage {
818 from: "daemon".into(),
819 body: big_body,
820 message_id: None,
821 };
822
823 let start = std::time::Instant::now();
824 let mut attempts = 0;
825 let mut last_err = None;
826 while start.elapsed() < std::time::Duration::from_secs(5) {
827 attempts += 1;
828 match sender.send(&cmd) {
829 Ok(()) => continue,
830 Err(error) => {
831 last_err = Some(error);
832 break;
833 }
834 }
835 }
836 let error = last_err.expect("send should have timed out within 5s");
837 let io_error = error
838 .downcast_ref::<std::io::Error>()
839 .expect("write timeout should surface as an io::Error");
840 assert!(
841 matches!(
842 io_error.kind(),
843 std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
844 ),
845 "expected WouldBlock/TimedOut error, got {:?}",
846 io_error.kind()
847 );
848 assert!(
849 attempts >= 1,
850 "sanity check: send loop should have attempted at least once"
851 );
852 }
853}