Skip to main content

actix_http/h1/
dispatcher.rs

1use std::{
2    collections::VecDeque,
3    fmt,
4    future::Future,
5    io, mem, net,
6    pin::Pin,
7    rc::Rc,
8    task::{Context, Poll},
9};
10
11use actix_codec::{Framed, FramedParts};
12use actix_rt::time::sleep_until;
13use actix_service::Service;
14use bitflags::bitflags;
15use bytes::{Buf, BytesMut};
16use futures_core::ready;
17use pin_project_lite::pin_project;
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::{Decoder as _, Encoder as _};
20use tracing::{error, trace};
21
22use super::{
23    codec::Codec,
24    decoder::MAX_BUFFER_SIZE,
25    payload::{Payload, PayloadSender, PayloadStatus},
26    timer::TimerState,
27    Message, MessageType,
28};
29use crate::{
30    body::{BodySize, BoxBody, MessageBody},
31    config::ServiceConfig,
32    error::{DispatchError, ParseError, PayloadError},
33    service::HttpFlow,
34    ConnectionType, Error, Extensions, HttpMessage, OnConnectData, Request, Response, StatusCode,
35};
36
37const LW_BUFFER_SIZE: usize = 1024;
38const HW_BUFFER_SIZE: usize = 1024 * 8;
39const MAX_PIPELINED_MESSAGES: usize = 16;
40
41bitflags! {
42    #[derive(Debug, Clone, Copy)]
43    pub struct Flags: u8 {
44        /// Set when stream is read for first time.
45        const STARTED          = 0b0000_0001;
46
47        /// Set when full request-response cycle has occurred.
48        const FINISHED         = 0b0000_0010;
49
50        /// Set if connection is in keep-alive (inactive) state.
51        const KEEP_ALIVE       = 0b0000_0100;
52
53        /// Set if in shutdown procedure.
54        const SHUTDOWN         = 0b0000_1000;
55
56        /// Set if read-half is disconnected.
57        const READ_DISCONNECT  = 0b0001_0000;
58
59        /// Set if write-half is disconnected.
60        const WRITE_DISCONNECT = 0b0010_0000;
61
62        /// Set while gracefully closing a connection after an early response.
63        const LINGER           = 0b0100_0000;
64    }
65}
66
67// there's 2 versions of Dispatcher state because of:
68// https://github.com/taiki-e/pin-project-lite/issues/3
69//
70// tl;dr: pin-project-lite doesn't play well with other attribute macros
71
72#[cfg(not(test))]
73pin_project! {
74    /// Dispatcher for HTTP/1.1 protocol
75    pub struct Dispatcher<T, S, B, X, U>
76    where
77        S: Service<Request>,
78        S::Error: Into<Response<BoxBody>>,
79
80        B: MessageBody,
81
82        X: Service<Request, Response = Request>,
83        X::Error: Into<Response<BoxBody>>,
84
85        U: Service<(Request, Framed<T, Codec>), Response = ()>,
86        U::Error: fmt::Display,
87    {
88        #[pin]
89        inner: DispatcherState<T, S, B, X, U>,
90    }
91}
92
93#[cfg(test)]
94pin_project! {
95    /// Dispatcher for HTTP/1.1 protocol
96    pub struct Dispatcher<T, S, B, X, U>
97    where
98        S: Service<Request>,
99        S::Error: Into<Response<BoxBody>>,
100
101        B: MessageBody,
102
103        X: Service<Request, Response = Request>,
104        X::Error: Into<Response<BoxBody>>,
105
106        U: Service<(Request, Framed<T, Codec>), Response = ()>,
107        U::Error: fmt::Display,
108    {
109        #[pin]
110        pub(super) inner: DispatcherState<T, S, B, X, U>,
111
112        // used in tests
113        pub(super) poll_count: u64,
114    }
115}
116
117pin_project! {
118    #[project = DispatcherStateProj]
119    pub(super) enum DispatcherState<T, S, B, X, U>
120    where
121        S: Service<Request>,
122        S::Error: Into<Response<BoxBody>>,
123
124        B: MessageBody,
125
126        X: Service<Request, Response = Request>,
127        X::Error: Into<Response<BoxBody>>,
128
129        U: Service<(Request, Framed<T, Codec>), Response = ()>,
130        U::Error: fmt::Display,
131    {
132        Normal { #[pin] inner: InnerDispatcher<T, S, B, X, U> },
133        Upgrade { #[pin] fut: U::Future },
134    }
135}
136
137pin_project! {
138    #[project = InnerDispatcherProj]
139    pub(super) struct InnerDispatcher<T, S, B, X, U>
140    where
141        S: Service<Request>,
142        S::Error: Into<Response<BoxBody>>,
143
144        B: MessageBody,
145
146        X: Service<Request, Response = Request>,
147        X::Error: Into<Response<BoxBody>>,
148
149        U: Service<(Request, Framed<T, Codec>), Response = ()>,
150        U::Error: fmt::Display,
151    {
152        flow: Rc<HttpFlow<S, X, U>>,
153        pub(super) flags: Flags,
154        peer_addr: Option<net::SocketAddr>,
155        conn_data: Option<Rc<Extensions>>,
156        config: ServiceConfig,
157        error: Option<DispatchError>,
158
159        #[pin]
160        pub(super) state: State<S, B, X>,
161        // when Some(_) dispatcher is in state of receiving request payload
162        payload: Option<PayloadSender>,
163        // true when current request uses chunked transfer encoding (drainable when payload is dropped)
164        payload_drainable: bool,
165        messages: VecDeque<DispatcherMessage>,
166
167        head_timer: TimerState,
168        ka_timer: TimerState,
169        shutdown_timer: TimerState,
170
171        pub(super) io: Option<T>,
172        read_buf: BytesMut,
173        write_buf: BytesMut,
174        h1_write_buffer_size: usize,
175        codec: Codec,
176    }
177}
178
179enum DispatcherMessage {
180    Item(Request),
181    Upgrade(Request),
182    Error(Response<()>),
183}
184
185pin_project! {
186    #[project = StateProj]
187    pub(super) enum State<S, B, X>
188    where
189        S: Service<Request>,
190        X: Service<Request, Response = Request>,
191        B: MessageBody,
192    {
193        None,
194        ExpectCall { #[pin] fut: X::Future },
195        ServiceCall { #[pin] fut: S::Future },
196        SendPayload { #[pin] body: B },
197        SendErrorPayload { #[pin] body: BoxBody },
198    }
199}
200
201impl<S, B, X> State<S, B, X>
202where
203    S: Service<Request>,
204    X: Service<Request, Response = Request>,
205    B: MessageBody,
206{
207    pub(super) fn is_none(&self) -> bool {
208        matches!(self, State::None)
209    }
210}
211
212impl<S, B, X> fmt::Debug for State<S, B, X>
213where
214    S: Service<Request>,
215    X: Service<Request, Response = Request>,
216    B: MessageBody,
217{
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        match self {
220            Self::None => write!(f, "State::None"),
221            Self::ExpectCall { .. } => f.debug_struct("State::ExpectCall").finish_non_exhaustive(),
222            Self::ServiceCall { .. } => {
223                f.debug_struct("State::ServiceCall").finish_non_exhaustive()
224            }
225            Self::SendPayload { .. } => {
226                f.debug_struct("State::SendPayload").finish_non_exhaustive()
227            }
228            Self::SendErrorPayload { .. } => f
229                .debug_struct("State::SendErrorPayload")
230                .finish_non_exhaustive(),
231        }
232    }
233}
234
235#[derive(Debug)]
236enum PollResponse {
237    Upgrade(Request),
238    DoNothing,
239    DrainWriteBuf,
240}
241
242impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
243where
244    T: AsyncRead + AsyncWrite + Unpin,
245
246    S: Service<Request>,
247    S::Error: Into<Response<BoxBody>>,
248    S::Response: Into<Response<B>>,
249
250    B: MessageBody,
251
252    X: Service<Request, Response = Request>,
253    X::Error: Into<Response<BoxBody>>,
254
255    U: Service<(Request, Framed<T, Codec>), Response = ()>,
256    U::Error: fmt::Display,
257{
258    /// Create HTTP/1 dispatcher.
259    pub(crate) fn new(
260        io: T,
261        flow: Rc<HttpFlow<S, X, U>>,
262        config: ServiceConfig,
263        peer_addr: Option<net::SocketAddr>,
264        conn_data: OnConnectData,
265    ) -> Self {
266        Dispatcher {
267            inner: DispatcherState::Normal {
268                inner: InnerDispatcher {
269                    flow,
270                    flags: Flags::empty(),
271                    peer_addr,
272                    conn_data: conn_data.0.map(Rc::new),
273                    config: config.clone(),
274                    error: None,
275
276                    state: State::None,
277                    payload: None,
278                    payload_drainable: false,
279                    messages: VecDeque::new(),
280
281                    head_timer: TimerState::new(config.client_request_deadline().is_some()),
282                    ka_timer: TimerState::new(config.keep_alive().enabled()),
283                    shutdown_timer: TimerState::new(config.client_disconnect_deadline().is_some()),
284
285                    io: Some(io),
286                    read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
287                    write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
288                    h1_write_buffer_size: config.h1_write_buffer_size(),
289                    codec: Codec::new(config),
290                },
291            },
292
293            #[cfg(test)]
294            poll_count: 0,
295        }
296    }
297}
298
299impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
300where
301    T: AsyncRead + AsyncWrite + Unpin,
302
303    S: Service<Request>,
304    S::Error: Into<Response<BoxBody>>,
305    S::Response: Into<Response<B>>,
306
307    B: MessageBody,
308
309    X: Service<Request, Response = Request>,
310    X::Error: Into<Response<BoxBody>>,
311
312    U: Service<(Request, Framed<T, Codec>), Response = ()>,
313    U::Error: fmt::Display,
314{
315    fn can_read(&self, cx: &mut Context<'_>) -> bool {
316        if self.flags.contains(Flags::READ_DISCONNECT) {
317            false
318        } else if let Some(ref info) = self.payload {
319            matches!(
320                info.need_read(cx),
321                PayloadStatus::Read | PayloadStatus::Dropped
322            )
323        } else {
324            true
325        }
326    }
327
328    fn client_disconnected(self: Pin<&mut Self>) {
329        let this = self.project();
330
331        this.flags
332            .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
333
334        if let Some(mut payload) = this.payload.take() {
335            payload.set_error(PayloadError::Incomplete(None));
336        }
337    }
338
339    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
340        let InnerDispatcherProj { io, write_buf, .. } = self.project();
341        let mut io = Pin::new(io.as_mut().unwrap());
342
343        let len = write_buf.len();
344        let mut written = 0;
345
346        while written < len {
347            match io.as_mut().poll_write(cx, &write_buf[written..])? {
348                Poll::Ready(0) => {
349                    error!("write zero; closing");
350                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, "")));
351                }
352
353                Poll::Ready(n) => written += n,
354
355                Poll::Pending => {
356                    write_buf.advance(written);
357                    return Poll::Pending;
358                }
359            }
360        }
361
362        // everything has written to I/O; clear buffer
363        write_buf.clear();
364
365        // flush the I/O and check if get blocked
366        io.poll_flush(cx)
367    }
368
369    fn enter_linger(flags: &mut Flags) {
370        flags.remove(Flags::KEEP_ALIVE);
371        flags.insert(Flags::LINGER | Flags::FINISHED);
372    }
373
374    fn ensure_linger_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool {
375        let this = self.as_mut().project();
376
377        if matches!(this.shutdown_timer, TimerState::Active { .. }) {
378            return true;
379        }
380
381        if let Some(deadline) = this.config.client_disconnect_deadline() {
382            this.shutdown_timer
383                .set_and_init(cx, sleep_until(deadline.into()), line!());
384            true
385        } else {
386            false
387        }
388    }
389
390    fn poll_linger(
391        mut self: Pin<&mut Self>,
392        cx: &mut Context<'_>,
393    ) -> Result<Poll<()>, DispatchError> {
394        if self.as_mut().poll_flush(cx)?.is_pending() {
395            return Ok(Poll::Pending);
396        }
397
398        if !self.as_mut().ensure_linger_timer(cx) {
399            let this = self.as_mut().project();
400            this.flags.remove(Flags::LINGER);
401            this.flags.insert(Flags::SHUTDOWN);
402            return Ok(Poll::Ready(()));
403        }
404
405        loop {
406            let should_disconnect = self.as_mut().read_available(cx)?;
407            let this = self.as_mut().project();
408            let mut progressed = false;
409
410            if !this.read_buf.is_empty() {
411                this.read_buf.clear();
412                progressed = true;
413            }
414
415            if should_disconnect {
416                this.flags.remove(Flags::LINGER);
417                this.flags.insert(Flags::READ_DISCONNECT | Flags::SHUTDOWN);
418                return Ok(Poll::Ready(()));
419            }
420
421            if !progressed {
422                return Ok(Poll::Pending);
423            }
424        }
425    }
426
427    fn send_response_inner(
428        self: Pin<&mut Self>,
429        res: Response<()>,
430        body: &impl MessageBody,
431    ) -> Result<BodySize, DispatchError> {
432        let this = self.project();
433
434        let size = body.size();
435
436        this.codec
437            .encode(Message::Item((res, size)), this.write_buf)
438            .map_err(|err| {
439                if let Some(mut payload) = this.payload.take() {
440                    payload.set_error(PayloadError::Incomplete(None));
441                }
442
443                DispatchError::Io(err)
444            })?;
445
446        Ok(size)
447    }
448
449    fn send_response(
450        mut self: Pin<&mut Self>,
451        mut res: Response<()>,
452        body: B,
453    ) -> Result<(), DispatchError> {
454        let close_after_response = {
455            let this = self.as_mut().project();
456            should_close_after_response(this.payload.as_ref(), *this.payload_drainable)
457        };
458
459        if close_after_response {
460            res.head_mut().set_connection_type(ConnectionType::Close);
461        }
462
463        let size = self.as_mut().send_response_inner(res, &body)?;
464        match size {
465            BodySize::None | BodySize::Sized(0) => {
466                let mut this = self.as_mut().project();
467
468                if close_after_response {
469                    if this.config.client_disconnect_deadline().is_some() {
470                        Self::enter_linger(this.flags);
471                    } else {
472                        this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
473                    }
474                } else {
475                    this.flags.insert(Flags::FINISHED);
476                }
477
478                this.state.set(State::None);
479            }
480            _ => self
481                .as_mut()
482                .project()
483                .state
484                .set(State::SendPayload { body }),
485        }
486
487        Ok(())
488    }
489
490    fn send_error_response(
491        mut self: Pin<&mut Self>,
492        mut res: Response<()>,
493        body: BoxBody,
494    ) -> Result<(), DispatchError> {
495        let close_after_response = {
496            let this = self.as_mut().project();
497            should_close_after_response(this.payload.as_ref(), *this.payload_drainable)
498        };
499
500        if close_after_response {
501            res.head_mut().set_connection_type(ConnectionType::Close);
502        }
503
504        let size = self.as_mut().send_response_inner(res, &body)?;
505        match size {
506            BodySize::None | BodySize::Sized(0) => {
507                let mut this = self.as_mut().project();
508
509                if close_after_response {
510                    if this.config.client_disconnect_deadline().is_some() {
511                        Self::enter_linger(this.flags);
512                    } else {
513                        this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
514                    }
515                } else {
516                    this.flags.insert(Flags::FINISHED);
517                }
518
519                this.state.set(State::None);
520            }
521            _ => self
522                .as_mut()
523                .project()
524                .state
525                .set(State::SendErrorPayload { body }),
526        }
527
528        Ok(())
529    }
530
531    fn send_continue(self: Pin<&mut Self>) {
532        self.project()
533            .write_buf
534            .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
535    }
536
537    fn poll_response(
538        mut self: Pin<&mut Self>,
539        cx: &mut Context<'_>,
540    ) -> Result<PollResponse, DispatchError> {
541        'res: loop {
542            let mut this = self.as_mut().project();
543            match this.state.as_mut().project() {
544                // no future is in InnerDispatcher state; pop next message
545                StateProj::None => match this.messages.pop_front() {
546                    // handle request message
547                    Some(DispatcherMessage::Item(req)) => {
548                        // Handle `EXPECT: 100-Continue` header
549                        if req.head().expect() {
550                            // set InnerDispatcher state and continue loop to poll it
551                            let fut = this.flow.expect.call(req);
552                            this.state.set(State::ExpectCall { fut });
553                        } else {
554                            // set InnerDispatcher state and continue loop to poll it
555                            let fut = this.flow.service.call(req);
556                            this.state.set(State::ServiceCall { fut });
557                        };
558                    }
559
560                    // handle error message
561                    Some(DispatcherMessage::Error(res)) => {
562                        // send_response would update InnerDispatcher state to SendPayload or None
563                        // (If response body is empty)
564                        // continue loop to poll it
565                        self.as_mut().send_error_response(res, BoxBody::new(()))?;
566                    }
567
568                    // return with upgrade request and poll it exclusively
569                    Some(DispatcherMessage::Upgrade(req)) => return Ok(PollResponse::Upgrade(req)),
570
571                    // all messages are dealt with
572                    None => {
573                        // start keep-alive only if request payload is fully read/drained
574                        this.flags.set(
575                            Flags::KEEP_ALIVE,
576                            this.payload.is_none() && this.codec.keep_alive(),
577                        );
578
579                        return Ok(PollResponse::DoNothing);
580                    }
581                },
582
583                StateProj::ServiceCall { fut } => {
584                    match fut.poll(cx) {
585                        // service call resolved. send response.
586                        Poll::Ready(Ok(res)) => {
587                            let (res, body) = res.into().replace_body(());
588                            self.as_mut().send_response(res, body)?;
589                        }
590
591                        // send service call error as response
592                        Poll::Ready(Err(err)) => {
593                            let res: Response<BoxBody> = err.into();
594                            let (res, body) = res.replace_body(());
595                            self.as_mut().send_error_response(res, body)?;
596                        }
597
598                        // service call pending and could be waiting for more chunk messages
599                        // (pipeline message limit and/or payload can_read limit)
600                        Poll::Pending => {
601                            // no new message is decoded and no new payload is fed
602                            // nothing to do except waiting for new incoming data from client
603                            if !self.as_mut().poll_request(cx)? {
604                                return Ok(PollResponse::DoNothing);
605                            }
606                            // else loop
607                        }
608                    }
609                }
610
611                StateProj::SendPayload { mut body } => {
612                    // keep populate writer buffer until buffer size limit hit,
613                    // get blocked or finished.
614                    while this.write_buf.len() < *this.h1_write_buffer_size {
615                        match body.as_mut().poll_next(cx) {
616                            Poll::Ready(Some(Ok(item))) => {
617                                this.codec
618                                    .encode(Message::Chunk(Some(item)), this.write_buf)?;
619                            }
620
621                            Poll::Ready(None) => {
622                                this.codec.encode(Message::Chunk(None), this.write_buf)?;
623
624                                // if we have not yet pipelined to the next request, then
625                                // this.payload was the payload for the request we just finished
626                                // responding to. We can check to see if we finished reading it
627                                // yet, and if not, shutdown the connection.
628                                let close_after_response = should_close_after_response(
629                                    this.payload.as_ref(),
630                                    *this.payload_drainable,
631                                );
632                                let not_pipelined = this.messages.is_empty();
633
634                                // payload stream finished.
635                                // set state to None and handle next message
636                                this.state.set(State::None);
637
638                                if not_pipelined && close_after_response {
639                                    if this.config.client_disconnect_deadline().is_some() {
640                                        Self::enter_linger(this.flags);
641                                    } else {
642                                        this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
643                                    }
644                                } else {
645                                    this.flags.insert(Flags::FINISHED);
646                                }
647
648                                continue 'res;
649                            }
650
651                            Poll::Ready(Some(Err(err))) => {
652                                let err = err.into();
653                                tracing::error!("Response payload stream error: {err:?}");
654                                this.flags.insert(Flags::FINISHED);
655                                return Err(DispatchError::Body(err));
656                            }
657
658                            Poll::Pending => return Ok(PollResponse::DoNothing),
659                        }
660                    }
661
662                    // buffer is beyond max size
663                    // return and try to write the whole buffer to I/O stream.
664                    return Ok(PollResponse::DrainWriteBuf);
665                }
666
667                StateProj::SendErrorPayload { mut body } => {
668                    // TODO: de-dupe impl with SendPayload
669
670                    // keep populate writer buffer until buffer size limit hit,
671                    // get blocked or finished.
672                    while this.write_buf.len() < *this.h1_write_buffer_size {
673                        match body.as_mut().poll_next(cx) {
674                            Poll::Ready(Some(Ok(item))) => {
675                                this.codec
676                                    .encode(Message::Chunk(Some(item)), this.write_buf)?;
677                            }
678
679                            Poll::Ready(None) => {
680                                this.codec.encode(Message::Chunk(None), this.write_buf)?;
681
682                                // if we have not yet pipelined to the next request, then
683                                // this.payload was the payload for the request we just finished
684                                // responding to. We can check to see if we finished reading it
685                                // yet, and if not, shutdown the connection.
686                                let close_after_response = should_close_after_response(
687                                    this.payload.as_ref(),
688                                    *this.payload_drainable,
689                                );
690                                let not_pipelined = this.messages.is_empty();
691
692                                // payload stream finished.
693                                // set state to None and handle next message
694                                this.state.set(State::None);
695
696                                if not_pipelined && close_after_response {
697                                    if this.config.client_disconnect_deadline().is_some() {
698                                        Self::enter_linger(this.flags);
699                                    } else {
700                                        this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
701                                    }
702                                } else {
703                                    this.flags.insert(Flags::FINISHED);
704                                }
705
706                                continue 'res;
707                            }
708
709                            Poll::Ready(Some(Err(err))) => {
710                                tracing::error!("Response payload stream error: {err:?}");
711                                this.flags.insert(Flags::FINISHED);
712                                return Err(DispatchError::Body(
713                                    Error::new_body().with_cause(err).into(),
714                                ));
715                            }
716
717                            Poll::Pending => return Ok(PollResponse::DoNothing),
718                        }
719                    }
720
721                    // buffer is beyond max size
722                    // return and try to write the whole buffer to stream
723                    return Ok(PollResponse::DrainWriteBuf);
724                }
725
726                StateProj::ExpectCall { fut } => {
727                    trace!("  calling expect service");
728
729                    match fut.poll(cx) {
730                        // expect resolved. write continue to buffer and set InnerDispatcher state
731                        // to service call.
732                        Poll::Ready(Ok(req)) => {
733                            this.write_buf
734                                .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
735                            let fut = this.flow.service.call(req);
736                            this.state.set(State::ServiceCall { fut });
737                        }
738
739                        // send expect error as response
740                        Poll::Ready(Err(err)) => {
741                            let res: Response<BoxBody> = err.into();
742                            let (res, body) = res.replace_body(());
743                            self.as_mut().send_error_response(res, body)?;
744                        }
745
746                        // expect must be solved before progress can be made.
747                        Poll::Pending => return Ok(PollResponse::DoNothing),
748                    }
749                }
750            }
751        }
752    }
753
754    fn handle_request(
755        mut self: Pin<&mut Self>,
756        req: Request,
757        cx: &mut Context<'_>,
758    ) -> Result<(), DispatchError> {
759        // initialize dispatcher state
760        {
761            let mut this = self.as_mut().project();
762
763            // Handle `EXPECT: 100-Continue` header
764            if req.head().expect() {
765                // set dispatcher state to call expect handler
766                let fut = this.flow.expect.call(req);
767                this.state.set(State::ExpectCall { fut });
768            } else {
769                // set dispatcher state to call service handler
770                let fut = this.flow.service.call(req);
771                this.state.set(State::ServiceCall { fut });
772            };
773        };
774
775        // eagerly poll the future once (or twice if expect is resolved immediately).
776        loop {
777            match self.as_mut().project().state.project() {
778                StateProj::ExpectCall { fut } => {
779                    match fut.poll(cx) {
780                        // expect is resolved; continue loop and poll the service call branch.
781                        Poll::Ready(Ok(req)) => {
782                            self.as_mut().send_continue();
783
784                            let mut this = self.as_mut().project();
785                            let fut = this.flow.service.call(req);
786                            this.state.set(State::ServiceCall { fut });
787
788                            continue;
789                        }
790
791                        // future is error; send response and return a result
792                        // on success to notify the dispatcher a new state is set and the outer loop
793                        // should be continued
794                        Poll::Ready(Err(err)) => {
795                            let res: Response<BoxBody> = err.into();
796                            let (res, body) = res.replace_body(());
797                            return self.send_error_response(res, body);
798                        }
799
800                        // future is pending; return Ok(()) to notify that a new state is
801                        // set and the outer loop should be continue.
802                        Poll::Pending => return Ok(()),
803                    }
804                }
805
806                StateProj::ServiceCall { fut } => {
807                    // return no matter the service call future's result.
808                    return match fut.poll(cx) {
809                        // Future is resolved. Send response and return a result. On success
810                        // to notify the dispatcher a new state is set and the outer loop
811                        // should be continue.
812                        Poll::Ready(Ok(res)) => {
813                            let (res, body) = res.into().replace_body(());
814                            self.as_mut().send_response(res, body)
815                        }
816
817                        // see the comment on ExpectCall state branch's Pending
818                        Poll::Pending => Ok(()),
819
820                        // see the comment on ExpectCall state branch's Ready(Err(_))
821                        Poll::Ready(Err(err)) => {
822                            let res: Response<BoxBody> = err.into();
823                            let (res, body) = res.replace_body(());
824                            self.as_mut().send_error_response(res, body)
825                        }
826                    };
827                }
828
829                _ => {
830                    unreachable!("State must be set to ServiceCall or ExceptCall in handle_request")
831                }
832            }
833        }
834    }
835
836    /// Process one incoming request.
837    ///
838    /// Returns true if any meaningful work was done.
839    fn poll_request(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
840        let pipeline_queue_full = self.messages.len() >= MAX_PIPELINED_MESSAGES;
841        let can_not_read = !self.can_read(cx);
842
843        // limit amount of non-processed requests
844        if pipeline_queue_full || can_not_read {
845            return Ok(false);
846        }
847
848        let mut this = self.as_mut().project();
849
850        let mut updated = false;
851
852        // decode from read buf as many full requests as possible
853        loop {
854            match this.codec.decode(this.read_buf) {
855                Ok(Some(msg)) => {
856                    updated = true;
857
858                    match msg {
859                        Message::Item(mut req) => {
860                            // head timer only applies to first request on connection
861                            this.head_timer.clear(line!());
862
863                            req.head_mut().peer_addr = *this.peer_addr;
864
865                            req.conn_data.clone_from(this.conn_data);
866
867                            match this.codec.message_type() {
868                                // request has no payload
869                                MessageType::None => *this.payload_drainable = false,
870
871                                // Request is upgradable. Add upgrade message and break.
872                                // Everything remaining in read buffer will be handed to
873                                // upgraded Request.
874                                MessageType::Stream if this.flow.upgrade.is_some() => {
875                                    *this.payload_drainable = false;
876                                    this.messages.push_back(DispatcherMessage::Upgrade(req));
877                                    break;
878                                }
879
880                                // request is not upgradable
881                                MessageType::Payload | MessageType::Stream => {
882                                    // PayloadSender and Payload are smart pointers share the
883                                    // same state. PayloadSender is attached to dispatcher and used
884                                    // to sink new chunked request data to state. Payload is
885                                    // attached to Request and passed to Service::call where the
886                                    // state can be collected and consumed.
887                                    let (sender, payload) = Payload::create(false);
888                                    *req.payload() = crate::Payload::H1 { payload };
889                                    *this.payload = Some(sender);
890                                    *this.payload_drainable = req.chunked().unwrap_or(false);
891                                }
892                            }
893
894                            // handle request early when no future in InnerDispatcher state.
895                            if this.state.is_none() {
896                                self.as_mut().handle_request(req, cx)?;
897                                this = self.as_mut().project();
898                            } else {
899                                this.messages.push_back(DispatcherMessage::Item(req));
900                            }
901                        }
902
903                        Message::Chunk(Some(chunk)) => {
904                            if let Some(ref mut payload) = this.payload {
905                                payload.feed_data(chunk);
906                            } else {
907                                error!("Internal server error: unexpected payload chunk");
908                                this.flags.insert(Flags::READ_DISCONNECT);
909                                this.messages.push_back(DispatcherMessage::Error(
910                                    Response::internal_server_error().drop_body(),
911                                ));
912                                *this.error = Some(DispatchError::InternalError);
913                                break;
914                            }
915                        }
916
917                        Message::Chunk(None) => {
918                            if let Some(mut payload) = this.payload.take() {
919                                payload.feed_eof();
920                                *this.payload_drainable = false;
921                            } else {
922                                error!("Internal server error: unexpected eof");
923                                this.flags.insert(Flags::READ_DISCONNECT);
924                                this.messages.push_back(DispatcherMessage::Error(
925                                    Response::internal_server_error().drop_body(),
926                                ));
927                                *this.error = Some(DispatchError::InternalError);
928                                break;
929                            }
930                        }
931                    }
932                }
933
934                // decode is partial and buffer is not full yet
935                // break and wait for more read
936                Ok(None) => break,
937
938                Err(ParseError::Io(err)) => {
939                    trace!("I/O error: {}", &err);
940                    self.as_mut().client_disconnected();
941                    this = self.as_mut().project();
942                    *this.error = Some(DispatchError::Io(err));
943                    break;
944                }
945
946                Err(ParseError::TooLarge) => {
947                    trace!("request head was too big; returning 431 response");
948
949                    if let Some(mut payload) = this.payload.take() {
950                        payload.set_error(PayloadError::Overflow);
951                    }
952
953                    // request heads that overflow buffer size return a 431 error
954                    this.messages
955                        .push_back(DispatcherMessage::Error(Response::with_body(
956                            StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
957                            (),
958                        )));
959
960                    this.flags.insert(Flags::READ_DISCONNECT);
961                    *this.error = Some(ParseError::TooLarge.into());
962
963                    break;
964                }
965
966                Err(err) => {
967                    trace!("parse error {}", &err);
968
969                    if let Some(mut payload) = this.payload.take() {
970                        payload.set_error(PayloadError::EncodingCorrupted);
971                    }
972
973                    // malformed requests should be responded with 400
974                    this.messages.push_back(DispatcherMessage::Error(
975                        Response::bad_request().drop_body(),
976                    ));
977
978                    this.flags.insert(Flags::READ_DISCONNECT);
979                    *this.error = Some(err.into());
980                    break;
981                }
982            }
983        }
984
985        Ok(updated)
986    }
987
988    fn poll_head_timer(
989        mut self: Pin<&mut Self>,
990        cx: &mut Context<'_>,
991    ) -> Result<(), DispatchError> {
992        let this = self.as_mut().project();
993
994        if let TimerState::Active { timer } = this.head_timer {
995            if timer.as_mut().poll(cx).is_ready() {
996                // timeout on first request (slow request) return 408
997
998                trace!("timed out on slow request; replying with 408 and closing connection");
999
1000                let _ = self.as_mut().send_error_response(
1001                    Response::with_body(StatusCode::REQUEST_TIMEOUT, ()),
1002                    BoxBody::new(()),
1003                );
1004
1005                self.project().flags.insert(Flags::SHUTDOWN);
1006            }
1007        };
1008
1009        Ok(())
1010    }
1011
1012    fn poll_ka_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
1013        let this = self.as_mut().project();
1014        if let TimerState::Active { timer } = this.ka_timer {
1015            debug_assert!(
1016                this.flags.contains(Flags::KEEP_ALIVE),
1017                "keep-alive flag should be set when timer is active",
1018            );
1019            debug_assert!(
1020                this.state.is_none(),
1021                "dispatcher should not be in keep-alive phase if state is not none: {:?}",
1022                this.state,
1023            );
1024
1025            // Assert removed by @robjtede on account of issue #2655. There are cases where an I/O
1026            // flush can be pending after entering the keep-alive state causing the subsequent flush
1027            // wake up to panic here. This appears to be a Linux-only problem. Leaving original code
1028            // below for posterity because a simple and reliable test could not be found to trigger
1029            // the behavior.
1030            // debug_assert!(
1031            //     this.write_buf.is_empty(),
1032            //     "dispatcher should not be in keep-alive phase if write_buf is not empty",
1033            // );
1034
1035            // keep-alive timer has timed out
1036            if timer.as_mut().poll(cx).is_ready() {
1037                // no tasks at hand
1038                trace!("timer timed out; closing connection");
1039                this.flags.insert(Flags::SHUTDOWN);
1040
1041                if let Some(deadline) = this.config.client_disconnect_deadline() {
1042                    // start shutdown timeout if enabled
1043                    this.shutdown_timer
1044                        .set_and_init(cx, sleep_until(deadline.into()), line!());
1045                } else {
1046                    // no shutdown timeout, drop socket
1047                    this.flags.insert(Flags::WRITE_DISCONNECT);
1048                }
1049            }
1050        }
1051
1052        Ok(())
1053    }
1054
1055    fn poll_shutdown_timer(
1056        mut self: Pin<&mut Self>,
1057        cx: &mut Context<'_>,
1058    ) -> Result<(), DispatchError> {
1059        let this = self.as_mut().project();
1060        if let TimerState::Active { timer } = this.shutdown_timer {
1061            debug_assert!(
1062                this.flags.intersects(Flags::LINGER | Flags::SHUTDOWN),
1063                "shutdown or linger flag should be set when timer is active",
1064            );
1065
1066            if timer.as_mut().poll(cx).is_ready() {
1067                if this.flags.contains(Flags::LINGER) {
1068                    trace!("timed-out during linger; shutting down connection");
1069                    this.flags.remove(Flags::LINGER);
1070                    this.flags.insert(Flags::SHUTDOWN);
1071                    this.shutdown_timer.clear(line!());
1072                } else {
1073                    trace!("timed-out during shutdown");
1074                    return Err(DispatchError::DisconnectTimeout);
1075                }
1076            }
1077        }
1078
1079        Ok(())
1080    }
1081
1082    /// Poll head, keep-alive, and disconnect timer.
1083    fn poll_timers(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
1084        self.as_mut().poll_head_timer(cx)?;
1085        self.as_mut().poll_ka_timer(cx)?;
1086        self.as_mut().poll_shutdown_timer(cx)?;
1087
1088        Ok(())
1089    }
1090
1091    /// Returns true when I/O stream can be disconnected after write to it.
1092    ///
1093    /// It covers these conditions:
1094    /// - `std::io::ErrorKind::ConnectionReset` after partial read;
1095    /// - all data read done.
1096    #[inline(always)] // TODO: bench this inline
1097    fn read_available(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
1098        let this = self.project();
1099
1100        if this.flags.contains(Flags::READ_DISCONNECT) {
1101            return Ok(false);
1102        };
1103
1104        let mut io = Pin::new(this.io.as_mut().unwrap());
1105
1106        let mut read_some = false;
1107
1108        loop {
1109            // Return early when read buf exceed decoder's max buffer size.
1110            if this.read_buf.len() >= MAX_BUFFER_SIZE {
1111                // At this point it's not known IO stream is still scheduled to be waked up so
1112                // force wake up dispatcher just in case.
1113                //
1114                // Reason:
1115                // AsyncRead mostly would only have guarantee wake up when the poll_read
1116                // return Poll::Pending.
1117                //
1118                // Case:
1119                // When read_buf is beyond max buffer size the early return could be successfully
1120                // be parsed as a new Request. This case would not generate ParseError::TooLarge and
1121                // at this point IO stream is not fully read to Pending and would result in
1122                // dispatcher stuck until timeout (keep-alive).
1123                //
1124                // Note:
1125                // This is a perf choice to reduce branch on <Request as MessageType>::decode.
1126                //
1127                // A Request head too large to parse is only checked on `httparse::Status::Partial`.
1128
1129                match this.payload.as_ref().map(|p| p.need_read(cx)) {
1130                    // Payload consumer is alive but applying backpressure. Wait for its waker.
1131                    Some(PayloadStatus::Pause) => {}
1132
1133                    // Consumer dropped means drain/discard mode; keep polling to make progress.
1134                    Some(PayloadStatus::Dropped) | Some(PayloadStatus::Read) | None => {
1135                        cx.waker().wake_by_ref()
1136                    }
1137                }
1138
1139                return Ok(false);
1140            }
1141
1142            // grow buffer if necessary.
1143            let remaining = this.read_buf.capacity() - this.read_buf.len();
1144            if remaining < LW_BUFFER_SIZE {
1145                this.read_buf.reserve(HW_BUFFER_SIZE - remaining);
1146            }
1147
1148            match tokio_util::io::poll_read_buf(io.as_mut(), cx, this.read_buf) {
1149                Poll::Ready(Ok(n)) => {
1150                    // When draining a dropped request payload, keep FINISHED set so the
1151                    // disconnect/keep-alive decision can be made once the payload is fully drained.
1152                    if !this.payload.as_ref().is_some_and(|pl| pl.is_dropped()) {
1153                        this.flags.remove(Flags::FINISHED);
1154                    }
1155
1156                    if n == 0 {
1157                        return Ok(true);
1158                    }
1159
1160                    read_some = true;
1161                }
1162
1163                Poll::Pending => {
1164                    return Ok(false);
1165                }
1166
1167                Poll::Ready(Err(err)) => {
1168                    return match err.kind() {
1169                        // convert WouldBlock error to the same as Pending return
1170                        io::ErrorKind::WouldBlock => Ok(false),
1171
1172                        // connection reset after partial read
1173                        io::ErrorKind::ConnectionReset if read_some => Ok(true),
1174
1175                        _ => Err(DispatchError::Io(err)),
1176                    };
1177                }
1178            }
1179        }
1180    }
1181
1182    /// call upgrade service with request.
1183    fn upgrade(self: Pin<&mut Self>, req: Request) -> U::Future {
1184        let this = self.project();
1185        let mut parts = FramedParts::with_read_buf(
1186            this.io.take().unwrap(),
1187            mem::take(this.codec),
1188            mem::take(this.read_buf),
1189        );
1190        parts.write_buf = mem::take(this.write_buf);
1191        let framed = Framed::from_parts(parts);
1192        this.flow.upgrade.as_ref().unwrap().call((req, framed))
1193    }
1194}
1195
1196impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
1197where
1198    T: AsyncRead + AsyncWrite + Unpin,
1199
1200    S: Service<Request>,
1201    S::Error: Into<Response<BoxBody>>,
1202    S::Response: Into<Response<B>>,
1203
1204    B: MessageBody,
1205
1206    X: Service<Request, Response = Request>,
1207    X::Error: Into<Response<BoxBody>>,
1208
1209    U: Service<(Request, Framed<T, Codec>), Response = ()>,
1210    U::Error: fmt::Display,
1211{
1212    type Output = Result<(), DispatchError>;
1213
1214    #[inline]
1215    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1216        let this = self.as_mut().project();
1217
1218        #[cfg(test)]
1219        {
1220            *this.poll_count += 1;
1221        }
1222
1223        match this.inner.project() {
1224            DispatcherStateProj::Upgrade { fut: upgrade } => upgrade.poll(cx).map_err(|err| {
1225                error!("Upgrade handler error: {}", err);
1226                DispatchError::Upgrade
1227            }),
1228
1229            DispatcherStateProj::Normal { mut inner } => {
1230                trace!("start flags: {:?}", &inner.flags);
1231
1232                trace_timer_states(
1233                    "start",
1234                    &inner.head_timer,
1235                    &inner.ka_timer,
1236                    &inner.shutdown_timer,
1237                );
1238
1239                inner.as_mut().poll_timers(cx)?;
1240
1241                let poll = if inner.flags.contains(Flags::LINGER) {
1242                    match inner.as_mut().poll_linger(cx)? {
1243                        Poll::Ready(()) => {
1244                            cx.waker().wake_by_ref();
1245                            Poll::Pending
1246                        }
1247                        Poll::Pending => Poll::Pending,
1248                    }
1249                } else if inner.flags.contains(Flags::SHUTDOWN) {
1250                    if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1251                        Poll::Ready(Ok(()))
1252                    } else {
1253                        // flush buffer and wait on blocked
1254                        ready!(inner.as_mut().poll_flush(cx))?;
1255                        Pin::new(inner.as_mut().project().io.as_mut().unwrap())
1256                            .poll_shutdown(cx)
1257                            .map_err(DispatchError::from)
1258                    }
1259                } else {
1260                    // read from I/O stream and fill read buffer
1261                    let should_disconnect = inner.as_mut().read_available(cx)?;
1262
1263                    // after reading something from stream, clear keep-alive timer
1264                    if !inner.read_buf.is_empty() && inner.flags.contains(Flags::KEEP_ALIVE) {
1265                        let inner = inner.as_mut().project();
1266                        inner.flags.remove(Flags::KEEP_ALIVE);
1267                        inner.ka_timer.clear(line!());
1268                    }
1269
1270                    if !inner.flags.contains(Flags::STARTED) {
1271                        inner.as_mut().project().flags.insert(Flags::STARTED);
1272
1273                        if let Some(deadline) = inner.config.client_request_deadline() {
1274                            inner.as_mut().project().head_timer.set_and_init(
1275                                cx,
1276                                sleep_until(deadline.into()),
1277                                line!(),
1278                            );
1279                        }
1280                    }
1281
1282                    inner.as_mut().poll_request(cx)?;
1283
1284                    if should_disconnect {
1285                        // I/O stream should to be closed
1286                        let inner = inner.as_mut().project();
1287                        inner.flags.insert(Flags::READ_DISCONNECT);
1288                        if let Some(mut payload) = inner.payload.take() {
1289                            payload.set_error(PayloadError::Incomplete(None));
1290                            payload.feed_eof();
1291                        }
1292                    };
1293
1294                    loop {
1295                        // poll response to populate write buffer
1296                        // drain indicates whether write buffer should be emptied before next run
1297                        let drain = match inner.as_mut().poll_response(cx)? {
1298                            PollResponse::DrainWriteBuf => true,
1299
1300                            PollResponse::DoNothing => {
1301                                // KEEP_ALIVE is set in send_response_inner if client allows it
1302                                // FINISHED is set after writing last chunk of response
1303                                if inner.flags.contains(Flags::KEEP_ALIVE | Flags::FINISHED) {
1304                                    if let Some(timer) = inner.config.keep_alive_deadline() {
1305                                        inner.as_mut().project().ka_timer.set_and_init(
1306                                            cx,
1307                                            sleep_until(timer.into()),
1308                                            line!(),
1309                                        );
1310                                    }
1311                                }
1312
1313                                false
1314                            }
1315
1316                            // upgrade request and goes Upgrade variant of DispatcherState.
1317                            PollResponse::Upgrade(req) => {
1318                                let upgrade = inner.upgrade(req);
1319                                self.as_mut()
1320                                    .project()
1321                                    .inner
1322                                    .set(DispatcherState::Upgrade { fut: upgrade });
1323                                return self.poll(cx);
1324                            }
1325                        };
1326
1327                        // we didn't get WouldBlock from write operation, so data get written to
1328                        // kernel completely (macOS) and we have to write again otherwise response
1329                        // can get stuck
1330                        //
1331                        // TODO: want to find a reference for this behavior
1332                        // see introduced commit: 3872d3ba
1333                        let flush_was_ready = inner.as_mut().poll_flush(cx)?.is_ready();
1334
1335                        // this assert seems to always be true but not willing to commit to it until
1336                        // we understand what Nikolay meant when writing the above comment
1337                        // debug_assert!(flush_was_ready);
1338
1339                        if !flush_was_ready || !drain {
1340                            break;
1341                        }
1342                    }
1343
1344                    // client is gone
1345                    if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1346                        trace!("client is gone; disconnecting");
1347                        return Poll::Ready(Ok(()));
1348                    }
1349
1350                    let inner_p = inner.as_mut().project();
1351                    let state_is_none = inner_p.state.is_none();
1352
1353                    // If the read-half is closed, we start the shutdown procedure if either is
1354                    // true:
1355                    //
1356                    // - state is [`State::None`], which means that we're done with request
1357                    //   processing, so if the client closed its writer-side it means that it won't
1358                    //   send more requests.
1359                    // - The user requested to not allow half-closures
1360                    if inner_p.flags.contains(Flags::READ_DISCONNECT)
1361                        && (!inner_p.config.h1_allow_half_closed() || state_is_none)
1362                    {
1363                        trace!("read half closed; start shutdown");
1364                        inner_p.flags.insert(Flags::SHUTDOWN);
1365                    }
1366
1367                    // keep-alive and stream errors
1368                    if state_is_none && inner_p.write_buf.is_empty() {
1369                        if let Some(err) = inner_p.error.take() {
1370                            error!("stream error: {}", &err);
1371                            return Poll::Ready(Err(err));
1372                        }
1373
1374                        // disconnect if keep-alive is not enabled
1375                        if inner_p.flags.contains(Flags::FINISHED)
1376                            && !inner_p.flags.contains(Flags::KEEP_ALIVE)
1377                            && inner_p.payload.is_none()
1378                        {
1379                            inner_p.flags.remove(Flags::FINISHED);
1380                            inner_p.flags.insert(Flags::SHUTDOWN);
1381                            return self.poll(cx);
1382                        }
1383
1384                        // disconnect if shutdown
1385                        if inner_p.flags.contains(Flags::SHUTDOWN) {
1386                            return self.poll(cx);
1387                        }
1388                    }
1389
1390                    trace_timer_states(
1391                        "end",
1392                        inner_p.head_timer,
1393                        inner_p.ka_timer,
1394                        inner_p.shutdown_timer,
1395                    );
1396
1397                    if inner_p.flags.intersects(Flags::LINGER | Flags::SHUTDOWN) {
1398                        cx.waker().wake_by_ref();
1399                    }
1400                    Poll::Pending
1401                };
1402
1403                trace!("end flags: {:?}", &inner.flags);
1404
1405                poll
1406            }
1407        }
1408    }
1409}
1410
1411fn should_close_after_response(payload: Option<&PayloadSender>, payload_drainable: bool) -> bool {
1412    let payload_unfinished = payload.is_some();
1413    let drain_payload = payload.is_some_and(|pl| pl.is_dropped()) && payload_drainable;
1414
1415    payload_unfinished && !drain_payload
1416}
1417
1418#[allow(dead_code)]
1419fn trace_timer_states(
1420    label: &str,
1421    head_timer: &TimerState,
1422    ka_timer: &TimerState,
1423    shutdown_timer: &TimerState,
1424) {
1425    trace!("{} timers:", label);
1426
1427    if head_timer.is_enabled() {
1428        trace!("  head {}", &head_timer);
1429    }
1430
1431    if ka_timer.is_enabled() {
1432        trace!("  keep-alive {}", &ka_timer);
1433    }
1434
1435    if shutdown_timer.is_enabled() {
1436        trace!("  shutdown {}", &shutdown_timer);
1437    }
1438}