Skip to main content

aranya_daemon_api/crypto/
txp.rs

1//! Encrypted tarpc [`Transport`]s.
2//!
3//! [`Transport`][tarpc::Transport]
4
5use core::{
6    borrow::Borrow,
7    error, fmt,
8    marker::PhantomData,
9    pin::{pin, Pin},
10    task::{Context, Poll},
11};
12use std::{iter, sync::Arc};
13
14use aranya_crypto::{
15    dangerous::spideroak_crypto::{
16        aead::{Aead, Tag},
17        hpke::{Hpke, HpkeError, Mode, OpenCtx, SealCtx, Seq},
18        import::Import,
19        kem::Kem,
20    },
21    CipherSuite, Csprng,
22};
23use buggy::BugExt;
24use bytes::{Bytes, BytesMut};
25use futures_util::{ready, Sink, Stream, TryStream};
26use pin_project::pin_project;
27use serde::{de::DeserializeOwned, Deserialize, Serialize};
28pub use tarpc::tokio_util::codec::length_delimited::{Builder, LengthDelimitedCodec};
29use tarpc::{
30    serde_transport::{self, Transport},
31    tokio_serde::{formats::MessagePack, Deserializer, Serializer},
32    tokio_util::codec::Framed,
33};
34use tokio::io::{self, AsyncRead, AsyncWrite};
35
36use crate::crypto::{ApiKey, PublicApiKey};
37
38fn other<E>(err: E) -> io::Error
39where
40    E: Into<Box<dyn error::Error + Send + Sync>>,
41{
42    io::Error::other(err)
43}
44
45type Encap<CS> = <<CS as CipherSuite>::Kem as Kem>::Encap;
46
47/// HPKE encryption context.
48///
49/// The client creates one the first time it tries to write to
50/// the server. It sends the HPKE peer encapsulation to the
51/// server, then begins sending ciphertext.
52///
53/// The server creates one the first time it receives a HPKE peer
54/// encapsulation from the client.
55struct Ctx<CS: CipherSuite> {
56    seal: SealCtx<<CS as CipherSuite>::Aead>,
57    open: OpenCtx<<CS as CipherSuite>::Aead>,
58}
59
60impl<CS: CipherSuite> Ctx<CS> {
61    // Contextual binding for exporting the server's encryption
62    // key and nonce.
63    const SERVER_KEY_CTX: &[u8] = b"aranya daemon api server seal key";
64    const SERVER_NONCE_CTX: &[u8] = b"aranya daemon api server seal nonce";
65
66    /// Creates the HPKE encryption context for the client.
67    fn client<R: Csprng>(
68        rng: R,
69        pk: &PublicApiKey<CS>,
70        info: &[u8],
71    ) -> Result<(Self, Encap<CS>), HpkeError> {
72        let (enc, send) = Hpke::<CS::Kem, CS::Kdf, CS::Aead>::setup_send(
73            rng,
74            Mode::Base,
75            pk.as_inner(),
76            iter::once(info),
77        )?;
78        // NB: These are the reverse of the server's keys.
79        let (open_key, open_nonce) = {
80            let key = send.export(Self::SERVER_KEY_CTX)?;
81            let nonce = send.export(Self::SERVER_NONCE_CTX)?;
82            (key, nonce)
83        };
84        let (seal_key, seal_nonce) = send
85            .into_raw_parts()
86            .assume("should be able to decompose `SendCtx`")?;
87
88        let ctx = Self {
89            seal: SealCtx::new(&seal_key, &seal_nonce, Seq::ZERO)?,
90            open: OpenCtx::new(&open_key, &open_nonce, Seq::ZERO)?,
91        };
92        Ok((ctx, enc))
93    }
94
95    /// Creates the HPKE encryption context for the server.
96    fn server(sk: &ApiKey<CS>, info: &[u8], enc: &[u8]) -> Result<Self, HpkeError> {
97        let enc = Encap::<CS>::import(enc)?;
98
99        let recv = Hpke::<CS::Kem, CS::Kdf, CS::Aead>::setup_recv(
100            Mode::Base,
101            &enc,
102            sk.as_inner(),
103            iter::once(info),
104        )?;
105        // NB: These are the reverse of the client's keys.
106        let (seal_key, seal_nonce) = {
107            let key = recv.export(Self::SERVER_KEY_CTX)?;
108            let nonce = recv.export(Self::SERVER_NONCE_CTX)?;
109            (key, nonce)
110        };
111        let (open_key, open_nonce) = recv
112            .into_raw_parts()
113            .assume("should be able to decompose `SendCtx`")?;
114
115        Ok(Self {
116            seal: SealCtx::new(&seal_key, &seal_nonce, Seq::ZERO)?,
117            open: OpenCtx::new(&open_key, &open_nonce, Seq::ZERO)?,
118        })
119    }
120
121    /// Serializes `item`, encrypts and authenticates the
122    /// resulting bytes, and returns the ciphertext.
123    ///
124    /// `side` represents the current side performing the
125    /// encryption.
126    fn encrypt<Item, SinkItem>(&mut self, item: SinkItem, side: Side) -> io::Result<Data>
127    where
128        SinkItem: Serialize,
129    {
130        let codec = MessagePack::<Item, SinkItem>::default();
131        let mut plaintext = BytesMut::from(pin!(codec).serialize(&item)?);
132        let mut tag = BytesMut::from(&*Tag::<CS::Aead>::default());
133        let ad = auth_data(self.seal.seq(), side);
134        let seq = self
135            .seal
136            .seal_in_place(&mut plaintext, &mut tag, &ad)
137            .map_err(other)?;
138        Ok(Data {
139            seq: seq.to_u64(),
140            ciphertext: plaintext,
141            tag: tag.freeze(),
142        })
143    }
144
145    /// Decrypts and authenticates `data`, then deserializes the
146    /// resulting plaintext and returns the resulting `Item`.
147    ///
148    /// `side` represents the side that created `data`.
149    fn decrypt<Item, SinkItem>(&mut self, data: Data, side: Side) -> io::Result<Item>
150    where
151        Item: DeserializeOwned,
152    {
153        let Data {
154            seq,
155            mut ciphertext,
156            tag,
157        } = data;
158        let ad = auth_data(Seq::new(seq), side);
159        self.open
160            .open_in_place_at(&mut ciphertext, &tag, &ad, Seq::new(seq))
161            .map_err(other)?;
162        let codec = MessagePack::<Item, SinkItem>::default();
163        let item = pin!(codec).deserialize(&ciphertext)?;
164        Ok(item)
165    }
166}
167
168impl<CS: CipherSuite> fmt::Debug for Ctx<CS> {
169    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170        f.debug_struct("Ctx").finish_non_exhaustive()
171    }
172}
173
174/// Generates the AD for encryption/decryption.
175///
176/// We include the sequence number in the AD per the advice in
177/// [RFC 9180] section 9.7.1.
178///
179/// [RFC 9180]: https://www.rfc-editor.org/rfc/rfc9180.html
180fn auth_data(seq: Seq, side: Side) -> [u8; 8 + 14] {
181    let base = match side {
182        Side::Server => b"server base ad",
183        Side::Client => b"client base ad",
184    };
185
186    // ad = seq || base
187    let mut ad = [0; 8 + 14];
188    ad[..8].copy_from_slice(&seq.to_u64().to_le_bytes());
189    ad[8..].copy_from_slice(base);
190    ad
191}
192
193#[derive(Copy, Clone, Debug, Eq, PartialEq)]
194enum Side {
195    Server,
196    Client,
197}
198
199/// Creates a client-side transport.
200pub fn client<S, R, CS, Item, SinkItem>(
201    io: S,
202    codec: LengthDelimitedCodec,
203    rng: R,
204    pk: PublicApiKey<CS>,
205    info: &[u8],
206) -> ClientConn<S, R, CS, Item, SinkItem>
207where
208    S: AsyncRead + AsyncWrite,
209    CS: CipherSuite,
210{
211    ClientConn {
212        inner: serde_transport::new(Framed::new(io, codec), MessagePack::default()),
213        rng,
214        pk,
215        info: Box::from(info),
216        ctx: None,
217        rekeys: 0,
218        _marker: PhantomData,
219    }
220}
221
222/// An encrypted [`Transport`][tarpc::Transport] for the client.
223///
224/// It is created by [`client`].
225#[pin_project]
226pub struct ClientConn<S, R, CS, Item, SinkItem>
227where
228    CS: CipherSuite,
229{
230    /// The underlying transport.
231    #[pin]
232    inner: Transport<S, ServerMsg, ClientMsg, MessagePack<ServerMsg, ClientMsg>>,
233    /// For rekeying.
234    rng: R,
235    /// The server's public key.
236    pk: PublicApiKey<CS>,
237    /// The "info" parameter when rekeying.
238    info: Box<[u8]>,
239    /// This is set to `Some` the first time the conn (as
240    /// a `Sink`) is polled for readiness.
241    ///
242    /// It is periodically updated via rekeying in order to keep
243    /// the keys fresh.
244    ctx: Option<Ctx<CS>>,
245    /// The number of times we've rekeyed, including the initial
246    /// keying.
247    ///
248    /// Mostly for debugging purposes.
249    rekeys: usize,
250    _marker: PhantomData<fn() -> (Item, SinkItem)>,
251}
252
253impl<S, R, CS, Item, SinkItem> ClientConn<S, R, CS, Item, SinkItem>
254where
255    S: AsyncRead + AsyncWrite,
256    CS: CipherSuite,
257    SinkItem: Serialize,
258{
259    /// Serializes `item`, encrypts and authenticates the
260    /// resulting bytes, and returns the ciphertext.
261    ///
262    /// It is an error if `self.ctx` has not yet been
263    /// initialized.
264    fn encrypt(&mut self, item: SinkItem) -> io::Result<Data> {
265        self.ctx
266            .as_mut()
267            .assume("`self.ctx` should be `Some`")
268            .map_err(other)?
269            .encrypt::<Item, SinkItem>(item, Side::Client)
270            .map_err(other)
271    }
272}
273
274impl<S, R, CS, Item, SinkItem> ClientConn<S, R, CS, Item, SinkItem>
275where
276    CS: CipherSuite,
277    Item: DeserializeOwned,
278{
279    /// Decrypts and authenticates `data`, then deserializes the
280    /// resulting plaintext and returns the resulting `Item`.
281    ///
282    /// It is an error if `self.ctx` has not yet been
283    /// initialized.
284    fn decrypt(&mut self, data: Data) -> io::Result<Item> {
285        self.ctx
286            .as_mut()
287            .assume("`self.ctx` should be `Some`")
288            .map_err(other)?
289            .decrypt::<Item, SinkItem>(data, Side::Server)
290            .map_err(other)
291    }
292}
293
294impl<S, R, CS, Item, SinkItem> ClientConn<S, R, CS, Item, SinkItem>
295where
296    R: Csprng,
297    CS: CipherSuite,
298{
299    /// Returns `Some` with the `Rekey` message to send to the
300    /// server if we need to rekey, or `None` otherwise.
301    fn try_rekey(&mut self) -> Result<Option<ClientMsg>, HpkeError> {
302        if !self.need_rekey() {
303            return Ok(None);
304        }
305        let enc = self.rekey()?;
306        let msg = ClientMsg::Rekey(Rekey {
307            enc: Bytes::from(enc.borrow().to_vec()),
308        });
309        Ok(Some(msg))
310    }
311
312    /// Reports whether we need to generate a new HPKE encryption
313    /// context.
314    fn need_rekey(&self) -> bool {
315        let Some(ctx) = self.ctx.as_ref() else {
316            return true;
317        };
318        // To prevent us from reaching the end of the sequence,
319        // rekey when we're halfway there.
320        let max = Seq::max::<<CS::Aead as Aead>::NonceSize>();
321        let seq = ctx.seal.seq().to_u64();
322        seq >= max / 2
323    }
324
325    /// Generates a new HPKE encryption context and returns the
326    /// resulting peer encapsulation.
327    fn rekey(&mut self) -> Result<Encap<CS>, HpkeError> {
328        let (ctx, enc) = Ctx::client(&mut self.rng, &self.pk, &self.info)?;
329        self.ctx = Some(ctx);
330        // Rekeying takes so long (relatively speaking, anyway)
331        // that this should never overflow.
332        self.rekeys = self
333            .rekeys
334            .checked_add(1)
335            .assume("rekey count should not overflow")?;
336        Ok(enc)
337    }
338}
339
340impl<S, R, CS, Item, SinkItem> Stream for ClientConn<S, R, CS, Item, SinkItem>
341where
342    S: AsyncRead + AsyncWrite + Unpin,
343    R: Csprng,
344    CS: CipherSuite,
345    Item: DeserializeOwned,
346{
347    type Item = io::Result<Item>;
348
349    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
350        if self.ctx.is_none() {
351            // In tarpc the client always writes first. We create
352            // our encryption context the first time we write, so
353            // if we get here we haven't written yet.
354            // TODO(eric): should we return an error instead?
355            return Poll::Pending;
356        }
357        let Some(msg) = ready!(self.as_mut().project().inner.poll_next(cx)?) else {
358            return Poll::Ready(None);
359        };
360        match msg {
361            ServerMsg::Data(data) => {
362                let pt = self.decrypt(data)?;
363                Poll::Ready(Some(Ok(pt)))
364            }
365        }
366    }
367}
368
369impl<S, R, CS, Item, SinkItem> Sink<SinkItem> for ClientConn<S, R, CS, Item, SinkItem>
370where
371    S: AsyncRead + AsyncWrite + Unpin,
372    R: Csprng,
373    CS: CipherSuite,
374    SinkItem: Serialize,
375{
376    type Error = io::Error;
377
378    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
379        ready!(self.as_mut().project().inner.poll_ready(cx)?);
380
381        // Do we need to rekey?
382        if let Some(msg) = self.try_rekey().map_err(other)? {
383            // We updated our keys, so forward the message on to
384            // the server.
385            self.as_mut().project().inner.start_send(msg)?;
386
387            // Each call to `start_send` must be preceeded by
388            // a call to `poll_ready`, so call `poll_ready`
389            // again.
390            ready!(self.as_mut().project().inner.poll_ready(cx)?);
391        }
392
393        Poll::Ready(Ok(()))
394    }
395
396    fn start_send(mut self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
397        let data = self.encrypt(item)?;
398        self.project().inner.start_send(ClientMsg::Data(data))?;
399        Ok(())
400    }
401
402    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
403        self.project().inner.poll_flush(cx)
404    }
405
406    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
407        self.project().inner.poll_close(cx)
408    }
409}
410
411impl<S, R, CS, Item, SinkItem> fmt::Debug for ClientConn<S, R, CS, Item, SinkItem>
412where
413    CS: CipherSuite,
414{
415    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416        f.debug_struct("Server")
417            .field("pk", &self.pk)
418            .field("info", &self.info)
419            .field("ctx", &self.ctx)
420            .field("rekeys", &self.rekeys)
421            .finish_non_exhaustive()
422    }
423}
424
425/// A message (request) sent by the client to the server.
426#[derive(Clone, Debug, Serialize, Deserialize)]
427#[non_exhaustive]
428enum ClientMsg {
429    Data(Data),
430    Rekey(Rekey),
431}
432
433/// Some encrypted data.
434#[derive(Clone, Debug, Serialize, Deserialize)]
435struct Data {
436    /// The position of this ciphertext in the stream of
437    /// messages.
438    seq: u64,
439    /// The ciphertext.
440    ciphertext: BytesMut,
441    /// The authentication tag.
442    tag: Bytes,
443}
444
445/// Instructs the server to rekey.
446#[derive(Clone, Debug, Serialize, Deserialize)]
447struct Rekey {
448    /// The HPKE peer encapsulation.
449    enc: Bytes,
450}
451
452/// Creates a server-side transport.
453pub fn server<L, CS, Item, SinkItem>(
454    listener: L,
455    codec: LengthDelimitedCodec,
456    sk: ApiKey<CS>,
457    info: &[u8],
458) -> Server<L, CS, Item, SinkItem>
459where
460    CS: CipherSuite,
461{
462    Server {
463        listener,
464        codec,
465        sk: Arc::new(sk),
466        info: Arc::from(info),
467        _marker: PhantomData,
468    }
469}
470
471/// Creates [`ServerConn`]s.
472///
473/// It is created by [`server`]
474#[derive(Debug)]
475#[pin_project]
476pub struct Server<L, CS, Item, SinkItem>
477where
478    CS: CipherSuite,
479{
480    #[pin]
481    listener: L,
482    codec: LengthDelimitedCodec,
483    /// The server's secret key.
484    sk: Arc<ApiKey<CS>>,
485    /// The "info" parameter when rekeying.
486    info: Arc<[u8]>,
487    _marker: PhantomData<fn() -> (Item, SinkItem)>,
488}
489
490impl<S, L, CS, Item, SinkItem> Stream for Server<L, CS, Item, SinkItem>
491where
492    S: AsyncRead + AsyncWrite,
493    L: TryStream<Ok = S, Error = io::Error>,
494    CS: CipherSuite,
495{
496    type Item = io::Result<ServerConn<S, CS, Item, SinkItem>>;
497
498    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
499        let Some(io) = ready!(self.as_mut().project().listener.try_poll_next(cx)?) else {
500            return Poll::Ready(None);
501        };
502        let conn = ServerConn {
503            inner: serde_transport::new(
504                Framed::new(io, self.codec.clone()),
505                MessagePack::default(),
506            ),
507            sk: Arc::clone(&self.sk),
508            info: Arc::clone(&self.info),
509            ctx: None,
510            _marker: PhantomData,
511        };
512        Poll::Ready(Some(Ok(conn)))
513    }
514}
515
516/// An encrypted [`Transport`][tarpc::Transport] for the server.
517///
518/// It is created by reading from [`Server`], which is
519/// a [`Stream`].
520#[pin_project]
521pub struct ServerConn<S, CS, Item, SinkItem>
522where
523    CS: CipherSuite,
524{
525    /// The underlying transport.
526    #[pin]
527    inner: Transport<S, ClientMsg, ServerMsg, MessagePack<ClientMsg, ServerMsg>>,
528    /// The server's secret key.
529    sk: Arc<ApiKey<CS>>,
530    /// The "info" parameter when rekeying.
531    info: Arc<[u8]>,
532    /// The HPKE encryption context.
533    ///
534    /// This is set to `Some` after the client sends the first
535    /// `Rekey` message.
536    ///
537    /// It is periodically updated via rekeying in order to keep
538    /// the keys fresh.
539    ctx: Option<Ctx<CS>>,
540    _marker: PhantomData<fn() -> (Item, SinkItem)>,
541}
542
543impl<S, CS, Item, SinkItem> ServerConn<S, CS, Item, SinkItem>
544where
545    CS: CipherSuite,
546    SinkItem: Serialize,
547{
548    /// Serializes `item`, encrypts and authenticates the
549    /// resulting bytes, and returns the ciphertext.
550    ///
551    /// It is an error if `self.ctx` has not yet been
552    /// initialized.
553    fn encrypt(&mut self, item: SinkItem) -> io::Result<Data> {
554        self.ctx
555            .as_mut()
556            .assume("`self.ctx` should be `Some`")
557            .map_err(other)?
558            .encrypt::<Item, SinkItem>(item, Side::Server)
559            .map_err(other)
560    }
561}
562
563impl<S, CS, Item, SinkItem> ServerConn<S, CS, Item, SinkItem>
564where
565    CS: CipherSuite,
566    Item: DeserializeOwned,
567{
568    /// Decrypts and authenticates `data`, then deserializes the
569    /// resulting plaintext and returns the resulting `Item`.
570    ///
571    /// It is an error if `self.ctx` has not yet been
572    /// initialized.
573    fn decrypt(&mut self, data: Data) -> io::Result<Item> {
574        self.ctx
575            .as_mut()
576            .assume("`self.ctx` should be `Some`")
577            .map_err(other)?
578            .decrypt::<Item, SinkItem>(data, Side::Client)
579            .map_err(other)
580    }
581}
582
583impl<S, CS, Item, SinkItem> ServerConn<S, CS, Item, SinkItem>
584where
585    CS: CipherSuite,
586{
587    /// Updates the HPKE encryption context per the peer's
588    /// encapsulation.
589    fn rekey(&mut self, msg: Rekey) -> Result<(), HpkeError> {
590        let ctx = Ctx::server(&self.sk, &self.info, &msg.enc)?;
591        self.ctx = Some(ctx);
592        Ok(())
593    }
594}
595
596impl<S, CS, Item, SinkItem> Stream for ServerConn<S, CS, Item, SinkItem>
597where
598    S: AsyncRead + AsyncWrite + Unpin,
599    CS: CipherSuite,
600    Item: DeserializeOwned,
601{
602    type Item = io::Result<Item>;
603
604    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
605        // Skip past control (i.e., non-`Data`) messages.
606        loop {
607            let Some(msg) = ready!(self.as_mut().project().inner.poll_next(cx)?) else {
608                return Poll::Ready(None);
609            };
610            match msg {
611                ClientMsg::Data(data) => {
612                    let pt = self.decrypt(data)?;
613                    return Poll::Ready(Some(Ok(pt)));
614                }
615                ClientMsg::Rekey(rekey) => self.rekey(rekey).map_err(other)?,
616            }
617        }
618    }
619}
620
621impl<S, CS, Item, SinkItem> Sink<SinkItem> for ServerConn<S, CS, Item, SinkItem>
622where
623    S: AsyncRead + AsyncWrite + Unpin,
624    CS: CipherSuite,
625    SinkItem: Serialize,
626{
627    type Error = io::Error;
628
629    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
630        self.project().inner.poll_ready(cx)
631    }
632
633    fn start_send(mut self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
634        let data = self.encrypt(item)?;
635        self.project().inner.start_send(ServerMsg::Data(data))
636    }
637
638    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
639        self.project().inner.poll_flush(cx)
640    }
641
642    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
643        self.project().inner.poll_close(cx)
644    }
645}
646
647impl<S, CS, Item, SinkItem> fmt::Debug for ServerConn<S, CS, Item, SinkItem>
648where
649    CS: CipherSuite,
650{
651    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
652        f.debug_struct("Server")
653            .field("sk", &self.sk)
654            .field("info", &self.info)
655            .field("ctx", &self.ctx)
656            .finish_non_exhaustive()
657    }
658}
659
660/// A message (response) sent by the server to the client.
661#[derive(Clone, Debug, Serialize, Deserialize)]
662#[non_exhaustive]
663enum ServerMsg {
664    Data(Data),
665}
666
667/// Unix utilities.
668#[cfg(unix)]
669#[cfg_attr(docsrs, doc(cfg(unix)))]
670pub mod unix {
671    use core::{
672        pin::Pin,
673        task::{Context, Poll},
674    };
675
676    use futures_util::{ready, Stream};
677    use tokio::{
678        io,
679        net::{UnixListener, UnixStream},
680    };
681
682    /// Converts a [`UnixListener`] into a [`Stream`].
683    #[derive(Debug)]
684    pub struct UnixListenerStream(UnixListener);
685
686    impl Stream for UnixListenerStream {
687        type Item = io::Result<UnixStream>;
688
689        fn poll_next(
690            self: Pin<&mut Self>,
691            cx: &mut Context<'_>,
692        ) -> Poll<Option<io::Result<UnixStream>>> {
693            let (stream, _) = ready!(self.0.poll_accept(cx))?;
694            Poll::Ready(Some(Ok(stream)))
695        }
696    }
697
698    impl From<UnixListener> for UnixListenerStream {
699        #[inline]
700        fn from(listener: UnixListener) -> Self {
701            Self(listener)
702        }
703    }
704}
705
706#[cfg(test)]
707#[cfg(unix)]
708#[allow(clippy::arithmetic_side_effects, clippy::panic)]
709mod tests {
710    use std::panic;
711
712    use aranya_crypto::{
713        default::{DefaultCipherSuite, DefaultEngine},
714        Rng,
715    };
716    use backon::{ExponentialBuilder, Retryable as _};
717    use futures_util::{SinkExt, TryStreamExt};
718    use tokio::{
719        net::{UnixListener, UnixStream},
720        task::JoinSet,
721    };
722
723    use super::*;
724
725    impl<S, R, CS, Item, SinkItem> ClientConn<S, R, CS, Item, SinkItem>
726    where
727        S: AsyncRead + AsyncWrite + Unpin,
728        CS: CipherSuite,
729    {
730        fn force_rekey(&mut self) {
731            self.ctx = None;
732        }
733    }
734
735    type CS = DefaultCipherSuite;
736
737    #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
738    struct Ping {
739        v: usize,
740    }
741
742    #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
743    struct Pong {
744        v: usize,
745    }
746
747    /// Basic one client, one server ping pong test.
748    #[tokio::test(flavor = "multi_thread")]
749    async fn test_ping_pong() {
750        let dir = tempfile::tempdir().unwrap();
751        let path = Arc::new(dir.path().to_path_buf().join("sock"));
752        let info = Arc::from(path.as_os_str().as_encoded_bytes());
753
754        let (eng, _) = DefaultEngine::from_entropy(Rng);
755        let sk = ApiKey::<CS>::new(&eng);
756        let pk = sk.public().unwrap();
757
758        const MAX_PING_PONGS: usize = 100;
759
760        let mut set = JoinSet::new();
761
762        {
763            let path = Arc::clone(&path);
764            let info = Arc::clone(&info);
765            set.spawn(async move {
766                let listener = UnixListener::bind(&*path)?;
767                let codec = LengthDelimitedCodec::builder()
768                    .max_frame_length(usize::MAX)
769                    .new_codec();
770                let mut server = server::<_, _, Ping, Pong>(
771                    unix::UnixListenerStream::from(listener),
772                    codec.clone(),
773                    sk,
774                    &info,
775                );
776
777                let mut conn = server.try_next().await.unwrap().unwrap();
778                for v in 0..MAX_PING_PONGS {
779                    let got = conn.try_next().await?.ok_or_else(|| {
780                        io::Error::new(io::ErrorKind::UnexpectedEof, "stream finished early")
781                    })?;
782                    assert_eq!(got, Ping { v });
783                    conn.send(Pong {
784                        v: got.v.wrapping_add(1),
785                    })
786                    .await?;
787                }
788                io::Result::Ok(())
789            });
790        }
791
792        {
793            let path = Arc::clone(&path);
794            let info = Arc::clone(&info);
795            set.spawn(async move {
796                let codec = LengthDelimitedCodec::builder()
797                    .max_frame_length(usize::MAX)
798                    .new_codec();
799                let sock = (|| UnixStream::connect(&*path))
800                    .retry(ExponentialBuilder::default())
801                    .await
802                    .unwrap();
803                let mut client = client::<_, _, _, Pong, Ping>(sock, codec, Rng, pk, &info);
804                for v in 0..MAX_PING_PONGS {
805                    client.send(Ping { v }).await?;
806                    let got = client.try_next().await?.ok_or_else(|| {
807                        io::Error::new(io::ErrorKind::UnexpectedEof, "stream finished early")
808                    })?;
809                    let want = Pong {
810                        v: v.wrapping_add(1),
811                    };
812                    assert_eq!(got, want)
813                }
814                Ok(())
815            });
816        }
817
818        while let Some(res) = set.join_next().await {
819            match res {
820                Ok(Ok(())) => {}
821                Ok(Err(err)) => {
822                    set.abort_all();
823                    panic!("{err}");
824                }
825                Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
826                Err(err) => panic!("{err}"),
827            }
828        }
829    }
830
831    /// One client rekeys each request.
832    #[tokio::test(flavor = "multi_thread")]
833    async fn test_rekey() {
834        let dir = tempfile::tempdir().unwrap();
835        let path = Arc::new(dir.path().to_path_buf().join("sock"));
836        let info = Arc::from(path.as_os_str().as_encoded_bytes());
837
838        let (eng, _) = DefaultEngine::from_entropy(Rng);
839        let sk = ApiKey::<CS>::new(&eng);
840        let pk = sk.public().unwrap();
841
842        const MAX_PING_PONGS: usize = 100;
843
844        let mut set = JoinSet::new();
845
846        {
847            let path = Arc::clone(&path);
848            let info = Arc::clone(&info);
849            set.spawn(async move {
850                let listener = UnixListener::bind(&*path).unwrap();
851                let codec = LengthDelimitedCodec::builder()
852                    .max_frame_length(usize::MAX)
853                    .new_codec();
854                let mut server = server::<_, _, Ping, Pong>(
855                    unix::UnixListenerStream::from(listener),
856                    codec.clone(),
857                    sk,
858                    &info,
859                );
860                let mut conn = server.try_next().await.unwrap().unwrap();
861                for v in 0..MAX_PING_PONGS {
862                    let got = conn.try_next().await?.ok_or_else(|| {
863                        io::Error::new(io::ErrorKind::UnexpectedEof, "stream finished early")
864                    })?;
865                    // In this test the client rekeys each time
866                    // it sends data, so our seq number should
867                    // always be zero.
868                    let ctx = conn.ctx.as_ref().map(|ctx| &ctx.seal).unwrap();
869                    assert_eq!(ctx.seq(), Seq::ZERO);
870
871                    assert_eq!(got, Ping { v });
872                    conn.send(Pong {
873                        v: got.v.wrapping_add(1),
874                    })
875                    .await?;
876
877                    // Double check that it actually increments.
878                    let ctx = conn.ctx.as_ref().map(|ctx| &ctx.seal).unwrap();
879                    assert_eq!(ctx.seq(), Seq::new(1));
880                }
881                io::Result::Ok(())
882            });
883        }
884
885        {
886            let path = Arc::clone(&path);
887            let info = Arc::clone(&info);
888            set.spawn(async move {
889                let codec = LengthDelimitedCodec::builder()
890                    .max_frame_length(usize::MAX)
891                    .new_codec();
892                let sock = (|| UnixStream::connect(&*path))
893                    .retry(ExponentialBuilder::default())
894                    .await
895                    .unwrap();
896                let mut client = client::<_, _, _, Pong, Ping>(sock, codec, Rng, pk, &info);
897                for v in 0..MAX_PING_PONGS {
898                    let last = client.rekeys;
899                    client.force_rekey();
900                    client.send(Ping { v }).await.unwrap();
901                    assert_eq!(client.rekeys, last + 1);
902                    let got = client.try_next().await?.ok_or_else(|| {
903                        io::Error::new(io::ErrorKind::UnexpectedEof, "stream finished early")
904                    })?;
905                    let want = Pong {
906                        v: v.wrapping_add(1),
907                    };
908                    assert_eq!(got, want)
909                }
910                Ok(())
911            });
912        }
913
914        while let Some(res) = set.join_next().await {
915            match res {
916                Ok(Ok(())) => {}
917                Ok(Err(err)) => {
918                    set.abort_all();
919                    panic!("{err}");
920                }
921                Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
922                Err(err) => panic!("{err}"),
923            }
924        }
925    }
926
927    /// N clients make repeated requests to one server.
928    #[tokio::test(flavor = "multi_thread")]
929    async fn test_multi_client() {
930        let dir = tempfile::tempdir().unwrap();
931        let path = Arc::new(dir.path().to_path_buf().join("sock"));
932        let info = Arc::from(path.as_os_str().as_encoded_bytes());
933
934        let (eng, _) = DefaultEngine::from_entropy(Rng);
935        let sk = ApiKey::<CS>::new(&eng);
936        let pk = sk.public().unwrap();
937
938        const MAX_PING_PONGS: usize = 2;
939        const MAX_CLIENTS: usize = 10;
940
941        let mut set = JoinSet::new();
942
943        {
944            let path = Arc::clone(&path);
945            let info = Arc::clone(&info);
946            set.spawn(async move {
947                let listener = UnixListener::bind(&*path).unwrap();
948                let codec = LengthDelimitedCodec::builder()
949                    .max_frame_length(usize::MAX)
950                    .new_codec();
951                let mut server = server::<_, _, Ping, Pong>(
952                    unix::UnixListenerStream::from(listener),
953                    codec.clone(),
954                    sk,
955                    &info,
956                );
957                let mut set = JoinSet::new();
958                for _ in 0..MAX_CLIENTS {
959                    let mut conn = server.try_next().await?.unwrap();
960                    set.spawn(async move {
961                        for v in 0..MAX_PING_PONGS {
962                            let got = conn.try_next().await?.ok_or_else(|| {
963                                io::Error::new(
964                                    io::ErrorKind::UnexpectedEof,
965                                    "client stream finished early",
966                                )
967                            })?;
968                            assert_eq!(got, Ping { v });
969                            conn.send(Pong {
970                                v: got.v.wrapping_add(1),
971                            })
972                            .await?;
973                        }
974                        io::Result::Ok(())
975                    });
976                }
977                set.join_all()
978                    .await
979                    .into_iter()
980                    .find(|v| v.is_err())
981                    .unwrap_or(Ok(()))
982            });
983        }
984
985        for _ in 0..10 {
986            let path = Arc::clone(&path);
987            let info = Arc::clone(&info);
988            let pk = pk.clone();
989            set.spawn(async move {
990                let codec = LengthDelimitedCodec::builder()
991                    .max_frame_length(usize::MAX)
992                    .new_codec();
993                let sock = (|| UnixStream::connect(&*path))
994                    .retry(ExponentialBuilder::default())
995                    .await
996                    .unwrap();
997                let mut client = client::<_, _, _, Pong, Ping>(sock, codec, Rng, pk, &info);
998                for v in 0..MAX_PING_PONGS {
999                    client.send(Ping { v }).await?;
1000                    let got = client.try_next().await?.ok_or_else(|| {
1001                        io::Error::new(io::ErrorKind::UnexpectedEof, "server stream finished early")
1002                    })?;
1003                    let want = Pong {
1004                        v: v.wrapping_add(1),
1005                    };
1006                    assert_eq!(got, want);
1007                }
1008                Ok(())
1009            });
1010        }
1011
1012        while let Some(res) = set.join_next().await {
1013            match res {
1014                Ok(Ok(())) => {}
1015                Ok(Err(err)) => {
1016                    set.abort_all();
1017                    panic!("{err}");
1018                }
1019                Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
1020                Err(err) => panic!("{err}"),
1021            }
1022        }
1023    }
1024}