1use std::{
38 collections::HashSet,
39 convert::TryFrom,
40 sync::Arc,
41 task::{Context, Poll},
42};
43
44use bytes::{Buf, BytesMut};
45use futures_util::future;
46use http::{response, HeaderMap, Request, Response, StatusCode};
47use quic::StreamId;
48use tokio::sync::mpsc;
49
50use crate::{
51 connection::{self, ConnectionInner, ConnectionState, SharedStateRef},
52 error::{Code, Error},
53 frame::FrameStream,
54 proto::{frame::Frame, headers::Header, varint::VarInt},
55 qpack,
56 quic::{self, RecvStream as _, SendStream as _},
57 stream,
58};
59use tracing::{error, trace, warn};
60
61pub fn builder() -> Builder {
63 Builder::new()
64}
65
66pub struct Connection<C, B>
72where
73 C: quic::Connection<B>,
74 B: Buf,
75{
76 inner: ConnectionInner<C, B>,
77 max_field_section_size: u64,
78 ongoing_streams: HashSet<StreamId>,
80 request_end_recv: mpsc::UnboundedReceiver<StreamId>,
82 request_end_send: mpsc::UnboundedSender<StreamId>,
83}
84
85impl<C, B> ConnectionState for Connection<C, B>
86where
87 C: quic::Connection<B>,
88 B: Buf,
89{
90 fn shared_state(&self) -> &SharedStateRef {
91 &self.inner.shared
92 }
93}
94
95impl<C, B> Connection<C, B>
96where
97 C: quic::Connection<B>,
98 B: Buf,
99{
100 pub async fn new(conn: C) -> Result<Self, Error> {
104 Ok(builder().build(conn).await?)
105 }
106}
107
108impl<C, B> Connection<C, B>
109where
110 C: quic::Connection<B>,
111 B: Buf,
112{
113 pub async fn accept(
118 &mut self,
119 ) -> Result<Option<(Request<()>, RequestStream<C::BidiStream, B>)>, Error> {
120 let mut stream = match future::poll_fn(|cx| self.poll_accept_request(cx)).await {
121 Ok(Some(s)) => FrameStream::new(s),
122 Ok(None) => {
123 self.inner.shutdown(0).await?;
126 return Ok(None);
127 }
128 Err(e) => {
129 if e.is_closed() {
130 return Ok(None);
131 }
132 return Err(e);
133 }
134 };
135
136 let frame = future::poll_fn(|cx| stream.poll_next(cx)).await;
137
138 let mut encoded = match frame {
139 Ok(Some(Frame::Headers(h))) => h,
140 Ok(None) => {
141 return Err(
142 Code::H3_REQUEST_INCOMPLETE.with_reason("request stream closed before headers")
143 )
144 }
145 Ok(Some(_)) => {
146 return Err(
147 Code::H3_FRAME_UNEXPECTED.with_reason("first request frame is not headers")
148 )
149 }
150 Err(e) => {
151 let err: Error = e.into();
152 if err.is_closed() {
153 return Ok(None);
154 }
155 return Err(err);
156 }
157 };
158
159 let mut request_stream = RequestStream {
160 request_end: Arc::new(RequestEnd {
161 request_end: self.request_end_send.clone(),
162 stream_id: stream.id(),
163 }),
164 inner: connection::RequestStream::new(
165 stream,
166 self.max_field_section_size,
167 self.inner.shared.clone(),
168 self.inner.send_grease_frame,
169 ),
170 };
171
172 let qpack::Decoded { fields, .. } =
173 match qpack::decode_stateless(&mut encoded, self.max_field_section_size) {
174 Err(qpack::DecoderError::HeaderTooLong(cancel_size)) => {
175 request_stream
176 .send_response(
177 http::Response::builder()
178 .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
179 .body(())
180 .expect("header too big response"),
181 )
182 .await?;
183 return Err(Error::header_too_big(
184 cancel_size,
185 self.max_field_section_size,
186 ));
187 }
188 Ok(decoded) => decoded,
189 Err(e) => return Err(e.into()),
190 };
191
192 let (method, uri, headers) = Header::try_from(fields)?.into_request_parts()?;
193
194 let mut req = http::Request::new(());
195 *req.method_mut() = method;
196 *req.uri_mut() = uri;
197 *req.headers_mut() = headers;
198 *req.version_mut() = http::Version::HTTP_3;
199 self.inner.send_grease_frame = false;
201
202 Ok(Some((req, request_stream)))
203 }
204
205 pub async fn shutdown(&mut self, max_requests: usize) -> Result<(), Error> {
208 self.inner.shutdown(max_requests).await
209 }
210
211 fn poll_accept_request(
212 &mut self,
213 cx: &mut Context<'_>,
214 ) -> Poll<Result<Option<C::BidiStream>, Error>> {
215 let _ = self.poll_control(cx)?;
216 let _ = self.poll_requests_completion(cx);
217
218 let closing = self.shared_state().read("server accept").closing;
219
220 loop {
221 match self.inner.poll_accept_request(cx) {
222 Poll::Ready(Err(x)) => break Poll::Ready(Err(x)),
223 Poll::Ready(Ok(None)) => {
224 if self.poll_requests_completion(cx).is_ready() {
225 break Poll::Ready(Ok(None));
226 } else {
227 break Poll::Pending;
230 }
231 }
232 Poll::Pending => {
233 if closing.is_some() && self.poll_requests_completion(cx).is_ready() {
234 break Poll::Ready(Ok(None));
236 } else {
237 return Poll::Pending;
238 }
239 }
240 Poll::Ready(Ok(Some(mut s))) => {
241 if let Some(max_id) = closing {
245 if s.id() > max_id {
246 s.stop_sending(Code::H3_REQUEST_REJECTED.value());
247 s.reset(Code::H3_REQUEST_REJECTED.value());
248 if self.poll_requests_completion(cx).is_ready() {
249 break Poll::Ready(Ok(None));
250 }
251 continue;
252 }
253 }
254 self.inner.start_stream(s.id());
255 self.ongoing_streams.insert(s.id());
256 break Poll::Ready(Ok(Some(s)));
257 }
258 };
259 }
260 }
261
262 fn poll_control(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
263 while let Poll::Ready(frame) = self.inner.poll_control(cx)? {
264 match frame {
265 Frame::Settings(_) => trace!("Got settings"),
266 Frame::Goaway(id) => {
267 if !id.is_push() {
268 return Poll::Ready(Err(Code::H3_ID_ERROR
269 .with_reason(format!("non-push StreamId in a GoAway frame: {}", id))));
270 }
271 }
272 f @ Frame::MaxPushId(_) | f @ Frame::CancelPush(_) => {
273 warn!("Control frame ignored {:?}", f);
274 }
275 frame => {
276 return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED
277 .with_reason(format!("on server control stream: {:?}", frame))))
278 }
279 }
280 }
281 Poll::Pending
282 }
283
284 fn poll_requests_completion(&mut self, cx: &mut Context<'_>) -> Poll<()> {
285 loop {
286 match self.request_end_recv.poll_recv(cx) {
287 Poll::Ready(None) => return Poll::Ready(()),
289 Poll::Ready(Some(id)) => {
291 self.ongoing_streams.remove(&id);
292 }
293 Poll::Pending => {
294 if self.ongoing_streams.is_empty() {
295 return Poll::Ready(());
298 } else {
299 return Poll::Pending;
300 }
301 }
302 }
303 }
304 }
305}
306
307impl<C, B> Drop for Connection<C, B>
308where
309 C: quic::Connection<B>,
310 B: Buf,
311{
312 fn drop(&mut self) {
313 self.inner.close(Code::H3_NO_ERROR, "");
314 }
315}
316
317pub struct Builder {
338 pub(super) max_field_section_size: u64,
339 pub(super) send_grease: bool,
340}
341
342impl Builder {
343 pub(super) fn new() -> Self {
345 Builder {
346 max_field_section_size: VarInt::MAX.0,
347 send_grease: true,
348 }
349 }
350
351 pub fn max_field_section_size(&mut self, value: u64) -> &mut Self {
354 self.max_field_section_size = value;
355 self
356 }
357
358 pub fn send_grease(&mut self, value: bool) -> &mut Self {
361 self.send_grease = value;
362 self
363 }
364}
365
366impl Builder {
367 pub async fn build<C, B>(&self, conn: C) -> Result<Connection<C, B>, Error>
369 where
370 C: quic::Connection<B>,
371 B: Buf,
372 {
373 let (sender, receiver) = mpsc::unbounded_channel();
374 Ok(Connection {
375 inner: ConnectionInner::new(
376 conn,
377 self.max_field_section_size,
378 SharedStateRef::default(),
379 self.send_grease,
380 )
381 .await?,
382 max_field_section_size: self.max_field_section_size,
383 request_end_send: sender,
384 request_end_recv: receiver,
385 ongoing_streams: HashSet::new(),
386 })
387 }
388}
389
390pub struct RequestEnd {
391 request_end: mpsc::UnboundedSender<StreamId>,
392 stream_id: StreamId,
393}
394
395pub struct RequestStream<S, B> {
397 inner: connection::RequestStream<S, B>,
398 request_end: Arc<RequestEnd>,
399}
400
401impl<S, B> AsMut<connection::RequestStream<S, B>> for RequestStream<S, B> {
402 fn as_mut(&mut self) -> &mut connection::RequestStream<S, B> {
403 &mut self.inner
404 }
405}
406
407impl<S, B> ConnectionState for RequestStream<S, B> {
408 fn shared_state(&self) -> &SharedStateRef {
409 &self.inner.conn_state
410 }
411}
412
413impl<S, B> RequestStream<S, B>
414where
415 S: quic::RecvStream,
416{
417 pub async fn recv_data(&mut self) -> Result<Option<impl Buf>, Error> {
419 self.inner.recv_data().await
420 }
421
422 pub fn stop_sending(&mut self, error_code: crate::error::Code) {
423 self.inner.stream.stop_sending(error_code)
424 }
425}
426
427impl<S, B> RequestStream<S, B>
428where
429 S: quic::SendStream<B>,
430 B: Buf,
431{
432 pub async fn send_response(&mut self, resp: Response<()>) -> Result<(), Error> {
434 let (parts, _) = resp.into_parts();
435 let response::Parts {
436 status, headers, ..
437 } = parts;
438 let headers = Header::response(status, headers);
439
440 let mut block = BytesMut::new();
441 let mem_size = qpack::encode_stateless(&mut block, headers)?;
442
443 let max_mem_size = self
444 .inner
445 .conn_state
446 .read("send_response")
447 .peer_max_field_section_size;
448 if mem_size > max_mem_size {
449 return Err(Error::header_too_big(mem_size, max_mem_size));
450 }
451
452 stream::write(&mut self.inner.stream, Frame::Headers(block.freeze()))
453 .await
454 .map_err(|e| self.maybe_conn_err(e))?;
455
456 Ok(())
457 }
458
459 pub async fn send_data(&mut self, buf: B) -> Result<(), Error> {
461 self.inner.send_data(buf).await
462 }
463
464 pub async fn send_trailers(&mut self, trailers: HeaderMap) -> Result<(), Error> {
466 self.inner.send_trailers(trailers).await
467 }
468
469 pub async fn finish(&mut self) -> Result<(), Error> {
471 self.inner.finish().await
472 }
473}
474
475impl<S, B> RequestStream<S, B>
476where
477 S: quic::RecvStream + quic::SendStream<B>,
478 B: Buf,
479{
480 pub async fn recv_trailers(&mut self) -> Result<Option<HeaderMap>, Error> {
482 let res = self.inner.recv_trailers().await;
483 if let Err(ref e) = res {
484 if e.is_header_too_big() {
485 self.send_response(
486 http::Response::builder()
487 .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
488 .body(())
489 .expect("header too big response"),
490 )
491 .await?;
492 }
493 }
494 res
495 }
496}
497
498impl<S, B> RequestStream<S, B>
499where
500 S: quic::BidiStream<B>,
501 B: Buf,
502{
503 pub fn split(
506 self,
507 ) -> (
508 RequestStream<S::SendStream, B>,
509 RequestStream<S::RecvStream, B>,
510 ) {
511 let (send, recv) = self.inner.split();
512 (
513 RequestStream {
514 inner: send,
515 request_end: self.request_end.clone(),
516 },
517 RequestStream {
518 inner: recv,
519 request_end: self.request_end,
520 },
521 )
522 }
523}
524
525impl Drop for RequestEnd {
526 fn drop(&mut self) {
527 if let Err(e) = self.request_end.send(self.stream_id) {
528 error!(
529 "failed to notify connection of request end: {} {}",
530 self.stream_id, e
531 );
532 }
533 }
534}