actori_http/h1/
dispatcher.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::{fmt, io, net};
6
7use actori_codec::{AsyncRead, AsyncWrite, Decoder, Encoder, Framed, FramedParts};
8use actori_rt::time::{delay_until, Delay, Instant};
9use actori_service::Service;
10use bitflags::bitflags;
11use bytes::{Buf, BytesMut};
12use log::{error, trace};
13
14use crate::body::{Body, BodySize, MessageBody, ResponseBody};
15use crate::cloneable::CloneableService;
16use crate::config::ServiceConfig;
17use crate::error::{DispatchError, Error};
18use crate::error::{ParseError, PayloadError};
19use crate::helpers::DataFactory;
20use crate::httpmessage::HttpMessage;
21use crate::request::Request;
22use crate::response::Response;
23
24use super::codec::Codec;
25use super::payload::{Payload, PayloadSender, PayloadStatus};
26use super::{Message, MessageType};
27
28const LW_BUFFER_SIZE: usize = 4096;
29const HW_BUFFER_SIZE: usize = 32_768;
30const MAX_PIPELINED_MESSAGES: usize = 16;
31
32bitflags! {
33    pub struct Flags: u8 {
34        const STARTED            = 0b0000_0001;
35        const KEEPALIVE          = 0b0000_0010;
36        const POLLED             = 0b0000_0100;
37        const SHUTDOWN           = 0b0000_1000;
38        const READ_DISCONNECT    = 0b0001_0000;
39        const WRITE_DISCONNECT   = 0b0010_0000;
40        const UPGRADE            = 0b0100_0000;
41    }
42}
43
44/// Dispatcher for HTTP/1.1 protocol
45pub struct Dispatcher<T, S, B, X, U>
46where
47    S: Service<Request = Request>,
48    S::Error: Into<Error>,
49    B: MessageBody,
50    X: Service<Request = Request, Response = Request>,
51    X::Error: Into<Error>,
52    U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
53    U::Error: fmt::Display,
54{
55    inner: DispatcherState<T, S, B, X, U>,
56}
57
58enum DispatcherState<T, S, B, X, U>
59where
60    S: Service<Request = Request>,
61    S::Error: Into<Error>,
62    B: MessageBody,
63    X: Service<Request = Request, Response = Request>,
64    X::Error: Into<Error>,
65    U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
66    U::Error: fmt::Display,
67{
68    Normal(InnerDispatcher<T, S, B, X, U>),
69    Upgrade(U::Future),
70    None,
71}
72
73struct InnerDispatcher<T, S, B, X, U>
74where
75    S: Service<Request = Request>,
76    S::Error: Into<Error>,
77    B: MessageBody,
78    X: Service<Request = Request, Response = Request>,
79    X::Error: Into<Error>,
80    U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
81    U::Error: fmt::Display,
82{
83    service: CloneableService<S>,
84    expect: CloneableService<X>,
85    upgrade: Option<CloneableService<U>>,
86    on_connect: Option<Box<dyn DataFactory>>,
87    flags: Flags,
88    peer_addr: Option<net::SocketAddr>,
89    error: Option<DispatchError>,
90
91    state: State<S, B, X>,
92    payload: Option<PayloadSender>,
93    messages: VecDeque<DispatcherMessage>,
94
95    ka_expire: Instant,
96    ka_timer: Option<Delay>,
97
98    io: T,
99    read_buf: BytesMut,
100    write_buf: BytesMut,
101    codec: Codec,
102}
103
104enum DispatcherMessage {
105    Item(Request),
106    Upgrade(Request),
107    Error(Response<()>),
108}
109
110enum State<S, B, X>
111where
112    S: Service<Request = Request>,
113    X: Service<Request = Request, Response = Request>,
114    B: MessageBody,
115{
116    None,
117    ExpectCall(X::Future),
118    ServiceCall(S::Future),
119    SendPayload(ResponseBody<B>),
120}
121
122impl<S, B, X> State<S, B, X>
123where
124    S: Service<Request = Request>,
125    X: Service<Request = Request, Response = Request>,
126    B: MessageBody,
127{
128    fn is_empty(&self) -> bool {
129        if let State::None = self {
130            true
131        } else {
132            false
133        }
134    }
135
136    fn is_call(&self) -> bool {
137        if let State::ServiceCall(_) = self {
138            true
139        } else {
140            false
141        }
142    }
143}
144
145enum PollResponse {
146    Upgrade(Request),
147    DoNothing,
148    DrainWriteBuf,
149}
150
151impl PartialEq for PollResponse {
152    fn eq(&self, other: &PollResponse) -> bool {
153        match self {
154            PollResponse::DrainWriteBuf => match other {
155                PollResponse::DrainWriteBuf => true,
156                _ => false,
157            },
158            PollResponse::DoNothing => match other {
159                PollResponse::DoNothing => true,
160                _ => false,
161            },
162            _ => false,
163        }
164    }
165}
166
167impl<T, S, B, X, U> Dispatcher<T, S, B, X, U>
168where
169    T: AsyncRead + AsyncWrite + Unpin,
170    S: Service<Request = Request>,
171    S::Error: Into<Error>,
172    S::Response: Into<Response<B>>,
173    B: MessageBody,
174    X: Service<Request = Request, Response = Request>,
175    X::Error: Into<Error>,
176    U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
177    U::Error: fmt::Display,
178{
179    /// Create http/1 dispatcher.
180    pub(crate) fn new(
181        stream: T,
182        config: ServiceConfig,
183        service: CloneableService<S>,
184        expect: CloneableService<X>,
185        upgrade: Option<CloneableService<U>>,
186        on_connect: Option<Box<dyn DataFactory>>,
187        peer_addr: Option<net::SocketAddr>,
188    ) -> Self {
189        Dispatcher::with_timeout(
190            stream,
191            Codec::new(config.clone()),
192            config,
193            BytesMut::with_capacity(HW_BUFFER_SIZE),
194            None,
195            service,
196            expect,
197            upgrade,
198            on_connect,
199            peer_addr,
200        )
201    }
202
203    /// Create http/1 dispatcher with slow request timeout.
204    pub(crate) fn with_timeout(
205        io: T,
206        codec: Codec,
207        config: ServiceConfig,
208        read_buf: BytesMut,
209        timeout: Option<Delay>,
210        service: CloneableService<S>,
211        expect: CloneableService<X>,
212        upgrade: Option<CloneableService<U>>,
213        on_connect: Option<Box<dyn DataFactory>>,
214        peer_addr: Option<net::SocketAddr>,
215    ) -> Self {
216        let keepalive = config.keep_alive_enabled();
217        let flags = if keepalive {
218            Flags::KEEPALIVE
219        } else {
220            Flags::empty()
221        };
222
223        // keep-alive timer
224        let (ka_expire, ka_timer) = if let Some(delay) = timeout {
225            (delay.deadline(), Some(delay))
226        } else if let Some(delay) = config.keep_alive_timer() {
227            (delay.deadline(), Some(delay))
228        } else {
229            (config.now(), None)
230        };
231
232        Dispatcher {
233            inner: DispatcherState::Normal(InnerDispatcher {
234                write_buf: BytesMut::with_capacity(HW_BUFFER_SIZE),
235                payload: None,
236                state: State::None,
237                error: None,
238                messages: VecDeque::new(),
239                io,
240                codec,
241                read_buf,
242                service,
243                expect,
244                upgrade,
245                on_connect,
246                flags,
247                peer_addr,
248                ka_expire,
249                ka_timer,
250            }),
251        }
252    }
253}
254
255impl<T, S, B, X, U> InnerDispatcher<T, S, B, X, U>
256where
257    T: AsyncRead + AsyncWrite + Unpin,
258    S: Service<Request = Request>,
259    S::Error: Into<Error>,
260    S::Response: Into<Response<B>>,
261    B: MessageBody,
262    X: Service<Request = Request, Response = Request>,
263    X::Error: Into<Error>,
264    U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
265    U::Error: fmt::Display,
266{
267    fn can_read(&self, cx: &mut Context<'_>) -> bool {
268        if self
269            .flags
270            .intersects(Flags::READ_DISCONNECT | Flags::UPGRADE)
271        {
272            false
273        } else if let Some(ref info) = self.payload {
274            info.need_read(cx) == PayloadStatus::Read
275        } else {
276            true
277        }
278    }
279
280    // if checked is set to true, delay disconnect until all tasks have finished.
281    fn client_disconnected(&mut self) {
282        self.flags
283            .insert(Flags::READ_DISCONNECT | Flags::WRITE_DISCONNECT);
284        if let Some(mut payload) = self.payload.take() {
285            payload.set_error(PayloadError::Incomplete(None));
286        }
287    }
288
289    /// Flush stream
290    ///
291    /// true - got whouldblock
292    /// false - didnt get whouldblock
293    fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<bool, DispatchError> {
294        if self.write_buf.is_empty() {
295            return Ok(false);
296        }
297
298        let len = self.write_buf.len();
299        let mut written = 0;
300        while written < len {
301            match unsafe { Pin::new_unchecked(&mut self.io) }
302                .poll_write(cx, &self.write_buf[written..])
303            {
304                Poll::Ready(Ok(0)) => {
305                    return Err(DispatchError::Io(io::Error::new(
306                        io::ErrorKind::WriteZero,
307                        "",
308                    )));
309                }
310                Poll::Ready(Ok(n)) => {
311                    written += n;
312                }
313                Poll::Pending => {
314                    if written > 0 {
315                        self.write_buf.advance(written);
316                    }
317                    return Ok(true);
318                }
319                Poll::Ready(Err(err)) => return Err(DispatchError::Io(err)),
320            }
321        }
322        if written == self.write_buf.len() {
323            unsafe { self.write_buf.set_len(0) }
324        } else {
325            self.write_buf.advance(written);
326        }
327        Ok(false)
328    }
329
330    fn send_response(
331        &mut self,
332        message: Response<()>,
333        body: ResponseBody<B>,
334    ) -> Result<State<S, B, X>, DispatchError> {
335        self.codec
336            .encode(Message::Item((message, body.size())), &mut self.write_buf)
337            .map_err(|err| {
338                if let Some(mut payload) = self.payload.take() {
339                    payload.set_error(PayloadError::Incomplete(None));
340                }
341                DispatchError::Io(err)
342            })?;
343
344        self.flags.set(Flags::KEEPALIVE, self.codec.keepalive());
345        match body.size() {
346            BodySize::None | BodySize::Empty => Ok(State::None),
347            _ => Ok(State::SendPayload(body)),
348        }
349    }
350
351    fn send_continue(&mut self) {
352        self.write_buf
353            .extend_from_slice(b"HTTP/1.1 100 Continue\r\n\r\n");
354    }
355
356    fn poll_response(
357        &mut self,
358        cx: &mut Context<'_>,
359    ) -> Result<PollResponse, DispatchError> {
360        loop {
361            let state = match self.state {
362                State::None => match self.messages.pop_front() {
363                    Some(DispatcherMessage::Item(req)) => {
364                        Some(self.handle_request(req, cx)?)
365                    }
366                    Some(DispatcherMessage::Error(res)) => {
367                        Some(self.send_response(res, ResponseBody::Other(Body::Empty))?)
368                    }
369                    Some(DispatcherMessage::Upgrade(req)) => {
370                        return Ok(PollResponse::Upgrade(req));
371                    }
372                    None => None,
373                },
374                State::ExpectCall(ref mut fut) => {
375                    match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
376                        Poll::Ready(Ok(req)) => {
377                            self.send_continue();
378                            self.state = State::ServiceCall(self.service.call(req));
379                            continue;
380                        }
381                        Poll::Ready(Err(e)) => {
382                            let res: Response = e.into().into();
383                            let (res, body) = res.replace_body(());
384                            Some(self.send_response(res, body.into_body())?)
385                        }
386                        Poll::Pending => None,
387                    }
388                }
389                State::ServiceCall(ref mut fut) => {
390                    match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
391                        Poll::Ready(Ok(res)) => {
392                            let (res, body) = res.into().replace_body(());
393                            self.state = self.send_response(res, body)?;
394                            continue;
395                        }
396                        Poll::Ready(Err(e)) => {
397                            let res: Response = e.into().into();
398                            let (res, body) = res.replace_body(());
399                            Some(self.send_response(res, body.into_body())?)
400                        }
401                        Poll::Pending => None,
402                    }
403                }
404                State::SendPayload(ref mut stream) => {
405                    loop {
406                        if self.write_buf.len() < HW_BUFFER_SIZE {
407                            match stream.poll_next(cx) {
408                                Poll::Ready(Some(Ok(item))) => {
409                                    self.codec.encode(
410                                        Message::Chunk(Some(item)),
411                                        &mut self.write_buf,
412                                    )?;
413                                    continue;
414                                }
415                                Poll::Ready(None) => {
416                                    self.codec.encode(
417                                        Message::Chunk(None),
418                                        &mut self.write_buf,
419                                    )?;
420                                    self.state = State::None;
421                                }
422                                Poll::Ready(Some(Err(_))) => {
423                                    return Err(DispatchError::Unknown)
424                                }
425                                Poll::Pending => return Ok(PollResponse::DoNothing),
426                            }
427                        } else {
428                            return Ok(PollResponse::DrainWriteBuf);
429                        }
430                        break;
431                    }
432                    continue;
433                }
434            };
435
436            // set new state
437            if let Some(state) = state {
438                self.state = state;
439                if !self.state.is_empty() {
440                    continue;
441                }
442            } else {
443                // if read-backpressure is enabled and we consumed some data.
444                // we may read more data and retry
445                if self.state.is_call() {
446                    if self.poll_request(cx)? {
447                        continue;
448                    }
449                } else if !self.messages.is_empty() {
450                    continue;
451                }
452            }
453            break;
454        }
455
456        Ok(PollResponse::DoNothing)
457    }
458
459    fn handle_request(
460        &mut self,
461        req: Request,
462        cx: &mut Context<'_>,
463    ) -> Result<State<S, B, X>, DispatchError> {
464        // Handle `EXPECT: 100-Continue` header
465        let req = if req.head().expect() {
466            let mut task = self.expect.call(req);
467            match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
468                Poll::Ready(Ok(req)) => {
469                    self.send_continue();
470                    req
471                }
472                Poll::Pending => return Ok(State::ExpectCall(task)),
473                Poll::Ready(Err(e)) => {
474                    let e = e.into();
475                    let res: Response = e.into();
476                    let (res, body) = res.replace_body(());
477                    return self.send_response(res, body.into_body());
478                }
479            }
480        } else {
481            req
482        };
483
484        // Call service
485        let mut task = self.service.call(req);
486        match unsafe { Pin::new_unchecked(&mut task) }.poll(cx) {
487            Poll::Ready(Ok(res)) => {
488                let (res, body) = res.into().replace_body(());
489                self.send_response(res, body)
490            }
491            Poll::Pending => Ok(State::ServiceCall(task)),
492            Poll::Ready(Err(e)) => {
493                let res: Response = e.into().into();
494                let (res, body) = res.replace_body(());
495                self.send_response(res, body.into_body())
496            }
497        }
498    }
499
500    /// Process one incoming requests
501    pub(self) fn poll_request(
502        &mut self,
503        cx: &mut Context<'_>,
504    ) -> Result<bool, DispatchError> {
505        // limit a mount of non processed requests
506        if self.messages.len() >= MAX_PIPELINED_MESSAGES || !self.can_read(cx) {
507            return Ok(false);
508        }
509
510        let mut updated = false;
511        loop {
512            match self.codec.decode(&mut self.read_buf) {
513                Ok(Some(msg)) => {
514                    updated = true;
515                    self.flags.insert(Flags::STARTED);
516
517                    match msg {
518                        Message::Item(mut req) => {
519                            let pl = self.codec.message_type();
520                            req.head_mut().peer_addr = self.peer_addr;
521
522                            // set on_connect data
523                            if let Some(ref on_connect) = self.on_connect {
524                                on_connect.set(&mut req.extensions_mut());
525                            }
526
527                            if pl == MessageType::Stream && self.upgrade.is_some() {
528                                self.messages.push_back(DispatcherMessage::Upgrade(req));
529                                break;
530                            }
531                            if pl == MessageType::Payload || pl == MessageType::Stream {
532                                let (ps, pl) = Payload::create(false);
533                                let (req1, _) =
534                                    req.replace_payload(crate::Payload::H1(pl));
535                                req = req1;
536                                self.payload = Some(ps);
537                            }
538
539                            // handle request early
540                            if self.state.is_empty() {
541                                self.state = self.handle_request(req, cx)?;
542                            } else {
543                                self.messages.push_back(DispatcherMessage::Item(req));
544                            }
545                        }
546                        Message::Chunk(Some(chunk)) => {
547                            if let Some(ref mut payload) = self.payload {
548                                payload.feed_data(chunk);
549                            } else {
550                                error!(
551                                    "Internal server error: unexpected payload chunk"
552                                );
553                                self.flags.insert(Flags::READ_DISCONNECT);
554                                self.messages.push_back(DispatcherMessage::Error(
555                                    Response::InternalServerError().finish().drop_body(),
556                                ));
557                                self.error = Some(DispatchError::InternalError);
558                                break;
559                            }
560                        }
561                        Message::Chunk(None) => {
562                            if let Some(mut payload) = self.payload.take() {
563                                payload.feed_eof();
564                            } else {
565                                error!("Internal server error: unexpected eof");
566                                self.flags.insert(Flags::READ_DISCONNECT);
567                                self.messages.push_back(DispatcherMessage::Error(
568                                    Response::InternalServerError().finish().drop_body(),
569                                ));
570                                self.error = Some(DispatchError::InternalError);
571                                break;
572                            }
573                        }
574                    }
575                }
576                Ok(None) => break,
577                Err(ParseError::Io(e)) => {
578                    self.client_disconnected();
579                    self.error = Some(DispatchError::Io(e));
580                    break;
581                }
582                Err(e) => {
583                    if let Some(mut payload) = self.payload.take() {
584                        payload.set_error(PayloadError::EncodingCorrupted);
585                    }
586
587                    // Malformed requests should be responded with 400
588                    self.messages.push_back(DispatcherMessage::Error(
589                        Response::BadRequest().finish().drop_body(),
590                    ));
591                    self.flags.insert(Flags::READ_DISCONNECT);
592                    self.error = Some(e.into());
593                    break;
594                }
595            }
596        }
597
598        if updated && self.ka_timer.is_some() {
599            if let Some(expire) = self.codec.config().keep_alive_expire() {
600                self.ka_expire = expire;
601            }
602        }
603        Ok(updated)
604    }
605
606    /// keep-alive timer
607    fn poll_keepalive(&mut self, cx: &mut Context<'_>) -> Result<(), DispatchError> {
608        if self.ka_timer.is_none() {
609            // shutdown timeout
610            if self.flags.contains(Flags::SHUTDOWN) {
611                if let Some(interval) = self.codec.config().client_disconnect_timer() {
612                    self.ka_timer = Some(delay_until(interval));
613                } else {
614                    self.flags.insert(Flags::READ_DISCONNECT);
615                    if let Some(mut payload) = self.payload.take() {
616                        payload.set_error(PayloadError::Incomplete(None));
617                    }
618                    return Ok(());
619                }
620            } else {
621                return Ok(());
622            }
623        }
624
625        match Pin::new(&mut self.ka_timer.as_mut().unwrap()).poll(cx) {
626            Poll::Ready(()) => {
627                // if we get timeout during shutdown, drop connection
628                if self.flags.contains(Flags::SHUTDOWN) {
629                    return Err(DispatchError::DisconnectTimeout);
630                } else if self.ka_timer.as_mut().unwrap().deadline() >= self.ka_expire {
631                    // check for any outstanding tasks
632                    if self.state.is_empty() && self.write_buf.is_empty() {
633                        if self.flags.contains(Flags::STARTED) {
634                            trace!("Keep-alive timeout, close connection");
635                            self.flags.insert(Flags::SHUTDOWN);
636
637                            // start shutdown timer
638                            if let Some(deadline) =
639                                self.codec.config().client_disconnect_timer()
640                            {
641                                if let Some(mut timer) = self.ka_timer.as_mut() {
642                                    timer.reset(deadline);
643                                    let _ = Pin::new(&mut timer).poll(cx);
644                                }
645                            } else {
646                                // no shutdown timeout, drop socket
647                                self.flags.insert(Flags::WRITE_DISCONNECT);
648                                return Ok(());
649                            }
650                        } else {
651                            // timeout on first request (slow request) return 408
652                            if !self.flags.contains(Flags::STARTED) {
653                                trace!("Slow request timeout");
654                                let _ = self.send_response(
655                                    Response::RequestTimeout().finish().drop_body(),
656                                    ResponseBody::Other(Body::Empty),
657                                );
658                            } else {
659                                trace!("Keep-alive connection timeout");
660                            }
661                            self.flags.insert(Flags::STARTED | Flags::SHUTDOWN);
662                            self.state = State::None;
663                        }
664                    } else if let Some(deadline) =
665                        self.codec.config().keep_alive_expire()
666                    {
667                        if let Some(mut timer) = self.ka_timer.as_mut() {
668                            timer.reset(deadline);
669                            let _ = Pin::new(&mut timer).poll(cx);
670                        }
671                    }
672                } else if let Some(mut timer) = self.ka_timer.as_mut() {
673                    timer.reset(self.ka_expire);
674                    let _ = Pin::new(&mut timer).poll(cx);
675                }
676            }
677            Poll::Pending => (),
678        }
679
680        Ok(())
681    }
682}
683
684impl<T, S, B, X, U> Unpin for Dispatcher<T, S, B, X, U>
685where
686    T: AsyncRead + AsyncWrite + Unpin,
687    S: Service<Request = Request>,
688    S::Error: Into<Error>,
689    S::Response: Into<Response<B>>,
690    B: MessageBody,
691    X: Service<Request = Request, Response = Request>,
692    X::Error: Into<Error>,
693    U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
694    U::Error: fmt::Display,
695{
696}
697
698impl<T, S, B, X, U> Future for Dispatcher<T, S, B, X, U>
699where
700    T: AsyncRead + AsyncWrite + Unpin,
701    S: Service<Request = Request>,
702    S::Error: Into<Error>,
703    S::Response: Into<Response<B>>,
704    B: MessageBody,
705    X: Service<Request = Request, Response = Request>,
706    X::Error: Into<Error>,
707    U: Service<Request = (Request, Framed<T, Codec>), Response = ()>,
708    U::Error: fmt::Display,
709{
710    type Output = Result<(), DispatchError>;
711
712    #[inline]
713    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
714        match self.as_mut().inner {
715            DispatcherState::Normal(ref mut inner) => {
716                inner.poll_keepalive(cx)?;
717
718                if inner.flags.contains(Flags::SHUTDOWN) {
719                    if inner.flags.contains(Flags::WRITE_DISCONNECT) {
720                        Poll::Ready(Ok(()))
721                    } else {
722                        // flush buffer
723                        inner.poll_flush(cx)?;
724                        if !inner.write_buf.is_empty() {
725                            Poll::Pending
726                        } else {
727                            match Pin::new(&mut inner.io).poll_shutdown(cx) {
728                                Poll::Ready(res) => {
729                                    Poll::Ready(res.map_err(DispatchError::from))
730                                }
731                                Poll::Pending => Poll::Pending,
732                            }
733                        }
734                    }
735                } else {
736                    // read socket into a buf
737                    let should_disconnect =
738                        if !inner.flags.contains(Flags::READ_DISCONNECT) {
739                            read_available(cx, &mut inner.io, &mut inner.read_buf)?
740                        } else {
741                            None
742                        };
743
744                    inner.poll_request(cx)?;
745                    if let Some(true) = should_disconnect {
746                        inner.flags.insert(Flags::READ_DISCONNECT);
747                        if let Some(mut payload) = inner.payload.take() {
748                            payload.feed_eof();
749                        }
750                    };
751
752                    loop {
753                        let remaining =
754                            inner.write_buf.capacity() - inner.write_buf.len();
755                        if remaining < LW_BUFFER_SIZE {
756                            inner.write_buf.reserve(HW_BUFFER_SIZE - remaining);
757                        }
758                        let result = inner.poll_response(cx)?;
759                        let drain = result == PollResponse::DrainWriteBuf;
760
761                        // switch to upgrade handler
762                        if let PollResponse::Upgrade(req) = result {
763                            if let DispatcherState::Normal(inner) =
764                                std::mem::replace(&mut self.inner, DispatcherState::None)
765                            {
766                                let mut parts = FramedParts::with_read_buf(
767                                    inner.io,
768                                    inner.codec,
769                                    inner.read_buf,
770                                );
771                                parts.write_buf = inner.write_buf;
772                                let framed = Framed::from_parts(parts);
773                                self.inner = DispatcherState::Upgrade(
774                                    inner.upgrade.unwrap().call((req, framed)),
775                                );
776                                return self.poll(cx);
777                            } else {
778                                panic!()
779                            }
780                        }
781
782                        // we didnt get WouldBlock from write operation,
783                        // so data get written to kernel completely (OSX)
784                        // and we have to write again otherwise response can get stuck
785                        if inner.poll_flush(cx)? || !drain {
786                            break;
787                        }
788                    }
789
790                    // client is gone
791                    if inner.flags.contains(Flags::WRITE_DISCONNECT) {
792                        return Poll::Ready(Ok(()));
793                    }
794
795                    let is_empty = inner.state.is_empty();
796
797                    // read half is closed and we do not processing any responses
798                    if inner.flags.contains(Flags::READ_DISCONNECT) && is_empty {
799                        inner.flags.insert(Flags::SHUTDOWN);
800                    }
801
802                    // keep-alive and stream errors
803                    if is_empty && inner.write_buf.is_empty() {
804                        if let Some(err) = inner.error.take() {
805                            Poll::Ready(Err(err))
806                        }
807                        // disconnect if keep-alive is not enabled
808                        else if inner.flags.contains(Flags::STARTED)
809                            && !inner.flags.intersects(Flags::KEEPALIVE)
810                        {
811                            inner.flags.insert(Flags::SHUTDOWN);
812                            self.poll(cx)
813                        }
814                        // disconnect if shutdown
815                        else if inner.flags.contains(Flags::SHUTDOWN) {
816                            self.poll(cx)
817                        } else {
818                            Poll::Pending
819                        }
820                    } else {
821                        Poll::Pending
822                    }
823                }
824            }
825            DispatcherState::Upgrade(ref mut fut) => {
826                unsafe { Pin::new_unchecked(fut) }.poll(cx).map_err(|e| {
827                    error!("Upgrade handler error: {}", e);
828                    DispatchError::Upgrade
829                })
830            }
831            DispatcherState::None => panic!(),
832        }
833    }
834}
835
836fn read_available<T>(
837    cx: &mut Context<'_>,
838    io: &mut T,
839    buf: &mut BytesMut,
840) -> Result<Option<bool>, io::Error>
841where
842    T: AsyncRead + Unpin,
843{
844    let mut read_some = false;
845    loop {
846        let remaining = buf.capacity() - buf.len();
847        if remaining < LW_BUFFER_SIZE {
848            buf.reserve(HW_BUFFER_SIZE - remaining);
849        }
850
851        match read(cx, io, buf) {
852            Poll::Pending => {
853                return if read_some { Ok(Some(false)) } else { Ok(None) };
854            }
855            Poll::Ready(Ok(n)) => {
856                if n == 0 {
857                    return Ok(Some(true));
858                } else {
859                    read_some = true;
860                }
861            }
862            Poll::Ready(Err(e)) => {
863                return if e.kind() == io::ErrorKind::WouldBlock {
864                    if read_some {
865                        Ok(Some(false))
866                    } else {
867                        Ok(None)
868                    }
869                } else if e.kind() == io::ErrorKind::ConnectionReset && read_some {
870                    Ok(Some(true))
871                } else {
872                    Err(e)
873                }
874            }
875        }
876    }
877}
878
879fn read<T>(
880    cx: &mut Context<'_>,
881    io: &mut T,
882    buf: &mut BytesMut,
883) -> Poll<Result<usize, io::Error>>
884where
885    T: AsyncRead + Unpin,
886{
887    Pin::new(io).poll_read_buf(cx, buf)
888}
889
890#[cfg(test)]
891mod tests {
892    use actori_service::IntoService;
893    use futures_util::future::{lazy, ok};
894
895    use super::*;
896    use crate::error::Error;
897    use crate::h1::{ExpectHandler, UpgradeHandler};
898    use crate::test::TestBuffer;
899
900    #[actori_rt::test]
901    async fn test_req_parse_err() {
902        lazy(|cx| {
903            let buf = TestBuffer::new("GET /test HTTP/1\r\n\r\n");
904
905            let mut h1 = Dispatcher::<_, _, _, _, UpgradeHandler<TestBuffer>>::new(
906                buf,
907                ServiceConfig::default(),
908                CloneableService::new(
909                    (|_| ok::<_, Error>(Response::Ok().finish())).into_service(),
910                ),
911                CloneableService::new(ExpectHandler),
912                None,
913                None,
914                None,
915            );
916            match Pin::new(&mut h1).poll(cx) {
917                Poll::Pending => panic!(),
918                Poll::Ready(res) => assert!(res.is_err()),
919            }
920
921            if let DispatcherState::Normal(ref inner) = h1.inner {
922                assert!(inner.flags.contains(Flags::READ_DISCONNECT));
923                assert_eq!(&inner.io.write_buf[..26], b"HTTP/1.1 400 Bad Request\r\n");
924            }
925        })
926        .await;
927    }
928}