Skip to main content

microsandbox_agent_client/
client.rs

1//! Client for connecting to a microsandbox agent relay.
2//!
3//! [`AgentClient`] communicates with `agentd` through an agent relay transport.
4//! During connection, the relay assigns a non-overlapping correlation ID range
5//! and sends the cached `core.ready` payload so the client can begin issuing
6//! commands immediately. Unix domain sockets are available with the `uds`
7//! feature; the `stream` feature drives the client over any
8//! `AsyncRead + AsyncWrite` byte stream (e.g. a caller-owned, pre-authenticated
9//! transport adapted to bytes).
10//!
11//! Two API tiers share one socket and one reader task:
12//!
13//! - **Raw** ([`request_raw`](AgentClient::request_raw),
14//!   [`stream_raw`](AgentClient::stream_raw),
15//!   [`send_raw`](AgentClient::send_raw)) — exchange [`RawFrame`]s. The client
16//!   handles framing and correlation IDs; CBOR encoding/decoding is left to the
17//!   caller. Use this when wrapping the client for other languages.
18//! - **Typed** ([`request`](AgentClient::request),
19//!   [`stream`](AgentClient::stream), [`send`](AgentClient::send)) — same
20//!   primitives over [`Message`]; the SDK serializes payloads with CBOR.
21
22use std::collections::HashMap;
23#[cfg(feature = "stream")]
24use std::future::Future;
25#[cfg(feature = "uds")]
26use std::path::Path;
27#[cfg(feature = "stream")]
28use std::pin::Pin;
29use std::sync::{Arc, atomic::AtomicU32};
30#[cfg(feature = "stream")]
31use std::time::Duration;
32
33#[cfg(feature = "stream")]
34use microsandbox_protocol::message::FLAG_TERMINAL;
35#[cfg(feature = "stream")]
36use microsandbox_protocol::{codec::MAX_FRAME_SIZE, message::FRAME_HEADER_SIZE};
37use microsandbox_protocol::{
38    codec::{self, RawFrame},
39    core::Ready,
40    message::{Message, MessageType, PROTOCOL_VERSION},
41};
42use serde::Serialize;
43#[cfg(feature = "stream")]
44use tokio::io::{AsyncRead, AsyncWrite};
45#[cfg(feature = "uds")]
46use tokio::net::UnixStream;
47use tokio::sync::{Mutex, mpsc, oneshot};
48use tokio::task::JoinHandle;
49#[cfg(feature = "stream")]
50use tokio::time::Instant;
51
52use super::error::{AgentClientError, AgentClientResult};
53
54//--------------------------------------------------------------------------------------------------
55// Constants
56//--------------------------------------------------------------------------------------------------
57
58/// Default handshake timeout used by [`AgentClient::connect`].
59#[cfg(feature = "stream")]
60const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
61
62#[cfg(feature = "stream")]
63const WRITER_QUEUE_CAPACITY: usize = 1024;
64const REQUEST_QUEUE_CAPACITY: usize = 1;
65const STREAM_QUEUE_CAPACITY: usize = 1024;
66
67const LEGACY_PROTOCOL_VERSION: u8 = 1;
68// TODO(upgrade-0.6): Remove in 0.6.x or later once live-sandbox
69// compatibility for versions before 0.5 is no longer supported.
70#[cfg(feature = "stream")]
71const LEGACY_RELAY_ID_RANGE_STEP: u32 = u32::MAX / 16;
72
73//--------------------------------------------------------------------------------------------------
74// Types
75//--------------------------------------------------------------------------------------------------
76
77/// Agent protocol generation spoken by a connected sandbox relay.
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum AgentProtocol {
80    /// Current protocol generation.
81    Current,
82
83    /// pre-0.5 microsandbox relay handshake and agent protocol.
84    ///
85    /// TODO(upgrade-0.6): Remove in 0.6.x or later once live-sandbox
86    /// compatibility for versions before 0.5 is no longer supported.
87    LegacyV1,
88}
89
90/// Client for communicating with agentd through the agent relay.
91///
92/// See the module-level docs for an overview of the two API tiers.
93pub struct AgentClient {
94    /// Channel to the transport writer task.
95    writer: mpsc::Sender<WriterCommand>,
96    /// Next correlation ID to allocate (starts at `id_min`).
97    next_id: AtomicU32,
98    /// Lower bound (inclusive) of the assigned ID range, used for wrap-around.
99    id_min: u32,
100    /// Upper bound (exclusive) of the assigned ID range.
101    id_max: u32,
102    /// Agent protocol generation for this connection.
103    protocol: AgentProtocol,
104    /// Negotiated protocol generation: `min(our PROTOCOL_VERSION, the
105    /// generation the sandbox echoed in its `core.ready` frame)`. Drives the
106    /// capability gate on the typed send path. Distinct from [`Self::protocol`],
107    /// which selects the wire codec; see `VERSIONING.md`.
108    negotiated_version: u8,
109    /// Pending response channels keyed by correlation ID.
110    pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
111    /// Background reader task handle.
112    reader_handle: JoinHandle<()>,
113    /// Background writer task handle.
114    writer_handle: JoinHandle<()>,
115    /// Cached `core.ready` frame body (raw CBOR bytes) from the relay handshake.
116    ready_body: Vec<u8>,
117    /// Decoded `core.ready` payload from the relay handshake.
118    ready: Ready,
119}
120
121#[cfg(feature = "stream")]
122struct AgentHandshake {
123    id_min: u32,
124    id_max: u32,
125    protocol: AgentProtocol,
126    negotiated_version: u8,
127    ready_body: Vec<u8>,
128    ready: Ready,
129}
130
131#[cfg_attr(not(feature = "stream"), allow(dead_code))]
132struct WriterCommand {
133    frame: RawFrame,
134    ack: oneshot::Sender<AgentClientResult<()>>,
135}
136
137#[cfg(feature = "stream")]
138trait HandshakeReader {
139    fn read_exact_handshake<'a>(
140        &'a mut self,
141        out: &'a mut [u8],
142    ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>>;
143
144    fn read_frame_handshake<'a>(
145        &'a mut self,
146    ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>>;
147}
148
149//--------------------------------------------------------------------------------------------------
150// Methods: Connection lifecycle
151//--------------------------------------------------------------------------------------------------
152
153impl AgentProtocol {
154    fn version(self) -> u8 {
155        match self {
156            Self::Current => PROTOCOL_VERSION,
157            Self::LegacyV1 => LEGACY_PROTOCOL_VERSION,
158        }
159    }
160}
161
162impl AgentClient {
163    /// Connect to a Unix domain socket agent relay using the default 10s
164    /// handshake timeout.
165    #[cfg(feature = "uds")]
166    pub async fn connect(sock_path: impl AsRef<Path>) -> AgentClientResult<Self> {
167        Self::connect_with_timeout(sock_path, DEFAULT_HANDSHAKE_TIMEOUT).await
168    }
169
170    /// Connect to a Unix domain socket agent relay using an explicit
171    /// handshake timeout.
172    #[cfg(feature = "uds")]
173    pub async fn connect_with_timeout(
174        sock_path: impl AsRef<Path>,
175        timeout: Duration,
176    ) -> AgentClientResult<Self> {
177        let deadline = Instant::now() + timeout;
178        Self::connect_with_deadline(sock_path, deadline).await
179    }
180
181    /// Connect with an explicit handshake deadline.
182    ///
183    /// `deadline` bounds both handshake reads. Without it, an accepted
184    /// connection that stalls (e.g. a sandbox alive but wedged before
185    /// writing the handshake bytes) would block this call indefinitely.
186    #[cfg(feature = "uds")]
187    pub async fn connect_with_deadline(
188        sock_path: impl AsRef<Path>,
189        deadline: Instant,
190    ) -> AgentClientResult<Self> {
191        let sock_path = sock_path.as_ref();
192        let stream =
193            UnixStream::connect(sock_path)
194                .await
195                .map_err(|source| AgentClientError::Connect {
196                    path: sock_path.to_path_buf(),
197                    source,
198                })?;
199        Self::connect_stream_with_deadline(stream, deadline).await
200    }
201
202    /// Connect over an arbitrary byte-stream transport using the default 10s
203    /// handshake timeout.
204    ///
205    /// The stream must be a transparent pipe to the agent relay: the relay's
206    /// `[id_min][id_max]` + `core.ready` prologue and the framed protocol that
207    /// follows flow over it verbatim. This is the injection point for
208    /// caller-owned transports — e.g. a pre-authenticated WebSocket adapted to
209    /// bytes — so the caller owns the dial and its credentials and this crate
210    /// stays transport- (and dependency-) agnostic.
211    #[cfg(feature = "stream")]
212    pub async fn connect_stream<S>(stream: S) -> AgentClientResult<Self>
213    where
214        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
215    {
216        Self::connect_stream_with_timeout(stream, DEFAULT_HANDSHAKE_TIMEOUT).await
217    }
218
219    /// Connect over an arbitrary byte-stream transport using an explicit
220    /// handshake timeout.
221    #[cfg(feature = "stream")]
222    pub async fn connect_stream_with_timeout<S>(
223        stream: S,
224        timeout: Duration,
225    ) -> AgentClientResult<Self>
226    where
227        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
228    {
229        let deadline = Instant::now() + timeout;
230        Self::connect_stream_with_deadline(stream, deadline).await
231    }
232
233    /// Connect over an arbitrary byte-stream transport with an explicit
234    /// handshake deadline.
235    ///
236    /// `deadline` bounds both handshake reads so an accepted-but-stalled
237    /// transport cannot block this call indefinitely.
238    #[cfg(feature = "stream")]
239    pub async fn connect_stream_with_deadline<S>(
240        stream: S,
241        deadline: Instant,
242    ) -> AgentClientResult<Self>
243    where
244        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
245    {
246        let (mut reader, writer) = tokio::io::split(stream);
247        let handshake = perform_handshake(&mut reader, deadline).await?;
248
249        tracing::info!(
250            id_min = handshake.id_min,
251            id_max = handshake.id_max,
252            protocol = ?handshake.protocol,
253            ready_bytes = handshake.ready_body.len(),
254            boot_time_ns = handshake.ready.boot_time_ns,
255            "agent client: connected to relay"
256        );
257        if handshake.protocol == AgentProtocol::LegacyV1 {
258            // TODO(upgrade-0.6): Remove in 0.6.x or later once live-sandbox
259            // compatibility for versions before 0.5 is no longer supported.
260            tracing::warn!(
261                "agent client: connected to a sandbox started before microsandbox 0.5; exec compatibility is temporary and filesystem/SFTP require stop/start"
262            );
263        }
264
265        let pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>> =
266            Arc::new(Mutex::new(HashMap::new()));
267
268        let (writer_tx, writer_rx) = mpsc::channel(WRITER_QUEUE_CAPACITY);
269        let reader_handle = tokio::spawn(reader_loop(reader, Arc::clone(&pending)));
270        let writer_handle = tokio::spawn(stream_writer_loop(writer, writer_rx));
271
272        Ok(Self {
273            writer: writer_tx,
274            next_id: AtomicU32::new(first_request_id(handshake.id_min)),
275            id_min: handshake.id_min,
276            id_max: handshake.id_max,
277            protocol: handshake.protocol,
278            negotiated_version: handshake.negotiated_version,
279            pending,
280            reader_handle,
281            writer_handle,
282            ready_body: handshake.ready_body,
283            ready: handshake.ready,
284        })
285    }
286
287    /// Close the connection. Drops the writer and aborts the reader task;
288    /// any in-flight requests resolve with [`AgentClientError::Closed`].
289    pub async fn close(self) {
290        // Drop runs: reader aborts via Drop impl, writer closes when the
291        // last Arc reference dies. Senders in `pending` drop with self,
292        // resolving outstanding waiters.
293    }
294}
295
296//--------------------------------------------------------------------------------------------------
297// Methods: Raw transport (CBOR-blind)
298//--------------------------------------------------------------------------------------------------
299
300impl AgentClient {
301    /// One-shot raw request: alloc id, send a frame with `(flags, body)`,
302    /// await one response frame with the matching id.
303    ///
304    /// Use this for protocol RPCs that produce exactly one terminal response
305    /// (e.g. `FsRequest` → `FsResponse`).
306    pub async fn request_raw(&self, flags: u8, body: Vec<u8>) -> AgentClientResult<RawFrame> {
307        let (tx, mut rx) = mpsc::channel(REQUEST_QUEUE_CAPACITY);
308        let id = self.reserve_id(tx).await?;
309
310        if let Err(e) = self.write_frame_owned(id, flags, body).await {
311            self.pending.lock().await.remove(&id);
312            return Err(e);
313        }
314
315        let frame = rx.recv().await.ok_or(AgentClientError::ReaderClosed(id))?;
316        self.pending.lock().await.remove(&id);
317        Ok(frame)
318    }
319
320    /// Open a streaming raw session: alloc id, register a subscription,
321    /// send the opening frame, return `(id, receiver)`.
322    ///
323    /// The receiver yields every frame the relay forwards for this `id`
324    /// until a frame with [`FLAG_TERMINAL`] arrives or the receiver is dropped.
325    /// Use [`send_raw`](Self::send_raw) with the returned id to send
326    /// follow-up frames within the session.
327    pub async fn stream_raw(
328        &self,
329        flags: u8,
330        body: Vec<u8>,
331    ) -> AgentClientResult<(u32, mpsc::Receiver<RawFrame>)> {
332        let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
333        let id = self.reserve_id(tx).await?;
334
335        if let Err(e) = self.write_frame_owned(id, flags, body).await {
336            self.pending.lock().await.remove(&id);
337            return Err(e);
338        }
339
340        Ok((id, rx))
341    }
342
343    /// Send a follow-up raw frame on an existing correlation id.
344    ///
345    /// Use for messages that belong to a session started via
346    /// [`stream_raw`](Self::stream_raw) (e.g. `ExecStdin`, `ExecSignal`,
347    /// `ExecResize`, `FsData` chunks).
348    pub async fn send_raw(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
349        self.write_frame(id, flags, body).await
350    }
351
352    /// The cached `core.ready` handshake frame body bytes (CBOR-encoded).
353    ///
354    /// Useful for bindings that want to deserialize the ready payload with
355    /// their own CBOR tooling. For typed access, use [`ready`](Self::ready).
356    pub fn ready_bytes(&self) -> &[u8] {
357        &self.ready_body
358    }
359
360    /// Agent protocol generation for this connection.
361    pub fn protocol(&self) -> AgentProtocol {
362        self.protocol
363    }
364
365    /// Returns `true` if this connection is using the legacy pre-0.5 protocol.
366    pub fn is_legacy_protocol(&self) -> bool {
367        self.protocol == AgentProtocol::LegacyV1
368    }
369
370    /// The negotiated protocol generation for this connection: the lower of what
371    /// this client speaks and what the sandbox advertised at handshake.
372    pub fn negotiated_version(&self) -> u8 {
373        self.negotiated_version
374    }
375
376    /// The runtime's self-reported package version, taken from its `core.ready`
377    /// frame. Empty when the runtime predates this field (an older agent), in
378    /// which case fall back to the generation for diagnostics.
379    pub fn agent_version(&self) -> &str {
380        &self.ready.agent_version
381    }
382
383    /// Whether the connected sandbox is new enough to handle the given message
384    /// type. The single source of truth for feature gating: callers that can't
385    /// gate by sending (e.g. the SSH/SFTP layer) consult this instead of
386    /// inspecting the protocol generation directly.
387    pub fn supports(&self, t: MessageType) -> bool {
388        t.min_protocol_version() <= self.negotiated_version
389    }
390
391    /// Reject a message type the connected sandbox is too old to handle, against
392    /// this connection's negotiated generation. Fails before any bytes are sent,
393    /// so only that one operation fails and the session continues.
394    pub fn ensure_version_compat(&self, t: MessageType) -> AgentClientResult<()> {
395        Self::ensure_version_compat_for(t, self.negotiated_version)
396    }
397
398    /// Check a message type against an explicit negotiated generation.
399    ///
400    /// The single place the rule lives. Exposed for callers that hold the
401    /// negotiated generation but not the live client (e.g. the SSH/SFTP layer).
402    pub fn ensure_version_compat_for(t: MessageType, negotiated: u8) -> AgentClientResult<()> {
403        if t.is_available_at(negotiated) {
404            return Ok(());
405        }
406        Err(AgentClientError::UnsupportedOperation {
407            msg_type: t.as_str(),
408            needs: t.min_protocol_version(),
409            peer: negotiated,
410        })
411    }
412}
413
414//--------------------------------------------------------------------------------------------------
415// Methods: Typed transport (CBOR-aware)
416//--------------------------------------------------------------------------------------------------
417
418impl AgentClient {
419    /// One-shot typed request. Flags are derived from the message type.
420    pub async fn request<T: Serialize>(
421        &self,
422        t: MessageType,
423        payload: &T,
424    ) -> AgentClientResult<Message> {
425        self.ensure_version_compat(t)?;
426        let flags = t.flags();
427        let body = encode_message_body(self.protocol.version(), t, payload)?;
428        let frame = self.request_raw(flags, body).await?;
429        Ok(codec::raw_frame_to_message(frame)?)
430    }
431
432    /// Open a streaming typed session. Flags are derived from the message type.
433    /// Returns the assigned id and a typed receiver.
434    pub async fn stream<T: Serialize>(
435        &self,
436        t: MessageType,
437        payload: &T,
438    ) -> AgentClientResult<(u32, mpsc::Receiver<Message>)> {
439        self.ensure_version_compat(t)?;
440        let flags = t.flags();
441        let body = encode_message_body(self.protocol.version(), t, payload)?;
442        let (id, raw_rx) = self.stream_raw(flags, body).await?;
443
444        let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
445        tokio::spawn(decode_stream_task(raw_rx, tx));
446        Ok((id, rx))
447    }
448
449    /// Send a follow-up typed message on an existing correlation id.
450    pub async fn send<T: Serialize>(
451        &self,
452        id: u32,
453        t: MessageType,
454        payload: &T,
455    ) -> AgentClientResult<()> {
456        self.ensure_version_compat(t)?;
457        let flags = t.flags();
458        let body = encode_message_body(self.protocol.version(), t, payload)?;
459        self.write_frame_owned(id, flags, body).await
460    }
461
462    /// Decode the cached handshake `core.ready` payload.
463    pub fn ready(&self) -> AgentClientResult<Ready> {
464        Ok(self.ready.clone())
465    }
466}
467
468//--------------------------------------------------------------------------------------------------
469// Methods: Internals
470//--------------------------------------------------------------------------------------------------
471
472impl AgentClient {
473    /// Reserve a unique correlation ID from the relay-assigned range.
474    ///
475    /// Wraps around within the assigned range and skips IDs that still have an
476    /// active pending request or stream.
477    async fn reserve_id(&self, tx: mpsc::Sender<RawFrame>) -> AgentClientResult<u32> {
478        let mut pending = self.pending.lock().await;
479        let attempts = usable_id_count(self.id_min, self.id_max);
480        for _ in 0..attempts {
481            let id = self
482                .next_id
483                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
484            if self.next_id.load(std::sync::atomic::Ordering::Relaxed) >= self.id_max {
485                self.next_id.store(
486                    first_request_id(self.id_min),
487                    std::sync::atomic::Ordering::Relaxed,
488                );
489            }
490            if id == 0 || id < self.id_min || id >= self.id_max || pending.contains_key(&id) {
491                continue;
492            }
493            pending.insert(id, tx);
494            return Ok(id);
495        }
496
497        Err(AgentClientError::IdRangeExhausted)
498    }
499
500    /// Write a single framed message to the socket.
501    async fn write_frame(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
502        self.write_frame_owned(id, flags, body.to_vec()).await
503    }
504
505    /// Write a single framed message to the socket, taking ownership of the body.
506    async fn write_frame_owned(&self, id: u32, flags: u8, body: Vec<u8>) -> AgentClientResult<()> {
507        let (ack, written) = oneshot::channel();
508        self.writer
509            .send(WriterCommand {
510                frame: RawFrame { id, flags, body },
511                ack,
512            })
513            .await
514            .map_err(|_| AgentClientError::Closed)?;
515        written.await.map_err(|_| AgentClientError::Closed)?
516    }
517}
518
519//--------------------------------------------------------------------------------------------------
520// Functions
521//--------------------------------------------------------------------------------------------------
522
523#[cfg(feature = "stream")]
524async fn perform_handshake<R>(
525    reader: &mut R,
526    deadline: Instant,
527) -> AgentClientResult<AgentHandshake>
528where
529    R: HandshakeReader + ?Sized,
530{
531    // Current handshake:
532    // [id_min: u32 BE][id_max: u32 BE][ready_frame_bytes...]
533    //
534    // Legacy pre-0.5 handshake:
535    // [id_offset: u32 BE][ready_frame_bytes...]
536    //
537    // Reading 8 bytes up-front lets us distinguish the two forms. For legacy
538    // relays, the second word is the ready-frame length prefix.
539    let mut range_buf = [0u8; 8];
540    tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut range_buf))
541        .await
542        .map_err(|_| {
543            AgentClientError::Handshake("read id range: timed out before relay sent bytes".into())
544        })??;
545    let id_start_or_offset = u32::from_be_bytes(range_buf[0..4].try_into().unwrap());
546    let id_max_or_frame_len = u32::from_be_bytes(range_buf[4..8].try_into().unwrap());
547
548    let legacy_handshake =
549        looks_like_legacy_relay_handshake(id_start_or_offset, id_max_or_frame_len);
550    let (id_min, id_max, ready_frame, protocol) = if legacy_handshake {
551        let id_offset = id_start_or_offset;
552        let ready_frame =
553            read_raw_frame_after_len_prefix(reader, range_buf[4..8].try_into().unwrap(), deadline)
554                .await?;
555        (
556            id_offset.saturating_add(1),
557            id_offset.saturating_add(LEGACY_RELAY_ID_RANGE_STEP),
558            ready_frame,
559            AgentProtocol::LegacyV1,
560        )
561    } else if id_start_or_offset >= id_max_or_frame_len {
562        return Err(AgentClientError::Handshake(format!(
563            "invalid relay id range: start={id_start_or_offset}, end={id_max_or_frame_len}"
564        )));
565    } else {
566        let ready_frame = tokio::time::timeout_at(deadline, reader.read_frame_handshake())
567            .await
568            .map_err(|_| {
569                AgentClientError::Handshake(
570                    "read ready frame: timed out before relay sent frame".into(),
571                )
572            })?
573            .map_err(|e| AgentClientError::Handshake(format!("read ready frame: {e}")))?;
574        (
575            id_start_or_offset,
576            id_max_or_frame_len,
577            ready_frame,
578            AgentProtocol::Current,
579        )
580    };
581    ensure_usable_id_range(id_min, id_max)?;
582
583    let ready_msg = codec::raw_frame_to_message(ready_frame.clone())
584        .map_err(|e| AgentClientError::Handshake(format!("decode ready frame: {e}")))?;
585    if ready_msg.t != MessageType::Ready {
586        return Err(AgentClientError::Handshake(format!(
587            "expected core.ready frame, got {}",
588            ready_msg.t.as_str()
589        )));
590    }
591    let ready: Ready = ready_msg
592        .payload()
593        .map_err(|e| AgentClientError::Handshake(format!("decode ready payload: {e}")))?;
594
595    // The negotiated capability generation is the lower of what we speak and
596    // what the sandbox echoed in its ready frame (`ready_msg.v`). For the
597    // load-bearing case — a newer host meeting an older runtime — this is the
598    // runtime's generation, so the send gate withholds features it can't
599    // handle. The codec generation (`protocol`) is negotiated separately.
600    let negotiated_version = protocol.version().min(ready_msg.v);
601
602    Ok(AgentHandshake {
603        id_min,
604        id_max,
605        protocol,
606        negotiated_version,
607        ready_body: ready_frame.body,
608        ready,
609    })
610}
611
612fn first_request_id(id_min: u32) -> u32 {
613    id_min.max(1)
614}
615
616#[cfg(feature = "stream")]
617fn ensure_usable_id_range(id_min: u32, id_max: u32) -> AgentClientResult<()> {
618    if usable_id_count(id_min, id_max) == 0 {
619        return Err(AgentClientError::Handshake(format!(
620            "relay id range contains no usable nonzero ids: start={id_min}, end={id_max}"
621        )));
622    }
623    Ok(())
624}
625
626fn usable_id_count(id_min: u32, id_max: u32) -> u32 {
627    id_max.saturating_sub(first_request_id(id_min))
628}
629
630#[cfg(feature = "stream")]
631fn looks_like_legacy_relay_handshake(id_min: u32, id_max: u32) -> bool {
632    // TODO(upgrade-0.6): Remove in 0.6.x or later once pre-0.5 relay
633    // handshakes are no longer accepted.
634    // In the legacy relay handshake, the first 4 bytes are the id offset and
635    // the next 4 bytes are already the ready-frame length prefix. In the v2
636    // handshake, the second word is the exclusive upper id bound, which is far
637    // larger than any valid frame length. Tiny current ranges are possible in
638    // tests, so prefer the current interpretation when the range is otherwise
639    // valid and starts at a nonzero id.
640    id_max >= FRAME_HEADER_SIZE as u32
641        && id_max <= MAX_FRAME_SIZE
642        && (id_min == 0 || id_min >= id_max)
643}
644
645#[cfg(feature = "stream")]
646async fn read_raw_frame_after_len_prefix<R>(
647    reader: &mut R,
648    len_buf: [u8; 4],
649    deadline: Instant,
650) -> AgentClientResult<RawFrame>
651where
652    R: HandshakeReader + ?Sized,
653{
654    let frame_len = u32::from_be_bytes(len_buf);
655    if frame_len > MAX_FRAME_SIZE {
656        return Err(AgentClientError::Handshake(format!(
657            "legacy ready frame too large: {frame_len} bytes (max {MAX_FRAME_SIZE})"
658        )));
659    }
660    if frame_len < FRAME_HEADER_SIZE as u32 {
661        return Err(AgentClientError::Handshake(format!(
662            "legacy ready frame too short: {frame_len} bytes"
663        )));
664    }
665
666    let mut data = vec![0u8; frame_len as usize];
667    tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut data))
668        .await
669        .map_err(|_| {
670            AgentClientError::Handshake(
671                "read legacy ready frame: timed out before relay sent frame".into(),
672            )
673        })?
674        .map_err(|e| AgentClientError::Handshake(format!("read legacy ready frame: {e}")))?;
675
676    let id = u32::from_be_bytes(data[0..4].try_into().unwrap());
677    let flags = data[4];
678    let body = data[FRAME_HEADER_SIZE..].to_vec();
679
680    Ok(RawFrame { id, flags, body })
681}
682
683#[cfg(feature = "stream")]
684impl<R> HandshakeReader for R
685where
686    R: tokio::io::AsyncRead + Unpin + Send,
687{
688    fn read_exact_handshake<'a>(
689        &'a mut self,
690        out: &'a mut [u8],
691    ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>> {
692        Box::pin(async move {
693            tokio::io::AsyncReadExt::read_exact(self, out)
694                .await
695                .map(|_| ())
696                .map_err(|e| AgentClientError::Handshake(e.to_string()))
697        })
698    }
699
700    fn read_frame_handshake<'a>(
701        &'a mut self,
702    ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>> {
703        Box::pin(async move {
704            codec::read_raw_frame(self)
705                .await
706                .map_err(AgentClientError::Protocol)
707        })
708    }
709}
710
711#[cfg(feature = "stream")]
712async fn stream_writer_loop<W>(mut writer: W, mut rx: mpsc::Receiver<WriterCommand>)
713where
714    W: tokio::io::AsyncWrite + Unpin,
715{
716    while let Some(command) = rx.recv().await {
717        if let Err(e) = codec::write_raw_frame(&mut writer, &command.frame).await {
718            tracing::debug!("agent client: stream writer error: {e}");
719            let _ = command.ack.send(Err(AgentClientError::Protocol(e)));
720            break;
721        }
722        let _ = command.ack.send(Ok(()));
723    }
724}
725
726/// Background task that reads frames from the relay and dispatches them to
727/// pending channels by correlation ID. Operates on raw frames — no CBOR.
728#[cfg(feature = "stream")]
729async fn reader_loop<R>(mut reader: R, pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>)
730where
731    R: tokio::io::AsyncRead + Unpin,
732{
733    loop {
734        let frame = match codec::read_raw_frame(&mut reader).await {
735            Ok(frame) => frame,
736            Err(e) => {
737                tracing::debug!("agent client: reader EOF or error: {e}");
738                break;
739            }
740        };
741
742        dispatch_frame(frame, &pending).await;
743    }
744
745    // Reader exited — drop all senders so outstanding receivers wake up.
746    let mut map = pending.lock().await;
747    map.clear();
748}
749
750#[cfg(feature = "stream")]
751async fn dispatch_frame(
752    frame: RawFrame,
753    pending: &Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
754) {
755    let id = frame.id;
756    let is_terminal = (frame.flags & FLAG_TERMINAL) != 0;
757
758    let tx = {
759        let mut map = pending.lock().await;
760        let Some(tx) = map.get(&id).cloned() else {
761            tracing::trace!("agent client: no pending handler for id={id}");
762            return;
763        };
764        if is_terminal {
765            map.remove(&id);
766        }
767        tx
768    };
769
770    if tx.send(frame).await.is_err() {
771        pending.lock().await.remove(&id);
772    }
773}
774
775/// Translate a stream of raw frames into typed messages.
776async fn decode_stream_task(mut raw_rx: mpsc::Receiver<RawFrame>, tx: mpsc::Sender<Message>) {
777    while let Some(frame) = raw_rx.recv().await {
778        match codec::raw_frame_to_message(frame) {
779            Ok(msg) => {
780                if tx.send(msg).await.is_err() {
781                    break;
782                }
783            }
784            Err(e) => {
785                tracing::warn!("agent client: failed to decode frame in stream: {e}");
786                // Continue — single malformed frame shouldn't kill the stream.
787            }
788        }
789    }
790}
791
792/// Encode a typed payload to a CBOR `Message` body.
793fn encode_message_body<T: Serialize>(
794    version: u8,
795    t: MessageType,
796    payload: &T,
797) -> AgentClientResult<Vec<u8>> {
798    let mut msg = Message::with_payload(t, 0, payload)?;
799    msg.v = version;
800    let mut body = Vec::new();
801    ciborium::into_writer(&msg, &mut body).map_err(microsandbox_protocol::ProtocolError::from)?;
802    Ok(body)
803}
804
805//--------------------------------------------------------------------------------------------------
806// Tests
807//--------------------------------------------------------------------------------------------------
808
809#[cfg(test)]
810mod tests {
811    #[cfg(feature = "uds")]
812    use microsandbox_protocol::core::Ready;
813    #[cfg(feature = "uds")]
814    use microsandbox_protocol::exec::ExecRequest;
815    #[cfg(feature = "uds")]
816    use microsandbox_protocol::message::PROTOCOL_VERSION;
817    #[cfg(feature = "uds")]
818    use tokio::io::AsyncWriteExt;
819    #[cfg(feature = "uds")]
820    use tokio::net::UnixListener;
821    #[cfg(feature = "uds")]
822    use tokio::sync::oneshot;
823
824    use super::*;
825
826    #[cfg(feature = "uds")]
827    #[tokio::test]
828    async fn connect_decodes_ready_payload() {
829        let temp = tempfile::tempdir().unwrap();
830        let sock_path = temp.path().join("agent.sock");
831        let listener = UnixListener::bind(&sock_path).unwrap();
832        let ready = Ready {
833            boot_time_ns: 11,
834            init_time_ns: 22,
835            ready_time_ns: 33,
836            agent_version: "9.9.9".to_string(),
837        };
838        let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
839
840        tokio::spawn(async move {
841            let (mut socket, _) = listener.accept().await.unwrap();
842            socket.write_all(&1u32.to_be_bytes()).await.unwrap();
843            socket.write_all(&8u32.to_be_bytes()).await.unwrap();
844            codec::write_message(&mut socket, &ready_msg).await.unwrap();
845        });
846
847        let client =
848            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
849                .await
850                .unwrap();
851
852        assert_eq!(client.protocol(), AgentProtocol::Current);
853        // Both peers speak the current generation, so that is what is negotiated.
854        assert_eq!(client.negotiated_version(), PROTOCOL_VERSION);
855        assert!(client.supports(MessageType::FsRequest));
856        // The runtime's self-reported version round-trips from the ready frame.
857        assert_eq!(client.agent_version(), "9.9.9");
858        let decoded = client.ready().unwrap();
859        assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
860        assert_eq!(decoded.init_time_ns, ready.init_time_ns);
861        assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
862
863        let raw_msg: Message = ciborium::from_reader(client.ready_bytes()).unwrap();
864        assert_eq!(raw_msg.t, MessageType::Ready);
865    }
866
867    #[cfg(feature = "uds")]
868    #[tokio::test]
869    async fn connect_negotiates_down_to_older_guest_generation() {
870        let temp = tempfile::tempdir().unwrap();
871        let sock_path = temp.path().join("agent.sock");
872        let listener = UnixListener::bind(&sock_path).unwrap();
873        let ready = Ready {
874            boot_time_ns: 1,
875            init_time_ns: 2,
876            ready_time_ns: 3,
877            ..Default::default()
878        };
879        // A current-codec guest that advertises an older capability generation in
880        // its ready frame (a runtime one generation behind this host).
881        let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
882        ready_msg.v = 1;
883
884        tokio::spawn(async move {
885            let (mut socket, _) = listener.accept().await.unwrap();
886            socket.write_all(&1u32.to_be_bytes()).await.unwrap();
887            socket
888                .write_all(&microsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
889                .await
890                .unwrap();
891            codec::write_message(&mut socket, &ready_msg).await.unwrap();
892        });
893
894        let client =
895            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
896                .await
897                .unwrap();
898
899        // Current codec, but the capability gate is pinned to the guest's older
900        // generation: min(host PROTOCOL_VERSION, guest's advertised 1) == 1.
901        assert_eq!(client.protocol(), AgentProtocol::Current);
902        assert_eq!(client.negotiated_version(), 1);
903        // Exec is in the baseline; filesystem is not, at generation 1.
904        assert!(client.supports(MessageType::ExecRequest));
905        assert!(!client.supports(MessageType::FsRequest));
906    }
907
908    #[cfg(feature = "uds")]
909    #[tokio::test]
910    async fn connect_accepts_legacy_relay_handshake() {
911        assert_accepts_legacy_relay_handshake(0).await;
912        assert_accepts_legacy_relay_handshake(268_435_455).await;
913    }
914
915    #[cfg(feature = "uds")]
916    #[tokio::test]
917    async fn legacy_relay_requests_use_v1_and_legacy_id_range() {
918        let temp = tempfile::tempdir().unwrap();
919        let sock_path = temp.path().join("agent.sock");
920        let listener = UnixListener::bind(&sock_path).unwrap();
921        let ready = Ready {
922            boot_time_ns: 11,
923            init_time_ns: 22,
924            ready_time_ns: 33,
925            ..Default::default()
926        };
927        let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
928        let id_offset = 268_435_455u32;
929        let (frame_tx, frame_rx) = oneshot::channel();
930
931        tokio::spawn(async move {
932            let (mut socket, _) = listener.accept().await.unwrap();
933            socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
934            codec::write_message(&mut socket, &ready_msg).await.unwrap();
935            let frame = codec::read_raw_frame(&mut socket).await.unwrap();
936            frame_tx.send(frame).unwrap();
937        });
938
939        let client =
940            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
941                .await
942                .unwrap();
943        let request = ExecRequest {
944            cmd: "/bin/true".into(),
945            args: Vec::new(),
946            env: Vec::new(),
947            cwd: None,
948            user: None,
949            tty: false,
950            rows: 24,
951            cols: 80,
952            rlimits: Vec::new(),
953        };
954        let (id, _rx) = client
955            .stream(MessageType::ExecRequest, &request)
956            .await
957            .unwrap();
958
959        let frame = frame_rx.await.unwrap();
960        let message = codec::raw_frame_to_message(frame).unwrap();
961
962        assert_eq!(id, id_offset + 1);
963        assert_eq!(message.id, id_offset + 1);
964        assert_eq!(message.v, LEGACY_PROTOCOL_VERSION);
965        assert_eq!(message.t, MessageType::ExecRequest);
966    }
967
968    #[test]
969    fn version_compat_across_generations() {
970        use MessageType::{ExecRequest, FsRequest};
971        // (message type, peer generation, expected allowed). Generation 1 is the
972        // pre-0.5 legacy runtime (no filesystem); generation 2 introduced the
973        // Fs* types; generation 5 is current.
974        let cases = [
975            (ExecRequest, 1, true),
976            (ExecRequest, 2, true),
977            (ExecRequest, 3, true),
978            (FsRequest, 1, false),
979            (FsRequest, 2, true),
980            (FsRequest, 3, true),
981        ];
982        for (t, generation, allowed) in cases {
983            assert_eq!(
984                AgentClient::ensure_version_compat_for(t, generation).is_ok(),
985                allowed,
986                "{t:?} at generation {generation}"
987            );
988        }
989    }
990
991    #[test]
992    fn version_compat_rejection_is_typed() {
993        // Filesystem on the legacy (generation 1) runtime is rejected before any
994        // send, with the structured error whose message tells the user to restart.
995        let err =
996            AgentClient::ensure_version_compat_for(MessageType::FsRequest, LEGACY_PROTOCOL_VERSION)
997                .unwrap_err();
998        assert!(matches!(
999            err,
1000            AgentClientError::UnsupportedOperation {
1001                needs: 2,
1002                peer: 1,
1003                ..
1004            }
1005        ));
1006    }
1007
1008    #[cfg(feature = "uds")]
1009    #[tokio::test]
1010    async fn connect_preserves_current_peer_protocol_version() {
1011        let temp = tempfile::tempdir().unwrap();
1012        let sock_path = temp.path().join("agent.sock");
1013        let listener = UnixListener::bind(&sock_path).unwrap();
1014        let ready = Ready {
1015            boot_time_ns: 11,
1016            init_time_ns: 22,
1017            ready_time_ns: 33,
1018            ..Default::default()
1019        };
1020        let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1021        ready_msg.v = 2;
1022
1023        tokio::spawn(async move {
1024            let (mut socket, _) = listener.accept().await.unwrap();
1025            socket.write_all(&1u32.to_be_bytes()).await.unwrap();
1026            socket
1027                .write_all(&microsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
1028                .await
1029                .unwrap();
1030            codec::write_message(&mut socket, &ready_msg).await.unwrap();
1031        });
1032
1033        let client =
1034            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1035                .await
1036                .unwrap();
1037
1038        assert_eq!(client.protocol(), AgentProtocol::Current);
1039        // The runtime reported generation 2, so that is the negotiated capability.
1040        assert_eq!(client.negotiated_version(), 2);
1041        // TCP forwarding (generation 4) is unavailable to a generation-2 runtime.
1042        assert!(!client.supports(MessageType::TcpConnect));
1043    }
1044
1045    #[cfg(feature = "uds")]
1046    async fn assert_accepts_legacy_relay_handshake(id_offset: u32) {
1047        let temp = tempfile::tempdir().unwrap();
1048        let sock_path = temp.path().join("agent.sock");
1049        let listener = UnixListener::bind(&sock_path).unwrap();
1050        let ready = Ready {
1051            boot_time_ns: 11,
1052            init_time_ns: 22,
1053            ready_time_ns: 33,
1054            ..Default::default()
1055        };
1056        let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1057
1058        tokio::spawn(async move {
1059            let (mut socket, _) = listener.accept().await.unwrap();
1060            socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
1061            codec::write_message(&mut socket, &ready_msg).await.unwrap();
1062        });
1063
1064        let client =
1065            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1066                .await
1067                .unwrap();
1068
1069        assert_eq!(client.protocol(), AgentProtocol::LegacyV1);
1070        assert_eq!(client.negotiated_version(), LEGACY_PROTOCOL_VERSION);
1071        let decoded = client.ready().unwrap();
1072        assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
1073        assert_eq!(decoded.init_time_ns, ready.init_time_ns);
1074        assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
1075    }
1076
1077    #[cfg(feature = "stream")]
1078    #[tokio::test]
1079    async fn connect_stream_handshakes_and_streams_exec() {
1080        use microsandbox_protocol::exec::{ExecExited, ExecRequest, ExecStdout};
1081        use tokio::io::AsyncWriteExt;
1082
1083        let (client_io, mut server_io) = tokio::io::duplex(64 * 1024);
1084        let ready = Ready {
1085            boot_time_ns: 11,
1086            init_time_ns: 22,
1087            ready_time_ns: 33,
1088            agent_version: "stream-test".to_string(),
1089        };
1090        let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1091
1092        tokio::spawn(async move {
1093            // Relay handshake: [id_min][id_max] then the core.ready frame.
1094            server_io.write_all(&1u32.to_be_bytes()).await.unwrap();
1095            server_io.write_all(&1024u32.to_be_bytes()).await.unwrap();
1096            codec::write_message(&mut server_io, &ready_msg)
1097                .await
1098                .unwrap();
1099
1100            // One exec stream echoed back: stdout, then a terminal exited.
1101            let request = codec::read_raw_frame(&mut server_io).await.unwrap();
1102            let stdout = Message::with_payload(
1103                MessageType::ExecStdout,
1104                request.id,
1105                &ExecStdout {
1106                    data: b"hi".to_vec(),
1107                },
1108            )
1109            .unwrap();
1110            codec::write_message(&mut server_io, &stdout).await.unwrap();
1111            let exited =
1112                Message::with_payload(MessageType::ExecExited, request.id, &ExecExited { code: 0 })
1113                    .unwrap();
1114            codec::write_message(&mut server_io, &exited).await.unwrap();
1115        });
1116
1117        let client = AgentClient::connect_stream_with_deadline(
1118            client_io,
1119            Instant::now() + Duration::from_secs(1),
1120        )
1121        .await
1122        .unwrap();
1123
1124        assert_eq!(client.protocol(), AgentProtocol::Current);
1125        assert_eq!(client.agent_version(), "stream-test");
1126        assert!(client.supports(MessageType::ExecRequest));
1127
1128        let request = ExecRequest {
1129            cmd: "echo".into(),
1130            args: vec!["hi".into()],
1131            env: Vec::new(),
1132            cwd: None,
1133            user: None,
1134            tty: false,
1135            rows: 24,
1136            cols: 80,
1137            rlimits: Vec::new(),
1138        };
1139        let (_id, mut rx) = client
1140            .stream(MessageType::ExecRequest, &request)
1141            .await
1142            .unwrap();
1143
1144        let first = rx.recv().await.unwrap();
1145        assert_eq!(first.t, MessageType::ExecStdout);
1146        let out: ExecStdout = first.payload().unwrap();
1147        assert_eq!(out.data, b"hi");
1148
1149        let second = rx.recv().await.unwrap();
1150        assert_eq!(second.t, MessageType::ExecExited);
1151        let exit: ExecExited = second.payload().unwrap();
1152        assert_eq!(exit.code, 0);
1153    }
1154}
1155
1156//--------------------------------------------------------------------------------------------------
1157// Trait Implementations
1158//--------------------------------------------------------------------------------------------------
1159
1160impl Drop for AgentClient {
1161    fn drop(&mut self) {
1162        self.reader_handle.abort();
1163        self.writer_handle.abort();
1164    }
1165}