1use crate::{
2 Buffer, Conn, Headers, HttpContext, KnownHeaderName, Method, ProtocolSession, ReceivedBody,
3 Status, TypeSet, Version,
4 h2::H2Connection,
5 h3::{Frame, H3Connection, H3Settings},
6 headers::qpack::{FieldSection, PseudoHeaders},
7 received_body::{H3TrailerFuture, ReceivedBodyState, write_chunk},
8 util::encoding,
9};
10use encoding_rs::Encoding;
11use fieldwork::Fieldwork;
12use futures_lite::{
13 AsyncWriteExt,
14 io::{AsyncRead, AsyncWrite},
15};
16use std::{
17 borrow::Cow,
18 fmt::{self, Debug, Formatter},
19 io::{self, IoSlice, Write},
20 net::IpAddr,
21 pin::Pin,
22 str,
23 sync::Arc,
24 task::{Context, Poll, ready},
25 time::Instant,
26};
27
28#[derive(Debug)]
31pub(crate) enum WriteState {
32 Raw,
35 H1Chunked(H1ChunkedState),
37 H3Framed(H3FramedState),
39}
40
41#[derive(Debug, Default)]
42pub(crate) struct H1ChunkedState {
43 pub(crate) pending: Vec<u8>,
44 pub(crate) terminator_written: bool,
45}
46
47#[derive(Debug, Default)]
48pub(crate) struct H3FramedState {
49 pub(crate) pending: Vec<u8>,
50 pub(crate) terminator_written: bool,
51}
52
53fn compute_write_state(version: Version, outbound_headers: &Headers) -> WriteState {
57 match version {
58 Version::Http1_0 | Version::Http1_1 if has_chunked_encoding(outbound_headers) => {
59 WriteState::H1Chunked(H1ChunkedState::default())
60 }
61 Version::Http3 => WriteState::H3Framed(H3FramedState::default()),
62 _ => WriteState::Raw,
63 }
64}
65
66fn has_chunked_encoding(headers: &Headers) -> bool {
69 headers
70 .get_str(KnownHeaderName::TransferEncoding)
71 .is_some_and(|v| {
72 v.split(',')
73 .any(|coding| coding.trim().eq_ignore_ascii_case("chunked"))
74 })
75}
76
77fn parse_content_length(inbound_headers: &Headers) -> Option<u64> {
79 if inbound_headers.has_header(KnownHeaderName::TransferEncoding) {
80 return None;
81 }
82 inbound_headers.content_length()
83}
84
85fn poll_drain_pending<T: AsyncWrite + Unpin>(
87 pending: &mut Vec<u8>,
88 cx: &mut Context<'_>,
89 transport: &mut T,
90) -> Poll<io::Result<()>> {
91 while !pending.is_empty() {
92 match Pin::new(&mut *transport).poll_write(cx, pending) {
93 Poll::Ready(Ok(0)) => return Poll::Ready(Err(io::ErrorKind::WriteZero.into())),
94 Poll::Ready(Ok(n)) => {
95 pending.drain(..n);
96 }
97 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
98 Poll::Pending => return Poll::Pending,
99 }
100 }
101 Poll::Ready(Ok(()))
102}
103
104fn best_effort_drain<T: AsyncWrite + Unpin>(
107 pending: &mut Vec<u8>,
108 cx: &mut Context<'_>,
109 transport: &mut T,
110) -> io::Result<()> {
111 while !pending.is_empty() {
112 match Pin::new(&mut *transport).poll_write(cx, pending) {
113 Poll::Ready(Ok(0)) => return Err(io::ErrorKind::WriteZero.into()),
114 Poll::Ready(Ok(n)) => {
115 pending.drain(..n);
116 }
117 Poll::Ready(Err(e)) => return Err(e),
118 Poll::Pending => break,
119 }
120 }
121 Ok(())
122}
123
124fn encode_h3_data_header(out: &mut Vec<u8>, payload_len: u64) {
127 let frame = Frame::Data(payload_len);
128 let header_len = frame.encoded_len();
129 let start = out.len();
130 out.resize(start + header_len, 0);
131 frame.encode(&mut out[start..]);
132}
133
134#[derive(Fieldwork)]
140#[fieldwork(get, get_mut, set, with, take, into_field, rename_predicates)]
141pub struct Upgrade<Transport> {
142 #[field(deprecate(was = "request_headers", since = "1.3.0"))]
144 pub(crate) received_headers: Headers,
145
146 #[field(deprecate(was = "response_headers", since = "1.3.0"))]
149 pub(crate) sent_headers: Headers,
150
151 #[field(get = false)]
153 pub(crate) path: Cow<'static, str>,
154
155 #[field(copy)]
157 pub(crate) method: Method,
158
159 pub(crate) state: TypeSet,
161
162 pub(crate) transport: Transport,
164
165 #[field(deref = "[u8]", into_field = false, set = false, with = false)]
170 pub(crate) buffer: Buffer,
171
172 #[field(deref = false)]
174 pub(crate) context: Arc<HttpContext>,
175
176 #[field(copy)]
178 pub(crate) peer_ip: Option<IpAddr>,
179
180 #[field(copy)]
182 pub(crate) start_time: Instant,
183
184 pub(crate) authority: Option<Cow<'static, str>>,
186
187 pub(crate) scheme: Option<Cow<'static, str>>,
189
190 #[field = false]
193 pub(crate) protocol_session: ProtocolSession,
194
195 pub(crate) protocol: Option<Cow<'static, str>>,
197
198 #[field = "http_version"]
200 pub(crate) version: Version,
201
202 #[field(copy)]
205 pub(crate) status: Option<Status>,
206
207 pub(crate) secure: bool,
209
210 #[field = false]
214 pub(crate) received_body_state: ReceivedBodyState,
215
216 #[field(get, get_mut, take, set = false, with = false, into_field = false)]
219 pub(crate) received_trailers: Option<Headers>,
220
221 #[field = false]
223 pub(crate) content_length_in: Option<u64>,
224
225 #[field = false]
227 pub(crate) write_state: WriteState,
228
229 #[field = false]
232 pub(crate) inbound_encoding: &'static Encoding,
233
234 #[field = false]
238 pub(crate) h3_trailer_decode_in: Option<H3TrailerFuture>,
239
240 #[field = false]
244 pub(crate) h3_trailer_payload_in: Vec<u8>,
245}
246
247impl<Transport> Upgrade<Transport> {
248 #[doc(hidden)]
249 pub fn new(
250 received_headers: Headers,
251 path: impl Into<Cow<'static, str>>,
252 method: Method,
253 transport: Transport,
254 buffer: Buffer,
255 version: Version,
256 ) -> Self {
257 Self {
258 received_headers,
259 sent_headers: Headers::new(),
260 path: path.into(),
261 method,
262 transport,
263 buffer,
264 state: TypeSet::new(),
265 context: Arc::default(),
266 peer_ip: None,
267 start_time: Instant::now(),
268 authority: None,
269 scheme: None,
270 protocol_session: ProtocolSession::Http1,
271 protocol: None,
272 secure: false,
273 version,
274 status: None,
275 received_body_state: ReceivedBodyState::Raw { total: 0 },
276 received_trailers: None,
277 content_length_in: None,
278 write_state: WriteState::Raw,
279 inbound_encoding: encoding_rs::WINDOWS_1252,
280 h3_trailer_decode_in: None,
281 h3_trailer_payload_in: Vec::new(),
282 }
283 }
284
285 #[cfg(feature = "unstable")]
286 #[doc(hidden)]
287 #[allow(clippy::too_many_arguments)]
288 pub fn from_parts(
289 received_headers: Headers,
290 sent_headers: Headers,
291 path: Cow<'static, str>,
292 method: Method,
293 transport: Transport,
294 buffer: Buffer,
295 state: TypeSet,
296 context: Arc<HttpContext>,
297 peer_ip: Option<IpAddr>,
298 authority: Option<Cow<'static, str>>,
299 scheme: Option<Cow<'static, str>>,
300 protocol_session: ProtocolSession,
301 protocol: Option<Cow<'static, str>>,
302 version: Version,
303 status: Option<Status>,
304 secure: bool,
305 received_body_state: ReceivedBodyState,
306 received_trailers: Option<Headers>,
307 ) -> Self {
308 let write_state = compute_write_state(version, &sent_headers);
309 let content_length_in = parse_content_length(&received_headers);
310 let inbound_encoding = encoding(&received_headers);
311
312 Self {
313 received_headers,
314 sent_headers,
315 path,
316 method,
317 state,
318 transport,
319 buffer,
320 context,
321 peer_ip,
322 start_time: Instant::now(),
323 authority,
324 scheme,
325 protocol_session,
326 protocol,
327 version,
328 status,
329 secure,
330 received_body_state,
331 received_trailers,
332 content_length_in,
333 write_state,
334 inbound_encoding,
335 h3_trailer_decode_in: None,
336 h3_trailer_payload_in: Vec::new(),
337 }
338 }
339
340 pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
342 self.protocol_session.h2_connection()
343 }
344
345 pub fn h2_stream_id(&self) -> Option<u32> {
347 self.protocol_session.h2_stream_id()
348 }
349
350 pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
352 self.protocol_session.h3_connection()
353 }
354
355 pub fn h3_stream_id(&self) -> Option<u64> {
357 self.protocol_session.h3_stream_id()
358 }
359
360 pub fn take_buffer(&mut self) -> Vec<u8> {
362 std::mem::take(&mut self.buffer).into()
363 }
364
365 #[doc(hidden)]
366 pub fn buffer_and_transport_mut(&mut self) -> (&mut Buffer, &mut Transport) {
367 (&mut self.buffer, &mut self.transport)
368 }
369
370 pub fn shared_state(&self) -> &TypeSet {
372 self.context.shared_state()
373 }
374
375 pub fn path(&self) -> &str {
377 match self.path.split_once('?') {
378 Some((path, _)) => path,
379 None => &self.path,
380 }
381 }
382
383 pub fn querystring(&self) -> &str {
385 self.path
386 .split_once('?')
387 .map(|(_, query)| query)
388 .unwrap_or_default()
389 }
390
391 pub fn map_transport<T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static>(
395 self,
396 f: impl Fn(Transport) -> T,
397 ) -> Upgrade<T> {
398 Upgrade {
402 transport: f(self.transport),
403 path: self.path,
404 method: self.method,
405 state: self.state,
406 buffer: self.buffer,
407 received_headers: self.received_headers,
408 sent_headers: self.sent_headers,
409 context: self.context,
410 peer_ip: self.peer_ip,
411 start_time: self.start_time,
412 authority: self.authority,
413 scheme: self.scheme,
414 protocol_session: self.protocol_session,
415 protocol: self.protocol,
416 version: self.version,
417 status: self.status,
418 secure: self.secure,
419 received_body_state: self.received_body_state,
420 received_trailers: self.received_trailers,
421 content_length_in: self.content_length_in,
422 write_state: self.write_state,
423 inbound_encoding: self.inbound_encoding,
424 h3_trailer_decode_in: self.h3_trailer_decode_in,
425 h3_trailer_payload_in: self.h3_trailer_payload_in,
426 }
427 }
428}
429
430impl<Transport: AsyncWrite + Unpin> Upgrade<Transport> {
431 pub async fn send_trailers(self, trailers: Headers) -> io::Result<()> {
451 let Self {
452 mut transport,
453 mut write_state,
454 context,
455 protocol_session,
456 ..
457 } = self;
458
459 match &mut write_state {
460 WriteState::H1Chunked(state) => {
461 if state.terminator_written {
462 return Err(io::ErrorKind::BrokenPipe.into());
463 }
464 state.pending.extend_from_slice(b"0\r\n");
465 crate::conn::write_headers_or_trailers(&mut state.pending, &trailers, &context)
466 .map_err(io::Error::other)?;
467 state.pending.extend_from_slice(b"\r\n");
468 state.terminator_written = true;
469
470 transport.write_all(&state.pending).await?;
471 state.pending.clear();
472 transport.close().await
473 }
474 WriteState::H3Framed(state) => {
475 if state.terminator_written {
476 return Err(io::ErrorKind::BrokenPipe.into());
477 }
478 let Some((h3, stream_id)) = protocol_session.as_h3() else {
479 return Err(io::ErrorKind::NotConnected.into());
480 };
481 let max_field_section = h3
482 .peer_settings()
483 .and_then(H3Settings::max_field_section_size);
484 let field_section = FieldSection::new(PseudoHeaders::default(), &trailers);
485 crate::conn::encode_field_section_h3(
486 &h3,
487 &field_section,
488 max_field_section,
489 &mut state.pending,
490 stream_id,
491 )?;
492 state.terminator_written = true;
493
494 transport.write_all(&state.pending).await?;
495 state.pending.clear();
496 transport.close().await
497 }
498 WriteState::Raw => {
499 if let Some((h2, stream_id)) = protocol_session.as_h2() {
500 h2.submit_trailers(stream_id, trailers)
501 } else {
502 log::warn!(
503 "Upgrade::send_trailers called on a raw upgrade with no per-stream \
504 framing; trailers dropped. Set `Transfer-Encoding: chunked` on the \
505 outbound headers if you intend to emit trailers over HTTP/1.1."
506 );
507 Ok(())
508 }
509 }
510 }
511 }
512}
513
514impl<Transport> Debug for Upgrade<Transport> {
515 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
516 f.debug_struct(&format!("Upgrade<{}>", std::any::type_name::<Transport>()))
517 .field("received_headers", &self.received_headers)
518 .field("sent_headers", &self.sent_headers)
519 .field("path", &self.path)
520 .field("method", &self.method)
521 .field("buffer", &self.buffer)
522 .field("context", &self.context)
523 .field("state", &self.state)
524 .field("transport", &format_args!(".."))
525 .field("peer_ip", &self.peer_ip)
526 .field("start_time", &self.start_time)
527 .field("authority", &self.authority)
528 .field("scheme", &self.scheme)
529 .field("protocol_session", &self.protocol_session)
530 .field("protocol", &self.protocol)
531 .field("version", &self.version)
532 .field("status", &self.status)
533 .field("secure", &self.secure)
534 .field("received_body_state", &self.received_body_state)
535 .field("received_trailers", &self.received_trailers)
536 .field("content_length_in", &self.content_length_in)
537 .field("write_state", &self.write_state)
538 .field("inbound_encoding", &self.inbound_encoding.name())
539 .field(
540 "h3_trailer_decode_in",
541 &self
542 .h3_trailer_decode_in
543 .as_ref()
544 .map(|_| format_args!("..")),
545 )
546 .field(
547 "h3_trailer_payload_in_len",
548 &self.h3_trailer_payload_in.len(),
549 )
550 .finish()
551 }
552}
553
554impl<Transport> From<Conn<Transport>> for Upgrade<Transport> {
555 fn from(conn: Conn<Transport>) -> Self {
556 let Conn {
559 request_headers,
560 response_headers,
561 path,
562 method,
563 state,
564 transport,
565 buffer,
566 context,
567 peer_ip,
568 start_time,
569 authority,
570 scheme,
571 protocol_session,
572 protocol,
573 version,
574 status,
575 secure,
576 request_body_state,
577 request_trailers,
578 response_body,
579 after_send: _,
581 upgrade: _,
582 } = conn;
583
584 if let Some(body) = &response_body
585 && !body.is_empty()
586 {
587 log::warn!(
588 "Conn::upgrade() and a non-empty response body are both set; body is being \
589 discarded. The upgrade path is mutually exclusive with serving a response body."
590 );
591 }
592
593 let write_state = compute_write_state(version, &response_headers);
595 let content_length_in = parse_content_length(&request_headers);
596 let inbound_encoding = encoding(&request_headers);
597 let received_body_state = request_body_state;
598 let received_trailers = request_trailers.filter(|t| !t.is_empty());
599
600 Self {
601 received_headers: request_headers,
602 sent_headers: response_headers,
603 path,
604 method,
605 state,
606 transport,
607 buffer,
608 context,
609 peer_ip,
610 start_time,
611 authority,
612 scheme,
613 protocol_session,
614 protocol,
615 version,
616 status,
617 secure,
618 received_body_state,
619 received_trailers,
620 content_length_in,
621 write_state,
622 inbound_encoding,
623 h3_trailer_decode_in: None,
624 h3_trailer_payload_in: Vec::new(),
625 }
626 }
627}
628
629#[cfg(test)]
630mod tests;
631
632impl<Transport> AsyncRead for Upgrade<Transport>
633where
634 Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
635{
636 fn poll_read(
637 mut self: Pin<&mut Self>,
638 cx: &mut Context<'_>,
639 buf: &mut [u8],
640 ) -> Poll<io::Result<usize>> {
641 let Self {
642 transport,
643 buffer,
644 received_body_state,
645 content_length_in,
646 context,
647 protocol_session,
648 received_trailers,
649 h3_trailer_decode_in,
650 h3_trailer_payload_in,
651 inbound_encoding,
652 ..
653 } = &mut *self;
654
655 let protocol_session = protocol_session.clone();
656 let mut body: ReceivedBody<'_, Transport> = ReceivedBody::new_with_config(
657 *content_length_in,
658 buffer,
659 transport,
660 received_body_state,
661 None,
662 inbound_encoding,
663 &context.config,
664 )
665 .with_trailers(received_trailers)
666 .with_protocol_session(protocol_session)
667 .with_h3_trailer_future(h3_trailer_decode_in)
668 .with_h3_trailer_payload_buffer(h3_trailer_payload_in);
669
670 Pin::new(&mut body).poll_read(cx, buf)
671 }
672}
673
674impl<Transport: AsyncWrite + Unpin> AsyncWrite for Upgrade<Transport> {
675 fn poll_write(
676 mut self: Pin<&mut Self>,
677 cx: &mut Context<'_>,
678 buf: &[u8],
679 ) -> Poll<io::Result<usize>> {
680 let Self {
681 transport,
682 write_state,
683 ..
684 } = &mut *self;
685 match write_state {
686 WriteState::Raw => Pin::new(transport).poll_write(cx, buf),
687 WriteState::H1Chunked(state) => {
688 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
689
690 if buf.is_empty() {
692 return Poll::Ready(Ok(0));
693 }
694
695 if state.terminator_written {
696 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
697 }
698
699 write_chunk(&mut state.pending, buf);
700 best_effort_drain(&mut state.pending, cx, transport)?;
701 Poll::Ready(Ok(buf.len()))
702 }
703 WriteState::H3Framed(state) => {
704 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
705
706 if buf.is_empty() {
707 return Poll::Ready(Ok(0));
708 }
709
710 if state.terminator_written {
711 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
712 }
713
714 encode_h3_data_header(&mut state.pending, buf.len() as u64);
715 state.pending.extend_from_slice(buf);
716 best_effort_drain(&mut state.pending, cx, transport)?;
717 Poll::Ready(Ok(buf.len()))
718 }
719 }
720 }
721
722 fn poll_write_vectored(
723 mut self: Pin<&mut Self>,
724 cx: &mut Context<'_>,
725 bufs: &[IoSlice<'_>],
726 ) -> Poll<io::Result<usize>> {
727 let Self {
728 transport,
729 write_state,
730 ..
731 } = &mut *self;
732 match write_state {
733 WriteState::Raw => Pin::new(transport).poll_write_vectored(cx, bufs),
734 WriteState::H1Chunked(state) => {
735 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
736 let total: usize = bufs.iter().map(|b| b.len()).sum();
737 if total == 0 {
738 return Poll::Ready(Ok(0));
739 }
740 if state.terminator_written {
741 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
742 }
743 let _ = write!(state.pending, "{total:X}\r\n");
746 for b in bufs {
747 state.pending.extend_from_slice(b);
748 }
749 state.pending.extend_from_slice(b"\r\n");
750 best_effort_drain(&mut state.pending, cx, transport)?;
751 Poll::Ready(Ok(total))
752 }
753 WriteState::H3Framed(state) => {
754 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
755 let total: usize = bufs.iter().map(|b| b.len()).sum();
756 if total == 0 {
757 return Poll::Ready(Ok(0));
758 }
759 if state.terminator_written {
760 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
761 }
762 encode_h3_data_header(&mut state.pending, total as u64);
765 for b in bufs {
766 state.pending.extend_from_slice(b);
767 }
768 best_effort_drain(&mut state.pending, cx, transport)?;
769 Poll::Ready(Ok(total))
770 }
771 }
772 }
773
774 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
775 let Self {
776 transport,
777 write_state,
778 ..
779 } = &mut *self;
780 match write_state {
781 WriteState::Raw => Pin::new(transport).poll_flush(cx),
782 WriteState::H1Chunked(state) => {
783 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
784 Pin::new(transport).poll_flush(cx)
785 }
786 WriteState::H3Framed(state) => {
787 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
788 Pin::new(transport).poll_flush(cx)
789 }
790 }
791 }
792
793 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
794 let Self {
795 transport,
796 write_state,
797 ..
798 } = &mut *self;
799 match write_state {
800 WriteState::Raw => Pin::new(transport).poll_close(cx),
801 WriteState::H1Chunked(state) => {
802 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
803 if !state.terminator_written {
804 state.pending.extend_from_slice(b"0\r\n\r\n");
805 state.terminator_written = true;
807 }
808 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
809 Pin::new(transport).poll_close(cx)
810 }
811 WriteState::H3Framed(state) => {
812 ready!(poll_drain_pending(&mut state.pending, cx, transport))?;
814 state.terminator_written = true;
815 Pin::new(transport).poll_close(cx)
816 }
817 }
818 }
819}