remote/
streams.rs

1use futures::SinkExt;
2use tokio::io::AsyncWriteExt;
3
4#[derive(Debug)]
5pub struct SendStream {
6    framed:
7        tokio_util::codec::FramedWrite<quinn::SendStream, tokio_util::codec::LengthDelimitedCodec>,
8}
9
10impl SendStream {
11    pub async fn new(stream: quinn::SendStream) -> anyhow::Result<Self> {
12        let framed = tokio_util::codec::FramedWrite::new(
13            stream,
14            tokio_util::codec::LengthDelimitedCodec::new(),
15        );
16        Ok(Self { framed })
17    }
18
19    pub async fn send_batch_message<T: serde::Serialize>(&mut self, obj: &T) -> anyhow::Result<()> {
20        let bytes = bincode::serialize(obj)?;
21        self.framed.send(bytes::Bytes::from(bytes)).await?;
22        Ok(())
23    }
24
25    pub async fn send_control_message<T: serde::Serialize>(
26        &mut self,
27        obj: &T,
28    ) -> anyhow::Result<()> {
29        self.send_batch_message(obj).await?;
30        self.framed.flush().await?;
31        Ok(())
32    }
33
34    pub async fn send_message_with_data<T: serde::Serialize, R: tokio::io::AsyncRead + Unpin>(
35        &mut self,
36        obj: &T,
37        reader: &mut R,
38    ) -> anyhow::Result<u64> {
39        self.send_control_message(obj).await?;
40        let mut data_stream = self.framed.get_mut();
41        let bytes_copied = tokio::io::copy(reader, &mut data_stream).await?;
42        Ok(bytes_copied)
43    }
44
45    pub async fn close(&mut self) -> anyhow::Result<()> {
46        self.framed.close().await?;
47        Ok(())
48    }
49}
50
51pub type SharedSendStream = std::sync::Arc<tokio::sync::Mutex<SendStream>>;
52
53#[derive(Debug)]
54pub struct RecvStream {
55    framed:
56        tokio_util::codec::FramedRead<quinn::RecvStream, tokio_util::codec::LengthDelimitedCodec>,
57}
58
59impl RecvStream {
60    pub async fn new(stream: quinn::RecvStream) -> anyhow::Result<Self> {
61        let framed = tokio_util::codec::FramedRead::new(
62            stream,
63            tokio_util::codec::LengthDelimitedCodec::new(),
64        );
65        Ok(Self { framed })
66    }
67
68    pub async fn recv_object<T: serde::de::DeserializeOwned>(
69        &mut self,
70    ) -> anyhow::Result<Option<T>> {
71        if let Some(frame) = futures::StreamExt::next(&mut self.framed).await {
72            let bytes = frame?;
73            let obj = bincode::deserialize(&bytes)?;
74            Ok(Some(obj))
75        } else {
76            Ok(None)
77        }
78    }
79
80    pub async fn copy_to<W: tokio::io::AsyncWrite + Unpin>(
81        &mut self,
82        writer: &mut W,
83    ) -> anyhow::Result<u64> {
84        let read_buffer = self.framed.read_buffer();
85        let buffer_size = read_buffer.len() as u64;
86        writer.write_all(read_buffer).await?;
87        let data_stream = self.framed.get_mut();
88        let stream_bytes = tokio::io::copy(data_stream, writer).await?;
89        Ok(buffer_size + stream_bytes)
90    }
91
92    pub async fn close(&mut self) {
93        let recv_stream = self.framed.get_mut();
94        // copied from QUIC documentation: https://docs.rs/quinn/0.10.2/quinn/struct.RecvStream.html
95        if recv_stream.read_to_end(0).await.is_err() {
96            // discard unexpected data and notify the peer to stop sending it
97            let _ = recv_stream.stop(0u8.into());
98        }
99    }
100}
101
102/// Connection wrapper that provides framed stream creation
103#[derive(Clone, Debug)]
104pub struct Connection {
105    inner: quinn::Connection,
106}
107
108impl Connection {
109    pub fn new(conn: quinn::Connection) -> Self {
110        Self { inner: conn }
111    }
112
113    pub async fn open_bi(&self) -> anyhow::Result<(SharedSendStream, RecvStream)> {
114        let (send_stream, recv_stream) = self.inner.open_bi().await?;
115        let send_stream = SendStream::new(send_stream).await?;
116        let recv_stream = RecvStream::new(recv_stream).await?;
117        Ok((
118            std::sync::Arc::new(tokio::sync::Mutex::new(send_stream)),
119            recv_stream,
120        ))
121    }
122
123    pub async fn open_uni(&self) -> anyhow::Result<SendStream> {
124        let send_stream = self.inner.open_uni().await?;
125        SendStream::new(send_stream).await
126    }
127
128    pub async fn accept_bi(&self) -> anyhow::Result<(SharedSendStream, RecvStream)> {
129        let (send_stream, recv_stream) = self.inner.accept_bi().await?;
130        let send_stream = SendStream::new(send_stream).await?;
131        let recv_stream = RecvStream::new(recv_stream).await?;
132        Ok((
133            std::sync::Arc::new(tokio::sync::Mutex::new(send_stream)),
134            recv_stream,
135        ))
136    }
137
138    pub async fn accept_uni(&self) -> anyhow::Result<RecvStream> {
139        let recv_stream = self.inner.accept_uni().await?;
140        RecvStream::new(recv_stream).await
141    }
142
143    pub fn close(&self) {
144        self.inner.close(0u32.into(), b"done");
145    }
146}