Skip to main content

microsandbox_agent_client/
client.rs

1//! Client for connecting to a microsandbox agent relay.
2//!
3//! [`AgentClient`] communicates with `agentd` through an agent relay transport.
4//! During connection, the relay assigns a non-overlapping correlation ID range
5//! and sends the cached `core.ready` payload so the client can begin issuing
6//! commands immediately. Unix domain sockets are available with the `uds`
7//! feature, and WebSocket connections are available with the `websocket`
8//! feature.
9//!
10//! Two API tiers share one socket and one reader task:
11//!
12//! - **Raw** ([`request_raw`](AgentClient::request_raw),
13//!   [`stream_raw`](AgentClient::stream_raw),
14//!   [`send_raw`](AgentClient::send_raw)) — exchange [`RawFrame`]s. The client
15//!   handles framing and correlation IDs; CBOR encoding/decoding is left to the
16//!   caller. Use this when wrapping the client for other languages.
17//! - **Typed** ([`request`](AgentClient::request),
18//!   [`stream`](AgentClient::stream), [`send`](AgentClient::send)) — same
19//!   primitives over [`Message`]; the SDK serializes payloads with CBOR.
20
21use std::collections::HashMap;
22#[cfg(any(feature = "uds", feature = "websocket"))]
23use std::future::Future;
24#[cfg(feature = "uds")]
25use std::path::Path;
26#[cfg(any(feature = "uds", feature = "websocket"))]
27use std::pin::Pin;
28use std::sync::{Arc, atomic::AtomicU32};
29#[cfg(any(feature = "uds", feature = "websocket"))]
30use std::time::Duration;
31
32#[cfg(feature = "websocket")]
33use futures_util::{SinkExt, StreamExt};
34#[cfg(any(feature = "uds", feature = "websocket"))]
35use microsandbox_protocol::message::FLAG_TERMINAL;
36#[cfg(any(feature = "uds", feature = "websocket"))]
37use microsandbox_protocol::{codec::MAX_FRAME_SIZE, message::FRAME_HEADER_SIZE};
38use microsandbox_protocol::{
39    codec::{self, RawFrame},
40    core::Ready,
41    message::{Message, MessageType, PROTOCOL_VERSION},
42};
43use serde::Serialize;
44#[cfg(feature = "uds")]
45use tokio::net::UnixStream;
46use tokio::sync::{Mutex, mpsc, oneshot};
47use tokio::task::JoinHandle;
48#[cfg(any(feature = "uds", feature = "websocket"))]
49use tokio::time::Instant;
50#[cfg(feature = "websocket")]
51use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
52
53use super::error::{AgentClientError, AgentClientResult};
54
55//--------------------------------------------------------------------------------------------------
56// Constants
57//--------------------------------------------------------------------------------------------------
58
59/// Default handshake timeout used by [`AgentClient::connect`].
60#[cfg(any(feature = "uds", feature = "websocket"))]
61const DEFAULT_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
62
63#[cfg(any(feature = "uds", feature = "websocket"))]
64const WRITER_QUEUE_CAPACITY: usize = 1024;
65const REQUEST_QUEUE_CAPACITY: usize = 1;
66const STREAM_QUEUE_CAPACITY: usize = 1024;
67#[cfg(feature = "websocket")]
68const MAX_WEBSOCKET_BUFFER_SIZE: usize = MAX_FRAME_SIZE as usize + 12;
69
70const LEGACY_PROTOCOL_VERSION: u8 = 1;
71// TODO(upgrade-0.6): Remove in 0.6.x or later once live-sandbox
72// compatibility for versions before 0.5 is no longer supported.
73#[cfg(any(feature = "uds", feature = "websocket"))]
74const LEGACY_RELAY_ID_RANGE_STEP: u32 = u32::MAX / 16;
75
76//--------------------------------------------------------------------------------------------------
77// Types
78//--------------------------------------------------------------------------------------------------
79
80/// Agent protocol generation spoken by a connected sandbox relay.
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum AgentProtocol {
83    /// Current protocol generation.
84    Current,
85
86    /// pre-0.5 microsandbox relay handshake and agent protocol.
87    ///
88    /// TODO(upgrade-0.6): Remove in 0.6.x or later once live-sandbox
89    /// compatibility for versions before 0.5 is no longer supported.
90    LegacyV1,
91}
92
93/// Client for communicating with agentd through the agent relay.
94///
95/// See the module-level docs for an overview of the two API tiers.
96pub struct AgentClient {
97    /// Channel to the transport writer task.
98    writer: mpsc::Sender<WriterCommand>,
99    /// Next correlation ID to allocate (starts at `id_min`).
100    next_id: AtomicU32,
101    /// Lower bound (inclusive) of the assigned ID range, used for wrap-around.
102    id_min: u32,
103    /// Upper bound (exclusive) of the assigned ID range.
104    id_max: u32,
105    /// Agent protocol generation for this connection.
106    protocol: AgentProtocol,
107    /// Negotiated protocol generation: `min(our PROTOCOL_VERSION, the
108    /// generation the sandbox echoed in its `core.ready` frame)`. Drives the
109    /// capability gate on the typed send path. Distinct from [`Self::protocol`],
110    /// which selects the wire codec; see `VERSIONING.md`.
111    negotiated_version: u8,
112    /// Pending response channels keyed by correlation ID.
113    pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
114    /// Background reader task handle.
115    reader_handle: JoinHandle<()>,
116    /// Background writer task handle.
117    writer_handle: JoinHandle<()>,
118    /// Cached `core.ready` frame body (raw CBOR bytes) from the relay handshake.
119    ready_body: Vec<u8>,
120    /// Decoded `core.ready` payload from the relay handshake.
121    ready: Ready,
122}
123
124#[cfg(any(feature = "uds", feature = "websocket"))]
125struct AgentHandshake {
126    id_min: u32,
127    id_max: u32,
128    protocol: AgentProtocol,
129    negotiated_version: u8,
130    ready_body: Vec<u8>,
131    ready: Ready,
132}
133
134#[cfg_attr(not(any(feature = "uds", feature = "websocket")), allow(dead_code))]
135struct WriterCommand {
136    frame: RawFrame,
137    ack: oneshot::Sender<AgentClientResult<()>>,
138}
139
140#[cfg(any(feature = "uds", feature = "websocket"))]
141trait HandshakeReader {
142    fn read_exact_handshake<'a>(
143        &'a mut self,
144        out: &'a mut [u8],
145    ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>>;
146
147    fn read_frame_handshake<'a>(
148        &'a mut self,
149    ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>>;
150}
151
152//--------------------------------------------------------------------------------------------------
153// Methods: Connection lifecycle
154//--------------------------------------------------------------------------------------------------
155
156impl AgentProtocol {
157    fn version(self) -> u8 {
158        match self {
159            Self::Current => PROTOCOL_VERSION,
160            Self::LegacyV1 => LEGACY_PROTOCOL_VERSION,
161        }
162    }
163}
164
165impl AgentClient {
166    /// Connect to a Unix domain socket agent relay using the default 10s
167    /// handshake timeout.
168    #[cfg(feature = "uds")]
169    pub async fn connect(sock_path: impl AsRef<Path>) -> AgentClientResult<Self> {
170        Self::connect_with_timeout(sock_path, DEFAULT_HANDSHAKE_TIMEOUT).await
171    }
172
173    /// Connect to a Unix domain socket agent relay using an explicit
174    /// handshake timeout.
175    #[cfg(feature = "uds")]
176    pub async fn connect_with_timeout(
177        sock_path: impl AsRef<Path>,
178        timeout: Duration,
179    ) -> AgentClientResult<Self> {
180        let deadline = Instant::now() + timeout;
181        Self::connect_with_deadline(sock_path, deadline).await
182    }
183
184    /// Connect with an explicit handshake deadline.
185    ///
186    /// `deadline` bounds both handshake reads. Without it, an accepted
187    /// connection that stalls (e.g. a sandbox alive but wedged before
188    /// writing the handshake bytes) would block this call indefinitely.
189    #[cfg(feature = "uds")]
190    pub async fn connect_with_deadline(
191        sock_path: impl AsRef<Path>,
192        deadline: Instant,
193    ) -> AgentClientResult<Self> {
194        let sock_path = sock_path.as_ref();
195        let stream =
196            UnixStream::connect(sock_path)
197                .await
198                .map_err(|source| AgentClientError::Connect {
199                    path: sock_path.to_path_buf(),
200                    source,
201                })?;
202
203        let (mut reader, writer) = stream.into_split();
204        let handshake = perform_handshake(&mut reader, deadline).await?;
205
206        tracing::info!(
207            id_min = handshake.id_min,
208            id_max = handshake.id_max,
209            protocol = ?handshake.protocol,
210            ready_bytes = handshake.ready_body.len(),
211            boot_time_ns = handshake.ready.boot_time_ns,
212            "agent client: connected to relay"
213        );
214        if handshake.protocol == AgentProtocol::LegacyV1 {
215            // TODO(upgrade-0.6): Remove in 0.6.x or later once live-sandbox
216            // compatibility for versions before 0.5 is no longer supported.
217            tracing::warn!(
218                "agent client: connected to a sandbox started before microsandbox 0.5; exec compatibility is temporary and filesystem/SFTP require stop/start"
219            );
220        }
221
222        let pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>> =
223            Arc::new(Mutex::new(HashMap::new()));
224
225        let (writer_tx, writer_rx) = mpsc::channel(WRITER_QUEUE_CAPACITY);
226        let reader_handle = tokio::spawn(reader_loop(reader, Arc::clone(&pending)));
227        let writer_handle = tokio::spawn(uds_writer_loop(writer, writer_rx));
228
229        Ok(Self {
230            writer: writer_tx,
231            next_id: AtomicU32::new(first_request_id(handshake.id_min)),
232            id_min: handshake.id_min,
233            id_max: handshake.id_max,
234            protocol: handshake.protocol,
235            negotiated_version: handshake.negotiated_version,
236            pending,
237            reader_handle,
238            writer_handle,
239            ready_body: handshake.ready_body,
240            ready: handshake.ready,
241        })
242    }
243
244    /// Connect to an agent relay over WebSocket using the default 10s
245    /// handshake timeout.
246    #[cfg(feature = "websocket")]
247    pub async fn connect_websocket(url: &str) -> AgentClientResult<Self> {
248        Self::connect_websocket_with_timeout(url, DEFAULT_HANDSHAKE_TIMEOUT).await
249    }
250
251    /// Connect to an agent relay over WebSocket using an explicit handshake
252    /// timeout.
253    #[cfg(feature = "websocket")]
254    pub async fn connect_websocket_with_timeout(
255        url: &str,
256        timeout: Duration,
257    ) -> AgentClientResult<Self> {
258        let deadline = Instant::now() + timeout;
259        Self::connect_websocket_with_deadline(url, deadline).await
260    }
261
262    /// Connect to an agent relay over WebSocket with an explicit handshake
263    /// deadline.
264    #[cfg(feature = "websocket")]
265    pub async fn connect_websocket_with_deadline(
266        url: &str,
267        deadline: Instant,
268    ) -> AgentClientResult<Self> {
269        let (stream, _) = tokio_tungstenite::connect_async(url)
270            .await
271            .map_err(|e| AgentClientError::WebSocket(format!("connect {url}: {e}")))?;
272        let (writer, reader) = stream.split();
273        let mut reader = WebSocketByteReader::new(reader);
274        let handshake = perform_handshake(&mut reader, deadline).await?;
275        if handshake.protocol == AgentProtocol::LegacyV1 {
276            // TODO(upgrade-0.6): Remove in 0.6.x or later once live-sandbox
277            // compatibility for versions before 0.5 is no longer supported.
278            tracing::warn!(
279                "agent client: connected to a sandbox started before microsandbox 0.5; exec compatibility is temporary and filesystem/SFTP require stop/start"
280            );
281        }
282
283        let pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>> =
284            Arc::new(Mutex::new(HashMap::new()));
285        let (writer_tx, writer_rx) = mpsc::channel(WRITER_QUEUE_CAPACITY);
286        let reader_handle = tokio::spawn(websocket_reader_loop(reader, Arc::clone(&pending)));
287        let writer_handle = tokio::spawn(websocket_writer_loop(writer, writer_rx));
288
289        Ok(Self {
290            writer: writer_tx,
291            next_id: AtomicU32::new(first_request_id(handshake.id_min)),
292            id_min: handshake.id_min,
293            id_max: handshake.id_max,
294            protocol: handshake.protocol,
295            negotiated_version: handshake.negotiated_version,
296            pending,
297            reader_handle,
298            writer_handle,
299            ready_body: handshake.ready_body,
300            ready: handshake.ready,
301        })
302    }
303
304    /// Close the connection. Drops the writer and aborts the reader task;
305    /// any in-flight requests resolve with [`AgentClientError::Closed`].
306    pub async fn close(self) {
307        // Drop runs: reader aborts via Drop impl, writer closes when the
308        // last Arc reference dies. Senders in `pending` drop with self,
309        // resolving outstanding waiters.
310    }
311}
312
313//--------------------------------------------------------------------------------------------------
314// Methods: Raw transport (CBOR-blind)
315//--------------------------------------------------------------------------------------------------
316
317impl AgentClient {
318    /// One-shot raw request: alloc id, send a frame with `(flags, body)`,
319    /// await one response frame with the matching id.
320    ///
321    /// Use this for protocol RPCs that produce exactly one terminal response
322    /// (e.g. `FsRequest` → `FsResponse`).
323    pub async fn request_raw(&self, flags: u8, body: Vec<u8>) -> AgentClientResult<RawFrame> {
324        let (tx, mut rx) = mpsc::channel(REQUEST_QUEUE_CAPACITY);
325        let id = self.reserve_id(tx).await?;
326
327        if let Err(e) = self.write_frame_owned(id, flags, body).await {
328            self.pending.lock().await.remove(&id);
329            return Err(e);
330        }
331
332        let frame = rx.recv().await.ok_or(AgentClientError::ReaderClosed(id))?;
333        self.pending.lock().await.remove(&id);
334        Ok(frame)
335    }
336
337    /// Open a streaming raw session: alloc id, register a subscription,
338    /// send the opening frame, return `(id, receiver)`.
339    ///
340    /// The receiver yields every frame the relay forwards for this `id`
341    /// until a frame with [`FLAG_TERMINAL`] arrives or the receiver is dropped.
342    /// Use [`send_raw`](Self::send_raw) with the returned id to send
343    /// follow-up frames within the session.
344    pub async fn stream_raw(
345        &self,
346        flags: u8,
347        body: Vec<u8>,
348    ) -> AgentClientResult<(u32, mpsc::Receiver<RawFrame>)> {
349        let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
350        let id = self.reserve_id(tx).await?;
351
352        if let Err(e) = self.write_frame_owned(id, flags, body).await {
353            self.pending.lock().await.remove(&id);
354            return Err(e);
355        }
356
357        Ok((id, rx))
358    }
359
360    /// Send a follow-up raw frame on an existing correlation id.
361    ///
362    /// Use for messages that belong to a session started via
363    /// [`stream_raw`](Self::stream_raw) (e.g. `ExecStdin`, `ExecSignal`,
364    /// `ExecResize`, `FsData` chunks).
365    pub async fn send_raw(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
366        self.write_frame(id, flags, body).await
367    }
368
369    /// The cached `core.ready` handshake frame body bytes (CBOR-encoded).
370    ///
371    /// Useful for bindings that want to deserialize the ready payload with
372    /// their own CBOR tooling. For typed access, use [`ready`](Self::ready).
373    pub fn ready_bytes(&self) -> &[u8] {
374        &self.ready_body
375    }
376
377    /// Agent protocol generation for this connection.
378    pub fn protocol(&self) -> AgentProtocol {
379        self.protocol
380    }
381
382    /// Returns `true` if this connection is using the legacy pre-0.5 protocol.
383    pub fn is_legacy_protocol(&self) -> bool {
384        self.protocol == AgentProtocol::LegacyV1
385    }
386
387    /// The negotiated protocol generation for this connection: the lower of what
388    /// this client speaks and what the sandbox advertised at handshake.
389    pub fn negotiated_version(&self) -> u8 {
390        self.negotiated_version
391    }
392
393    /// The runtime's self-reported package version, taken from its `core.ready`
394    /// frame. Empty when the runtime predates this field (an older agent), in
395    /// which case fall back to the generation for diagnostics.
396    pub fn agent_version(&self) -> &str {
397        &self.ready.agent_version
398    }
399
400    /// Whether the connected sandbox is new enough to handle the given message
401    /// type. The single source of truth for feature gating: callers that can't
402    /// gate by sending (e.g. the SSH/SFTP layer) consult this instead of
403    /// inspecting the protocol generation directly.
404    pub fn supports(&self, t: MessageType) -> bool {
405        t.min_protocol_version() <= self.negotiated_version
406    }
407
408    /// Reject a message type the connected sandbox is too old to handle, against
409    /// this connection's negotiated generation. Fails before any bytes are sent,
410    /// so only that one operation fails and the session continues.
411    pub fn ensure_version_compat(&self, t: MessageType) -> AgentClientResult<()> {
412        Self::ensure_version_compat_for(t, self.negotiated_version)
413    }
414
415    /// Check a message type against an explicit negotiated generation.
416    ///
417    /// The single place the rule lives. Exposed for callers that hold the
418    /// negotiated generation but not the live client (e.g. the SSH/SFTP layer).
419    pub fn ensure_version_compat_for(t: MessageType, negotiated: u8) -> AgentClientResult<()> {
420        if t.is_available_at(negotiated) {
421            return Ok(());
422        }
423        Err(AgentClientError::UnsupportedOperation {
424            msg_type: t.as_str(),
425            needs: t.min_protocol_version(),
426            peer: negotiated,
427        })
428    }
429}
430
431//--------------------------------------------------------------------------------------------------
432// Methods: Typed transport (CBOR-aware)
433//--------------------------------------------------------------------------------------------------
434
435impl AgentClient {
436    /// One-shot typed request. Flags are derived from the message type.
437    pub async fn request<T: Serialize>(
438        &self,
439        t: MessageType,
440        payload: &T,
441    ) -> AgentClientResult<Message> {
442        self.ensure_version_compat(t)?;
443        let flags = t.flags();
444        let body = encode_message_body(self.protocol.version(), t, payload)?;
445        let frame = self.request_raw(flags, body).await?;
446        Ok(codec::raw_frame_to_message(frame)?)
447    }
448
449    /// Open a streaming typed session. Flags are derived from the message type.
450    /// Returns the assigned id and a typed receiver.
451    pub async fn stream<T: Serialize>(
452        &self,
453        t: MessageType,
454        payload: &T,
455    ) -> AgentClientResult<(u32, mpsc::Receiver<Message>)> {
456        self.ensure_version_compat(t)?;
457        let flags = t.flags();
458        let body = encode_message_body(self.protocol.version(), t, payload)?;
459        let (id, raw_rx) = self.stream_raw(flags, body).await?;
460
461        let (tx, rx) = mpsc::channel(STREAM_QUEUE_CAPACITY);
462        tokio::spawn(decode_stream_task(raw_rx, tx));
463        Ok((id, rx))
464    }
465
466    /// Send a follow-up typed message on an existing correlation id.
467    pub async fn send<T: Serialize>(
468        &self,
469        id: u32,
470        t: MessageType,
471        payload: &T,
472    ) -> AgentClientResult<()> {
473        self.ensure_version_compat(t)?;
474        let flags = t.flags();
475        let body = encode_message_body(self.protocol.version(), t, payload)?;
476        self.write_frame_owned(id, flags, body).await
477    }
478
479    /// Decode the cached handshake `core.ready` payload.
480    pub fn ready(&self) -> AgentClientResult<Ready> {
481        Ok(self.ready.clone())
482    }
483}
484
485//--------------------------------------------------------------------------------------------------
486// Methods: Internals
487//--------------------------------------------------------------------------------------------------
488
489impl AgentClient {
490    /// Reserve a unique correlation ID from the relay-assigned range.
491    ///
492    /// Wraps around within the assigned range and skips IDs that still have an
493    /// active pending request or stream.
494    async fn reserve_id(&self, tx: mpsc::Sender<RawFrame>) -> AgentClientResult<u32> {
495        let mut pending = self.pending.lock().await;
496        let attempts = usable_id_count(self.id_min, self.id_max);
497        for _ in 0..attempts {
498            let id = self
499                .next_id
500                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
501            if self.next_id.load(std::sync::atomic::Ordering::Relaxed) >= self.id_max {
502                self.next_id.store(
503                    first_request_id(self.id_min),
504                    std::sync::atomic::Ordering::Relaxed,
505                );
506            }
507            if id == 0 || id < self.id_min || id >= self.id_max || pending.contains_key(&id) {
508                continue;
509            }
510            pending.insert(id, tx);
511            return Ok(id);
512        }
513
514        Err(AgentClientError::IdRangeExhausted)
515    }
516
517    /// Write a single framed message to the socket.
518    async fn write_frame(&self, id: u32, flags: u8, body: &[u8]) -> AgentClientResult<()> {
519        self.write_frame_owned(id, flags, body.to_vec()).await
520    }
521
522    /// Write a single framed message to the socket, taking ownership of the body.
523    async fn write_frame_owned(&self, id: u32, flags: u8, body: Vec<u8>) -> AgentClientResult<()> {
524        let (ack, written) = oneshot::channel();
525        self.writer
526            .send(WriterCommand {
527                frame: RawFrame { id, flags, body },
528                ack,
529            })
530            .await
531            .map_err(|_| AgentClientError::Closed)?;
532        written.await.map_err(|_| AgentClientError::Closed)?
533    }
534}
535
536//--------------------------------------------------------------------------------------------------
537// Functions
538//--------------------------------------------------------------------------------------------------
539
540#[cfg(any(feature = "uds", feature = "websocket"))]
541async fn perform_handshake<R>(
542    reader: &mut R,
543    deadline: Instant,
544) -> AgentClientResult<AgentHandshake>
545where
546    R: HandshakeReader + ?Sized,
547{
548    // Current handshake:
549    // [id_min: u32 BE][id_max: u32 BE][ready_frame_bytes...]
550    //
551    // Legacy pre-0.5 handshake:
552    // [id_offset: u32 BE][ready_frame_bytes...]
553    //
554    // Reading 8 bytes up-front lets us distinguish the two forms. For legacy
555    // relays, the second word is the ready-frame length prefix.
556    let mut range_buf = [0u8; 8];
557    tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut range_buf))
558        .await
559        .map_err(|_| {
560            AgentClientError::Handshake("read id range: timed out before relay sent bytes".into())
561        })??;
562    let id_start_or_offset = u32::from_be_bytes(range_buf[0..4].try_into().unwrap());
563    let id_max_or_frame_len = u32::from_be_bytes(range_buf[4..8].try_into().unwrap());
564
565    let legacy_handshake =
566        looks_like_legacy_relay_handshake(id_start_or_offset, id_max_or_frame_len);
567    let (id_min, id_max, ready_frame, protocol) = if legacy_handshake {
568        let id_offset = id_start_or_offset;
569        let ready_frame =
570            read_raw_frame_after_len_prefix(reader, range_buf[4..8].try_into().unwrap(), deadline)
571                .await?;
572        (
573            id_offset.saturating_add(1),
574            id_offset.saturating_add(LEGACY_RELAY_ID_RANGE_STEP),
575            ready_frame,
576            AgentProtocol::LegacyV1,
577        )
578    } else if id_start_or_offset >= id_max_or_frame_len {
579        return Err(AgentClientError::Handshake(format!(
580            "invalid relay id range: start={id_start_or_offset}, end={id_max_or_frame_len}"
581        )));
582    } else {
583        let ready_frame = tokio::time::timeout_at(deadline, reader.read_frame_handshake())
584            .await
585            .map_err(|_| {
586                AgentClientError::Handshake(
587                    "read ready frame: timed out before relay sent frame".into(),
588                )
589            })?
590            .map_err(|e| AgentClientError::Handshake(format!("read ready frame: {e}")))?;
591        (
592            id_start_or_offset,
593            id_max_or_frame_len,
594            ready_frame,
595            AgentProtocol::Current,
596        )
597    };
598    ensure_usable_id_range(id_min, id_max)?;
599
600    let ready_msg = codec::raw_frame_to_message(ready_frame.clone())
601        .map_err(|e| AgentClientError::Handshake(format!("decode ready frame: {e}")))?;
602    if ready_msg.t != MessageType::Ready {
603        return Err(AgentClientError::Handshake(format!(
604            "expected core.ready frame, got {}",
605            ready_msg.t.as_str()
606        )));
607    }
608    let ready: Ready = ready_msg
609        .payload()
610        .map_err(|e| AgentClientError::Handshake(format!("decode ready payload: {e}")))?;
611
612    // The negotiated capability generation is the lower of what we speak and
613    // what the sandbox echoed in its ready frame (`ready_msg.v`). For the
614    // load-bearing case — a newer host meeting an older runtime — this is the
615    // runtime's generation, so the send gate withholds features it can't
616    // handle. The codec generation (`protocol`) is negotiated separately.
617    let negotiated_version = protocol.version().min(ready_msg.v);
618
619    Ok(AgentHandshake {
620        id_min,
621        id_max,
622        protocol,
623        negotiated_version,
624        ready_body: ready_frame.body,
625        ready,
626    })
627}
628
629fn first_request_id(id_min: u32) -> u32 {
630    id_min.max(1)
631}
632
633#[cfg(any(feature = "uds", feature = "websocket"))]
634fn ensure_usable_id_range(id_min: u32, id_max: u32) -> AgentClientResult<()> {
635    if usable_id_count(id_min, id_max) == 0 {
636        return Err(AgentClientError::Handshake(format!(
637            "relay id range contains no usable nonzero ids: start={id_min}, end={id_max}"
638        )));
639    }
640    Ok(())
641}
642
643fn usable_id_count(id_min: u32, id_max: u32) -> u32 {
644    id_max.saturating_sub(first_request_id(id_min))
645}
646
647#[cfg(any(feature = "uds", feature = "websocket"))]
648fn looks_like_legacy_relay_handshake(id_min: u32, id_max: u32) -> bool {
649    // TODO(upgrade-0.6): Remove in 0.6.x or later once pre-0.5 relay
650    // handshakes are no longer accepted.
651    // In the legacy relay handshake, the first 4 bytes are the id offset and
652    // the next 4 bytes are already the ready-frame length prefix. In the v2
653    // handshake, the second word is the exclusive upper id bound, which is far
654    // larger than any valid frame length. Tiny current ranges are possible in
655    // tests, so prefer the current interpretation when the range is otherwise
656    // valid and starts at a nonzero id.
657    id_max >= FRAME_HEADER_SIZE as u32
658        && id_max <= MAX_FRAME_SIZE
659        && (id_min == 0 || id_min >= id_max)
660}
661
662#[cfg(any(feature = "uds", feature = "websocket"))]
663async fn read_raw_frame_after_len_prefix<R>(
664    reader: &mut R,
665    len_buf: [u8; 4],
666    deadline: Instant,
667) -> AgentClientResult<RawFrame>
668where
669    R: HandshakeReader + ?Sized,
670{
671    let frame_len = u32::from_be_bytes(len_buf);
672    if frame_len > MAX_FRAME_SIZE {
673        return Err(AgentClientError::Handshake(format!(
674            "legacy ready frame too large: {frame_len} bytes (max {MAX_FRAME_SIZE})"
675        )));
676    }
677    if frame_len < FRAME_HEADER_SIZE as u32 {
678        return Err(AgentClientError::Handshake(format!(
679            "legacy ready frame too short: {frame_len} bytes"
680        )));
681    }
682
683    let mut data = vec![0u8; frame_len as usize];
684    tokio::time::timeout_at(deadline, reader.read_exact_handshake(&mut data))
685        .await
686        .map_err(|_| {
687            AgentClientError::Handshake(
688                "read legacy ready frame: timed out before relay sent frame".into(),
689            )
690        })?
691        .map_err(|e| AgentClientError::Handshake(format!("read legacy ready frame: {e}")))?;
692
693    let id = u32::from_be_bytes(data[0..4].try_into().unwrap());
694    let flags = data[4];
695    let body = data[FRAME_HEADER_SIZE..].to_vec();
696
697    Ok(RawFrame { id, flags, body })
698}
699
700#[cfg(feature = "uds")]
701impl<R> HandshakeReader for R
702where
703    R: tokio::io::AsyncRead + Unpin + Send,
704{
705    fn read_exact_handshake<'a>(
706        &'a mut self,
707        out: &'a mut [u8],
708    ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>> {
709        Box::pin(async move {
710            tokio::io::AsyncReadExt::read_exact(self, out)
711                .await
712                .map(|_| ())
713                .map_err(|e| AgentClientError::Handshake(e.to_string()))
714        })
715    }
716
717    fn read_frame_handshake<'a>(
718        &'a mut self,
719    ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>> {
720        Box::pin(async move {
721            codec::read_raw_frame(self)
722                .await
723                .map_err(AgentClientError::Protocol)
724        })
725    }
726}
727
728#[cfg(feature = "uds")]
729async fn uds_writer_loop(
730    mut writer: tokio::net::unix::OwnedWriteHalf,
731    mut rx: mpsc::Receiver<WriterCommand>,
732) {
733    while let Some(command) = rx.recv().await {
734        if let Err(e) = codec::write_raw_frame(&mut writer, &command.frame).await {
735            tracing::debug!("agent client: UDS writer error: {e}");
736            let _ = command.ack.send(Err(AgentClientError::Protocol(e)));
737            break;
738        }
739        let _ = command.ack.send(Ok(()));
740    }
741}
742
743/// Background task that reads frames from the relay and dispatches them to
744/// pending channels by correlation ID. Operates on raw frames — no CBOR.
745#[cfg(feature = "uds")]
746async fn reader_loop<R>(mut reader: R, pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>)
747where
748    R: tokio::io::AsyncRead + Unpin,
749{
750    loop {
751        let frame = match codec::read_raw_frame(&mut reader).await {
752            Ok(frame) => frame,
753            Err(e) => {
754                tracing::debug!("agent client: reader EOF or error: {e}");
755                break;
756            }
757        };
758
759        dispatch_frame(frame, &pending).await;
760    }
761
762    // Reader exited — drop all senders so outstanding receivers wake up.
763    let mut map = pending.lock().await;
764    map.clear();
765}
766
767#[cfg(feature = "websocket")]
768struct WebSocketByteReader<S> {
769    stream: S,
770    buffer: Vec<u8>,
771    cursor: usize,
772}
773
774#[cfg(feature = "websocket")]
775impl<S> WebSocketByteReader<S>
776where
777    S: futures_util::Stream<Item = Result<WebSocketMessage, tokio_tungstenite::tungstenite::Error>>
778        + Unpin
779        + Send,
780{
781    fn new(stream: S) -> Self {
782        Self {
783            stream,
784            buffer: Vec::new(),
785            cursor: 0,
786        }
787    }
788
789    async fn read_exact(&mut self, out: &mut [u8]) -> AgentClientResult<()> {
790        while self.available_bytes() < out.len() {
791            let Some(message) = self.stream.next().await else {
792                return Err(AgentClientError::WebSocket(
793                    "websocket closed before enough bytes were available".to_string(),
794                ));
795            };
796            match message.map_err(|e| AgentClientError::WebSocket(e.to_string()))? {
797                WebSocketMessage::Binary(bytes) => {
798                    if self.available_bytes() + bytes.len() > MAX_WEBSOCKET_BUFFER_SIZE {
799                        return Err(AgentClientError::WebSocket(format!(
800                            "websocket buffer exceeded maximum size: {} bytes (max {MAX_WEBSOCKET_BUFFER_SIZE})",
801                            self.available_bytes() + bytes.len()
802                        )));
803                    }
804                    self.compact_consumed();
805                    self.buffer.extend_from_slice(&bytes);
806                }
807                WebSocketMessage::Close(_) => {
808                    return Err(AgentClientError::WebSocket(
809                        "websocket closed before enough bytes were available".to_string(),
810                    ));
811                }
812                WebSocketMessage::Ping(_) | WebSocketMessage::Pong(_) => {}
813                WebSocketMessage::Text(_) | WebSocketMessage::Frame(_) => {
814                    return Err(AgentClientError::WebSocket(
815                        "websocket message is not binary".to_string(),
816                    ));
817                }
818            }
819        }
820
821        out.copy_from_slice(&self.buffer[self.cursor..self.cursor + out.len()]);
822        self.cursor += out.len();
823        self.compact_consumed();
824        Ok(())
825    }
826
827    fn available_bytes(&self) -> usize {
828        self.buffer.len().saturating_sub(self.cursor)
829    }
830
831    fn compact_consumed(&mut self) {
832        if self.cursor == 0 {
833            return;
834        }
835        if self.cursor == self.buffer.len() {
836            self.buffer.clear();
837        } else {
838            self.buffer.drain(..self.cursor);
839        }
840        self.cursor = 0;
841    }
842
843    async fn read_raw_frame(&mut self) -> AgentClientResult<RawFrame> {
844        let mut len_buf = [0u8; 4];
845        self.read_exact(&mut len_buf).await?;
846        let frame_len = u32::from_be_bytes(len_buf);
847        if frame_len > MAX_FRAME_SIZE {
848            return Err(AgentClientError::Protocol(
849                microsandbox_protocol::ProtocolError::FrameTooLarge {
850                    size: frame_len,
851                    max: MAX_FRAME_SIZE,
852                },
853            ));
854        }
855        if frame_len < FRAME_HEADER_SIZE as u32 {
856            return Err(AgentClientError::Protocol(
857                microsandbox_protocol::ProtocolError::FrameTooShort {
858                    size: frame_len,
859                    min: FRAME_HEADER_SIZE as u32,
860                },
861            ));
862        }
863
864        let mut payload = vec![0u8; frame_len as usize];
865        self.read_exact(&mut payload).await?;
866        let id = u32::from_be_bytes(payload[0..4].try_into().unwrap());
867        let flags = payload[4];
868        let body = payload[FRAME_HEADER_SIZE..].to_vec();
869        Ok(RawFrame { id, flags, body })
870    }
871}
872
873#[cfg(feature = "websocket")]
874impl<S> HandshakeReader for WebSocketByteReader<S>
875where
876    S: futures_util::Stream<Item = Result<WebSocketMessage, tokio_tungstenite::tungstenite::Error>>
877        + Unpin
878        + Send,
879{
880    fn read_exact_handshake<'a>(
881        &'a mut self,
882        out: &'a mut [u8],
883    ) -> Pin<Box<dyn Future<Output = AgentClientResult<()>> + Send + 'a>> {
884        Box::pin(async move { self.read_exact(out).await })
885    }
886
887    fn read_frame_handshake<'a>(
888        &'a mut self,
889    ) -> Pin<Box<dyn Future<Output = AgentClientResult<RawFrame>> + Send + 'a>> {
890        Box::pin(async move { self.read_raw_frame().await })
891    }
892}
893
894#[cfg(feature = "websocket")]
895async fn websocket_reader_loop<S>(
896    mut reader: WebSocketByteReader<S>,
897    pending: Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
898) where
899    S: futures_util::Stream<Item = Result<WebSocketMessage, tokio_tungstenite::tungstenite::Error>>
900        + Unpin
901        + Send,
902{
903    loop {
904        let frame = match reader.read_raw_frame().await {
905            Ok(frame) => frame,
906            Err(e) => {
907                tracing::debug!("agent client: websocket reader EOF or error: {e}");
908                break;
909            }
910        };
911
912        dispatch_frame(frame, &pending).await;
913    }
914
915    let mut map = pending.lock().await;
916    map.clear();
917}
918
919#[cfg(feature = "websocket")]
920async fn websocket_writer_loop<S>(mut writer: S, mut rx: mpsc::Receiver<WriterCommand>)
921where
922    S: futures_util::Sink<WebSocketMessage, Error = tokio_tungstenite::tungstenite::Error> + Unpin,
923{
924    while let Some(command) = rx.recv().await {
925        let mut buf = Vec::new();
926        if let Err(e) = codec::encode_raw_to_buf(&command.frame, &mut buf) {
927            tracing::debug!("agent client: websocket encode error: {e}");
928            let _ = command.ack.send(Err(AgentClientError::Protocol(e)));
929            break;
930        }
931        if let Err(e) = writer.send(WebSocketMessage::Binary(buf.into())).await {
932            tracing::debug!("agent client: websocket writer error: {e}");
933            let _ = command
934                .ack
935                .send(Err(AgentClientError::WebSocket(e.to_string())));
936            break;
937        }
938        let _ = command.ack.send(Ok(()));
939    }
940}
941
942#[cfg(any(feature = "uds", feature = "websocket"))]
943async fn dispatch_frame(
944    frame: RawFrame,
945    pending: &Arc<Mutex<HashMap<u32, mpsc::Sender<RawFrame>>>>,
946) {
947    let id = frame.id;
948    let is_terminal = (frame.flags & FLAG_TERMINAL) != 0;
949
950    let tx = {
951        let mut map = pending.lock().await;
952        let Some(tx) = map.get(&id).cloned() else {
953            tracing::trace!("agent client: no pending handler for id={id}");
954            return;
955        };
956        if is_terminal {
957            map.remove(&id);
958        }
959        tx
960    };
961
962    if tx.send(frame).await.is_err() {
963        pending.lock().await.remove(&id);
964    }
965}
966
967/// Translate a stream of raw frames into typed messages.
968async fn decode_stream_task(mut raw_rx: mpsc::Receiver<RawFrame>, tx: mpsc::Sender<Message>) {
969    while let Some(frame) = raw_rx.recv().await {
970        match codec::raw_frame_to_message(frame) {
971            Ok(msg) => {
972                if tx.send(msg).await.is_err() {
973                    break;
974                }
975            }
976            Err(e) => {
977                tracing::warn!("agent client: failed to decode frame in stream: {e}");
978                // Continue — single malformed frame shouldn't kill the stream.
979            }
980        }
981    }
982}
983
984/// Encode a typed payload to a CBOR `Message` body.
985fn encode_message_body<T: Serialize>(
986    version: u8,
987    t: MessageType,
988    payload: &T,
989) -> AgentClientResult<Vec<u8>> {
990    let mut msg = Message::with_payload(t, 0, payload)?;
991    msg.v = version;
992    let mut body = Vec::new();
993    ciborium::into_writer(&msg, &mut body).map_err(microsandbox_protocol::ProtocolError::from)?;
994    Ok(body)
995}
996
997//--------------------------------------------------------------------------------------------------
998// Tests
999//--------------------------------------------------------------------------------------------------
1000
1001#[cfg(test)]
1002mod tests {
1003    #[cfg(any(feature = "uds", feature = "websocket"))]
1004    use microsandbox_protocol::core::Ready;
1005    #[cfg(any(feature = "uds", feature = "websocket"))]
1006    use microsandbox_protocol::exec::ExecRequest;
1007    #[cfg(feature = "websocket")]
1008    use microsandbox_protocol::fs::{FsOp, FsRequest, FsResponse};
1009    #[cfg(feature = "uds")]
1010    use microsandbox_protocol::message::PROTOCOL_VERSION;
1011    #[cfg(feature = "uds")]
1012    use tokio::io::AsyncWriteExt;
1013    #[cfg(feature = "websocket")]
1014    use tokio::net::TcpListener;
1015    #[cfg(feature = "uds")]
1016    use tokio::net::UnixListener;
1017    #[cfg(any(feature = "uds", feature = "websocket"))]
1018    use tokio::sync::oneshot;
1019
1020    use super::*;
1021
1022    #[cfg(feature = "uds")]
1023    #[tokio::test]
1024    async fn connect_decodes_ready_payload() {
1025        let temp = tempfile::tempdir().unwrap();
1026        let sock_path = temp.path().join("agent.sock");
1027        let listener = UnixListener::bind(&sock_path).unwrap();
1028        let ready = Ready {
1029            boot_time_ns: 11,
1030            init_time_ns: 22,
1031            ready_time_ns: 33,
1032            agent_version: "9.9.9".to_string(),
1033        };
1034        let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1035
1036        tokio::spawn(async move {
1037            let (mut socket, _) = listener.accept().await.unwrap();
1038            socket.write_all(&1u32.to_be_bytes()).await.unwrap();
1039            socket.write_all(&8u32.to_be_bytes()).await.unwrap();
1040            codec::write_message(&mut socket, &ready_msg).await.unwrap();
1041        });
1042
1043        let client =
1044            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1045                .await
1046                .unwrap();
1047
1048        assert_eq!(client.protocol(), AgentProtocol::Current);
1049        // Both peers speak the current generation, so that is what is negotiated.
1050        assert_eq!(client.negotiated_version(), PROTOCOL_VERSION);
1051        assert!(client.supports(MessageType::FsRequest));
1052        // The runtime's self-reported version round-trips from the ready frame.
1053        assert_eq!(client.agent_version(), "9.9.9");
1054        let decoded = client.ready().unwrap();
1055        assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
1056        assert_eq!(decoded.init_time_ns, ready.init_time_ns);
1057        assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
1058
1059        let raw_msg: Message = ciborium::from_reader(client.ready_bytes()).unwrap();
1060        assert_eq!(raw_msg.t, MessageType::Ready);
1061    }
1062
1063    #[cfg(feature = "uds")]
1064    #[tokio::test]
1065    async fn connect_negotiates_down_to_older_guest_generation() {
1066        let temp = tempfile::tempdir().unwrap();
1067        let sock_path = temp.path().join("agent.sock");
1068        let listener = UnixListener::bind(&sock_path).unwrap();
1069        let ready = Ready {
1070            boot_time_ns: 1,
1071            init_time_ns: 2,
1072            ready_time_ns: 3,
1073            ..Default::default()
1074        };
1075        // A current-codec guest that advertises an older capability generation in
1076        // its ready frame (a runtime one generation behind this host).
1077        let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1078        ready_msg.v = 1;
1079
1080        tokio::spawn(async move {
1081            let (mut socket, _) = listener.accept().await.unwrap();
1082            socket.write_all(&1u32.to_be_bytes()).await.unwrap();
1083            socket
1084                .write_all(&microsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
1085                .await
1086                .unwrap();
1087            codec::write_message(&mut socket, &ready_msg).await.unwrap();
1088        });
1089
1090        let client =
1091            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1092                .await
1093                .unwrap();
1094
1095        // Current codec, but the capability gate is pinned to the guest's older
1096        // generation: min(host PROTOCOL_VERSION, guest's advertised 1) == 1.
1097        assert_eq!(client.protocol(), AgentProtocol::Current);
1098        assert_eq!(client.negotiated_version(), 1);
1099        // Exec is in the baseline; filesystem is not, at generation 1.
1100        assert!(client.supports(MessageType::ExecRequest));
1101        assert!(!client.supports(MessageType::FsRequest));
1102    }
1103
1104    #[cfg(feature = "uds")]
1105    #[tokio::test]
1106    async fn connect_accepts_legacy_relay_handshake() {
1107        assert_accepts_legacy_relay_handshake(0).await;
1108        assert_accepts_legacy_relay_handshake(268_435_455).await;
1109    }
1110
1111    #[cfg(feature = "uds")]
1112    #[tokio::test]
1113    async fn legacy_relay_requests_use_v1_and_legacy_id_range() {
1114        let temp = tempfile::tempdir().unwrap();
1115        let sock_path = temp.path().join("agent.sock");
1116        let listener = UnixListener::bind(&sock_path).unwrap();
1117        let ready = Ready {
1118            boot_time_ns: 11,
1119            init_time_ns: 22,
1120            ready_time_ns: 33,
1121            ..Default::default()
1122        };
1123        let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1124        let id_offset = 268_435_455u32;
1125        let (frame_tx, frame_rx) = oneshot::channel();
1126
1127        tokio::spawn(async move {
1128            let (mut socket, _) = listener.accept().await.unwrap();
1129            socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
1130            codec::write_message(&mut socket, &ready_msg).await.unwrap();
1131            let frame = codec::read_raw_frame(&mut socket).await.unwrap();
1132            frame_tx.send(frame).unwrap();
1133        });
1134
1135        let client =
1136            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1137                .await
1138                .unwrap();
1139        let request = ExecRequest {
1140            cmd: "/bin/true".into(),
1141            args: Vec::new(),
1142            env: Vec::new(),
1143            cwd: None,
1144            user: None,
1145            tty: false,
1146            rows: 24,
1147            cols: 80,
1148            rlimits: Vec::new(),
1149        };
1150        let (id, _rx) = client
1151            .stream(MessageType::ExecRequest, &request)
1152            .await
1153            .unwrap();
1154
1155        let frame = frame_rx.await.unwrap();
1156        let message = codec::raw_frame_to_message(frame).unwrap();
1157
1158        assert_eq!(id, id_offset + 1);
1159        assert_eq!(message.id, id_offset + 1);
1160        assert_eq!(message.v, LEGACY_PROTOCOL_VERSION);
1161        assert_eq!(message.t, MessageType::ExecRequest);
1162    }
1163
1164    #[test]
1165    fn version_compat_across_generations() {
1166        use MessageType::{ExecRequest, FsRequest};
1167        // (message type, peer generation, expected allowed). Generation 1 is the
1168        // pre-0.5 legacy runtime (no filesystem); generation 2 introduced the
1169        // Fs* types; generation 5 is current.
1170        let cases = [
1171            (ExecRequest, 1, true),
1172            (ExecRequest, 2, true),
1173            (ExecRequest, 3, true),
1174            (FsRequest, 1, false),
1175            (FsRequest, 2, true),
1176            (FsRequest, 3, true),
1177        ];
1178        for (t, generation, allowed) in cases {
1179            assert_eq!(
1180                AgentClient::ensure_version_compat_for(t, generation).is_ok(),
1181                allowed,
1182                "{t:?} at generation {generation}"
1183            );
1184        }
1185    }
1186
1187    #[test]
1188    fn version_compat_rejection_is_typed() {
1189        // Filesystem on the legacy (generation 1) runtime is rejected before any
1190        // send, with the structured error whose message tells the user to restart.
1191        let err =
1192            AgentClient::ensure_version_compat_for(MessageType::FsRequest, LEGACY_PROTOCOL_VERSION)
1193                .unwrap_err();
1194        assert!(matches!(
1195            err,
1196            AgentClientError::UnsupportedOperation {
1197                needs: 2,
1198                peer: 1,
1199                ..
1200            }
1201        ));
1202    }
1203
1204    #[cfg(feature = "uds")]
1205    #[tokio::test]
1206    async fn connect_preserves_current_peer_protocol_version() {
1207        let temp = tempfile::tempdir().unwrap();
1208        let sock_path = temp.path().join("agent.sock");
1209        let listener = UnixListener::bind(&sock_path).unwrap();
1210        let ready = Ready {
1211            boot_time_ns: 11,
1212            init_time_ns: 22,
1213            ready_time_ns: 33,
1214            ..Default::default()
1215        };
1216        let mut ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1217        ready_msg.v = 2;
1218
1219        tokio::spawn(async move {
1220            let (mut socket, _) = listener.accept().await.unwrap();
1221            socket.write_all(&1u32.to_be_bytes()).await.unwrap();
1222            socket
1223                .write_all(&microsandbox_protocol::AGENT_RELAY_ID_RANGE_STEP.to_be_bytes())
1224                .await
1225                .unwrap();
1226            codec::write_message(&mut socket, &ready_msg).await.unwrap();
1227        });
1228
1229        let client =
1230            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1231                .await
1232                .unwrap();
1233
1234        assert_eq!(client.protocol(), AgentProtocol::Current);
1235        // The runtime reported generation 2, so that is the negotiated capability.
1236        assert_eq!(client.negotiated_version(), 2);
1237        // TCP forwarding (generation 4) is unavailable to a generation-2 runtime.
1238        assert!(!client.supports(MessageType::TcpConnect));
1239    }
1240
1241    #[cfg(feature = "uds")]
1242    async fn assert_accepts_legacy_relay_handshake(id_offset: u32) {
1243        let temp = tempfile::tempdir().unwrap();
1244        let sock_path = temp.path().join("agent.sock");
1245        let listener = UnixListener::bind(&sock_path).unwrap();
1246        let ready = Ready {
1247            boot_time_ns: 11,
1248            init_time_ns: 22,
1249            ready_time_ns: 33,
1250            ..Default::default()
1251        };
1252        let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1253
1254        tokio::spawn(async move {
1255            let (mut socket, _) = listener.accept().await.unwrap();
1256            socket.write_all(&id_offset.to_be_bytes()).await.unwrap();
1257            codec::write_message(&mut socket, &ready_msg).await.unwrap();
1258        });
1259
1260        let client =
1261            AgentClient::connect_with_deadline(&sock_path, Instant::now() + Duration::from_secs(1))
1262                .await
1263                .unwrap();
1264
1265        assert_eq!(client.protocol(), AgentProtocol::LegacyV1);
1266        assert_eq!(client.negotiated_version(), LEGACY_PROTOCOL_VERSION);
1267        let decoded = client.ready().unwrap();
1268        assert_eq!(decoded.boot_time_ns, ready.boot_time_ns);
1269        assert_eq!(decoded.init_time_ns, ready.init_time_ns);
1270        assert_eq!(decoded.ready_time_ns, ready.ready_time_ns);
1271    }
1272
1273    #[cfg(feature = "websocket")]
1274    #[tokio::test]
1275    async fn websocket_connects_and_completes_request() {
1276        use futures_util::{SinkExt, StreamExt};
1277        use tokio_tungstenite::accept_async;
1278        use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
1279
1280        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1281        let addr = listener.local_addr().unwrap();
1282        let ready = Ready {
1283            boot_time_ns: 11,
1284            init_time_ns: 22,
1285            ready_time_ns: 33,
1286            agent_version: "ws-test".to_string(),
1287        };
1288
1289        tokio::spawn(async move {
1290            let (stream, _) = listener.accept().await.unwrap();
1291            let mut ws = accept_async(stream).await.unwrap();
1292
1293            let mut range = Vec::new();
1294            range.extend_from_slice(&1u32.to_be_bytes());
1295            range.extend_from_slice(&1024u32.to_be_bytes());
1296            ws.send(WebSocketMessage::Binary(range.into()))
1297                .await
1298                .unwrap();
1299
1300            let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1301            let mut ready_packet = Vec::new();
1302            codec::encode_to_buf(&ready_msg, &mut ready_packet).unwrap();
1303            ws.send(WebSocketMessage::Binary(ready_packet.into()))
1304                .await
1305                .unwrap();
1306
1307            let request_packet = loop {
1308                match ws.next().await.unwrap().unwrap() {
1309                    WebSocketMessage::Binary(bytes) => break bytes.to_vec(),
1310                    _ => continue,
1311                }
1312            };
1313            let mut request_buf = request_packet;
1314            let request = codec::try_decode_raw_from_buf(&mut request_buf)
1315                .unwrap()
1316                .unwrap();
1317            let request_msg = codec::raw_frame_to_message(request.clone()).unwrap();
1318            assert_eq!(request_msg.t, MessageType::FsRequest);
1319
1320            let response = Message::with_payload(
1321                MessageType::FsResponse,
1322                request.id,
1323                &FsResponse {
1324                    ok: true,
1325                    error: None,
1326                    data: None,
1327                },
1328            )
1329            .unwrap();
1330            let mut response_packet = Vec::new();
1331            codec::encode_to_buf(&response, &mut response_packet).unwrap();
1332            ws.send(WebSocketMessage::Binary(response_packet.into()))
1333                .await
1334                .unwrap();
1335        });
1336
1337        let client = AgentClient::connect_websocket(&format!("ws://{addr}"))
1338            .await
1339            .unwrap();
1340        assert_eq!(client.agent_version(), "ws-test");
1341
1342        let response = client
1343            .request(
1344                MessageType::FsRequest,
1345                &FsRequest {
1346                    op: FsOp::Stat {
1347                        path: "/tmp".to_string(),
1348                        follow_symlink: true,
1349                    },
1350                },
1351            )
1352            .await
1353            .unwrap();
1354        assert_eq!(response.t, MessageType::FsResponse);
1355        let payload: FsResponse = response.payload().unwrap();
1356        assert!(payload.ok);
1357    }
1358
1359    #[cfg(feature = "websocket")]
1360    #[tokio::test]
1361    async fn websocket_accepts_legacy_relay_handshake() {
1362        use futures_util::{SinkExt, StreamExt};
1363        use tokio_tungstenite::accept_async;
1364        use tokio_tungstenite::tungstenite::Message as WebSocketMessage;
1365
1366        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1367        let addr = listener.local_addr().unwrap();
1368        let ready = Ready {
1369            boot_time_ns: 11,
1370            init_time_ns: 22,
1371            ready_time_ns: 33,
1372            agent_version: "legacy-ws-test".to_string(),
1373        };
1374        let id_offset = 268_435_455u32;
1375        let (frame_tx, frame_rx) = oneshot::channel();
1376
1377        tokio::spawn(async move {
1378            let (stream, _) = listener.accept().await.unwrap();
1379            let mut ws = accept_async(stream).await.unwrap();
1380
1381            ws.send(WebSocketMessage::Binary(
1382                id_offset.to_be_bytes().to_vec().into(),
1383            ))
1384            .await
1385            .unwrap();
1386
1387            let ready_msg = Message::with_payload(MessageType::Ready, 0, &ready).unwrap();
1388            let mut ready_packet = Vec::new();
1389            codec::encode_to_buf(&ready_msg, &mut ready_packet).unwrap();
1390            ws.send(WebSocketMessage::Binary(ready_packet.into()))
1391                .await
1392                .unwrap();
1393
1394            let request_packet = loop {
1395                match ws.next().await.unwrap().unwrap() {
1396                    WebSocketMessage::Binary(bytes) => break bytes.to_vec(),
1397                    _ => continue,
1398                }
1399            };
1400            let mut request_buf = request_packet;
1401            let request = codec::try_decode_raw_from_buf(&mut request_buf)
1402                .unwrap()
1403                .unwrap();
1404            frame_tx.send(request).unwrap();
1405        });
1406
1407        let client = AgentClient::connect_websocket(&format!("ws://{addr}"))
1408            .await
1409            .unwrap();
1410        assert_eq!(client.protocol(), AgentProtocol::LegacyV1);
1411        assert_eq!(client.negotiated_version(), LEGACY_PROTOCOL_VERSION);
1412        assert_eq!(client.agent_version(), "legacy-ws-test");
1413
1414        let request = ExecRequest {
1415            cmd: "/bin/true".into(),
1416            args: Vec::new(),
1417            env: Vec::new(),
1418            cwd: None,
1419            user: None,
1420            tty: false,
1421            rows: 24,
1422            cols: 80,
1423            rlimits: Vec::new(),
1424        };
1425        let (id, _rx) = client
1426            .stream(MessageType::ExecRequest, &request)
1427            .await
1428            .unwrap();
1429        let frame = frame_rx.await.unwrap();
1430        let message = codec::raw_frame_to_message(frame).unwrap();
1431
1432        assert_eq!(id, id_offset + 1);
1433        assert_eq!(message.id, id_offset + 1);
1434        assert_eq!(message.v, LEGACY_PROTOCOL_VERSION);
1435        assert_eq!(message.t, MessageType::ExecRequest);
1436    }
1437}
1438
1439//--------------------------------------------------------------------------------------------------
1440// Trait Implementations
1441//--------------------------------------------------------------------------------------------------
1442
1443impl Drop for AgentClient {
1444    fn drop(&mut self) {
1445        self.reader_handle.abort();
1446        self.writer_handle.abort();
1447    }
1448}