1use std::{
2 collections::BTreeMap,
3 pin::Pin,
4 sync::{Arc, Mutex},
5 task::{Context, Poll, Waker},
6 time::Instant,
7};
8
9use crate::{EndpointInner, QuicStream};
10use async_io::Timer;
11use futures::{channel::mpsc::Sender, prelude::*};
12
13pub struct QuicConnection {
14 inner: Arc<ConnectionInner>,
15}
16
17impl QuicConnection {
18 pub(crate) fn new(
19 handle: quinn_proto::ConnectionHandle,
20 conn: quinn_proto::Connection,
21 endpoint: Arc<EndpointInner>,
22 transmit_sender: Sender<quinn_proto::Transmit>,
23 ) -> Self {
24 let state = Mutex::new(ConnectionState {
25 conn,
26 timer: None,
27 stream_wakers: BTreeMap::new(),
28 conn_waker: None,
29 transmit_sender,
30 });
31 let inner = Arc::new(ConnectionInner {
32 state,
33 endpoint,
34 handle,
35 });
36 Self { inner }
37 }
38 pub(crate) fn inner(&self) -> Arc<ConnectionInner> {
39 self.inner.clone()
40 }
41}
42
43pub(crate) struct ConnectionInner {
44 state: Mutex<ConnectionState>,
45 endpoint: Arc<EndpointInner>,
46 handle: quinn_proto::ConnectionHandle,
47}
48
49impl ConnectionInner {
50 pub(crate) fn handle_event(&self, event: quinn_proto::ConnectionEvent) {
51 let mut guard = self.state.lock().unwrap();
52 guard.conn.handle_event(event);
53 if let Some(waker) = guard.conn_waker.take() {
54 waker.wake()
55 }
56 }
57 fn poll(self: &Arc<ConnectionInner>, cx: &mut Context<'_>) -> Poll<QuicConnectionEvent> {
58 let mut guard = self.state.lock().unwrap();
59 guard.conn_waker = None;
60 let mgs = self.endpoint.udp_state().max_gso_segments();
61 while let Some(t) = guard.conn.poll_transmit(Instant::now(), mgs) {
62 if let Poll::Ready(Ok(())) = guard.transmit_sender.poll_ready(cx) {
63 guard.transmit_sender.start_send(t).unwrap()
64 }
65 }
66 loop {
67 guard.timer = guard.conn.poll_timeout().map(Timer::at);
68 if let Some(timer) = &mut guard.timer {
69 match timer.poll_unpin(cx) {
70 Poll::Ready(_) => guard.conn.handle_timeout(Instant::now()),
71 Poll::Pending => break,
72 }
73 }
74 }
75 while let Some(event) = guard.conn.poll_endpoint_events() {
76 if let Some(event) = self.endpoint.handle_enpoint_event(self.handle, event) {
77 guard.conn.handle_event(event);
78 }
79 }
80 while let Some(event) = guard.conn.poll() {
81 match event {
82 quinn_proto::Event::HandshakeDataReady => log::info!("handshake data ready"),
83 quinn_proto::Event::Connected => log::info!("connected"),
84 quinn_proto::Event::ConnectionLost { reason } => {
85 log::error!("connection lost: {:?}", reason)
86 }
87 quinn_proto::Event::DatagramReceived => log::error!("ignoring datagram"),
88 quinn_proto::Event::Stream(event) => match event {
89 quinn_proto::StreamEvent::Opened { .. } => {} quinn_proto::StreamEvent::Readable { id } => guard.wake(id, true, false),
91 quinn_proto::StreamEvent::Writable { id } => guard.wake(id, false, true),
92 quinn_proto::StreamEvent::Finished { id } => guard.wake(id, false, true),
93 quinn_proto::StreamEvent::Stopped { id, .. } => guard.wake(id, true, false),
94 quinn_proto::StreamEvent::Available { dir } => todo!("available: {}", dir),
95 },
96 }
97 }
98 let mut streams = guard.conn.streams();
99 if let Some(id) = streams.accept(quinn_proto::Dir::Uni) {
100 guard.stream_wakers.insert(id, [None, None]);
101 return Poll::Ready(QuicConnectionEvent::StreamR(QuicStream::new(
102 self.clone(),
103 id,
104 )));
105 }
106 if let Some(id) = streams.accept(quinn_proto::Dir::Bi) {
107 guard.stream_wakers.insert(id, [None, None]);
108 return Poll::Ready(QuicConnectionEvent::StreamRW(QuicStream::new(
109 self.clone(),
110 id,
111 )));
112 }
113 guard.conn_waker = Some(cx.waker().clone());
114 Poll::Pending
115 }
116 pub(crate) fn poll_read(
117 &self,
118 id: quinn_proto::StreamId,
119 cx: &mut Context<'_>,
120 buf: &mut [u8],
121 ) -> Poll<(usize, Option<quinn_proto::VarInt>)> {
122 let mut guard = self.state.lock().unwrap();
123 guard.stream_wakers.get_mut(&id).unwrap()[0] = None;
124 let mut recv_stream = guard.conn.recv_stream(id);
125 let mut chunks = match recv_stream.read(true) {
126 Ok(chunks) => chunks,
127 Err(_) => return Poll::Ready((0, None)),
128 };
129 let mut n = 0usize;
130 let (blocked, err_code) = loop {
131 if buf.len() == n {
132 break (false, None);
133 }
134 match chunks.next(buf.len() - n) {
135 Ok(Some(chunk)) => {
136 let m = n + chunk.bytes.len();
137 buf[n..m].copy_from_slice(&chunk.bytes);
138 n = m;
139 }
140 Ok(None) => break (false, None),
141 Err(quinn_proto::ReadError::Blocked) => break (true, None),
142 Err(quinn_proto::ReadError::Reset(err)) => break (false, Some(err)),
143 }
144 };
145 if chunks.finalize().should_transmit() {
146 if let Some(w) = guard.conn_waker.take() {
147 w.wake();
148 }
149 }
150 if n == 0 && blocked {
151 guard.stream_wakers.get_mut(&id).unwrap()[0] = Some(cx.waker().clone());
152 return Poll::Pending;
153 }
154 return Poll::Ready((n, err_code));
155 }
156 pub(crate) fn poll_write(
157 &self,
158 id: quinn_proto::StreamId,
159 cx: &mut Context<'_>,
160 buf: &[u8],
161 ) -> Poll<Result<usize, Option<quinn_proto::VarInt>>> {
162 let mut guard = self.state.lock().unwrap();
163 guard.stream_wakers.get_mut(&id).unwrap()[1] = None;
164 let mut send_stream = guard.conn.send_stream(id);
165 match send_stream.write(buf) {
166 Ok(n) => Poll::Ready(Ok(n)),
167 Err(quinn_proto::WriteError::Blocked) => {
168 guard.stream_wakers.get_mut(&id).unwrap()[1] = Some(cx.waker().clone());
169 Poll::Pending
170 }
171 Err(quinn_proto::WriteError::Stopped(err_code)) => Poll::Ready(Err(Some(err_code))),
172 Err(quinn_proto::WriteError::UnknownStream) => Poll::Ready(Err(None)),
173 }
174 }
175 pub(crate) fn close(
176 &self,
177 id: quinn_proto::StreamId,
178 _cx: &mut Context<'_>,
179 ) -> Result<(), Option<quinn_proto::VarInt>> {
180 let mut guard = self.state.lock().unwrap();
181 guard.stream_wakers.get_mut(&id).unwrap()[1] = None;
182 let mut send_stream = guard.conn.send_stream(id);
183 match send_stream.finish() {
184 Ok(()) => Ok(()),
185 Err(quinn_proto::FinishError::Stopped(err_code)) => Err(Some(err_code)),
186 Err(quinn_proto::FinishError::UnknownStream) => Err(None),
187 }
188 }
189}
190
191impl Stream for QuicConnection {
192 type Item = QuicConnectionEvent;
193
194 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
195 self.inner.poll(cx).map(Option::Some)
196 }
197}
198
199struct ConnectionState {
200 conn: quinn_proto::Connection,
201 timer: Option<Timer>,
202 conn_waker: Option<Waker>,
203 stream_wakers: BTreeMap<quinn_proto::StreamId, [Option<Waker>; 2]>,
204 transmit_sender: Sender<quinn_proto::Transmit>,
205}
206
207impl ConnectionState {
208 fn wake(&mut self, id: quinn_proto::StreamId, r: bool, w: bool) {
209 if let Some(wakers) = self.stream_wakers.get_mut(&id) {
210 if r {
211 if let Some(waker) = wakers[0].take() {
212 waker.wake();
213 }
214 }
215 if w {
216 if let Some(waker) = wakers[1].take() {
217 waker.wake();
218 }
219 }
220 }
221 }
222}
223pub enum QuicConnectionEvent {
224 StreamR(QuicStream<true, false>),
225 StreamRW(QuicStream<true, true>),
226}
227
228impl QuicConnectionEvent {
229 pub fn stream_r(self) -> Option<QuicStream<true, false>> {
230 match self {
231 Self::StreamR(stream) => Some(stream),
232 _ => None,
233 }
234 }
235 pub fn stream_rw(self) -> Option<QuicStream<true, true>> {
236 match self {
237 Self::StreamRW(stream) => Some(stream),
238 _ => None,
239 }
240 }
241}