1use serde::{Deserialize, Serialize};
7use std::io::{self, Read, Write};
8use std::os::unix::net::UnixStream;
9
10#[derive(Debug, Serialize, Deserialize)]
15#[serde(tag = "cmd")]
16pub enum Command {
17 SendMessage {
18 from: String,
19 body: String,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 message_id: Option<String>,
22 },
23 CaptureScreen {
24 last_n_lines: Option<usize>,
25 },
26 GetState,
27 Resize {
28 rows: u16,
29 cols: u16,
30 },
31 Shutdown {
32 timeout_secs: u32,
33 },
34 Kill,
35 Ping,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(tag = "event")]
44pub enum Event {
45 Ready,
46 StateChanged {
47 from: ShimState,
48 to: ShimState,
49 summary: String,
50 },
51 Completion {
52 #[serde(skip_serializing_if = "Option::is_none")]
53 message_id: Option<String>,
54 response: String,
55 last_lines: String,
56 },
57 Died {
58 exit_code: Option<i32>,
59 last_lines: String,
60 },
61 ContextExhausted {
62 message: String,
63 last_lines: String,
64 },
65 ScreenCapture {
66 content: String,
67 cursor_row: u16,
68 cursor_col: u16,
69 },
70 State {
71 state: ShimState,
72 since_secs: u64,
73 },
74 Pong,
75 Warning {
76 message: String,
77 idle_secs: Option<u64>,
78 },
79 Error {
80 command: String,
81 reason: String,
82 },
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
90#[serde(rename_all = "snake_case")]
91pub enum ShimState {
92 Starting,
93 Idle,
94 Working,
95 Dead,
96 ContextExhausted,
97}
98
99impl std::fmt::Display for ShimState {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 Self::Starting => write!(f, "starting"),
103 Self::Idle => write!(f, "idle"),
104 Self::Working => write!(f, "working"),
105 Self::Dead => write!(f, "dead"),
106 Self::ContextExhausted => write!(f, "context_exhausted"),
107 }
108 }
109}
110
111pub struct Channel {
119 stream: UnixStream,
120 read_buf: Vec<u8>,
121}
122
123const MAX_MSG: usize = 1_048_576; impl Channel {
126 pub fn new(stream: UnixStream) -> Self {
127 Self {
128 stream,
129 read_buf: vec![0u8; 4096],
130 }
131 }
132
133 pub fn send<T: Serialize>(&mut self, msg: &T) -> anyhow::Result<()> {
135 let json = serde_json::to_vec(msg)?;
136 if json.len() > MAX_MSG {
137 anyhow::bail!("message too large: {} bytes", json.len());
138 }
139 let len = (json.len() as u32).to_be_bytes();
140 self.stream.write_all(&len)?;
141 self.stream.write_all(&json)?;
142 self.stream.flush()?;
143 Ok(())
144 }
145
146 pub fn recv<T: for<'de> Deserialize<'de>>(&mut self) -> anyhow::Result<Option<T>> {
149 let mut len_buf = [0u8; 4];
150 match self.stream.read_exact(&mut len_buf) {
151 Ok(()) => {}
152 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
153 Err(e) => return Err(e.into()),
154 }
155 let len = u32::from_be_bytes(len_buf) as usize;
156 if len > MAX_MSG {
157 anyhow::bail!("incoming message too large: {} bytes", len);
158 }
159 if self.read_buf.len() < len {
160 self.read_buf.resize(len, 0);
161 }
162 self.stream.read_exact(&mut self.read_buf[..len])?;
163 let msg = serde_json::from_slice(&self.read_buf[..len])?;
164 Ok(Some(msg))
165 }
166
167 pub fn set_read_timeout(&mut self, timeout: Option<std::time::Duration>) -> anyhow::Result<()> {
171 self.stream.set_read_timeout(timeout)?;
172 Ok(())
173 }
174
175 pub fn try_clone(&self) -> anyhow::Result<Self> {
177 Ok(Self {
178 stream: self.stream.try_clone()?,
179 read_buf: vec![0u8; 4096],
180 })
181 }
182}
183
184pub fn socketpair() -> anyhow::Result<(UnixStream, UnixStream)> {
191 let (a, b) = UnixStream::pair()?;
192 Ok((a, b))
193}
194
195#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn roundtrip_command_send_message() {
205 let (a, b) = socketpair().unwrap();
206 let mut sender = Channel::new(a);
207 let mut receiver = Channel::new(b);
208
209 let cmd = Command::SendMessage {
210 from: "user".into(),
211 body: "say hello".into(),
212 message_id: Some("msg-1".into()),
213 };
214 sender.send(&cmd).unwrap();
215 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
216
217 match received {
218 Command::SendMessage {
219 from,
220 body,
221 message_id,
222 } => {
223 assert_eq!(from, "user");
224 assert_eq!(body, "say hello");
225 assert_eq!(message_id.as_deref(), Some("msg-1"));
226 }
227 _ => panic!("wrong variant"),
228 }
229 }
230
231 #[test]
232 fn roundtrip_command_capture_screen() {
233 let (a, b) = socketpair().unwrap();
234 let mut sender = Channel::new(a);
235 let mut receiver = Channel::new(b);
236
237 let cmd = Command::CaptureScreen {
238 last_n_lines: Some(10),
239 };
240 sender.send(&cmd).unwrap();
241 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
242 match received {
243 Command::CaptureScreen { last_n_lines } => assert_eq!(last_n_lines, Some(10)),
244 _ => panic!("wrong variant"),
245 }
246 }
247
248 #[test]
249 fn roundtrip_command_get_state() {
250 let (a, b) = socketpair().unwrap();
251 let mut sender = Channel::new(a);
252 let mut receiver = Channel::new(b);
253
254 sender.send(&Command::GetState).unwrap();
255 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
256 assert!(matches!(received, Command::GetState));
257 }
258
259 #[test]
260 fn roundtrip_command_resize() {
261 let (a, b) = socketpair().unwrap();
262 let mut sender = Channel::new(a);
263 let mut receiver = Channel::new(b);
264
265 let cmd = Command::Resize {
266 rows: 50,
267 cols: 220,
268 };
269 sender.send(&cmd).unwrap();
270 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
271 match received {
272 Command::Resize { rows, cols } => {
273 assert_eq!(rows, 50);
274 assert_eq!(cols, 220);
275 }
276 _ => panic!("wrong variant"),
277 }
278 }
279
280 #[test]
281 fn roundtrip_command_shutdown() {
282 let (a, b) = socketpair().unwrap();
283 let mut sender = Channel::new(a);
284 let mut receiver = Channel::new(b);
285
286 let cmd = Command::Shutdown { timeout_secs: 30 };
287 sender.send(&cmd).unwrap();
288 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
289 match received {
290 Command::Shutdown { timeout_secs } => assert_eq!(timeout_secs, 30),
291 _ => panic!("wrong variant"),
292 }
293 }
294
295 #[test]
296 fn roundtrip_command_kill() {
297 let (a, b) = socketpair().unwrap();
298 let mut sender = Channel::new(a);
299 let mut receiver = Channel::new(b);
300
301 sender.send(&Command::Kill).unwrap();
302 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
303 assert!(matches!(received, Command::Kill));
304 }
305
306 #[test]
307 fn roundtrip_command_ping() {
308 let (a, b) = socketpair().unwrap();
309 let mut sender = Channel::new(a);
310 let mut receiver = Channel::new(b);
311
312 sender.send(&Command::Ping).unwrap();
313 let received: Command = receiver.recv::<Command>().unwrap().unwrap();
314 assert!(matches!(received, Command::Ping));
315 }
316
317 #[test]
318 fn roundtrip_event_completion() {
319 let (a, b) = socketpair().unwrap();
320 let mut sender = Channel::new(a);
321 let mut receiver = Channel::new(b);
322
323 let evt = Event::Completion {
324 message_id: None,
325 response: "Hello!".into(),
326 last_lines: "Hello!\n❯".into(),
327 };
328 sender.send(&evt).unwrap();
329 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
330
331 match received {
332 Event::Completion { response, .. } => assert_eq!(response, "Hello!"),
333 _ => panic!("wrong variant"),
334 }
335 }
336
337 #[test]
338 fn roundtrip_event_state_changed() {
339 let (a, b) = socketpair().unwrap();
340 let mut sender = Channel::new(a);
341 let mut receiver = Channel::new(b);
342
343 let evt = Event::StateChanged {
344 from: ShimState::Idle,
345 to: ShimState::Working,
346 summary: "working now".into(),
347 };
348 sender.send(&evt).unwrap();
349 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
350 match received {
351 Event::StateChanged { from, to, summary } => {
352 assert_eq!(from, ShimState::Idle);
353 assert_eq!(to, ShimState::Working);
354 assert_eq!(summary, "working now");
355 }
356 _ => panic!("wrong variant"),
357 }
358 }
359
360 #[test]
361 fn roundtrip_event_ready() {
362 let (a, b) = socketpair().unwrap();
363 let mut sender = Channel::new(a);
364 let mut receiver = Channel::new(b);
365
366 sender.send(&Event::Ready).unwrap();
367 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
368 assert!(matches!(received, Event::Ready));
369 }
370
371 #[test]
372 fn roundtrip_event_pong() {
373 let (a, b) = socketpair().unwrap();
374 let mut sender = Channel::new(a);
375 let mut receiver = Channel::new(b);
376
377 sender.send(&Event::Pong).unwrap();
378 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
379 assert!(matches!(received, Event::Pong));
380 }
381
382 #[test]
383 fn roundtrip_event_died() {
384 let (a, b) = socketpair().unwrap();
385 let mut sender = Channel::new(a);
386 let mut receiver = Channel::new(b);
387
388 let evt = Event::Died {
389 exit_code: Some(1),
390 last_lines: "error occurred".into(),
391 };
392 sender.send(&evt).unwrap();
393 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
394 match received {
395 Event::Died {
396 exit_code,
397 last_lines,
398 } => {
399 assert_eq!(exit_code, Some(1));
400 assert_eq!(last_lines, "error occurred");
401 }
402 _ => panic!("wrong variant"),
403 }
404 }
405
406 #[test]
407 fn roundtrip_event_context_exhausted() {
408 let (a, b) = socketpair().unwrap();
409 let mut sender = Channel::new(a);
410 let mut receiver = Channel::new(b);
411
412 let evt = Event::ContextExhausted {
413 message: "context full".into(),
414 last_lines: "last output".into(),
415 };
416 sender.send(&evt).unwrap();
417 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
418 match received {
419 Event::ContextExhausted {
420 message,
421 last_lines,
422 } => {
423 assert_eq!(message, "context full");
424 assert_eq!(last_lines, "last output");
425 }
426 _ => panic!("wrong variant"),
427 }
428 }
429
430 #[test]
431 fn roundtrip_event_screen_capture() {
432 let (a, b) = socketpair().unwrap();
433 let mut sender = Channel::new(a);
434 let mut receiver = Channel::new(b);
435
436 let evt = Event::ScreenCapture {
437 content: "screen data".into(),
438 cursor_row: 5,
439 cursor_col: 10,
440 };
441 sender.send(&evt).unwrap();
442 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
443 match received {
444 Event::ScreenCapture {
445 content,
446 cursor_row,
447 cursor_col,
448 } => {
449 assert_eq!(content, "screen data");
450 assert_eq!(cursor_row, 5);
451 assert_eq!(cursor_col, 10);
452 }
453 _ => panic!("wrong variant"),
454 }
455 }
456
457 #[test]
458 fn roundtrip_event_error() {
459 let (a, b) = socketpair().unwrap();
460 let mut sender = Channel::new(a);
461 let mut receiver = Channel::new(b);
462
463 let evt = Event::Error {
464 command: "SendMessage".into(),
465 reason: "agent busy".into(),
466 };
467 sender.send(&evt).unwrap();
468 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
469 match received {
470 Event::Error { command, reason } => {
471 assert_eq!(command, "SendMessage");
472 assert_eq!(reason, "agent busy");
473 }
474 _ => panic!("wrong variant"),
475 }
476 }
477
478 #[test]
479 fn roundtrip_event_warning() {
480 let (a, b) = socketpair().unwrap();
481 let mut sender = Channel::new(a);
482 let mut receiver = Channel::new(b);
483
484 let evt = Event::Warning {
485 message: "no screen change".into(),
486 idle_secs: Some(300),
487 };
488 sender.send(&evt).unwrap();
489 let received: Event = receiver.recv::<Event>().unwrap().unwrap();
490 match received {
491 Event::Warning { message, idle_secs } => {
492 assert_eq!(message, "no screen change");
493 assert_eq!(idle_secs, Some(300));
494 }
495 _ => panic!("wrong variant"),
496 }
497 }
498
499 #[test]
500 fn eof_returns_none() {
501 let (a, b) = socketpair().unwrap();
502 drop(a); let mut receiver = Channel::new(b);
504 let result: Option<Command> = receiver.recv().unwrap();
505 assert!(result.is_none());
506 }
507
508 #[test]
509 fn all_states_serialize() {
510 for state in [
511 ShimState::Starting,
512 ShimState::Idle,
513 ShimState::Working,
514 ShimState::Dead,
515 ShimState::ContextExhausted,
516 ] {
517 let json = serde_json::to_string(&state).unwrap();
518 let back: ShimState = serde_json::from_str(&json).unwrap();
519 assert_eq!(state, back);
520 }
521 }
522
523 #[test]
524 fn shim_state_display() {
525 assert_eq!(ShimState::Starting.to_string(), "starting");
526 assert_eq!(ShimState::Idle.to_string(), "idle");
527 assert_eq!(ShimState::Working.to_string(), "working");
528 assert_eq!(ShimState::Dead.to_string(), "dead");
529 assert_eq!(ShimState::ContextExhausted.to_string(), "context_exhausted");
530 }
531
532 #[test]
533 fn socketpair_creates_connected_pair() {
534 let (a, b) = socketpair().unwrap();
535 let mut ch_a = Channel::new(a);
537 let mut ch_b = Channel::new(b);
538 ch_a.send(&Command::Ping).unwrap();
539 let msg: Command = ch_b.recv().unwrap().unwrap();
540 assert!(matches!(msg, Command::Ping));
541 }
542}