quiche_tokio/
connection.rs

1use super::stream;
2use rand::prelude::*;
3use std::ops::Deref;
4
5#[derive(Clone)]
6pub enum ConnectionError {
7    Quic(quiche::Error),
8    Io(std::io::ErrorKind),
9    Connection(quiche::ConnectionError),
10}
11
12impl ConnectionError {
13    pub fn to_id(&self) -> u64 {
14        match self {
15            Self::Quic(_) => 0,
16            Self::Io(_) => 0,
17            Self::Connection(c) => c.error_code,
18        }
19    }
20}
21
22impl std::fmt::Debug for ConnectionError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::Quic(q) => f.write_fmt(format_args!("QUIC({:?})", q)),
26            Self::Io(e) => f.write_fmt(format_args!("IO({:?})", e)),
27            Self::Connection(e) => f.write_fmt(format_args!(
28                "Connection(is_app={}, error_code={:x}, reason={})",
29                e.is_app,
30                e.error_code,
31                String::from_utf8_lossy(&e.reason)
32            )),
33        }
34    }
35}
36
37impl std::fmt::Display for ConnectionError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.write_fmt(format_args!("{:?}", self))
40    }
41}
42
43impl std::error::Error for ConnectionError {}
44
45type ConnectionResult<T> = Result<T, ConnectionError>;
46
47impl From<quiche::Error> for ConnectionError {
48    fn from(value: quiche::Error) -> Self {
49        Self::Quic(value)
50    }
51}
52
53impl From<quiche::ConnectionError> for ConnectionError {
54    fn from(value: quiche::ConnectionError) -> Self {
55        Self::Connection(value)
56    }
57}
58
59impl From<std::io::Error> for ConnectionError {
60    fn from(value: std::io::Error) -> Self {
61        Self::Io(value.kind())
62    }
63}
64
65impl From<std::io::ErrorKind> for ConnectionError {
66    fn from(value: std::io::ErrorKind) -> Self {
67        Self::Io(value)
68    }
69}
70
71impl From<ConnectionError> for std::io::Error {
72    fn from(value: ConnectionError) -> Self {
73        match value {
74            ConnectionError::Io(k) => std::io::Error::new(k, ""),
75            o => std::io::Error::new(std::io::ErrorKind::Other, format!("{:?}", o)),
76        }
77    }
78}
79
80pub(super) enum Control {
81    ShouldSend,
82    SendAckEliciting,
83    SetQLog(QLogConfig),
84    Close {
85        app: bool,
86        err: u64,
87        reason: Vec<u8>,
88    },
89    StreamSend {
90        stream_id: u64,
91        data: Vec<u8>,
92        fin: bool,
93        resp: tokio::sync::oneshot::Sender<ConnectionResult<usize>>,
94    },
95    StreamRecv {
96        stream_id: u64,
97        len: usize,
98        resp: tokio::sync::oneshot::Sender<ConnectionResult<(Vec<u8>, bool)>>,
99    },
100    // StreamShutdown {
101    //     stream_id: u64,
102    //     direction: quiche::Shutdown,
103    //     err: u64,
104    //     resp: tokio::sync::oneshot::Sender<ConnectionResult<()>>,
105    // },
106}
107
108#[derive(Debug)]
109pub struct Connection {
110    is_server: bool,
111    control_tx: tokio::sync::mpsc::Sender<Control>,
112    shared_state: std::sync::Arc<SharedConnectionState>,
113    new_stream_rx: Option<tokio::sync::mpsc::Receiver<stream::Stream>>,
114}
115
116pub struct QLogConfig {
117    pub qlog: crate::qlog::QLog,
118    pub title: String,
119    pub description: String,
120    pub level: quiche::QlogLevel,
121}
122
123#[derive(Debug)]
124pub(super) struct SharedConnectionState {
125    connection_established: std::sync::atomic::AtomicBool,
126    connection_established_notify: tokio::sync::Mutex<Vec<std::sync::Arc<tokio::sync::Notify>>>,
127    connection_closed: std::sync::atomic::AtomicBool,
128    connection_closed_notify: tokio::sync::Mutex<Vec<std::sync::Arc<tokio::sync::Notify>>>,
129    pub(super) connection_error: tokio::sync::RwLock<Option<ConnectionError>>,
130}
131
132struct InnerConnectionState {
133    conn: quiche::Connection,
134    socket: tokio::net::UdpSocket,
135    local_addr: std::net::SocketAddr,
136    max_datagram_size: usize,
137    control_rx: tokio::sync::mpsc::Receiver<Control>,
138    control_tx: tokio::sync::mpsc::Sender<Control>,
139    new_stream_tx: tokio::sync::mpsc::Sender<stream::Stream>,
140}
141
142impl Connection {
143    pub async fn connect(
144        peer_addr: std::net::SocketAddr,
145        mut config: quiche::Config,
146        server_name: Option<&str>,
147        qlog: Option<QLogConfig>,
148    ) -> ConnectionResult<Self> {
149        let bind_addr: std::net::SocketAddr = match peer_addr {
150            std::net::SocketAddr::V4(_) => "0.0.0.0:0",
151            std::net::SocketAddr::V6(_) => "[::]:0",
152        }
153        .parse()
154        .unwrap();
155
156        let mut cid = [0; quiche::MAX_CONN_ID_LEN];
157        thread_rng().fill(&mut cid[..]);
158        let cid = quiche::ConnectionId::from_ref(&cid);
159
160        let socket = tokio::net::UdpSocket::bind(bind_addr).await?;
161        let local_addr = socket.local_addr()?;
162        debug!("Connecting to {} from {}", peer_addr, local_addr);
163
164        let mut conn = quiche::connect(server_name, &cid, local_addr, peer_addr, &mut config)?;
165        if let Some(qlog) = qlog {
166            conn.set_qlog_with_level(
167                Box::new(qlog.qlog),
168                qlog.title,
169                qlog.description,
170                qlog.level,
171            );
172        }
173        let max_datagram_size = conn.max_send_udp_payload_size();
174
175        let (control_tx, control_rx) = tokio::sync::mpsc::channel(25);
176        let (new_stream_tx, new_stream_rx) = tokio::sync::mpsc::channel(25);
177
178        let shared_connection_state = std::sync::Arc::new(SharedConnectionState {
179            connection_established: std::sync::atomic::AtomicBool::new(false),
180            connection_established_notify: tokio::sync::Mutex::new(Vec::new()),
181            connection_closed: std::sync::atomic::AtomicBool::new(false),
182            connection_closed_notify: tokio::sync::Mutex::new(Vec::new()),
183            connection_error: tokio::sync::RwLock::new(None),
184        });
185
186        let connection = Connection {
187            is_server: conn.is_server(),
188            control_tx: control_tx.clone(),
189            shared_state: shared_connection_state.clone(),
190            new_stream_rx: Some(new_stream_rx),
191        };
192
193        shared_connection_state.run(InnerConnectionState {
194            conn,
195            socket,
196            local_addr,
197            max_datagram_size,
198            control_rx,
199            control_tx,
200            new_stream_tx,
201        });
202
203        connection.should_send().await.unwrap();
204
205        Ok(connection)
206    }
207
208    async fn send_control(&self, control: Control) -> ConnectionResult<()> {
209        if let Some(err) = self
210            .shared_state
211            .connection_error
212            .read()
213            .await
214            .deref()
215            .clone()
216        {
217            return Err(err);
218        }
219        match self.control_tx.try_send(control) {
220            Ok(_) => {}
221            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {}
222            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
223                if let Some(err) = self
224                    .shared_state
225                    .connection_error
226                    .read()
227                    .await
228                    .deref()
229                    .clone()
230                {
231                    return Err(err);
232                }
233                return Err(std::io::ErrorKind::ConnectionReset.into());
234            }
235        }
236        Ok(())
237    }
238
239    async fn should_send(&self) -> ConnectionResult<()> {
240        self.send_control(Control::ShouldSend).await
241    }
242
243    pub async fn established(&self) -> ConnectionResult<()> {
244        if let Some(err) = self
245            .shared_state
246            .connection_error
247            .read()
248            .await
249            .deref()
250            .clone()
251        {
252            return Err(err);
253        }
254        if self
255            .shared_state
256            .connection_established
257            .load(std::sync::atomic::Ordering::Acquire)
258        {
259            return Ok(());
260        }
261        let notify = std::sync::Arc::new(tokio::sync::Notify::new());
262        self.shared_state
263            .connection_established_notify
264            .lock()
265            .await
266            .push(notify.clone());
267        if let Some(err) = self
268            .shared_state
269            .connection_error
270            .read()
271            .await
272            .deref()
273            .clone()
274        {
275            return Err(err);
276        }
277        notify.notified().await;
278        if let Some(err) = self
279            .shared_state
280            .connection_error
281            .read()
282            .await
283            .deref()
284            .clone()
285        {
286            return Err(err);
287        }
288        Ok(())
289    }
290
291    pub async fn set_qlog(&self, qlog: QLogConfig) -> ConnectionResult<()> {
292        self.send_control(Control::SetQLog(qlog)).await
293    }
294
295    pub async fn send_ack_eliciting(&self) -> ConnectionResult<()> {
296        self.send_control(Control::SendAckEliciting).await
297    }
298
299    pub async fn close(&self, app: bool, err: u64, reason: Vec<u8>) -> ConnectionResult<()> {
300        let notify = std::sync::Arc::new(tokio::sync::Notify::new());
301        self.shared_state
302            .connection_established_notify
303            .lock()
304            .await
305            .push(notify.clone());
306        self.send_control(Control::Close { app, err, reason })
307            .await?;
308        notify.notified().await;
309        if let Some(err) = self
310            .shared_state
311            .connection_error
312            .read()
313            .await
314            .deref()
315            .clone()
316        {
317            return Err(err);
318        }
319        Ok(())
320    }
321
322    pub fn is_server(&self) -> bool {
323        self.is_server
324    }
325
326    pub async fn new_stream(&self, stream_id: u64, bidi: bool) -> ConnectionResult<stream::Stream> {
327        Ok(stream::Stream::new(
328            self.is_server,
329            stream::StreamID::new(stream_id, bidi, self.is_server),
330            self.shared_state.clone(),
331            self.control_tx.clone(),
332        ))
333    }
334
335    pub async fn next_peer_stream(&mut self) -> ConnectionResult<stream::Stream> {
336        match self.new_stream_rx.as_mut().unwrap().recv().await {
337            Some(s) => Ok(s),
338            None => Err(self
339                .shared_state
340                .connection_error
341                .read()
342                .await
343                .clone()
344                .unwrap_or(std::io::ErrorKind::ConnectionReset.into())),
345        }
346    }
347
348    pub fn peer_streams(&mut self) -> ConnectionNewStreams {
349        ConnectionNewStreams {
350            stream_rx: self.new_stream_rx.take().unwrap(),
351            shared_state: self.shared_state.clone(),
352        }
353    }
354}
355
356#[derive(Debug)]
357pub struct ConnectionNewStreams {
358    stream_rx: tokio::sync::mpsc::Receiver<stream::Stream>,
359    shared_state: std::sync::Arc<SharedConnectionState>,
360}
361
362impl ConnectionNewStreams {
363    pub async fn next(&mut self) -> ConnectionResult<stream::Stream> {
364        match self.stream_rx.recv().await {
365            Some(s) => Ok(s),
366            None => Err(self
367                .shared_state
368                .connection_error
369                .read()
370                .await
371                .clone()
372                .unwrap_or(std::io::ErrorKind::ConnectionReset.into())),
373        }
374    }
375}
376
377struct PendingReceive {
378    stream_id: u64,
379    read_len: usize,
380    resp: tokio::sync::oneshot::Sender<ConnectionResult<(Vec<u8>, bool)>>,
381}
382
383impl SharedConnectionState {
384    fn run(self: std::sync::Arc<Self>, mut inner: InnerConnectionState) {
385        let (timeout_tx, mut timeout_rx) = tokio::sync::mpsc::channel(1);
386
387        tokio::task::spawn(async move {
388            let mut buf = [0; 65535];
389            let mut out = vec![0; inner.max_datagram_size];
390            let mut pending_recv: Vec<PendingReceive> = vec![];
391            let mut known_stream_ids = std::collections::HashSet::new();
392
393            'outer: loop {
394                tokio::select! {
395                    res = inner.socket.recv_from(&mut buf) => {
396                        let (len, addr) = match res {
397                            Ok(v) => v,
398                            Err(e) => {
399                                self.set_error(e.into()).await;
400                                break;
401                            }
402                        };
403                        let recv_info = quiche::RecvInfo {
404                            from: addr,
405                            to: inner.local_addr
406                        };
407
408                        let read = match inner.conn.recv(&mut buf[..len], recv_info) {
409                            Ok(v) => v,
410                            Err(quiche::Error::Done) => {
411                                continue;
412                            },
413                            Err(e) => {
414                                self.set_error(e.into()).await;
415                                break;
416                            },
417                        };
418                        trace!("Received {} bytes", read);
419                        if inner.conn.is_established() {
420                            self.set_established().await;
421                        }
422                        inner.control_tx.send(Control::ShouldSend).await.unwrap();
423
424                        let readable = pending_recv
425                            .extract_if(|s| inner.conn.stream_readable(s.stream_id))
426                            .collect::<Vec<_>>();
427                        for s in readable {
428                            let mut buf = vec![0u8; s.read_len];
429                            match inner.conn.stream_recv(s.stream_id, &mut buf) {
430                                Ok((read, fin)) => {
431                                    let out = buf[..read].to_vec();
432                                    let _ = s.resp.send(Ok((out, fin)));
433                                }
434                                Err(e) => {
435                                    let _ = s.resp.send(Err(e.into()));
436                                }
437                            }
438                        }
439
440                        let new_stream_ids = inner.conn.readable().filter(|stream_id| {
441                            let client_flag = stream_id & 1;
442                            if inner.conn.is_server() && client_flag == 1 {
443                                return false;
444                            }
445                            if !inner.conn.is_server() && client_flag == 0 {
446                                return false;
447                            }
448                            if known_stream_ids.contains(stream_id) {
449                                return false;
450                            }
451                            known_stream_ids.insert(*stream_id);
452                            true
453                        }).collect::<Vec<_>>();
454                        for stream in new_stream_ids {
455                            let _ = inner.new_stream_tx.send(stream::Stream::new(
456                                inner.conn.is_server(), stream::StreamID(stream),
457                                self.clone(), inner.control_tx.clone(),
458                            )).await;
459                        }
460                    }
461                    c = inner.control_rx.recv() => {
462                        let c = match c {
463                            Some(c) => c,
464                            None => break
465                        };
466                        match c {
467                            Control::ShouldSend => if !inner.conn.is_draining() {
468                                loop {
469                                    let (write, send_info) = match inner.conn.send(&mut out) {
470                                        Ok(v) => v,
471                                        Err(quiche::Error::Done) => {
472                                            break;
473                                        },
474                                        Err(e) => {
475                                            self.set_error(e.into()).await;
476                                            break 'outer;
477                                        }
478                                    };
479                                    if inner.conn.is_established() {
480                                        self.set_established().await;
481                                    }
482                                    if let Err(e) = inner.socket.send_to(&out[..write], &send_info.to).await {
483                                        self.set_error(e.into()).await;
484                                        break;
485                                    }
486                                    trace!("Sent {} bytes", write);
487                                    if let Some(timeout) = inner.conn.timeout() {
488                                        let inner_timeout_tx = timeout_tx.clone();
489                                        tokio::task::spawn(async move {
490                                            tokio::time::sleep(timeout).await;
491                                            let _ = inner_timeout_tx.send(()).await;
492                                        });
493                                    }
494                                }
495                            },
496                            Control::SendAckEliciting => {
497                                if let Err(e) = inner.conn.send_ack_eliciting() {
498                                    self.set_error(e.into()).await;
499                                    break;
500                                }
501                            }
502                            Control::StreamSend { stream_id, data, fin, resp} => {
503                                let _ = resp.send(
504                                    inner.conn.stream_send(stream_id, &data, fin)
505                                        .map_err(|e| e.into())
506                                );
507                            }
508                            Control::StreamRecv { stream_id, len, resp } => {
509                                let mut buf = vec![0u8; len];
510                                match inner.conn.stream_recv(stream_id, &mut buf) {
511                                    Ok((read, fin)) => {
512                                        let out = buf[..read].to_vec();
513                                        let _ = resp.send(Ok((out, fin)));
514                                    }
515                                    Err(quiche::Error::Done) => {
516                                        pending_recv.push(PendingReceive {
517                                            stream_id,
518                                            read_len: len,
519                                            resp
520                                        });
521                                    }
522                                    Err(e) => {
523                                        let _ = resp.send(Err(e.into()));
524                                    }
525                                }
526                            }
527                            // Control::StreamShutdown { stream_id, direction, err, resp} => {
528                            //     let _ = resp.send(
529                            //         inner.conn.stream_shutdown(stream_id, direction, err)
530                            //             .map_err(|e| e.into())
531                            //     );
532                            // }
533                            Control::SetQLog(qlog) => {
534                                inner.conn.set_qlog_with_level(
535                                    Box::new(qlog.qlog),
536                                    qlog.title,
537                                    qlog.description,
538                                    qlog.level,
539                                );
540                            }
541                            Control::Close { app, err, reason } => {
542                                if let Err(e) = inner.conn.close(app, err, &reason) {
543                                    self.set_error(e.into()).await;
544                                    break;
545                                }
546                            }
547                        }
548                    }
549                    _ = timeout_rx.recv() => {
550                        trace!("On timeout");
551                        inner.conn.on_timeout();
552                        inner.control_tx.send(Control::ShouldSend).await.unwrap();
553                    }
554                }
555
556                if inner.conn.is_closed() {
557                    if let Some(err) = inner.conn.peer_error() {
558                        self.connection_error
559                            .write()
560                            .await
561                            .replace(err.clone().into());
562                    } else if let Some(err) = inner.conn.local_error() {
563                        self.connection_error
564                            .write()
565                            .await
566                            .replace(err.clone().into());
567                    } else if inner.conn.is_timed_out() {
568                        self.connection_error
569                            .write()
570                            .await
571                            .replace(std::io::ErrorKind::TimedOut.into());
572                    } else {
573                        self.connection_error
574                            .write()
575                            .await
576                            .replace(std::io::ErrorKind::ConnectionReset.into());
577                    }
578                    self.set_closed().await;
579                    break;
580                }
581            }
582        });
583    }
584
585    async fn set_error(&self, error: ConnectionError) {
586        self.connection_error.write().await.replace(error);
587        self.notify_connection_established().await;
588    }
589
590    async fn notify_connection_established(&self) {
591        for n in self.connection_established_notify.lock().await.drain(..) {
592            n.notify_one();
593        }
594    }
595
596    async fn notify_connection_closed(&self) {
597        for n in self.connection_closed_notify.lock().await.drain(..) {
598            n.notify_one();
599        }
600        self.notify_connection_established().await;
601    }
602
603    async fn set_established(&self) {
604        self.connection_established
605            .store(true, std::sync::atomic::Ordering::Relaxed);
606        self.notify_connection_established().await;
607    }
608
609    async fn set_closed(&self) {
610        self.connection_closed
611            .store(true, std::sync::atomic::Ordering::Relaxed);
612        self.notify_connection_closed().await;
613    }
614}