1use std::{
2 collections::{hash_map::Entry, HashMap},
3 future::Future,
4 io::{Error, IoSlice},
5 mem,
6 pin::{pin, Pin},
7 sync::{
8 atomic::{AtomicU32, Ordering},
9 Arc,
10 },
11 task::{Context, Poll},
12};
13
14use bincode::Options;
15use s2n_quic::{
16 connection::{Handle, StreamAcceptor as QuicStreamAcceptor},
17 stream::{ReceiveStream as QuicRecvStream, SendStream as QuicSendStream},
18};
19use serde::{de::DeserializeOwned, Deserialize, Serialize};
20use tokio::{
21 io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf},
22 select,
23 sync::{mpsc, oneshot},
24};
25use tokio_serde::{
26 formats::{Bincode, SymmetricalBincode},
27 SymmetricallyFramed,
28};
29use tokio_util::codec::{length_delimited, FramedRead, FramedWrite, LengthDelimitedCodec};
30use tracing::{debug, error, event, Level};
31
32#[cfg(feature = "metrics")]
33pub mod metrics;
34
35#[doc(hidden)]
36#[cfg(any(test, feature = "__testing"))]
37pub mod testing;
38
39#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
41pub struct Id(pub(crate) u64);
42
43#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
44enum StreamId {
45 Implicit(u64),
46 Explicit(u64),
47}
48
49#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
53struct ConnectionId(pub(crate) u32);
54
55#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)]
57struct UniqueId {
58 cids: Vec<ConnectionId>,
59 id: StreamId,
60}
61
62type StreamSend = oneshot::Sender<(QuicRecvStream, usize)>;
63type StreamRecv = oneshot::Receiver<(QuicRecvStream, usize)>;
64
65pub struct StreamManager {
67 acceptor: QuicStreamAcceptor,
68 cmd_send: mpsc::UnboundedSender<Cmd>,
69 cmd_recv: mpsc::UnboundedReceiver<Cmd>,
70 pending: HashMap<UniqueId, StreamSend>,
71 accepted: HashMap<UniqueId, (QuicRecvStream, usize)>,
72}
73
74#[derive(Debug)]
81pub struct Connection {
82 cids: Vec<ConnectionId>,
83 next_cid: Arc<AtomicU32>,
84 handle: Handle,
85 cmd: mpsc::UnboundedSender<Cmd>,
86 next_implicit_id: u64,
87}
88
89pub struct SendStreamBytes {
91 inner: QuicSendStream,
92}
93
94pub struct ReceiveStreamBytes {
96 inner: ReceiveStreamWrapper,
97}
98
99pub type SendStream<T> = SymmetricallyFramed<
101 FramedWrite<SendStreamBytes, LengthDelimitedCodec>,
102 T,
103 SymmetricalBincode<T>,
104>;
105
106pub type TempSendStream<'a, T> = SymmetricallyFramed<
107 FramedWrite<&'a mut SendStreamBytes, LengthDelimitedCodec>,
108 T,
109 SymmetricalBincode<T>,
110>;
111
112pub type ReceiveStream<T> = SymmetricallyFramed<
114 FramedRead<ReceiveStreamBytes, LengthDelimitedCodec>,
115 T,
116 SymmetricalBincode<T>,
117>;
118
119pub type ReceiveStreamTemp<'a, T> = SymmetricallyFramed<
120 FramedRead<&'a mut ReceiveStreamBytes, LengthDelimitedCodec>,
121 T,
122 SymmetricalBincode<T>,
123>;
124
125enum ReceiveStreamWrapper {
126 Channel { stream_recv: StreamRecv },
127 Stream { recv_stream: QuicRecvStream },
128}
129
130#[derive(Debug)]
131enum Cmd {
132 NewStream {
133 uid: UniqueId,
134 stream_return: StreamSend,
135 },
136 AcceptedStream {
137 uid: UniqueId,
138 stream: QuicRecvStream,
139 bytes_read: usize,
140 },
141}
142
143impl StreamManager {
144 pub fn new(acceptor: QuicStreamAcceptor) -> Self {
145 let (cmd_send, cmd_recv) = mpsc::unbounded_channel();
146 Self {
147 acceptor,
148 cmd_send,
149 cmd_recv,
150 pending: Default::default(),
151 accepted: Default::default(),
152 }
153 }
154
155 #[tracing::instrument(skip_all)]
159 pub async fn start(mut self) {
160 loop {
161 let mut receive_stream = pin!(self.acceptor.accept_receive_stream());
163 select! {
164 res = &mut receive_stream => {
165 match res {
166 Ok(Some(stream)) => {
167 debug!("accepted stream");
168 Self::accepted(stream, self.cmd_send.clone());
169 }
170 Ok(None) => {
171 debug!("remote closed");
172 return;
173 }
174 Err(err) => {
175 error!(%err, "unable to accept stream");
176 return;
177 }
178 }
179 }
180 Some(cmd) = self.cmd_recv.recv() => { debug!(?cmd, "received cmd");
182 match cmd {
183 Cmd::NewStream {uid, stream_return} => {
184 if let Some(accepted) = self.accepted.remove(&uid) {
185 if stream_return.send(accepted).is_err() {
186 debug!("accepted remote stream but local receiver is closed");
187 }
188 debug!("sending new stream to receiver");
189 continue;
190 }
191 match self.pending.entry(uid) {
192 Entry::Occupied(occupied_entry) => {
193 panic!("Duplicate unique id: {:?}", occupied_entry.key())
194 },
195 Entry::Vacant(vacant_entry) => {vacant_entry.insert(stream_return);},
196 }
197 }
198 Cmd::AcceptedStream {uid, stream, bytes_read} => {
199 if let Some(stream_ret) = self.pending.remove(&uid) {
200 if stream_ret.send((stream, bytes_read)).is_err() {
201 debug!("accepted remote stream but local receiver is closed");
202 }
203 } else {
204 debug!("accepted stream but no pending");
205 self.accepted.insert(uid, (stream, bytes_read));
206 }
207 }
208 }
209 }
210 }
211 }
212 }
213
214 fn accepted(mut stream: QuicRecvStream, cmd_send: mpsc::UnboundedSender<Cmd>) {
216 tokio::spawn(async move {
217 let (uid, bytes_read) = match UniqueId::read_from(&mut stream).await {
218 Ok(ret) => ret,
219 Err(err) => {
220 error!(?err, "unable to read stream unique id");
221 return;
222 }
223 };
224 cmd_send
225 .send(Cmd::AcceptedStream {
226 uid,
227 stream,
228 bytes_read,
229 })
230 .expect("cmd_rcv is owned by StreamManager")
231 });
232 }
233}
234
235#[derive(thiserror::Error, Debug)]
236pub enum ConnectionError {
237 #[error("Unable to open stream")]
238 OpenStream(#[source] s2n_quic::connection::Error),
239 #[error("io error during stream establishment")]
240 IoError(#[source] io::Error),
241 #[error("StreamManager is dropped and not accepting connections")]
242 StreamManagerDropped,
243 #[error("Stream unique id deserialization failed")]
244 UniqueIdDeserialization(#[source] bincode::Error),
245 #[error("Stream unique id serialization failed")]
246 UniqueIdSerialization(#[source] bincode::Error),
247 #[error("Reached maximum number of sub connections")]
248 SubConnectionLimitReached,
249}
250
251impl Connection {
252 pub fn new(quic_conn: s2n_quic::Connection) -> (Self, StreamManager) {
253 let (handle, acceptor) = quic_conn.split();
254 let stream_manager = StreamManager::new(acceptor);
255 let conn = Self {
256 cids: vec![],
257 next_cid: Arc::new(AtomicU32::new(0)),
258 handle,
259 cmd: stream_manager.cmd_send.clone(),
260 next_implicit_id: 0,
261 };
262 (conn, stream_manager)
263 }
264
265 #[tracing::instrument(level = Level::DEBUG, skip(self), ret)]
270 pub fn sub_connection(&mut self) -> Self {
271 let cid = self.next_cid.fetch_add(1, Ordering::Relaxed);
272 let mut cids = self.cids.clone();
273 cids.push(ConnectionId(cid));
274 Self {
275 cids,
276 next_cid: Arc::new(AtomicU32::new(0)),
277 handle: self.handle.clone(),
278 cmd: self.cmd.clone(),
279 next_implicit_id: 0,
280 }
281 }
282
283 async fn internal_byte_stream(
284 &self,
285 stream_id: StreamId,
286 ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
287 let uid = UniqueId::new(self.cids.clone(), stream_id);
288 let mut snd = self
289 .handle
290 .clone()
291 .open_send_stream()
292 .await
293 .map_err(ConnectionError::OpenStream)?;
294 let bytes_written = uid.write_into(&mut snd).await?;
295 event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes_written);
296 let (stream_return, stream_recv) = oneshot::channel();
297 self.cmd
298 .send(Cmd::NewStream { uid, stream_return })
299 .map_err(|_| ConnectionError::StreamManagerDropped)?;
300 let snd = SendStreamBytes { inner: snd };
301 let recv = ReceiveStreamBytes {
302 inner: ReceiveStreamWrapper::Channel { stream_recv },
303 };
304 Ok((snd, recv))
305 }
306
307 pub async fn byte_stream(
309 &mut self,
310 ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
311 self.next_implicit_id += 1;
312 self.internal_byte_stream(StreamId::Implicit(self.next_implicit_id - 1))
313 .await
314 }
315
316 pub async fn byte_stream_with_id(
318 &self,
319 id: Id,
320 ) -> Result<(SendStreamBytes, ReceiveStreamBytes), ConnectionError> {
321 self.internal_byte_stream(StreamId::Explicit(id.0)).await
322 }
323
324 async fn internal_stream<T: Serialize + DeserializeOwned>(
326 &self,
327 id: StreamId,
328 ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
329 let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
330 let mut ld_codec = LengthDelimitedCodec::builder();
331 const MB: usize = 1024 * 1024;
333 ld_codec.max_frame_length(256 * MB);
334 let framed_send = ld_codec.new_write(send_bytes);
335 let framed_read = ld_codec.new_read(recv_bytes);
336 let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
337 let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
338 Ok((serde_send, serde_read))
339 }
340
341 pub async fn stream<T: Serialize + DeserializeOwned>(
343 &mut self,
344 ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
345 self.next_implicit_id += 1;
346 self.internal_stream(StreamId::Implicit(self.next_implicit_id - 1))
347 .await
348 }
349
350 pub async fn stream_with_id<T: Serialize + DeserializeOwned>(
353 &self,
354 id: Id,
355 ) -> Result<(SendStream<T>, ReceiveStream<T>), ConnectionError> {
356 self.internal_stream(StreamId::Explicit(id.0)).await
357 }
358
359 async fn internal_request_response_stream<T: Serialize, S: DeserializeOwned>(
360 &self,
361 id: StreamId,
362 ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
363 let (send_bytes, recv_bytes) = self.internal_byte_stream(id).await?;
364 let framed_send = default_codec().new_write(send_bytes);
365 let framed_read = default_codec().new_read(recv_bytes);
366 let serde_send = SymmetricallyFramed::new(framed_send, Bincode::default());
367 let serde_read = SymmetricallyFramed::new(framed_read, Bincode::default());
368 Ok((serde_send, serde_read))
369 }
370
371 pub async fn request_response_stream<T: Serialize, S: DeserializeOwned>(
374 &mut self,
375 ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
376 self.next_implicit_id += 1;
377 self.internal_request_response_stream(StreamId::Implicit(self.next_implicit_id - 1))
378 .await
379 }
380
381 pub async fn request_response_stream_with_id<T: Serialize, S: DeserializeOwned>(
384 &self,
385 id: Id,
386 ) -> Result<(SendStream<T>, ReceiveStream<S>), ConnectionError> {
387 self.internal_request_response_stream(StreamId::Explicit(id.0))
388 .await
389 }
390}
391
392impl Id {
393 pub fn new(id: u64) -> Self {
394 Self(id)
395 }
396}
397
398fn bincode_opts() -> impl bincode::Options {
399 bincode::options().with_big_endian().with_varint_encoding()
400}
401
402impl UniqueId {
403 fn new(cids: Vec<ConnectionId>, id: StreamId) -> Self {
404 Self { cids, id }
405 }
406
407 async fn write_into<W: AsyncWrite>(&self, write: W) -> Result<usize, ConnectionError> {
408 let mut write = pin!(write);
409 let mut options = bincode_opts();
410 let serialized = (&mut options)
411 .serialize(self)
412 .map_err(ConnectionError::UniqueIdSerialization)?;
413 write
414 .write_u32(
415 serialized
416 .len()
417 .try_into()
418 .map_err(|_| ConnectionError::SubConnectionLimitReached)?,
419 )
420 .await
421 .map_err(ConnectionError::IoError)?;
422 write
423 .write_all(&serialized)
424 .await
425 .map_err(ConnectionError::IoError)?;
426 Ok(mem::size_of::<u32>() + serialized.len())
427 }
428
429 async fn read_from<R: AsyncRead>(reader: R) -> Result<(Self, usize), ConnectionError> {
430 let mut reader = pin!(reader);
431 let len = reader.read_u32().await.map_err(ConnectionError::IoError)?;
432 let mut buf = vec![0; len as usize];
433 reader
434 .read_exact(&mut buf)
435 .await
436 .map_err(ConnectionError::IoError)?;
437 let uid = bincode_opts()
438 .deserialize(&buf)
439 .map_err(ConnectionError::UniqueIdDeserialization)?;
440 Ok((uid, mem::size_of::<u32>() + len as usize))
441 }
442}
443
444#[derive(thiserror::Error, Debug)]
445pub enum StreamError {
446 #[error("unable to flush stream")]
447 Flush(#[source] s2n_quic::stream::Error),
448 #[error("unable to close stream")]
449 Close(#[source] s2n_quic::stream::Error),
450 #[error("unable to finish stream")]
451 Finish(#[source] s2n_quic::stream::Error),
452}
453
454impl SendStreamBytes {
455 pub async fn flush(&mut self) -> Result<(), StreamError> {
456 self.inner.flush().await.map_err(StreamError::Flush)
457 }
458
459 pub fn finish(&mut self) -> Result<(), StreamError> {
460 self.inner.finish().map_err(StreamError::Finish)
461 }
462
463 pub async fn close(&mut self) -> Result<(), StreamError> {
464 self.inner.close().await.map_err(StreamError::Close)
465 }
466
467 pub fn as_stream<T: Serialize>(&mut self) -> TempSendStream<T> {
468 let framed_send = default_codec().new_write(self);
469 SymmetricallyFramed::new(framed_send, Bincode::default())
470 }
471}
472
473impl AsyncWrite for SendStreamBytes {
474 fn poll_write(
475 mut self: Pin<&mut Self>,
476 cx: &mut Context<'_>,
477 buf: &[u8],
478 ) -> Poll<Result<usize, Error>> {
479 let inner = Pin::new(&mut self.inner);
480 trace_poll(inner.poll_write(cx, buf))
481 }
482
483 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
484 let inner = Pin::new(&mut self.inner);
485 AsyncWrite::poll_flush(inner, cx)
486 }
487
488 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
489 let inner = Pin::new(&mut self.inner);
490 inner.poll_shutdown(cx)
491 }
492
493 fn poll_write_vectored(
494 mut self: Pin<&mut Self>,
495 cx: &mut Context<'_>,
496 bufs: &[IoSlice<'_>],
497 ) -> Poll<Result<usize, Error>> {
498 let inner = Pin::new(&mut self.inner);
499 trace_poll(inner.poll_write_vectored(cx, bufs))
500 }
501
502 fn is_write_vectored(&self) -> bool {
503 self.inner.is_write_vectored()
504 }
505}
506
507fn trace_poll(p: Poll<io::Result<usize>>) -> Poll<io::Result<usize>> {
508 if let Poll::Ready(Ok(bytes)) = p {
509 event!(target: "cryprot_metrics", Level::TRACE, bytes_written = bytes);
510 }
511 p
512}
513
514impl ReceiveStreamBytes {
515 pub fn as_stream<T: DeserializeOwned>(&mut self) -> ReceiveStreamTemp<T> {
516 let framed_read = default_codec().new_read(self);
517 SymmetricallyFramed::new(framed_read, Bincode::default())
518 }
519}
520
521impl AsyncRead for ReceiveStreamBytes {
524 fn poll_read(
525 mut self: Pin<&mut Self>,
526 cx: &mut Context<'_>,
527 buf: &mut ReadBuf<'_>,
528 ) -> Poll<std::io::Result<()>> {
529 match &mut self.inner {
530 ReceiveStreamWrapper::Channel { stream_recv } => match Pin::new(stream_recv).poll(cx) {
531 Poll::Pending => Poll::Pending,
532 Poll::Ready(Ok((recv_stream, bytes_read))) => {
533 event!(target: "cryprot_metrics", Level::TRACE, bytes_read);
536 self.inner = ReceiveStreamWrapper::Stream { recv_stream };
537 self.poll_read(cx, buf)
538 }
539 Poll::Ready(Err(err)) => Poll::Ready(Err(std::io::Error::other(Box::new(err)))),
540 },
541 ReceiveStreamWrapper::Stream { recv_stream } => {
542 let len = buf.filled().len();
543 let poll = Pin::new(recv_stream).poll_read(cx, buf);
544 if let Poll::Ready(Ok(())) = poll {
545 let bytes = buf.filled().len() - len;
546 if bytes > 0 {
547 event!(target: "cryprot_metrics", Level::TRACE, bytes_read = bytes);
548 }
549 }
550 poll
551 }
552 }
553 }
554}
555
556fn default_codec() -> length_delimited::Builder {
557 let mut ld_codec = LengthDelimitedCodec::builder();
558 const MB: usize = 1024 * 1024;
559 ld_codec.max_frame_length(20 * MB);
560 ld_codec
561}
562
563#[cfg(test)]
564mod tests {
565 use std::u8;
566
567 use anyhow::{Context, Result};
568 use futures::{SinkExt, StreamExt};
569 use tokio::{
570 io::{AsyncReadExt, AsyncWriteExt},
571 task::JoinSet,
572 };
573 use tracing::debug;
574
575 use crate::{
576 testing::{init_tracing, local_conn},
577 Id,
578 };
579
580 #[tokio::test]
581 async fn create_local_conn() -> Result<()> {
582 let _g = init_tracing();
583 let _ = local_conn().await?;
584 Ok(())
585 }
586
587 #[tokio::test]
588 async fn byte_stream() -> Result<()> {
589 let _g = init_tracing();
590 let (mut s, mut c) = local_conn().await?;
591 let (mut s_send, _) = s.byte_stream().await?;
592 let (_, mut c_recv) = c.byte_stream().await?;
593 let send_buf = b"hello there";
594 s_send.write_all(send_buf).await?;
595 let mut buf = [0; 11];
596 c_recv.read_exact(&mut buf).await?;
597 assert_eq!(send_buf, &buf);
598 Ok(())
599 }
600
601 #[tokio::test]
602 async fn byte_stream_explicit_implicit_id() -> Result<()> {
603 let _g = init_tracing();
604 let (mut s, mut c) = local_conn().await?;
605 let (mut s_send1, _) = s.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
606 let (mut s_send2, _) = s.byte_stream().await?;
607 let (_, mut c_recv1) = c.byte_stream_with_id(Id::new(u32::MAX as u64 + 42)).await?;
608 let (_, mut c_recv2) = c.byte_stream().await?;
609 let send_buf1 = b"hello there";
610 s_send1.write_all(send_buf1).await?;
611 let mut buf = [0; 11];
612 c_recv1.read_exact(&mut buf).await?;
613 assert_eq!(send_buf1, &buf);
614
615 let send_buf2 = b"general kenobi";
616 s_send2.write_all(send_buf2).await?;
617 let mut buf = [0; 14];
618 c_recv2.read_exact(&mut buf).await?;
619 assert_eq!(send_buf2, &buf);
620 Ok(())
621 }
622
623 #[tokio::test]
624 async fn byte_stream_different_order() -> Result<()> {
625 let _g = init_tracing();
626 let (mut s, mut c) = local_conn().await?;
627 let (mut s_send, mut s_recv) = s.byte_stream().await?;
628 let s_send_buf = b"hello there";
629 s_send.write_all(s_send_buf).await?;
630 let mut s_recv_buf = [0; 2];
631 let jh = tokio::spawn(async move {
634 s_recv.read_exact(&mut s_recv_buf).await.unwrap();
635 s_recv_buf
636 });
637 let (mut c_send, mut c_recv) = c.byte_stream().await?;
638 let mut c_recv_buf = [0; 11];
639 c_recv.read_exact(&mut c_recv_buf).await?;
640 assert_eq!(s_send_buf, &c_recv_buf);
641 let c_send_buf = b"42";
642 c_send.write_all(c_send_buf).await?;
643 let s_recv_buf = jh.await?;
644 assert_eq!(c_send_buf, &s_recv_buf);
645 Ok(())
646 }
647
648 #[tokio::test]
649 async fn many_parallel_byte_streams() -> Result<()> {
650 let _g = init_tracing();
651 let (mut c1, mut c2) = local_conn().await?;
652 let mut jhs = JoinSet::new();
653 for i in 0..10 {
654 let ((mut s, _), (_, mut r)) =
655 tokio::try_join!(c1.byte_stream(), c2.byte_stream()).unwrap();
656
657 let jh = tokio::spawn(async move {
658 let buf = vec![0; 10 * 1024 * 1024];
659 s.write_all(&buf).await.unwrap();
660 debug!("wrote buf {i}");
661 });
662 jhs.spawn(jh);
663 let jh = tokio::spawn(async move {
664 let mut buf = vec![0; 10 * 1024 * 1024];
665 r.read_exact(&mut buf).await.unwrap();
666 debug!("received buf {i}");
667 });
668 jhs.spawn(jh);
669 }
670 let res = jhs.join_all().await;
671 for res in res {
672 res.unwrap();
673 }
674 Ok(())
675 }
676
677 #[tokio::test]
678 async fn serde_stream() -> Result<()> {
679 let _g = init_tracing();
680 let (mut s, mut c) = local_conn().await?;
681 let (mut snd, _) = s.stream::<Vec<i32>>().await?;
682 let (_, mut recv) = c.stream::<Vec<i32>>().await?;
683 snd.send(vec![1, 2, 3]).await?;
684 let ret = recv.next().await.context("recv")??;
685 assert_eq!(vec![1, 2, 3], ret);
686 drop(snd);
687 let ret = recv.next().await.map(|res| res.map_err(|_| ()));
688 assert_eq!(None, ret);
689 Ok(())
690 }
691
692 #[tokio::test]
693 async fn serde_stream_block() -> Result<()> {
694 let _g = init_tracing();
695 let (mut s, mut c) = local_conn().await?;
696 let (mut snd, _) = s.stream().await?;
697 let (_, mut recv) = c.stream().await?;
698 snd.send(vec![u8::MAX; 16]).await?;
699 let ret: Vec<_> = recv.next().await.context("recv")??;
700 assert_eq!(vec![u8::MAX; 16], ret);
701 Ok(())
702 }
703
704 #[tokio::test]
705 async fn serde_byte_stream_as_stream() -> Result<()> {
706 let _g = init_tracing();
707 let (mut s, mut c) = local_conn().await?;
708 let (mut s_send, _) = s.byte_stream().await?;
709 let (_, mut c_recv) = c.byte_stream().await?;
710 {
711 let mut send_ser1 = s_send.as_stream::<i32>();
712 let mut recv_ser1 = c_recv.as_stream::<i32>();
713 send_ser1.send(42).await?;
714 let ret = recv_ser1.next().await.context("recv")??;
715 assert_eq!(42, ret);
716 }
717 {
718 let mut send_ser2 = s_send.as_stream::<Vec<i32>>();
719 let mut recv_ser2 = c_recv.as_stream::<Vec<i32>>();
720 send_ser2.send(vec![1, 2, 3]).await?;
721 let ret = recv_ser2.next().await.context("recv")??;
722 assert_eq!(vec![1, 2, 3], ret);
723 }
724 Ok(())
725 }
726
727 #[tokio::test]
728 async fn serde_request_response_stream() -> Result<()> {
729 let _g = init_tracing();
730 let (mut s, mut c) = local_conn().await?;
731 let (mut snd1, mut recv1) = s.request_response_stream::<Vec<i32>, String>().await?;
732 let (mut snd2, mut recv2) = c.request_response_stream::<String, Vec<i32>>().await?;
733 snd1.send(vec![1, 2, 3]).await?;
734 let ret = recv2.next().await.context("recv")??;
735 assert_eq!(vec![1, 2, 3], ret);
736 snd2.send("hello there".to_string()).await?;
737 let ret = recv1.next().await.context("recv2")??;
738 assert_eq!("hello there", &ret);
739 Ok(())
740 }
741
742 #[tokio::test]
743 async fn sub_connection() -> Result<()> {
744 let _g = init_tracing();
745 let (mut s1, mut c1) = local_conn().await?;
746 let mut s2 = s1.sub_connection();
747 let mut c2 = c1.sub_connection();
748 let _ = s1.byte_stream();
749 let _ = c1.byte_stream();
750 let (mut snd, _) = s2.stream::<Vec<i32>>().await?;
751 let (_, mut recv) = c2.stream::<Vec<i32>>().await?;
752
753 snd.send(vec![1, 2, 3]).await?;
754 let ret = recv.next().await.context("recv")??;
755 assert_eq!(vec![1, 2, 3], ret);
756 Ok(())
757 }
758
759 #[tokio::test]
760 async fn sub_sub_connection() -> Result<()> {
761 let _g = init_tracing();
762 let (mut s1, mut c1) = local_conn().await?;
763 let mut s2 = s1.sub_connection();
764 let mut c2 = c1.sub_connection();
765 let mut s3 = s2.sub_connection();
766 let mut c3 = c2.sub_connection();
767 let _ = s1.byte_stream();
768 let _ = c1.byte_stream();
769 let _ = s2.byte_stream();
770 let _ = c2.byte_stream();
771 let (mut snd, _) = s3.stream::<Vec<i32>>().await?;
772 let (_, mut recv) = c3.stream::<Vec<i32>>().await?;
773
774 snd.send(vec![1, 2, 3]).await?;
775 let ret = recv.next().await.context("recv")??;
776 assert_eq!(vec![1, 2, 3], ret);
777 Ok(())
778 }
779}