Skip to main content

selium_remote_client/
lib.rs

1//! External client library for interacting with Selium runtimes and their guests.
2//!
3//! You **should** use this library in code that _does not_ run on a Selium runtime.
4//! You **should not** use this library in code that _does_ run on a Selium runtime.
5//!
6//! This library can be used for:
7//!  1. Orchestration - creating channels, managing processes, etc.
8//!  2. Data I/O - publishing payloads, calling RPC servers, etc.
9//!
10//! # Examples
11//! ```no_run
12//! use futures::{SinkExt, StreamExt, pin_mut};
13//! use selium_remote_client::{Channel, ClientConfigBuilder, ClientError};
14//!
15//! #[tokio::main]
16//! async fn main() -> Result<(), ClientError> {
17//!     let client = ClientConfigBuilder::default().connect().await?;
18//!     let chunk_size = 64 * 1024;
19//!     let channel = Channel::create(&client, chunk_size).await?;
20//!
21//!     let mut publisher = channel.publish().await?;
22//!     publisher.send(b"ping".to_vec()).await?;
23//!
24//!     let mut subscriber = channel.subscribe(chunk_size).await?;
25//!     pin_mut!(subscriber);
26//!     if let Some(frame) = subscriber.next().await {
27//!         let payload = frame?;
28//!         eprintln!("received {} bytes", payload.len());
29//!     }
30//!
31//!     Ok(())
32//! }
33//! ```
34#![deny(missing_docs)]
35/// Protocol types shared with the remote client control plane.
36pub use selium_remote_client_protocol::{
37    AbiParam, AbiScalarType, AbiScalarValue, AbiSignature, Capability, EntrypointArg,
38    GuestResourceId, GuestUint,
39};
40use std::{
41    collections::VecDeque,
42    fmt::{Debug, Display},
43    fs,
44    io::ErrorKind,
45    net::SocketAddr,
46    path::{Path, PathBuf},
47    pin::Pin,
48    sync::Arc,
49    task::{Context, Poll},
50    time::Duration,
51};
52
53use futures::{Sink, Stream};
54use quinn::{
55    ClientConfig, ConnectError, Connection, ConnectionError, Endpoint, RecvStream, SendStream,
56    TransportConfig, WriteError, crypto::rustls::QuicClientConfig, rustls,
57};
58use rustls::{
59    RootCertStore,
60    pki_types::{CertificateDer, PrivateKeyDer},
61};
62use rustls_pki_types::{PrivatePkcs1KeyDer, PrivatePkcs8KeyDer, pem::SliceIter};
63use selium_remote_client_protocol::{
64    ChannelRef, ProcessStartRequest, Request, Response, decode_response, encode_request,
65};
66use thiserror::Error;
67use tokio::net::lookup_host;
68use tracing::debug;
69
70type Result<T> = std::result::Result<T, ClientError>;
71
72/// Default domain exposed by the remote-client guest.
73pub const DEFAULT_DOMAIN: &str = "localhost";
74/// Default port exposed by the remote-client guest.
75pub const DEFAULT_PORT: u16 = 7000;
76/// Maximum size allowed for RPC server replies.
77pub const DEFAULT_RESPONSE_LIMIT: usize = 8 * 1024;
78
79/// Errors returned by the Selium client.
80#[derive(Debug, Error)]
81pub enum ClientError {
82    /// X.509 material could not be parsed.
83    #[error("failed to parse certificate: {0}")]
84    Certificate(String),
85    /// TLS configuration could not be constructed.
86    #[error("failed to build TLS config: {0}")]
87    Tls(#[source] rustls::Error),
88    /// DNS resolution failed for the given target.
89    #[error("failed to resolve {target}: {source}")]
90    Resolve {
91        /// The `domain:port` string that was resolved.
92        target: String,
93        /// The underlying resolution error.
94        #[source]
95        source: std::io::Error,
96    },
97    /// Failed to bind a local QUIC endpoint.
98    #[error("failed to open client endpoint: {0}")]
99    Endpoint(#[source] std::io::Error),
100    /// The QUIC connection attempt failed.
101    #[error("failed to connect: {0}")]
102    Connect(#[source] ConnectError),
103    /// An established QUIC connection returned an error.
104    #[error("connection failed: {0}")]
105    Connection(#[source] ConnectionError),
106    /// A write to the QUIC stream failed.
107    #[error("stream write failed: {0}")]
108    Write(#[source] WriteError),
109    /// The QUIC send stream could not be finished cleanly.
110    #[error("stream finish failed: {0}")]
111    Finish(String),
112    /// A read from the QUIC stream failed.
113    #[error("stream read failed: {0}")]
114    Read(String),
115    /// Failed to encode a request payload.
116    #[error("encode request: {0}")]
117    Encode(String),
118    /// Failed to decode a response payload.
119    #[error("decode response: {0}")]
120    Decode(String),
121    /// Input arguments were rejected locally (before contacting the remote).
122    #[error("invalid request: {0}")]
123    InvalidArgument(&'static str),
124    /// The remote `remote-client` guest returned an error string.
125    #[error("remote error: {0}")]
126    Remote(String),
127    /// The remote `remote-client` guest sent an unexpected response.
128    #[error("unexpected response from remote client")]
129    UnexpectedResponse,
130}
131
132/// Configures a [`Client`] connection to the runtime.
133#[derive(Clone, Debug)]
134pub struct ClientConfigBuilder {
135    domain: String,
136    port: u16,
137    response_limit: usize,
138    cert_dir: PathBuf,
139}
140
141/// Connection to a Selium runtime.
142#[derive(Clone, Debug)]
143pub struct Client {
144    inner: Arc<ClientInner>,
145}
146
147#[derive(Debug)]
148struct ClientInner {
149    endpoint: Endpoint,
150    server_addr: SocketAddr,
151    server_name: String,
152    response_limit: usize,
153}
154
155/// Data pipelines that transmit bytes bidirectionally.
156#[derive(Clone, Debug)]
157pub struct Channel {
158    client: Client,
159    handle: GuestResourceId,
160}
161
162/// Handle to a running process in the runtime.
163#[derive(Clone, Debug)]
164pub struct Process {
165    client: Client,
166    handle: GuestResourceId,
167}
168
169/// Configures a process to be launched in the runtime.
170#[derive(Clone, Debug, PartialEq)]
171pub struct ProcessBuilder {
172    module_id: String,
173    entrypoint: String,
174    log_uri: Option<String>,
175    capabilities: Vec<Capability>,
176    signature: AbiSignature,
177    args: Vec<EntrypointArg>,
178}
179
180struct QuicSession {
181    connection: Connection,
182    send: SendStream,
183    recv: RecvStream,
184}
185
186enum PublishState {
187    Ready(SendStream),
188    Writing(Pin<Box<dyn std::future::Future<Output = Result<SendStream>> + Send>>),
189    Closed,
190}
191
192/// Sink that drains payloads into the channel.
193pub struct Publisher {
194    _connection: Connection,
195    state: PublishState,
196}
197
198/// Stream of payloads read from a channel.
199pub struct Subscriber {
200    inner: Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + Send>>,
201}
202
203impl Default for ClientConfigBuilder {
204    /// Create a builder using the default Selium client settings.
205    fn default() -> Self {
206        Self {
207            domain: DEFAULT_DOMAIN.to_string(),
208            port: DEFAULT_PORT,
209            response_limit: DEFAULT_RESPONSE_LIMIT,
210            cert_dir: default_cert_dir(),
211        }
212    }
213}
214
215impl ClientConfigBuilder {
216    /// Override the target domain name.
217    pub fn domain(mut self, domain: impl Into<String>) -> Self {
218        self.domain = domain.into();
219        self
220    }
221
222    /// Override the target port.
223    pub fn port(mut self, port: u16) -> Self {
224        self.port = port;
225        self
226    }
227
228    /// Override the maximum response size accepted from the remote client.
229    pub fn response_limit(mut self, limit: usize) -> Self {
230        self.response_limit = limit.max(1);
231        self
232    }
233
234    /// Override the directory containing the client and CA certificates.
235    pub fn certificate_directory(mut self, dir: impl Into<PathBuf>) -> Self {
236        self.cert_dir = dir.into();
237        self
238    }
239
240    /// Build a [`Client`] with the provided settings.
241    ///
242    /// # Examples
243    /// ```no_run
244    /// # async fn example() -> Result<(), selium_remote_client::ClientError> {
245    /// let _client = selium_remote_client::ClientConfigBuilder::default()
246    ///     .domain("api.selium.io")
247    ///     .port(7000)
248    ///     .connect()
249    ///     .await?;
250    /// # Ok(())
251    /// # }
252    /// ```
253    pub async fn connect(self) -> Result<Client> {
254        Client::connect_with(self).await
255    }
256}
257
258impl Client {
259    /// Construct a client using the defaults from [`ClientConfigBuilder`].
260    pub async fn connect() -> Result<Self> {
261        Client::connect_with(ClientConfigBuilder::default()).await
262    }
263
264    async fn connect_with(config: ClientConfigBuilder) -> Result<Self> {
265        let server_addr = resolve_socket(&config.domain, config.port).await?;
266        let bind_addr = match server_addr {
267            SocketAddr::V4(_) => SocketAddr::from(([0, 0, 0, 0], 0)),
268            SocketAddr::V6(_) => SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0)),
269        };
270
271        let mut endpoint = Endpoint::client(bind_addr).map_err(ClientError::Endpoint)?;
272        endpoint.set_default_client_config(build_client_config(&config.cert_dir)?);
273
274        Ok(Self {
275            inner: Arc::new(ClientInner {
276                endpoint,
277                server_addr,
278                server_name: config.domain,
279                response_limit: config.response_limit,
280            }),
281        })
282    }
283
284    async fn open_session(&self) -> Result<QuicSession> {
285        let connecting = self
286            .inner
287            .endpoint
288            .connect(self.inner.server_addr, &self.inner.server_name)
289            .map_err(ClientError::Connect)?;
290        let connection = connecting.await.map_err(ClientError::Connection)?;
291        let (send, recv) = connection
292            .open_bi()
293            .await
294            .map_err(ClientError::Connection)?;
295        Ok(QuicSession {
296            connection,
297            send,
298            recv,
299        })
300    }
301
302    async fn request(&self, request: Request) -> Result<Response> {
303        let mut session = self.open_session().await?;
304        let payload =
305            encode_request(&request).map_err(|err| ClientError::Encode(err.to_string()))?;
306        debug!("sending request of {} bytes", payload.len());
307        session
308            .send
309            .write_all(&payload)
310            .await
311            .map_err(ClientError::Write)?;
312        session
313            .send
314            .finish()
315            .map_err(|_| ClientError::Finish("stream closed".to_string()))?;
316        debug!("request sent, awaiting response");
317        let mut buffer = VecDeque::new();
318        let response =
319            read_response(&mut session.recv, self.inner.response_limit, &mut buffer).await?;
320        match response {
321            Response::Error(message) => Err(ClientError::Remote(message)),
322            other => Ok(other),
323        }
324    }
325}
326
327impl Channel {
328    /// Create a channel using the remote-client guest.
329    pub async fn create(client: &Client, capacity: u32) -> Result<Self> {
330        let response = client.request(Request::ChannelCreate(capacity)).await?;
331        match response {
332            Response::ChannelCreate(handle) => Ok(Self::new(client, handle)),
333            Response::Error(msg) => Err(ClientError::Remote(msg)),
334            _ => Err(ClientError::UnexpectedResponse),
335        }
336    }
337
338    /// Construct a handle to an existing channel.
339    ///
340    /// # Safety
341    /// `handle` must be a valid channel capability minted for the current client; forged or stale
342    /// handles will be rejected by the remote guest and may cause the connection to be closed.
343    pub fn new(client: &Client, handle: GuestResourceId) -> Self {
344        Self {
345            client: client.clone(),
346            handle,
347        }
348    }
349
350    /// Delete this channel.
351    pub async fn delete(self) -> Result<()> {
352        let response = self
353            .client
354            .request(Request::ChannelDelete(self.handle))
355            .await?;
356        match response {
357            Response::Ok => Ok(()),
358            Response::Error(msg) => Err(ClientError::Remote(msg)),
359            _ => Err(ClientError::UnexpectedResponse),
360        }
361    }
362
363    /// Subscribe to this channel as a [`Subscriber`] of byte payloads.
364    ///
365    /// # Examples
366    /// ```no_run
367    /// use futures::{StreamExt, pin_mut};
368    ///
369    /// # async fn example(channel: &selium_remote_client::Channel) -> Result<(), selium_remote_client::ClientError> {
370    /// let mut subscriber = channel.subscribe(64 * 1024).await?;
371    /// pin_mut!(subscriber);
372    /// if let Some(frame) = subscriber.next().await {
373    ///     let payload = frame?;
374    ///     eprintln!("received {} bytes", payload.len());
375    /// }
376    /// # Ok(())
377    /// # }
378    /// ```
379    pub async fn subscribe(&self, chunk_size: u32) -> Result<Subscriber> {
380        self.subscribe_inner(ChannelRef::Strong(self.handle), chunk_size)
381            .await
382    }
383
384    /// Subscribe to this channel as a [`Subscriber`] of byte payloads.
385    ///
386    /// # Examples
387    /// ```no_run
388    /// use futures::{StreamExt, pin_mut};
389    ///
390    /// # async fn example(channel: &selium_remote_client::Channel) -> Result<(), selium_remote_client::ClientError> {
391    /// let mut subscriber = channel.subscribe_shared(64 * 1024).await?;
392    /// pin_mut!(subscriber);
393    /// if let Some(frame) = subscriber.next().await {
394    ///     let payload = frame?;
395    ///     eprintln!("received {} bytes", payload.len());
396    /// }
397    /// # Ok(())
398    /// # }
399    /// ```
400    pub async fn subscribe_shared(&self, chunk_size: u32) -> Result<Subscriber> {
401        self.subscribe_inner(ChannelRef::Shared(self.handle), chunk_size)
402            .await
403    }
404
405    async fn subscribe_inner(&self, target: ChannelRef, chunk_size: u32) -> Result<Subscriber> {
406        let mut session = self.client.open_session().await?;
407        let chunk_size = GuestUint::try_from(chunk_size)
408            .map_err(|_| ClientError::InvalidArgument("chunk size exceeds u32::MAX"))?;
409        let payload = encode_request(&Request::Subscribe(target, chunk_size))
410            .map_err(|err| ClientError::Encode(err.to_string()))?;
411        session
412            .send
413            .write_all(&payload)
414            .await
415            .map_err(ClientError::Write)?;
416        session
417            .send
418            .finish()
419            .map_err(|_| ClientError::Finish("stream closed".to_string()))?;
420        let max_frame = usize::try_from(chunk_size)
421            .map_err(|_| ClientError::InvalidArgument("chunk size exceeds usize::MAX"))?;
422        let connection = session.connection.clone();
423        let stream = futures::stream::unfold(
424            (session.recv, VecDeque::new(), max_frame, connection),
425            move |(mut recv, mut buffer, max_frame, connection)| async move {
426                match read_subscribed_frame(&mut recv, &mut buffer, max_frame).await {
427                    Ok(Some(frame)) => Some((Ok(frame), (recv, buffer, max_frame, connection))),
428                    Ok(None) => None,
429                    Err(err) => Some((Err(err), (recv, buffer, max_frame, connection))),
430                }
431            },
432        );
433        Ok(Subscriber::new(stream))
434    }
435
436    /// Create a publisher for this channel. The returned [`Sink`] streams raw payloads.
437    ///
438    /// # Examples
439    /// ```no_run
440    /// use futures::SinkExt;
441    ///
442    /// # async fn example(channel: &selium_remote_client::Channel) -> Result<(), selium_remote_client::ClientError> {
443    /// let mut publisher = channel.publish().await?;
444    /// publisher.send(b"hello".to_vec()).await?;
445    /// publisher.close().await?;
446    /// # Ok(())
447    /// # }
448    /// ```
449    pub async fn publish(&self) -> Result<Publisher> {
450        let mut session = self.client.open_session().await?;
451        let payload = encode_request(&Request::Publish(self.handle))
452            .map_err(|err| ClientError::Encode(err.to_string()))?;
453        session
454            .send
455            .write_all(&payload)
456            .await
457            .map_err(ClientError::Write)?;
458
459        let mut buffer = VecDeque::new();
460        read_response_once(
461            &mut session.recv,
462            self.client.inner.response_limit,
463            &mut buffer,
464        )
465        .await?;
466
467        Ok(Publisher {
468            _connection: session.connection,
469            state: PublishState::Ready(session.send),
470        })
471    }
472
473    /// Expose the raw handle.
474    pub fn handle(&self) -> GuestResourceId {
475        self.handle
476    }
477}
478
479impl ProcessBuilder {
480    /// Create a new builder targeting the supplied module and entrypoint.
481    pub fn new(module_id: impl Into<String>, entrypoint: impl Into<String>) -> Self {
482        Self {
483            module_id: module_id.into(),
484            entrypoint: entrypoint.into(),
485            log_uri: None,
486            capabilities: vec![Capability::ChannelLifecycle, Capability::ChannelWriter],
487            signature: AbiSignature::new(Vec::new(), Vec::new()),
488            args: Vec::new(),
489        }
490    }
491
492    /// Add a capability the process should receive.
493    pub fn capability(mut self, capability: Capability) -> Self {
494        if !self.capabilities.contains(&capability) {
495            self.capabilities.push(capability);
496        }
497        self
498    }
499
500    /// Specify the entrypoint ABI signature.
501    pub fn signature(mut self, signature: AbiSignature) -> Self {
502        self.signature = signature;
503        self
504    }
505
506    /// Provide a log URI for the process when Atlas logging is enabled.
507    pub fn log_uri(mut self, value: impl Into<String>) -> Self {
508        self.log_uri = Some(value.into());
509        self
510    }
511
512    /// Append a scalar argument.
513    pub fn arg_scalar(mut self, value: AbiScalarValue) -> Self {
514        self.args.push(EntrypointArg::Scalar(value));
515        self
516    }
517
518    /// Append a UTF-8 string argument.
519    pub fn arg_utf8(self, value: impl Into<String>) -> Self {
520        self.arg_buffer(value.into().into_bytes())
521    }
522
523    /// Append a raw buffer argument.
524    pub fn arg_buffer(mut self, value: impl Into<Vec<u8>>) -> Self {
525        self.args.push(EntrypointArg::Buffer(value.into()));
526        self
527    }
528
529    /// Append a process handle argument.
530    pub fn arg_resource(mut self, handle: impl Into<GuestResourceId>) -> Self {
531        self.args.push(EntrypointArg::Resource(handle.into()));
532        self
533    }
534
535    fn build_request(self) -> Result<ProcessStartRequest> {
536        validate_entrypoint_args(&self.signature, &self.args)?;
537
538        Ok(ProcessStartRequest {
539            module_id: self.module_id,
540            entrypoint: self.entrypoint,
541            log_uri: self.log_uri,
542            capabilities: self.capabilities,
543            signature: self.signature,
544            args: self.args,
545        })
546    }
547}
548
549fn validate_entrypoint_args(signature: &AbiSignature, args: &[EntrypointArg]) -> Result<()> {
550    if signature.params().len() != args.len() {
551        return Err(ClientError::InvalidArgument(
552            "arguments do not satisfy the signature",
553        ));
554    }
555
556    for (param, arg) in signature.params().iter().zip(args.iter()) {
557        match (param, arg) {
558            (AbiParam::Scalar(expected), EntrypointArg::Scalar(actual))
559                if actual.kind() == *expected => {}
560            (AbiParam::Scalar(AbiScalarType::I32), EntrypointArg::Resource(_)) => {}
561            (AbiParam::Scalar(AbiScalarType::U64), EntrypointArg::Resource(_)) => {}
562            (AbiParam::Buffer, EntrypointArg::Buffer(_)) => {}
563            _ => {
564                return Err(ClientError::InvalidArgument(
565                    "arguments do not satisfy the signature",
566                ));
567            }
568        }
569    }
570
571    Ok(())
572}
573
574impl Process {
575    /// Construct a handle to an existing process.
576    ///
577    /// # Safety
578    /// `handle` must be a valid process capability minted for the current client; forged or stale
579    /// handles will be rejected by the remote guest and may cause the connection to be closed.
580    pub fn new(client: &Client, handle: GuestResourceId) -> Self {
581        Self {
582            client: client.clone(),
583            handle,
584        }
585    }
586
587    /// Start a process using the provided builder.
588    pub async fn start(client: &Client, builder: ProcessBuilder) -> Result<Self> {
589        let request = builder.build_request()?;
590        let response = client.request(Request::ProcessStart(request)).await?;
591        match response {
592            Response::ProcessStart(handle) => Ok(Self::new(client, handle)),
593            Response::Error(msg) => Err(ClientError::Remote(msg)),
594            _ => Err(ClientError::UnexpectedResponse),
595        }
596    }
597
598    /// Access the raw process handle.
599    pub fn handle(&self) -> GuestResourceId {
600        self.handle
601    }
602
603    /// Fetch the shared logging channel handle for this process.
604    pub async fn log_channel(&self) -> Result<Channel> {
605        let response = self
606            .client
607            .request(Request::ProcessLogChannel(self.handle))
608            .await?;
609        match response {
610            Response::ProcessLogChannel(handle) => Ok(Channel::new(&self.client, handle)),
611            Response::Error(msg) => Err(ClientError::Remote(msg)),
612            _ => Err(ClientError::UnexpectedResponse),
613        }
614    }
615
616    /// Stop the referenced process.
617    pub async fn stop(self) -> Result<()> {
618        let response = self
619            .client
620            .request(Request::ProcessStop(self.handle))
621            .await?;
622        match response {
623            Response::Ok => Ok(()),
624            Response::Error(msg) => Err(ClientError::Remote(msg)),
625            _ => Err(ClientError::UnexpectedResponse),
626        }
627    }
628}
629
630impl Display for Process {
631    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632        write!(f, "Process({})", self.handle)
633    }
634}
635
636impl Subscriber {
637    fn new(stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'static) -> Self {
638        Self {
639            inner: Box::pin(stream),
640        }
641    }
642}
643
644impl Stream for Subscriber {
645    type Item = Result<Vec<u8>>;
646
647    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
648        let this = self.get_mut();
649        this.inner.as_mut().poll_next(cx)
650    }
651}
652
653impl Sink<Vec<u8>> for Publisher {
654    type Error = ClientError;
655
656    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
657        let this = self.get_mut();
658        match &mut this.state {
659            PublishState::Ready(_) => Poll::Ready(Ok(())),
660            PublishState::Writing(fut) => match fut.as_mut().poll(cx) {
661                Poll::Ready(Ok(stream)) => {
662                    this.state = PublishState::Ready(stream);
663                    Poll::Ready(Ok(()))
664                }
665                Poll::Ready(Err(err)) => {
666                    this.state = PublishState::Closed;
667                    Poll::Ready(Err(err))
668                }
669                Poll::Pending => Poll::Pending,
670            },
671            PublishState::Closed => Poll::Ready(Err(ClientError::UnexpectedResponse)),
672        }
673    }
674
675    fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<()> {
676        let this = self.get_mut();
677        match std::mem::replace(&mut this.state, PublishState::Closed) {
678            PublishState::Ready(stream) => {
679                this.state = PublishState::Writing(Box::pin(write_once(stream, item)));
680                Ok(())
681            }
682            other => {
683                this.state = other;
684                Err(ClientError::InvalidArgument("publisher not ready"))
685            }
686        }
687    }
688
689    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
690        self.poll_ready(cx)
691    }
692
693    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
694        let this = self.get_mut();
695        loop {
696            match std::mem::replace(&mut this.state, PublishState::Closed) {
697                PublishState::Ready(mut stream) => {
698                    stream
699                        .finish()
700                        .map_err(|_| ClientError::Finish("stream closed".to_string()))?;
701                    this.state = PublishState::Closed;
702                    return Poll::Ready(Ok(()));
703                }
704                PublishState::Writing(mut fut) => match fut.as_mut().poll(cx) {
705                    Poll::Ready(Ok(stream)) => {
706                        this.state = PublishState::Ready(stream);
707                    }
708                    Poll::Ready(Err(err)) => {
709                        this.state = PublishState::Closed;
710                        return Poll::Ready(Err(err));
711                    }
712                    Poll::Pending => {
713                        this.state = PublishState::Writing(fut);
714                        return Poll::Pending;
715                    }
716                },
717                PublishState::Closed => return Poll::Ready(Ok(())),
718            }
719        }
720    }
721}
722
723async fn read_subscribed_frame(
724    recv: &mut RecvStream,
725    buffer: &mut VecDeque<u8>,
726    max_frame: usize,
727) -> Result<Option<Vec<u8>>> {
728    loop {
729        if let Some(frame) = try_parse_frame(
730            buffer,
731            max_frame,
732            "frame length exceeds subscribed chunk size",
733        )? {
734            return Ok(Some(frame));
735        }
736
737        let mut chunk = vec![0u8; max_frame.max(4)];
738        match recv.read(&mut chunk).await {
739            Ok(Some(len)) if len > 0 => {
740                chunk.truncate(len);
741                buffer.extend(chunk);
742            }
743            Ok(Some(_)) => {}
744            Ok(None) => {
745                if buffer.is_empty() {
746                    return Ok(None);
747                }
748                return Err(ClientError::Decode(
749                    "stream ended with a partial frame".to_string(),
750                ));
751            }
752            Err(err) => return Err(ClientError::Read(err.to_string())),
753        }
754    }
755}
756
757fn try_parse_frame(
758    buffer: &mut VecDeque<u8>,
759    max_frame: usize,
760    limit_error: &'static str,
761) -> Result<Option<Vec<u8>>> {
762    const LENGTH_PREFIX: usize = 4;
763    if buffer.len() < LENGTH_PREFIX {
764        return Ok(None);
765    }
766
767    let mut len_bytes = [0u8; LENGTH_PREFIX];
768    for (dst, src) in len_bytes.iter_mut().zip(buffer.iter().take(LENGTH_PREFIX)) {
769        *dst = *src;
770    }
771    let frame_len = u32::from_le_bytes(len_bytes) as usize;
772    if frame_len > max_frame {
773        return Err(ClientError::InvalidArgument(limit_error));
774    }
775
776    if buffer.len() < LENGTH_PREFIX + frame_len {
777        return Ok(None);
778    }
779
780    buffer.drain(..LENGTH_PREFIX);
781    let payload = buffer.drain(..frame_len).collect();
782    Ok(Some(payload))
783}
784
785async fn read_response_frame(
786    recv: &mut RecvStream,
787    limit: usize,
788    buffer: &mut VecDeque<u8>,
789) -> Result<Vec<u8>> {
790    let max_buffer = limit.saturating_add(4);
791    loop {
792        if let Some(frame) = try_parse_frame(
793            buffer,
794            limit,
795            "response frame length exceeds response limit",
796        )? {
797            return Ok(frame);
798        }
799
800        let mut chunk = vec![0u8; max_buffer.saturating_sub(buffer.len()).max(4)];
801        match recv.read(&mut chunk).await {
802            Ok(Some(len)) if len > 0 => {
803                debug!("read {len} bytes from response stream");
804                chunk.truncate(len);
805                buffer.extend(chunk);
806                if buffer.len() > max_buffer {
807                    return Err(ClientError::Decode(
808                        "response exceeded maximum size".to_string(),
809                    ));
810                }
811            }
812            Ok(Some(_)) => {}
813            Ok(None) => {
814                return Err(ClientError::Decode(
815                    "response stream ended before frame".to_string(),
816                ));
817            }
818            Err(err) => return Err(ClientError::Read(err.to_string())),
819        }
820    }
821}
822
823async fn read_response(
824    recv: &mut RecvStream,
825    limit: usize,
826    buffer: &mut VecDeque<u8>,
827) -> Result<Response> {
828    let payload = read_response_frame(recv, limit, buffer).await?;
829    decode_response(&payload).map_err(|err| ClientError::Decode(err.to_string()))
830}
831
832async fn read_response_once(
833    recv: &mut RecvStream,
834    limit: usize,
835    buffer: &mut VecDeque<u8>,
836) -> Result<()> {
837    let response = read_response(recv, limit, buffer).await?;
838    match response {
839        Response::Ok => Ok(()),
840        Response::Error(msg) => Err(ClientError::Remote(msg)),
841        _ => Err(ClientError::UnexpectedResponse),
842    }
843}
844
845async fn resolve_socket(domain: &str, port: u16) -> Result<SocketAddr> {
846    let target = format!("{domain}:{port}");
847    let addrs: Vec<SocketAddr> = lookup_host(&target)
848        .await
849        .map_err(|source| ClientError::Resolve {
850            target: target.clone(),
851            source,
852        })?
853        .collect();
854
855    addrs
856        .iter()
857        .copied()
858        .find(SocketAddr::is_ipv4)
859        .or_else(|| addrs.into_iter().next())
860        .ok_or_else(|| ClientError::Resolve {
861            target,
862            source: std::io::Error::new(
863                ErrorKind::AddrNotAvailable,
864                "no socket addresses returned",
865            ),
866        })
867}
868
869fn build_client_config(cert_dir: &Path) -> Result<ClientConfig> {
870    let provider = rustls::crypto::ring::default_provider();
871    let tls_builder = rustls::ClientConfig::builder_with_provider(provider.into())
872        .with_protocol_versions(&[&rustls::version::TLS13])
873        .map_err(ClientError::Tls)?;
874    let roots = build_root_store(cert_dir)?;
875    let client_cert = read_certificate(cert_dir, "client.crt")?;
876    let client_key = read_certificate(cert_dir, "client.key")?;
877    let cert_chain = parse_certificates(&client_cert)?;
878    let key = parse_private_key(&client_key)?;
879    let tls_config = tls_builder
880        .with_root_certificates(roots)
881        .with_client_auth_cert(cert_chain, key)
882        .map_err(ClientError::Tls)?;
883
884    let crypto = QuicClientConfig::try_from(Arc::new(tls_config))
885        .map_err(|_| ClientError::Certificate("failed to select QUIC cipher suite".to_string()))?;
886    let mut client_config = ClientConfig::new(Arc::new(crypto));
887    let mut transport = TransportConfig::default();
888    transport.keep_alive_interval(Some(Duration::from_secs(10)));
889    client_config.transport_config(Arc::new(transport));
890    Ok(client_config)
891}
892
893fn build_root_store(cert_dir: &Path) -> Result<RootCertStore> {
894    let mut store = RootCertStore::empty();
895    let ca_cert = read_certificate(cert_dir, "ca.crt")?;
896    for cert in parse_certificates(&ca_cert)? {
897        store
898            .add(cert)
899            .map_err(|_| ClientError::Certificate("add CA certificate".to_string()))?;
900    }
901    Ok(store)
902}
903
904fn read_certificate(cert_dir: &Path, name: &str) -> Result<Vec<u8>> {
905    let path = cert_dir.join(name);
906    fs::read(&path)
907        .map_err(|err| ClientError::Certificate(format!("read {}: {}", path.display(), err)))
908}
909
910fn default_cert_dir() -> PathBuf {
911    PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../../certs")
912}
913
914fn parse_certificates(bytes: &[u8]) -> Result<Vec<CertificateDer<'static>>> {
915    let parsed: std::result::Result<Vec<_>, _> = SliceIter::new(bytes).collect();
916    let parsed = parsed.map_err(|err| ClientError::Certificate(err.to_string()))?;
917    if !parsed.is_empty() {
918        return Ok(parsed);
919    }
920
921    Ok(vec![CertificateDer::from(bytes.to_vec())])
922}
923
924fn parse_private_key(bytes: &[u8]) -> Result<PrivateKeyDer<'static>> {
925    let pkcs8_result: std::result::Result<Vec<PrivatePkcs8KeyDer>, _> =
926        SliceIter::new(bytes).collect();
927    let pkcs8 = pkcs8_result.map_err(|err| ClientError::Certificate(err.to_string()))?;
928    if let Some(key) = pkcs8.into_iter().next() {
929        return Ok(key.into());
930    }
931
932    let rsa_result: std::result::Result<Vec<PrivatePkcs1KeyDer>, _> =
933        SliceIter::new(bytes).collect();
934    let rsa = rsa_result.map_err(|err| ClientError::Certificate(err.to_string()))?;
935    if let Some(key) = rsa.into_iter().next() {
936        return Ok(key.into());
937    }
938
939    PrivateKeyDer::try_from(bytes.to_vec()).map_err(|_| {
940        ClientError::Certificate("no usable private key found in client key material".to_string())
941    })
942}
943
944async fn write_once(mut stream: SendStream, payload: Vec<u8>) -> Result<SendStream> {
945    if !payload.is_empty() {
946        stream
947            .write_all(&payload)
948            .await
949            .map_err(ClientError::Write)?;
950    }
951    Ok(stream)
952}
953
954#[cfg(test)]
955mod tests {
956    use super::*;
957
958    #[test]
959    fn process_builder_validates_signature() {
960        let signature = AbiSignature::new(Vec::new(), Vec::new());
961        let builder = ProcessBuilder::new("module", "entry")
962            .signature(signature.clone())
963            .arg_resource(7u64);
964
965        let result = builder.build_request();
966        assert!(matches!(result, Err(ClientError::InvalidArgument(_))));
967    }
968
969    #[test]
970    fn request_round_trips() {
971        let req = Request::ChannelCreate(64 * 1024);
972        let encoded = encode_request(&req).expect("encode");
973        let decoded = selium_remote_client_protocol::decode_request(&encoded).expect("decode");
974        match decoded {
975            Request::ChannelCreate(capacity) => assert_eq!(capacity, 64 * 1024),
976            other => panic!("unexpected variant: {other:?}"),
977        }
978    }
979
980    #[test]
981    fn try_parse_frame_requires_complete_prefix() {
982        let mut buffer = VecDeque::from([0u8, 1u8, 2u8]);
983        let frame = try_parse_frame(&mut buffer, 8, "frame length exceeds subscribed chunk size")
984            .expect("parse result");
985        assert!(frame.is_none());
986        assert_eq!(buffer.len(), 3);
987    }
988
989    #[test]
990    fn try_parse_frame_extracts_payload() {
991        let mut buffer = VecDeque::new();
992        buffer.extend([3u8, 0, 0, 0]); // length prefix = 3
993        buffer.extend([1u8, 2, 3]);
994
995        let frame = try_parse_frame(&mut buffer, 8, "frame length exceeds subscribed chunk size")
996            .expect("parse result")
997            .expect("frame available");
998
999        assert_eq!(frame, vec![1u8, 2, 3]);
1000        assert!(buffer.is_empty());
1001    }
1002
1003    #[test]
1004    fn try_parse_frame_rejects_oversized_frame() {
1005        let mut buffer = VecDeque::from([9u8, 0, 0, 0, 1, 2, 3]);
1006        let result = try_parse_frame(&mut buffer, 4, "frame length exceeds subscribed chunk size");
1007        assert!(matches!(result, Err(ClientError::InvalidArgument(_))));
1008    }
1009}