1use std::collections::VecDeque;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{fmt, io, net};
6
7use actori_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
8use actori_rt::time::{delay_until, Delay, Instant};
9use actori_service::Service;
10use bitflags::bitflags;
11use bytes::{Buf, BytesMut};
12use log::{error, trace};
13
14use crate::body::{Body, BodySize, MessageBody, ResponseBody};
15use crate::cloneable::CloneableService;
16use crate::config::ServiceConfig;
17use crate::error::{DispatchError, Error};
18use crate::error::{ParseError, PayloadError};
19use crate::helpers::DataFactory;
20use crate::httpmessage::HttpMessage;
21use crate::request::Request;
22use crate::response::Response;
23
24use super::codec::Codec;
25use super::payload::{Payload, PayloadSender, PayloadStatus};
26use super::{Message, MessageType};
27
28const LW_BUFFER_SIZE: usize = 4096;
29const HW_BUFFER_SIZE: usize = 32_768;
30const MAX_PIPELINED_MESSAGES: usize = 16;
31
32bitflags! {
33 pub struct Flags: u8 {
34 const STARTED = 0b0000_0001;
35 const KEEPALIVE = 0b0000_0010;
36 const POLLED = 0b0000_0100;
37 const SHUTDOWN = 0b0000_1000;
38 const READ_DISCONNECT = 0b0001_0000;
39 const WRITE_DISCONNECT = 0b0010_0000;
40 const UPGRADE = 0b0100_0000;
41 }
42}
43
44pub struct Dispatcher<T, S, B, X, U>
46where
47 S: Service<Request = Request>,
48 S::Error: Into<Error>,
49 B: MessageBody,
50 X: Service<Request = Request, Response = Request>,
51 X::Error: Into<Error>,
52 U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
53 U::Error: fmt::Display,
54{
55 inner: DispatcherState<T, S, B, X, U>,
56}
57
58enum DispatcherState<T, S, B, X, U>
59where
60 S: Service<Request = Request>,
61 S::Error: Into<Error>,
62 B: MessageBody,
63 X: Service<Request = Request, Response = Request>,
64 X::Error: Into<Error>,
65 U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
66 U::Error: fmt::Display,
67{
68 Normal(InnerDispatcher<T, S, B, X, U>),
69 Upgrade(U::Future),
70 None,
71}
72
73struct InnerDispatcher<T, S, B, X, U>
74where
75 S: Service<Request = Request>,
76 S::Error: Into<Error>,
77 B: MessageBody,
78 X: Service<Request = Request, Response = Request>,
79 X::Error: Into<Error>,
80 U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
81 U::Error: fmt::Display,
82{
83 service: CloneableService<S>,
84 expect: CloneableService<X>,
85 upgrade: Option<CloneableService<U>>,
86 on_connect: Option<Box<dyn DataFactory>>,
87 flags: Flags,
88 peer_addr: Option<net::SocketAddr>,
89 error: Option<DispatchError>,
90
91 state: State<S, B, X>,
92 payload: Option<PayloadSender>,
93 messages: VecDeque<DispatcherMessage>,
94
95 ka_expire: Instant,
96 ka_timer: Option<Delay>,
97
98 io: T,
99 read_buf: BytesMut,
100 write_buf: BytesMut,
101 codec: Codec,
102}
103
104enum DispatcherMessage {
105 Item(Request),
106 Upgrade(Request),
107 Error(Response<()>),
108}
109
110enum State<S, B, X>
111where
112 S: Service<Request = Request>,
113 X: Service<Request = Request, Response = Request>,
114 B: MessageBody,
115{
116 None,
117 ExpectCall(X::Future),
118 ServiceCall(S::Future),
119 SendPayload(ResponseBody<B>),
120}
121
122impl<S, B, X> State<S, B, X>
123where
124 S: Service<Request = Request>,
125 X: Service<Request = Request, Response = Request>,
126 B: MessageBody,
127{
128 fn is_empty(&self) -> bool {
129 if let State::None = self {
130 true
131 } else {
132 false
133 }
134 }
135
136 fn is_call(&self) -> bool {
137 if let State::ServiceCall(_) = self {
138 true
139 } else {
140 false
141 }
142 }
143}
144
145enum PollResponse {
146 Upgrade(Request),
147 DoNothing,
148 DrainWriteBuf,
149}
150
151impl PartialEq for PollResponse {
152 fn eq(&self, other: &PollResponse) -> bool {
153 match self {
154 PollResponse::DrainWriteBuf => match other {
155 PollResponse::DrainWriteBuf => true,
156 _ => false,
157 },
158 PollResponse::DoNothing => match other {
159 PollResponse::DoNothing => true,
160 _ => false,
161 },
162 _ => false,
163 }
164 }
165}
166
167impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
168where
169 T: AsyncRead + AsyncWrite + Unpin,
170 S: Service<Request = Request>,
171 S::Error: Into<Error>,
172 S::Response: Into<Response<B>>,
173 B: MessageBody,
174 X: Service<Request = Request, Response = Request>,
175 X::Error: Into<Error>,
176 U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
177 U::Error: fmt::Display,
178{
179 pub(crate) fn new(
181 stream: T,
182 config: ServiceConfig,
183 service: CloneableService<S>,
184 expect: CloneableService<X>,
185 upgrade: Option<CloneableService<U>>,
186 on_connect: Option<Box<dyn DataFactory>>,
187 peer_addr: Option<net::SocketAddr>,
188 ) -> Self {
189 Dispatcher::with_timeout(
190 stream,
191 Codec::new(config.clone()),
192 config,
193 BytesMut::with_capacity(HW_BUFFER_SIZE),
194 None,
195 service,
196 expect,
197 upgrade,
198 on_connect,
199 peer_addr,
200 )
201 }
202
203 pub(crate) fn with_timeout(
205 io: T,
206 codec: Codec,
207 config: ServiceConfig,
208 read_buf: BytesMut,
209 timeout: Option<Delay>,
210 service: CloneableService<S>,
211 expect: CloneableService<X>,
212 upgrade: Option<CloneableService<U>>,
213 on_connect: Option<Box<dyn DataFactory>>,
214 peer_addr: Option<net::SocketAddr>,
215 ) -> Self {
216 let keepalive = config.keep_alive_enabled();
217 let flags = if keepalive {
218 Flags::KEEPALIVE
219 } else {
220 Flags::empty()
221 };
222
223 let (ka_expire, ka_timer) = if let Some(delay) = timeout {
225 (delay.deadline(), Some(delay))
226 } else if let Some(delay) = config.keep_alive_timer() {
227 (delay.deadline(), Some(delay))
228 } else {
229 (config.now(), None)
230 };
231
232 Dispatcher {
233 inner: DispatcherState::Normal(InnerDispatcher {
234 write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
235 payload: None,
236 state: State::None,
237 error: None,
238 messages: VecDeque::new(),
239 io,
240 codec,
241 read_buf,
242 service,
243 expect,
244 upgrade,
245 on_connect,
246 flags,
247 peer_addr,
248 ka_expire,
249 ka_timer,
250 }),
251 }
252 }
253}
254
255impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
256where
257 T: AsyncRead + AsyncWrite + Unpin,
258 S: Service<Request = Request>,
259 S::Error: Into<Error>,
260 S::Response: Into<Response<B>>,
261 B: MessageBody,
262 X: Service<Request = Request, Response = Request>,
263 X::Error: Into<Error>,
264 U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
265 U::Error: fmt::Display,
266{
267 fn can_read(&self, cx: &mut Context<'_>) -> bool {
268 if self
269 .flags
270 .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
271 {
272 false
273 } else if let Some(ref info) = self.payload {
274 info.need_read(cx) == PayloadStatus::Read
275 } else {
276 true
277 }
278 }
279
280 fn client_disconnected(&mut self) {
282 self.flags
283 .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
284 if let Some(mut payload) = self.payload.take() {
285 payload.set_error(PayloadError::Incomplete(None));
286 }
287 }
288
289 fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
294 if self.write_buf.is_empty() {
295 return Ok(false);
296 }
297
298 let len = self.write_buf.len();
299 let mut written = 0;
300 while written < len {
301 match unsafe { Pin::new_unchecked(&mut self.io) }
302 .poll_write(cx, &self.write_buf[written..])
303 {
304 Poll::Ready(Ok(0)) => {
305 return Err(DispatchError::Io(io::Error::new(
306 io::ErrorKind::WriteZero,
307 "",
308 )));
309 }
310 Poll::Ready(Ok(n)) => {
311 written += n;
312 }
313 Poll::Pending => {
314 if written > 0 {
315 self.write_buf.advance(written);
316 }
317 return Ok(true);
318 }
319 Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)),
320 }
321 }
322 if written == self.write_buf.len() {
323 unsafe { self.write_buf.set_len(0) }
324 } else {
325 self.write_buf.advance(written);
326 }
327 Ok(false)
328 }
329
330 fn send_response(
331 &mut self,
332 message: Response<()>,
333 body: ResponseBody<B>,
334 ) -> Result<State<S, B, X>, DispatchError> {
335 self.codec
336 .encode(Message::Item((message, body.size())), &mut self.write_buf)
337 .map_err(|err| {
338 if let Some(mut payload) = self.payload.take() {
339 payload.set_error(PayloadError::Incomplete(None));
340 }
341 DispatchError::Io(err)
342 })?;
343
344 self.flags.set(Flags::KEEPALIVE, self.codec.keepalive());
345 match body.size() {
346 BodySize::None | BodySize::Empty => Ok(State::None),
347 _ => Ok(State::SendPayload(body)),
348 }
349 }
350
351 fn send_continue(&mut self) {
352 self.write_buf
353 .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
354 }
355
356 fn poll_response(
357 &mut self,
358 cx: &mut Context<'_>,
359 ) -> Result<PollResponse, DispatchError> {
360 loop {
361 let state = match self.state {
362 State::None => match self.messages.pop_front() {
363 Some(DispatcherMessage::Item(req)) => {
364 Some(self.handle_request(req, cx)?)
365 }
366 Some(DispatcherMessage::Error(res)) => {
367 Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
368 }
369 Some(DispatcherMessage::Upgrade(req)) => {
370 return Ok(PollResponse::Upgrade(req));
371 }
372 None => None,
373 },
374 State::ExpectCall(ref mut fut) => {
375 match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
376 Poll::Ready(Ok(req)) => {
377 self.send_continue();
378 self.state = State::ServiceCall(self.service.call(req));
379 continue;
380 }
381 Poll::Ready(Err(e)) => {
382 let res: Response = e.into().into();
383 let (res, body) = res.replace_body(());
384 Some(self.send_response(res, body.into_body())?)
385 }
386 Poll::Pending => None,
387 }
388 }
389 State::ServiceCall(ref mut fut) => {
390 match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
391 Poll::Ready(Ok(res)) => {
392 let (res, body) = res.into().replace_body(());
393 self.state = self.send_response(res, body)?;
394 continue;
395 }
396 Poll::Ready(Err(e)) => {
397 let res: Response = e.into().into();
398 let (res, body) = res.replace_body(());
399 Some(self.send_response(res, body.into_body())?)
400 }
401 Poll::Pending => None,
402 }
403 }
404 State::SendPayload(ref mut stream) => {
405 loop {
406 if self.write_buf.len() < HW_BUFFER_SIZE {
407 match stream.poll_next(cx) {
408 Poll::Ready(Some(Ok(item))) => {
409 self.codec.encode(
410 Message::Chunk(Some(item)),
411 &mut self.write_buf,
412 )?;
413 continue;
414 }
415 Poll::Ready(None) => {
416 self.codec.encode(
417 Message::Chunk(None),
418 &mut self.write_buf,
419 )?;
420 self.state = State::None;
421 }
422 Poll::Ready(Some(Err(_))) => {
423 return Err(DispatchError::Unknown)
424 }
425 Poll::Pending => return Ok(PollResponse::DoNothing),
426 }
427 } else {
428 return Ok(PollResponse::DrainWriteBuf);
429 }
430 break;
431 }
432 continue;
433 }
434 };
435
436 if let Some(state) = state {
438 self.state = state;
439 if !self.state.is_empty() {
440 continue;
441 }
442 } else {
443 if self.state.is_call() {
446 if self.poll_request(cx)? {
447 continue;
448 }
449 } else if !self.messages.is_empty() {
450 continue;
451 }
452 }
453 break;
454 }
455
456 Ok(PollResponse::DoNothing)
457 }
458
459 fn handle_request(
460 &mut self,
461 req: Request,
462 cx: &mut Context<'_>,
463 ) -> Result<State<S, B, X>, DispatchError> {
464 let req = if req.head().expect() {
466 let mut task = self.expect.call(req);
467 match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
468 Poll::Ready(Ok(req)) => {
469 self.send_continue();
470 req
471 }
472 Poll::Pending => return Ok(State::ExpectCall(task)),
473 Poll::Ready(Err(e)) => {
474 let e = e.into();
475 let res: Response = e.into();
476 let (res, body) = res.replace_body(());
477 return self.send_response(res, body.into_body());
478 }
479 }
480 } else {
481 req
482 };
483
484 let mut task = self.service.call(req);
486 match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
487 Poll::Ready(Ok(res)) => {
488 let (res, body) = res.into().replace_body(());
489 self.send_response(res, body)
490 }
491 Poll::Pending => Ok(State::ServiceCall(task)),
492 Poll::Ready(Err(e)) => {
493 let res: Response = e.into().into();
494 let (res, body) = res.replace_body(());
495 self.send_response(res, body.into_body())
496 }
497 }
498 }
499
500 pub(self) fn poll_request(
502 &mut self,
503 cx: &mut Context<'_>,
504 ) -> Result<bool, DispatchError> {
505 if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) {
507 return Ok(false);
508 }
509
510 let mut updated = false;
511 loop {
512 match self.codec.decode(&mut self.read_buf) {
513 Ok(Some(msg)) => {
514 updated = true;
515 self.flags.insert(Flags::STARTED);
516
517 match msg {
518 Message::Item(mut req) => {
519 let pl = self.codec.message_type();
520 req.head_mut().peer_addr = self.peer_addr;
521
522 if let Some(ref on_connect) = self.on_connect {
524 on_connect.set(&mut req.extensions_mut());
525 }
526
527 if pl == MessageType::Stream && self.upgrade.is_some() {
528 self.messages.push_back(DispatcherMessage::Upgrade(req));
529 break;
530 }
531 if pl == MessageType::Payload || pl == MessageType::Stream {
532 let (ps, pl) = Payload::create(false);
533 let (req1, _) =
534 req.replace_payload(crate::Payload::H1(pl));
535 req = req1;
536 self.payload = Some(ps);
537 }
538
539 if self.state.is_empty() {
541 self.state = self.handle_request(req, cx)?;
542 } else {
543 self.messages.push_back(DispatcherMessage::Item(req));
544 }
545 }
546 Message::Chunk(Some(chunk)) => {
547 if let Some(ref mut payload) = self.payload {
548 payload.feed_data(chunk);
549 } else {
550 error!(
551 "Internal server error: unexpected payload chunk"
552 );
553 self.flags.insert(Flags::READ_DISCONNECT);
554 self.messages.push_back(DispatcherMessage::Error(
555 Response::InternalServerError().finish().drop_body(),
556 ));
557 self.error = Some(DispatchError::InternalError);
558 break;
559 }
560 }
561 Message::Chunk(None) => {
562 if let Some(mut payload) = self.payload.take() {
563 payload.feed_eof();
564 } else {
565 error!("Internal server error: unexpected eof");
566 self.flags.insert(Flags::READ_DISCONNECT);
567 self.messages.push_back(DispatcherMessage::Error(
568 Response::InternalServerError().finish().drop_body(),
569 ));
570 self.error = Some(DispatchError::InternalError);
571 break;
572 }
573 }
574 }
575 }
576 Ok(None) => break,
577 Err(ParseError::Io(e)) => {
578 self.client_disconnected();
579 self.error = Some(DispatchError::Io(e));
580 break;
581 }
582 Err(e) => {
583 if let Some(mut payload) = self.payload.take() {
584 payload.set_error(PayloadError::EncodingCorrupted);
585 }
586
587 self.messages.push_back(DispatcherMessage::Error(
589 Response::BadRequest().finish().drop_body(),
590 ));
591 self.flags.insert(Flags::READ_DISCONNECT);
592 self.error = Some(e.into());
593 break;
594 }
595 }
596 }
597
598 if updated && self.ka_timer.is_some() {
599 if let Some(expire) = self.codec.config().keep_alive_expire() {
600 self.ka_expire = expire;
601 }
602 }
603 Ok(updated)
604 }
605
606 fn poll_keepalive(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchError> {
608 if self.ka_timer.is_none() {
609 if self.flags.contains(Flags::SHUTDOWN) {
611 if let Some(interval) = self.codec.config().client_disconnect_timer() {
612 self.ka_timer = Some(delay_until(interval));
613 } else {
614 self.flags.insert(Flags::READ_DISCONNECT);
615 if let Some(mut payload) = self.payload.take() {
616 payload.set_error(PayloadError::Incomplete(None));
617 }
618 return Ok(());
619 }
620 } else {
621 return Ok(());
622 }
623 }
624
625 match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) {
626 Poll::Ready(()) => {
627 if self.flags.contains(Flags::SHUTDOWN) {
629 return Err(DispatchError::DisconnectTimeout);
630 } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire {
631 if self.state.is_empty() && self.write_buf.is_empty() {
633 if self.flags.contains(Flags::STARTED) {
634 trace!("Keep-alive timeout, close connection");
635 self.flags.insert(Flags::SHUTDOWN);
636
637 if let Some(deadline) =
639 self.codec.config().client_disconnect_timer()
640 {
641 if let Some(mut timer) = self.ka_timer.as_mut() {
642 timer.reset(deadline);
643 let _ = Pin::new(&mut timer).poll(cx);
644 }
645 } else {
646 self.flags.insert(Flags::WRITE_DISCONNECT);
648 return Ok(());
649 }
650 } else {
651 if !self.flags.contains(Flags::STARTED) {
653 trace!("Slow request timeout");
654 let _ = self.send_response(
655 Response::RequestTimeout().finish().drop_body(),
656 ResponseBody::Other(Body::Empty),
657 );
658 } else {
659 trace!("Keep-alive connection timeout");
660 }
661 self.flags.insert(Flags::STARTED | Flags::SHUTDOWN);
662 self.state = State::None;
663 }
664 } else if let Some(deadline) =
665 self.codec.config().keep_alive_expire()
666 {
667 if let Some(mut timer) = self.ka_timer.as_mut() {
668 timer.reset(deadline);
669 let _ = Pin::new(&mut timer).poll(cx);
670 }
671 }
672 } else if let Some(mut timer) = self.ka_timer.as_mut() {
673 timer.reset(self.ka_expire);
674 let _ = Pin::new(&mut timer).poll(cx);
675 }
676 }
677 Poll::Pending => (),
678 }
679
680 Ok(())
681 }
682}
683
684impl<T, S, B, X, U> Unpin for Dispatcher<T, S, B, X, U>
685where
686 T: AsyncRead + AsyncWrite + Unpin,
687 S: Service<Request = Request>,
688 S::Error: Into<Error>,
689 S::Response: Into<Response<B>>,
690 B: MessageBody,
691 X: Service<Request = Request, Response = Request>,
692 X::Error: Into<Error>,
693 U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
694 U::Error: fmt::Display,
695{
696}
697
698impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
699where
700 T: AsyncRead + AsyncWrite + Unpin,
701 S: Service<Request = Request>,
702 S::Error: Into<Error>,
703 S::Response: Into<Response<B>>,
704 B: MessageBody,
705 X: Service<Request = Request, Response = Request>,
706 X::Error: Into<Error>,
707 U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
708 U::Error: fmt::Display,
709{
710 type Output = Result<(), DispatchError>;
711
712 #[inline]
713 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
714 match self.as_mut().inner {
715 DispatcherState::Normal(ref mut inner) => {
716 inner.poll_keepalive(cx)?;
717
718 if inner.flags.contains(Flags::SHUTDOWN) {
719 if inner.flags.contains(Flags::WRITE_DISCONNECT) {
720 Poll::Ready(Ok(()))
721 } else {
722 inner.poll_flush(cx)?;
724 if !inner.write_buf.is_empty() {
725 Poll::Pending
726 } else {
727 match Pin::new(&mut inner.io).poll_shutdown(cx) {
728 Poll::Ready(res) => {
729 Poll::Ready(res.map_err(DispatchError::from))
730 }
731 Poll::Pending => Poll::Pending,
732 }
733 }
734 }
735 } else {
736 let should_disconnect =
738 if !inner.flags.contains(Flags::READ_DISCONNECT) {
739 read_available(cx, &mut inner.io, &mut inner.read_buf)?
740 } else {
741 None
742 };
743
744 inner.poll_request(cx)?;
745 if let Some(true) = should_disconnect {
746 inner.flags.insert(Flags::READ_DISCONNECT);
747 if let Some(mut payload) = inner.payload.take() {
748 payload.feed_eof();
749 }
750 };
751
752 loop {
753 let remaining =
754 inner.write_buf.capacity() - inner.write_buf.len();
755 if remaining < LW_BUFFER_SIZE {
756 inner.write_buf.reserve(HW_BUFFER_SIZE - remaining);
757 }
758 let result = inner.poll_response(cx)?;
759 let drain = result == PollResponse::DrainWriteBuf;
760
761 if let PollResponse::Upgrade(req) = result {
763 if let DispatcherState::Normal(inner) =
764 std::mem::replace(&mut self.inner, DispatcherState::None)
765 {
766 let mut parts = FramedParts::with_read_buf(
767 inner.io,
768 inner.codec,
769 inner.read_buf,
770 );
771 parts.write_buf = inner.write_buf;
772 let framed = Framed::from_parts(parts);
773 self.inner = DispatcherState::Upgrade(
774 inner.upgrade.unwrap().call((req, framed)),
775 );
776 return self.poll(cx);
777 } else {
778 panic!()
779 }
780 }
781
782 if inner.poll_flush(cx)? || !drain {
786 break;
787 }
788 }
789
790 if inner.flags.contains(Flags::WRITE_DISCONNECT) {
792 return Poll::Ready(Ok(()));
793 }
794
795 let is_empty = inner.state.is_empty();
796
797 if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty {
799 inner.flags.insert(Flags::SHUTDOWN);
800 }
801
802 if is_empty && inner.write_buf.is_empty() {
804 if let Some(err) = inner.error.take() {
805 Poll::Ready(Err(err))
806 }
807 else if inner.flags.contains(Flags::STARTED)
809 && !inner.flags.intersects(Flags::KEEPALIVE)
810 {
811 inner.flags.insert(Flags::SHUTDOWN);
812 self.poll(cx)
813 }
814 else if inner.flags.contains(Flags::SHUTDOWN) {
816 self.poll(cx)
817 } else {
818 Poll::Pending
819 }
820 } else {
821 Poll::Pending
822 }
823 }
824 }
825 DispatcherState::Upgrade(ref mut fut) => {
826 unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| {
827 error!("Upgrade handler error: {}", e);
828 DispatchError::Upgrade
829 })
830 }
831 DispatcherState::None => panic!(),
832 }
833 }
834}
835
836fn read_available<T>(
837 cx: &mut Context<'_>,
838 io: &mut T,
839 buf: &mut BytesMut,
840) -> Result<Option<bool>, io::Error>
841where
842 T: AsyncRead + Unpin,
843{
844 let mut read_some = false;
845 loop {
846 let remaining = buf.capacity() - buf.len();
847 if remaining < LW_BUFFER_SIZE {
848 buf.reserve(HW_BUFFER_SIZE - remaining);
849 }
850
851 match read(cx, io, buf) {
852 Poll::Pending => {
853 return if read_some { Ok(Some(false)) } else { Ok(None) };
854 }
855 Poll::Ready(Ok(n)) => {
856 if n == 0 {
857 return Ok(Some(true));
858 } else {
859 read_some = true;
860 }
861 }
862 Poll::Ready(Err(e)) => {
863 return if e.kind() == io::ErrorKind::WouldBlock {
864 if read_some {
865 Ok(Some(false))
866 } else {
867 Ok(None)
868 }
869 } else if e.kind() == io::ErrorKind::ConnectionReset && read_some {
870 Ok(Some(true))
871 } else {
872 Err(e)
873 }
874 }
875 }
876 }
877}
878
879fn read<T>(
880 cx: &mut Context<'_>,
881 io: &mut T,
882 buf: &mut BytesMut,
883) -> Poll<Result<usize, io::Error>>
884where
885 T: AsyncRead + Unpin,
886{
887 Pin::new(io).poll_read_buf(cx, buf)
888}
889
890#[cfg(test)]
891mod tests {
892 use actori_service::IntoService;
893 use futures_util::future::{lazy, ok};
894
895 use super::*;
896 use crate::error::Error;
897 use crate::h1::{ExpectHandler, UpgradeHandler};
898 use crate::test::TestBuffer;
899
900 #[actori_rt::test]
901 async fn test_req_parse_err() {
902 lazy(|cx| {
903 let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n");
904
905 let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<TestBuffer>>::new(
906 buf,
907 ServiceConfig::default(),
908 CloneableService::new(
909 (|_| ok::<_, Error>(Response::Ok().finish())).into_service(),
910 ),
911 CloneableService::new(ExpectHandler),
912 None,
913 None,
914 None,
915 );
916 match Pin::new(&mut h1).poll(cx) {
917 Poll::Pending => panic!(),
918 Poll::Ready(res) => assert!(res.is_err()),
919 }
920
921 if let DispatcherState::Normal(ref inner) = h1.inner {
922 assert!(inner.flags.contains(Flags::READ_DISCONNECT));
923 assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
924 }
925 })
926 .await;
927 }
928}