cryprot_net/
lib.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    future::Future,
4    io::{Error, IoSlice},
5    mem,
6    pin::{pin, Pin},
7    sync::{
8        atomic::{AtomicU32, Ordering},
9        Arc,
10    },
11    task::{Context, Poll},
12};
13
14use bincode::Options;
15use s2n_quic::{
16    connection::{Handle, StreamAcceptor as QuicStreamAcceptor},
17    stream::{ReceiveStream as QuicRecvStream, SendStream as QuicSendStream},
18};
19use serde::{de::DeserializeOwned, Deserialize, Serialize};
20use tokio::{
21    io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
22    select,
23    sync::{mpsc, oneshot},
24};
25use tokio_serde::{
26    formats::{Bincode, SymmetricalBincode},
27    SymmetricallyFramed,
28};
29use tokio_util::codec::{length_delimited, FramedRead, FramedWrite, LengthDelimitedCodec};
30use tracing::{debug, error, event, Level};
31
32#[cfg(feature = "metrics")]
33pub mod metrics;
34
35#[doc(hidden)]
36#[cfg(any(test, feature = "__testing"))]
37pub mod testing;
38
39/// Explicit Id provided by the user for a stream for a specific [`Connection`].
40#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
41pub struct Id(pub(crate) u64);
42
43#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
44enum StreamId {
45    Implicit(u64),
46    Explicit(u64),
47}
48
49/// Id of a [`Connection`]. Does not include parent Ids of this connection.
50/// It is only unique with respect to its sibling connections created by
51/// [`Connection::sub_connection`].
52#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
53struct ConnectionId(pub(crate) u32);
54
55/// Unique id of a stream and all its parent [`ConnectionId`]s.
56#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
57struct UniqueId {
58    cids: Vec<ConnectionId>,
59    id: StreamId,
60}
61
62type StreamSend = oneshot::Sender<(QuicRecvStream, usize)>;
63type StreamRecv = oneshot::Receiver<(QuicRecvStream, usize)>;
64
65/// Manages accepting of new streams.
66pub struct StreamManager {
67    acceptor: QuicStreamAcceptor,
68    cmd_send: mpsc::UnboundedSender<Cmd>,
69    cmd_recv: mpsc::UnboundedReceiver<Cmd>,
70    pending: HashMap<UniqueId, StreamSend>,
71    accepted: HashMap<UniqueId, (QuicRecvStream, usize)>,
72}
73
74/// Used to create grouped sub-streams.
75///
76/// Connections can have sub-connections. Streams created via
77/// [`Connection::byte_stream`] and [`Connection::stream`] are tied to their
78/// connection. Streams created with the same [`Id`] but for different
79/// connections will not conflict with each other.
80#[derive(Debug)]
81pub struct Connection {
82    cids: Vec<ConnectionId>,
83    next_cid: Arc<AtomicU32>,
84    handle: Handle,
85    cmd: mpsc::UnboundedSender<Cmd>,
86    next_implicit_id: u64,
87}
88
89/// Send part of the bytes stream.
90pub struct SendStreamBytes {
91    inner: QuicSendStream,
92}
93
94/// Receive part of the bytes stream.
95pub struct ReceiveStreamBytes {
96    inner: ReceiveStreamWrapper,
97}
98
99/// Send part of the serialized stream.
100pub type SendStream<T> = SymmetricallyFramed<
101    FramedWrite<SendStreamBytes, LengthDelimitedCodec>,
102    T,
103    SymmetricalBincode<T>,
104>;
105
106pub type TempSendStream<'a, T> = SymmetricallyFramed<
107    FramedWrite<&'a mut SendStreamBytes, LengthDelimitedCodec>,
108    T,
109    SymmetricalBincode<T>,
110>;
111
112/// Receive part of the serialized stream.
113pub type ReceiveStream<T> = SymmetricallyFramed<
114    FramedRead<ReceiveStreamBytes, LengthDelimitedCodec>,
115    T,
116    SymmetricalBincode<T>,
117>;
118
119pub type ReceiveStreamTemp<'a, T> = SymmetricallyFramed<
120    FramedRead<&'a mut ReceiveStreamBytes, LengthDelimitedCodec>,
121    T,
122    SymmetricalBincode<T>,
123>;
124
125enum ReceiveStreamWrapper {
126    Channel { stream_recv: StreamRecv },
127    Stream { recv_stream: QuicRecvStream },
128}
129
130#[derive(Debug)]
131enum Cmd {
132    NewStream {
133        uid: UniqueId,
134        stream_return: StreamSend,
135    },
136    AcceptedStream {
137        uid: UniqueId,
138        stream: QuicRecvStream,
139        bytes_read: usize,
140    },
141}
142
143impl StreamManager {
144    pub fn new(acceptor: QuicStreamAcceptor) -> Self {
145        let (cmd_send, cmd_recv) = mpsc::unbounded_channel();
146        Self {
147            acceptor,
148            cmd_send,
149            cmd_recv,
150            pending: Default::default(),
151            accepted: Default::default(),
152        }
153    }
154
155    /// Start the StreamManager to accept streams.
156    ///
157    /// This method needs to be continually polled to establish new streams.
158    #[tracing::instrument(skip_all)]
159    pub async fn start(mut self) {
160        loop {
161            // Guard against possible cancellation unsafety of `accept_receive_stream`
162            let mut receive_stream = pin!(self.acceptor.accept_receive_stream());
163            select! {
164                res = &mut receive_stream => {
165                    match res {
166                        Ok(Some(stream)) => {
167                            debug!("accepted stream");
168                            Self::accepted(stream, self.cmd_send.clone());
169                        }
170                        Ok(None) => {
171                            debug!("remote closed");
172                            return;
173                        }
174                        Err(err) => {
175                            error!(%err, "unable to accept stream");
176                            return;
177                        }
178                    }
179                }
180                Some(cmd) = self.cmd_recv.recv() => {   // recv() is cancel safe
181                    debug!(?cmd, "received cmd");
182                    match cmd {
183                        Cmd::NewStream {uid, stream_return} => {
184                            if let Some(accepted) = self.accepted.remove(&uid) {
185                                if stream_return.send(accepted).is_err() {
186                                    debug!("accepted remote stream but local receiver is closed");
187                                }
188                                debug!("sending new stream to receiver");
189                                continue;
190                            }
191                            match self.pending.entry(uid) {
192                                Entry::Occupied(occupied_entry) => {
193                                    panic!("Duplicate unique id: {:?}", occupied_entry.key())
194                                },
195                                Entry::Vacant(vacant_entry) => {vacant_entry.insert(stream_return);},
196                            }
197                        }
198                        Cmd::AcceptedStream {uid, stream, bytes_read} => {
199                            if let Some(stream_ret) = self.pending.remove(&uid) {
200                               if stream_ret.send((stream, bytes_read)).is_err() {
201                                debug!("accepted remote stream but local receiver is closed");
202                               }
203                            } else {
204                                debug!("accepted stream but no pending");
205                                self.accepted.insert(uid, (stream, bytes_read));
206                            }
207                        }
208                    }
209                }
210            }
211        }
212    }
213
214    // not taking &self to work around borrow issue
215    fn accepted(mut stream: QuicRecvStream, cmd_send: mpsc::UnboundedSender<Cmd>) {
216        tokio::spawn(async move {
217            let (uid, bytes_read) = match UniqueId::read_from(&mut stream).await {
218                Ok(ret) => ret,
219                Err(err) => {
220                    error!(?err, "unable to read stream unique id");
221                    return;
222                }
223            };
224            cmd_send
225                .send(Cmd::AcceptedStream {
226                    uid,
227                    stream,
228                    bytes_read,
229                })
230                .expect("cmd_rcv is owned by StreamManager")
231        });
232    }
233}
234
235#[derive(thiserror::Error, Debug)]
236pub enum ConnectionError {
237    #[error("Unable to open stream")]
238    OpenStream(#[source] s2n_quic::connection::Error),
239    #[error("io error during stream establishment")]
240    IoError(#[source] io::Error),
241    #[error("StreamManager is dropped and not accepting connections")]
242    StreamManagerDropped,
243    #[error("Stream unique id deserialization failed")]
244    UniqueIdDeserialization(#[source] bincode::Error),
245    #[error("Stream unique id serialization failed")]
246    UniqueIdSerialization(#[source] bincode::Error),
247    #[error("Reached maximum number of sub connections")]
248    SubConnectionLimitReached,
249}
250
251impl Connection {
252    pub fn new(quic_conn: s2n_quic::Connection) -> (Self, StreamManager) {
253        let (handle, acceptor) = quic_conn.split();
254        let stream_manager = StreamManager::new(acceptor);
255        let conn = Self {
256            cids: vec![],
257            next_cid: Arc::new(AtomicU32::new(0)),
258            handle,
259            cmd: stream_manager.cmd_send.clone(),
260            next_implicit_id: 0,
261        };
262        (conn, stream_manager)
263    }
264
265    /// Create a sub-connection. The n'th call to sub_connection
266    /// is paired with the n'th call to `sub_connection` on the corresponding
267    /// [`Connection`] of the other party. Creating a sub-connection results
268    /// in no immediate communication and is a fast synchronous operation.
269    #[tracing::instrument(level = Level::DEBUG, skip(self), ret)]
270    pub fn sub_connection(&mut self) -> Self {
271        let cid = self.next_cid.fetch_add(1, Ordering::Relaxed);
272        let mut cids = self.cids.clone();
273        cids.push(ConnectionId(cid));
274        Self {
275            cids,
276            next_cid: Arc::new(AtomicU32::new(0)),
277            handle: self.handle.clone(),
278            cmd: self.cmd.clone(),
279            next_implicit_id: 0,
280        }
281    }
282
283    async fn internal_byte_stream(
284        &self,
285        stream_id: StreamId,
286    ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
287        let uid = UniqueId::new(self.cids.clone(), stream_id);
288        let mut snd = self
289            .handle
290            .clone()
291            .open_send_stream()
292            .await
293            .map_err(ConnectionError::OpenStream)?;
294        let bytes_written = uid.write_into(&mut snd).await?;
295        event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes_written);
296        let (stream_return, stream_recv) = oneshot::channel();
297        self.cmd
298            .send(Cmd::NewStream { uid, stream_return })
299            .map_err(|_| ConnectionError::StreamManagerDropped)?;
300        let snd = SendStreamBytes { inner: snd };
301        let recv = ReceiveStreamBytes {
302            inner: ReceiveStreamWrapper::Channel { stream_recv },
303        };
304        Ok((snd, recv))
305    }
306
307    /// Establish a byte stream over this connection with the provided Id.
308    pub async fn byte_stream(
309        &mut self,
310    ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
311        self.next_implicit_id += 1;
312        self.internal_byte_stream(StreamId::Implicit(self.next_implicit_id - 1))
313            .await
314    }
315
316    /// Establish a byte stream over this connection with the provided Id.
317    pub async fn byte_stream_with_id(
318        &self,
319        id: Id,
320    ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
321        self.internal_byte_stream(StreamId::Explicit(id.0)).await
322    }
323
324    /// Establish a typed stream over this connection.
325    async fn internal_stream<T: Serialize + DeserializeOwned>(
326        &self,
327        id: StreamId,
328    ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
329        let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
330        let mut ld_codec = LengthDelimitedCodec::builder();
331        // TODO what is a sensible max length?
332        const MB: usize = 1024 * 1024;
333        ld_codec.max_frame_length(256 * MB);
334        let framed_send = ld_codec.new_write(send_bytes);
335        let framed_read = ld_codec.new_read(recv_bytes);
336        let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
337        let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
338        Ok((serde_send, serde_read))
339    }
340
341    /// Establish a typed stream over this connection.
342    pub async fn stream<T: Serialize + DeserializeOwned>(
343        &mut self,
344    ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
345        self.next_implicit_id += 1;
346        self.internal_stream(StreamId::Implicit(self.next_implicit_id - 1))
347            .await
348    }
349
350    /// Establish a typed stream over this connection with the provided explicit
351    /// Id.
352    pub async fn stream_with_id<T: Serialize + DeserializeOwned>(
353        &self,
354        id: Id,
355    ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
356        self.internal_stream(StreamId::Explicit(id.0)).await
357    }
358
359    async fn internal_request_response_stream<T: Serialize, S: DeserializeOwned>(
360        &self,
361        id: StreamId,
362    ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
363        let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
364        let framed_send = default_codec().new_write(send_bytes);
365        let framed_read = default_codec().new_read(recv_bytes);
366        let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
367        let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
368        Ok((serde_send, serde_read))
369    }
370
371    /// Establish a typed request-response stream over this connection with
372    /// differing types for the request and response.
373    pub async fn request_response_stream<T: Serialize, S: DeserializeOwned>(
374        &mut self,
375    ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
376        self.next_implicit_id += 1;
377        self.internal_request_response_stream(StreamId::Implicit(self.next_implicit_id - 1))
378            .await
379    }
380
381    /// Establish a typed request-response stream over this connection with
382    /// differing types for the request and response.
383    pub async fn request_response_stream_with_id<T: Serialize, S: DeserializeOwned>(
384        &self,
385        id: Id,
386    ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
387        self.internal_request_response_stream(StreamId::Explicit(id.0))
388            .await
389    }
390}
391
392impl Id {
393    pub fn new(id: u64) -> Self {
394        Self(id)
395    }
396}
397
398fn bincode_opts() -> impl bincode::Options {
399    bincode::options().with_big_endian().with_varint_encoding()
400}
401
402impl UniqueId {
403    fn new(cids: Vec<ConnectionId>, id: StreamId) -> Self {
404        Self { cids, id }
405    }
406
407    async fn write_into<W: AsyncWrite>(&self, write: W) -> Result<usize, ConnectionError> {
408        let mut write = pin!(write);
409        let mut options = bincode_opts();
410        let serialized = (&mut options)
411            .serialize(self)
412            .map_err(ConnectionError::UniqueIdSerialization)?;
413        write
414            .write_u32(
415                serialized
416                    .len()
417                    .try_into()
418                    .map_err(|_| ConnectionError::SubConnectionLimitReached)?,
419            )
420            .await
421            .map_err(ConnectionError::IoError)?;
422        write
423            .write_all(&serialized)
424            .await
425            .map_err(ConnectionError::IoError)?;
426        Ok(mem::size_of::<u32>() + serialized.len())
427    }
428
429    async fn read_from<R: AsyncRead>(reader: R) -> Result<(Self, usize), ConnectionError> {
430        let mut reader = pin!(reader);
431        let len = reader.read_u32().await.map_err(ConnectionError::IoError)?;
432        let mut buf = vec![0; len as usize];
433        reader
434            .read_exact(&mut buf)
435            .await
436            .map_err(ConnectionError::IoError)?;
437        let uid = bincode_opts()
438            .deserialize(&buf)
439            .map_err(ConnectionError::UniqueIdDeserialization)?;
440        Ok((uid, mem::size_of::<u32>() + len as usize))
441    }
442}
443
444#[derive(thiserror::Error, Debug)]
445pub enum StreamError {
446    #[error("unable to flush stream")]
447    Flush(#[source] s2n_quic::stream::Error),
448    #[error("unable to close stream")]
449    Close(#[source] s2n_quic::stream::Error),
450    #[error("unable to finish stream")]
451    Finish(#[source] s2n_quic::stream::Error),
452}
453
454impl SendStreamBytes {
455    pub async fn flush(&mut self) -> Result<(), StreamError> {
456        self.inner.flush().await.map_err(StreamError::Flush)
457    }
458
459    pub fn finish(&mut self) -> Result<(), StreamError> {
460        self.inner.finish().map_err(StreamError::Finish)
461    }
462
463    pub async fn close(&mut self) -> Result<(), StreamError> {
464        self.inner.close().await.map_err(StreamError::Close)
465    }
466
467    pub fn as_stream<T: Serialize>(&mut self) -> TempSendStream<T> {
468        let framed_send = default_codec().new_write(self);
469        SymmetricallyFramed::new(framed_send, Bincode::default())
470    }
471}
472
473impl AsyncWrite for SendStreamBytes {
474    fn poll_write(
475        mut self: Pin<&mut Self>,
476        cx: &mut Context<'_>,
477        buf: &[u8],
478    ) -> Poll<Result<usize, Error>> {
479        let inner = Pin::new(&mut self.inner);
480        trace_poll(inner.poll_write(cx, buf))
481    }
482
483    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
484        let inner = Pin::new(&mut self.inner);
485        AsyncWrite::poll_flush(inner, cx)
486    }
487
488    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
489        let inner = Pin::new(&mut self.inner);
490        inner.poll_shutdown(cx)
491    }
492
493    fn poll_write_vectored(
494        mut self: Pin<&mut Self>,
495        cx: &mut Context<'_>,
496        bufs: &[IoSlice<'_>],
497    ) -> Poll<Result<usize, Error>> {
498        let inner = Pin::new(&mut self.inner);
499        trace_poll(inner.poll_write_vectored(cx, bufs))
500    }
501
502    fn is_write_vectored(&self) -> bool {
503        self.inner.is_write_vectored()
504    }
505}
506
507fn trace_poll(p: Poll<io::Result<usize>>) -> Poll<io::Result<usize>> {
508    if let Poll::Ready(Ok(bytes)) = p {
509        event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes);
510    }
511    p
512}
513
514impl ReceiveStreamBytes {
515    pub fn as_stream<T: DeserializeOwned>(&mut self) -> ReceiveStreamTemp<T> {
516        let framed_read = default_codec().new_read(self);
517        SymmetricallyFramed::new(framed_read, Bincode::default())
518    }
519}
520
521// Implement AsyncRead for ReceiveStream to poll the oneshot Receiver first if
522// there is not already a channel.
523impl AsyncRead for ReceiveStreamBytes {
524    fn poll_read(
525        mut self: Pin<&mut Self>,
526        cx: &mut Context<'_>,
527        buf: &mut ReadBuf<'_>,
528    ) -> Poll<std::io::Result<()>> {
529        match &mut self.inner {
530            ReceiveStreamWrapper::Channel { stream_recv } => match Pin::new(stream_recv).poll(cx) {
531                Poll::Pending => Poll::Pending,
532                Poll::Ready(Ok((recv_stream, bytes_read))) => {
533                    // We know we read those bytes in the StreamManager, so we emit
534                    // the corresponding event here in the correct span.
535                    event!(target: "cryprot_metrics", Level::TRACE, bytes_read);
536                    self.inner = ReceiveStreamWrapper::Stream { recv_stream };
537                    self.poll_read(cx, buf)
538                }
539                Poll::Ready(Err(err)) => Poll::Ready(Err(std::io::Error::other(Box::new(err)))),
540            },
541            ReceiveStreamWrapper::Stream { recv_stream } => {
542                let len = buf.filled().len();
543                let poll = Pin::new(recv_stream).poll_read(cx, buf);
544                if let Poll::Ready(Ok(())) = poll {
545                    let bytes = buf.filled().len() - len;
546                    if bytes > 0 {
547                        event!(target: "cryprot_metrics", Level::TRACE, bytes_read = bytes);
548                    }
549                }
550                poll
551            }
552        }
553    }
554}
555
556fn default_codec() -> length_delimited::Builder {
557    let mut ld_codec = LengthDelimitedCodec::builder();
558    const MB: usize = 1024 * 1024;
559    ld_codec.max_frame_length(20 * MB);
560    ld_codec
561}
562
563#[cfg(test)]
564mod tests {
565    use std::u8;
566
567    use anyhow::{Context, Result};
568    use futures::{SinkExt, StreamExt};
569    use tokio::{
570        io::{AsyncReadExt, AsyncWriteExt},
571        task::JoinSet,
572    };
573    use tracing::debug;
574
575    use crate::{
576        testing::{init_tracing, local_conn},
577        Id,
578    };
579
580    #[tokio::test]
581    async fn create_local_conn() -> Result<()> {
582        let _g = init_tracing();
583        let _ = local_conn().await?;
584        Ok(())
585    }
586
587    #[tokio::test]
588    async fn byte_stream() -> Result<()> {
589        let _g = init_tracing();
590        let (mut s, mut c) = local_conn().await?;
591        let (mut s_send, _) = s.byte_stream().await?;
592        let (_, mut c_recv) = c.byte_stream().await?;
593        let send_buf = b"hello there";
594        s_send.write_all(send_buf).await?;
595        let mut buf = [0; 11];
596        c_recv.read_exact(&mut buf).await?;
597        assert_eq!(send_buf, &buf);
598        Ok(())
599    }
600
601    #[tokio::test]
602    async fn byte_stream_explicit_implicit_id() -> Result<()> {
603        let _g = init_tracing();
604        let (mut s, mut c) = local_conn().await?;
605        let (mut s_send1, _) = s.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
606        let (mut s_send2, _) = s.byte_stream().await?;
607        let (_, mut c_recv1) = c.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
608        let (_, mut c_recv2) = c.byte_stream().await?;
609        let send_buf1 = b"hello there";
610        s_send1.write_all(send_buf1).await?;
611        let mut buf = [0; 11];
612        c_recv1.read_exact(&mut buf).await?;
613        assert_eq!(send_buf1, &buf);
614
615        let send_buf2 = b"general kenobi";
616        s_send2.write_all(send_buf2).await?;
617        let mut buf = [0; 14];
618        c_recv2.read_exact(&mut buf).await?;
619        assert_eq!(send_buf2, &buf);
620        Ok(())
621    }
622
623    #[tokio::test]
624    async fn byte_stream_different_order() -> Result<()> {
625        let _g = init_tracing();
626        let (mut s, mut c) = local_conn().await?;
627        let (mut s_send, mut s_recv) = s.byte_stream().await?;
628        let s_send_buf = b"hello there";
629        s_send.write_all(s_send_buf).await?;
630        let mut s_recv_buf = [0; 2];
631        // By already spawning the read task before the client calls c._new_byte_stream
632        // we check that the switch from channel to s2n stream works
633        let jh = tokio::spawn(async move {
634            s_recv.read_exact(&mut s_recv_buf).await.unwrap();
635            s_recv_buf
636        });
637        let (mut c_send, mut c_recv) = c.byte_stream().await?;
638        let mut c_recv_buf = [0; 11];
639        c_recv.read_exact(&mut c_recv_buf).await?;
640        assert_eq!(s_send_buf, &c_recv_buf);
641        let c_send_buf = b"42";
642        c_send.write_all(c_send_buf).await?;
643        let s_recv_buf = jh.await?;
644        assert_eq!(c_send_buf, &s_recv_buf);
645        Ok(())
646    }
647
648    #[tokio::test]
649    async fn many_parallel_byte_streams() -> Result<()> {
650        let _g = init_tracing();
651        let (mut c1, mut c2) = local_conn().await?;
652        let mut jhs = JoinSet::new();
653        for i in 0..10 {
654            let ((mut s, _), (_, mut r)) =
655                tokio::try_join!(c1.byte_stream(), c2.byte_stream()).unwrap();
656
657            let jh = tokio::spawn(async move {
658                let buf = vec![0; 10 * 1024 * 1024];
659                s.write_all(&buf).await.unwrap();
660                debug!("wrote buf {i}");
661            });
662            jhs.spawn(jh);
663            let jh = tokio::spawn(async move {
664                let mut buf = vec![0; 10 * 1024 * 1024];
665                r.read_exact(&mut buf).await.unwrap();
666                debug!("received buf {i}");
667            });
668            jhs.spawn(jh);
669        }
670        let res = jhs.join_all().await;
671        for res in res {
672            res.unwrap();
673        }
674        Ok(())
675    }
676
677    #[tokio::test]
678    async fn serde_stream() -> Result<()> {
679        let _g = init_tracing();
680        let (mut s, mut c) = local_conn().await?;
681        let (mut snd, _) = s.stream::<Vec<i32>>().await?;
682        let (_, mut recv) = c.stream::<Vec<i32>>().await?;
683        snd.send(vec![1, 2, 3]).await?;
684        let ret = recv.next().await.context("recv")??;
685        assert_eq!(vec![1, 2, 3], ret);
686        drop(snd);
687        let ret = recv.next().await.map(|res| res.map_err(|_| ()));
688        assert_eq!(None, ret);
689        Ok(())
690    }
691
692    #[tokio::test]
693    async fn serde_stream_block() -> Result<()> {
694        let _g = init_tracing();
695        let (mut s, mut c) = local_conn().await?;
696        let (mut snd, _) = s.stream().await?;
697        let (_, mut recv) = c.stream().await?;
698        snd.send(vec![u8::MAX; 16]).await?;
699        let ret: Vec<_> = recv.next().await.context("recv")??;
700        assert_eq!(vec![u8::MAX; 16], ret);
701        Ok(())
702    }
703
704    #[tokio::test]
705    async fn serde_byte_stream_as_stream() -> Result<()> {
706        let _g = init_tracing();
707        let (mut s, mut c) = local_conn().await?;
708        let (mut s_send, _) = s.byte_stream().await?;
709        let (_, mut c_recv) = c.byte_stream().await?;
710        {
711            let mut send_ser1 = s_send.as_stream::<i32>();
712            let mut recv_ser1 = c_recv.as_stream::<i32>();
713            send_ser1.send(42).await?;
714            let ret = recv_ser1.next().await.context("recv")??;
715            assert_eq!(42, ret);
716        }
717        {
718            let mut send_ser2 = s_send.as_stream::<Vec<i32>>();
719            let mut recv_ser2 = c_recv.as_stream::<Vec<i32>>();
720            send_ser2.send(vec![1, 2, 3]).await?;
721            let ret = recv_ser2.next().await.context("recv")??;
722            assert_eq!(vec![1, 2, 3], ret);
723        }
724        Ok(())
725    }
726
727    #[tokio::test]
728    async fn serde_request_response_stream() -> Result<()> {
729        let _g = init_tracing();
730        let (mut s, mut c) = local_conn().await?;
731        let (mut snd1, mut recv1) = s.request_response_stream::<Vec<i32>, String>().await?;
732        let (mut snd2, mut recv2) = c.request_response_stream::<String, Vec<i32>>().await?;
733        snd1.send(vec![1, 2, 3]).await?;
734        let ret = recv2.next().await.context("recv")??;
735        assert_eq!(vec![1, 2, 3], ret);
736        snd2.send("hello there".to_string()).await?;
737        let ret = recv1.next().await.context("recv2")??;
738        assert_eq!("hello there", &ret);
739        Ok(())
740    }
741
742    #[tokio::test]
743    async fn sub_connection() -> Result<()> {
744        let _g = init_tracing();
745        let (mut s1, mut c1) = local_conn().await?;
746        let mut s2 = s1.sub_connection();
747        let mut c2 = c1.sub_connection();
748        let _ = s1.byte_stream();
749        let _ = c1.byte_stream();
750        let (mut snd, _) = s2.stream::<Vec<i32>>().await?;
751        let (_, mut recv) = c2.stream::<Vec<i32>>().await?;
752
753        snd.send(vec![1, 2, 3]).await?;
754        let ret = recv.next().await.context("recv")??;
755        assert_eq!(vec![1, 2, 3], ret);
756        Ok(())
757    }
758
759    #[tokio::test]
760    async fn sub_sub_connection() -> Result<()> {
761        let _g = init_tracing();
762        let (mut s1, mut c1) = local_conn().await?;
763        let mut s2 = s1.sub_connection();
764        let mut c2 = c1.sub_connection();
765        let mut s3 = s2.sub_connection();
766        let mut c3 = c2.sub_connection();
767        let _ = s1.byte_stream();
768        let _ = c1.byte_stream();
769        let _ = s2.byte_stream();
770        let _ = c2.byte_stream();
771        let (mut snd, _) = s3.stream::<Vec<i32>>().await?;
772        let (_, mut recv) = c3.stream::<Vec<i32>>().await?;
773
774        snd.send(vec![1, 2, 3]).await?;
775        let ret = recv.next().await.context("recv")??;
776        assert_eq!(vec![1, 2, 3], ret);
777        Ok(())
778    }
779}