1use std::{
2 ops::Deref,
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll},
6};
7
8use futures::Stream;
9use gm_quic::{StreamId, StreamReader, StreamWriter};
10use h3::quic::{ConnectionErrorIncoming, StreamErrorIncoming};
11
12use crate::{
13 error::{self, convert_quic_error},
14 streams::{BidiStream, RecvStream, SendStream},
15};
16pub struct QuicConnection {
18 connection: Arc<gm_quic::Connection>,
19 accept_bi: AcceptBiStreams,
20 accept_uni: AcceptUniStreams,
21 open_bi: OpenBiStreams,
22 open_uni: OpenUniStreams,
23}
24
25impl Deref for QuicConnection {
26 type Target = Arc<gm_quic::Connection>;
27
28 fn deref(&self) -> &Self::Target {
29 &self.connection
30 }
31}
32
33impl QuicConnection {
34 pub fn new(conn: Arc<gm_quic::Connection>) -> Self {
35 Self {
36 accept_bi: AcceptBiStreams::new(conn.clone()),
37 accept_uni: AcceptUniStreams::new(conn.clone()),
38 open_bi: OpenBiStreams::new(conn.clone()),
39 open_uni: OpenUniStreams::new(conn.clone()),
40 connection: conn,
41 }
42 }
43}
44
45impl<B: bytes::Buf> h3::quic::OpenStreams<B> for QuicConnection {
47 type BidiStream = BidiStream<B>;
48
49 type SendStream = SendStream<B>;
50
51 #[inline]
52 fn poll_open_bidi(
53 &mut self,
54 cx: &mut Context<'_>,
55 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
56 self.open_bi.poll_open(cx)
70
71 }
73
74 #[inline]
75 fn poll_open_send(
76 &mut self,
77 cx: &mut Context<'_>,
78 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
79 self.open_uni.poll_open(cx)
80 }
81
82 #[inline]
83 fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
84 let reason = unsafe { String::from_utf8_unchecked(reason.to_vec()) };
85 self.connection.close(reason, code.into());
86 }
87}
88
89impl<B: bytes::Buf> h3::quic::Connection<B> for QuicConnection {
92 type RecvStream = RecvStream;
93
94 type OpenStreams = OpenStreams;
95
96 #[inline]
97 fn poll_accept_recv(
98 &mut self,
99 cx: &mut Context<'_>,
100 ) -> Poll<Result<Self::RecvStream, ConnectionErrorIncoming>> {
101 self.accept_uni.poll_accept(cx)
102 }
103
104 #[inline]
105 fn poll_accept_bidi(
106 &mut self,
107 cx: &mut Context<'_>,
108 ) -> Poll<Result<Self::BidiStream, ConnectionErrorIncoming>> {
109 self.accept_bi.poll_accept(cx)
110 }
111
112 #[inline]
116 fn opener(&self) -> Self::OpenStreams {
117 OpenStreams::new(self.connection.clone())
118 }
119}
120
121pub struct OpenStreams {
123 connection: Arc<gm_quic::Connection>,
124 open_bi: OpenBiStreams,
125 open_uni: OpenUniStreams,
126}
127
128impl OpenStreams {
129 fn new(conn: Arc<gm_quic::Connection>) -> Self {
130 Self {
131 open_bi: OpenBiStreams::new(conn.clone()),
132 open_uni: OpenUniStreams::new(conn.clone()),
133 connection: conn,
134 }
135 }
136}
137
138impl Clone for OpenStreams {
139 fn clone(&self) -> Self {
140 Self {
141 open_bi: OpenBiStreams::new(self.connection.clone()),
142 open_uni: OpenUniStreams::new(self.connection.clone()),
143 connection: self.connection.clone(),
144 }
145 }
146}
147
148impl<B: bytes::Buf> h3::quic::OpenStreams<B> for OpenStreams {
150 type BidiStream = BidiStream<B>;
151
152 type SendStream = SendStream<B>;
153
154 #[inline]
155 fn poll_open_bidi(
156 &mut self,
157 cx: &mut Context<'_>,
158 ) -> Poll<Result<Self::BidiStream, StreamErrorIncoming>> {
159 self.open_bi.poll_open(cx)
160 }
161
162 #[inline]
163 fn poll_open_send(
164 &mut self,
165 cx: &mut Context<'_>,
166 ) -> Poll<Result<Self::SendStream, StreamErrorIncoming>> {
167 self.open_uni.poll_open(cx)
168 }
169
170 #[inline]
171 fn close(&mut self, code: h3::error::Code, reason: &[u8]) {
172 let reason = unsafe { String::from_utf8_unchecked(reason.to_vec()) };
173 self.connection.close(reason, code.into());
174 }
175}
176
177type BoxStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync>>;
178
179fn sid_exceed_limit_error() -> ConnectionErrorIncoming {
180 ConnectionErrorIncoming::Undefined(Arc::from(Box::from(
181 "the stream IDs in the `dir` direction exceed 2^60, this is very very hard to happen.",
182 )) as _)
183}
184
185#[allow(clippy::type_complexity)]
186struct OpenBiStreams(
187 BoxStream<Result<(StreamId, (StreamReader, StreamWriter)), ConnectionErrorIncoming>>,
188);
189
190impl OpenBiStreams {
191 fn new(conn: Arc<gm_quic::Connection>) -> Self {
192 let stream = futures::stream::unfold(conn, |conn| async {
193 let bidi = conn
194 .open_bi_stream()
195 .await
196 .map_err(convert_quic_error)
197 .and_then(|o| o.ok_or_else(sid_exceed_limit_error));
198 Some((bidi, conn))
199 });
200 Self(Box::pin(stream))
201 }
202
203 fn poll_open<B>(
207 &mut self,
208 cx: &mut Context<'_>,
209 ) -> Poll<Result<BidiStream<B>, StreamErrorIncoming>> {
210 self.0
211 .as_mut()
212 .poll_next(cx)
213 .map(Option::unwrap)
214 .map_ok(|(sid, stream)| BidiStream::new(sid, stream))
215 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
216 connection_error: e,
217 })
218 }
219}
220
221struct OpenUniStreams(BoxStream<Result<(StreamId, StreamWriter), ConnectionErrorIncoming>>);
222
223impl OpenUniStreams {
224 fn new(conn: Arc<gm_quic::Connection>) -> Self {
225 let stream = futures::stream::unfold(conn, |conn| async {
226 let send = conn
227 .open_uni_stream()
228 .await
229 .map_err(convert_quic_error)
230 .and_then(|o| o.ok_or_else(sid_exceed_limit_error));
231 Some((send, conn))
232 });
233 Self(Box::pin(stream))
234 }
235
236 fn poll_open<B>(
237 &mut self,
238 cx: &mut Context<'_>,
239 ) -> Poll<Result<SendStream<B>, StreamErrorIncoming>> {
240 self.0
241 .as_mut()
242 .poll_next(cx)
243 .map(Option::unwrap)
244 .map_ok(|(sid, writer)| SendStream::new(sid, writer))
245 .map_err(|e| StreamErrorIncoming::ConnectionErrorIncoming {
246 connection_error: e,
247 })
248 }
249}
250
251#[allow(clippy::type_complexity)]
252struct AcceptBiStreams(
253 BoxStream<Result<(StreamId, (StreamReader, StreamWriter)), ConnectionErrorIncoming>>,
254);
255
256impl AcceptBiStreams {
257 fn new(conn: Arc<gm_quic::Connection>) -> Self {
258 let stream = futures::stream::unfold(conn, |conn| async {
259 Some((
260 conn.accept_bi_stream()
261 .await
262 .map_err(error::convert_quic_error),
263 conn,
264 ))
265 });
266 Self(Box::pin(stream))
267 }
268
269 fn poll_accept<B>(
270 &mut self,
271 cx: &mut Context<'_>,
272 ) -> Poll<Result<BidiStream<B>, ConnectionErrorIncoming>> {
273 self.0
274 .as_mut()
275 .poll_next(cx)
276 .map(Option::unwrap)
277 .map_ok(|(sid, stream)| BidiStream::new(sid, stream))
278 }
279}
280
281struct AcceptUniStreams(BoxStream<Result<(StreamId, StreamReader), ConnectionErrorIncoming>>);
282
283impl AcceptUniStreams {
284 fn new(conn: Arc<gm_quic::Connection>) -> Self {
285 let stream = futures::stream::unfold(conn, |conn| async {
286 let uni = conn
287 .accept_uni_stream()
288 .await
289 .map_err(error::convert_quic_error);
290 Some((uni, conn))
291 });
292 Self(Box::pin(stream))
293 }
294
295 fn poll_accept(
296 &mut self,
297 cx: &mut Context<'_>,
298 ) -> Poll<Result<RecvStream, ConnectionErrorIncoming>> {
299 self.0
300 .as_mut()
301 .poll_next(cx)
302 .map(Option::unwrap)
303 .map_ok(|(sid, reader)| RecvStream::new(sid, reader))
304 }
305}