1use super::stream;
2use rand::prelude::*;
3use std::ops::Deref;
4
5#[derive(Clone)]
6pub enum ConnectionError {
7 Quic(quiche::Error),
8 Io(std::io::ErrorKind),
9 Connection(quiche::ConnectionError),
10}
11
12impl ConnectionError {
13 pub fn to_id(&self) -> u64 {
14 match self {
15 Self::Quic(_) => 0,
16 Self::Io(_) => 0,
17 Self::Connection(c) => c.error_code,
18 }
19 }
20}
21
22impl std::fmt::Debug for ConnectionError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::Quic(q) => f.write_fmt(format_args!("QUIC({:?})", q)),
26 Self::Io(e) => f.write_fmt(format_args!("IO({:?})", e)),
27 Self::Connection(e) => f.write_fmt(format_args!(
28 "Connection(is_app={}, error_code={:x}, reason={})",
29 e.is_app,
30 e.error_code,
31 String::from_utf8_lossy(&e.reason)
32 )),
33 }
34 }
35}
36
37impl std::fmt::Display for ConnectionError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.write_fmt(format_args!("{:?}", self))
40 }
41}
42
43impl std::error::Error for ConnectionError {}
44
45type ConnectionResult<T> = Result<T, ConnectionError>;
46
47impl From<quiche::Error> for ConnectionError {
48 fn from(value: quiche::Error) -> Self {
49 Self::Quic(value)
50 }
51}
52
53impl From<quiche::ConnectionError> for ConnectionError {
54 fn from(value: quiche::ConnectionError) -> Self {
55 Self::Connection(value)
56 }
57}
58
59impl From<std::io::Error> for ConnectionError {
60 fn from(value: std::io::Error) -> Self {
61 Self::Io(value.kind())
62 }
63}
64
65impl From<std::io::ErrorKind> for ConnectionError {
66 fn from(value: std::io::ErrorKind) -> Self {
67 Self::Io(value)
68 }
69}
70
71impl From<ConnectionError> for std::io::Error {
72 fn from(value: ConnectionError) -> Self {
73 match value {
74 ConnectionError::Io(k) => std::io::Error::new(k, ""),
75 o => std::io::Error::new(std::io::ErrorKind::Other, format!("{:?}", o)),
76 }
77 }
78}
79
80pub(super) enum Control {
81 ShouldSend,
82 SendAckEliciting,
83 SetQLog(QLogConfig),
84 Close {
85 app: bool,
86 err: u64,
87 reason: Vec<u8>,
88 },
89 StreamSend {
90 stream_id: u64,
91 data: Vec<u8>,
92 fin: bool,
93 resp: tokio::sync::oneshot::Sender<ConnectionResult<usize>>,
94 },
95 StreamRecv {
96 stream_id: u64,
97 len: usize,
98 resp: tokio::sync::oneshot::Sender<ConnectionResult<(Vec<u8>, bool)>>,
99 },
100 }
107
108#[derive(Debug)]
109pub struct Connection {
110 is_server: bool,
111 control_tx: tokio::sync::mpsc::Sender<Control>,
112 shared_state: std::sync::Arc<SharedConnectionState>,
113 new_stream_rx: Option<tokio::sync::mpsc::Receiver<stream::Stream>>,
114}
115
116pub struct QLogConfig {
117 pub qlog: crate::qlog::QLog,
118 pub title: String,
119 pub description: String,
120 pub level: quiche::QlogLevel,
121}
122
123#[derive(Debug)]
124pub(super) struct SharedConnectionState {
125 connection_established: std::sync::atomic::AtomicBool,
126 connection_established_notify: tokio::sync::Mutex<Vec<std::sync::Arc<tokio::sync::Notify>>>,
127 connection_closed: std::sync::atomic::AtomicBool,
128 connection_closed_notify: tokio::sync::Mutex<Vec<std::sync::Arc<tokio::sync::Notify>>>,
129 pub(super) connection_error: tokio::sync::RwLock<Option<ConnectionError>>,
130}
131
132struct InnerConnectionState {
133 conn: quiche::Connection,
134 socket: tokio::net::UdpSocket,
135 local_addr: std::net::SocketAddr,
136 max_datagram_size: usize,
137 control_rx: tokio::sync::mpsc::Receiver<Control>,
138 control_tx: tokio::sync::mpsc::Sender<Control>,
139 new_stream_tx: tokio::sync::mpsc::Sender<stream::Stream>,
140}
141
142impl Connection {
143 pub async fn connect(
144 peer_addr: std::net::SocketAddr,
145 mut config: quiche::Config,
146 server_name: Option<&str>,
147 qlog: Option<QLogConfig>,
148 ) -> ConnectionResult<Self> {
149 let bind_addr: std::net::SocketAddr = match peer_addr {
150 std::net::SocketAddr::V4(_) => "0.0.0.0:0",
151 std::net::SocketAddr::V6(_) => "[::]:0",
152 }
153 .parse()
154 .unwrap();
155
156 let mut cid = [0; quiche::MAX_CONN_ID_LEN];
157 thread_rng().fill(&mut cid[..]);
158 let cid = quiche::ConnectionId::from_ref(&cid);
159
160 let socket = tokio::net::UdpSocket::bind(bind_addr).await?;
161 let local_addr = socket.local_addr()?;
162 debug!("Connecting to {} from {}", peer_addr, local_addr);
163
164 let mut conn = quiche::connect(server_name, &cid, local_addr, peer_addr, &mut config)?;
165 if let Some(qlog) = qlog {
166 conn.set_qlog_with_level(
167 Box::new(qlog.qlog),
168 qlog.title,
169 qlog.description,
170 qlog.level,
171 );
172 }
173 let max_datagram_size = conn.max_send_udp_payload_size();
174
175 let (control_tx, control_rx) = tokio::sync::mpsc::channel(25);
176 let (new_stream_tx, new_stream_rx) = tokio::sync::mpsc::channel(25);
177
178 let shared_connection_state = std::sync::Arc::new(SharedConnectionState {
179 connection_established: std::sync::atomic::AtomicBool::new(false),
180 connection_established_notify: tokio::sync::Mutex::new(Vec::new()),
181 connection_closed: std::sync::atomic::AtomicBool::new(false),
182 connection_closed_notify: tokio::sync::Mutex::new(Vec::new()),
183 connection_error: tokio::sync::RwLock::new(None),
184 });
185
186 let connection = Connection {
187 is_server: conn.is_server(),
188 control_tx: control_tx.clone(),
189 shared_state: shared_connection_state.clone(),
190 new_stream_rx: Some(new_stream_rx),
191 };
192
193 shared_connection_state.run(InnerConnectionState {
194 conn,
195 socket,
196 local_addr,
197 max_datagram_size,
198 control_rx,
199 control_tx,
200 new_stream_tx,
201 });
202
203 connection.should_send().await.unwrap();
204
205 Ok(connection)
206 }
207
208 async fn send_control(&self, control: Control) -> ConnectionResult<()> {
209 if let Some(err) = self
210 .shared_state
211 .connection_error
212 .read()
213 .await
214 .deref()
215 .clone()
216 {
217 return Err(err);
218 }
219 match self.control_tx.try_send(control) {
220 Ok(_) => {}
221 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {}
222 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
223 if let Some(err) = self
224 .shared_state
225 .connection_error
226 .read()
227 .await
228 .deref()
229 .clone()
230 {
231 return Err(err);
232 }
233 return Err(std::io::ErrorKind::ConnectionReset.into());
234 }
235 }
236 Ok(())
237 }
238
239 async fn should_send(&self) -> ConnectionResult<()> {
240 self.send_control(Control::ShouldSend).await
241 }
242
243 pub async fn established(&self) -> ConnectionResult<()> {
244 if let Some(err) = self
245 .shared_state
246 .connection_error
247 .read()
248 .await
249 .deref()
250 .clone()
251 {
252 return Err(err);
253 }
254 if self
255 .shared_state
256 .connection_established
257 .load(std::sync::atomic::Ordering::Acquire)
258 {
259 return Ok(());
260 }
261 let notify = std::sync::Arc::new(tokio::sync::Notify::new());
262 self.shared_state
263 .connection_established_notify
264 .lock()
265 .await
266 .push(notify.clone());
267 if let Some(err) = self
268 .shared_state
269 .connection_error
270 .read()
271 .await
272 .deref()
273 .clone()
274 {
275 return Err(err);
276 }
277 notify.notified().await;
278 if let Some(err) = self
279 .shared_state
280 .connection_error
281 .read()
282 .await
283 .deref()
284 .clone()
285 {
286 return Err(err);
287 }
288 Ok(())
289 }
290
291 pub async fn set_qlog(&self, qlog: QLogConfig) -> ConnectionResult<()> {
292 self.send_control(Control::SetQLog(qlog)).await
293 }
294
295 pub async fn send_ack_eliciting(&self) -> ConnectionResult<()> {
296 self.send_control(Control::SendAckEliciting).await
297 }
298
299 pub async fn close(&self, app: bool, err: u64, reason: Vec<u8>) -> ConnectionResult<()> {
300 let notify = std::sync::Arc::new(tokio::sync::Notify::new());
301 self.shared_state
302 .connection_established_notify
303 .lock()
304 .await
305 .push(notify.clone());
306 self.send_control(Control::Close { app, err, reason })
307 .await?;
308 notify.notified().await;
309 if let Some(err) = self
310 .shared_state
311 .connection_error
312 .read()
313 .await
314 .deref()
315 .clone()
316 {
317 return Err(err);
318 }
319 Ok(())
320 }
321
322 pub fn is_server(&self) -> bool {
323 self.is_server
324 }
325
326 pub async fn new_stream(&self, stream_id: u64, bidi: bool) -> ConnectionResult<stream::Stream> {
327 Ok(stream::Stream::new(
328 self.is_server,
329 stream::StreamID::new(stream_id, bidi, self.is_server),
330 self.shared_state.clone(),
331 self.control_tx.clone(),
332 ))
333 }
334
335 pub async fn next_peer_stream(&mut self) -> ConnectionResult<stream::Stream> {
336 match self.new_stream_rx.as_mut().unwrap().recv().await {
337 Some(s) => Ok(s),
338 None => Err(self
339 .shared_state
340 .connection_error
341 .read()
342 .await
343 .clone()
344 .unwrap_or(std::io::ErrorKind::ConnectionReset.into())),
345 }
346 }
347
348 pub fn peer_streams(&mut self) -> ConnectionNewStreams {
349 ConnectionNewStreams {
350 stream_rx: self.new_stream_rx.take().unwrap(),
351 shared_state: self.shared_state.clone(),
352 }
353 }
354}
355
356#[derive(Debug)]
357pub struct ConnectionNewStreams {
358 stream_rx: tokio::sync::mpsc::Receiver<stream::Stream>,
359 shared_state: std::sync::Arc<SharedConnectionState>,
360}
361
362impl ConnectionNewStreams {
363 pub async fn next(&mut self) -> ConnectionResult<stream::Stream> {
364 match self.stream_rx.recv().await {
365 Some(s) => Ok(s),
366 None => Err(self
367 .shared_state
368 .connection_error
369 .read()
370 .await
371 .clone()
372 .unwrap_or(std::io::ErrorKind::ConnectionReset.into())),
373 }
374 }
375}
376
377struct PendingReceive {
378 stream_id: u64,
379 read_len: usize,
380 resp: tokio::sync::oneshot::Sender<ConnectionResult<(Vec<u8>, bool)>>,
381}
382
383impl SharedConnectionState {
384 fn run(self: std::sync::Arc<Self>, mut inner: InnerConnectionState) {
385 let (timeout_tx, mut timeout_rx) = tokio::sync::mpsc::channel(1);
386
387 tokio::task::spawn(async move {
388 let mut buf = [0; 65535];
389 let mut out = vec![0; inner.max_datagram_size];
390 let mut pending_recv: Vec<PendingReceive> = vec![];
391 let mut known_stream_ids = std::collections::HashSet::new();
392
393 'outer: loop {
394 tokio::select! {
395 res = inner.socket.recv_from(&mut buf) => {
396 let (len, addr) = match res {
397 Ok(v) => v,
398 Err(e) => {
399 self.set_error(e.into()).await;
400 break;
401 }
402 };
403 let recv_info = quiche::RecvInfo {
404 from: addr,
405 to: inner.local_addr
406 };
407
408 let read = match inner.conn.recv(&mut buf[..len], recv_info) {
409 Ok(v) => v,
410 Err(quiche::Error::Done) => {
411 continue;
412 },
413 Err(e) => {
414 self.set_error(e.into()).await;
415 break;
416 },
417 };
418 trace!("Received {} bytes", read);
419 if inner.conn.is_established() {
420 self.set_established().await;
421 }
422 inner.control_tx.send(Control::ShouldSend).await.unwrap();
423
424 let readable = pending_recv
425 .extract_if(|s| inner.conn.stream_readable(s.stream_id))
426 .collect::<Vec<_>>();
427 for s in readable {
428 let mut buf = vec![0u8; s.read_len];
429 match inner.conn.stream_recv(s.stream_id, &mut buf) {
430 Ok((read, fin)) => {
431 let out = buf[..read].to_vec();
432 let _ = s.resp.send(Ok((out, fin)));
433 }
434 Err(e) => {
435 let _ = s.resp.send(Err(e.into()));
436 }
437 }
438 }
439
440 let new_stream_ids = inner.conn.readable().filter(|stream_id| {
441 let client_flag = stream_id & 1;
442 if inner.conn.is_server() && client_flag == 1 {
443 return false;
444 }
445 if !inner.conn.is_server() && client_flag == 0 {
446 return false;
447 }
448 if known_stream_ids.contains(stream_id) {
449 return false;
450 }
451 known_stream_ids.insert(*stream_id);
452 true
453 }).collect::<Vec<_>>();
454 for stream in new_stream_ids {
455 let _ = inner.new_stream_tx.send(stream::Stream::new(
456 inner.conn.is_server(), stream::StreamID(stream),
457 self.clone(), inner.control_tx.clone(),
458 )).await;
459 }
460 }
461 c = inner.control_rx.recv() => {
462 let c = match c {
463 Some(c) => c,
464 None => break
465 };
466 match c {
467 Control::ShouldSend => if !inner.conn.is_draining() {
468 loop {
469 let (write, send_info) = match inner.conn.send(&mut out) {
470 Ok(v) => v,
471 Err(quiche::Error::Done) => {
472 break;
473 },
474 Err(e) => {
475 self.set_error(e.into()).await;
476 break 'outer;
477 }
478 };
479 if inner.conn.is_established() {
480 self.set_established().await;
481 }
482 if let Err(e) = inner.socket.send_to(&out[..write], &send_info.to).await {
483 self.set_error(e.into()).await;
484 break;
485 }
486 trace!("Sent {} bytes", write);
487 if let Some(timeout) = inner.conn.timeout() {
488 let inner_timeout_tx = timeout_tx.clone();
489 tokio::task::spawn(async move {
490 tokio::time::sleep(timeout).await;
491 let _ = inner_timeout_tx.send(()).await;
492 });
493 }
494 }
495 },
496 Control::SendAckEliciting => {
497 if let Err(e) = inner.conn.send_ack_eliciting() {
498 self.set_error(e.into()).await;
499 break;
500 }
501 }
502 Control::StreamSend { stream_id, data, fin, resp} => {
503 let _ = resp.send(
504 inner.conn.stream_send(stream_id, &data, fin)
505 .map_err(|e| e.into())
506 );
507 }
508 Control::StreamRecv { stream_id, len, resp } => {
509 let mut buf = vec![0u8; len];
510 match inner.conn.stream_recv(stream_id, &mut buf) {
511 Ok((read, fin)) => {
512 let out = buf[..read].to_vec();
513 let _ = resp.send(Ok((out, fin)));
514 }
515 Err(quiche::Error::Done) => {
516 pending_recv.push(PendingReceive {
517 stream_id,
518 read_len: len,
519 resp
520 });
521 }
522 Err(e) => {
523 let _ = resp.send(Err(e.into()));
524 }
525 }
526 }
527 Control::SetQLog(qlog) => {
534 inner.conn.set_qlog_with_level(
535 Box::new(qlog.qlog),
536 qlog.title,
537 qlog.description,
538 qlog.level,
539 );
540 }
541 Control::Close { app, err, reason } => {
542 if let Err(e) = inner.conn.close(app, err, &reason) {
543 self.set_error(e.into()).await;
544 break;
545 }
546 }
547 }
548 }
549 _ = timeout_rx.recv() => {
550 trace!("On timeout");
551 inner.conn.on_timeout();
552 inner.control_tx.send(Control::ShouldSend).await.unwrap();
553 }
554 }
555
556 if inner.conn.is_closed() {
557 if let Some(err) = inner.conn.peer_error() {
558 self.connection_error
559 .write()
560 .await
561 .replace(err.clone().into());
562 } else if let Some(err) = inner.conn.local_error() {
563 self.connection_error
564 .write()
565 .await
566 .replace(err.clone().into());
567 } else if inner.conn.is_timed_out() {
568 self.connection_error
569 .write()
570 .await
571 .replace(std::io::ErrorKind::TimedOut.into());
572 } else {
573 self.connection_error
574 .write()
575 .await
576 .replace(std::io::ErrorKind::ConnectionReset.into());
577 }
578 self.set_closed().await;
579 break;
580 }
581 }
582 });
583 }
584
585 async fn set_error(&self, error: ConnectionError) {
586 self.connection_error.write().await.replace(error);
587 self.notify_connection_established().await;
588 }
589
590 async fn notify_connection_established(&self) {
591 for n in self.connection_established_notify.lock().await.drain(..) {
592 n.notify_one();
593 }
594 }
595
596 async fn notify_connection_closed(&self) {
597 for n in self.connection_closed_notify.lock().await.drain(..) {
598 n.notify_one();
599 }
600 self.notify_connection_established().await;
601 }
602
603 async fn set_established(&self) {
604 self.connection_established
605 .store(true, std::sync::atomic::Ordering::Relaxed);
606 self.notify_connection_established().await;
607 }
608
609 async fn set_closed(&self) {
610 self.connection_closed
611 .store(true, std::sync::atomic::Ordering::Relaxed);
612 self.notify_connection_closed().await;
613 }
614}