kvarn_quinn/
connection.rs

1use std::{
2    any::Any,
3    fmt,
4    future::Future,
5    net::{IpAddr, SocketAddr},
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll, Waker},
9    time::{Duration, Instant},
10};
11
12use crate::runtime::{AsyncTimer, AsyncUdpSocket, Runtime};
13use bytes::{Bytes, BytesMut};
14use pin_project_lite::pin_project;
15use proto::{ConnectionError, ConnectionHandle, ConnectionStats, Dir, StreamEvent, StreamId};
16use rustc_hash::FxHashMap;
17use thiserror::Error;
18use tokio::sync::{futures::Notified, mpsc, oneshot, Notify};
19use tracing::{debug_span, Instrument, Span};
20
21use crate::{
22    mutex::Mutex,
23    recv_stream::RecvStream,
24    send_stream::{SendStream, WriteError},
25    ConnectionEvent, EndpointEvent, VarInt,
26};
27use proto::congestion::Controller;
28
29/// In-progress connection attempt future
30#[derive(Debug)]
31#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
32pub struct Connecting {
33    conn: Option<ConnectionRef>,
34    connected: oneshot::Receiver<bool>,
35    handshake_data_ready: Option<oneshot::Receiver<()>>,
36}
37
38impl Connecting {
39    pub(crate) fn new(
40        handle: ConnectionHandle,
41        conn: proto::Connection,
42        endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
43        conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
44        socket: Arc<dyn AsyncUdpSocket>,
45        runtime: Arc<dyn Runtime>,
46    ) -> Self {
47        let (on_handshake_data_send, on_handshake_data_recv) = oneshot::channel();
48        let (on_connected_send, on_connected_recv) = oneshot::channel();
49        let conn = ConnectionRef::new(
50            handle,
51            conn,
52            endpoint_events,
53            conn_events,
54            on_handshake_data_send,
55            on_connected_send,
56            socket,
57            runtime.clone(),
58        );
59
60        runtime.spawn(Box::pin(
61            ConnectionDriver(conn.clone()).instrument(Span::current()),
62        ));
63
64        Self {
65            conn: Some(conn),
66            connected: on_connected_recv,
67            handshake_data_ready: Some(on_handshake_data_recv),
68        }
69    }
70
71    /// Convert into a 0-RTT or 0.5-RTT connection at the cost of weakened security
72    ///
73    /// Opens up the connection for use before the handshake finishes, allowing the API user to
74    /// send data with 0-RTT encryption if the necessary key material is available. This is useful
75    /// for reducing start-up latency by beginning transmission of application data without waiting
76    /// for the handshake's cryptographic security guarantees to be established.
77    ///
78    /// When the `ZeroRttAccepted` future completes, the connection has been fully established.
79    ///
80    /// # Security
81    ///
82    /// On outgoing connections, this enables transmission of 0-RTT data, which might be vulnerable
83    /// to replay attacks, and should therefore never invoke non-idempotent operations.
84    ///
85    /// On incoming connections, this enables transmission of 0.5-RTT data, which might be
86    /// intercepted by a man-in-the-middle. If this occurs, the handshake will not complete
87    /// successfully.
88    ///
89    /// # Errors
90    ///
91    /// Outgoing connections are only 0-RTT-capable when a cryptographic session ticket cached from
92    /// a previous connection to the same server is available, and includes a 0-RTT key. If no such
93    /// ticket is found, `self` is returned unmodified.
94    ///
95    /// For incoming connections, a 0.5-RTT connection will always be successfully constructed.
96    pub fn into_0rtt(mut self) -> Result<(Connection, ZeroRttAccepted), Self> {
97        // This lock borrows `self` and would normally be dropped at the end of this scope, so we'll
98        // have to release it explicitly before returning `self` by value.
99        let conn = (self.conn.as_mut().unwrap()).state.lock("into_0rtt");
100
101        let is_ok = conn.inner.has_0rtt() || conn.inner.side().is_server();
102        drop(conn);
103
104        if is_ok {
105            let conn = self.conn.take().unwrap();
106            Ok((Connection(conn), ZeroRttAccepted(self.connected)))
107        } else {
108            Err(self)
109        }
110    }
111
112    /// Parameters negotiated during the handshake
113    ///
114    /// The dynamic type returned is determined by the configured
115    /// [`Session`](proto::crypto::Session). For the default `rustls` session, the return value can
116    /// be [`downcast`](Box::downcast) to a
117    /// [`crypto::rustls::HandshakeData`](crate::crypto::rustls::HandshakeData).
118    pub async fn handshake_data(&mut self) -> Result<Box<dyn Any>, ConnectionError> {
119        // Taking &mut self allows us to use a single oneshot channel rather than dealing with
120        // potentially many tasks waiting on the same event. It's a bit of a hack, but keeps things
121        // simple.
122        if let Some(x) = self.handshake_data_ready.take() {
123            let _ = x.await;
124        }
125        let conn = self.conn.as_ref().unwrap();
126        let inner = conn.state.lock("handshake");
127        inner
128            .inner
129            .crypto_session()
130            .handshake_data()
131            .ok_or_else(|| {
132                inner
133                    .error
134                    .clone()
135                    .expect("spurious handshake data ready notification")
136            })
137    }
138
139    /// The local IP address which was used when the peer established
140    /// the connection
141    ///
142    /// This can be different from the address the endpoint is bound to, in case
143    /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`.
144    ///
145    /// This will return `None` for clients.
146    ///
147    /// Retrieving the local IP address is currently supported on the following
148    /// platforms:
149    /// - Linux
150    /// - FreeBSD
151    /// - macOS
152    ///
153    /// On all non-supported platforms the local IP address will not be available,
154    /// and the method will return `None`.
155    pub fn local_ip(&self) -> Option<IpAddr> {
156        let conn = self.conn.as_ref().unwrap();
157        let inner = conn.state.lock("local_ip");
158
159        inner.inner.local_ip()
160    }
161
162    /// The peer's UDP address.
163    ///
164    /// Will panic if called after `poll` has returned `Ready`.
165    pub fn remote_address(&self) -> SocketAddr {
166        let conn_ref: &ConnectionRef = self.conn.as_ref().expect("used after yielding Ready");
167        conn_ref.state.lock("remote_address").inner.remote_address()
168    }
169}
170
171impl Future for Connecting {
172    type Output = Result<Connection, ConnectionError>;
173    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
174        Pin::new(&mut self.connected).poll(cx).map(|_| {
175            let conn = self.conn.take().unwrap();
176            let inner = conn.state.lock("connecting");
177            if inner.connected {
178                drop(inner);
179                Ok(Connection(conn))
180            } else {
181                Err(inner
182                    .error
183                    .clone()
184                    .expect("connected signaled without connection success or error"))
185            }
186        })
187    }
188}
189
190/// Future that completes when a connection is fully established
191///
192/// For clients, the resulting value indicates if 0-RTT was accepted. For servers, the resulting
193/// value is meaningless.
194#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"]
195pub struct ZeroRttAccepted(oneshot::Receiver<bool>);
196
197impl Future for ZeroRttAccepted {
198    type Output = bool;
199    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
200        Pin::new(&mut self.0).poll(cx).map(|x| x.unwrap_or(false))
201    }
202}
203
204/// A future that drives protocol logic for a connection
205///
206/// This future handles the protocol logic for a single connection, routing events from the
207/// `Connection` API object to the `Endpoint` task and the related stream-related interfaces.
208/// It also keeps track of outstanding timeouts for the `Connection`.
209///
210/// If the connection encounters an error condition, this future will yield an error. It will
211/// terminate (yielding `Ok(())`) if the connection was closed without error. Unlike other
212/// connection-related futures, this waits for the draining period to complete to ensure that
213/// packets still in flight from the peer are handled gracefully.
214#[must_use = "connection drivers must be spawned for their connections to function"]
215#[derive(Debug)]
216struct ConnectionDriver(ConnectionRef);
217
218impl Future for ConnectionDriver {
219    type Output = ();
220
221    #[allow(unused_mut)] // MSRV
222    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
223        let conn = &mut *self.0.state.lock("poll");
224
225        let span = debug_span!("drive", id = conn.handle.0);
226        let _guard = span.enter();
227
228        if let Err(e) = conn.process_conn_events(&self.0.shared, cx) {
229            conn.terminate(e, &self.0.shared);
230            return Poll::Ready(());
231        }
232        let mut keep_going = conn.drive_transmit();
233        // If a timer expires, there might be more to transmit. When we transmit something, we
234        // might need to reset a timer. Hence, we must loop until neither happens.
235        keep_going |= conn.drive_timer(cx);
236        conn.forward_endpoint_events();
237        conn.forward_app_events(&self.0.shared);
238
239        if !conn.inner.is_drained() {
240            if keep_going {
241                // If the connection hasn't processed all tasks, schedule it again
242                cx.waker().wake_by_ref();
243            } else {
244                conn.driver = Some(cx.waker().clone());
245            }
246            return Poll::Pending;
247        }
248        if conn.error.is_none() {
249            unreachable!("drained connections always have an error");
250        }
251        Poll::Ready(())
252    }
253}
254
255/// A QUIC connection.
256///
257/// If all references to a connection (including every clone of the `Connection` handle, streams of
258/// incoming streams, and the various stream types) have been dropped, then the connection will be
259/// automatically closed with an `error_code` of 0 and an empty `reason`. You can also close the
260/// connection explicitly by calling [`Connection::close()`].
261///
262/// May be cloned to obtain another handle to the same connection.
263///
264/// [`Connection::close()`]: Connection::close
265#[derive(Debug, Clone)]
266pub struct Connection(ConnectionRef);
267
268impl Connection {
269    /// Initiate a new outgoing unidirectional stream.
270    ///
271    /// Streams are cheap and instantaneous to open unless blocked by flow control. As a
272    /// consequence, the peer won't be notified that a stream has been opened until the stream is
273    /// actually used.
274    pub fn open_uni(&self) -> OpenUni<'_> {
275        OpenUni {
276            conn: &self.0,
277            notify: self.0.shared.stream_budget_available[Dir::Uni as usize].notified(),
278        }
279    }
280
281    /// Initiate a new outgoing bidirectional stream.
282    ///
283    /// Streams are cheap and instantaneous to open unless blocked by flow control. As a
284    /// consequence, the peer won't be notified that a stream has been opened until the stream is
285    /// actually used. Calling [`open_bi()`] then waiting on the [`RecvStream`] without writing
286    /// anything to [`SendStream`] will never succeed.
287    ///
288    /// [`open_bi()`]: crate::Connection::open_bi
289    /// [`SendStream`]: crate::SendStream
290    /// [`RecvStream`]: crate::RecvStream
291    pub fn open_bi(&self) -> OpenBi<'_> {
292        OpenBi {
293            conn: &self.0,
294            notify: self.0.shared.stream_budget_available[Dir::Bi as usize].notified(),
295        }
296    }
297
298    /// Accept the next incoming uni-directional stream
299    pub fn accept_uni(&self) -> AcceptUni<'_> {
300        AcceptUni {
301            conn: &self.0,
302            notify: self.0.shared.stream_incoming[Dir::Uni as usize].notified(),
303        }
304    }
305
306    /// Accept the next incoming bidirectional stream
307    ///
308    /// **Important Note**: The `Connection` that calls [`open_bi()`] must write to its [`SendStream`]
309    /// before the other `Connection` is able to `accept_bi()`. Calling [`open_bi()`] then
310    /// waiting on the [`RecvStream`] without writing anything to [`SendStream`] will never succeed.
311    ///
312    /// [`accept_bi()`]: crate::Connection::accept_bi
313    /// [`open_bi()`]: crate::Connection::open_bi
314    /// [`SendStream`]: crate::SendStream
315    /// [`RecvStream`]: crate::RecvStream
316    pub fn accept_bi(&self) -> AcceptBi<'_> {
317        AcceptBi {
318            conn: &self.0,
319            notify: self.0.shared.stream_incoming[Dir::Bi as usize].notified(),
320        }
321    }
322
323    /// Receive an application datagram
324    pub fn read_datagram(&self) -> ReadDatagram<'_> {
325        ReadDatagram {
326            conn: &self.0,
327            notify: self.0.shared.datagrams.notified(),
328        }
329    }
330
331    /// Wait for the connection to be closed for any reason
332    ///
333    /// Despite the return type's name, closed connections are often not an error condition at the
334    /// application layer. Cases that might be routine include [`ConnectionError::LocallyClosed`]
335    /// and [`ConnectionError::ApplicationClosed`].
336    pub async fn closed(&self) -> ConnectionError {
337        {
338            let conn = self.0.state.lock("closed");
339            if let Some(error) = conn.error.as_ref() {
340                return error.clone();
341            }
342            // Construct the future while the lock is held to ensure we can't miss a wakeup if
343            // the `Notify` is signaled immediately after we release the lock. `await` it after
344            // the lock guard is out of scope.
345            self.0.shared.closed.notified()
346        }
347        .await;
348        self.0
349            .state
350            .lock("closed")
351            .error
352            .as_ref()
353            .expect("closed without an error")
354            .clone()
355    }
356
357    /// If the connection is closed, the reason why.
358    ///
359    /// Returns `None` if the connection is still open.
360    pub fn close_reason(&self) -> Option<ConnectionError> {
361        self.0.state.lock("close_reason").error.clone()
362    }
363
364    /// Close the connection immediately.
365    ///
366    /// Pending operations will fail immediately with [`ConnectionError::LocallyClosed`]. Delivery
367    /// of data on unfinished streams is not guaranteed, so the application must call this only
368    /// when all important communications have been completed, e.g. by calling [`finish`] on
369    /// outstanding [`SendStream`]s and waiting for the resulting futures to complete.
370    ///
371    /// `error_code` and `reason` are not interpreted, and are provided directly to the peer.
372    ///
373    /// `reason` will be truncated to fit in a single packet with overhead; to improve odds that it
374    /// is preserved in full, it should be kept under 1KiB.
375    ///
376    /// [`ConnectionError::LocallyClosed`]: crate::ConnectionError::LocallyClosed
377    /// [`finish`]: crate::SendStream::finish
378    /// [`SendStream`]: crate::SendStream
379    pub fn close(&self, error_code: VarInt, reason: &[u8]) {
380        let conn = &mut *self.0.state.lock("close");
381        conn.close(error_code, Bytes::copy_from_slice(reason), &self.0.shared);
382    }
383
384    /// Transmit `data` as an unreliable, unordered application datagram
385    ///
386    /// Application datagrams are a low-level primitive. They may be lost or delivered out of order,
387    /// and `data` must both fit inside a single QUIC packet and be smaller than the maximum
388    /// dictated by the peer.
389    pub fn send_datagram(&self, data: Bytes) -> Result<(), SendDatagramError> {
390        let conn = &mut *self.0.state.lock("send_datagram");
391        if let Some(ref x) = conn.error {
392            return Err(SendDatagramError::ConnectionLost(x.clone()));
393        }
394        use proto::SendDatagramError::*;
395        match conn.inner.datagrams().send(data) {
396            Ok(()) => {
397                conn.wake();
398                Ok(())
399            }
400            Err(e) => Err(match e {
401                UnsupportedByPeer => SendDatagramError::UnsupportedByPeer,
402                Disabled => SendDatagramError::Disabled,
403                TooLarge => SendDatagramError::TooLarge,
404            }),
405        }
406    }
407
408    /// Compute the maximum size of datagrams that may be passed to [`send_datagram()`].
409    ///
410    /// Returns `None` if datagrams are unsupported by the peer or disabled locally.
411    ///
412    /// This may change over the lifetime of a connection according to variation in the path MTU
413    /// estimate. The peer can also enforce an arbitrarily small fixed limit, but if the peer's
414    /// limit is large this is guaranteed to be a little over a kilobyte at minimum.
415    ///
416    /// Not necessarily the maximum size of received datagrams.
417    ///
418    /// [`send_datagram()`]: Connection::send_datagram
419    pub fn max_datagram_size(&self) -> Option<usize> {
420        self.0
421            .state
422            .lock("max_datagram_size")
423            .inner
424            .datagrams()
425            .max_size()
426    }
427
428    /// Bytes available in the outgoing datagram buffer
429    ///
430    /// When greater than zero, calling [`send_datagram()`](Self::send_datagram) with a datagram of
431    /// at most this size is guaranteed not to cause older datagrams to be dropped.
432    pub fn datagram_send_buffer_space(&self) -> usize {
433        self.0
434            .state
435            .lock("datagram_send_buffer_space")
436            .inner
437            .datagrams()
438            .send_buffer_space()
439    }
440
441    /// The peer's UDP address
442    ///
443    /// If `ServerConfig::migration` is `true`, clients may change addresses at will, e.g. when
444    /// switching to a cellular internet connection.
445    pub fn remote_address(&self) -> SocketAddr {
446        self.0.state.lock("remote_address").inner.remote_address()
447    }
448
449    /// The local IP address which was used when the peer established
450    /// the connection
451    ///
452    /// This can be different from the address the endpoint is bound to, in case
453    /// the endpoint is bound to a wildcard address like `0.0.0.0` or `::`.
454    ///
455    /// This will return `None` for clients.
456    ///
457    /// Retrieving the local IP address is currently supported on the following
458    /// platforms:
459    /// - Linux
460    ///
461    /// On all non-supported platforms the local IP address will not be available,
462    /// and the method will return `None`.
463    pub fn local_ip(&self) -> Option<IpAddr> {
464        self.0.state.lock("local_ip").inner.local_ip()
465    }
466
467    /// Current best estimate of this connection's latency (round-trip-time)
468    pub fn rtt(&self) -> Duration {
469        self.0.state.lock("rtt").inner.rtt()
470    }
471
472    /// Returns connection statistics
473    pub fn stats(&self) -> ConnectionStats {
474        self.0.state.lock("stats").inner.stats()
475    }
476
477    /// Current state of the congestion control algorithm, for debugging purposes
478    pub fn congestion_state(&self) -> Box<dyn Controller> {
479        self.0
480            .state
481            .lock("congestion_state")
482            .inner
483            .congestion_state()
484            .clone_box()
485    }
486
487    /// Parameters negotiated during the handshake
488    ///
489    /// Guaranteed to return `Some` on fully established connections or after
490    /// [`Connecting::handshake_data()`] succeeds. See that method's documentations for details on
491    /// the returned value.
492    ///
493    /// [`Connection::handshake_data()`]: crate::Connecting::handshake_data
494    pub fn handshake_data(&self) -> Option<Box<dyn Any>> {
495        self.0
496            .state
497            .lock("handshake_data")
498            .inner
499            .crypto_session()
500            .handshake_data()
501    }
502
503    /// Cryptographic identity of the peer
504    ///
505    /// The dynamic type returned is determined by the configured
506    /// [`Session`](proto::crypto::Session). For the default `rustls` session, the return value can
507    /// be [`downcast`](Box::downcast) to a <code>Vec<[rustls::Certificate]></code>
508    pub fn peer_identity(&self) -> Option<Box<dyn Any>> {
509        self.0
510            .state
511            .lock("peer_identity")
512            .inner
513            .crypto_session()
514            .peer_identity()
515    }
516
517    /// A stable identifier for this connection
518    ///
519    /// Peer addresses and connection IDs can change, but this value will remain
520    /// fixed for the lifetime of the connection.
521    pub fn stable_id(&self) -> usize {
522        self.0.stable_id()
523    }
524
525    // Update traffic keys spontaneously for testing purposes.
526    #[doc(hidden)]
527    pub fn force_key_update(&self) {
528        self.0
529            .state
530            .lock("force_key_update")
531            .inner
532            .initiate_key_update()
533    }
534
535    /// Derive keying material from this connection's TLS session secrets.
536    ///
537    /// When both peers call this method with the same `label` and `context`
538    /// arguments and `output` buffers of equal length, they will get the
539    /// same sequence of bytes in `output`. These bytes are cryptographically
540    /// strong and pseudorandom, and are suitable for use as keying material.
541    ///
542    /// See [RFC5705](https://tools.ietf.org/html/rfc5705) for more information.
543    pub fn export_keying_material(
544        &self,
545        output: &mut [u8],
546        label: &[u8],
547        context: &[u8],
548    ) -> Result<(), proto::crypto::ExportKeyingMaterialError> {
549        self.0
550            .state
551            .lock("export_keying_material")
552            .inner
553            .crypto_session()
554            .export_keying_material(output, label, context)
555    }
556
557    /// Modify the number of remotely initiated unidirectional streams that may be concurrently open
558    ///
559    /// No streams may be opened by the peer unless fewer than `count` are already open. Large
560    /// `count`s increase both minimum and worst-case memory consumption.
561    pub fn set_max_concurrent_uni_streams(&self, count: VarInt) {
562        let mut conn = self.0.state.lock("set_max_concurrent_uni_streams");
563        conn.inner.set_max_concurrent_streams(Dir::Uni, count);
564        // May need to send MAX_STREAMS to make progress
565        conn.wake();
566    }
567
568    /// See [`proto::TransportConfig::receive_window()`]
569    pub fn set_receive_window(&self, receive_window: VarInt) {
570        let mut conn = self.0.state.lock("set_receive_window");
571        conn.inner.set_receive_window(receive_window);
572        conn.wake();
573    }
574
575    /// Modify the number of remotely initiated bidirectional streams that may be concurrently open
576    ///
577    /// No streams may be opened by the peer unless fewer than `count` are already open. Large
578    /// `count`s increase both minimum and worst-case memory consumption.
579    pub fn set_max_concurrent_bi_streams(&self, count: VarInt) {
580        let mut conn = self.0.state.lock("set_max_concurrent_bi_streams");
581        conn.inner.set_max_concurrent_streams(Dir::Bi, count);
582        // May need to send MAX_STREAMS to make progress
583        conn.wake();
584    }
585}
586
587pin_project! {
588    /// Future produced by [`Connection::open_uni`]
589    pub struct OpenUni<'a> {
590        conn: &'a ConnectionRef,
591        #[pin]
592        notify: Notified<'a>,
593    }
594}
595
596impl Future for OpenUni<'_> {
597    type Output = Result<SendStream, ConnectionError>;
598    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
599        let this = self.project();
600        let (conn, id, is_0rtt) = ready!(poll_open(ctx, this.conn, this.notify, Dir::Uni))?;
601        Poll::Ready(Ok(SendStream::new(conn, id, is_0rtt)))
602    }
603}
604
605pin_project! {
606    /// Future produced by [`Connection::open_bi`]
607    pub struct OpenBi<'a> {
608        conn: &'a ConnectionRef,
609        #[pin]
610        notify: Notified<'a>,
611    }
612}
613
614impl Future for OpenBi<'_> {
615    type Output = Result<(SendStream, RecvStream), ConnectionError>;
616    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
617        let this = self.project();
618        let (conn, id, is_0rtt) = ready!(poll_open(ctx, this.conn, this.notify, Dir::Bi))?;
619
620        Poll::Ready(Ok((
621            SendStream::new(conn.clone(), id, is_0rtt),
622            RecvStream::new(conn, id, is_0rtt),
623        )))
624    }
625}
626
627fn poll_open<'a>(
628    ctx: &mut Context<'_>,
629    conn: &'a ConnectionRef,
630    mut notify: Pin<&mut Notified<'a>>,
631    dir: Dir,
632) -> Poll<Result<(ConnectionRef, StreamId, bool), ConnectionError>> {
633    let mut state = conn.state.lock("poll_open");
634    if let Some(ref e) = state.error {
635        return Poll::Ready(Err(e.clone()));
636    } else if let Some(id) = state.inner.streams().open(dir) {
637        let is_0rtt = state.inner.side().is_client() && state.inner.is_handshaking();
638        drop(state); // Release the lock so clone can take it
639        return Poll::Ready(Ok((conn.clone(), id, is_0rtt)));
640    }
641    loop {
642        match notify.as_mut().poll(ctx) {
643            // `state` lock ensures we didn't race with readiness
644            Poll::Pending => return Poll::Pending,
645            // Spurious wakeup, get a new future
646            Poll::Ready(()) => {
647                notify.set(conn.shared.stream_budget_available[dir as usize].notified())
648            }
649        }
650    }
651}
652
653pin_project! {
654    /// Future produced by [`Connection::accept_uni`]
655    pub struct AcceptUni<'a> {
656        conn: &'a ConnectionRef,
657        #[pin]
658        notify: Notified<'a>,
659    }
660}
661
662impl Future for AcceptUni<'_> {
663    type Output = Result<RecvStream, ConnectionError>;
664
665    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
666        let this = self.project();
667        let (conn, id, is_0rtt) = ready!(poll_accept(ctx, this.conn, this.notify, Dir::Uni))?;
668        Poll::Ready(Ok(RecvStream::new(conn, id, is_0rtt)))
669    }
670}
671
672pin_project! {
673    /// Future produced by [`Connection::accept_bi`]
674    pub struct AcceptBi<'a> {
675        conn: &'a ConnectionRef,
676        #[pin]
677        notify: Notified<'a>,
678    }
679}
680
681impl Future for AcceptBi<'_> {
682    type Output = Result<(SendStream, RecvStream), ConnectionError>;
683
684    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
685        let this = self.project();
686        let (conn, id, is_0rtt) = ready!(poll_accept(ctx, this.conn, this.notify, Dir::Bi))?;
687        Poll::Ready(Ok((
688            SendStream::new(conn.clone(), id, is_0rtt),
689            RecvStream::new(conn, id, is_0rtt),
690        )))
691    }
692}
693
694fn poll_accept<'a>(
695    ctx: &mut Context<'_>,
696    conn: &'a ConnectionRef,
697    mut notify: Pin<&mut Notified<'a>>,
698    dir: Dir,
699) -> Poll<Result<(ConnectionRef, StreamId, bool), ConnectionError>> {
700    let mut state = conn.state.lock("poll_accept");
701    // Check for incoming streams before checking `state.error` so that already-received streams,
702    // which are necessarily finite, can be drained from a closed connection.
703    if let Some(id) = state.inner.streams().accept(dir) {
704        let is_0rtt = state.inner.is_handshaking();
705        state.wake(); // To send additional stream ID credit
706        drop(state); // Release the lock so clone can take it
707        return Poll::Ready(Ok((conn.clone(), id, is_0rtt)));
708    } else if let Some(ref e) = state.error {
709        return Poll::Ready(Err(e.clone()));
710    }
711    loop {
712        match notify.as_mut().poll(ctx) {
713            // `state` lock ensures we didn't race with readiness
714            Poll::Pending => return Poll::Pending,
715            // Spurious wakeup, get a new future
716            Poll::Ready(()) => notify.set(conn.shared.stream_incoming[dir as usize].notified()),
717        }
718    }
719}
720
721pin_project! {
722    /// Future produced by [`Connection::read_datagram`]
723    pub struct ReadDatagram<'a> {
724        conn: &'a ConnectionRef,
725        #[pin]
726        notify: Notified<'a>,
727    }
728}
729
730impl Future for ReadDatagram<'_> {
731    type Output = Result<Bytes, ConnectionError>;
732    fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
733        let mut this = self.project();
734        let mut state = this.conn.state.lock("ReadDatagram::poll");
735        // Check for buffered datagrams before checking `state.error` so that already-received
736        // datagrams, which are necessarily finite, can be drained from a closed connection.
737        if let Some(x) = state.inner.datagrams().recv() {
738            return Poll::Ready(Ok(x));
739        } else if let Some(ref e) = state.error {
740            return Poll::Ready(Err(e.clone()));
741        }
742        loop {
743            match this.notify.as_mut().poll(ctx) {
744                // `state` lock ensures we didn't race with readiness
745                Poll::Pending => return Poll::Pending,
746                // Spurious wakeup, get a new future
747                Poll::Ready(()) => this.notify.set(this.conn.shared.datagrams.notified()),
748            }
749        }
750    }
751}
752
753#[derive(Debug)]
754pub(crate) struct ConnectionRef(Arc<ConnectionInner>);
755
756impl ConnectionRef {
757    #[allow(clippy::too_many_arguments)]
758    fn new(
759        handle: ConnectionHandle,
760        conn: proto::Connection,
761        endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
762        conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
763        on_handshake_data: oneshot::Sender<()>,
764        on_connected: oneshot::Sender<bool>,
765        socket: Arc<dyn AsyncUdpSocket>,
766        runtime: Arc<dyn Runtime>,
767    ) -> Self {
768        Self(Arc::new(ConnectionInner {
769            state: Mutex::new(State {
770                inner: conn,
771                driver: None,
772                handle,
773                on_handshake_data: Some(on_handshake_data),
774                on_connected: Some(on_connected),
775                connected: false,
776                timer: None,
777                timer_deadline: None,
778                conn_events,
779                endpoint_events,
780                blocked_writers: FxHashMap::default(),
781                blocked_readers: FxHashMap::default(),
782                finishing: FxHashMap::default(),
783                stopped: FxHashMap::default(),
784                error: None,
785                ref_count: 0,
786                socket,
787                runtime,
788            }),
789            shared: Shared::default(),
790        }))
791    }
792
793    fn stable_id(&self) -> usize {
794        &*self.0 as *const _ as usize
795    }
796}
797
798impl Clone for ConnectionRef {
799    fn clone(&self) -> Self {
800        self.state.lock("clone").ref_count += 1;
801        Self(self.0.clone())
802    }
803}
804
805impl Drop for ConnectionRef {
806    fn drop(&mut self) {
807        let conn = &mut *self.state.lock("drop");
808        if let Some(x) = conn.ref_count.checked_sub(1) {
809            conn.ref_count = x;
810            if x == 0 && !conn.inner.is_closed() {
811                // If the driver is alive, it's just it and us, so we'd better shut it down. If it's
812                // not, we can't do any harm. If there were any streams being opened, then either
813                // the connection will be closed for an unrelated reason or a fresh reference will
814                // be constructed for the newly opened stream.
815                conn.implicit_close(&self.shared);
816            }
817        }
818    }
819}
820
821impl std::ops::Deref for ConnectionRef {
822    type Target = ConnectionInner;
823    fn deref(&self) -> &Self::Target {
824        &self.0
825    }
826}
827
828#[derive(Debug)]
829pub(crate) struct ConnectionInner {
830    pub(crate) state: Mutex<State>,
831    pub(crate) shared: Shared,
832}
833
834#[derive(Debug, Default)]
835pub(crate) struct Shared {
836    /// Notified when new streams may be locally initiated due to an increase in stream ID flow
837    /// control budget
838    stream_budget_available: [Notify; 2],
839    /// Notified when the peer has initiated a new stream
840    stream_incoming: [Notify; 2],
841    datagrams: Notify,
842    closed: Notify,
843}
844
845pub(crate) struct State {
846    pub(crate) inner: proto::Connection,
847    driver: Option<Waker>,
848    handle: ConnectionHandle,
849    on_handshake_data: Option<oneshot::Sender<()>>,
850    on_connected: Option<oneshot::Sender<bool>>,
851    connected: bool,
852    timer: Option<Pin<Box<dyn AsyncTimer>>>,
853    timer_deadline: Option<Instant>,
854    conn_events: mpsc::UnboundedReceiver<ConnectionEvent>,
855    endpoint_events: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
856    pub(crate) blocked_writers: FxHashMap<StreamId, Waker>,
857    pub(crate) blocked_readers: FxHashMap<StreamId, Waker>,
858    pub(crate) finishing: FxHashMap<StreamId, oneshot::Sender<Option<WriteError>>>,
859    pub(crate) stopped: FxHashMap<StreamId, Waker>,
860    /// Always set to Some before the connection becomes drained
861    pub(crate) error: Option<ConnectionError>,
862    /// Number of live handles that can be used to initiate or handle I/O; excludes the driver
863    ref_count: usize,
864    socket: Arc<dyn AsyncUdpSocket>,
865    runtime: Arc<dyn Runtime>,
866}
867
868impl State {
869    fn drive_transmit(&mut self) -> bool {
870        let now = Instant::now();
871        let mut transmits = 0;
872
873        let max_datagrams = self.socket.max_transmit_segments();
874        let capacity = self.inner.current_mtu();
875        let mut buffer = BytesMut::with_capacity(capacity as usize);
876
877        while let Some(t) = self.inner.poll_transmit(now, max_datagrams, &mut buffer) {
878            transmits += match t.segment_size {
879                None => 1,
880                Some(s) => (t.size + s - 1) / s, // round up
881            };
882            // If the endpoint driver is gone, noop.
883            let size = t.size;
884            let _ = self.endpoint_events.send((
885                self.handle,
886                EndpointEvent::Transmit(t, buffer.split_to(size).freeze()),
887            ));
888
889            if transmits >= MAX_TRANSMIT_DATAGRAMS {
890                // TODO: What isn't ideal here yet is that if we don't poll all
891                // datagrams that could be sent we don't go into the `app_limited`
892                // state and CWND continues to grow until we get here the next time.
893                // See https://github.com/quinn-rs/quinn/issues/1126
894                return true;
895            }
896        }
897
898        false
899    }
900
901    fn forward_endpoint_events(&mut self) {
902        while let Some(event) = self.inner.poll_endpoint_events() {
903            // If the endpoint driver is gone, noop.
904            let _ = self
905                .endpoint_events
906                .send((self.handle, EndpointEvent::Proto(event)));
907        }
908    }
909
910    /// If this returns `Err`, the endpoint is dead, so the driver should exit immediately.
911    fn process_conn_events(
912        &mut self,
913        shared: &Shared,
914        cx: &mut Context,
915    ) -> Result<(), ConnectionError> {
916        loop {
917            match self.conn_events.poll_recv(cx) {
918                Poll::Ready(Some(ConnectionEvent::Ping)) => {
919                    self.inner.ping();
920                }
921                Poll::Ready(Some(ConnectionEvent::Proto(event))) => {
922                    self.inner.handle_event(event);
923                }
924                Poll::Ready(Some(ConnectionEvent::Close { reason, error_code })) => {
925                    self.close(error_code, reason, shared);
926                }
927                Poll::Ready(None) => {
928                    return Err(ConnectionError::TransportError(proto::TransportError {
929                        code: proto::TransportErrorCode::INTERNAL_ERROR,
930                        frame: None,
931                        reason: "endpoint driver future was dropped".to_string(),
932                    }));
933                }
934                Poll::Pending => {
935                    return Ok(());
936                }
937            }
938        }
939    }
940
941    fn forward_app_events(&mut self, shared: &Shared) {
942        while let Some(event) = self.inner.poll() {
943            use proto::Event::*;
944            match event {
945                HandshakeDataReady => {
946                    if let Some(x) = self.on_handshake_data.take() {
947                        let _ = x.send(());
948                    }
949                }
950                Connected => {
951                    self.connected = true;
952                    if let Some(x) = self.on_connected.take() {
953                        // We don't care if the on-connected future was dropped
954                        let _ = x.send(self.inner.accepted_0rtt());
955                    }
956                }
957                ConnectionLost { reason } => {
958                    self.terminate(reason, shared);
959                }
960                Stream(StreamEvent::Writable { id }) => {
961                    if let Some(writer) = self.blocked_writers.remove(&id) {
962                        writer.wake();
963                    }
964                }
965                Stream(StreamEvent::Opened { dir: Dir::Uni }) => {
966                    shared.stream_incoming[Dir::Uni as usize].notify_waiters();
967                }
968                Stream(StreamEvent::Opened { dir: Dir::Bi }) => {
969                    shared.stream_incoming[Dir::Bi as usize].notify_waiters();
970                }
971                DatagramReceived => {
972                    shared.datagrams.notify_waiters();
973                }
974                Stream(StreamEvent::Readable { id }) => {
975                    if let Some(reader) = self.blocked_readers.remove(&id) {
976                        reader.wake();
977                    }
978                }
979                Stream(StreamEvent::Available { dir }) => {
980                    // Might mean any number of streams are ready, so we wake up everyone
981                    shared.stream_budget_available[dir as usize].notify_waiters();
982                }
983                Stream(StreamEvent::Finished { id }) => {
984                    if let Some(finishing) = self.finishing.remove(&id) {
985                        // If the finishing stream was already dropped, there's nothing more to do.
986                        let _ = finishing.send(None);
987                    }
988                    if let Some(stopped) = self.stopped.remove(&id) {
989                        stopped.wake();
990                    }
991                }
992                Stream(StreamEvent::Stopped { id, error_code }) => {
993                    if let Some(stopped) = self.stopped.remove(&id) {
994                        stopped.wake();
995                    }
996                    if let Some(finishing) = self.finishing.remove(&id) {
997                        let _ = finishing.send(Some(WriteError::Stopped(error_code)));
998                    }
999                    if let Some(writer) = self.blocked_writers.remove(&id) {
1000                        writer.wake();
1001                    }
1002                }
1003            }
1004        }
1005    }
1006
1007    fn drive_timer(&mut self, cx: &mut Context) -> bool {
1008        // Check whether we need to (re)set the timer. If so, we must poll again to ensure the
1009        // timer is registered with the runtime (and check whether it's already
1010        // expired).
1011        match self.inner.poll_timeout() {
1012            Some(deadline) => {
1013                if let Some(delay) = &mut self.timer {
1014                    // There is no need to reset the tokio timer if the deadline
1015                    // did not change
1016                    if self
1017                        .timer_deadline
1018                        .map(|current_deadline| current_deadline != deadline)
1019                        .unwrap_or(true)
1020                    {
1021                        delay.as_mut().reset(deadline);
1022                    }
1023                } else {
1024                    self.timer = Some(self.runtime.new_timer(deadline));
1025                }
1026                // Store the actual expiration time of the timer
1027                self.timer_deadline = Some(deadline);
1028            }
1029            None => {
1030                self.timer_deadline = None;
1031                return false;
1032            }
1033        }
1034
1035        if self.timer_deadline.is_none() {
1036            return false;
1037        }
1038
1039        let delay = self
1040            .timer
1041            .as_mut()
1042            .expect("timer must exist in this state")
1043            .as_mut();
1044        if delay.poll(cx).is_pending() {
1045            // Since there wasn't a timeout event, there is nothing new
1046            // for the connection to do
1047            return false;
1048        }
1049
1050        // A timer expired, so the caller needs to check for
1051        // new transmits, which might cause new timers to be set.
1052        self.inner.handle_timeout(Instant::now());
1053        self.timer_deadline = None;
1054        true
1055    }
1056
1057    /// Wake up a blocked `Driver` task to process I/O
1058    pub(crate) fn wake(&mut self) {
1059        if let Some(x) = self.driver.take() {
1060            x.wake();
1061        }
1062    }
1063
1064    /// Used to wake up all blocked futures when the connection becomes closed for any reason
1065    fn terminate(&mut self, reason: ConnectionError, shared: &Shared) {
1066        self.error = Some(reason.clone());
1067        if let Some(x) = self.on_handshake_data.take() {
1068            let _ = x.send(());
1069        }
1070        for (_, writer) in self.blocked_writers.drain() {
1071            writer.wake()
1072        }
1073        for (_, reader) in self.blocked_readers.drain() {
1074            reader.wake()
1075        }
1076        shared.stream_budget_available[Dir::Uni as usize].notify_waiters();
1077        shared.stream_budget_available[Dir::Bi as usize].notify_waiters();
1078        shared.stream_incoming[Dir::Uni as usize].notify_waiters();
1079        shared.stream_incoming[Dir::Bi as usize].notify_waiters();
1080        shared.datagrams.notify_waiters();
1081        for (_, x) in self.finishing.drain() {
1082            let _ = x.send(Some(WriteError::ConnectionLost(reason.clone())));
1083        }
1084        if let Some(x) = self.on_connected.take() {
1085            let _ = x.send(false);
1086        }
1087        for (_, waker) in self.stopped.drain() {
1088            waker.wake();
1089        }
1090        shared.closed.notify_waiters();
1091    }
1092
1093    fn close(&mut self, error_code: VarInt, reason: Bytes, shared: &Shared) {
1094        self.inner.close(Instant::now(), error_code, reason);
1095        self.terminate(ConnectionError::LocallyClosed, shared);
1096        self.wake();
1097    }
1098
1099    /// Close for a reason other than the application's explicit request
1100    pub(crate) fn implicit_close(&mut self, shared: &Shared) {
1101        self.close(0u32.into(), Bytes::new(), shared);
1102    }
1103
1104    pub(crate) fn check_0rtt(&self) -> Result<(), ()> {
1105        if self.inner.is_handshaking()
1106            || self.inner.accepted_0rtt()
1107            || self.inner.side().is_server()
1108        {
1109            Ok(())
1110        } else {
1111            Err(())
1112        }
1113    }
1114}
1115
1116impl Drop for State {
1117    fn drop(&mut self) {
1118        if !self.inner.is_drained() {
1119            // Ensure the endpoint can tidy up
1120            let _ = self.endpoint_events.send((
1121                self.handle,
1122                EndpointEvent::Proto(proto::EndpointEvent::drained()),
1123            ));
1124        }
1125    }
1126}
1127
1128impl fmt::Debug for State {
1129    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1130        f.debug_struct("State").field("inner", &self.inner).finish()
1131    }
1132}
1133
1134/// Errors that can arise when sending a datagram
1135#[derive(Debug, Error, Clone, Eq, PartialEq)]
1136pub enum SendDatagramError {
1137    /// The peer does not support receiving datagram frames
1138    #[error("datagrams not supported by peer")]
1139    UnsupportedByPeer,
1140    /// Datagram support is disabled locally
1141    #[error("datagram support disabled")]
1142    Disabled,
1143    /// The datagram is larger than the connection can currently accommodate
1144    ///
1145    /// Indicates that the path MTU minus overhead or the limit advertised by the peer has been
1146    /// exceeded.
1147    #[error("datagram too large")]
1148    TooLarge,
1149    /// The connection was lost
1150    #[error("connection lost")]
1151    ConnectionLost(#[from] ConnectionError),
1152}
1153
1154/// The maximum amount of datagrams which will be produced in a single `drive_transmit` call
1155///
1156/// This limits the amount of CPU resources consumed by datagram generation,
1157/// and allows other tasks (like receiving ACKs) to run in between.
1158const MAX_TRANSMIT_DATAGRAMS: usize = 20;
1159
1160/// Error indicating that a stream has already been finished or reset
1161#[derive(Debug, Error, Clone, PartialEq, Eq)]
1162#[error("unknown stream")]
1163pub struct UnknownStream {
1164    _private: (),
1165}
1166
1167impl From<proto::UnknownStream> for UnknownStream {
1168    fn from(_: proto::UnknownStream) -> Self {
1169        Self { _private: () }
1170    }
1171}