1use std::{
2 collections::VecDeque,
3 fmt,
4 future::Future,
5 io, mem, net,
6 pin::Pin,
7 rc::Rc,
8 task::{Context, Poll},
9};
10
11use actix_codec::{Framed, FramedParts};
12use actix_rt::time::sleep_until;
13use actix_service::Service;
14use bitflags::bitflags;
15use bytes::{Buf, BytesMut};
16use futures_core::ready;
17use pin_project_lite::pin_project;
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio_util::codec::{Decoder as _, Encoder as _};
20use tracing::{error, trace};
21
22use super::{
23 codec::Codec,
24 decoder::MAX_BUFFER_SIZE,
25 payload::{Payload, PayloadSender, PayloadStatus},
26 timer::TimerState,
27 Message, MessageType,
28};
29use crate::{
30 body::{BodySize, BoxBody, MessageBody},
31 config::ServiceConfig,
32 error::{DispatchError, ParseError, PayloadError},
33 service::HttpFlow,
34 ConnectionType, Error, Extensions, HttpMessage, OnConnectData, Request, Response, StatusCode,
35};
36
37const LW_BUFFER_SIZE: usize = 1024;
38const HW_BUFFER_SIZE: usize = 1024 * 8;
39const MAX_PIPELINED_MESSAGES: usize = 16;
40
41bitflags! {
42 #[derive(Debug, Clone, Copy)]
43 pub struct Flags: u8 {
44 const STARTED = 0b0000_0001;
46
47 const FINISHED = 0b0000_0010;
49
50 const KEEP_ALIVE = 0b0000_0100;
52
53 const SHUTDOWN = 0b0000_1000;
55
56 const READ_DISCONNECT = 0b0001_0000;
58
59 const WRITE_DISCONNECT = 0b0010_0000;
61
62 const LINGER = 0b0100_0000;
64 }
65}
66
67#[cfg(not(test))]
73pin_project! {
74 pub struct Dispatcher<T, S, B, X, U>
76 where
77 S: Service<Request>,
78 S::Error: Into<Response<BoxBody>>,
79
80 B: MessageBody,
81
82 X: Service<Request, Response = Request>,
83 X::Error: Into<Response<BoxBody>>,
84
85 U: Service<(Request, Framed<T, Codec>), Response = ()>,
86 U::Error: fmt::Display,
87 {
88 #[pin]
89 inner: DispatcherState<T, S, B, X, U>,
90 }
91}
92
93#[cfg(test)]
94pin_project! {
95 pub struct Dispatcher<T, S, B, X, U>
97 where
98 S: Service<Request>,
99 S::Error: Into<Response<BoxBody>>,
100
101 B: MessageBody,
102
103 X: Service<Request, Response = Request>,
104 X::Error: Into<Response<BoxBody>>,
105
106 U: Service<(Request, Framed<T, Codec>), Response = ()>,
107 U::Error: fmt::Display,
108 {
109 #[pin]
110 pub(super) inner: DispatcherState<T, S, B, X, U>,
111
112 pub(super) poll_count: u64,
114 }
115}
116
117pin_project! {
118 #[project = DispatcherStateProj]
119 pub(super) enum DispatcherState<T, S, B, X, U>
120 where
121 S: Service<Request>,
122 S::Error: Into<Response<BoxBody>>,
123
124 B: MessageBody,
125
126 X: Service<Request, Response = Request>,
127 X::Error: Into<Response<BoxBody>>,
128
129 U: Service<(Request, Framed<T, Codec>), Response = ()>,
130 U::Error: fmt::Display,
131 {
132 Normal { #[pin] inner: InnerDispatcher<T, S, B, X, U> },
133 Upgrade { #[pin] fut: U::Future },
134 }
135}
136
137pin_project! {
138 #[project = InnerDispatcherProj]
139 pub(super) struct InnerDispatcher<T, S, B, X, U>
140 where
141 S: Service<Request>,
142 S::Error: Into<Response<BoxBody>>,
143
144 B: MessageBody,
145
146 X: Service<Request, Response = Request>,
147 X::Error: Into<Response<BoxBody>>,
148
149 U: Service<(Request, Framed<T, Codec>), Response = ()>,
150 U::Error: fmt::Display,
151 {
152 flow: Rc<HttpFlow<S, X, U>>,
153 pub(super) flags: Flags,
154 peer_addr: Option<net::SocketAddr>,
155 conn_data: Option<Rc<Extensions>>,
156 config: ServiceConfig,
157 error: Option<DispatchError>,
158
159 #[pin]
160 pub(super) state: State<S, B, X>,
161 payload: Option<PayloadSender>,
163 payload_drainable: bool,
165 messages: VecDeque<DispatcherMessage>,
166
167 head_timer: TimerState,
168 ka_timer: TimerState,
169 shutdown_timer: TimerState,
170
171 pub(super) io: Option<T>,
172 read_buf: BytesMut,
173 write_buf: BytesMut,
174 h1_write_buffer_size: usize,
175 codec: Codec,
176 }
177}
178
179enum DispatcherMessage {
180 Item(Request),
181 Upgrade(Request),
182 Error(Response<()>),
183}
184
185pin_project! {
186 #[project = StateProj]
187 pub(super) enum State<S, B, X>
188 where
189 S: Service<Request>,
190 X: Service<Request, Response = Request>,
191 B: MessageBody,
192 {
193 None,
194 ExpectCall { #[pin] fut: X::Future },
195 ServiceCall { #[pin] fut: S::Future },
196 SendPayload { #[pin] body: B },
197 SendErrorPayload { #[pin] body: BoxBody },
198 }
199}
200
201impl<S, B, X> State<S, B, X>
202where
203 S: Service<Request>,
204 X: Service<Request, Response = Request>,
205 B: MessageBody,
206{
207 pub(super) fn is_none(&self) -> bool {
208 matches!(self, State::None)
209 }
210}
211
212impl<S, B, X> fmt::Debug for State<S, B, X>
213where
214 S: Service<Request>,
215 X: Service<Request, Response = Request>,
216 B: MessageBody,
217{
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 match self {
220 Self::None => write!(f, "State::None"),
221 Self::ExpectCall { .. } => f.debug_struct("State::ExpectCall").finish_non_exhaustive(),
222 Self::ServiceCall { .. } => {
223 f.debug_struct("State::ServiceCall").finish_non_exhaustive()
224 }
225 Self::SendPayload { .. } => {
226 f.debug_struct("State::SendPayload").finish_non_exhaustive()
227 }
228 Self::SendErrorPayload { .. } => f
229 .debug_struct("State::SendErrorPayload")
230 .finish_non_exhaustive(),
231 }
232 }
233}
234
235#[derive(Debug)]
236enum PollResponse {
237 Upgrade(Request),
238 DoNothing,
239 DrainWriteBuf,
240}
241
242impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
243where
244 T: AsyncRead + AsyncWrite + Unpin,
245
246 S: Service<Request>,
247 S::Error: Into<Response<BoxBody>>,
248 S::Response: Into<Response<B>>,
249
250 B: MessageBody,
251
252 X: Service<Request, Response = Request>,
253 X::Error: Into<Response<BoxBody>>,
254
255 U: Service<(Request, Framed<T, Codec>), Response = ()>,
256 U::Error: fmt::Display,
257{
258 pub(crate) fn new(
260 io: T,
261 flow: Rc<HttpFlow<S, X, U>>,
262 config: ServiceConfig,
263 peer_addr: Option<net::SocketAddr>,
264 conn_data: OnConnectData,
265 ) -> Self {
266 Dispatcher {
267 inner: DispatcherState::Normal {
268 inner: InnerDispatcher {
269 flow,
270 flags: Flags::empty(),
271 peer_addr,
272 conn_data: conn_data.0.map(Rc::new),
273 config: config.clone(),
274 error: None,
275
276 state: State::None,
277 payload: None,
278 payload_drainable: false,
279 messages: VecDeque::new(),
280
281 head_timer: TimerState::new(config.client_request_deadline().is_some()),
282 ka_timer: TimerState::new(config.keep_alive().enabled()),
283 shutdown_timer: TimerState::new(config.client_disconnect_deadline().is_some()),
284
285 io: Some(io),
286 read_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
287 write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
288 h1_write_buffer_size: config.h1_write_buffer_size(),
289 codec: Codec::new(config),
290 },
291 },
292
293 #[cfg(test)]
294 poll_count: 0,
295 }
296 }
297}
298
299impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
300where
301 T: AsyncRead + AsyncWrite + Unpin,
302
303 S: Service<Request>,
304 S::Error: Into<Response<BoxBody>>,
305 S::Response: Into<Response<B>>,
306
307 B: MessageBody,
308
309 X: Service<Request, Response = Request>,
310 X::Error: Into<Response<BoxBody>>,
311
312 U: Service<(Request, Framed<T, Codec>), Response = ()>,
313 U::Error: fmt::Display,
314{
315 fn can_read(&self, cx: &mut Context<'_>) -> bool {
316 if self.flags.contains(Flags::READ_DISCONNECT) {
317 false
318 } else if let Some(ref info) = self.payload {
319 matches!(
320 info.need_read(cx),
321 PayloadStatus::Read | PayloadStatus::Dropped
322 )
323 } else {
324 true
325 }
326 }
327
328 fn client_disconnected(self: Pin<&mut Self>) {
329 let this = self.project();
330
331 this.flags
332 .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
333
334 if let Some(mut payload) = this.payload.take() {
335 payload.set_error(PayloadError::Incomplete(None));
336 }
337 }
338
339 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
340 let InnerDispatcherProj { io, write_buf, .. } = self.project();
341 let mut io = Pin::new(io.as_mut().unwrap());
342
343 let len = write_buf.len();
344 let mut written = 0;
345
346 while written < len {
347 match io.as_mut().poll_write(cx, &write_buf[written..])? {
348 Poll::Ready(0) => {
349 error!("write zero; closing");
350 return Poll::Ready(Err(io::Error::new(io::ErrorKind::WriteZero, "")));
351 }
352
353 Poll::Ready(n) => written += n,
354
355 Poll::Pending => {
356 write_buf.advance(written);
357 return Poll::Pending;
358 }
359 }
360 }
361
362 write_buf.clear();
364
365 io.poll_flush(cx)
367 }
368
369 fn enter_linger(flags: &mut Flags) {
370 flags.remove(Flags::KEEP_ALIVE);
371 flags.insert(Flags::LINGER | Flags::FINISHED);
372 }
373
374 fn ensure_linger_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool {
375 let this = self.as_mut().project();
376
377 if matches!(this.shutdown_timer, TimerState::Active { .. }) {
378 return true;
379 }
380
381 if let Some(deadline) = this.config.client_disconnect_deadline() {
382 this.shutdown_timer
383 .set_and_init(cx, sleep_until(deadline.into()), line!());
384 true
385 } else {
386 false
387 }
388 }
389
390 fn poll_linger(
391 mut self: Pin<&mut Self>,
392 cx: &mut Context<'_>,
393 ) -> Result<Poll<()>, DispatchError> {
394 if self.as_mut().poll_flush(cx)?.is_pending() {
395 return Ok(Poll::Pending);
396 }
397
398 if !self.as_mut().ensure_linger_timer(cx) {
399 let this = self.as_mut().project();
400 this.flags.remove(Flags::LINGER);
401 this.flags.insert(Flags::SHUTDOWN);
402 return Ok(Poll::Ready(()));
403 }
404
405 loop {
406 let should_disconnect = self.as_mut().read_available(cx)?;
407 let this = self.as_mut().project();
408 let mut progressed = false;
409
410 if !this.read_buf.is_empty() {
411 this.read_buf.clear();
412 progressed = true;
413 }
414
415 if should_disconnect {
416 this.flags.remove(Flags::LINGER);
417 this.flags.insert(Flags::READ_DISCONNECT | Flags::SHUTDOWN);
418 return Ok(Poll::Ready(()));
419 }
420
421 if !progressed {
422 return Ok(Poll::Pending);
423 }
424 }
425 }
426
427 fn send_response_inner(
428 self: Pin<&mut Self>,
429 res: Response<()>,
430 body: &impl MessageBody,
431 ) -> Result<BodySize, DispatchError> {
432 let this = self.project();
433
434 let size = body.size();
435
436 this.codec
437 .encode(Message::Item((res, size)), this.write_buf)
438 .map_err(|err| {
439 if let Some(mut payload) = this.payload.take() {
440 payload.set_error(PayloadError::Incomplete(None));
441 }
442
443 DispatchError::Io(err)
444 })?;
445
446 Ok(size)
447 }
448
449 fn send_response(
450 mut self: Pin<&mut Self>,
451 mut res: Response<()>,
452 body: B,
453 ) -> Result<(), DispatchError> {
454 let close_after_response = !res.upgrade() && {
455 let this = self.as_mut().project();
456 should_close_after_response(this.payload.as_ref(), *this.payload_drainable)
457 };
458
459 if close_after_response {
460 res.head_mut().set_connection_type(ConnectionType::Close);
461 }
462
463 let size = self.as_mut().send_response_inner(res, &body)?;
464 match size {
465 BodySize::None | BodySize::Sized(0) => {
466 let mut this = self.as_mut().project();
467
468 if close_after_response {
469 if this.config.client_disconnect_deadline().is_some() {
470 Self::enter_linger(this.flags);
471 } else {
472 this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
473 }
474 } else {
475 this.flags.insert(Flags::FINISHED);
476 }
477
478 this.state.set(State::None);
479 }
480 _ => self
481 .as_mut()
482 .project()
483 .state
484 .set(State::SendPayload { body }),
485 }
486
487 Ok(())
488 }
489
490 fn send_error_response(
491 mut self: Pin<&mut Self>,
492 mut res: Response<()>,
493 body: BoxBody,
494 ) -> Result<(), DispatchError> {
495 let close_after_response = !res.upgrade() && {
496 let this = self.as_mut().project();
497 should_close_after_response(this.payload.as_ref(), *this.payload_drainable)
498 };
499
500 if close_after_response {
501 res.head_mut().set_connection_type(ConnectionType::Close);
502 }
503
504 let size = self.as_mut().send_response_inner(res, &body)?;
505 match size {
506 BodySize::None | BodySize::Sized(0) => {
507 let mut this = self.as_mut().project();
508
509 if close_after_response {
510 if this.config.client_disconnect_deadline().is_some() {
511 Self::enter_linger(this.flags);
512 } else {
513 this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
514 }
515 } else {
516 this.flags.insert(Flags::FINISHED);
517 }
518
519 this.state.set(State::None);
520 }
521 _ => self
522 .as_mut()
523 .project()
524 .state
525 .set(State::SendErrorPayload { body }),
526 }
527
528 Ok(())
529 }
530
531 fn send_continue(self: Pin<&mut Self>) {
532 self.project()
533 .write_buf
534 .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
535 }
536
537 fn poll_response(
538 mut self: Pin<&mut Self>,
539 cx: &mut Context<'_>,
540 ) -> Result<PollResponse, DispatchError> {
541 'res: loop {
542 let mut this = self.as_mut().project();
543 match this.state.as_mut().project() {
544 StateProj::None => match this.messages.pop_front() {
546 Some(DispatcherMessage::Item(req)) => {
548 if req.head().expect() {
550 let fut = this.flow.expect.call(req);
552 this.state.set(State::ExpectCall { fut });
553 } else {
554 let fut = this.flow.service.call(req);
556 this.state.set(State::ServiceCall { fut });
557 };
558 }
559
560 Some(DispatcherMessage::Error(res)) => {
562 self.as_mut().send_error_response(res, BoxBody::new(()))?;
566 }
567
568 Some(DispatcherMessage::Upgrade(req)) => return Ok(PollResponse::Upgrade(req)),
570
571 None => {
573 this.flags.set(
575 Flags::KEEP_ALIVE,
576 this.payload.is_none() && this.codec.keep_alive(),
577 );
578
579 return Ok(PollResponse::DoNothing);
580 }
581 },
582
583 StateProj::ServiceCall { fut } => {
584 match fut.poll(cx) {
585 Poll::Ready(Ok(res)) => {
587 let (res, body) = res.into().replace_body(());
588 self.as_mut().send_response(res, body)?;
589 }
590
591 Poll::Ready(Err(err)) => {
593 let res: Response<BoxBody> = err.into();
594 let (res, body) = res.replace_body(());
595 self.as_mut().send_error_response(res, body)?;
596 }
597
598 Poll::Pending => {
601 if !self.as_mut().poll_request(cx)? {
604 return Ok(PollResponse::DoNothing);
605 }
606 }
608 }
609 }
610
611 StateProj::SendPayload { mut body } => {
612 while this.write_buf.len() < *this.h1_write_buffer_size {
615 match body.as_mut().poll_next(cx) {
616 Poll::Ready(Some(Ok(item))) => {
617 this.codec
618 .encode(Message::Chunk(Some(item)), this.write_buf)?;
619 }
620
621 Poll::Ready(None) => {
622 this.codec.encode(Message::Chunk(None), this.write_buf)?;
623
624 let close_after_response = should_close_after_response(
629 this.payload.as_ref(),
630 *this.payload_drainable,
631 );
632 let not_pipelined = this.messages.is_empty();
633
634 this.state.set(State::None);
637
638 if not_pipelined && close_after_response {
639 if this.config.client_disconnect_deadline().is_some() {
640 Self::enter_linger(this.flags);
641 } else {
642 this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
643 }
644 } else {
645 this.flags.insert(Flags::FINISHED);
646 }
647
648 continue 'res;
649 }
650
651 Poll::Ready(Some(Err(err))) => {
652 let err = err.into();
653 tracing::error!("Response payload stream error: {err:?}");
654 this.flags.insert(Flags::FINISHED);
655 return Err(DispatchError::Body(err));
656 }
657
658 Poll::Pending => return Ok(PollResponse::DoNothing),
659 }
660 }
661
662 return Ok(PollResponse::DrainWriteBuf);
665 }
666
667 StateProj::SendErrorPayload { mut body } => {
668 while this.write_buf.len() < *this.h1_write_buffer_size {
673 match body.as_mut().poll_next(cx) {
674 Poll::Ready(Some(Ok(item))) => {
675 this.codec
676 .encode(Message::Chunk(Some(item)), this.write_buf)?;
677 }
678
679 Poll::Ready(None) => {
680 this.codec.encode(Message::Chunk(None), this.write_buf)?;
681
682 let close_after_response = should_close_after_response(
687 this.payload.as_ref(),
688 *this.payload_drainable,
689 );
690 let not_pipelined = this.messages.is_empty();
691
692 this.state.set(State::None);
695
696 if not_pipelined && close_after_response {
697 if this.config.client_disconnect_deadline().is_some() {
698 Self::enter_linger(this.flags);
699 } else {
700 this.flags.insert(Flags::SHUTDOWN | Flags::FINISHED);
701 }
702 } else {
703 this.flags.insert(Flags::FINISHED);
704 }
705
706 continue 'res;
707 }
708
709 Poll::Ready(Some(Err(err))) => {
710 tracing::error!("Response payload stream error: {err:?}");
711 this.flags.insert(Flags::FINISHED);
712 return Err(DispatchError::Body(
713 Error::new_body().with_cause(err).into(),
714 ));
715 }
716
717 Poll::Pending => return Ok(PollResponse::DoNothing),
718 }
719 }
720
721 return Ok(PollResponse::DrainWriteBuf);
724 }
725
726 StateProj::ExpectCall { fut } => {
727 trace!(" calling expect service");
728
729 match fut.poll(cx) {
730 Poll::Ready(Ok(req)) => {
733 this.write_buf
734 .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
735 let fut = this.flow.service.call(req);
736 this.state.set(State::ServiceCall { fut });
737 }
738
739 Poll::Ready(Err(err)) => {
741 let res: Response<BoxBody> = err.into();
742 let (res, body) = res.replace_body(());
743 self.as_mut().send_error_response(res, body)?;
744 }
745
746 Poll::Pending => return Ok(PollResponse::DoNothing),
748 }
749 }
750 }
751 }
752 }
753
754 fn handle_request(
755 mut self: Pin<&mut Self>,
756 req: Request,
757 cx: &mut Context<'_>,
758 ) -> Result<(), DispatchError> {
759 {
761 let mut this = self.as_mut().project();
762
763 if req.head().expect() {
765 let fut = this.flow.expect.call(req);
767 this.state.set(State::ExpectCall { fut });
768 } else {
769 let fut = this.flow.service.call(req);
771 this.state.set(State::ServiceCall { fut });
772 };
773 };
774
775 loop {
777 match self.as_mut().project().state.project() {
778 StateProj::ExpectCall { fut } => {
779 match fut.poll(cx) {
780 Poll::Ready(Ok(req)) => {
782 self.as_mut().send_continue();
783
784 let mut this = self.as_mut().project();
785 let fut = this.flow.service.call(req);
786 this.state.set(State::ServiceCall { fut });
787
788 continue;
789 }
790
791 Poll::Ready(Err(err)) => {
795 let res: Response<BoxBody> = err.into();
796 let (res, body) = res.replace_body(());
797 return self.send_error_response(res, body);
798 }
799
800 Poll::Pending => return Ok(()),
803 }
804 }
805
806 StateProj::ServiceCall { fut } => {
807 return match fut.poll(cx) {
809 Poll::Ready(Ok(res)) => {
813 let (res, body) = res.into().replace_body(());
814 self.as_mut().send_response(res, body)
815 }
816
817 Poll::Pending => Ok(()),
819
820 Poll::Ready(Err(err)) => {
822 let res: Response<BoxBody> = err.into();
823 let (res, body) = res.replace_body(());
824 self.as_mut().send_error_response(res, body)
825 }
826 };
827 }
828
829 _ => {
830 unreachable!("State must be set to ServiceCall or ExceptCall in handle_request")
831 }
832 }
833 }
834 }
835
836 fn poll_request(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
840 let pipeline_queue_full = self.messages.len() >= MAX_PIPELINED_MESSAGES;
841 let can_not_read = !self.can_read(cx);
842
843 if pipeline_queue_full || can_not_read {
845 return Ok(false);
846 }
847
848 let mut this = self.as_mut().project();
849
850 let mut updated = false;
851
852 loop {
854 match this.codec.decode(this.read_buf) {
855 Ok(Some(msg)) => {
856 updated = true;
857
858 match msg {
859 Message::Item(mut req) => {
860 this.head_timer.clear(line!());
862
863 req.head_mut().peer_addr = *this.peer_addr;
864
865 req.conn_data.clone_from(this.conn_data);
866
867 match this.codec.message_type() {
868 MessageType::None => *this.payload_drainable = false,
870
871 MessageType::Stream if this.flow.upgrade.is_some() => {
875 *this.payload_drainable = false;
876 this.messages.push_back(DispatcherMessage::Upgrade(req));
877 break;
878 }
879
880 MessageType::Payload | MessageType::Stream => {
882 let (sender, payload) = Payload::create(false);
888 *req.payload() = crate::Payload::H1 { payload };
889 *this.payload = Some(sender);
890 *this.payload_drainable = req.chunked().unwrap_or(false);
891 }
892 }
893
894 if this.state.is_none() {
896 self.as_mut().handle_request(req, cx)?;
897 this = self.as_mut().project();
898 } else {
899 this.messages.push_back(DispatcherMessage::Item(req));
900 }
901 }
902
903 Message::Chunk(Some(chunk)) => {
904 if let Some(ref mut payload) = this.payload {
905 payload.feed_data(chunk);
906 } else {
907 error!("Internal server error: unexpected payload chunk");
908 this.flags.insert(Flags::READ_DISCONNECT);
909 this.messages.push_back(DispatcherMessage::Error(
910 Response::internal_server_error().drop_body(),
911 ));
912 *this.error = Some(DispatchError::InternalError);
913 break;
914 }
915 }
916
917 Message::Chunk(None) => {
918 if let Some(mut payload) = this.payload.take() {
919 payload.feed_eof();
920 *this.payload_drainable = false;
921 } else {
922 error!("Internal server error: unexpected eof");
923 this.flags.insert(Flags::READ_DISCONNECT);
924 this.messages.push_back(DispatcherMessage::Error(
925 Response::internal_server_error().drop_body(),
926 ));
927 *this.error = Some(DispatchError::InternalError);
928 break;
929 }
930 }
931 }
932 }
933
934 Ok(None) => break,
937
938 Err(ParseError::Io(err)) => {
939 trace!("I/O error: {}", &err);
940 self.as_mut().client_disconnected();
941 this = self.as_mut().project();
942 *this.error = Some(DispatchError::Io(err));
943 break;
944 }
945
946 Err(ParseError::TooLarge) => {
947 trace!("request head was too big; returning 431 response");
948
949 if let Some(mut payload) = this.payload.take() {
950 payload.set_error(PayloadError::Overflow);
951 }
952
953 this.messages
955 .push_back(DispatcherMessage::Error(Response::with_body(
956 StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
957 (),
958 )));
959
960 this.flags.insert(Flags::READ_DISCONNECT);
961 *this.error = Some(ParseError::TooLarge.into());
962
963 break;
964 }
965
966 Err(err) => {
967 trace!("parse error {}", &err);
968
969 if let Some(mut payload) = this.payload.take() {
970 payload.set_error(PayloadError::EncodingCorrupted);
971 }
972
973 this.messages.push_back(DispatcherMessage::Error(
975 Response::bad_request().drop_body(),
976 ));
977
978 this.flags.insert(Flags::READ_DISCONNECT);
979 *this.error = Some(err.into());
980 break;
981 }
982 }
983 }
984
985 Ok(updated)
986 }
987
988 fn poll_head_timer(
989 mut self: Pin<&mut Self>,
990 cx: &mut Context<'_>,
991 ) -> Result<(), DispatchError> {
992 let this = self.as_mut().project();
993
994 if let TimerState::Active { timer } = this.head_timer {
995 if timer.as_mut().poll(cx).is_ready() {
996 trace!("timed out on slow request; replying with 408 and closing connection");
999
1000 let _ = self.as_mut().send_error_response(
1001 Response::with_body(StatusCode::REQUEST_TIMEOUT, ()),
1002 BoxBody::new(()),
1003 );
1004
1005 self.project().flags.insert(Flags::SHUTDOWN);
1006 }
1007 };
1008
1009 Ok(())
1010 }
1011
1012 fn poll_ka_timer(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
1013 let this = self.as_mut().project();
1014 if let TimerState::Active { timer } = this.ka_timer {
1015 debug_assert!(
1016 this.flags.contains(Flags::KEEP_ALIVE),
1017 "keep-alive flag should be set when timer is active",
1018 );
1019 debug_assert!(
1020 this.state.is_none(),
1021 "dispatcher should not be in keep-alive phase if state is not none: {:?}",
1022 this.state,
1023 );
1024
1025 if timer.as_mut().poll(cx).is_ready() {
1037 trace!("timer timed out; closing connection");
1039 this.flags.insert(Flags::SHUTDOWN);
1040
1041 if let Some(deadline) = this.config.client_disconnect_deadline() {
1042 this.shutdown_timer
1044 .set_and_init(cx, sleep_until(deadline.into()), line!());
1045 } else {
1046 this.flags.insert(Flags::WRITE_DISCONNECT);
1048 }
1049 }
1050 }
1051
1052 Ok(())
1053 }
1054
1055 fn poll_shutdown_timer(
1056 mut self: Pin<&mut Self>,
1057 cx: &mut Context<'_>,
1058 ) -> Result<(), DispatchError> {
1059 let this = self.as_mut().project();
1060 if let TimerState::Active { timer } = this.shutdown_timer {
1061 debug_assert!(
1062 this.flags.intersects(Flags::LINGER | Flags::SHUTDOWN),
1063 "shutdown or linger flag should be set when timer is active",
1064 );
1065
1066 if timer.as_mut().poll(cx).is_ready() {
1067 if this.flags.contains(Flags::LINGER) {
1068 trace!("timed-out during linger; shutting down connection");
1069 this.flags.remove(Flags::LINGER);
1070 this.flags.insert(Flags::SHUTDOWN);
1071 this.shutdown_timer.clear(line!());
1072 } else {
1073 trace!("timed-out during shutdown");
1074 return Err(DispatchError::DisconnectTimeout);
1075 }
1076 }
1077 }
1078
1079 Ok(())
1080 }
1081
1082 fn poll_timers(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<(), DispatchError> {
1084 self.as_mut().poll_head_timer(cx)?;
1085 self.as_mut().poll_ka_timer(cx)?;
1086 self.as_mut().poll_shutdown_timer(cx)?;
1087
1088 Ok(())
1089 }
1090
1091 #[inline(always)] fn read_available(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
1098 let this = self.project();
1099
1100 if this.flags.contains(Flags::READ_DISCONNECT) {
1101 return Ok(false);
1102 };
1103
1104 let mut io = Pin::new(this.io.as_mut().unwrap());
1105
1106 let mut read_some = false;
1107
1108 loop {
1109 if this.read_buf.len() >= MAX_BUFFER_SIZE {
1111 match this.payload.as_ref().map(|p| p.need_read(cx)) {
1130 Some(PayloadStatus::Pause) => {}
1132
1133 Some(PayloadStatus::Dropped) | Some(PayloadStatus::Read) | None => {
1135 cx.waker().wake_by_ref()
1136 }
1137 }
1138
1139 return Ok(false);
1140 }
1141
1142 let remaining = this.read_buf.capacity() - this.read_buf.len();
1144 if remaining < LW_BUFFER_SIZE {
1145 this.read_buf.reserve(HW_BUFFER_SIZE - remaining);
1146 }
1147
1148 match tokio_util::io::poll_read_buf(io.as_mut(), cx, this.read_buf) {
1149 Poll::Ready(Ok(n)) => {
1150 if !this.payload.as_ref().is_some_and(|pl| pl.is_dropped()) {
1153 this.flags.remove(Flags::FINISHED);
1154 }
1155
1156 if n == 0 {
1157 return Ok(true);
1158 }
1159
1160 read_some = true;
1161 }
1162
1163 Poll::Pending => {
1164 return Ok(false);
1165 }
1166
1167 Poll::Ready(Err(err)) => {
1168 return match err.kind() {
1169 io::ErrorKind::WouldBlock => Ok(false),
1171
1172 io::ErrorKind::ConnectionReset if read_some => Ok(true),
1174
1175 _ => Err(DispatchError::Io(err)),
1176 };
1177 }
1178 }
1179 }
1180 }
1181
1182 fn upgrade(self: Pin<&mut Self>, req: Request) -> U::Future {
1184 let this = self.project();
1185 let mut parts = FramedParts::with_read_buf(
1186 this.io.take().unwrap(),
1187 mem::take(this.codec),
1188 mem::take(this.read_buf),
1189 );
1190 parts.write_buf = mem::take(this.write_buf);
1191 let framed = Framed::from_parts(parts);
1192 this.flow.upgrade.as_ref().unwrap().call((req, framed))
1193 }
1194}
1195
1196impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
1197where
1198 T: AsyncRead + AsyncWrite + Unpin,
1199
1200 S: Service<Request>,
1201 S::Error: Into<Response<BoxBody>>,
1202 S::Response: Into<Response<B>>,
1203
1204 B: MessageBody,
1205
1206 X: Service<Request, Response = Request>,
1207 X::Error: Into<Response<BoxBody>>,
1208
1209 U: Service<(Request, Framed<T, Codec>), Response = ()>,
1210 U::Error: fmt::Display,
1211{
1212 type Output = Result<(), DispatchError>;
1213
1214 #[inline]
1215 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1216 let this = self.as_mut().project();
1217
1218 #[cfg(test)]
1219 {
1220 *this.poll_count += 1;
1221 }
1222
1223 match this.inner.project() {
1224 DispatcherStateProj::Upgrade { fut: upgrade } => upgrade.poll(cx).map_err(|err| {
1225 error!("Upgrade handler error: {}", err);
1226 DispatchError::Upgrade
1227 }),
1228
1229 DispatcherStateProj::Normal { mut inner } => {
1230 trace!("start flags: {:?}", &inner.flags);
1231
1232 trace_timer_states(
1233 "start",
1234 &inner.head_timer,
1235 &inner.ka_timer,
1236 &inner.shutdown_timer,
1237 );
1238
1239 inner.as_mut().poll_timers(cx)?;
1240
1241 let poll = if inner.flags.contains(Flags::LINGER) {
1242 match inner.as_mut().poll_linger(cx)? {
1243 Poll::Ready(()) => {
1244 cx.waker().wake_by_ref();
1245 Poll::Pending
1246 }
1247 Poll::Pending => Poll::Pending,
1248 }
1249 } else if inner.flags.contains(Flags::SHUTDOWN) {
1250 if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1251 Poll::Ready(Ok(()))
1252 } else {
1253 ready!(inner.as_mut().poll_flush(cx))?;
1255 Pin::new(inner.as_mut().project().io.as_mut().unwrap())
1256 .poll_shutdown(cx)
1257 .map_err(DispatchError::from)
1258 }
1259 } else {
1260 let should_disconnect = inner.as_mut().read_available(cx)?;
1262
1263 if !inner.read_buf.is_empty() && inner.flags.contains(Flags::KEEP_ALIVE) {
1265 let inner = inner.as_mut().project();
1266 inner.flags.remove(Flags::KEEP_ALIVE);
1267 inner.ka_timer.clear(line!());
1268 }
1269
1270 if !inner.flags.contains(Flags::STARTED) {
1271 inner.as_mut().project().flags.insert(Flags::STARTED);
1272
1273 if let Some(deadline) = inner.config.client_request_deadline() {
1274 inner.as_mut().project().head_timer.set_and_init(
1275 cx,
1276 sleep_until(deadline.into()),
1277 line!(),
1278 );
1279 }
1280 }
1281
1282 inner.as_mut().poll_request(cx)?;
1283
1284 if should_disconnect {
1285 let inner = inner.as_mut().project();
1287 inner.flags.insert(Flags::READ_DISCONNECT);
1288 if let Some(mut payload) = inner.payload.take() {
1289 payload.set_error(PayloadError::Incomplete(None));
1290 payload.feed_eof();
1291 }
1292 };
1293
1294 loop {
1295 let drain = match inner.as_mut().poll_response(cx)? {
1298 PollResponse::DrainWriteBuf => true,
1299
1300 PollResponse::DoNothing => {
1301 if inner.flags.contains(Flags::KEEP_ALIVE | Flags::FINISHED) {
1304 if let Some(timer) = inner.config.keep_alive_deadline() {
1305 inner.as_mut().project().ka_timer.set_and_init(
1306 cx,
1307 sleep_until(timer.into()),
1308 line!(),
1309 );
1310 }
1311 }
1312
1313 false
1314 }
1315
1316 PollResponse::Upgrade(req) => {
1318 let upgrade = inner.upgrade(req);
1319 self.as_mut()
1320 .project()
1321 .inner
1322 .set(DispatcherState::Upgrade { fut: upgrade });
1323 return self.poll(cx);
1324 }
1325 };
1326
1327 let flush_was_ready = inner.as_mut().poll_flush(cx)?.is_ready();
1334
1335 if !flush_was_ready || !drain {
1340 break;
1341 }
1342 }
1343
1344 if inner.flags.contains(Flags::WRITE_DISCONNECT) {
1346 trace!("client is gone; disconnecting");
1347 return Poll::Ready(Ok(()));
1348 }
1349
1350 let inner_p = inner.as_mut().project();
1351 let state_is_none = inner_p.state.is_none();
1352
1353 if inner_p.flags.contains(Flags::READ_DISCONNECT)
1361 && (!inner_p.config.h1_allow_half_closed() || state_is_none)
1362 {
1363 trace!("read half closed; start shutdown");
1364 inner_p.flags.insert(Flags::SHUTDOWN);
1365 }
1366
1367 if state_is_none && inner_p.write_buf.is_empty() {
1369 if let Some(err) = inner_p.error.take() {
1370 error!("stream error: {}", &err);
1371 return Poll::Ready(Err(err));
1372 }
1373
1374 if inner_p.flags.contains(Flags::FINISHED)
1376 && !inner_p.flags.contains(Flags::KEEP_ALIVE)
1377 && inner_p.payload.is_none()
1378 {
1379 inner_p.flags.remove(Flags::FINISHED);
1380 inner_p.flags.insert(Flags::SHUTDOWN);
1381 return self.poll(cx);
1382 }
1383
1384 if inner_p.flags.contains(Flags::SHUTDOWN) {
1386 return self.poll(cx);
1387 }
1388 }
1389
1390 trace_timer_states(
1391 "end",
1392 inner_p.head_timer,
1393 inner_p.ka_timer,
1394 inner_p.shutdown_timer,
1395 );
1396
1397 if inner_p.flags.intersects(Flags::LINGER | Flags::SHUTDOWN) {
1398 cx.waker().wake_by_ref();
1399 }
1400 Poll::Pending
1401 };
1402
1403 trace!("end flags: {:?}", &inner.flags);
1404
1405 poll
1406 }
1407 }
1408 }
1409}
1410
1411fn should_close_after_response(payload: Option<&PayloadSender>, payload_drainable: bool) -> bool {
1412 let payload_unfinished = payload.is_some();
1413 let drain_payload = payload.is_some_and(|pl| pl.is_dropped()) && payload_drainable;
1414
1415 payload_unfinished && !drain_payload
1416}
1417
1418#[allow(dead_code)]
1419fn trace_timer_states(
1420 label: &str,
1421 head_timer: &TimerState,
1422 ka_timer: &TimerState,
1423 shutdown_timer: &TimerState,
1424) {
1425 trace!("{} timers:", label);
1426
1427 if head_timer.is_enabled() {
1428 trace!(" head {}", &head_timer);
1429 }
1430
1431 if ka_timer.is_enabled() {
1432 trace!(" keep-alive {}", &ka_timer);
1433 }
1434
1435 if shutdown_timer.is_enabled() {
1436 trace!(" shutdown {}", &shutdown_timer);
1437 }
1438}