1use std::collections::HashMap;
22#[cfg(any(feature = "uds", feature = "websocket"))]
23use std::future::Future;
24#[cfg(feature = "uds")]
25use std::path::Path;
26#[cfg(any(feature = "uds", feature = "websocket"))]
27use std::pin::Pin;
28use std::sync::{Arc, atomic::AtomicU32};
29#[cfg(any(feature = "uds", feature = "websocket"))]
30use std::time::Duration;
31
32#[cfg(feature = "websocket")]
33use futures_util::{SinkExt, StreamExt};
34#[cfg(any(feature = "uds", feature = "websocket"))]
35use microsandbox_protocol::message::FLAG_TERMINAL;
36#[cfg(any(feature = "uds", feature = "websocket"))]
37use microsandbox_protocol::{codec::MAX_FRAME_SIZE, message::FRAME_HEADER_SIZE};
38use microsandbox_protocol::{
39 codec::{self, RawFrame},
40 core::Ready,
41 message::{Message, MessageType, PROTOCOL_VERSION},
42};
43use serde::Serialize;
44#[cfg(feature = "uds")]
45use tokio::net::UnixStream;
46use tokio::sync::{Mutex, mpsc, oneshot};
47use tokio::task::JoinHandle;
48#[cfg(any(feature = "uds", feature = "websocket"))]
49use tokio::time::Instant;
50#[cfg(feature = "websocket")]
51use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
52
53use super::error::{AgentClientError, AgentClientResult};
54
55#[cfg(any(feature = "uds", feature = "websocket"))]
61const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
62
63#[cfg(any(feature = "uds", feature = "websocket"))]
64const WRITER_QUEUE_CAPACITY: usize = 1024;
65const REQUEST_QUEUE_CAPACITY: usize = 1;
66const STREAM_QUEUE_CAPACITY: usize = 1024;
67#[cfg(feature = "websocket")]
68const MAX_WEBSOCKET_BUFFER_SIZE: usize = MAX_FRAME_SIZE as usize + 12;
69
70const LEGACY_PROTOCOL_VERSION: u8 = 1;
71#[cfg(any(feature = "uds", feature = "websocket"))]
74const LEGACY_RELAY_ID_RANGE_STEP: u32 = u32::MAX / 16;
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum AgentProtocol {
83 Current,
85
86 LegacyV1,
91}
92
93pub struct AgentClient {
97 writer: mpsc::Sender<WriterCommand>,
99 next_id: AtomicU32,
101 id_min: u32,
103 id_max: u32,
105 protocol: AgentProtocol,
107 negotiated_version: u8,
112 pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
114 reader_handle: JoinHandle<()>,
116 writer_handle: JoinHandle<()>,
118 ready_body: Vec<u8>,
120 ready: Ready,
122}
123
124#[cfg(any(feature = "uds", feature = "websocket"))]
125struct AgentHandshake {
126 id_min: u32,
127 id_max: u32,
128 protocol: AgentProtocol,
129 negotiated_version: u8,
130 ready_body: Vec<u8>,
131 ready: Ready,
132}
133
134#[cfg_attr(not(any(feature = "uds", feature = "websocket")), allow(dead_code))]
135struct WriterCommand {
136 frame: RawFrame,
137 ack: oneshot::Sender<AgentClientResult<()>>,
138}
139
140#[cfg(any(feature = "uds", feature = "websocket"))]
141trait HandshakeReader {
142 fn read_exact_handshake<'a>(
143 &'a mut self,
144 out: &'a mut [u8],
145 ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>>;
146
147 fn read_frame_handshake<'a>(
148 &'a mut self,
149 ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>>;
150}
151
152impl AgentProtocol {
157 fn version(self) -> u8 {
158 match self {
159 Self::Current => PROTOCOL_VERSION,
160 Self::LegacyV1 => LEGACY_PROTOCOL_VERSION,
161 }
162 }
163}
164
165impl AgentClient {
166 #[cfg(feature = "uds")]
169 pub async fn connect(sock_path: impl AsRef<Path>) -> AgentClientResult<Self> {
170 Self::connect_with_timeout(sock_path, DEFAULT_HANDSHAKE_TIMEOUT).await
171 }
172
173 #[cfg(feature = "uds")]
176 pub async fn connect_with_timeout(
177 sock_path: impl AsRef<Path>,
178 timeout: Duration,
179 ) -> AgentClientResult<Self> {
180 let deadline = Instant::now() + timeout;
181 Self::connect_with_deadline(sock_path, deadline).await
182 }
183
184 #[cfg(feature = "uds")]
190 pub async fn connect_with_deadline(
191 sock_path: impl AsRef<Path>,
192 deadline: Instant,
193 ) -> AgentClientResult<Self> {
194 let sock_path = sock_path.as_ref();
195 let stream =
196 UnixStream::connect(sock_path)
197 .await
198 .map_err(|source| AgentClientError::Connect {
199 path: sock_path.to_path_buf(),
200 source,
201 })?;
202
203 let (mut reader, writer) = stream.into_split();
204 let handshake = perform_handshake(&mut reader, deadline).await?;
205
206 tracing::info!(
207 id_min = handshake.id_min,
208 id_max = handshake.id_max,
209 protocol = ?handshake.protocol,
210 ready_bytes = handshake.ready_body.len(),
211 boot_time_ns = handshake.ready.boot_time_ns,
212 "agent client: connected to relay"
213 );
214 if handshake.protocol == AgentProtocol::LegacyV1 {
215 tracing::warn!(
218 "agent client: connected to a sandbox started before microsandbox 0.5; exec compatibility is temporary and filesystem/SFTP require stop/start"
219 );
220 }
221
222 let pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>> =
223 Arc::new(Mutex::new(HashMap::new()));
224
225 let (writer_tx, writer_rx) = mpsc::channel(WRITER_QUEUE_CAPACITY);
226 let reader_handle = tokio::spawn(reader_loop(reader, Arc::clone(&pending)));
227 let writer_handle = tokio::spawn(uds_writer_loop(writer, writer_rx));
228
229 Ok(Self {
230 writer: writer_tx,
231 next_id: AtomicU32::new(first_request_id(handshake.id_min)),
232 id_min: handshake.id_min,
233 id_max: handshake.id_max,
234 protocol: handshake.protocol,
235 negotiated_version: handshake.negotiated_version,
236 pending,
237 reader_handle,
238 writer_handle,
239 ready_body: handshake.ready_body,
240 ready: handshake.ready,
241 })
242 }
243
244 #[cfg(feature = "websocket")]
247 pub async fn connect_websocket(url: &str) -> AgentClientResult<Self> {
248 Self::connect_websocket_with_timeout(url, DEFAULT_HANDSHAKE_TIMEOUT).await
249 }
250
251 #[cfg(feature = "websocket")]
254 pub async fn connect_websocket_with_timeout(
255 url: &str,
256 timeout: Duration,
257 ) -> AgentClientResult<Self> {
258 let deadline = Instant::now() + timeout;
259 Self::connect_websocket_with_deadline(url, deadline).await
260 }
261
262 #[cfg(feature = "websocket")]
265 pub async fn connect_websocket_with_deadline(
266 url: &str,
267 deadline: Instant,
268 ) -> AgentClientResult<Self> {
269 let (stream, _) = tokio_tungstenite::connect_async(url)
270 .await
271 .map_err(|e| AgentClientError::WebSocket(format!("connect {url}: {e}")))?;
272 let (writer, reader) = stream.split();
273 let mut reader = WebSocketByteReader::new(reader);
274 let handshake = perform_handshake(&mut reader, deadline).await?;
275 if handshake.protocol == AgentProtocol::LegacyV1 {
276 tracing::warn!(
279 "agent client: connected to a sandbox started before microsandbox 0.5; exec compatibility is temporary and filesystem/SFTP require stop/start"
280 );
281 }
282
283 let pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>> =
284 Arc::new(Mutex::new(HashMap::new()));
285 let (writer_tx, writer_rx) = mpsc::channel(WRITER_QUEUE_CAPACITY);
286 let reader_handle = tokio::spawn(websocket_reader_loop(reader, Arc::clone(&pending)));
287 let writer_handle = tokio::spawn(websocket_writer_loop(writer, writer_rx));
288
289 Ok(Self {
290 writer: writer_tx,
291 next_id: AtomicU32::new(first_request_id(handshake.id_min)),
292 id_min: handshake.id_min,
293 id_max: handshake.id_max,
294 protocol: handshake.protocol,
295 negotiated_version: handshake.negotiated_version,
296 pending,
297 reader_handle,
298 writer_handle,
299 ready_body: handshake.ready_body,
300 ready: handshake.ready,
301 })
302 }
303
304 pub async fn close(self) {
307 }
311}
312
313impl AgentClient {
318 pub async fn request_raw(&self, flags: u8, body: Vec<u8>) -> AgentClientResult<RawFrame> {
324 let (tx, mut rx) = mpsc::channel(REQUEST_QUEUE_CAPACITY);
325 let id = self.reserve_id(tx).await?;
326
327 if let Err(e) = self.write_frame_owned(id, flags, body).await {
328 self.pending.lock().await.remove(&id);
329 return Err(e);
330 }
331
332 let frame = rx.recv().await.ok_or(AgentClientError::ReaderClosed(id))?;
333 self.pending.lock().await.remove(&id);
334 Ok(frame)
335 }
336
337 pub async fn stream_raw(
345 &self,
346 flags: u8,
347 body: Vec<u8>,
348 ) -> AgentClientResult<(u32, mpsc::Receiver<RawFrame>)> {
349 let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
350 let id = self.reserve_id(tx).await?;
351
352 if let Err(e) = self.write_frame_owned(id, flags, body).await {
353 self.pending.lock().await.remove(&id);
354 return Err(e);
355 }
356
357 Ok((id, rx))
358 }
359
360 pub async fn send_raw(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
366 self.write_frame(id, flags, body).await
367 }
368
369 pub fn ready_bytes(&self) -> &[u8] {
374 &self.ready_body
375 }
376
377 pub fn protocol(&self) -> AgentProtocol {
379 self.protocol
380 }
381
382 pub fn is_legacy_protocol(&self) -> bool {
384 self.protocol == AgentProtocol::LegacyV1
385 }
386
387 pub fn negotiated_version(&self) -> u8 {
390 self.negotiated_version
391 }
392
393 pub fn agent_version(&self) -> &str {
397 &self.ready.agent_version
398 }
399
400 pub fn supports(&self, t: MessageType) -> bool {
405 t.min_protocol_version() <= self.negotiated_version
406 }
407
408 pub fn ensure_version_compat(&self, t: MessageType) -> AgentClientResult<()> {
412 Self::ensure_version_compat_for(t, self.negotiated_version)
413 }
414
415 pub fn ensure_version_compat_for(t: MessageType, negotiated: u8) -> AgentClientResult<()> {
420 if t.is_available_at(negotiated) {
421 return Ok(());
422 }
423 Err(AgentClientError::UnsupportedOperation {
424 msg_type: t.as_str(),
425 needs: t.min_protocol_version(),
426 peer: negotiated,
427 })
428 }
429}
430
431impl AgentClient {
436 pub async fn request<T: Serialize>(
438 &self,
439 t: MessageType,
440 payload: &T,
441 ) -> AgentClientResult<Message> {
442 self.ensure_version_compat(t)?;
443 let flags = t.flags();
444 let body = encode_message_body(self.protocol.version(), t, payload)?;
445 let frame = self.request_raw(flags, body).await?;
446 Ok(codec::raw_frame_to_message(frame)?)
447 }
448
449 pub async fn stream<T: Serialize>(
452 &self,
453 t: MessageType,
454 payload: &T,
455 ) -> AgentClientResult<(u32, mpsc::Receiver<Message>)> {
456 self.ensure_version_compat(t)?;
457 let flags = t.flags();
458 let body = encode_message_body(self.protocol.version(), t, payload)?;
459 let (id, raw_rx) = self.stream_raw(flags, body).await?;
460
461 let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
462 tokio::spawn(decode_stream_task(raw_rx, tx));
463 Ok((id, rx))
464 }
465
466 pub async fn send<T: Serialize>(
468 &self,
469 id: u32,
470 t: MessageType,
471 payload: &T,
472 ) -> AgentClientResult<()> {
473 self.ensure_version_compat(t)?;
474 let flags = t.flags();
475 let body = encode_message_body(self.protocol.version(), t, payload)?;
476 self.write_frame_owned(id, flags, body).await
477 }
478
479 pub fn ready(&self) -> AgentClientResult<Ready> {
481 Ok(self.ready.clone())
482 }
483}
484
485impl AgentClient {
490 async fn reserve_id(&self, tx: mpsc::Sender<RawFrame>) -> AgentClientResult<u32> {
495 let mut pending = self.pending.lock().await;
496 let attempts = usable_id_count(self.id_min, self.id_max);
497 for _ in 0..attempts {
498 let id = self
499 .next_id
500 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
501 if self.next_id.load(std::sync::atomic::Ordering::Relaxed) >= self.id_max {
502 self.next_id.store(
503 first_request_id(self.id_min),
504 std::sync::atomic::Ordering::Relaxed,
505 );
506 }
507 if id == 0 || id < self.id_min || id >= self.id_max || pending.contains_key(&id) {
508 continue;
509 }
510 pending.insert(id, tx);
511 return Ok(id);
512 }
513
514 Err(AgentClientError::IdRangeExhausted)
515 }
516
517 async fn write_frame(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
519 self.write_frame_owned(id, flags, body.to_vec()).await
520 }
521
522 async fn write_frame_owned(&self, id: u32, flags: u8, body: Vec<u8>) -> AgentClientResult<()> {
524 let (ack, written) = oneshot::channel();
525 self.writer
526 .send(WriterCommand {
527 frame: RawFrame { id, flags, body },
528 ack,
529 })
530 .await
531 .map_err(|_| AgentClientError::Closed)?;
532 written.await.map_err(|_| AgentClientError::Closed)?
533 }
534}
535
536#[cfg(any(feature = "uds", feature = "websocket"))]
541async fn perform_handshake<R>(
542 reader: &mut R,
543 deadline: Instant,
544) -> AgentClientResult<AgentHandshake>
545where
546 R: HandshakeReader + ?Sized,
547{
548 let mut range_buf = [0u8; 8];
557 tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut range_buf))
558 .await
559 .map_err(|_| {
560 AgentClientError::Handshake("read id range: timed out before relay sent bytes".into())
561 })??;
562 let id_start_or_offset = u32::from_be_bytes(range_buf[0..4].try_into().unwrap());
563 let id_max_or_frame_len = u32::from_be_bytes(range_buf[4..8].try_into().unwrap());
564
565 let legacy_handshake =
566 looks_like_legacy_relay_handshake(id_start_or_offset, id_max_or_frame_len);
567 let (id_min, id_max, ready_frame, protocol) = if legacy_handshake {
568 let id_offset = id_start_or_offset;
569 let ready_frame =
570 read_raw_frame_after_len_prefix(reader, range_buf[4..8].try_into().unwrap(), deadline)
571 .await?;
572 (
573 id_offset.saturating_add(1),
574 id_offset.saturating_add(LEGACY_RELAY_ID_RANGE_STEP),
575 ready_frame,
576 AgentProtocol::LegacyV1,
577 )
578 } else if id_start_or_offset >= id_max_or_frame_len {
579 return Err(AgentClientError::Handshake(format!(
580 "invalid relay id range: start={id_start_or_offset}, end={id_max_or_frame_len}"
581 )));
582 } else {
583 let ready_frame = tokio::time::timeout_at(deadline, reader.read_frame_handshake())
584 .await
585 .map_err(|_| {
586 AgentClientError::Handshake(
587 "read ready frame: timed out before relay sent frame".into(),
588 )
589 })?
590 .map_err(|e| AgentClientError::Handshake(format!("read ready frame: {e}")))?;
591 (
592 id_start_or_offset,
593 id_max_or_frame_len,
594 ready_frame,
595 AgentProtocol::Current,
596 )
597 };
598 ensure_usable_id_range(id_min, id_max)?;
599
600 let ready_msg = codec::raw_frame_to_message(ready_frame.clone())
601 .map_err(|e| AgentClientError::Handshake(format!("decode ready frame: {e}")))?;
602 if ready_msg.t != MessageType::Ready {
603 return Err(AgentClientError::Handshake(format!(
604 "expected core.ready frame, got {}",
605 ready_msg.t.as_str()
606 )));
607 }
608 let ready: Ready = ready_msg
609 .payload()
610 .map_err(|e| AgentClientError::Handshake(format!("decode ready payload: {e}")))?;
611
612 let negotiated_version = protocol.version().min(ready_msg.v);
618
619 Ok(AgentHandshake {
620 id_min,
621 id_max,
622 protocol,
623 negotiated_version,
624 ready_body: ready_frame.body,
625 ready,
626 })
627}
628
629fn first_request_id(id_min: u32) -> u32 {
630 id_min.max(1)
631}
632
633#[cfg(any(feature = "uds", feature = "websocket"))]
634fn ensure_usable_id_range(id_min: u32, id_max: u32) -> AgentClientResult<()> {
635 if usable_id_count(id_min, id_max) == 0 {
636 return Err(AgentClientError::Handshake(format!(
637 "relay id range contains no usable nonzero ids: start={id_min}, end={id_max}"
638 )));
639 }
640 Ok(())
641}
642
643fn usable_id_count(id_min: u32, id_max: u32) -> u32 {
644 id_max.saturating_sub(first_request_id(id_min))
645}
646
647#[cfg(any(feature = "uds", feature = "websocket"))]
648fn looks_like_legacy_relay_handshake(id_min: u32, id_max: u32) -> bool {
649 id_max >= FRAME_HEADER_SIZE as u32
658 && id_max <= MAX_FRAME_SIZE
659 && (id_min == 0 || id_min >= id_max)
660}
661
662#[cfg(any(feature = "uds", feature = "websocket"))]
663async fn read_raw_frame_after_len_prefix<R>(
664 reader: &mut R,
665 len_buf: [u8; 4],
666 deadline: Instant,
667) -> AgentClientResult<RawFrame>
668where
669 R: HandshakeReader + ?Sized,
670{
671 let frame_len = u32::from_be_bytes(len_buf);
672 if frame_len > MAX_FRAME_SIZE {
673 return Err(AgentClientError::Handshake(format!(
674 "legacy ready frame too large: {frame_len} bytes (max {MAX_FRAME_SIZE})"
675 )));
676 }
677 if frame_len < FRAME_HEADER_SIZE as u32 {
678 return Err(AgentClientError::Handshake(format!(
679 "legacy ready frame too short: {frame_len} bytes"
680 )));
681 }
682
683 let mut data = vec![0u8; frame_len as usize];
684 tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut data))
685 .await
686 .map_err(|_| {
687 AgentClientError::Handshake(
688 "read legacy ready frame: timed out before relay sent frame".into(),
689 )
690 })?
691 .map_err(|e| AgentClientError::Handshake(format!("read legacy ready frame: {e}")))?;
692
693 let id = u32::from_be_bytes(data[0..4].try_into().unwrap());
694 let flags = data[4];
695 let body = data[FRAME_HEADER_SIZE..].to_vec();
696
697 Ok(RawFrame { id, flags, body })
698}
699
700#[cfg(feature = "uds")]
701impl<R> HandshakeReader for R
702where
703 R: tokio::io::AsyncRead + Unpin + Send,
704{
705 fn read_exact_handshake<'a>(
706 &'a mut self,
707 out: &'a mut [u8],
708 ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>> {
709 Box::pin(async move {
710 tokio::io::AsyncReadExt::read_exact(self, out)
711 .await
712 .map(|_| ())
713 .map_err(|e| AgentClientError::Handshake(e.to_string()))
714 })
715 }
716
717 fn read_frame_handshake<'a>(
718 &'a mut self,
719 ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>> {
720 Box::pin(async move {
721 codec::read_raw_frame(self)
722 .await
723 .map_err(AgentClientError::Protocol)
724 })
725 }
726}
727
728#[cfg(feature = "uds")]
729async fn uds_writer_loop(
730 mut writer: tokio::net::unix::OwnedWriteHalf,
731 mut rx: mpsc::Receiver<WriterCommand>,
732) {
733 while let Some(command) = rx.recv().await {
734 if let Err(e) = codec::write_raw_frame(&mut writer, &command.frame).await {
735 tracing::debug!("agent client: UDS writer error: {e}");
736 let _ = command.ack.send(Err(AgentClientError::Protocol(e)));
737 break;
738 }
739 let _ = command.ack.send(Ok(()));
740 }
741}
742
743#[cfg(feature = "uds")]
746async fn reader_loop<R>(mut reader: R, pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>)
747where
748 R: tokio::io::AsyncRead + Unpin,
749{
750 loop {
751 let frame = match codec::read_raw_frame(&mut reader).await {
752 Ok(frame) => frame,
753 Err(e) => {
754 tracing::debug!("agent client: reader EOF or error: {e}");
755 break;
756 }
757 };
758
759 dispatch_frame(frame, &pending).await;
760 }
761
762 let mut map = pending.lock().await;
764 map.clear();
765}
766
767#[cfg(feature = "websocket")]
768struct WebSocketByteReader<S> {
769 stream: S,
770 buffer: Vec<u8>,
771 cursor: usize,
772}
773
774#[cfg(feature = "websocket")]
775impl<S> WebSocketByteReader<S>
776where
777 S: futures_util::Stream<Item = Result<WebSocketMessage, tokio_tungstenite::tungstenite::Error>>
778 + Unpin
779 + Send,
780{
781 fn new(stream: S) -> Self {
782 Self {
783 stream,
784 buffer: Vec::new(),
785 cursor: 0,
786 }
787 }
788
789 async fn read_exact(&mut self, out: &mut [u8]) -> AgentClientResult<()> {
790 while self.available_bytes() < out.len() {
791 let Some(message) = self.stream.next().await else {
792 return Err(AgentClientError::WebSocket(
793 "websocket closed before enough bytes were available".to_string(),
794 ));
795 };
796 match message.map_err(|e| AgentClientError::WebSocket(e.to_string()))? {
797 WebSocketMessage::Binary(bytes) => {
798 if self.available_bytes() + bytes.len() > MAX_WEBSOCKET_BUFFER_SIZE {
799 return Err(AgentClientError::WebSocket(format!(
800 "websocket buffer exceeded maximum size: {} bytes (max {MAX_WEBSOCKET_BUFFER_SIZE})",
801 self.available_bytes() + bytes.len()
802 )));
803 }
804 self.compact_consumed();
805 self.buffer.extend_from_slice(&bytes);
806 }
807 WebSocketMessage::Close(_) => {
808 return Err(AgentClientError::WebSocket(
809 "websocket closed before enough bytes were available".to_string(),
810 ));
811 }
812 WebSocketMessage::Ping(_) | WebSocketMessage::Pong(_) => {}
813 WebSocketMessage::Text(_) | WebSocketMessage::Frame(_) => {
814 return Err(AgentClientError::WebSocket(
815 "websocket message is not binary".to_string(),
816 ));
817 }
818 }
819 }
820
821 out.copy_from_slice(&self.buffer[self.cursor..self.cursor + out.len()]);
822 self.cursor += out.len();
823 self.compact_consumed();
824 Ok(())
825 }
826
827 fn available_bytes(&self) -> usize {
828 self.buffer.len().saturating_sub(self.cursor)
829 }
830
831 fn compact_consumed(&mut self) {
832 if self.cursor == 0 {
833 return;
834 }
835 if self.cursor == self.buffer.len() {
836 self.buffer.clear();
837 } else {
838 self.buffer.drain(..self.cursor);
839 }
840 self.cursor = 0;
841 }
842
843 async fn read_raw_frame(&mut self) -> AgentClientResult<RawFrame> {
844 let mut len_buf = [0u8; 4];
845 self.read_exact(&mut len_buf).await?;
846 let frame_len = u32::from_be_bytes(len_buf);
847 if frame_len > MAX_FRAME_SIZE {
848 return Err(AgentClientError::Protocol(
849 microsandbox_protocol::ProtocolError::FrameTooLarge {
850 size: frame_len,
851 max: MAX_FRAME_SIZE,
852 },
853 ));
854 }
855 if frame_len < FRAME_HEADER_SIZE as u32 {
856 return Err(AgentClientError::Protocol(
857 microsandbox_protocol::ProtocolError::FrameTooShort {
858 size: frame_len,
859 min: FRAME_HEADER_SIZE as u32,
860 },
861 ));
862 }
863
864 let mut payload = vec![0u8; frame_len as usize];
865 self.read_exact(&mut payload).await?;
866 let id = u32::from_be_bytes(payload[0..4].try_into().unwrap());
867 let flags = payload[4];
868 let body = payload[FRAME_HEADER_SIZE..].to_vec();
869 Ok(RawFrame { id, flags, body })
870 }
871}
872
873#[cfg(feature = "websocket")]
874impl<S> HandshakeReader for WebSocketByteReader<S>
875where
876 S: futures_util::Stream<Item = Result<WebSocketMessage, tokio_tungstenite::tungstenite::Error>>
877 + Unpin
878 + Send,
879{
880 fn read_exact_handshake<'a>(
881 &'a mut self,
882 out: &'a mut [u8],
883 ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>> {
884 Box::pin(async move { self.read_exact(out).await })
885 }
886
887 fn read_frame_handshake<'a>(
888 &'a mut self,
889 ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>> {
890 Box::pin(async move { self.read_raw_frame().await })
891 }
892}
893
894#[cfg(feature = "websocket")]
895async fn websocket_reader_loop<S>(
896 mut reader: WebSocketByteReader<S>,
897 pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
898) where
899 S: futures_util::Stream<Item = Result<WebSocketMessage, tokio_tungstenite::tungstenite::Error>>
900 + Unpin
901 + Send,
902{
903 loop {
904 let frame = match reader.read_raw_frame().await {
905 Ok(frame) => frame,
906 Err(e) => {
907 tracing::debug!("agent client: websocket reader EOF or error: {e}");
908 break;
909 }
910 };
911
912 dispatch_frame(frame, &pending).await;
913 }
914
915 let mut map = pending.lock().await;
916 map.clear();
917}
918
919#[cfg(feature = "websocket")]
920async fn websocket_writer_loop<S>(mut writer: S, mut rx: mpsc::Receiver<WriterCommand>)
921where
922 S: futures_util::Sink<WebSocketMessage, Error = tokio_tungstenite::tungstenite::Error> + Unpin,
923{
924 while let Some(command) = rx.recv().await {
925 let mut buf = Vec::new();
926 if let Err(e) = codec::encode_raw_to_buf(&command.frame, &mut buf) {
927 tracing::debug!("agent client: websocket encode error: {e}");
928 let _ = command.ack.send(Err(AgentClientError::Protocol(e)));
929 break;
930 }
931 if let Err(e) = writer.send(WebSocketMessage::Binary(buf.into())).await {
932 tracing::debug!("agent client: websocket writer error: {e}");
933 let _ = command
934 .ack
935 .send(Err(AgentClientError::WebSocket(e.to_string())));
936 break;
937 }
938 let _ = command.ack.send(Ok(()));
939 }
940}
941
942#[cfg(any(feature = "uds", feature = "websocket"))]
943async fn dispatch_frame(
944 frame: RawFrame,
945 pending: &Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
946) {
947 let id = frame.id;
948 let is_terminal = (frame.flags & FLAG_TERMINAL) != 0;
949
950 let tx = {
951 let mut map = pending.lock().await;
952 let Some(tx) = map.get(&id).cloned() else {
953 tracing::trace!("agent client: no pending handler for id={id}");
954 return;
955 };
956 if is_terminal {
957 map.remove(&id);
958 }
959 tx
960 };
961
962 if tx.send(frame).await.is_err() {
963 pending.lock().await.remove(&id);
964 }
965}
966
967async fn decode_stream_task(mut raw_rx: mpsc::Receiver<RawFrame>, tx: mpsc::Sender<Message>) {
969 while let Some(frame) = raw_rx.recv().await {
970 match codec::raw_frame_to_message(frame) {
971 Ok(msg) => {
972 if tx.send(msg).await.is_err() {
973 break;
974 }
975 }
976 Err(e) => {
977 tracing::warn!("agent client: failed to decode frame in stream: {e}");
978 }
980 }
981 }
982}
983
984fn encode_message_body<T: Serialize>(
986 version: u8,
987 t: MessageType,
988 payload: &T,
989) -> AgentClientResult<Vec<u8>> {
990 let mut msg = Message::with_payload(t, 0, payload)?;
991 msg.v = version;
992 let mut body = Vec::new();
993 ciborium::into_writer(&msg, &mut body).map_err(microsandbox_protocol::ProtocolError::from)?;
994 Ok(body)
995}
996
997#[cfg(test)]
1002mod tests {
1003 #[cfg(any(feature = "uds", feature = "websocket"))]
1004 use microsandbox_protocol::core::Ready;
1005 #[cfg(any(feature = "uds", feature = "websocket"))]
1006 use microsandbox_protocol::exec::ExecRequest;
1007 #[cfg(feature = "websocket")]
1008 use microsandbox_protocol::fs::{FsOp, FsRequest, FsResponse};
1009 #[cfg(feature = "uds")]
1010 use microsandbox_protocol::message::PROTOCOL_VERSION;
1011 #[cfg(feature = "uds")]
1012 use tokio::io::AsyncWriteExt;
1013 #[cfg(feature = "websocket")]
1014 use tokio::net::TcpListener;
1015 #[cfg(feature = "uds")]
1016 use tokio::net::UnixListener;
1017 #[cfg(any(feature = "uds", feature = "websocket"))]
1018 use tokio::sync::oneshot;
1019
1020 use super::*;
1021
1022 #[cfg(feature = "uds")]
1023 #[tokio::test]
1024 async fn connect_decodes_ready_payload() {
1025 let temp = tempfile::tempdir().unwrap();
1026 let sock_path = temp.path().join("agent.sock");
1027 let listener = UnixListener::bind(&sock_path).unwrap();
1028 let ready = Ready {
1029 boot_time_ns: 11,
1030 init_time_ns: 22,
1031 ready_time_ns: 33,
1032 agent_version: "9.9.9".to_string(),
1033 };
1034 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1035
1036 tokio::spawn(async move {
1037 let (mut socket, _) = listener.accept().await.unwrap();
1038 socket.write_all(&1u32.to_be_bytes()).await.unwrap();
1039 socket.write_all(&8u32.to_be_bytes()).await.unwrap();
1040 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1041 });
1042
1043 let client =
1044 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1045 .await
1046 .unwrap();
1047
1048 assert_eq!(client.protocol(), AgentProtocol::Current);
1049 assert_eq!(client.negotiated_version(), PROTOCOL_VERSION);
1051 assert!(client.supports(MessageType::FsRequest));
1052 assert_eq!(client.agent_version(), "9.9.9");
1054 let decoded = client.ready().unwrap();
1055 assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
1056 assert_eq!(decoded.init_time_ns, ready.init_time_ns);
1057 assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
1058
1059 let raw_msg: Message = ciborium::from_reader(client.ready_bytes()).unwrap();
1060 assert_eq!(raw_msg.t, MessageType::Ready);
1061 }
1062
1063 #[cfg(feature = "uds")]
1064 #[tokio::test]
1065 async fn connect_negotiates_down_to_older_guest_generation() {
1066 let temp = tempfile::tempdir().unwrap();
1067 let sock_path = temp.path().join("agent.sock");
1068 let listener = UnixListener::bind(&sock_path).unwrap();
1069 let ready = Ready {
1070 boot_time_ns: 1,
1071 init_time_ns: 2,
1072 ready_time_ns: 3,
1073 ..Default::default()
1074 };
1075 let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1078 ready_msg.v = 1;
1079
1080 tokio::spawn(async move {
1081 let (mut socket, _) = listener.accept().await.unwrap();
1082 socket.write_all(&1u32.to_be_bytes()).await.unwrap();
1083 socket
1084 .write_all(µsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
1085 .await
1086 .unwrap();
1087 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1088 });
1089
1090 let client =
1091 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1092 .await
1093 .unwrap();
1094
1095 assert_eq!(client.protocol(), AgentProtocol::Current);
1098 assert_eq!(client.negotiated_version(), 1);
1099 assert!(client.supports(MessageType::ExecRequest));
1101 assert!(!client.supports(MessageType::FsRequest));
1102 }
1103
1104 #[cfg(feature = "uds")]
1105 #[tokio::test]
1106 async fn connect_accepts_legacy_relay_handshake() {
1107 assert_accepts_legacy_relay_handshake(0).await;
1108 assert_accepts_legacy_relay_handshake(268_435_455).await;
1109 }
1110
1111 #[cfg(feature = "uds")]
1112 #[tokio::test]
1113 async fn legacy_relay_requests_use_v1_and_legacy_id_range() {
1114 let temp = tempfile::tempdir().unwrap();
1115 let sock_path = temp.path().join("agent.sock");
1116 let listener = UnixListener::bind(&sock_path).unwrap();
1117 let ready = Ready {
1118 boot_time_ns: 11,
1119 init_time_ns: 22,
1120 ready_time_ns: 33,
1121 ..Default::default()
1122 };
1123 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1124 let id_offset = 268_435_455u32;
1125 let (frame_tx, frame_rx) = oneshot::channel();
1126
1127 tokio::spawn(async move {
1128 let (mut socket, _) = listener.accept().await.unwrap();
1129 socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
1130 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1131 let frame = codec::read_raw_frame(&mut socket).await.unwrap();
1132 frame_tx.send(frame).unwrap();
1133 });
1134
1135 let client =
1136 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1137 .await
1138 .unwrap();
1139 let request = ExecRequest {
1140 cmd: "/bin/true".into(),
1141 args: Vec::new(),
1142 env: Vec::new(),
1143 cwd: None,
1144 user: None,
1145 tty: false,
1146 rows: 24,
1147 cols: 80,
1148 rlimits: Vec::new(),
1149 };
1150 let (id, _rx) = client
1151 .stream(MessageType::ExecRequest, &request)
1152 .await
1153 .unwrap();
1154
1155 let frame = frame_rx.await.unwrap();
1156 let message = codec::raw_frame_to_message(frame).unwrap();
1157
1158 assert_eq!(id, id_offset + 1);
1159 assert_eq!(message.id, id_offset + 1);
1160 assert_eq!(message.v, LEGACY_PROTOCOL_VERSION);
1161 assert_eq!(message.t, MessageType::ExecRequest);
1162 }
1163
1164 #[test]
1165 fn version_compat_across_generations() {
1166 use MessageType::{ExecRequest, FsRequest};
1167 let cases = [
1171 (ExecRequest, 1, true),
1172 (ExecRequest, 2, true),
1173 (ExecRequest, 3, true),
1174 (FsRequest, 1, false),
1175 (FsRequest, 2, true),
1176 (FsRequest, 3, true),
1177 ];
1178 for (t, generation, allowed) in cases {
1179 assert_eq!(
1180 AgentClient::ensure_version_compat_for(t, generation).is_ok(),
1181 allowed,
1182 "{t:?} at generation {generation}"
1183 );
1184 }
1185 }
1186
1187 #[test]
1188 fn version_compat_rejection_is_typed() {
1189 let err =
1192 AgentClient::ensure_version_compat_for(MessageType::FsRequest, LEGACY_PROTOCOL_VERSION)
1193 .unwrap_err();
1194 assert!(matches!(
1195 err,
1196 AgentClientError::UnsupportedOperation {
1197 needs: 2,
1198 peer: 1,
1199 ..
1200 }
1201 ));
1202 }
1203
1204 #[cfg(feature = "uds")]
1205 #[tokio::test]
1206 async fn connect_preserves_current_peer_protocol_version() {
1207 let temp = tempfile::tempdir().unwrap();
1208 let sock_path = temp.path().join("agent.sock");
1209 let listener = UnixListener::bind(&sock_path).unwrap();
1210 let ready = Ready {
1211 boot_time_ns: 11,
1212 init_time_ns: 22,
1213 ready_time_ns: 33,
1214 ..Default::default()
1215 };
1216 let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1217 ready_msg.v = 2;
1218
1219 tokio::spawn(async move {
1220 let (mut socket, _) = listener.accept().await.unwrap();
1221 socket.write_all(&1u32.to_be_bytes()).await.unwrap();
1222 socket
1223 .write_all(µsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
1224 .await
1225 .unwrap();
1226 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1227 });
1228
1229 let client =
1230 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1231 .await
1232 .unwrap();
1233
1234 assert_eq!(client.protocol(), AgentProtocol::Current);
1235 assert_eq!(client.negotiated_version(), 2);
1237 assert!(!client.supports(MessageType::TcpConnect));
1239 }
1240
1241 #[cfg(feature = "uds")]
1242 async fn assert_accepts_legacy_relay_handshake(id_offset: u32) {
1243 let temp = tempfile::tempdir().unwrap();
1244 let sock_path = temp.path().join("agent.sock");
1245 let listener = UnixListener::bind(&sock_path).unwrap();
1246 let ready = Ready {
1247 boot_time_ns: 11,
1248 init_time_ns: 22,
1249 ready_time_ns: 33,
1250 ..Default::default()
1251 };
1252 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1253
1254 tokio::spawn(async move {
1255 let (mut socket, _) = listener.accept().await.unwrap();
1256 socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
1257 codec::write_message(&mut socket, &ready_msg).await.unwrap();
1258 });
1259
1260 let client =
1261 AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1262 .await
1263 .unwrap();
1264
1265 assert_eq!(client.protocol(), AgentProtocol::LegacyV1);
1266 assert_eq!(client.negotiated_version(), LEGACY_PROTOCOL_VERSION);
1267 let decoded = client.ready().unwrap();
1268 assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
1269 assert_eq!(decoded.init_time_ns, ready.init_time_ns);
1270 assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
1271 }
1272
1273 #[cfg(feature = "websocket")]
1274 #[tokio::test]
1275 async fn websocket_connects_and_completes_request() {
1276 use futures_util::{SinkExt, StreamExt};
1277 use tokio_tungstenite::accept_async;
1278 use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
1279
1280 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1281 let addr = listener.local_addr().unwrap();
1282 let ready = Ready {
1283 boot_time_ns: 11,
1284 init_time_ns: 22,
1285 ready_time_ns: 33,
1286 agent_version: "ws-test".to_string(),
1287 };
1288
1289 tokio::spawn(async move {
1290 let (stream, _) = listener.accept().await.unwrap();
1291 let mut ws = accept_async(stream).await.unwrap();
1292
1293 let mut range = Vec::new();
1294 range.extend_from_slice(&1u32.to_be_bytes());
1295 range.extend_from_slice(&1024u32.to_be_bytes());
1296 ws.send(WebSocketMessage::Binary(range.into()))
1297 .await
1298 .unwrap();
1299
1300 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1301 let mut ready_packet = Vec::new();
1302 codec::encode_to_buf(&ready_msg, &mut ready_packet).unwrap();
1303 ws.send(WebSocketMessage::Binary(ready_packet.into()))
1304 .await
1305 .unwrap();
1306
1307 let request_packet = loop {
1308 match ws.next().await.unwrap().unwrap() {
1309 WebSocketMessage::Binary(bytes) => break bytes.to_vec(),
1310 _ => continue,
1311 }
1312 };
1313 let mut request_buf = request_packet;
1314 let request = codec::try_decode_raw_from_buf(&mut request_buf)
1315 .unwrap()
1316 .unwrap();
1317 let request_msg = codec::raw_frame_to_message(request.clone()).unwrap();
1318 assert_eq!(request_msg.t, MessageType::FsRequest);
1319
1320 let response = Message::with_payload(
1321 MessageType::FsResponse,
1322 request.id,
1323 &FsResponse {
1324 ok: true,
1325 error: None,
1326 data: None,
1327 },
1328 )
1329 .unwrap();
1330 let mut response_packet = Vec::new();
1331 codec::encode_to_buf(&response, &mut response_packet).unwrap();
1332 ws.send(WebSocketMessage::Binary(response_packet.into()))
1333 .await
1334 .unwrap();
1335 });
1336
1337 let client = AgentClient::connect_websocket(&format!("ws://{addr}"))
1338 .await
1339 .unwrap();
1340 assert_eq!(client.agent_version(), "ws-test");
1341
1342 let response = client
1343 .request(
1344 MessageType::FsRequest,
1345 &FsRequest {
1346 op: FsOp::Stat {
1347 path: "/tmp".to_string(),
1348 follow_symlink: true,
1349 },
1350 },
1351 )
1352 .await
1353 .unwrap();
1354 assert_eq!(response.t, MessageType::FsResponse);
1355 let payload: FsResponse = response.payload().unwrap();
1356 assert!(payload.ok);
1357 }
1358
1359 #[cfg(feature = "websocket")]
1360 #[tokio::test]
1361 async fn websocket_accepts_legacy_relay_handshake() {
1362 use futures_util::{SinkExt, StreamExt};
1363 use tokio_tungstenite::accept_async;
1364 use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
1365
1366 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1367 let addr = listener.local_addr().unwrap();
1368 let ready = Ready {
1369 boot_time_ns: 11,
1370 init_time_ns: 22,
1371 ready_time_ns: 33,
1372 agent_version: "legacy-ws-test".to_string(),
1373 };
1374 let id_offset = 268_435_455u32;
1375 let (frame_tx, frame_rx) = oneshot::channel();
1376
1377 tokio::spawn(async move {
1378 let (stream, _) = listener.accept().await.unwrap();
1379 let mut ws = accept_async(stream).await.unwrap();
1380
1381 ws.send(WebSocketMessage::Binary(
1382 id_offset.to_be_bytes().to_vec().into(),
1383 ))
1384 .await
1385 .unwrap();
1386
1387 let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1388 let mut ready_packet = Vec::new();
1389 codec::encode_to_buf(&ready_msg, &mut ready_packet).unwrap();
1390 ws.send(WebSocketMessage::Binary(ready_packet.into()))
1391 .await
1392 .unwrap();
1393
1394 let request_packet = loop {
1395 match ws.next().await.unwrap().unwrap() {
1396 WebSocketMessage::Binary(bytes) => break bytes.to_vec(),
1397 _ => continue,
1398 }
1399 };
1400 let mut request_buf = request_packet;
1401 let request = codec::try_decode_raw_from_buf(&mut request_buf)
1402 .unwrap()
1403 .unwrap();
1404 frame_tx.send(request).unwrap();
1405 });
1406
1407 let client = AgentClient::connect_websocket(&format!("ws://{addr}"))
1408 .await
1409 .unwrap();
1410 assert_eq!(client.protocol(), AgentProtocol::LegacyV1);
1411 assert_eq!(client.negotiated_version(), LEGACY_PROTOCOL_VERSION);
1412 assert_eq!(client.agent_version(), "legacy-ws-test");
1413
1414 let request = ExecRequest {
1415 cmd: "/bin/true".into(),
1416 args: Vec::new(),
1417 env: Vec::new(),
1418 cwd: None,
1419 user: None,
1420 tty: false,
1421 rows: 24,
1422 cols: 80,
1423 rlimits: Vec::new(),
1424 };
1425 let (id, _rx) = client
1426 .stream(MessageType::ExecRequest, &request)
1427 .await
1428 .unwrap();
1429 let frame = frame_rx.await.unwrap();
1430 let message = codec::raw_frame_to_message(frame).unwrap();
1431
1432 assert_eq!(id, id_offset + 1);
1433 assert_eq!(message.id, id_offset + 1);
1434 assert_eq!(message.v, LEGACY_PROTOCOL_VERSION);
1435 assert_eq!(message.t, MessageType::ExecRequest);
1436 }
1437}
1438
1439impl Drop for AgentClient {
1444 fn drop(&mut self) {
1445 self.reader_handle.abort();
1446 self.writer_handle.abort();
1447 }
1448}