actori_http/h2/
dispatcher.rs

1use std::convert::TryFrom;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::net;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use actori_codec::{AsyncRead, AsyncWrite};
9use actori_rt::time::{Delay, Instant};
10use actori_service::Service;
11use bytes::{Bytes, BytesMut};
12use h2::server::{Connection, SendResponse};
13use h2::SendStream;
14use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
15use log::{error, trace};
16
17use crate::body::{BodySize, MessageBody, ResponseBody};
18use crate::cloneable::CloneableService;
19use crate::config::ServiceConfig;
20use crate::error::{DispatchError, Error};
21use crate::helpers::DataFactory;
22use crate::httpmessage::HttpMessage;
23use crate::message::ResponseHead;
24use crate::payload::Payload;
25use crate::request::Request;
26use crate::response::Response;
27
28const CHUNK_SIZE: usize = 16_384;
29
30/// Dispatcher for HTTP/2 protocol
31#[pin_project::pin_project]
32pub struct Dispatcher<T, S: Service<Request = Request>, B: MessageBody>
33where
34    T: AsyncRead + AsyncWrite + Unpin,
35{
36    service: CloneableService<S>,
37    connection: Connection<T, Bytes>,
38    on_connect: Option<Box<dyn DataFactory>>,
39    config: ServiceConfig,
40    peer_addr: Option<net::SocketAddr>,
41    ka_expire: Instant,
42    ka_timer: Option<Delay>,
43    _t: PhantomData<B>,
44}
45
46impl<T, S, B> Dispatcher<T, S, B>
47where
48    T: AsyncRead + AsyncWrite + Unpin,
49    S: Service<Request = Request>,
50    S::Error: Into<Error>,
51    // S::Future: 'static,
52    S::Response: Into<Response<B>>,
53    B: MessageBody,
54{
55    pub(crate) fn new(
56        service: CloneableService<S>,
57        connection: Connection<T, Bytes>,
58        on_connect: Option<Box<dyn DataFactory>>,
59        config: ServiceConfig,
60        timeout: Option<Delay>,
61        peer_addr: Option<net::SocketAddr>,
62    ) -> Self {
63        // let keepalive = config.keep_alive_enabled();
64        // let flags = if keepalive {
65        // Flags::KEEPALIVE | Flags::KEEPALIVE_ENABLED
66        // } else {
67        //     Flags::empty()
68        // };
69
70        // keep-alive timer
71        let (ka_expire, ka_timer) = if let Some(delay) = timeout {
72            (delay.deadline(), Some(delay))
73        } else if let Some(delay) = config.keep_alive_timer() {
74            (delay.deadline(), Some(delay))
75        } else {
76            (config.now(), None)
77        };
78
79        Dispatcher {
80            service,
81            config,
82            peer_addr,
83            connection,
84            on_connect,
85            ka_expire,
86            ka_timer,
87            _t: PhantomData,
88        }
89    }
90}
91
92impl<T, S, B> Future for Dispatcher<T, S, B>
93where
94    T: AsyncRead + AsyncWrite + Unpin,
95    S: Service<Request = Request>,
96    S::Error: Into<Error> + 'static,
97    S::Future: 'static,
98    S::Response: Into<Response<B>> + 'static,
99    B: MessageBody + 'static,
100{
101    type Output = Result<(), DispatchError>;
102
103    #[inline]
104    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105        let this = self.get_mut();
106
107        loop {
108            match Pin::new(&mut this.connection).poll_accept(cx) {
109                Poll::Ready(None) => return Poll::Ready(Ok(())),
110                Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
111                Poll::Ready(Some(Ok((req, res)))) => {
112                    // update keep-alive expire
113                    if this.ka_timer.is_some() {
114                        if let Some(expire) = this.config.keep_alive_expire() {
115                            this.ka_expire = expire;
116                        }
117                    }
118
119                    let (parts, body) = req.into_parts();
120                    let mut req = Request::with_payload(Payload::<
121                        crate::payload::PayloadStream,
122                    >::H2(
123                        crate::h2::Payload::new(body)
124                    ));
125
126                    let head = &mut req.head_mut();
127                    head.uri = parts.uri;
128                    head.method = parts.method;
129                    head.version = parts.version;
130                    head.headers = parts.headers.into();
131                    head.peer_addr = this.peer_addr;
132
133                    // set on_connect data
134                    if let Some(ref on_connect) = this.on_connect {
135                        on_connect.set(&mut req.extensions_mut());
136                    }
137
138                    actori_rt::spawn(ServiceResponse::<
139                        S::Future,
140                        S::Response,
141                        S::Error,
142                        B,
143                    > {
144                        state: ServiceResponseState::ServiceCall(
145                            this.service.call(req),
146                            Some(res),
147                        ),
148                        config: this.config.clone(),
149                        buffer: None,
150                        _t: PhantomData,
151                    });
152                }
153                Poll::Pending => return Poll::Pending,
154            }
155        }
156    }
157}
158
159#[pin_project::pin_project]
160struct ServiceResponse<F, I, E, B> {
161    state: ServiceResponseState<F, B>,
162    config: ServiceConfig,
163    buffer: Option<Bytes>,
164    _t: PhantomData<(I, E)>,
165}
166
167enum ServiceResponseState<F, B> {
168    ServiceCall(F, Option<SendResponse<Bytes>>),
169    SendPayload(SendStream<Bytes>, ResponseBody<B>),
170}
171
172impl<F, I, E, B> ServiceResponse<F, I, E, B>
173where
174    F: Future<Output = Result<I, E>>,
175    E: Into<Error>,
176    I: Into<Response<B>>,
177    B: MessageBody,
178{
179    fn prepare_response(
180        &self,
181        head: &ResponseHead,
182        size: &mut BodySize,
183    ) -> http::Response<()> {
184        let mut has_date = false;
185        let mut skip_len = size != &BodySize::Stream;
186
187        let mut res = http::Response::new(());
188        *res.status_mut() = head.status;
189        *res.version_mut() = http::Version::HTTP_2;
190
191        // Content length
192        match head.status {
193            http::StatusCode::NO_CONTENT
194            | http::StatusCode::CONTINUE
195            | http::StatusCode::PROCESSING => *size = BodySize::None,
196            http::StatusCode::SWITCHING_PROTOCOLS => {
197                skip_len = true;
198                *size = BodySize::Stream;
199            }
200            _ => (),
201        }
202        let _ = match size {
203            BodySize::None | BodySize::Stream => None,
204            BodySize::Empty => res
205                .headers_mut()
206                .insert(CONTENT_LENGTH, HeaderValue::from_static("0")),
207            BodySize::Sized(len) => res.headers_mut().insert(
208                CONTENT_LENGTH,
209                HeaderValue::try_from(format!("{}", len)).unwrap(),
210            ),
211            BodySize::Sized64(len) => res.headers_mut().insert(
212                CONTENT_LENGTH,
213                HeaderValue::try_from(format!("{}", len)).unwrap(),
214            ),
215        };
216
217        // copy headers
218        for (key, value) in head.headers.iter() {
219            match *key {
220                CONNECTION | TRANSFER_ENCODING => continue, // http2 specific
221                CONTENT_LENGTH if skip_len => continue,
222                DATE => has_date = true,
223                _ => (),
224            }
225            res.headers_mut().append(key, value.clone());
226        }
227
228        // set date header
229        if !has_date {
230            let mut bytes = BytesMut::with_capacity(29);
231            self.config.set_date_header(&mut bytes);
232            res.headers_mut().insert(DATE, unsafe {
233                HeaderValue::from_maybe_shared_unchecked(bytes.freeze())
234            });
235        }
236
237        res
238    }
239}
240
241impl<F, I, E, B> Future for ServiceResponse<F, I, E, B>
242where
243    F: Future<Output = Result<I, E>>,
244    E: Into<Error>,
245    I: Into<Response<B>>,
246    B: MessageBody,
247{
248    type Output = ();
249
250    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
251        let mut this = self.as_mut().project();
252
253        match this.state {
254            ServiceResponseState::ServiceCall(ref mut call, ref mut send) => {
255                match unsafe { Pin::new_unchecked(call) }.poll(cx) {
256                    Poll::Ready(Ok(res)) => {
257                        let (res, body) = res.into().replace_body(());
258
259                        let mut send = send.take().unwrap();
260                        let mut size = body.size();
261                        let h2_res =
262                            self.as_mut().prepare_response(res.head(), &mut size);
263                        this = self.as_mut().project();
264
265                        let stream = match send.send_response(h2_res, size.is_eof()) {
266                            Err(e) => {
267                                trace!("Error sending h2 response: {:?}", e);
268                                return Poll::Ready(());
269                            }
270                            Ok(stream) => stream,
271                        };
272
273                        if size.is_eof() {
274                            Poll::Ready(())
275                        } else {
276                            *this.state =
277                                ServiceResponseState::SendPayload(stream, body);
278                            self.poll(cx)
279                        }
280                    }
281                    Poll::Pending => Poll::Pending,
282                    Poll::Ready(Err(e)) => {
283                        let res: Response = e.into().into();
284                        let (res, body) = res.replace_body(());
285
286                        let mut send = send.take().unwrap();
287                        let mut size = body.size();
288                        let h2_res =
289                            self.as_mut().prepare_response(res.head(), &mut size);
290                        this = self.as_mut().project();
291
292                        let stream = match send.send_response(h2_res, size.is_eof()) {
293                            Err(e) => {
294                                trace!("Error sending h2 response: {:?}", e);
295                                return Poll::Ready(());
296                            }
297                            Ok(stream) => stream,
298                        };
299
300                        if size.is_eof() {
301                            Poll::Ready(())
302                        } else {
303                            *this.state = ServiceResponseState::SendPayload(
304                                stream,
305                                body.into_body(),
306                            );
307                            self.poll(cx)
308                        }
309                    }
310                }
311            }
312            ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop {
313                loop {
314                    if let Some(ref mut buffer) = this.buffer {
315                        match stream.poll_capacity(cx) {
316                            Poll::Pending => return Poll::Pending,
317                            Poll::Ready(None) => return Poll::Ready(()),
318                            Poll::Ready(Some(Ok(cap))) => {
319                                let len = buffer.len();
320                                let bytes = buffer.split_to(std::cmp::min(cap, len));
321
322                                if let Err(e) = stream.send_data(bytes, false) {
323                                    warn!("{:?}", e);
324                                    return Poll::Ready(());
325                                } else if !buffer.is_empty() {
326                                    let cap = std::cmp::min(buffer.len(), CHUNK_SIZE);
327                                    stream.reserve_capacity(cap);
328                                } else {
329                                    this.buffer.take();
330                                }
331                            }
332                            Poll::Ready(Some(Err(e))) => {
333                                warn!("{:?}", e);
334                                return Poll::Ready(());
335                            }
336                        }
337                    } else {
338                        match body.poll_next(cx) {
339                            Poll::Pending => return Poll::Pending,
340                            Poll::Ready(None) => {
341                                if let Err(e) = stream.send_data(Bytes::new(), true) {
342                                    warn!("{:?}", e);
343                                }
344                                return Poll::Ready(());
345                            }
346                            Poll::Ready(Some(Ok(chunk))) => {
347                                stream.reserve_capacity(std::cmp::min(
348                                    chunk.len(),
349                                    CHUNK_SIZE,
350                                ));
351                                *this.buffer = Some(chunk);
352                            }
353                            Poll::Ready(Some(Err(e))) => {
354                                error!("Response payload stream error: {:?}", e);
355                                return Poll::Ready(());
356                            }
357                        }
358                    }
359                }
360            },
361        }
362    }
363}