1use crate::{
2 Buffer, Conn, Headers, HttpContext, KnownHeaderName, Method, ProtocolSession, ReceivedBody,
3 Status, TypeSet, Version,
4 h2::H2Connection,
5 h3::{Frame, H3Connection, H3Settings},
6 headers::qpack::{FieldSection, PseudoHeaders},
7 received_body::{H3TrailerFuture, ReceivedBodyState, write_chunk},
8 util::encoding,
9};
10use encoding_rs::Encoding;
11use fieldwork::Fieldwork;
12use futures_lite::{
13 AsyncWriteExt,
14 io::{AsyncRead, AsyncWrite},
15};
16use std::{
17 borrow::Cow,
18 fmt::{self, Debug, Formatter},
19 io::{self, IoSlice, Write},
20 net::IpAddr,
21 pin::Pin,
22 str,
23 sync::Arc,
24 task::{Context, Poll, ready},
25 time::Instant,
26};
27
28#[derive(Debug)]
31pub(crate) enum WriteState {
32 Raw,
35 H1Chunked(H1ChunkedState),
37 H3Framed(H3FramedState),
39}
40
41#[derive(Debug, Default)]
42pub(crate) struct H1ChunkedState {
43 pub(crate) pending: Vec<u8>,
44 pub(crate) terminator_written: bool,
45}
46
47#[derive(Debug, Default)]
48pub(crate) struct H3FramedState {
49 pub(crate) pending: Vec<u8>,
50 pub(crate) terminator_written: bool,
51}
52
53fn compute_write_state(version: Version, outbound_headers: &Headers) -> WriteState {
57 match version {
58 Version::Http1_0 | Version::Http1_1 if has_chunked_encoding(outbound_headers) => {
59 WriteState::H1Chunked(H1ChunkedState::default())
60 }
61 Version::Http3 => WriteState::H3Framed(H3FramedState::default()),
62 _ => WriteState::Raw,
63 }
64}
65
66fn has_chunked_encoding(headers: &Headers) -> bool {
69 headers
70 .token_iter(KnownHeaderName::TransferEncoding)
71 .any(|coding| coding.eq_ignore_ascii_case("chunked"))
72}
73
74fn parse_content_length(inbound_headers: &Headers) -> Option<u64> {
76 if inbound_headers.has_header(KnownHeaderName::TransferEncoding) {
77 return None;
78 }
79 inbound_headers.content_length()
80}
81
82fn poll_drain_pending<T: AsyncWrite + Unpin>(
84 pending: &mut Vec<u8>,
85 cx: &mut Context<'_>,
86 transport: &mut T,
87) -> Poll<io::Result<()>> {
88 while !pending.is_empty() {
89 match Pin::new(&mut *transport).poll_write(cx, pending) {
90 Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
91 Poll::Ready(Ok(n)) => {
92 pending.drain(..n);
93 }
94 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
95 Poll::Pending => return Poll::Pending,
96 }
97 }
98 Poll::Ready(Ok(()))
99}
100
101fn best_effort_drain<T: AsyncWrite + Unpin>(
104 pending: &mut Vec<u8>,
105 cx: &mut Context<'_>,
106 transport: &mut T,
107) -> io::Result<()> {
108 while !pending.is_empty() {
109 match Pin::new(&mut *transport).poll_write(cx, pending) {
110 Poll::Ready(Ok(0)) => return Err(io::ErrorKind::WriteZero.into()),
111 Poll::Ready(Ok(n)) => {
112 pending.drain(..n);
113 }
114 Poll::Ready(Err(e)) => return Err(e),
115 Poll::Pending => break,
116 }
117 }
118 Ok(())
119}
120
121fn encode_h3_data_header(out: &mut Vec<u8>, payload_len: u64) {
124 let frame = Frame::Data(payload_len);
125 let header_len = frame.encoded_len();
126 let start = out.len();
127 out.resize(start + header_len, 0);
128 frame.encode(&mut out[start..]);
129}
130
131#[derive(Fieldwork)]
137#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
138pub struct Upgrade<Transport> {
139 #[field(deprecate(was = "request_headers", since = "1.3.0"))]
141 pub(crate) received_headers: Headers,
142
143 #[field(deprecate(was = "response_headers", since = "1.3.0"))]
146 pub(crate) sent_headers: Headers,
147
148 #[field(get = false)]
150 pub(crate) path: Cow<'static, str>,
151
152 #[field(copy)]
154 pub(crate) method: Method,
155
156 pub(crate) state: TypeSet,
158
159 pub(crate) transport: Transport,
161
162 #[field(deref = "[u8]", into_field = false, set = false, with = false)]
167 pub(crate) buffer: Buffer,
168
169 #[field(deref = false)]
171 pub(crate) context: Arc<HttpContext>,
172
173 #[field(copy)]
175 pub(crate) peer_ip: Option<IpAddr>,
176
177 #[field(copy)]
179 pub(crate) start_time: Instant,
180
181 pub(crate) authority: Option<Cow<'static, str>>,
183
184 pub(crate) scheme: Option<Cow<'static, str>>,
186
187 #[field = false]
190 pub(crate) protocol_session: ProtocolSession,
191
192 pub(crate) protocol: Option<Cow<'static, str>>,
194
195 #[field = "http_version"]
197 pub(crate) version: Version,
198
199 #[field(copy)]
202 pub(crate) status: Option<Status>,
203
204 pub(crate) secure: bool,
206
207 #[field = false]
211 pub(crate) received_body_state: ReceivedBodyState,
212
213 #[field(get, get_mut, take, set = false, with = false, into_field = false)]
216 pub(crate) received_trailers: Option<Headers>,
217
218 #[field = false]
220 pub(crate) content_length_in: Option<u64>,
221
222 #[field = false]
224 pub(crate) write_state: WriteState,
225
226 #[field = false]
229 pub(crate) inbound_encoding: &'static Encoding,
230
231 #[field = false]
235 pub(crate) h3_trailer_decode_in: Option<H3TrailerFuture>,
236
237 #[field = false]
241 pub(crate) h3_trailer_payload_in: Vec<u8>,
242}
243
244impl<Transport> Upgrade<Transport> {
245 #[doc(hidden)]
246 pub fn new(
247 received_headers: Headers,
248 path: impl Into<Cow<'static, str>>,
249 method: Method,
250 transport: Transport,
251 buffer: Buffer,
252 version: Version,
253 ) -> Self {
254 Self {
255 received_headers,
256 sent_headers: Headers::new(),
257 path: path.into(),
258 method,
259 transport,
260 buffer,
261 state: TypeSet::new(),
262 context: Arc::default(),
263 peer_ip: None,
264 start_time: Instant::now(),
265 authority: None,
266 scheme: None,
267 protocol_session: ProtocolSession::Http1,
268 protocol: None,
269 secure: false,
270 version,
271 status: None,
272 received_body_state: ReceivedBodyState::Raw { total: 0 },
273 received_trailers: None,
274 content_length_in: None,
275 write_state: WriteState::Raw,
276 inbound_encoding: encoding_rs::WINDOWS_1252,
277 h3_trailer_decode_in: None,
278 h3_trailer_payload_in: Vec::new(),
279 }
280 }
281
282 #[cfg(feature = "unstable")]
283 #[doc(hidden)]
284 #[allow(clippy::too_many_arguments)]
285 pub fn from_parts(
286 received_headers: Headers,
287 sent_headers: Headers,
288 path: Cow<'static, str>,
289 method: Method,
290 transport: Transport,
291 buffer: Buffer,
292 state: TypeSet,
293 context: Arc<HttpContext>,
294 peer_ip: Option<IpAddr>,
295 authority: Option<Cow<'static, str>>,
296 scheme: Option<Cow<'static, str>>,
297 protocol_session: ProtocolSession,
298 protocol: Option<Cow<'static, str>>,
299 version: Version,
300 status: Option<Status>,
301 secure: bool,
302 received_body_state: ReceivedBodyState,
303 received_trailers: Option<Headers>,
304 ) -> Self {
305 let write_state = compute_write_state(version, &sent_headers);
306 let content_length_in = parse_content_length(&received_headers);
307 let inbound_encoding = encoding(&received_headers);
308
309 Self {
310 received_headers,
311 sent_headers,
312 path,
313 method,
314 state,
315 transport,
316 buffer,
317 context,
318 peer_ip,
319 start_time: Instant::now(),
320 authority,
321 scheme,
322 protocol_session,
323 protocol,
324 version,
325 status,
326 secure,
327 received_body_state,
328 received_trailers,
329 content_length_in,
330 write_state,
331 inbound_encoding,
332 h3_trailer_decode_in: None,
333 h3_trailer_payload_in: Vec::new(),
334 }
335 }
336
337 pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
339 self.protocol_session.h2_connection()
340 }
341
342 pub fn h2_stream_id(&self) -> Option<u32> {
344 self.protocol_session.h2_stream_id()
345 }
346
347 pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
349 self.protocol_session.h3_connection()
350 }
351
352 pub fn h3_stream_id(&self) -> Option<u64> {
354 self.protocol_session.h3_stream_id()
355 }
356
357 pub fn take_buffer(&mut self) -> Vec<u8> {
359 std::mem::take(&mut self.buffer).into()
360 }
361
362 #[doc(hidden)]
363 pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
364 (&mut self.buffer, &mut self.transport)
365 }
366
367 pub fn shared_state(&self) -> &TypeSet {
369 self.context.shared_state()
370 }
371
372 pub fn path(&self) -> &str {
374 match self.path.split_once('?') {
375 Some((path, _)) => path,
376 None => &self.path,
377 }
378 }
379
380 pub fn querystring(&self) -> &str {
382 self.path
383 .split_once('?')
384 .map(|(_, query)| query)
385 .unwrap_or_default()
386 }
387
388 pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
392 self,
393 f: impl Fn(Transport) -> T,
394 ) -> Upgrade<T> {
395 Upgrade {
399 transport: f(self.transport),
400 path: self.path,
401 method: self.method,
402 state: self.state,
403 buffer: self.buffer,
404 received_headers: self.received_headers,
405 sent_headers: self.sent_headers,
406 context: self.context,
407 peer_ip: self.peer_ip,
408 start_time: self.start_time,
409 authority: self.authority,
410 scheme: self.scheme,
411 protocol_session: self.protocol_session,
412 protocol: self.protocol,
413 version: self.version,
414 status: self.status,
415 secure: self.secure,
416 received_body_state: self.received_body_state,
417 received_trailers: self.received_trailers,
418 content_length_in: self.content_length_in,
419 write_state: self.write_state,
420 inbound_encoding: self.inbound_encoding,
421 h3_trailer_decode_in: self.h3_trailer_decode_in,
422 h3_trailer_payload_in: self.h3_trailer_payload_in,
423 }
424 }
425}
426
427impl<Transport: AsyncWrite + Unpin> Upgrade<Transport> {
428 pub async fn send_trailers(self, trailers: Headers) -> io::Result<()> {
448 let Self {
449 mut transport,
450 mut write_state,
451 context,
452 protocol_session,
453 ..
454 } = self;
455
456 match &mut write_state {
457 WriteState::H1Chunked(state) => {
458 if state.terminator_written {
459 return Err(io::ErrorKind::BrokenPipe.into());
460 }
461 state.pending.extend_from_slice(b"0\r\n");
462 crate::conn::write_headers_or_trailers(&mut state.pending, &trailers, &context)
463 .map_err(io::Error::other)?;
464 state.pending.extend_from_slice(b"\r\n");
465 state.terminator_written = true;
466
467 transport.write_all(&state.pending).await?;
468 state.pending.clear();
469 transport.close().await
470 }
471 WriteState::H3Framed(state) => {
472 if state.terminator_written {
473 return Err(io::ErrorKind::BrokenPipe.into());
474 }
475 let Some((h3, stream_id)) = protocol_session.as_h3() else {
476 return Err(io::ErrorKind::NotConnected.into());
477 };
478 let max_field_section = h3
479 .peer_settings()
480 .and_then(H3Settings::max_field_section_size);
481 let field_section = FieldSection::new(PseudoHeaders::default(), &trailers);
482 crate::conn::encode_field_section_h3(
483 &h3,
484 &field_section,
485 max_field_section,
486 &mut state.pending,
487 stream_id,
488 )?;
489 state.terminator_written = true;
490
491 transport.write_all(&state.pending).await?;
492 state.pending.clear();
493 transport.close().await
494 }
495 WriteState::Raw => {
496 if let Some((h2, stream_id)) = protocol_session.as_h2() {
497 h2.submit_trailers(stream_id, trailers)
498 } else {
499 log::warn!(
500 "Upgrade::send_trailers called on a raw upgrade with no per-stream \
501 framing; trailers dropped. Set `Transfer-Encoding: chunked` on the \
502 outbound headers if you intend to emit trailers over HTTP/1.1."
503 );
504 Ok(())
505 }
506 }
507 }
508 }
509}
510
511impl<Transport> Debug for Upgrade<Transport> {
512 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
513 f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
514 .field("received_headers", &self.received_headers)
515 .field("sent_headers", &self.sent_headers)
516 .field("path", &self.path)
517 .field("method", &self.method)
518 .field("buffer", &self.buffer)
519 .field("context", &self.context)
520 .field("state", &self.state)
521 .field("transport", &format_args!(".."))
522 .field("peer_ip", &self.peer_ip)
523 .field("start_time", &self.start_time)
524 .field("authority", &self.authority)
525 .field("scheme", &self.scheme)
526 .field("protocol_session", &self.protocol_session)
527 .field("protocol", &self.protocol)
528 .field("version", &self.version)
529 .field("status", &self.status)
530 .field("secure", &self.secure)
531 .field("received_body_state", &self.received_body_state)
532 .field("received_trailers", &self.received_trailers)
533 .field("content_length_in", &self.content_length_in)
534 .field("write_state", &self.write_state)
535 .field("inbound_encoding", &self.inbound_encoding.name())
536 .field(
537 "h3_trailer_decode_in",
538 &self
539 .h3_trailer_decode_in
540 .as_ref()
541 .map(|_| format_args!("..")),
542 )
543 .field(
544 "h3_trailer_payload_in_len",
545 &self.h3_trailer_payload_in.len(),
546 )
547 .finish()
548 }
549}
550
551impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
552 fn from(conn: Conn<Transport>) -> Self {
553 let Conn {
556 request_headers,
557 response_headers,
558 path,
559 method,
560 state,
561 transport,
562 buffer,
563 context,
564 peer_ip,
565 start_time,
566 authority,
567 scheme,
568 protocol_session,
569 protocol,
570 version,
571 status,
572 secure,
573 request_body_state,
574 request_trailers,
575 response_body,
576 after_send: _,
578 upgrade: _,
579 } = conn;
580
581 if let Some(body) = &response_body
582 && !body.is_empty()
583 {
584 log::warn!(
585 "Conn::upgrade() and a non-empty response body are both set; body is being \
586 discarded. The upgrade path is mutually exclusive with serving a response body."
587 );
588 }
589
590 let write_state = compute_write_state(version, &response_headers);
592 let content_length_in = parse_content_length(&request_headers);
593 let inbound_encoding = encoding(&request_headers);
594 let received_body_state = request_body_state;
595 let received_trailers = request_trailers.filter(|t| !t.is_empty());
596
597 Self {
598 received_headers: request_headers,
599 sent_headers: response_headers,
600 path,
601 method,
602 state,
603 transport,
604 buffer,
605 context,
606 peer_ip,
607 start_time,
608 authority,
609 scheme,
610 protocol_session,
611 protocol,
612 version,
613 status,
614 secure,
615 received_body_state,
616 received_trailers,
617 content_length_in,
618 write_state,
619 inbound_encoding,
620 h3_trailer_decode_in: None,
621 h3_trailer_payload_in: Vec::new(),
622 }
623 }
624}
625
626#[cfg(test)]
627mod tests;
628
629impl<Transport> AsyncRead for Upgrade<Transport>
630where
631 Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
632{
633 fn poll_read(
634 mut self: Pin<&mut Self>,
635 cx: &mut Context<'_>,
636 buf: &mut [u8],
637 ) -> Poll<io::Result<usize>> {
638 let Self {
639 transport,
640 buffer,
641 received_body_state,
642 content_length_in,
643 context,
644 protocol_session,
645 received_trailers,
646 h3_trailer_decode_in,
647 h3_trailer_payload_in,
648 inbound_encoding,
649 ..
650 } = &mut *self;
651
652 let protocol_session = protocol_session.clone();
653 let mut body: ReceivedBody<'_, Transport> = ReceivedBody::new_with_config(
654 *content_length_in,
655 buffer,
656 transport,
657 received_body_state,
658 None,
659 inbound_encoding,
660 &context.config,
661 )
662 .with_trailers(received_trailers)
663 .with_protocol_session(protocol_session)
664 .with_h3_trailer_future(h3_trailer_decode_in)
665 .with_h3_trailer_payload_buffer(h3_trailer_payload_in);
666
667 Pin::new(&mut body).poll_read(cx, buf)
668 }
669}
670
671impl<Transport: AsyncWrite + Unpin> AsyncWrite for Upgrade<Transport> {
672 fn poll_write(
673 mut self: Pin<&mut Self>,
674 cx: &mut Context<'_>,
675 buf: &[u8],
676 ) -> Poll<io::Result<usize>> {
677 let Self {
678 transport,
679 write_state,
680 ..
681 } = &mut *self;
682 match write_state {
683 WriteState::Raw => Pin::new(transport).poll_write(cx, buf),
684 WriteState::H1Chunked(state) => {
685 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
686
687 if buf.is_empty() {
689 return Poll::Ready(Ok(0));
690 }
691
692 if state.terminator_written {
693 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
694 }
695
696 write_chunk(&mut state.pending, buf);
697 best_effort_drain(&mut state.pending, cx, transport)?;
698 Poll::Ready(Ok(buf.len()))
699 }
700 WriteState::H3Framed(state) => {
701 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
702
703 if buf.is_empty() {
704 return Poll::Ready(Ok(0));
705 }
706
707 if state.terminator_written {
708 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
709 }
710
711 encode_h3_data_header(&mut state.pending, buf.len() as u64);
712 state.pending.extend_from_slice(buf);
713 best_effort_drain(&mut state.pending, cx, transport)?;
714 Poll::Ready(Ok(buf.len()))
715 }
716 }
717 }
718
719 fn poll_write_vectored(
720 mut self: Pin<&mut Self>,
721 cx: &mut Context<'_>,
722 bufs: &[IoSlice<'_>],
723 ) -> Poll<io::Result<usize>> {
724 let Self {
725 transport,
726 write_state,
727 ..
728 } = &mut *self;
729 match write_state {
730 WriteState::Raw => Pin::new(transport).poll_write_vectored(cx, bufs),
731 WriteState::H1Chunked(state) => {
732 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
733 let total: usize = bufs.iter().map(|b| b.len()).sum();
734 if total == 0 {
735 return Poll::Ready(Ok(0));
736 }
737 if state.terminator_written {
738 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
739 }
740 let _ = write!(state.pending, "{total:X}\r\n");
743 for b in bufs {
744 state.pending.extend_from_slice(b);
745 }
746 state.pending.extend_from_slice(b"\r\n");
747 best_effort_drain(&mut state.pending, cx, transport)?;
748 Poll::Ready(Ok(total))
749 }
750 WriteState::H3Framed(state) => {
751 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
752 let total: usize = bufs.iter().map(|b| b.len()).sum();
753 if total == 0 {
754 return Poll::Ready(Ok(0));
755 }
756 if state.terminator_written {
757 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
758 }
759 encode_h3_data_header(&mut state.pending, total as u64);
762 for b in bufs {
763 state.pending.extend_from_slice(b);
764 }
765 best_effort_drain(&mut state.pending, cx, transport)?;
766 Poll::Ready(Ok(total))
767 }
768 }
769 }
770
771 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
772 let Self {
773 transport,
774 write_state,
775 ..
776 } = &mut *self;
777 match write_state {
778 WriteState::Raw => Pin::new(transport).poll_flush(cx),
779 WriteState::H1Chunked(state) => {
780 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
781 Pin::new(transport).poll_flush(cx)
782 }
783 WriteState::H3Framed(state) => {
784 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
785 Pin::new(transport).poll_flush(cx)
786 }
787 }
788 }
789
790 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
791 let Self {
792 transport,
793 write_state,
794 ..
795 } = &mut *self;
796 match write_state {
797 WriteState::Raw => Pin::new(transport).poll_close(cx),
798 WriteState::H1Chunked(state) => {
799 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
800 if !state.terminator_written {
801 state.pending.extend_from_slice(b"0\r\n\r\n");
802 state.terminator_written = true;
804 }
805 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
806 Pin::new(transport).poll_close(cx)
807 }
808 WriteState::H3Framed(state) => {
809 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
811 state.terminator_written = true;
812 Pin::new(transport).poll_close(cx)
813 }
814 }
815 }
816}