1use std::collections::HashMap;
23#[cfg(feature = "stream")]
24use std::future::Future;
25#[cfg(feature = "uds")]
26use std::path::Path;
27#[cfg(feature = "stream")]
28use std::pin::Pin;
29use std::sync::{Arc, atomic::AtomicU32};
30#[cfg(feature = "stream")]
31use std::time::Duration;
32
33#[cfg(feature = "stream")]
34use microsandbox_protocol::message::FLAG_TERMINAL;
35#[cfg(feature = "stream")]
36use microsandbox_protocol::{codec::MAX_FRAME_SIZE, message::FRAME_HEADER_SIZE};
37use microsandbox_protocol::{
38 codec::{self, RawFrame},
39 core::Ready,
40 message::{Message, MessageType, PROTOCOL_VERSION},
41};
42use serde::Serialize;
43#[cfg(feature = "stream")]
44use tokio::io::{AsyncRead, AsyncWrite};
45#[cfg(feature = "uds")]
46use tokio::net::UnixStream;
47use tokio::sync::{Mutex, mpsc, oneshot};
48use tokio::task::JoinHandle;
49#[cfg(feature = "stream")]
50use tokio::time::Instant;
51
52use super::error::{AgentClientError, AgentClientResult};
53
54#[cfg(feature = "stream")]
60const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
61
62#[cfg(feature = "stream")]
63const WRITER_QUEUE_CAPACITY: usize = 1024;
64const REQUEST_QUEUE_CAPACITY: usize = 1;
65const STREAM_QUEUE_CAPACITY: usize = 1024;
66
67const LEGACY_PROTOCOL_VERSION: u8 = 1;
68#[cfg(feature = "stream")]
71const LEGACY_RELAY_ID_RANGE_STEP: u32 = u32::MAX / 16;
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum AgentProtocol {
80 Current,
82
83 LegacyV1,
88}
89
90pub struct AgentClient {
94 writer: mpsc::Sender<WriterCommand>,
96 next_id: AtomicU32,
98 id_min: u32,
100 id_max: u32,
102 protocol: AgentProtocol,
104 negotiated_version: u8,
109 pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
111 reader_handle: JoinHandle<()>,
113 writer_handle: JoinHandle<()>,
115 ready_body: Vec<u8>,
117 ready: Ready,
119}
120
121#[cfg(feature = "stream")]
122struct AgentHandshake {
123 id_min: u32,
124 id_max: u32,
125 protocol: AgentProtocol,
126 negotiated_version: u8,
127 ready_body: Vec<u8>,
128 ready: Ready,
129}
130
131#[cfg_attr(not(feature = "stream"), allow(dead_code))]
132struct WriterCommand {
133 frame: RawFrame,
134 ack: oneshot::Sender<AgentClientResult<()>>,
135}
136
137#[cfg(feature = "stream")]
138trait HandshakeReader {
139 fn read_exact_handshake<'a>(
140 &'a mut self,
141 out: &'a mut [u8],
142 ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>>;
143
144 fn read_frame_handshake<'a>(
145 &'a mut self,
146 ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>>;
147}
148
149impl AgentProtocol {
154 fn version(self) -> u8 {
155 match self {
156 Self::Current => PROTOCOL_VERSION,
157 Self::LegacyV1 => LEGACY_PROTOCOL_VERSION,
158 }
159 }
160}
161
162impl AgentClient {
163 #[cfg(feature = "uds")]
166 pub async fn connect(sock_path: impl AsRef<Path>) -> AgentClientResult<Self> {
167 Self::connect_with_timeout(sock_path, DEFAULT_HANDSHAKE_TIMEOUT).await
168 }
169
170 #[cfg(feature = "uds")]
173 pub async fn connect_with_timeout(
174 sock_path: impl AsRef<Path>,
175 timeout: Duration,
176 ) -> AgentClientResult<Self> {
177 let deadline = Instant::now() + timeout;
178 Self::connect_with_deadline(sock_path, deadline).await
179 }
180
181 #[cfg(feature = "uds")]
187 pub async fn connect_with_deadline(
188 sock_path: impl AsRef<Path>,
189 deadline: Instant,
190 ) -> AgentClientResult<Self> {
191 let sock_path = sock_path.as_ref();
192 let stream =
193 UnixStream::connect(sock_path)
194 .await
195 .map_err(|source| AgentClientError::Connect {
196 path: sock_path.to_path_buf(),
197 source,
198 })?;
199 Self::connect_stream_with_deadline(stream, deadline).await
200 }
201
202 #[cfg(feature = "stream")]
212 pub async fn connect_stream<S>(stream: S) -> AgentClientResult<Self>
213 where
214 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
215 {
216 Self::connect_stream_with_timeout(stream, DEFAULT_HANDSHAKE_TIMEOUT).await
217 }
218
219 #[cfg(feature = "stream")]
222 pub async fn connect_stream_with_timeout<S>(
223 stream: S,
224 timeout: Duration,
225 ) -> AgentClientResult<Self>
226 where
227 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
228 {
229 let deadline = Instant::now() + timeout;
230 Self::connect_stream_with_deadline(stream, deadline).await
231 }
232
233 #[cfg(feature = "stream")]
239 pub async fn connect_stream_with_deadline<S>(
240 stream: S,
241 deadline: Instant,
242 ) -> AgentClientResult<Self>
243 where
244 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
245 {
246 let (mut reader, writer) = tokio::io::split(stream);
247 let handshake = perform_handshake(&mut reader, deadline).await?;
248
249 tracing::info!(
250 id_min = handshake.id_min,
251 id_max = handshake.id_max,
252 protocol = ?handshake.protocol,
253 ready_bytes = handshake.ready_body.len(),
254 boot_time_ns = handshake.ready.boot_time_ns,
255 "agent client: connected to relay"
256 );
257 if handshake.protocol == AgentProtocol::LegacyV1 {
258 tracing::warn!(
261 "agent client: connected to a sandbox started before microsandbox 0.5; exec compatibility is temporary and filesystem/SFTP require stop/start"
262 );
263 }
264
265 let pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>> =
266 Arc::new(Mutex::new(HashMap::new()));
267
268 let (writer_tx, writer_rx) = mpsc::channel(WRITER_QUEUE_CAPACITY);
269 let reader_handle = tokio::spawn(reader_loop(reader, Arc::clone(&pending)));
270 let writer_handle = tokio::spawn(stream_writer_loop(writer, writer_rx));
271
272 Ok(Self {
273 writer: writer_tx,
274 next_id: AtomicU32::new(first_request_id(handshake.id_min)),
275 id_min: handshake.id_min,
276 id_max: handshake.id_max,
277 protocol: handshake.protocol,
278 negotiated_version: handshake.negotiated_version,
279 pending,
280 reader_handle,
281 writer_handle,
282 ready_body: handshake.ready_body,
283 ready: handshake.ready,
284 })
285 }
286
287 pub async fn close(self) {
290 }
294}
295
296impl AgentClient {
301 pub async fn request_raw(&self, flags: u8, body: Vec<u8>) -> AgentClientResult<RawFrame> {
307 let (tx, mut rx) = mpsc::channel(REQUEST_QUEUE_CAPACITY);
308 let id = self.reserve_id(tx).await?;
309
310 if let Err(e) = self.write_frame_owned(id, flags, body).await {
311 self.pending.lock().await.remove(&id);
312 return Err(e);
313 }
314
315 let frame = rx.recv().await.ok_or(AgentClientError::ReaderClosed(id))?;
316 self.pending.lock().await.remove(&id);
317 Ok(frame)
318 }
319
320 pub async fn stream_raw(
328 &self,
329 flags: u8,
330 body: Vec<u8>,
331 ) -> AgentClientResult<(u32, mpsc::Receiver<RawFrame>)> {
332 let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
333 let id = self.reserve_id(tx).await?;
334
335 if let Err(e) = self.write_frame_owned(id, flags, body).await {
336 self.pending.lock().await.remove(&id);
337 return Err(e);
338 }
339
340 Ok((id, rx))
341 }
342
343 pub async fn send_raw(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
349 self.write_frame(id, flags, body).await
350 }
351
352 pub fn ready_bytes(&self) -> &[u8] {
357 &self.ready_body
358 }
359
360 pub fn protocol(&self) -> AgentProtocol {
362 self.protocol
363 }
364
365 pub fn is_legacy_protocol(&self) -> bool {
367 self.protocol == AgentProtocol::LegacyV1
368 }
369
370 pub fn negotiated_version(&self) -> u8 {
373 self.negotiated_version
374 }
375
376 pub fn agent_version(&self) -> &str {
380 &self.ready.agent_version
381 }
382
383 pub fn supports(&self, t: MessageType) -> bool {
388 t.min_protocol_version() <= self.negotiated_version
389 }
390
391 pub fn ensure_version_compat(&self, t: MessageType) -> AgentClientResult<()> {
395 Self::ensure_version_compat_for(t, self.negotiated_version)
396 }
397
398 pub fn ensure_version_compat_for(t: MessageType, negotiated: u8) -> AgentClientResult<()> {
403 if t.is_available_at(negotiated) {
404 return Ok(());
405 }
406 Err(AgentClientError::UnsupportedOperation {
407 msg_type: t.as_str(),
408 needs: t.min_protocol_version(),
409 peer: negotiated,
410 })
411 }
412}
413
414impl AgentClient {
419 pub async fn request<T: Serialize>(
421 &self,
422 t: MessageType,
423 payload: &T,
424 ) -> AgentClientResult<Message> {
425 self.ensure_version_compat(t)?;
426 let flags = t.flags();
427 let body = encode_message_body(self.protocol.version(), t, payload)?;
428 let frame = self.request_raw(flags, body).await?;
429 Ok(codec::raw_frame_to_message(frame)?)
430 }
431
432 pub async fn stream<T: Serialize>(
435 &self,
436 t: MessageType,
437 payload: &T,
438 ) -> AgentClientResult<(u32, mpsc::Receiver<Message>)> {
439 self.ensure_version_compat(t)?;
440 let flags = t.flags();
441 let body = encode_message_body(self.protocol.version(), t, payload)?;
442 let (id, raw_rx) = self.stream_raw(flags, body).await?;
443
444 let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
445 tokio::spawn(decode_stream_task(raw_rx, tx));
446 Ok((id, rx))
447 }
448
449 pub async fn send<T: Serialize>(
451 &self,
452 id: u32,
453 t: MessageType,
454 payload: &T,
455 ) -> AgentClientResult<()> {
456 self.ensure_version_compat(t)?;
457 let flags = t.flags();
458 let body = encode_message_body(self.protocol.version(), t, payload)?;
459 self.write_frame_owned(id, flags, body).await
460 }
461
462 pub fn ready(&self) -> AgentClientResult<Ready> {
464 Ok(self.ready.clone())
465 }
466}
467
468impl AgentClient {
473 async fn reserve_id(&self, tx: mpsc::Sender<RawFrame>) -> AgentClientResult<u32> {
478 let mut pending = self.pending.lock().await;
479 let attempts = usable_id_count(self.id_min, self.id_max);
480 for _ in 0..attempts {
481 let id = self
482 .next_id
483 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
484 if self.next_id.load(std::sync::atomic::Ordering::Relaxed) >= self.id_max {
485 self.next_id.store(
486 first_request_id(self.id_min),
487 std::sync::atomic::Ordering::Relaxed,
488 );
489 }
490 if id == 0 || id < self.id_min || id >= self.id_max || pending.contains_key(&id) {
491 continue;
492 }
493 pending.insert(id, tx);
494 return Ok(id);
495 }
496
497 Err(AgentClientError::IdRangeExhausted)
498 }
499
500 async fn write_frame(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
502 self.write_frame_owned(id, flags, body.to_vec()).await
503 }
504
505 async fn write_frame_owned(&self, id: u32, flags: u8, body: Vec<u8>) -> AgentClientResult<()> {
507 let (ack, written) = oneshot::channel();
508 self.writer
509 .send(WriterCommand {
510 frame: RawFrame { id, flags, body },
511 ack,
512 })
513 .await
514 .map_err(|_| AgentClientError::Closed)?;
515 written.await.map_err(|_| AgentClientError::Closed)?
516 }
517}
518
519#[cfg(feature = "stream")]
524async fn perform_handshake<R>(
525 reader: &mut R,
526 deadline: Instant,
527) -> AgentClientResult<AgentHandshake>
528where
529 R: HandshakeReader + ?Sized,
530{
531 let mut range_buf = [0u8; 8];
540 tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut range_buf))
541 .await
542 .map_err(|_| {
543 AgentClientError::Handshake("read id range: timed out before relay sent bytes".into())
544 })??;
545 let id_start_or_offset = u32::from_be_bytes(range_buf[0..4].try_into().unwrap());
546 let id_max_or_frame_len = u32::from_be_bytes(range_buf[4..8].try_into().unwrap());
547
548 let legacy_handshake =
549 looks_like_legacy_relay_handshake(id_start_or_offset, id_max_or_frame_len);
550 let (id_min, id_max, ready_frame, protocol) = if legacy_handshake {
551 let id_offset = id_start_or_offset;
552 let ready_frame =
553 read_raw_frame_after_len_prefix(reader, range_buf[4..8].try_into().unwrap(), deadline)
554 .await?;
555 (
556 id_offset.saturating_add(1),
557 id_offset.saturating_add(LEGACY_RELAY_ID_RANGE_STEP),
558 ready_frame,
559 AgentProtocol::LegacyV1,
560 )
561 } else if id_start_or_offset >= id_max_or_frame_len {
562 return Err(AgentClientError::Handshake(format!(
563 "invalid relay id range: start={id_start_or_offset}, end={id_max_or_frame_len}"
564 )));
565 } else {
566 let ready_frame = tokio::time::timeout_at(deadline, reader.read_frame_handshake())
567 .await
568 .map_err(|_| {
569 AgentClientError::Handshake(
570 "read ready frame: timed out before relay sent frame".into(),
571 )
572 })?
573 .map_err(|e| AgentClientError::Handshake(format!("read ready frame: {e}")))?;
574 (
575 id_start_or_offset,
576 id_max_or_frame_len,
577 ready_frame,
578 AgentProtocol::Current,
579 )
580 };
581 ensure_usable_id_range(id_min, id_max)?;
582
583 let ready_msg = codec::raw_frame_to_message(ready_frame.clone())
584 .map_err(|e| AgentClientError::Handshake(format!("decode ready frame: {e}")))?;
585 if ready_msg.t != MessageType::Ready {
586 return Err(AgentClientError::Handshake(format!(
587 "expected core.ready frame, got {}",
588 ready_msg.t.as_str()
589 )));
590 }
591 let ready: Ready = ready_msg
592 .payload()
593 .map_err(|e| AgentClientError::Handshake(format!("decode ready payload: {e}")))?;
594
595 let negotiated_version = protocol.version().min(ready_msg.v);
601
602 Ok(AgentHandshake {
603 id_min,
604 id_max,
605 protocol,
606 negotiated_version,
607 ready_body: ready_frame.body,
608 ready,
609 })
610}
611
612fn first_request_id(id_min: u32) -> u32 {
613 id_min.max(1)
614}
615
616#[cfg(feature = "stream")]
617fn ensure_usable_id_range(id_min: u32, id_max: u32) -> AgentClientResult<()> {
618 if usable_id_count(id_min, id_max) == 0 {
619 return Err(AgentClientError::Handshake(format!(
620 "relay id range contains no usable nonzero ids: start={id_min}, end={id_max}"
621 )));
622 }
623 Ok(())
624}
625
626fn usable_id_count(id_min: u32, id_max: u32) -> u32 {
627 id_max.saturating_sub(first_request_id(id_min))
628}
629
630#[cfg(feature = "stream")]
631fn looks_like_legacy_relay_handshake(id_min: u32, id_max: u32) -> bool {
632 id_max >= FRAME_HEADER_SIZE as u32
641 && id_max <= MAX_FRAME_SIZE
642 && (id_min == 0 || id_min >= id_max)
643}
644
645#[cfg(feature = "stream")]
646async fn read_raw_frame_after_len_prefix<R>(
647 reader: &mut R,
648 len_buf: [u8; 4],
649 deadline: Instant,
650) -> AgentClientResult<RawFrame>
651where
652 R: HandshakeReader + ?Sized,
653{
654 let frame_len = u32::from_be_bytes(len_buf);
655 if frame_len > MAX_FRAME_SIZE {
656 return Err(AgentClientError::Handshake(format!(
657 "legacy ready frame too large: {frame_len} bytes (max {MAX_FRAME_SIZE})"
658 )));
659 }
660 if frame_len < FRAME_HEADER_SIZE as u32 {
661 return Err(AgentClientError::Handshake(format!(
662 "legacy ready frame too short: {frame_len} bytes"
663 )));
664 }
665
666 let mut data = vec![0u8; frame_len as usize];
667 tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut data))
668 .await
669 .map_err(|_| {
670 AgentClientError::Handshake(
671 "read legacy ready frame: timed out before relay sent frame".into(),
672 )
673 })?
674 .map_err(|e| AgentClientError::Handshake(format!("read legacy ready frame: {e}")))?;
675
676 let id = u32::from_be_bytes(data[0..4].try_into().unwrap());
677 let flags = data[4];
678 let body = data[FRAME_HEADER_SIZE..].to_vec();
679
680 Ok(RawFrame { id, flags, body })
681}
682
683#[cfg(feature = "stream")]
684impl<R> HandshakeReader for R
685where
686 R: tokio::io::AsyncRead + Unpin + Send,
687{
688 fn read_exact_handshake<'a>(
689 &'a mut self,
690 out: &'a mut [u8],
691 ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>> {
692 Box::pin(async move {
693 tokio::io::AsyncReadExt::read_exact(self, out)
694 .await
695 .map(|_| ())
696 .map_err(|e| AgentClientError::Handshake(e.to_string()))
697 })
698 }
699
700 fn read_frame_handshake<'a>(
701 &'a mut self,
702 ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>> {
703 Box::pin(async move {
704 codec::read_raw_frame(self)
705 .await
706 .map_err(AgentClientError::Protocol)
707 })
708 }
709}
710
711#[cfg(feature = "stream")]
712async fn stream_writer_loop<W>(mut writer: W, mut rx: mpsc::Receiver<WriterCommand>)
713where
714 W: tokio::io::AsyncWrite + Unpin,
715{
716 while let Some(command) = rx.recv().await {
717 if let Err(e) = codec::write_raw_frame(&mut writer, &command.frame).await {
718 tracing::debug!("agent client: stream writer error: {e}");
719 let _ = command.ack.send(Err(AgentClientError::Protocol(e)));
720 break;
721 }
722 let _ = command.ack.send(Ok(()));
723 }
724}
725
726#[cfg(feature = "stream")]
729async fn reader_loop<R>(mut reader: R, pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>)
730where
731 R: tokio::io::AsyncRead + Unpin,
732{
733 loop {
734 let frame = match codec::read_raw_frame(&mut reader).await {
735 Ok(frame) => frame,
736 Err(e) => {
737 tracing::debug!("agent client: reader EOF or error: {e}");
738 break;
739 }
740 };
741
742 dispatch_frame(frame, &pending).await;
743 }
744
745 let mut map = pending.lock().await;
747 map.clear();
748}
749
750#[cfg(feature = "stream")]
751async fn dispatch_frame(
752 frame: RawFrame,
753 pending: &Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
754) {
755 let id = frame.id;
756 let is_terminal = (frame.flags & FLAG_TERMINAL) != 0;
757
758 let tx = {
759 let mut map = pending.lock().await;
760 let Some(tx) = map.get(&id).cloned() else {
761 tracing::trace!("agent client: no pending handler for id={id}");
762 return;
763 };
764 if is_terminal {
765 map.remove(&id);
766 }
767 tx
768 };
769
770 if tx.send(frame).await.is_err() {
771 pending.lock().await.remove(&id);
772 }
773}
774
775async fn decode_stream_task(mut raw_rx: mpsc::Receiver<RawFrame>, tx: mpsc::Sender<Message>) {
777 while let Some(frame) = raw_rx.recv().await {
778 match codec::raw_frame_to_message(frame) {
779 Ok(msg) => {
780 if tx.send(msg).await.is_err() {
781 break;
782 }
783 }
784 Err(e) => {
785 tracing::warn!("agent client: failed to decode frame in stream: {e}");
786 }
788 }
789 }
790}
791
792fn encode_message_body<T: Serialize>(
794 version: u8,
795 t: MessageType,
796 payload: &T,
797) -> AgentClientResult<Vec<u8>> {
798 let mut msg = Message::with_payload(t, 0, payload)?;
799 msg.v = version;
800 let mut body = Vec::new();
801 ciborium::into_writer(&msg, &mut body).map_err(microsandbox_protocol::ProtocolError::from)?;
802 Ok(body)
803}
804
805#[cfg(test)]
810mod tests {
811 #[cfg(feature = "uds")]
812 use microsandbox_protocol::core::Ready;
813 #[cfg(feature = "uds")]
814 use microsandbox_protocol::exec::ExecRequest;
815 #[cfg(feature = "uds")]
816 use microsandbox_protocol::message::PROTOCOL_VERSION;
817 #[cfg(feature = "uds")]
818 use tokio::io::AsyncWriteExt;
819 #[cfg(feature = "uds")]
820 use tokio::net::UnixListener;
821 #[cfg(feature = "uds")]
822 use tokio::sync::oneshot;
823
824 use super::*;
825
826 #[cfg(feature = "uds")]
827 #[tokio::test]
828 async fn connect_decodes_ready_payload() {
829 let temp = tempfile::tempdir().unwrap();
830 let sock_path = temp.path().join("agent.sock");
831 let listener = UnixListener::bind(&sock_path).unwrap();
832 let ready = Ready {
833 boot_time_ns: 11,
834 init_time_ns: 22,
835 ready_time_ns: 33,
836 agent_version: "9.9.9".to_string(),
837 };
838 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
839
840 tokio::spawn(async move {
841 let (mut socket, _) = listener.accept().await.unwrap();
842 socket.write_all(&1u32.to_be_bytes()).await.unwrap();
843 socket.write_all(&8u32.to_be_bytes()).await.unwrap();
844 codec::write_message(&mut socket, &ready_msg).await.unwrap();
845 });
846
847 let client =
848 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
849 .await
850 .unwrap();
851
852 assert_eq!(client.protocol(), AgentProtocol::Current);
853 assert_eq!(client.negotiated_version(), PROTOCOL_VERSION);
855 assert!(client.supports(MessageType::FsRequest));
856 assert_eq!(client.agent_version(), "9.9.9");
858 let decoded = client.ready().unwrap();
859 assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
860 assert_eq!(decoded.init_time_ns, ready.init_time_ns);
861 assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
862
863 let raw_msg: Message = ciborium::from_reader(client.ready_bytes()).unwrap();
864 assert_eq!(raw_msg.t, MessageType::Ready);
865 }
866
867 #[cfg(feature = "uds")]
868 #[tokio::test]
869 async fn connect_negotiates_down_to_older_guest_generation() {
870 let temp = tempfile::tempdir().unwrap();
871 let sock_path = temp.path().join("agent.sock");
872 let listener = UnixListener::bind(&sock_path).unwrap();
873 let ready = Ready {
874 boot_time_ns: 1,
875 init_time_ns: 2,
876 ready_time_ns: 3,
877 ..Default::default()
878 };
879 let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
882 ready_msg.v = 1;
883
884 tokio::spawn(async move {
885 let (mut socket, _) = listener.accept().await.unwrap();
886 socket.write_all(&1u32.to_be_bytes()).await.unwrap();
887 socket
888 .write_all(µsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
889 .await
890 .unwrap();
891 codec::write_message(&mut socket, &ready_msg).await.unwrap();
892 });
893
894 let client =
895 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
896 .await
897 .unwrap();
898
899 assert_eq!(client.protocol(), AgentProtocol::Current);
902 assert_eq!(client.negotiated_version(), 1);
903 assert!(client.supports(MessageType::ExecRequest));
905 assert!(!client.supports(MessageType::FsRequest));
906 }
907
908 #[cfg(feature = "uds")]
909 #[tokio::test]
910 async fn connect_accepts_legacy_relay_handshake() {
911 assert_accepts_legacy_relay_handshake(0).await;
912 assert_accepts_legacy_relay_handshake(268_435_455).await;
913 }
914
915 #[cfg(feature = "uds")]
916 #[tokio::test]
917 async fn legacy_relay_requests_use_v1_and_legacy_id_range() {
918 let temp = tempfile::tempdir().unwrap();
919 let sock_path = temp.path().join("agent.sock");
920 let listener = UnixListener::bind(&sock_path).unwrap();
921 let ready = Ready {
922 boot_time_ns: 11,
923 init_time_ns: 22,
924 ready_time_ns: 33,
925 ..Default::default()
926 };
927 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
928 let id_offset = 268_435_455u32;
929 let (frame_tx, frame_rx) = oneshot::channel();
930
931 tokio::spawn(async move {
932 let (mut socket, _) = listener.accept().await.unwrap();
933 socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
934 codec::write_message(&mut socket, &ready_msg).await.unwrap();
935 let frame = codec::read_raw_frame(&mut socket).await.unwrap();
936 frame_tx.send(frame).unwrap();
937 });
938
939 let client =
940 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
941 .await
942 .unwrap();
943 let request = ExecRequest {
944 cmd: "/bin/true".into(),
945 args: Vec::new(),
946 env: Vec::new(),
947 cwd: None,
948 user: None,
949 tty: false,
950 rows: 24,
951 cols: 80,
952 rlimits: Vec::new(),
953 };
954 let (id, _rx) = client
955 .stream(MessageType::ExecRequest, &request)
956 .await
957 .unwrap();
958
959 let frame = frame_rx.await.unwrap();
960 let message = codec::raw_frame_to_message(frame).unwrap();
961
962 assert_eq!(id, id_offset + 1);
963 assert_eq!(message.id, id_offset + 1);
964 assert_eq!(message.v, LEGACY_PROTOCOL_VERSION);
965 assert_eq!(message.t, MessageType::ExecRequest);
966 }
967
968 #[test]
969 fn version_compat_across_generations() {
970 use MessageType::{ExecRequest, FsRequest};
971 let cases = [
975 (ExecRequest, 1, true),
976 (ExecRequest, 2, true),
977 (ExecRequest, 3, true),
978 (FsRequest, 1, false),
979 (FsRequest, 2, true),
980 (FsRequest, 3, true),
981 ];
982 for (t, generation, allowed) in cases {
983 assert_eq!(
984 AgentClient::ensure_version_compat_for(t, generation).is_ok(),
985 allowed,
986 "{t:?} at generation {generation}"
987 );
988 }
989 }
990
991 #[test]
992 fn version_compat_rejection_is_typed() {
993 let err =
996 AgentClient::ensure_version_compat_for(MessageType::FsRequest, LEGACY_PROTOCOL_VERSION)
997 .unwrap_err();
998 assert!(matches!(
999 err,
1000 AgentClientError::UnsupportedOperation {
1001 needs: 2,
1002 peer: 1,
1003 ..
1004 }
1005 ));
1006 }
1007
1008 #[cfg(feature = "uds")]
1009 #[tokio::test]
1010 async fn connect_preserves_current_peer_protocol_version() {
1011 let temp = tempfile::tempdir().unwrap();
1012 let sock_path = temp.path().join("agent.sock");
1013 let listener = UnixListener::bind(&sock_path).unwrap();
1014 let ready = Ready {
1015 boot_time_ns: 11,
1016 init_time_ns: 22,
1017 ready_time_ns: 33,
1018 ..Default::default()
1019 };
1020 let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1021 ready_msg.v = 2;
1022
1023 tokio::spawn(async move {
1024 let (mut socket, _) = listener.accept().await.unwrap();
1025 socket.write_all(&1u32.to_be_bytes()).await.unwrap();
1026 socket
1027 .write_all(µsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
1028 .await
1029 .unwrap();
1030 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1031 });
1032
1033 let client =
1034 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1035 .await
1036 .unwrap();
1037
1038 assert_eq!(client.protocol(), AgentProtocol::Current);
1039 assert_eq!(client.negotiated_version(), 2);
1041 assert!(!client.supports(MessageType::TcpConnect));
1043 }
1044
1045 #[cfg(feature = "uds")]
1046 async fn assert_accepts_legacy_relay_handshake(id_offset: u32) {
1047 let temp = tempfile::tempdir().unwrap();
1048 let sock_path = temp.path().join("agent.sock");
1049 let listener = UnixListener::bind(&sock_path).unwrap();
1050 let ready = Ready {
1051 boot_time_ns: 11,
1052 init_time_ns: 22,
1053 ready_time_ns: 33,
1054 ..Default::default()
1055 };
1056 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1057
1058 tokio::spawn(async move {
1059 let (mut socket, _) = listener.accept().await.unwrap();
1060 socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
1061 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1062 });
1063
1064 let client =
1065 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1066 .await
1067 .unwrap();
1068
1069 assert_eq!(client.protocol(), AgentProtocol::LegacyV1);
1070 assert_eq!(client.negotiated_version(), LEGACY_PROTOCOL_VERSION);
1071 let decoded = client.ready().unwrap();
1072 assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
1073 assert_eq!(decoded.init_time_ns, ready.init_time_ns);
1074 assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
1075 }
1076
1077 #[cfg(feature = "stream")]
1078 #[tokio::test]
1079 async fn connect_stream_handshakes_and_streams_exec() {
1080 use microsandbox_protocol::exec::{ExecExited, ExecRequest, ExecStdout};
1081 use tokio::io::AsyncWriteExt;
1082
1083 let (client_io, mut server_io) = tokio::io::duplex(64 * 1024);
1084 let ready = Ready {
1085 boot_time_ns: 11,
1086 init_time_ns: 22,
1087 ready_time_ns: 33,
1088 agent_version: "stream-test".to_string(),
1089 };
1090 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1091
1092 tokio::spawn(async move {
1093 server_io.write_all(&1u32.to_be_bytes()).await.unwrap();
1095 server_io.write_all(&1024u32.to_be_bytes()).await.unwrap();
1096 codec::write_message(&mut server_io, &ready_msg)
1097 .await
1098 .unwrap();
1099
1100 let request = codec::read_raw_frame(&mut server_io).await.unwrap();
1102 let stdout = Message::with_payload(
1103 MessageType::ExecStdout,
1104 request.id,
1105 &ExecStdout {
1106 data: b"hi".to_vec(),
1107 },
1108 )
1109 .unwrap();
1110 codec::write_message(&mut server_io, &stdout).await.unwrap();
1111 let exited =
1112 Message::with_payload(MessageType::ExecExited, request.id, &ExecExited { code: 0 })
1113 .unwrap();
1114 codec::write_message(&mut server_io, &exited).await.unwrap();
1115 });
1116
1117 let client = AgentClient::connect_stream_with_deadline(
1118 client_io,
1119 Instant::now() + Duration::from_secs(1),
1120 )
1121 .await
1122 .unwrap();
1123
1124 assert_eq!(client.protocol(), AgentProtocol::Current);
1125 assert_eq!(client.agent_version(), "stream-test");
1126 assert!(client.supports(MessageType::ExecRequest));
1127
1128 let request = ExecRequest {
1129 cmd: "echo".into(),
1130 args: vec!["hi".into()],
1131 env: Vec::new(),
1132 cwd: None,
1133 user: None,
1134 tty: false,
1135 rows: 24,
1136 cols: 80,
1137 rlimits: Vec::new(),
1138 };
1139 let (_id, mut rx) = client
1140 .stream(MessageType::ExecRequest, &request)
1141 .await
1142 .unwrap();
1143
1144 let first = rx.recv().await.unwrap();
1145 assert_eq!(first.t, MessageType::ExecStdout);
1146 let out: ExecStdout = first.payload().unwrap();
1147 assert_eq!(out.data, b"hi");
1148
1149 let second = rx.recv().await.unwrap();
1150 assert_eq!(second.t, MessageType::ExecExited);
1151 let exit: ExecExited = second.payload().unwrap();
1152 assert_eq!(exit.code, 0);
1153 }
1154}
1155
1156impl Drop for AgentClient {
1161 fn drop(&mut self) {
1162 self.reader_handle.abort();
1163 self.writer_handle.abort();
1164 }
1165}