selium_protocol/
bistream.rs1use crate::traits::{ShutdownSink, ShutdownStream};
2use crate::{error_codes, Frame, MessageCodec};
3use futures::{Sink, SinkExt, Stream, StreamExt};
4use quinn::VarInt;
5use quinn::{Connection, RecvStream, SendStream, StreamId};
6use selium_std::errors::{QuicError, Result, SeliumError};
7use std::{
8 ops::{Deref, DerefMut},
9 pin::Pin,
10 task::{Context, Poll},
11};
12use tokio_util::codec::{FramedRead, FramedWrite};
13
14pub struct WriteHalf(FramedWrite<SendStream, MessageCodec>);
15
16pub struct ReadHalf(FramedRead<RecvStream, MessageCodec>);
17
18pub struct BiStream {
19 write: WriteHalf,
20 read: ReadHalf,
21}
22
23impl Sink<Frame> for WriteHalf {
24 type Error = SeliumError;
25
26 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
27 self.0.poll_ready_unpin(cx)
28 }
29
30 fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
31 self.0.start_send_unpin(item)
32 }
33
34 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
35 self.0.poll_flush_unpin(cx)
36 }
37
38 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
39 self.0.poll_close_unpin(cx)
40 }
41}
42
43impl ShutdownSink for WriteHalf {
44 fn shutdown_sink(&mut self) {
45 let _ = self
46 .0
47 .get_mut()
48 .reset(VarInt::from_u32(error_codes::SHUTDOWN_IN_PROGRESS));
49 }
50}
51
52impl From<SendStream> for WriteHalf {
53 fn from(send: SendStream) -> Self {
54 Self(FramedWrite::new(send, MessageCodec))
55 }
56}
57
58impl Deref for WriteHalf {
59 type Target = SendStream;
60
61 fn deref(&self) -> &Self::Target {
62 self.0.get_ref()
63 }
64}
65
66impl DerefMut for WriteHalf {
67 fn deref_mut(&mut self) -> &mut Self::Target {
68 self.0.get_mut()
69 }
70}
71
72impl Stream for ReadHalf {
73 type Item = Result<Frame, SeliumError>;
74
75 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
76 self.0.poll_next_unpin(cx)
77 }
78
79 fn size_hint(&self) -> (usize, Option<usize>) {
80 self.0.size_hint()
81 }
82}
83
84impl ShutdownStream for ReadHalf {
85 fn shutdown_stream(&mut self) {
86 let _ = self
87 .0
88 .get_mut()
89 .stop(VarInt::from_u32(error_codes::SHUTDOWN_IN_PROGRESS));
90 }
91}
92
93impl From<RecvStream> for ReadHalf {
94 fn from(recv: RecvStream) -> Self {
95 Self(FramedRead::new(recv, MessageCodec))
96 }
97}
98
99impl Deref for ReadHalf {
100 type Target = RecvStream;
101
102 fn deref(&self) -> &Self::Target {
103 self.0.get_ref()
104 }
105}
106
107impl DerefMut for ReadHalf {
108 fn deref_mut(&mut self) -> &mut Self::Target {
109 self.0.get_mut()
110 }
111}
112
113impl BiStream {
114 pub async fn try_from_connection(connection: &Connection) -> Result<Self> {
115 let stream = connection
116 .open_bi()
117 .await
118 .map_err(QuicError::ConnectionError)?;
119 Ok(Self::from(stream))
120 }
121
122 pub fn split(self) -> (WriteHalf, ReadHalf) {
123 (self.write, self.read)
124 }
125
126 pub fn get_recv_stream_id(&self) -> StreamId {
127 self.read.id()
128 }
129
130 pub fn get_send_stream_id(&self) -> StreamId {
131 self.write.id()
132 }
133
134 pub fn read(&mut self) -> &mut RecvStream {
135 &mut self.read
136 }
137
138 pub fn write(&mut self) -> &mut SendStream {
139 &mut self.write
140 }
141
142 pub async fn finish(&mut self) -> Result<()> {
143 self.write.finish().await.map_err(QuicError::WriteError)?;
144 Ok(())
145 }
146}
147
148impl From<(SendStream, RecvStream)> for BiStream {
149 fn from((send, recv): (SendStream, RecvStream)) -> Self {
150 Self {
151 write: send.into(),
152 read: recv.into(),
153 }
154 }
155}
156
157impl Sink<Frame> for BiStream {
158 type Error = SeliumError;
159
160 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161 self.write.poll_ready_unpin(cx)
162 }
163
164 fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> {
165 self.write.start_send_unpin(item)
166 }
167
168 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169 self.write.poll_flush_unpin(cx)
170 }
171
172 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
173 self.write.poll_close_unpin(cx)
174 }
175}
176
177impl Stream for BiStream {
178 type Item = Result<Frame, SeliumError>;
179
180 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
181 self.read.poll_next_unpin(cx)
182 }
183
184 fn size_hint(&self) -> (usize, Option<usize>) {
185 self.read.size_hint()
186 }
187}
188
189impl ShutdownSink for BiStream {
190 fn shutdown_sink(&mut self) {
191 let _ = self
192 .write()
193 .reset(VarInt::from_u32(error_codes::SHUTDOWN_IN_PROGRESS));
194 }
195}
196
197impl ShutdownStream for BiStream {
198 fn shutdown_stream(&mut self) {
199 let _ = self
200 .read()
201 .stop(VarInt::from_u32(error_codes::SHUTDOWN_IN_PROGRESS));
202 }
203}