1use futures::SinkExt;
2use tokio::io::AsyncWriteExt;
3
4#[derive(Debug)]
5pub struct SendStream {
6 framed:
7 tokio_util::codec::FramedWrite<quinn::SendStream, tokio_util::codec::LengthDelimitedCodec>,
8}
9
10impl SendStream {
11 pub async fn new(stream: quinn::SendStream) -> anyhow::Result<Self> {
12 let framed = tokio_util::codec::FramedWrite::new(
13 stream,
14 tokio_util::codec::LengthDelimitedCodec::new(),
15 );
16 Ok(Self { framed })
17 }
18
19 pub async fn send_batch_message<T: serde::Serialize>(&mut self, obj: &T) -> anyhow::Result<()> {
20 let bytes = bincode::serialize(obj)?;
21 self.framed.send(bytes::Bytes::from(bytes)).await?;
22 Ok(())
23 }
24
25 pub async fn send_control_message<T: serde::Serialize>(
26 &mut self,
27 obj: &T,
28 ) -> anyhow::Result<()> {
29 self.send_batch_message(obj).await?;
30 self.framed.flush().await?;
31 Ok(())
32 }
33
34 pub async fn send_message_with_data<T: serde::Serialize, R: tokio::io::AsyncRead + Unpin>(
35 &mut self,
36 obj: &T,
37 reader: &mut R,
38 ) -> anyhow::Result<u64> {
39 self.send_control_message(obj).await?;
40 let mut data_stream = self.framed.get_mut();
41 let bytes_copied = tokio::io::copy(reader, &mut data_stream).await?;
42 Ok(bytes_copied)
43 }
44
45 pub async fn close(&mut self) -> anyhow::Result<()> {
46 self.framed.close().await?;
47 Ok(())
48 }
49}
50
51pub type SharedSendStream = std::sync::Arc<tokio::sync::Mutex<SendStream>>;
52
53#[derive(Debug)]
54pub struct RecvStream {
55 framed:
56 tokio_util::codec::FramedRead<quinn::RecvStream, tokio_util::codec::LengthDelimitedCodec>,
57}
58
59impl RecvStream {
60 pub async fn new(stream: quinn::RecvStream) -> anyhow::Result<Self> {
61 let framed = tokio_util::codec::FramedRead::new(
62 stream,
63 tokio_util::codec::LengthDelimitedCodec::new(),
64 );
65 Ok(Self { framed })
66 }
67
68 pub async fn recv_object<T: serde::de::DeserializeOwned>(
69 &mut self,
70 ) -> anyhow::Result<Option<T>> {
71 if let Some(frame) = futures::StreamExt::next(&mut self.framed).await {
72 let bytes = frame?;
73 let obj = bincode::deserialize(&bytes)?;
74 Ok(Some(obj))
75 } else {
76 Ok(None)
77 }
78 }
79
80 pub async fn copy_to<W: tokio::io::AsyncWrite + Unpin>(
81 &mut self,
82 writer: &mut W,
83 ) -> anyhow::Result<u64> {
84 let read_buffer = self.framed.read_buffer();
85 let buffer_size = read_buffer.len() as u64;
86 writer.write_all(read_buffer).await?;
87 let data_stream = self.framed.get_mut();
88 let stream_bytes = tokio::io::copy(data_stream, writer).await?;
89 Ok(buffer_size + stream_bytes)
90 }
91
92 pub async fn close(&mut self) {
93 let recv_stream = self.framed.get_mut();
94 if recv_stream.read_to_end(0).await.is_err() {
96 let _ = recv_stream.stop(0u8.into());
98 }
99 }
100}
101
102#[derive(Clone, Debug)]
104pub struct Connection {
105 inner: quinn::Connection,
106}
107
108impl Connection {
109 pub fn new(conn: quinn::Connection) -> Self {
110 Self { inner: conn }
111 }
112
113 pub async fn open_bi(&self) -> anyhow::Result<(SharedSendStream, RecvStream)> {
114 let (send_stream, recv_stream) = self.inner.open_bi().await?;
115 let send_stream = SendStream::new(send_stream).await?;
116 let recv_stream = RecvStream::new(recv_stream).await?;
117 Ok((
118 std::sync::Arc::new(tokio::sync::Mutex::new(send_stream)),
119 recv_stream,
120 ))
121 }
122
123 pub async fn open_uni(&self) -> anyhow::Result<SendStream> {
124 let send_stream = self.inner.open_uni().await?;
125 SendStream::new(send_stream).await
126 }
127
128 pub async fn accept_bi(&self) -> anyhow::Result<(SharedSendStream, RecvStream)> {
129 let (send_stream, recv_stream) = self.inner.accept_bi().await?;
130 let send_stream = SendStream::new(send_stream).await?;
131 let recv_stream = RecvStream::new(recv_stream).await?;
132 Ok((
133 std::sync::Arc::new(tokio::sync::Mutex::new(send_stream)),
134 recv_stream,
135 ))
136 }
137
138 pub async fn accept_uni(&self) -> anyhow::Result<RecvStream> {
139 let recv_stream = self.inner.accept_uni().await?;
140 RecvStream::new(recv_stream).await
141 }
142
143 pub fn close(&self) {
144 self.inner.close(0u32.into(), b"done");
145 }
146}