async_quic/
connection.rs

1use std::{
2    collections::BTreeMap,
3    pin::Pin,
4    sync::{Arc, Mutex},
5    task::{Context, Poll, Waker},
6    time::Instant,
7};
8
9use crate::{EndpointInner, QuicStream};
10use async_io::Timer;
11use futures::{channel::mpsc::Sender, prelude::*};
12
13pub struct QuicConnection {
14    inner: Arc<ConnectionInner>,
15}
16
17impl QuicConnection {
18    pub(crate) fn new(
19        handle: quinn_proto::ConnectionHandle,
20        conn: quinn_proto::Connection,
21        endpoint: Arc<EndpointInner>,
22        transmit_sender: Sender<quinn_proto::Transmit>,
23    ) -> Self {
24        let state = Mutex::new(ConnectionState {
25            conn,
26            timer: None,
27            stream_wakers: BTreeMap::new(),
28            conn_waker: None,
29            transmit_sender,
30        });
31        let inner = Arc::new(ConnectionInner {
32            state,
33            endpoint,
34            handle,
35        });
36        Self { inner }
37    }
38    pub(crate) fn inner(&self) -> Arc<ConnectionInner> {
39        self.inner.clone()
40    }
41}
42
43pub(crate) struct ConnectionInner {
44    state: Mutex<ConnectionState>,
45    endpoint: Arc<EndpointInner>,
46    handle: quinn_proto::ConnectionHandle,
47}
48
49impl ConnectionInner {
50    pub(crate) fn handle_event(&self, event: quinn_proto::ConnectionEvent) {
51        let mut guard = self.state.lock().unwrap();
52        guard.conn.handle_event(event);
53        if let Some(waker) = guard.conn_waker.take() {
54            waker.wake()
55        }
56    }
57    fn poll(self: &Arc<ConnectionInner>, cx: &mut Context<'_>) -> Poll<QuicConnectionEvent> {
58        let mut guard = self.state.lock().unwrap();
59        guard.conn_waker = None;
60        let mgs = self.endpoint.udp_state().max_gso_segments();
61        while let Some(t) = guard.conn.poll_transmit(Instant::now(), mgs) {
62            if let Poll::Ready(Ok(())) = guard.transmit_sender.poll_ready(cx) {
63                guard.transmit_sender.start_send(t).unwrap()
64            }
65        }
66        loop {
67            guard.timer = guard.conn.poll_timeout().map(Timer::at);
68            if let Some(timer) = &mut guard.timer {
69                match timer.poll_unpin(cx) {
70                    Poll::Ready(_) => guard.conn.handle_timeout(Instant::now()),
71                    Poll::Pending => break,
72                }
73            }
74        }
75        while let Some(event) = guard.conn.poll_endpoint_events() {
76            if let Some(event) = self.endpoint.handle_enpoint_event(self.handle, event) {
77                guard.conn.handle_event(event);
78            }
79        }
80        while let Some(event) = guard.conn.poll() {
81            match event {
82                quinn_proto::Event::HandshakeDataReady => log::info!("handshake data ready"),
83                quinn_proto::Event::Connected => log::info!("connected"),
84                quinn_proto::Event::ConnectionLost { reason } => {
85                    log::error!("connection lost: {:?}", reason)
86                }
87                quinn_proto::Event::DatagramReceived => log::error!("ignoring datagram"),
88                quinn_proto::Event::Stream(event) => match event {
89                    quinn_proto::StreamEvent::Opened { .. } => {} // ignore, because we check anyway
90                    quinn_proto::StreamEvent::Readable { id } => guard.wake(id, true, false),
91                    quinn_proto::StreamEvent::Writable { id } => guard.wake(id, false, true),
92                    quinn_proto::StreamEvent::Finished { id } => guard.wake(id, false, true),
93                    quinn_proto::StreamEvent::Stopped { id, .. } => guard.wake(id, true, false),
94                    quinn_proto::StreamEvent::Available { dir } => todo!("available: {}", dir),
95                },
96            }
97        }
98        let mut streams = guard.conn.streams();
99        if let Some(id) = streams.accept(quinn_proto::Dir::Uni) {
100            guard.stream_wakers.insert(id, [None, None]);
101            return Poll::Ready(QuicConnectionEvent::StreamR(QuicStream::new(
102                self.clone(),
103                id,
104            )));
105        }
106        if let Some(id) = streams.accept(quinn_proto::Dir::Bi) {
107            guard.stream_wakers.insert(id, [None, None]);
108            return Poll::Ready(QuicConnectionEvent::StreamRW(QuicStream::new(
109                self.clone(),
110                id,
111            )));
112        }
113        guard.conn_waker = Some(cx.waker().clone());
114        Poll::Pending
115    }
116    pub(crate) fn poll_read(
117        &self,
118        id: quinn_proto::StreamId,
119        cx: &mut Context<'_>,
120        buf: &mut [u8],
121    ) -> Poll<(usize, Option<quinn_proto::VarInt>)> {
122        let mut guard = self.state.lock().unwrap();
123        guard.stream_wakers.get_mut(&id).unwrap()[0] = None;
124        let mut recv_stream = guard.conn.recv_stream(id);
125        let mut chunks = match recv_stream.read(true) {
126            Ok(chunks) => chunks,
127            Err(_) => return Poll::Ready((0, None)),
128        };
129        let mut n = 0usize;
130        let (blocked, err_code) = loop {
131            if buf.len() == n {
132                break (false, None);
133            }
134            match chunks.next(buf.len() - n) {
135                Ok(Some(chunk)) => {
136                    let m = n + chunk.bytes.len();
137                    buf[n..m].copy_from_slice(&chunk.bytes);
138                    n = m;
139                }
140                Ok(None) => break (false, None),
141                Err(quinn_proto::ReadError::Blocked) => break (true, None),
142                Err(quinn_proto::ReadError::Reset(err)) => break (false, Some(err)),
143            }
144        };
145        if chunks.finalize().should_transmit() {
146            if let Some(w) = guard.conn_waker.take() {
147                w.wake();
148            }
149        }
150        if n == 0 && blocked {
151            guard.stream_wakers.get_mut(&id).unwrap()[0] = Some(cx.waker().clone());
152            return Poll::Pending;
153        }
154        return Poll::Ready((n, err_code));
155    }
156    pub(crate) fn poll_write(
157        &self,
158        id: quinn_proto::StreamId,
159        cx: &mut Context<'_>,
160        buf: &[u8],
161    ) -> Poll<Result<usize, Option<quinn_proto::VarInt>>> {
162        let mut guard = self.state.lock().unwrap();
163        guard.stream_wakers.get_mut(&id).unwrap()[1] = None;
164        let mut send_stream = guard.conn.send_stream(id);
165        match send_stream.write(buf) {
166            Ok(n) => Poll::Ready(Ok(n)),
167            Err(quinn_proto::WriteError::Blocked) => {
168                guard.stream_wakers.get_mut(&id).unwrap()[1] = Some(cx.waker().clone());
169                Poll::Pending
170            }
171            Err(quinn_proto::WriteError::Stopped(err_code)) => Poll::Ready(Err(Some(err_code))),
172            Err(quinn_proto::WriteError::UnknownStream) => Poll::Ready(Err(None)),
173        }
174    }
175    pub(crate) fn close(
176        &self,
177        id: quinn_proto::StreamId,
178        _cx: &mut Context<'_>,
179    ) -> Result<(), Option<quinn_proto::VarInt>> {
180        let mut guard = self.state.lock().unwrap();
181        guard.stream_wakers.get_mut(&id).unwrap()[1] = None;
182        let mut send_stream = guard.conn.send_stream(id);
183        match send_stream.finish() {
184            Ok(()) => Ok(()),
185            Err(quinn_proto::FinishError::Stopped(err_code)) => Err(Some(err_code)),
186            Err(quinn_proto::FinishError::UnknownStream) => Err(None),
187        }
188    }
189}
190
191impl Stream for QuicConnection {
192    type Item = QuicConnectionEvent;
193
194    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
195        self.inner.poll(cx).map(Option::Some)
196    }
197}
198
199struct ConnectionState {
200    conn: quinn_proto::Connection,
201    timer: Option<Timer>,
202    conn_waker: Option<Waker>,
203    stream_wakers: BTreeMap<quinn_proto::StreamId, [Option<Waker>; 2]>,
204    transmit_sender: Sender<quinn_proto::Transmit>,
205}
206
207impl ConnectionState {
208    fn wake(&mut self, id: quinn_proto::StreamId, r: bool, w: bool) {
209        if let Some(wakers) = self.stream_wakers.get_mut(&id) {
210            if r {
211                if let Some(waker) = wakers[0].take() {
212                    waker.wake();
213                }
214            }
215            if w {
216                if let Some(waker) = wakers[1].take() {
217                    waker.wake();
218                }
219            }
220        }
221    }
222}
223pub enum QuicConnectionEvent {
224    StreamR(QuicStream<true, false>),
225    StreamRW(QuicStream<true, true>),
226}
227
228impl QuicConnectionEvent {
229    pub fn stream_r(self) -> Option<QuicStream<true, false>> {
230        match self {
231            Self::StreamR(stream) => Some(stream),
232            _ => None,
233        }
234    }
235    pub fn stream_rw(self) -> Option<QuicStream<true, true>> {
236        match self {
237            Self::StreamRW(stream) => Some(stream),
238            _ => None,
239        }
240    }
241}