1#![deny(missing_docs)]
35pub use selium_remote_client_protocol::{
37 AbiParam, AbiScalarType, AbiScalarValue, AbiSignature, Capability, EntrypointArg,
38 GuestResourceId, GuestUint,
39};
40use std::{
41 collections::VecDeque,
42 fmt::{Debug, Display},
43 fs,
44 io::ErrorKind,
45 net::SocketAddr,
46 path::{Path, PathBuf},
47 pin::Pin,
48 sync::Arc,
49 task::{Context, Poll},
50 time::Duration,
51};
52
53use futures::{Sink, Stream};
54use quinn::{
55 ClientConfig, ConnectError, Connection, ConnectionError, Endpoint, RecvStream, SendStream,
56 TransportConfig, WriteError, crypto::rustls::QuicClientConfig, rustls,
57};
58use rustls::{
59 RootCertStore,
60 pki_types::{CertificateDer, PrivateKeyDer},
61};
62use rustls_pki_types::{PrivatePkcs1KeyDer, PrivatePkcs8KeyDer, pem::SliceIter};
63use selium_remote_client_protocol::{
64 ChannelRef, ProcessStartRequest, Request, Response, decode_response, encode_request,
65};
66use thiserror::Error;
67use tokio::net::lookup_host;
68use tracing::debug;
69
70type Result<T> = std::result::Result<T, ClientError>;
71
72pub const DEFAULT_DOMAIN: &str = "localhost";
74pub const DEFAULT_PORT: u16 = 7000;
76pub const DEFAULT_RESPONSE_LIMIT: usize = 8 * 1024;
78
79#[derive(Debug, Error)]
81pub enum ClientError {
82 #[error("failed to parse certificate: {0}")]
84 Certificate(String),
85 #[error("failed to build TLS config: {0}")]
87 Tls(#[source] rustls::Error),
88 #[error("failed to resolve {target}: {source}")]
90 Resolve {
91 target: String,
93 #[source]
95 source: std::io::Error,
96 },
97 #[error("failed to open client endpoint: {0}")]
99 Endpoint(#[source] std::io::Error),
100 #[error("failed to connect: {0}")]
102 Connect(#[source] ConnectError),
103 #[error("connection failed: {0}")]
105 Connection(#[source] ConnectionError),
106 #[error("stream write failed: {0}")]
108 Write(#[source] WriteError),
109 #[error("stream finish failed: {0}")]
111 Finish(String),
112 #[error("stream read failed: {0}")]
114 Read(String),
115 #[error("encode request: {0}")]
117 Encode(String),
118 #[error("decode response: {0}")]
120 Decode(String),
121 #[error("invalid request: {0}")]
123 InvalidArgument(&'static str),
124 #[error("remote error: {0}")]
126 Remote(String),
127 #[error("unexpected response from remote client")]
129 UnexpectedResponse,
130}
131
132#[derive(Clone, Debug)]
134pub struct ClientConfigBuilder {
135 domain: String,
136 port: u16,
137 response_limit: usize,
138 cert_dir: PathBuf,
139}
140
141#[derive(Clone, Debug)]
143pub struct Client {
144 inner: Arc<ClientInner>,
145}
146
147#[derive(Debug)]
148struct ClientInner {
149 endpoint: Endpoint,
150 server_addr: SocketAddr,
151 server_name: String,
152 response_limit: usize,
153}
154
155#[derive(Clone, Debug)]
157pub struct Channel {
158 client: Client,
159 handle: GuestResourceId,
160}
161
162#[derive(Clone, Debug)]
164pub struct Process {
165 client: Client,
166 handle: GuestResourceId,
167}
168
169#[derive(Clone, Debug, PartialEq)]
171pub struct ProcessBuilder {
172 module_id: String,
173 entrypoint: String,
174 log_uri: Option<String>,
175 capabilities: Vec<Capability>,
176 signature: AbiSignature,
177 args: Vec<EntrypointArg>,
178}
179
180struct QuicSession {
181 connection: Connection,
182 send: SendStream,
183 recv: RecvStream,
184}
185
186enum PublishState {
187 Ready(SendStream),
188 Writing(Pin<Box<dyn std::future::Future<Output = Result<SendStream>> + Send>>),
189 Closed,
190}
191
192pub struct Publisher {
194 _connection: Connection,
195 state: PublishState,
196}
197
198pub struct Subscriber {
200 inner: Pin<Box<dyn Stream<Item = Result<Vec<u8>>> + Send>>,
201}
202
203impl Default for ClientConfigBuilder {
204 fn default() -> Self {
206 Self {
207 domain: DEFAULT_DOMAIN.to_string(),
208 port: DEFAULT_PORT,
209 response_limit: DEFAULT_RESPONSE_LIMIT,
210 cert_dir: default_cert_dir(),
211 }
212 }
213}
214
215impl ClientConfigBuilder {
216 pub fn domain(mut self, domain: impl Into<String>) -> Self {
218 self.domain = domain.into();
219 self
220 }
221
222 pub fn port(mut self, port: u16) -> Self {
224 self.port = port;
225 self
226 }
227
228 pub fn response_limit(mut self, limit: usize) -> Self {
230 self.response_limit = limit.max(1);
231 self
232 }
233
234 pub fn certificate_directory(mut self, dir: impl Into<PathBuf>) -> Self {
236 self.cert_dir = dir.into();
237 self
238 }
239
240 pub async fn connect(self) -> Result<Client> {
254 Client::connect_with(self).await
255 }
256}
257
258impl Client {
259 pub async fn connect() -> Result<Self> {
261 Client::connect_with(ClientConfigBuilder::default()).await
262 }
263
264 async fn connect_with(config: ClientConfigBuilder) -> Result<Self> {
265 let server_addr = resolve_socket(&config.domain, config.port).await?;
266 let bind_addr = match server_addr {
267 SocketAddr::V4(_) => SocketAddr::from(([0, 0, 0, 0], 0)),
268 SocketAddr::V6(_) => SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], 0)),
269 };
270
271 let mut endpoint = Endpoint::client(bind_addr).map_err(ClientError::Endpoint)?;
272 endpoint.set_default_client_config(build_client_config(&config.cert_dir)?);
273
274 Ok(Self {
275 inner: Arc::new(ClientInner {
276 endpoint,
277 server_addr,
278 server_name: config.domain,
279 response_limit: config.response_limit,
280 }),
281 })
282 }
283
284 async fn open_session(&self) -> Result<QuicSession> {
285 let connecting = self
286 .inner
287 .endpoint
288 .connect(self.inner.server_addr, &self.inner.server_name)
289 .map_err(ClientError::Connect)?;
290 let connection = connecting.await.map_err(ClientError::Connection)?;
291 let (send, recv) = connection
292 .open_bi()
293 .await
294 .map_err(ClientError::Connection)?;
295 Ok(QuicSession {
296 connection,
297 send,
298 recv,
299 })
300 }
301
302 async fn request(&self, request: Request) -> Result<Response> {
303 let mut session = self.open_session().await?;
304 let payload =
305 encode_request(&request).map_err(|err| ClientError::Encode(err.to_string()))?;
306 debug!("sending request of {} bytes", payload.len());
307 session
308 .send
309 .write_all(&payload)
310 .await
311 .map_err(ClientError::Write)?;
312 session
313 .send
314 .finish()
315 .map_err(|_| ClientError::Finish("stream closed".to_string()))?;
316 debug!("request sent, awaiting response");
317 let mut buffer = VecDeque::new();
318 let response =
319 read_response(&mut session.recv, self.inner.response_limit, &mut buffer).await?;
320 match response {
321 Response::Error(message) => Err(ClientError::Remote(message)),
322 other => Ok(other),
323 }
324 }
325}
326
327impl Channel {
328 pub async fn create(client: &Client, capacity: u32) -> Result<Self> {
330 let response = client.request(Request::ChannelCreate(capacity)).await?;
331 match response {
332 Response::ChannelCreate(handle) => Ok(Self::new(client, handle)),
333 Response::Error(msg) => Err(ClientError::Remote(msg)),
334 _ => Err(ClientError::UnexpectedResponse),
335 }
336 }
337
338 pub fn new(client: &Client, handle: GuestResourceId) -> Self {
344 Self {
345 client: client.clone(),
346 handle,
347 }
348 }
349
350 pub async fn delete(self) -> Result<()> {
352 let response = self
353 .client
354 .request(Request::ChannelDelete(self.handle))
355 .await?;
356 match response {
357 Response::Ok => Ok(()),
358 Response::Error(msg) => Err(ClientError::Remote(msg)),
359 _ => Err(ClientError::UnexpectedResponse),
360 }
361 }
362
363 pub async fn subscribe(&self, chunk_size: u32) -> Result<Subscriber> {
380 self.subscribe_inner(ChannelRef::Strong(self.handle), chunk_size)
381 .await
382 }
383
384 pub async fn subscribe_shared(&self, chunk_size: u32) -> Result<Subscriber> {
401 self.subscribe_inner(ChannelRef::Shared(self.handle), chunk_size)
402 .await
403 }
404
405 async fn subscribe_inner(&self, target: ChannelRef, chunk_size: u32) -> Result<Subscriber> {
406 let mut session = self.client.open_session().await?;
407 let chunk_size = GuestUint::try_from(chunk_size)
408 .map_err(|_| ClientError::InvalidArgument("chunk size exceeds u32::MAX"))?;
409 let payload = encode_request(&Request::Subscribe(target, chunk_size))
410 .map_err(|err| ClientError::Encode(err.to_string()))?;
411 session
412 .send
413 .write_all(&payload)
414 .await
415 .map_err(ClientError::Write)?;
416 session
417 .send
418 .finish()
419 .map_err(|_| ClientError::Finish("stream closed".to_string()))?;
420 let max_frame = usize::try_from(chunk_size)
421 .map_err(|_| ClientError::InvalidArgument("chunk size exceeds usize::MAX"))?;
422 let connection = session.connection.clone();
423 let stream = futures::stream::unfold(
424 (session.recv, VecDeque::new(), max_frame, connection),
425 move |(mut recv, mut buffer, max_frame, connection)| async move {
426 match read_subscribed_frame(&mut recv, &mut buffer, max_frame).await {
427 Ok(Some(frame)) => Some((Ok(frame), (recv, buffer, max_frame, connection))),
428 Ok(None) => None,
429 Err(err) => Some((Err(err), (recv, buffer, max_frame, connection))),
430 }
431 },
432 );
433 Ok(Subscriber::new(stream))
434 }
435
436 pub async fn publish(&self) -> Result<Publisher> {
450 let mut session = self.client.open_session().await?;
451 let payload = encode_request(&Request::Publish(self.handle))
452 .map_err(|err| ClientError::Encode(err.to_string()))?;
453 session
454 .send
455 .write_all(&payload)
456 .await
457 .map_err(ClientError::Write)?;
458
459 let mut buffer = VecDeque::new();
460 read_response_once(
461 &mut session.recv,
462 self.client.inner.response_limit,
463 &mut buffer,
464 )
465 .await?;
466
467 Ok(Publisher {
468 _connection: session.connection,
469 state: PublishState::Ready(session.send),
470 })
471 }
472
473 pub fn handle(&self) -> GuestResourceId {
475 self.handle
476 }
477}
478
479impl ProcessBuilder {
480 pub fn new(module_id: impl Into<String>, entrypoint: impl Into<String>) -> Self {
482 Self {
483 module_id: module_id.into(),
484 entrypoint: entrypoint.into(),
485 log_uri: None,
486 capabilities: vec![Capability::ChannelLifecycle, Capability::ChannelWriter],
487 signature: AbiSignature::new(Vec::new(), Vec::new()),
488 args: Vec::new(),
489 }
490 }
491
492 pub fn capability(mut self, capability: Capability) -> Self {
494 if !self.capabilities.contains(&capability) {
495 self.capabilities.push(capability);
496 }
497 self
498 }
499
500 pub fn signature(mut self, signature: AbiSignature) -> Self {
502 self.signature = signature;
503 self
504 }
505
506 pub fn log_uri(mut self, value: impl Into<String>) -> Self {
508 self.log_uri = Some(value.into());
509 self
510 }
511
512 pub fn arg_scalar(mut self, value: AbiScalarValue) -> Self {
514 self.args.push(EntrypointArg::Scalar(value));
515 self
516 }
517
518 pub fn arg_utf8(self, value: impl Into<String>) -> Self {
520 self.arg_buffer(value.into().into_bytes())
521 }
522
523 pub fn arg_buffer(mut self, value: impl Into<Vec<u8>>) -> Self {
525 self.args.push(EntrypointArg::Buffer(value.into()));
526 self
527 }
528
529 pub fn arg_resource(mut self, handle: impl Into<GuestResourceId>) -> Self {
531 self.args.push(EntrypointArg::Resource(handle.into()));
532 self
533 }
534
535 fn build_request(self) -> Result<ProcessStartRequest> {
536 validate_entrypoint_args(&self.signature, &self.args)?;
537
538 Ok(ProcessStartRequest {
539 module_id: self.module_id,
540 entrypoint: self.entrypoint,
541 log_uri: self.log_uri,
542 capabilities: self.capabilities,
543 signature: self.signature,
544 args: self.args,
545 })
546 }
547}
548
549fn validate_entrypoint_args(signature: &AbiSignature, args: &[EntrypointArg]) -> Result<()> {
550 if signature.params().len() != args.len() {
551 return Err(ClientError::InvalidArgument(
552 "arguments do not satisfy the signature",
553 ));
554 }
555
556 for (param, arg) in signature.params().iter().zip(args.iter()) {
557 match (param, arg) {
558 (AbiParam::Scalar(expected), EntrypointArg::Scalar(actual))
559 if actual.kind() == *expected => {}
560 (AbiParam::Scalar(AbiScalarType::I32), EntrypointArg::Resource(_)) => {}
561 (AbiParam::Scalar(AbiScalarType::U64), EntrypointArg::Resource(_)) => {}
562 (AbiParam::Buffer, EntrypointArg::Buffer(_)) => {}
563 _ => {
564 return Err(ClientError::InvalidArgument(
565 "arguments do not satisfy the signature",
566 ));
567 }
568 }
569 }
570
571 Ok(())
572}
573
574impl Process {
575 pub fn new(client: &Client, handle: GuestResourceId) -> Self {
581 Self {
582 client: client.clone(),
583 handle,
584 }
585 }
586
587 pub async fn start(client: &Client, builder: ProcessBuilder) -> Result<Self> {
589 let request = builder.build_request()?;
590 let response = client.request(Request::ProcessStart(request)).await?;
591 match response {
592 Response::ProcessStart(handle) => Ok(Self::new(client, handle)),
593 Response::Error(msg) => Err(ClientError::Remote(msg)),
594 _ => Err(ClientError::UnexpectedResponse),
595 }
596 }
597
598 pub fn handle(&self) -> GuestResourceId {
600 self.handle
601 }
602
603 pub async fn log_channel(&self) -> Result<Channel> {
605 let response = self
606 .client
607 .request(Request::ProcessLogChannel(self.handle))
608 .await?;
609 match response {
610 Response::ProcessLogChannel(handle) => Ok(Channel::new(&self.client, handle)),
611 Response::Error(msg) => Err(ClientError::Remote(msg)),
612 _ => Err(ClientError::UnexpectedResponse),
613 }
614 }
615
616 pub async fn stop(self) -> Result<()> {
618 let response = self
619 .client
620 .request(Request::ProcessStop(self.handle))
621 .await?;
622 match response {
623 Response::Ok => Ok(()),
624 Response::Error(msg) => Err(ClientError::Remote(msg)),
625 _ => Err(ClientError::UnexpectedResponse),
626 }
627 }
628}
629
630impl Display for Process {
631 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632 write!(f, "Process({})", self.handle)
633 }
634}
635
636impl Subscriber {
637 fn new(stream: impl Stream<Item = Result<Vec<u8>>> + Send + 'static) -> Self {
638 Self {
639 inner: Box::pin(stream),
640 }
641 }
642}
643
644impl Stream for Subscriber {
645 type Item = Result<Vec<u8>>;
646
647 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
648 let this = self.get_mut();
649 this.inner.as_mut().poll_next(cx)
650 }
651}
652
653impl Sink<Vec<u8>> for Publisher {
654 type Error = ClientError;
655
656 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
657 let this = self.get_mut();
658 match &mut this.state {
659 PublishState::Ready(_) => Poll::Ready(Ok(())),
660 PublishState::Writing(fut) => match fut.as_mut().poll(cx) {
661 Poll::Ready(Ok(stream)) => {
662 this.state = PublishState::Ready(stream);
663 Poll::Ready(Ok(()))
664 }
665 Poll::Ready(Err(err)) => {
666 this.state = PublishState::Closed;
667 Poll::Ready(Err(err))
668 }
669 Poll::Pending => Poll::Pending,
670 },
671 PublishState::Closed => Poll::Ready(Err(ClientError::UnexpectedResponse)),
672 }
673 }
674
675 fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<()> {
676 let this = self.get_mut();
677 match std::mem::replace(&mut this.state, PublishState::Closed) {
678 PublishState::Ready(stream) => {
679 this.state = PublishState::Writing(Box::pin(write_once(stream, item)));
680 Ok(())
681 }
682 other => {
683 this.state = other;
684 Err(ClientError::InvalidArgument("publisher not ready"))
685 }
686 }
687 }
688
689 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
690 self.poll_ready(cx)
691 }
692
693 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
694 let this = self.get_mut();
695 loop {
696 match std::mem::replace(&mut this.state, PublishState::Closed) {
697 PublishState::Ready(mut stream) => {
698 stream
699 .finish()
700 .map_err(|_| ClientError::Finish("stream closed".to_string()))?;
701 this.state = PublishState::Closed;
702 return Poll::Ready(Ok(()));
703 }
704 PublishState::Writing(mut fut) => match fut.as_mut().poll(cx) {
705 Poll::Ready(Ok(stream)) => {
706 this.state = PublishState::Ready(stream);
707 }
708 Poll::Ready(Err(err)) => {
709 this.state = PublishState::Closed;
710 return Poll::Ready(Err(err));
711 }
712 Poll::Pending => {
713 this.state = PublishState::Writing(fut);
714 return Poll::Pending;
715 }
716 },
717 PublishState::Closed => return Poll::Ready(Ok(())),
718 }
719 }
720 }
721}
722
723async fn read_subscribed_frame(
724 recv: &mut RecvStream,
725 buffer: &mut VecDeque<u8>,
726 max_frame: usize,
727) -> Result<Option<Vec<u8>>> {
728 loop {
729 if let Some(frame) = try_parse_frame(
730 buffer,
731 max_frame,
732 "frame length exceeds subscribed chunk size",
733 )? {
734 return Ok(Some(frame));
735 }
736
737 let mut chunk = vec![0u8; max_frame.max(4)];
738 match recv.read(&mut chunk).await {
739 Ok(Some(len)) if len > 0 => {
740 chunk.truncate(len);
741 buffer.extend(chunk);
742 }
743 Ok(Some(_)) => {}
744 Ok(None) => {
745 if buffer.is_empty() {
746 return Ok(None);
747 }
748 return Err(ClientError::Decode(
749 "stream ended with a partial frame".to_string(),
750 ));
751 }
752 Err(err) => return Err(ClientError::Read(err.to_string())),
753 }
754 }
755}
756
757fn try_parse_frame(
758 buffer: &mut VecDeque<u8>,
759 max_frame: usize,
760 limit_error: &'static str,
761) -> Result<Option<Vec<u8>>> {
762 const LENGTH_PREFIX: usize = 4;
763 if buffer.len() < LENGTH_PREFIX {
764 return Ok(None);
765 }
766
767 let mut len_bytes = [0u8; LENGTH_PREFIX];
768 for (dst, src) in len_bytes.iter_mut().zip(buffer.iter().take(LENGTH_PREFIX)) {
769 *dst = *src;
770 }
771 let frame_len = u32::from_le_bytes(len_bytes) as usize;
772 if frame_len > max_frame {
773 return Err(ClientError::InvalidArgument(limit_error));
774 }
775
776 if buffer.len() < LENGTH_PREFIX + frame_len {
777 return Ok(None);
778 }
779
780 buffer.drain(..LENGTH_PREFIX);
781 let payload = buffer.drain(..frame_len).collect();
782 Ok(Some(payload))
783}
784
785async fn read_response_frame(
786 recv: &mut RecvStream,
787 limit: usize,
788 buffer: &mut VecDeque<u8>,
789) -> Result<Vec<u8>> {
790 let max_buffer = limit.saturating_add(4);
791 loop {
792 if let Some(frame) = try_parse_frame(
793 buffer,
794 limit,
795 "response frame length exceeds response limit",
796 )? {
797 return Ok(frame);
798 }
799
800 let mut chunk = vec![0u8; max_buffer.saturating_sub(buffer.len()).max(4)];
801 match recv.read(&mut chunk).await {
802 Ok(Some(len)) if len > 0 => {
803 debug!("read {len} bytes from response stream");
804 chunk.truncate(len);
805 buffer.extend(chunk);
806 if buffer.len() > max_buffer {
807 return Err(ClientError::Decode(
808 "response exceeded maximum size".to_string(),
809 ));
810 }
811 }
812 Ok(Some(_)) => {}
813 Ok(None) => {
814 return Err(ClientError::Decode(
815 "response stream ended before frame".to_string(),
816 ));
817 }
818 Err(err) => return Err(ClientError::Read(err.to_string())),
819 }
820 }
821}
822
823async fn read_response(
824 recv: &mut RecvStream,
825 limit: usize,
826 buffer: &mut VecDeque<u8>,
827) -> Result<Response> {
828 let payload = read_response_frame(recv, limit, buffer).await?;
829 decode_response(&payload).map_err(|err| ClientError::Decode(err.to_string()))
830}
831
832async fn read_response_once(
833 recv: &mut RecvStream,
834 limit: usize,
835 buffer: &mut VecDeque<u8>,
836) -> Result<()> {
837 let response = read_response(recv, limit, buffer).await?;
838 match response {
839 Response::Ok => Ok(()),
840 Response::Error(msg) => Err(ClientError::Remote(msg)),
841 _ => Err(ClientError::UnexpectedResponse),
842 }
843}
844
845async fn resolve_socket(domain: &str, port: u16) -> Result<SocketAddr> {
846 let target = format!("{domain}:{port}");
847 let addrs: Vec<SocketAddr> = lookup_host(&target)
848 .await
849 .map_err(|source| ClientError::Resolve {
850 target: target.clone(),
851 source,
852 })?
853 .collect();
854
855 addrs
856 .iter()
857 .copied()
858 .find(SocketAddr::is_ipv4)
859 .or_else(|| addrs.into_iter().next())
860 .ok_or_else(|| ClientError::Resolve {
861 target,
862 source: std::io::Error::new(
863 ErrorKind::AddrNotAvailable,
864 "no socket addresses returned",
865 ),
866 })
867}
868
869fn build_client_config(cert_dir: &Path) -> Result<ClientConfig> {
870 let provider = rustls::crypto::ring::default_provider();
871 let tls_builder = rustls::ClientConfig::builder_with_provider(provider.into())
872 .with_protocol_versions(&[&rustls::version::TLS13])
873 .map_err(ClientError::Tls)?;
874 let roots = build_root_store(cert_dir)?;
875 let client_cert = read_certificate(cert_dir, "client.crt")?;
876 let client_key = read_certificate(cert_dir, "client.key")?;
877 let cert_chain = parse_certificates(&client_cert)?;
878 let key = parse_private_key(&client_key)?;
879 let tls_config = tls_builder
880 .with_root_certificates(roots)
881 .with_client_auth_cert(cert_chain, key)
882 .map_err(ClientError::Tls)?;
883
884 let crypto = QuicClientConfig::try_from(Arc::new(tls_config))
885 .map_err(|_| ClientError::Certificate("failed to select QUIC cipher suite".to_string()))?;
886 let mut client_config = ClientConfig::new(Arc::new(crypto));
887 let mut transport = TransportConfig::default();
888 transport.keep_alive_interval(Some(Duration::from_secs(10)));
889 client_config.transport_config(Arc::new(transport));
890 Ok(client_config)
891}
892
893fn build_root_store(cert_dir: &Path) -> Result<RootCertStore> {
894 let mut store = RootCertStore::empty();
895 let ca_cert = read_certificate(cert_dir, "ca.crt")?;
896 for cert in parse_certificates(&ca_cert)? {
897 store
898 .add(cert)
899 .map_err(|_| ClientError::Certificate("add CA certificate".to_string()))?;
900 }
901 Ok(store)
902}
903
904fn read_certificate(cert_dir: &Path, name: &str) -> Result<Vec<u8>> {
905 let path = cert_dir.join(name);
906 fs::read(&path)
907 .map_err(|err| ClientError::Certificate(format!("read {}: {}", path.display(), err)))
908}
909
910fn default_cert_dir() -> PathBuf {
911 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../../certs")
912}
913
914fn parse_certificates(bytes: &[u8]) -> Result<Vec<CertificateDer<'static>>> {
915 let parsed: std::result::Result<Vec<_>, _> = SliceIter::new(bytes).collect();
916 let parsed = parsed.map_err(|err| ClientError::Certificate(err.to_string()))?;
917 if !parsed.is_empty() {
918 return Ok(parsed);
919 }
920
921 Ok(vec![CertificateDer::from(bytes.to_vec())])
922}
923
924fn parse_private_key(bytes: &[u8]) -> Result<PrivateKeyDer<'static>> {
925 let pkcs8_result: std::result::Result<Vec<PrivatePkcs8KeyDer>, _> =
926 SliceIter::new(bytes).collect();
927 let pkcs8 = pkcs8_result.map_err(|err| ClientError::Certificate(err.to_string()))?;
928 if let Some(key) = pkcs8.into_iter().next() {
929 return Ok(key.into());
930 }
931
932 let rsa_result: std::result::Result<Vec<PrivatePkcs1KeyDer>, _> =
933 SliceIter::new(bytes).collect();
934 let rsa = rsa_result.map_err(|err| ClientError::Certificate(err.to_string()))?;
935 if let Some(key) = rsa.into_iter().next() {
936 return Ok(key.into());
937 }
938
939 PrivateKeyDer::try_from(bytes.to_vec()).map_err(|_| {
940 ClientError::Certificate("no usable private key found in client key material".to_string())
941 })
942}
943
944async fn write_once(mut stream: SendStream, payload: Vec<u8>) -> Result<SendStream> {
945 if !payload.is_empty() {
946 stream
947 .write_all(&payload)
948 .await
949 .map_err(ClientError::Write)?;
950 }
951 Ok(stream)
952}
953
954#[cfg(test)]
955mod tests {
956 use super::*;
957
958 #[test]
959 fn process_builder_validates_signature() {
960 let signature = AbiSignature::new(Vec::new(), Vec::new());
961 let builder = ProcessBuilder::new("module", "entry")
962 .signature(signature.clone())
963 .arg_resource(7u64);
964
965 let result = builder.build_request();
966 assert!(matches!(result, Err(ClientError::InvalidArgument(_))));
967 }
968
969 #[test]
970 fn request_round_trips() {
971 let req = Request::ChannelCreate(64 * 1024);
972 let encoded = encode_request(&req).expect("encode");
973 let decoded = selium_remote_client_protocol::decode_request(&encoded).expect("decode");
974 match decoded {
975 Request::ChannelCreate(capacity) => assert_eq!(capacity, 64 * 1024),
976 other => panic!("unexpected variant: {other:?}"),
977 }
978 }
979
980 #[test]
981 fn try_parse_frame_requires_complete_prefix() {
982 let mut buffer = VecDeque::from([0u8, 1u8, 2u8]);
983 let frame = try_parse_frame(&mut buffer, 8, "frame length exceeds subscribed chunk size")
984 .expect("parse result");
985 assert!(frame.is_none());
986 assert_eq!(buffer.len(), 3);
987 }
988
989 #[test]
990 fn try_parse_frame_extracts_payload() {
991 let mut buffer = VecDeque::new();
992 buffer.extend([3u8, 0, 0, 0]); buffer.extend([1u8, 2, 3]);
994
995 let frame = try_parse_frame(&mut buffer, 8, "frame length exceeds subscribed chunk size")
996 .expect("parse result")
997 .expect("frame available");
998
999 assert_eq!(frame, vec![1u8, 2, 3]);
1000 assert!(buffer.is_empty());
1001 }
1002
1003 #[test]
1004 fn try_parse_frame_rejects_oversized_frame() {
1005 let mut buffer = VecDeque::from([9u8, 0, 0, 0, 1, 2, 3]);
1006 let result = try_parse_frame(&mut buffer, 4, "frame length exceeds subscribed chunk size");
1007 assert!(matches!(result, Err(ClientError::InvalidArgument(_))));
1008 }
1009}