selium_protocol/
bistream.rs

1use crate::traits::{ShutdownSink, ShutdownStream};
2use crate::{error_codes, Frame, MessageCodec};
3use futures::{Sink, SinkExt, Stream, StreamExt};
4use quinn::VarInt;
5use quinn::{Connection, RecvStream, SendStream, StreamId};
6use selium_std::errors::{QuicError, Result, SeliumError};
7use std::{
8    ops::{Deref, DerefMut},
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tokio_util::codec::{FramedRead, FramedWrite};
13
14pub struct WriteHalf(FramedWrite<SendStream, MessageCodec>);
15
16pub struct ReadHalf(FramedRead<RecvStream, MessageCodec>);
17
18pub struct BiStream {
19    write: WriteHalf,
20    read: ReadHalf,
21}
22
23impl Sink<Frame> for WriteHalf {
24    type Error = SeliumError;
25
26    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
27        self.0.poll_ready_unpin(cx)
28    }
29
30    fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
31        self.0.start_send_unpin(item)
32    }
33
34    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
35        self.0.poll_flush_unpin(cx)
36    }
37
38    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
39        self.0.poll_close_unpin(cx)
40    }
41}
42
43impl ShutdownSink for WriteHalf {
44    fn shutdown_sink(&mut self) {
45        let _ = self
46            .0
47            .get_mut()
48            .reset(VarInt::from_u32(error_codes::SHUTDOWN_IN_PROGRESS));
49    }
50}
51
52impl From<SendStream> for WriteHalf {
53    fn from(send: SendStream) -> Self {
54        Self(FramedWrite::new(send, MessageCodec))
55    }
56}
57
58impl Deref for WriteHalf {
59    type Target = SendStream;
60
61    fn deref(&self) -> &Self::Target {
62        self.0.get_ref()
63    }
64}
65
66impl DerefMut for WriteHalf {
67    fn deref_mut(&mut self) -> &mut Self::Target {
68        self.0.get_mut()
69    }
70}
71
72impl Stream for ReadHalf {
73    type Item = Result<Frame, SeliumError>;
74
75    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
76        self.0.poll_next_unpin(cx)
77    }
78
79    fn size_hint(&self) -> (usize, Option<usize>) {
80        self.0.size_hint()
81    }
82}
83
84impl ShutdownStream for ReadHalf {
85    fn shutdown_stream(&mut self) {
86        let _ = self
87            .0
88            .get_mut()
89            .stop(VarInt::from_u32(error_codes::SHUTDOWN_IN_PROGRESS));
90    }
91}
92
93impl From<RecvStream> for ReadHalf {
94    fn from(recv: RecvStream) -> Self {
95        Self(FramedRead::new(recv, MessageCodec))
96    }
97}
98
99impl Deref for ReadHalf {
100    type Target = RecvStream;
101
102    fn deref(&self) -> &Self::Target {
103        self.0.get_ref()
104    }
105}
106
107impl DerefMut for ReadHalf {
108    fn deref_mut(&mut self) -> &mut Self::Target {
109        self.0.get_mut()
110    }
111}
112
113impl BiStream {
114    pub async fn try_from_connection(connection: &Connection) -> Result<Self> {
115        let stream = connection
116            .open_bi()
117            .await
118            .map_err(QuicError::ConnectionError)?;
119        Ok(Self::from(stream))
120    }
121
122    pub fn split(self) -> (WriteHalf, ReadHalf) {
123        (self.write, self.read)
124    }
125
126    pub fn get_recv_stream_id(&self) -> StreamId {
127        self.read.id()
128    }
129
130    pub fn get_send_stream_id(&self) -> StreamId {
131        self.write.id()
132    }
133
134    pub fn read(&mut self) -> &mut RecvStream {
135        &mut self.read
136    }
137
138    pub fn write(&mut self) -> &mut SendStream {
139        &mut self.write
140    }
141
142    pub async fn finish(&mut self) -> Result<()> {
143        self.write.finish().await.map_err(QuicError::WriteError)?;
144        Ok(())
145    }
146}
147
148impl From<(SendStream, RecvStream)> for BiStream {
149    fn from((send, recv): (SendStream, RecvStream)) -> Self {
150        Self {
151            write: send.into(),
152            read: recv.into(),
153        }
154    }
155}
156
157impl Sink<Frame> for BiStream {
158    type Error = SeliumError;
159
160    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161        self.write.poll_ready_unpin(cx)
162    }
163
164    fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
165        self.write.start_send_unpin(item)
166    }
167
168    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169        self.write.poll_flush_unpin(cx)
170    }
171
172    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
173        self.write.poll_close_unpin(cx)
174    }
175}
176
177impl Stream for BiStream {
178    type Item = Result<Frame, SeliumError>;
179
180    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
181        self.read.poll_next_unpin(cx)
182    }
183
184    fn size_hint(&self) -> (usize, Option<usize>) {
185        self.read.size_hint()
186    }
187}
188
189impl ShutdownSink for BiStream {
190    fn shutdown_sink(&mut self) {
191        let _ = self
192            .write()
193            .reset(VarInt::from_u32(error_codes::SHUTDOWN_IN_PROGRESS));
194    }
195}
196
197impl ShutdownStream for BiStream {
198    fn shutdown_stream(&mut self) {
199        let _ = self
200            .read()
201            .stop(VarInt::from_u32(error_codes::SHUTDOWN_IN_PROGRESS));
202    }
203}