1use core::{
6 borrow::Borrow,
7 error, fmt,
8 marker::PhantomData,
9 pin::{pin, Pin},
10 task::{Context, Poll},
11};
12use std::{iter, sync::Arc};
13
14use aranya_crypto::{
15 dangerous::spideroak_crypto::{
16 aead::{Aead, Tag},
17 hpke::{Hpke, HpkeError, Mode, OpenCtx, SealCtx, Seq},
18 import::Import,
19 kem::Kem,
20 },
21 CipherSuite, Csprng,
22};
23use buggy::BugExt;
24use bytes::{Bytes, BytesMut};
25use futures_util::{ready, Sink, Stream, TryStream};
26use pin_project::pin_project;
27use serde::{de::DeserializeOwned, Deserialize, Serialize};
28pub use tarpc::tokio_util::codec::length_delimited::{Builder, LengthDelimitedCodec};
29use tarpc::{
30 serde_transport::{self, Transport},
31 tokio_serde::{formats::MessagePack, Deserializer, Serializer},
32 tokio_util::codec::Framed,
33};
34use tokio::io::{self, AsyncRead, AsyncWrite};
35
36use crate::crypto::{ApiKey, PublicApiKey};
37
38fn other<E>(err: E) -> io::Error
39where
40 E: Into<Box<dyn error::Error + Send + Sync>>,
41{
42 io::Error::other(err)
43}
44
45type Encap<CS> = <<CS as CipherSuite>::Kem as Kem>::Encap;
46
47struct Ctx<CS: CipherSuite> {
56 seal: SealCtx<<CS as CipherSuite>::Aead>,
57 open: OpenCtx<<CS as CipherSuite>::Aead>,
58}
59
60impl<CS: CipherSuite> Ctx<CS> {
61 const SERVER_KEY_CTX: &[u8] = b"aranya daemon api server seal key";
64 const SERVER_NONCE_CTX: &[u8] = b"aranya daemon api server seal nonce";
65
66 fn client<R: Csprng>(
68 rng: R,
69 pk: &PublicApiKey<CS>,
70 info: &[u8],
71 ) -> Result<(Self, Encap<CS>), HpkeError> {
72 let (enc, send) = Hpke::<CS::Kem, CS::Kdf, CS::Aead>::setup_send(
73 rng,
74 Mode::Base,
75 pk.as_inner(),
76 iter::once(info),
77 )?;
78 let (open_key, open_nonce) = {
80 let key = send.export(Self::SERVER_KEY_CTX)?;
81 let nonce = send.export(Self::SERVER_NONCE_CTX)?;
82 (key, nonce)
83 };
84 let (seal_key, seal_nonce) = send
85 .into_raw_parts()
86 .assume("should be able to decompose `SendCtx`")?;
87
88 let ctx = Self {
89 seal: SealCtx::new(&seal_key, &seal_nonce, Seq::ZERO)?,
90 open: OpenCtx::new(&open_key, &open_nonce, Seq::ZERO)?,
91 };
92 Ok((ctx, enc))
93 }
94
95 fn server(sk: &ApiKey<CS>, info: &[u8], enc: &[u8]) -> Result<Self, HpkeError> {
97 let enc = Encap::<CS>::import(enc)?;
98
99 let recv = Hpke::<CS::Kem, CS::Kdf, CS::Aead>::setup_recv(
100 Mode::Base,
101 &enc,
102 sk.as_inner(),
103 iter::once(info),
104 )?;
105 let (seal_key, seal_nonce) = {
107 let key = recv.export(Self::SERVER_KEY_CTX)?;
108 let nonce = recv.export(Self::SERVER_NONCE_CTX)?;
109 (key, nonce)
110 };
111 let (open_key, open_nonce) = recv
112 .into_raw_parts()
113 .assume("should be able to decompose `SendCtx`")?;
114
115 Ok(Self {
116 seal: SealCtx::new(&seal_key, &seal_nonce, Seq::ZERO)?,
117 open: OpenCtx::new(&open_key, &open_nonce, Seq::ZERO)?,
118 })
119 }
120
121 fn encrypt<Item, SinkItem>(&mut self, item: SinkItem, side: Side) -> io::Result<Data>
127 where
128 SinkItem: Serialize,
129 {
130 let codec = MessagePack::<Item, SinkItem>::default();
131 let mut plaintext = BytesMut::from(pin!(codec).serialize(&item)?);
132 let mut tag = BytesMut::from(&*Tag::<CS::Aead>::default());
133 let ad = auth_data(self.seal.seq(), side);
134 let seq = self
135 .seal
136 .seal_in_place(&mut plaintext, &mut tag, &ad)
137 .map_err(other)?;
138 Ok(Data {
139 seq: seq.to_u64(),
140 ciphertext: plaintext,
141 tag: tag.freeze(),
142 })
143 }
144
145 fn decrypt<Item, SinkItem>(&mut self, data: Data, side: Side) -> io::Result<Item>
150 where
151 Item: DeserializeOwned,
152 {
153 let Data {
154 seq,
155 mut ciphertext,
156 tag,
157 } = data;
158 let ad = auth_data(Seq::new(seq), side);
159 self.open
160 .open_in_place_at(&mut ciphertext, &tag, &ad, Seq::new(seq))
161 .map_err(other)?;
162 let codec = MessagePack::<Item, SinkItem>::default();
163 let item = pin!(codec).deserialize(&ciphertext)?;
164 Ok(item)
165 }
166}
167
168impl<CS: CipherSuite> fmt::Debug for Ctx<CS> {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 f.debug_struct("Ctx").finish_non_exhaustive()
171 }
172}
173
174fn auth_data(seq: Seq, side: Side) -> [u8; 8 + 14] {
181 let base = match side {
182 Side::Server => b"server base ad",
183 Side::Client => b"client base ad",
184 };
185
186 let mut ad = [0; 8 + 14];
188 ad[..8].copy_from_slice(&seq.to_u64().to_le_bytes());
189 ad[8..].copy_from_slice(base);
190 ad
191}
192
193#[derive(Copy, Clone, Debug, Eq, PartialEq)]
194enum Side {
195 Server,
196 Client,
197}
198
199pub fn client<S, R, CS, Item, SinkItem>(
201 io: S,
202 codec: LengthDelimitedCodec,
203 rng: R,
204 pk: PublicApiKey<CS>,
205 info: &[u8],
206) -> ClientConn<S, R, CS, Item, SinkItem>
207where
208 S: AsyncRead + AsyncWrite,
209 CS: CipherSuite,
210{
211 ClientConn {
212 inner: serde_transport::new(Framed::new(io, codec), MessagePack::default()),
213 rng,
214 pk,
215 info: Box::from(info),
216 ctx: None,
217 rekeys: 0,
218 _marker: PhantomData,
219 }
220}
221
222#[pin_project]
226pub struct ClientConn<S, R, CS, Item, SinkItem>
227where
228 CS: CipherSuite,
229{
230 #[pin]
232 inner: Transport<S, ServerMsg, ClientMsg, MessagePack<ServerMsg, ClientMsg>>,
233 rng: R,
235 pk: PublicApiKey<CS>,
237 info: Box<[u8]>,
239 ctx: Option<Ctx<CS>>,
245 rekeys: usize,
250 _marker: PhantomData<fn() -> (Item, SinkItem)>,
251}
252
253impl<S, R, CS, Item, SinkItem> ClientConn<S, R, CS, Item, SinkItem>
254where
255 S: AsyncRead + AsyncWrite,
256 CS: CipherSuite,
257 SinkItem: Serialize,
258{
259 fn encrypt(&mut self, item: SinkItem) -> io::Result<Data> {
265 self.ctx
266 .as_mut()
267 .assume("`self.ctx` should be `Some`")
268 .map_err(other)?
269 .encrypt::<Item, SinkItem>(item, Side::Client)
270 .map_err(other)
271 }
272}
273
274impl<S, R, CS, Item, SinkItem> ClientConn<S, R, CS, Item, SinkItem>
275where
276 CS: CipherSuite,
277 Item: DeserializeOwned,
278{
279 fn decrypt(&mut self, data: Data) -> io::Result<Item> {
285 self.ctx
286 .as_mut()
287 .assume("`self.ctx` should be `Some`")
288 .map_err(other)?
289 .decrypt::<Item, SinkItem>(data, Side::Server)
290 .map_err(other)
291 }
292}
293
294impl<S, R, CS, Item, SinkItem> ClientConn<S, R, CS, Item, SinkItem>
295where
296 R: Csprng,
297 CS: CipherSuite,
298{
299 fn try_rekey(&mut self) -> Result<Option<ClientMsg>, HpkeError> {
302 if !self.need_rekey() {
303 return Ok(None);
304 }
305 let enc = self.rekey()?;
306 let msg = ClientMsg::Rekey(Rekey {
307 enc: Bytes::from(enc.borrow().to_vec()),
308 });
309 Ok(Some(msg))
310 }
311
312 fn need_rekey(&self) -> bool {
315 let Some(ctx) = self.ctx.as_ref() else {
316 return true;
317 };
318 let max = Seq::max::<<CS::Aead as Aead>::NonceSize>();
321 let seq = ctx.seal.seq().to_u64();
322 seq >= max / 2
323 }
324
325 fn rekey(&mut self) -> Result<Encap<CS>, HpkeError> {
328 let (ctx, enc) = Ctx::client(&mut self.rng, &self.pk, &self.info)?;
329 self.ctx = Some(ctx);
330 self.rekeys = self
333 .rekeys
334 .checked_add(1)
335 .assume("rekey count should not overflow")?;
336 Ok(enc)
337 }
338}
339
340impl<S, R, CS, Item, SinkItem> Stream for ClientConn<S, R, CS, Item, SinkItem>
341where
342 S: AsyncRead + AsyncWrite + Unpin,
343 R: Csprng,
344 CS: CipherSuite,
345 Item: DeserializeOwned,
346{
347 type Item = io::Result<Item>;
348
349 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
350 if self.ctx.is_none() {
351 return Poll::Pending;
356 }
357 let Some(msg) = ready!(self.as_mut().project().inner.poll_next(cx)?) else {
358 return Poll::Ready(None);
359 };
360 match msg {
361 ServerMsg::Data(data) => {
362 let pt = self.decrypt(data)?;
363 Poll::Ready(Some(Ok(pt)))
364 }
365 }
366 }
367}
368
369impl<S, R, CS, Item, SinkItem> Sink<SinkItem> for ClientConn<S, R, CS, Item, SinkItem>
370where
371 S: AsyncRead + AsyncWrite + Unpin,
372 R: Csprng,
373 CS: CipherSuite,
374 SinkItem: Serialize,
375{
376 type Error = io::Error;
377
378 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
379 ready!(self.as_mut().project().inner.poll_ready(cx)?);
380
381 if let Some(msg) = self.try_rekey().map_err(other)? {
383 self.as_mut().project().inner.start_send(msg)?;
386
387 ready!(self.as_mut().project().inner.poll_ready(cx)?);
391 }
392
393 Poll::Ready(Ok(()))
394 }
395
396 fn start_send(mut self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
397 let data = self.encrypt(item)?;
398 self.project().inner.start_send(ClientMsg::Data(data))?;
399 Ok(())
400 }
401
402 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
403 self.project().inner.poll_flush(cx)
404 }
405
406 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
407 self.project().inner.poll_close(cx)
408 }
409}
410
411impl<S, R, CS, Item, SinkItem> fmt::Debug for ClientConn<S, R, CS, Item, SinkItem>
412where
413 CS: CipherSuite,
414{
415 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
416 f.debug_struct("Server")
417 .field("pk", &self.pk)
418 .field("info", &self.info)
419 .field("ctx", &self.ctx)
420 .field("rekeys", &self.rekeys)
421 .finish_non_exhaustive()
422 }
423}
424
425#[derive(Clone, Debug, Serialize, Deserialize)]
427#[non_exhaustive]
428enum ClientMsg {
429 Data(Data),
430 Rekey(Rekey),
431}
432
433#[derive(Clone, Debug, Serialize, Deserialize)]
435struct Data {
436 seq: u64,
439 ciphertext: BytesMut,
441 tag: Bytes,
443}
444
445#[derive(Clone, Debug, Serialize, Deserialize)]
447struct Rekey {
448 enc: Bytes,
450}
451
452pub fn server<L, CS, Item, SinkItem>(
454 listener: L,
455 codec: LengthDelimitedCodec,
456 sk: ApiKey<CS>,
457 info: &[u8],
458) -> Server<L, CS, Item, SinkItem>
459where
460 CS: CipherSuite,
461{
462 Server {
463 listener,
464 codec,
465 sk: Arc::new(sk),
466 info: Arc::from(info),
467 _marker: PhantomData,
468 }
469}
470
471#[derive(Debug)]
475#[pin_project]
476pub struct Server<L, CS, Item, SinkItem>
477where
478 CS: CipherSuite,
479{
480 #[pin]
481 listener: L,
482 codec: LengthDelimitedCodec,
483 sk: Arc<ApiKey<CS>>,
485 info: Arc<[u8]>,
487 _marker: PhantomData<fn() -> (Item, SinkItem)>,
488}
489
490impl<S, L, CS, Item, SinkItem> Stream for Server<L, CS, Item, SinkItem>
491where
492 S: AsyncRead + AsyncWrite,
493 L: TryStream<Ok = S, Error = io::Error>,
494 CS: CipherSuite,
495{
496 type Item = io::Result<ServerConn<S, CS, Item, SinkItem>>;
497
498 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
499 let Some(io) = ready!(self.as_mut().project().listener.try_poll_next(cx)?) else {
500 return Poll::Ready(None);
501 };
502 let conn = ServerConn {
503 inner: serde_transport::new(
504 Framed::new(io, self.codec.clone()),
505 MessagePack::default(),
506 ),
507 sk: Arc::clone(&self.sk),
508 info: Arc::clone(&self.info),
509 ctx: None,
510 _marker: PhantomData,
511 };
512 Poll::Ready(Some(Ok(conn)))
513 }
514}
515
516#[pin_project]
521pub struct ServerConn<S, CS, Item, SinkItem>
522where
523 CS: CipherSuite,
524{
525 #[pin]
527 inner: Transport<S, ClientMsg, ServerMsg, MessagePack<ClientMsg, ServerMsg>>,
528 sk: Arc<ApiKey<CS>>,
530 info: Arc<[u8]>,
532 ctx: Option<Ctx<CS>>,
540 _marker: PhantomData<fn() -> (Item, SinkItem)>,
541}
542
543impl<S, CS, Item, SinkItem> ServerConn<S, CS, Item, SinkItem>
544where
545 CS: CipherSuite,
546 SinkItem: Serialize,
547{
548 fn encrypt(&mut self, item: SinkItem) -> io::Result<Data> {
554 self.ctx
555 .as_mut()
556 .assume("`self.ctx` should be `Some`")
557 .map_err(other)?
558 .encrypt::<Item, SinkItem>(item, Side::Server)
559 .map_err(other)
560 }
561}
562
563impl<S, CS, Item, SinkItem> ServerConn<S, CS, Item, SinkItem>
564where
565 CS: CipherSuite,
566 Item: DeserializeOwned,
567{
568 fn decrypt(&mut self, data: Data) -> io::Result<Item> {
574 self.ctx
575 .as_mut()
576 .assume("`self.ctx` should be `Some`")
577 .map_err(other)?
578 .decrypt::<Item, SinkItem>(data, Side::Client)
579 .map_err(other)
580 }
581}
582
583impl<S, CS, Item, SinkItem> ServerConn<S, CS, Item, SinkItem>
584where
585 CS: CipherSuite,
586{
587 fn rekey(&mut self, msg: Rekey) -> Result<(), HpkeError> {
590 let ctx = Ctx::server(&self.sk, &self.info, &msg.enc)?;
591 self.ctx = Some(ctx);
592 Ok(())
593 }
594}
595
596impl<S, CS, Item, SinkItem> Stream for ServerConn<S, CS, Item, SinkItem>
597where
598 S: AsyncRead + AsyncWrite + Unpin,
599 CS: CipherSuite,
600 Item: DeserializeOwned,
601{
602 type Item = io::Result<Item>;
603
604 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
605 loop {
607 let Some(msg) = ready!(self.as_mut().project().inner.poll_next(cx)?) else {
608 return Poll::Ready(None);
609 };
610 match msg {
611 ClientMsg::Data(data) => {
612 let pt = self.decrypt(data)?;
613 return Poll::Ready(Some(Ok(pt)));
614 }
615 ClientMsg::Rekey(rekey) => self.rekey(rekey).map_err(other)?,
616 }
617 }
618 }
619}
620
621impl<S, CS, Item, SinkItem> Sink<SinkItem> for ServerConn<S, CS, Item, SinkItem>
622where
623 S: AsyncRead + AsyncWrite + Unpin,
624 CS: CipherSuite,
625 SinkItem: Serialize,
626{
627 type Error = io::Error;
628
629 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
630 self.project().inner.poll_ready(cx)
631 }
632
633 fn start_send(mut self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
634 let data = self.encrypt(item)?;
635 self.project().inner.start_send(ServerMsg::Data(data))
636 }
637
638 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
639 self.project().inner.poll_flush(cx)
640 }
641
642 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
643 self.project().inner.poll_close(cx)
644 }
645}
646
647impl<S, CS, Item, SinkItem> fmt::Debug for ServerConn<S, CS, Item, SinkItem>
648where
649 CS: CipherSuite,
650{
651 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
652 f.debug_struct("Server")
653 .field("sk", &self.sk)
654 .field("info", &self.info)
655 .field("ctx", &self.ctx)
656 .finish_non_exhaustive()
657 }
658}
659
660#[derive(Clone, Debug, Serialize, Deserialize)]
662#[non_exhaustive]
663enum ServerMsg {
664 Data(Data),
665}
666
667#[cfg(unix)]
669#[cfg_attr(docsrs, doc(cfg(unix)))]
670pub mod unix {
671 use core::{
672 pin::Pin,
673 task::{Context, Poll},
674 };
675
676 use futures_util::{ready, Stream};
677 use tokio::{
678 io,
679 net::{UnixListener, UnixStream},
680 };
681
682 #[derive(Debug)]
684 pub struct UnixListenerStream(UnixListener);
685
686 impl Stream for UnixListenerStream {
687 type Item = io::Result<UnixStream>;
688
689 fn poll_next(
690 self: Pin<&mut Self>,
691 cx: &mut Context<'_>,
692 ) -> Poll<Option<io::Result<UnixStream>>> {
693 let (stream, _) = ready!(self.0.poll_accept(cx))?;
694 Poll::Ready(Some(Ok(stream)))
695 }
696 }
697
698 impl From<UnixListener> for UnixListenerStream {
699 #[inline]
700 fn from(listener: UnixListener) -> Self {
701 Self(listener)
702 }
703 }
704}
705
706#[cfg(test)]
707#[cfg(unix)]
708#[allow(clippy::arithmetic_side_effects, clippy::panic)]
709mod tests {
710 use std::panic;
711
712 use aranya_crypto::{
713 default::{DefaultCipherSuite, DefaultEngine},
714 Rng,
715 };
716 use backon::{ExponentialBuilder, Retryable as _};
717 use futures_util::{SinkExt, TryStreamExt};
718 use tokio::{
719 net::{UnixListener, UnixStream},
720 task::JoinSet,
721 };
722
723 use super::*;
724
725 impl<S, R, CS, Item, SinkItem> ClientConn<S, R, CS, Item, SinkItem>
726 where
727 S: AsyncRead + AsyncWrite + Unpin,
728 CS: CipherSuite,
729 {
730 fn force_rekey(&mut self) {
731 self.ctx = None;
732 }
733 }
734
735 type CS = DefaultCipherSuite;
736
737 #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
738 struct Ping {
739 v: usize,
740 }
741
742 #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
743 struct Pong {
744 v: usize,
745 }
746
747 #[tokio::test(flavor = "multi_thread")]
749 async fn test_ping_pong() {
750 let dir = tempfile::tempdir().unwrap();
751 let path = Arc::new(dir.path().to_path_buf().join("sock"));
752 let info = Arc::from(path.as_os_str().as_encoded_bytes());
753
754 let (eng, _) = DefaultEngine::from_entropy(Rng);
755 let sk = ApiKey::<CS>::new(&eng);
756 let pk = sk.public().unwrap();
757
758 const MAX_PING_PONGS: usize = 100;
759
760 let mut set = JoinSet::new();
761
762 {
763 let path = Arc::clone(&path);
764 let info = Arc::clone(&info);
765 set.spawn(async move {
766 let listener = UnixListener::bind(&*path)?;
767 let codec = LengthDelimitedCodec::builder()
768 .max_frame_length(usize::MAX)
769 .new_codec();
770 let mut server = server::<_, _, Ping, Pong>(
771 unix::UnixListenerStream::from(listener),
772 codec.clone(),
773 sk,
774 &info,
775 );
776
777 let mut conn = server.try_next().await.unwrap().unwrap();
778 for v in 0..MAX_PING_PONGS {
779 let got = conn.try_next().await?.ok_or_else(|| {
780 io::Error::new(io::ErrorKind::UnexpectedEof, "stream finished early")
781 })?;
782 assert_eq!(got, Ping { v });
783 conn.send(Pong {
784 v: got.v.wrapping_add(1),
785 })
786 .await?;
787 }
788 io::Result::Ok(())
789 });
790 }
791
792 {
793 let path = Arc::clone(&path);
794 let info = Arc::clone(&info);
795 set.spawn(async move {
796 let codec = LengthDelimitedCodec::builder()
797 .max_frame_length(usize::MAX)
798 .new_codec();
799 let sock = (|| UnixStream::connect(&*path))
800 .retry(ExponentialBuilder::default())
801 .await
802 .unwrap();
803 let mut client = client::<_, _, _, Pong, Ping>(sock, codec, Rng, pk, &info);
804 for v in 0..MAX_PING_PONGS {
805 client.send(Ping { v }).await?;
806 let got = client.try_next().await?.ok_or_else(|| {
807 io::Error::new(io::ErrorKind::UnexpectedEof, "stream finished early")
808 })?;
809 let want = Pong {
810 v: v.wrapping_add(1),
811 };
812 assert_eq!(got, want)
813 }
814 Ok(())
815 });
816 }
817
818 while let Some(res) = set.join_next().await {
819 match res {
820 Ok(Ok(())) => {}
821 Ok(Err(err)) => {
822 set.abort_all();
823 panic!("{err}");
824 }
825 Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
826 Err(err) => panic!("{err}"),
827 }
828 }
829 }
830
831 #[tokio::test(flavor = "multi_thread")]
833 async fn test_rekey() {
834 let dir = tempfile::tempdir().unwrap();
835 let path = Arc::new(dir.path().to_path_buf().join("sock"));
836 let info = Arc::from(path.as_os_str().as_encoded_bytes());
837
838 let (eng, _) = DefaultEngine::from_entropy(Rng);
839 let sk = ApiKey::<CS>::new(&eng);
840 let pk = sk.public().unwrap();
841
842 const MAX_PING_PONGS: usize = 100;
843
844 let mut set = JoinSet::new();
845
846 {
847 let path = Arc::clone(&path);
848 let info = Arc::clone(&info);
849 set.spawn(async move {
850 let listener = UnixListener::bind(&*path).unwrap();
851 let codec = LengthDelimitedCodec::builder()
852 .max_frame_length(usize::MAX)
853 .new_codec();
854 let mut server = server::<_, _, Ping, Pong>(
855 unix::UnixListenerStream::from(listener),
856 codec.clone(),
857 sk,
858 &info,
859 );
860 let mut conn = server.try_next().await.unwrap().unwrap();
861 for v in 0..MAX_PING_PONGS {
862 let got = conn.try_next().await?.ok_or_else(|| {
863 io::Error::new(io::ErrorKind::UnexpectedEof, "stream finished early")
864 })?;
865 let ctx = conn.ctx.as_ref().map(|ctx| &ctx.seal).unwrap();
869 assert_eq!(ctx.seq(), Seq::ZERO);
870
871 assert_eq!(got, Ping { v });
872 conn.send(Pong {
873 v: got.v.wrapping_add(1),
874 })
875 .await?;
876
877 let ctx = conn.ctx.as_ref().map(|ctx| &ctx.seal).unwrap();
879 assert_eq!(ctx.seq(), Seq::new(1));
880 }
881 io::Result::Ok(())
882 });
883 }
884
885 {
886 let path = Arc::clone(&path);
887 let info = Arc::clone(&info);
888 set.spawn(async move {
889 let codec = LengthDelimitedCodec::builder()
890 .max_frame_length(usize::MAX)
891 .new_codec();
892 let sock = (|| UnixStream::connect(&*path))
893 .retry(ExponentialBuilder::default())
894 .await
895 .unwrap();
896 let mut client = client::<_, _, _, Pong, Ping>(sock, codec, Rng, pk, &info);
897 for v in 0..MAX_PING_PONGS {
898 let last = client.rekeys;
899 client.force_rekey();
900 client.send(Ping { v }).await.unwrap();
901 assert_eq!(client.rekeys, last + 1);
902 let got = client.try_next().await?.ok_or_else(|| {
903 io::Error::new(io::ErrorKind::UnexpectedEof, "stream finished early")
904 })?;
905 let want = Pong {
906 v: v.wrapping_add(1),
907 };
908 assert_eq!(got, want)
909 }
910 Ok(())
911 });
912 }
913
914 while let Some(res) = set.join_next().await {
915 match res {
916 Ok(Ok(())) => {}
917 Ok(Err(err)) => {
918 set.abort_all();
919 panic!("{err}");
920 }
921 Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
922 Err(err) => panic!("{err}"),
923 }
924 }
925 }
926
927 #[tokio::test(flavor = "multi_thread")]
929 async fn test_multi_client() {
930 let dir = tempfile::tempdir().unwrap();
931 let path = Arc::new(dir.path().to_path_buf().join("sock"));
932 let info = Arc::from(path.as_os_str().as_encoded_bytes());
933
934 let (eng, _) = DefaultEngine::from_entropy(Rng);
935 let sk = ApiKey::<CS>::new(&eng);
936 let pk = sk.public().unwrap();
937
938 const MAX_PING_PONGS: usize = 2;
939 const MAX_CLIENTS: usize = 10;
940
941 let mut set = JoinSet::new();
942
943 {
944 let path = Arc::clone(&path);
945 let info = Arc::clone(&info);
946 set.spawn(async move {
947 let listener = UnixListener::bind(&*path).unwrap();
948 let codec = LengthDelimitedCodec::builder()
949 .max_frame_length(usize::MAX)
950 .new_codec();
951 let mut server = server::<_, _, Ping, Pong>(
952 unix::UnixListenerStream::from(listener),
953 codec.clone(),
954 sk,
955 &info,
956 );
957 let mut set = JoinSet::new();
958 for _ in 0..MAX_CLIENTS {
959 let mut conn = server.try_next().await?.unwrap();
960 set.spawn(async move {
961 for v in 0..MAX_PING_PONGS {
962 let got = conn.try_next().await?.ok_or_else(|| {
963 io::Error::new(
964 io::ErrorKind::UnexpectedEof,
965 "client stream finished early",
966 )
967 })?;
968 assert_eq!(got, Ping { v });
969 conn.send(Pong {
970 v: got.v.wrapping_add(1),
971 })
972 .await?;
973 }
974 io::Result::Ok(())
975 });
976 }
977 set.join_all()
978 .await
979 .into_iter()
980 .find(|v| v.is_err())
981 .unwrap_or(Ok(()))
982 });
983 }
984
985 for _ in 0..10 {
986 let path = Arc::clone(&path);
987 let info = Arc::clone(&info);
988 let pk = pk.clone();
989 set.spawn(async move {
990 let codec = LengthDelimitedCodec::builder()
991 .max_frame_length(usize::MAX)
992 .new_codec();
993 let sock = (|| UnixStream::connect(&*path))
994 .retry(ExponentialBuilder::default())
995 .await
996 .unwrap();
997 let mut client = client::<_, _, _, Pong, Ping>(sock, codec, Rng, pk, &info);
998 for v in 0..MAX_PING_PONGS {
999 client.send(Ping { v }).await?;
1000 let got = client.try_next().await?.ok_or_else(|| {
1001 io::Error::new(io::ErrorKind::UnexpectedEof, "server stream finished early")
1002 })?;
1003 let want = Pong {
1004 v: v.wrapping_add(1),
1005 };
1006 assert_eq!(got, want);
1007 }
1008 Ok(())
1009 });
1010 }
1011
1012 while let Some(res) = set.join_next().await {
1013 match res {
1014 Ok(Ok(())) => {}
1015 Ok(Err(err)) => {
1016 set.abort_all();
1017 panic!("{err}");
1018 }
1019 Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
1020 Err(err) => panic!("{err}"),
1021 }
1022 }
1023 }
1024}