1use std::collections::HashMap;
24#[cfg(feature = "stream")]
25use std::future::Future;
26#[cfg(any(all(feature = "named-pipe", windows), all(feature = "uds", unix)))]
27use std::path::Path;
28#[cfg(feature = "stream")]
29use std::pin::Pin;
30use std::sync::{Arc, atomic::AtomicU32};
31#[cfg(feature = "stream")]
32use std::time::Duration;
33
34#[cfg(feature = "stream")]
35use microsandbox_protocol::message::FLAG_TERMINAL;
36#[cfg(feature = "stream")]
37use microsandbox_protocol::{codec::MAX_FRAME_SIZE, message::FRAME_HEADER_SIZE};
38use microsandbox_protocol::{
39 codec::{self, RawFrame},
40 core::Ready,
41 message::{Message, MessageType, PROTOCOL_VERSION},
42};
43use serde::Serialize;
44#[cfg(feature = "stream")]
45use tokio::io::{AsyncRead, AsyncWrite};
46#[cfg(all(feature = "uds", unix))]
47use tokio::net::UnixStream;
48#[cfg(all(feature = "named-pipe", windows))]
49use tokio::net::windows::named_pipe::ClientOptions;
50use tokio::sync::{Mutex, mpsc, oneshot};
51use tokio::task::JoinHandle;
52#[cfg(feature = "stream")]
53use tokio::time::Instant;
54
55use super::error::{AgentClientError, AgentClientResult};
56
57#[cfg(feature = "stream")]
63const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
64
65#[cfg(all(feature = "named-pipe", windows))]
66const WINDOWS_PIPE_CONNECT_RETRY: Duration = Duration::from_millis(10);
67
68#[cfg(feature = "stream")]
69const WRITER_QUEUE_CAPACITY: usize = 1024;
70const REQUEST_QUEUE_CAPACITY: usize = 1;
71const STREAM_QUEUE_CAPACITY: usize = 1024;
72
73const LEGACY_PROTOCOL_VERSION: u8 = 1;
74#[cfg(feature = "stream")]
77const LEGACY_RELAY_ID_RANGE_STEP: u32 = u32::MAX / 16;
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum AgentProtocol {
86 Current,
88
89 LegacyV1,
94}
95
96pub struct AgentClient {
100 writer: mpsc::Sender<WriterCommand>,
102 next_id: AtomicU32,
104 id_min: u32,
106 id_max: u32,
108 protocol: AgentProtocol,
110 negotiated_version: u8,
115 pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
117 reader_handle: JoinHandle<()>,
119 writer_handle: JoinHandle<()>,
121 ready_body: Vec<u8>,
123 ready: Ready,
125}
126
127#[cfg(feature = "stream")]
128struct AgentHandshake {
129 id_min: u32,
130 id_max: u32,
131 protocol: AgentProtocol,
132 negotiated_version: u8,
133 ready_body: Vec<u8>,
134 ready: Ready,
135}
136
137#[cfg_attr(not(feature = "stream"), allow(dead_code))]
138struct WriterCommand {
139 frame: RawFrame,
140 ack: oneshot::Sender<AgentClientResult<()>>,
141}
142
143#[cfg(feature = "stream")]
144trait HandshakeReader {
145 fn read_exact_handshake<'a>(
146 &'a mut self,
147 out: &'a mut [u8],
148 ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>>;
149
150 fn read_frame_handshake<'a>(
151 &'a mut self,
152 ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>>;
153}
154
155impl AgentProtocol {
160 fn version(self) -> u8 {
161 match self {
162 Self::Current => PROTOCOL_VERSION,
163 Self::LegacyV1 => LEGACY_PROTOCOL_VERSION,
164 }
165 }
166}
167
168impl AgentClient {
169 #[cfg(any(all(feature = "named-pipe", windows), all(feature = "uds", unix)))]
174 pub async fn connect(sock_path: impl AsRef<Path>) -> AgentClientResult<Self> {
175 Self::connect_with_timeout(sock_path, DEFAULT_HANDSHAKE_TIMEOUT).await
176 }
177
178 #[cfg(any(all(feature = "named-pipe", windows), all(feature = "uds", unix)))]
180 pub async fn connect_with_timeout(
181 sock_path: impl AsRef<Path>,
182 timeout: Duration,
183 ) -> AgentClientResult<Self> {
184 let deadline = Instant::now() + timeout;
185 Self::connect_with_deadline(sock_path, deadline).await
186 }
187
188 #[cfg(any(all(feature = "named-pipe", windows), all(feature = "uds", unix)))]
194 pub async fn connect_with_deadline(
195 sock_path: impl AsRef<Path>,
196 deadline: Instant,
197 ) -> AgentClientResult<Self> {
198 let sock_path = sock_path.as_ref();
199 let stream = connect_local_stream(sock_path, deadline).await?;
200 Self::connect_stream_with_deadline(stream, deadline).await
201 }
202
203 #[cfg(feature = "stream")]
213 pub async fn connect_stream<S>(stream: S) -> AgentClientResult<Self>
214 where
215 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
216 {
217 Self::connect_stream_with_timeout(stream, DEFAULT_HANDSHAKE_TIMEOUT).await
218 }
219
220 #[cfg(feature = "stream")]
223 pub async fn connect_stream_with_timeout<S>(
224 stream: S,
225 timeout: Duration,
226 ) -> AgentClientResult<Self>
227 where
228 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
229 {
230 let deadline = Instant::now() + timeout;
231 Self::connect_stream_with_deadline(stream, deadline).await
232 }
233
234 #[cfg(feature = "stream")]
240 pub async fn connect_stream_with_deadline<S>(
241 stream: S,
242 deadline: Instant,
243 ) -> AgentClientResult<Self>
244 where
245 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
246 {
247 let (mut reader, writer) = tokio::io::split(stream);
248 let handshake = perform_handshake(&mut reader, deadline).await?;
249
250 tracing::info!(
251 id_min = handshake.id_min,
252 id_max = handshake.id_max,
253 protocol = ?handshake.protocol,
254 ready_bytes = handshake.ready_body.len(),
255 boot_time_ns = handshake.ready.boot_time_ns,
256 "agent client: connected to relay"
257 );
258 if handshake.protocol == AgentProtocol::LegacyV1 {
259 tracing::warn!(
262 "agent client: connected to a sandbox started before microsandbox 0.5; exec compatibility is temporary and filesystem/SFTP require stop/start"
263 );
264 }
265
266 let pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>> =
267 Arc::new(Mutex::new(HashMap::new()));
268
269 let (writer_tx, writer_rx) = mpsc::channel(WRITER_QUEUE_CAPACITY);
270 let reader_handle = tokio::spawn(reader_loop(reader, Arc::clone(&pending)));
271 let writer_handle = tokio::spawn(stream_writer_loop(writer, writer_rx));
272
273 Ok(Self {
274 writer: writer_tx,
275 next_id: AtomicU32::new(first_request_id(handshake.id_min)),
276 id_min: handshake.id_min,
277 id_max: handshake.id_max,
278 protocol: handshake.protocol,
279 negotiated_version: handshake.negotiated_version,
280 pending,
281 reader_handle,
282 writer_handle,
283 ready_body: handshake.ready_body,
284 ready: handshake.ready,
285 })
286 }
287
288 pub async fn close(self) {
291 }
295}
296
297impl AgentClient {
302 pub async fn request_raw(&self, flags: u8, body: Vec<u8>) -> AgentClientResult<RawFrame> {
308 let (tx, mut rx) = mpsc::channel(REQUEST_QUEUE_CAPACITY);
309 let id = self.reserve_id(tx).await?;
310
311 if let Err(e) = self.write_frame_owned(id, flags, body).await {
312 self.pending.lock().await.remove(&id);
313 return Err(e);
314 }
315
316 let frame = rx.recv().await.ok_or(AgentClientError::ReaderClosed(id))?;
317 self.pending.lock().await.remove(&id);
318 Ok(frame)
319 }
320
321 pub async fn stream_raw(
329 &self,
330 flags: u8,
331 body: Vec<u8>,
332 ) -> AgentClientResult<(u32, mpsc::Receiver<RawFrame>)> {
333 let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
334 let id = self.reserve_id(tx).await?;
335
336 if let Err(e) = self.write_frame_owned(id, flags, body).await {
337 self.pending.lock().await.remove(&id);
338 return Err(e);
339 }
340
341 Ok((id, rx))
342 }
343
344 pub async fn send_raw(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
350 self.write_frame(id, flags, body).await
351 }
352
353 pub fn ready_bytes(&self) -> &[u8] {
358 &self.ready_body
359 }
360
361 pub fn protocol(&self) -> AgentProtocol {
363 self.protocol
364 }
365
366 pub fn is_legacy_protocol(&self) -> bool {
368 self.protocol == AgentProtocol::LegacyV1
369 }
370
371 pub fn negotiated_version(&self) -> u8 {
374 self.negotiated_version
375 }
376
377 pub fn agent_version(&self) -> &str {
381 &self.ready.agent_version
382 }
383
384 pub fn supports(&self, t: MessageType) -> bool {
389 t.min_protocol_version() <= self.negotiated_version
390 }
391
392 pub fn ensure_version_compat(&self, t: MessageType) -> AgentClientResult<()> {
396 Self::ensure_version_compat_for(t, self.negotiated_version)
397 }
398
399 pub fn ensure_version_compat_for(t: MessageType, negotiated: u8) -> AgentClientResult<()> {
404 if t.is_available_at(negotiated) {
405 return Ok(());
406 }
407 Err(AgentClientError::UnsupportedOperation {
408 msg_type: t.as_str(),
409 needs: t.min_protocol_version(),
410 peer: negotiated,
411 })
412 }
413}
414
415impl AgentClient {
420 pub async fn request<T: Serialize>(
422 &self,
423 t: MessageType,
424 payload: &T,
425 ) -> AgentClientResult<Message> {
426 self.ensure_version_compat(t)?;
427 let flags = t.flags();
428 let body = encode_message_body(self.protocol.version(), t, payload)?;
429 let frame = self.request_raw(flags, body).await?;
430 Ok(codec::raw_frame_to_message(frame)?)
431 }
432
433 pub async fn stream<T: Serialize>(
436 &self,
437 t: MessageType,
438 payload: &T,
439 ) -> AgentClientResult<(u32, mpsc::Receiver<Message>)> {
440 self.ensure_version_compat(t)?;
441 let flags = t.flags();
442 let body = encode_message_body(self.protocol.version(), t, payload)?;
443 let (id, raw_rx) = self.stream_raw(flags, body).await?;
444
445 let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
446 tokio::spawn(decode_stream_task(raw_rx, tx));
447 Ok((id, rx))
448 }
449
450 pub async fn send<T: Serialize>(
452 &self,
453 id: u32,
454 t: MessageType,
455 payload: &T,
456 ) -> AgentClientResult<()> {
457 self.ensure_version_compat(t)?;
458 let flags = t.flags();
459 let body = encode_message_body(self.protocol.version(), t, payload)?;
460 self.write_frame_owned(id, flags, body).await
461 }
462
463 pub fn ready(&self) -> AgentClientResult<Ready> {
465 Ok(self.ready.clone())
466 }
467}
468
469impl AgentClient {
474 async fn reserve_id(&self, tx: mpsc::Sender<RawFrame>) -> AgentClientResult<u32> {
479 let mut pending = self.pending.lock().await;
480 let attempts = usable_id_count(self.id_min, self.id_max);
481 for _ in 0..attempts {
482 let id = self
483 .next_id
484 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
485 if self.next_id.load(std::sync::atomic::Ordering::Relaxed) >= self.id_max {
486 self.next_id.store(
487 first_request_id(self.id_min),
488 std::sync::atomic::Ordering::Relaxed,
489 );
490 }
491 if id == 0 || id < self.id_min || id >= self.id_max || pending.contains_key(&id) {
492 continue;
493 }
494 pending.insert(id, tx);
495 return Ok(id);
496 }
497
498 Err(AgentClientError::IdRangeExhausted)
499 }
500
501 async fn write_frame(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
503 self.write_frame_owned(id, flags, body.to_vec()).await
504 }
505
506 async fn write_frame_owned(&self, id: u32, flags: u8, body: Vec<u8>) -> AgentClientResult<()> {
508 let (ack, written) = oneshot::channel();
509 self.writer
510 .send(WriterCommand {
511 frame: RawFrame { id, flags, body },
512 ack,
513 })
514 .await
515 .map_err(|_| AgentClientError::Closed)?;
516 written.await.map_err(|_| AgentClientError::Closed)?
517 }
518}
519
520#[cfg(all(feature = "uds", unix))]
525async fn connect_local_stream(
526 sock_path: &Path,
527 _deadline: Instant,
528) -> AgentClientResult<UnixStream> {
529 UnixStream::connect(sock_path)
530 .await
531 .map_err(|source| AgentClientError::Connect {
532 path: sock_path.to_path_buf(),
533 source,
534 })
535}
536
537#[cfg(all(feature = "named-pipe", windows))]
538async fn connect_local_stream(
539 pipe_path: &Path,
540 deadline: Instant,
541) -> AgentClientResult<tokio::net::windows::named_pipe::NamedPipeClient> {
542 loop {
543 match ClientOptions::new().open(pipe_path) {
544 Ok(stream) => return Ok(stream),
545 Err(source)
546 if is_retryable_named_pipe_connect_error(&source) && Instant::now() < deadline =>
547 {
548 tokio::time::sleep(WINDOWS_PIPE_CONNECT_RETRY).await;
549 }
550 Err(source) => {
551 return Err(AgentClientError::Connect {
552 path: pipe_path.to_path_buf(),
553 source,
554 });
555 }
556 }
557 }
558}
559
560#[cfg(all(feature = "named-pipe", windows))]
561fn is_retryable_named_pipe_connect_error(error: &std::io::Error) -> bool {
562 const ERROR_PIPE_BUSY: i32 = 231;
563
564 error.kind() == std::io::ErrorKind::NotFound || error.raw_os_error() == Some(ERROR_PIPE_BUSY)
565}
566
567#[cfg(feature = "stream")]
568async fn perform_handshake<R>(
569 reader: &mut R,
570 deadline: Instant,
571) -> AgentClientResult<AgentHandshake>
572where
573 R: HandshakeReader + ?Sized,
574{
575 let mut range_buf = [0u8; 8];
584 tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut range_buf))
585 .await
586 .map_err(|_| {
587 AgentClientError::Handshake("read id range: timed out before relay sent bytes".into())
588 })??;
589 let id_start_or_offset = u32::from_be_bytes(range_buf[0..4].try_into().unwrap());
590 let id_max_or_frame_len = u32::from_be_bytes(range_buf[4..8].try_into().unwrap());
591
592 let legacy_handshake =
593 looks_like_legacy_relay_handshake(id_start_or_offset, id_max_or_frame_len);
594 let (id_min, id_max, ready_frame, protocol) = if legacy_handshake {
595 let id_offset = id_start_or_offset;
596 let ready_frame =
597 read_raw_frame_after_len_prefix(reader, range_buf[4..8].try_into().unwrap(), deadline)
598 .await?;
599 (
600 id_offset.saturating_add(1),
601 id_offset.saturating_add(LEGACY_RELAY_ID_RANGE_STEP),
602 ready_frame,
603 AgentProtocol::LegacyV1,
604 )
605 } else if id_start_or_offset >= id_max_or_frame_len {
606 return Err(AgentClientError::Handshake(format!(
607 "invalid relay id range: start={id_start_or_offset}, end={id_max_or_frame_len}"
608 )));
609 } else {
610 let ready_frame = tokio::time::timeout_at(deadline, reader.read_frame_handshake())
611 .await
612 .map_err(|_| {
613 AgentClientError::Handshake(
614 "read ready frame: timed out before relay sent frame".into(),
615 )
616 })?
617 .map_err(|e| AgentClientError::Handshake(format!("read ready frame: {e}")))?;
618 (
619 id_start_or_offset,
620 id_max_or_frame_len,
621 ready_frame,
622 AgentProtocol::Current,
623 )
624 };
625 ensure_usable_id_range(id_min, id_max)?;
626
627 let ready_msg = codec::raw_frame_to_message(ready_frame.clone())
628 .map_err(|e| AgentClientError::Handshake(format!("decode ready frame: {e}")))?;
629 if ready_msg.t != MessageType::Ready {
630 return Err(AgentClientError::Handshake(format!(
631 "expected core.ready frame, got {}",
632 ready_msg.t.as_str()
633 )));
634 }
635 let ready: Ready = ready_msg
636 .payload()
637 .map_err(|e| AgentClientError::Handshake(format!("decode ready payload: {e}")))?;
638
639 let negotiated_version = protocol.version().min(ready_msg.v);
645
646 Ok(AgentHandshake {
647 id_min,
648 id_max,
649 protocol,
650 negotiated_version,
651 ready_body: ready_frame.body,
652 ready,
653 })
654}
655
656fn first_request_id(id_min: u32) -> u32 {
657 id_min.max(1)
658}
659
660#[cfg(feature = "stream")]
661fn ensure_usable_id_range(id_min: u32, id_max: u32) -> AgentClientResult<()> {
662 if usable_id_count(id_min, id_max) == 0 {
663 return Err(AgentClientError::Handshake(format!(
664 "relay id range contains no usable nonzero ids: start={id_min}, end={id_max}"
665 )));
666 }
667 Ok(())
668}
669
670fn usable_id_count(id_min: u32, id_max: u32) -> u32 {
671 id_max.saturating_sub(first_request_id(id_min))
672}
673
674#[cfg(feature = "stream")]
675fn looks_like_legacy_relay_handshake(id_min: u32, id_max: u32) -> bool {
676 id_max >= FRAME_HEADER_SIZE as u32
685 && id_max <= MAX_FRAME_SIZE
686 && (id_min == 0 || id_min >= id_max)
687}
688
689#[cfg(feature = "stream")]
690async fn read_raw_frame_after_len_prefix<R>(
691 reader: &mut R,
692 len_buf: [u8; 4],
693 deadline: Instant,
694) -> AgentClientResult<RawFrame>
695where
696 R: HandshakeReader + ?Sized,
697{
698 let frame_len = u32::from_be_bytes(len_buf);
699 if frame_len > MAX_FRAME_SIZE {
700 return Err(AgentClientError::Handshake(format!(
701 "legacy ready frame too large: {frame_len} bytes (max {MAX_FRAME_SIZE})"
702 )));
703 }
704 if frame_len < FRAME_HEADER_SIZE as u32 {
705 return Err(AgentClientError::Handshake(format!(
706 "legacy ready frame too short: {frame_len} bytes"
707 )));
708 }
709
710 let mut data = vec![0u8; frame_len as usize];
711 tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut data))
712 .await
713 .map_err(|_| {
714 AgentClientError::Handshake(
715 "read legacy ready frame: timed out before relay sent frame".into(),
716 )
717 })?
718 .map_err(|e| AgentClientError::Handshake(format!("read legacy ready frame: {e}")))?;
719
720 let id = u32::from_be_bytes(data[0..4].try_into().unwrap());
721 let flags = data[4];
722 let body = data[FRAME_HEADER_SIZE..].to_vec();
723
724 Ok(RawFrame { id, flags, body })
725}
726
727#[cfg(feature = "stream")]
728impl<R> HandshakeReader for R
729where
730 R: tokio::io::AsyncRead + Unpin + Send,
731{
732 fn read_exact_handshake<'a>(
733 &'a mut self,
734 out: &'a mut [u8],
735 ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>> {
736 Box::pin(async move {
737 tokio::io::AsyncReadExt::read_exact(self, out)
738 .await
739 .map(|_| ())
740 .map_err(|e| AgentClientError::Handshake(e.to_string()))
741 })
742 }
743
744 fn read_frame_handshake<'a>(
745 &'a mut self,
746 ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>> {
747 Box::pin(async move {
748 codec::read_raw_frame(self)
749 .await
750 .map_err(AgentClientError::Protocol)
751 })
752 }
753}
754
755#[cfg(feature = "stream")]
756async fn stream_writer_loop<W>(mut writer: W, mut rx: mpsc::Receiver<WriterCommand>)
757where
758 W: tokio::io::AsyncWrite + Unpin,
759{
760 while let Some(command) = rx.recv().await {
761 if let Err(e) = codec::write_raw_frame(&mut writer, &command.frame).await {
762 tracing::debug!("agent client: stream writer error: {e}");
763 let _ = command.ack.send(Err(AgentClientError::Protocol(e)));
764 break;
765 }
766 let _ = command.ack.send(Ok(()));
767 }
768}
769
770#[cfg(feature = "stream")]
773async fn reader_loop<R>(mut reader: R, pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>)
774where
775 R: tokio::io::AsyncRead + Unpin,
776{
777 loop {
778 let frame = match codec::read_raw_frame(&mut reader).await {
779 Ok(frame) => frame,
780 Err(e) => {
781 tracing::debug!("agent client: reader EOF or error: {e}");
782 break;
783 }
784 };
785
786 dispatch_frame(frame, &pending).await;
787 }
788
789 let mut map = pending.lock().await;
791 map.clear();
792}
793
794#[cfg(feature = "stream")]
795async fn dispatch_frame(
796 frame: RawFrame,
797 pending: &Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
798) {
799 let id = frame.id;
800 let is_terminal = (frame.flags & FLAG_TERMINAL) != 0;
801
802 let tx = {
803 let mut map = pending.lock().await;
804 let Some(tx) = map.get(&id).cloned() else {
805 tracing::trace!("agent client: no pending handler for id={id}");
806 return;
807 };
808 if is_terminal {
809 map.remove(&id);
810 }
811 tx
812 };
813
814 if tx.send(frame).await.is_err() {
815 pending.lock().await.remove(&id);
816 }
817}
818
819async fn decode_stream_task(mut raw_rx: mpsc::Receiver<RawFrame>, tx: mpsc::Sender<Message>) {
821 while let Some(frame) = raw_rx.recv().await {
822 match codec::raw_frame_to_message(frame) {
823 Ok(msg) => {
824 if tx.send(msg).await.is_err() {
825 break;
826 }
827 }
828 Err(e) => {
829 tracing::warn!("agent client: failed to decode frame in stream: {e}");
830 }
832 }
833 }
834}
835
836fn encode_message_body<T: Serialize>(
838 version: u8,
839 t: MessageType,
840 payload: &T,
841) -> AgentClientResult<Vec<u8>> {
842 let mut msg = Message::with_payload(t, 0, payload)?;
843 msg.v = version;
844 let mut body = Vec::new();
845 ciborium::into_writer(&msg, &mut body).map_err(microsandbox_protocol::ProtocolError::from)?;
846 Ok(body)
847}
848
849#[cfg(test)]
854mod tests {
855 #[cfg(all(feature = "uds", unix))]
856 use microsandbox_protocol::core::Ready;
857 #[cfg(all(feature = "uds", unix))]
858 use microsandbox_protocol::exec::ExecRequest;
859 #[cfg(all(feature = "uds", unix))]
860 use microsandbox_protocol::message::PROTOCOL_VERSION;
861 #[cfg(all(feature = "uds", unix))]
862 use tokio::io::AsyncWriteExt;
863 #[cfg(all(feature = "uds", unix))]
864 use tokio::net::UnixListener;
865 #[cfg(all(feature = "uds", unix))]
866 use tokio::sync::oneshot;
867
868 use super::*;
869
870 #[cfg(all(feature = "uds", unix))]
871 #[tokio::test]
872 async fn connect_decodes_ready_payload() {
873 let temp = tempfile::tempdir().unwrap();
874 let sock_path = temp.path().join("agent.sock");
875 let listener = UnixListener::bind(&sock_path).unwrap();
876 let ready = Ready {
877 boot_time_ns: 11,
878 init_time_ns: 22,
879 ready_time_ns: 33,
880 agent_version: "9.9.9".to_string(),
881 };
882 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
883
884 tokio::spawn(async move {
885 let (mut socket, _) = listener.accept().await.unwrap();
886 socket.write_all(&1u32.to_be_bytes()).await.unwrap();
887 socket.write_all(&8u32.to_be_bytes()).await.unwrap();
888 codec::write_message(&mut socket, &ready_msg).await.unwrap();
889 });
890
891 let client =
892 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
893 .await
894 .unwrap();
895
896 assert_eq!(client.protocol(), AgentProtocol::Current);
897 assert_eq!(client.negotiated_version(), PROTOCOL_VERSION);
899 assert!(client.supports(MessageType::FsRequest));
900 assert_eq!(client.agent_version(), "9.9.9");
902 let decoded = client.ready().unwrap();
903 assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
904 assert_eq!(decoded.init_time_ns, ready.init_time_ns);
905 assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
906
907 let raw_msg: Message = ciborium::from_reader(client.ready_bytes()).unwrap();
908 assert_eq!(raw_msg.t, MessageType::Ready);
909 }
910
911 #[cfg(all(feature = "named-pipe", windows))]
912 #[tokio::test]
913 async fn connect_decodes_ready_payload_from_named_pipe() {
914 use microsandbox_protocol::core::Ready;
915 use microsandbox_protocol::message::PROTOCOL_VERSION;
916 use tokio::io::AsyncWriteExt;
917 use tokio::net::windows::named_pipe::{PipeMode, ServerOptions};
918
919 let pipe_path = unique_named_pipe("ready");
920 let server = ServerOptions::new()
921 .first_pipe_instance(true)
922 .pipe_mode(PipeMode::Byte)
923 .create(&pipe_path)
924 .unwrap();
925 let ready = Ready {
926 boot_time_ns: 11,
927 init_time_ns: 22,
928 ready_time_ns: 33,
929 agent_version: "named-pipe-test".to_string(),
930 };
931 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
932
933 tokio::spawn(async move {
934 let mut server = server;
935 server.connect().await.unwrap();
936 server.write_all(&1u32.to_be_bytes()).await.unwrap();
937 server.write_all(&8u32.to_be_bytes()).await.unwrap();
938 codec::write_message(&mut server, &ready_msg).await.unwrap();
939 });
940
941 let client = AgentClient::connect_with_deadline(
942 std::path::Path::new(&pipe_path),
943 Instant::now() + Duration::from_secs(1),
944 )
945 .await
946 .unwrap();
947
948 assert_eq!(client.protocol(), AgentProtocol::Current);
949 assert_eq!(client.negotiated_version(), PROTOCOL_VERSION);
950 assert_eq!(client.agent_version(), "named-pipe-test");
951 let decoded = client.ready().unwrap();
952 assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
953 assert_eq!(decoded.init_time_ns, ready.init_time_ns);
954 assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
955 }
956
957 #[cfg(all(feature = "uds", unix))]
958 #[tokio::test]
959 async fn connect_negotiates_down_to_older_guest_generation() {
960 let temp = tempfile::tempdir().unwrap();
961 let sock_path = temp.path().join("agent.sock");
962 let listener = UnixListener::bind(&sock_path).unwrap();
963 let ready = Ready {
964 boot_time_ns: 1,
965 init_time_ns: 2,
966 ready_time_ns: 3,
967 ..Default::default()
968 };
969 let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
972 ready_msg.v = 1;
973
974 tokio::spawn(async move {
975 let (mut socket, _) = listener.accept().await.unwrap();
976 socket.write_all(&1u32.to_be_bytes()).await.unwrap();
977 socket
978 .write_all(µsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
979 .await
980 .unwrap();
981 codec::write_message(&mut socket, &ready_msg).await.unwrap();
982 });
983
984 let client =
985 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
986 .await
987 .unwrap();
988
989 assert_eq!(client.protocol(), AgentProtocol::Current);
992 assert_eq!(client.negotiated_version(), 1);
993 assert!(client.supports(MessageType::ExecRequest));
995 assert!(!client.supports(MessageType::FsRequest));
996 }
997
998 #[cfg(all(feature = "uds", unix))]
999 #[tokio::test]
1000 async fn connect_accepts_legacy_relay_handshake() {
1001 assert_accepts_legacy_relay_handshake(0).await;
1002 assert_accepts_legacy_relay_handshake(268_435_455).await;
1003 }
1004
1005 #[cfg(all(feature = "uds", unix))]
1006 #[tokio::test]
1007 async fn legacy_relay_requests_use_v1_and_legacy_id_range() {
1008 let temp = tempfile::tempdir().unwrap();
1009 let sock_path = temp.path().join("agent.sock");
1010 let listener = UnixListener::bind(&sock_path).unwrap();
1011 let ready = Ready {
1012 boot_time_ns: 11,
1013 init_time_ns: 22,
1014 ready_time_ns: 33,
1015 ..Default::default()
1016 };
1017 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1018 let id_offset = 268_435_455u32;
1019 let (frame_tx, frame_rx) = oneshot::channel();
1020
1021 tokio::spawn(async move {
1022 let (mut socket, _) = listener.accept().await.unwrap();
1023 socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
1024 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1025 let frame = codec::read_raw_frame(&mut socket).await.unwrap();
1026 frame_tx.send(frame).unwrap();
1027 });
1028
1029 let client =
1030 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1031 .await
1032 .unwrap();
1033 let request = ExecRequest {
1034 cmd: "/bin/true".into(),
1035 args: Vec::new(),
1036 env: Vec::new(),
1037 cwd: None,
1038 user: None,
1039 tty: false,
1040 rows: 24,
1041 cols: 80,
1042 rlimits: Vec::new(),
1043 };
1044 let (id, _rx) = client
1045 .stream(MessageType::ExecRequest, &request)
1046 .await
1047 .unwrap();
1048
1049 let frame = frame_rx.await.unwrap();
1050 let message = codec::raw_frame_to_message(frame).unwrap();
1051
1052 assert_eq!(id, id_offset + 1);
1053 assert_eq!(message.id, id_offset + 1);
1054 assert_eq!(message.v, LEGACY_PROTOCOL_VERSION);
1055 assert_eq!(message.t, MessageType::ExecRequest);
1056 }
1057
1058 #[test]
1059 fn version_compat_across_generations() {
1060 use MessageType::{ExecRequest, FsRequest};
1061 let cases = [
1065 (ExecRequest, 1, true),
1066 (ExecRequest, 2, true),
1067 (ExecRequest, 3, true),
1068 (FsRequest, 1, false),
1069 (FsRequest, 2, true),
1070 (FsRequest, 3, true),
1071 ];
1072 for (t, generation, allowed) in cases {
1073 assert_eq!(
1074 AgentClient::ensure_version_compat_for(t, generation).is_ok(),
1075 allowed,
1076 "{t:?} at generation {generation}"
1077 );
1078 }
1079 }
1080
1081 #[test]
1082 fn version_compat_rejection_is_typed() {
1083 let err =
1086 AgentClient::ensure_version_compat_for(MessageType::FsRequest, LEGACY_PROTOCOL_VERSION)
1087 .unwrap_err();
1088 assert!(matches!(
1089 err,
1090 AgentClientError::UnsupportedOperation {
1091 needs: 2,
1092 peer: 1,
1093 ..
1094 }
1095 ));
1096 }
1097
1098 #[cfg(all(feature = "uds", unix))]
1099 #[tokio::test]
1100 async fn connect_preserves_current_peer_protocol_version() {
1101 let temp = tempfile::tempdir().unwrap();
1102 let sock_path = temp.path().join("agent.sock");
1103 let listener = UnixListener::bind(&sock_path).unwrap();
1104 let ready = Ready {
1105 boot_time_ns: 11,
1106 init_time_ns: 22,
1107 ready_time_ns: 33,
1108 ..Default::default()
1109 };
1110 let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1111 ready_msg.v = 2;
1112
1113 tokio::spawn(async move {
1114 let (mut socket, _) = listener.accept().await.unwrap();
1115 socket.write_all(&1u32.to_be_bytes()).await.unwrap();
1116 socket
1117 .write_all(µsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
1118 .await
1119 .unwrap();
1120 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1121 });
1122
1123 let client =
1124 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1125 .await
1126 .unwrap();
1127
1128 assert_eq!(client.protocol(), AgentProtocol::Current);
1129 assert_eq!(client.negotiated_version(), 2);
1131 assert!(!client.supports(MessageType::TcpConnect));
1133 }
1134
1135 #[cfg(all(feature = "uds", unix))]
1136 async fn assert_accepts_legacy_relay_handshake(id_offset: u32) {
1137 let temp = tempfile::tempdir().unwrap();
1138 let sock_path = temp.path().join("agent.sock");
1139 let listener = UnixListener::bind(&sock_path).unwrap();
1140 let ready = Ready {
1141 boot_time_ns: 11,
1142 init_time_ns: 22,
1143 ready_time_ns: 33,
1144 ..Default::default()
1145 };
1146 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1147
1148 tokio::spawn(async move {
1149 let (mut socket, _) = listener.accept().await.unwrap();
1150 socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
1151 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1152 });
1153
1154 let client =
1155 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1156 .await
1157 .unwrap();
1158
1159 assert_eq!(client.protocol(), AgentProtocol::LegacyV1);
1160 assert_eq!(client.negotiated_version(), LEGACY_PROTOCOL_VERSION);
1161 let decoded = client.ready().unwrap();
1162 assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
1163 assert_eq!(decoded.init_time_ns, ready.init_time_ns);
1164 assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
1165 }
1166
1167 #[cfg(all(feature = "named-pipe", windows))]
1168 fn unique_named_pipe(name: &str) -> String {
1169 let nanos = std::time::SystemTime::now()
1170 .duration_since(std::time::UNIX_EPOCH)
1171 .unwrap()
1172 .as_nanos();
1173 format!(
1174 r"\\.\pipe\msb-agent-client-{name}-{}-{nanos}",
1175 std::process::id()
1176 )
1177 }
1178
1179 #[cfg(feature = "stream")]
1180 #[tokio::test]
1181 async fn connect_stream_handshakes_and_streams_exec() {
1182 use microsandbox_protocol::exec::{ExecExited, ExecRequest, ExecStdout};
1183 use tokio::io::AsyncWriteExt;
1184
1185 let (client_io, mut server_io) = tokio::io::duplex(64 * 1024);
1186 let ready = Ready {
1187 boot_time_ns: 11,
1188 init_time_ns: 22,
1189 ready_time_ns: 33,
1190 agent_version: "stream-test".to_string(),
1191 };
1192 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1193
1194 tokio::spawn(async move {
1195 server_io.write_all(&1u32.to_be_bytes()).await.unwrap();
1197 server_io.write_all(&1024u32.to_be_bytes()).await.unwrap();
1198 codec::write_message(&mut server_io, &ready_msg)
1199 .await
1200 .unwrap();
1201
1202 let request = codec::read_raw_frame(&mut server_io).await.unwrap();
1204 let stdout = Message::with_payload(
1205 MessageType::ExecStdout,
1206 request.id,
1207 &ExecStdout {
1208 data: b"hi".to_vec(),
1209 },
1210 )
1211 .unwrap();
1212 codec::write_message(&mut server_io, &stdout).await.unwrap();
1213 let exited =
1214 Message::with_payload(MessageType::ExecExited, request.id, &ExecExited { code: 0 })
1215 .unwrap();
1216 codec::write_message(&mut server_io, &exited).await.unwrap();
1217 });
1218
1219 let client = AgentClient::connect_stream_with_deadline(
1220 client_io,
1221 Instant::now() + Duration::from_secs(1),
1222 )
1223 .await
1224 .unwrap();
1225
1226 assert_eq!(client.protocol(), AgentProtocol::Current);
1227 assert_eq!(client.agent_version(), "stream-test");
1228 assert!(client.supports(MessageType::ExecRequest));
1229
1230 let request = ExecRequest {
1231 cmd: "echo".into(),
1232 args: vec!["hi".into()],
1233 env: Vec::new(),
1234 cwd: None,
1235 user: None,
1236 tty: false,
1237 rows: 24,
1238 cols: 80,
1239 rlimits: Vec::new(),
1240 };
1241 let (_id, mut rx) = client
1242 .stream(MessageType::ExecRequest, &request)
1243 .await
1244 .unwrap();
1245
1246 let first = rx.recv().await.unwrap();
1247 assert_eq!(first.t, MessageType::ExecStdout);
1248 let out: ExecStdout = first.payload().unwrap();
1249 assert_eq!(out.data, b"hi");
1250
1251 let second = rx.recv().await.unwrap();
1252 assert_eq!(second.t, MessageType::ExecExited);
1253 let exit: ExecExited = second.payload().unwrap();
1254 assert_eq!(exit.code, 0);
1255 }
1256}
1257
1258impl Drop for AgentClient {
1263 fn drop(&mut self) {
1264 self.reader_handle.abort();
1265 self.writer_handle.abort();
1266 }
1267}