1use std::convert::TryFrom;
2use std::future::Future;
3use std::marker::PhantomData;
4use std::net;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8use actori_codec::{AsyncRead, AsyncWrite};
9use actori_rt::time::{Delay, Instant};
10use actori_service::Service;
11use bytes::{Bytes, BytesMut};
12use h2::server::{Connection, SendResponse};
13use h2::SendStream;
14use http::header::{HeaderValue, CONNECTION, CONTENT_LENGTH, DATE, TRANSFER_ENCODING};
15use log::{error, trace};
16
17use crate::body::{BodySize, MessageBody, ResponseBody};
18use crate::cloneable::CloneableService;
19use crate::config::ServiceConfig;
20use crate::error::{DispatchError, Error};
21use crate::helpers::DataFactory;
22use crate::httpmessage::HttpMessage;
23use crate::message::ResponseHead;
24use crate::payload::Payload;
25use crate::request::Request;
26use crate::response::Response;
27
28const CHUNK_SIZE: usize = 16_384;
29
30#[pin_project::pin_project]
32pub struct Dispatcher<T, S: Service<Request = Request>, B: MessageBody>
33where
34 T: AsyncRead + AsyncWrite + Unpin,
35{
36 service: CloneableService<S>,
37 connection: Connection<T, Bytes>,
38 on_connect: Option<Box<dyn DataFactory>>,
39 config: ServiceConfig,
40 peer_addr: Option<net::SocketAddr>,
41 ka_expire: Instant,
42 ka_timer: Option<Delay>,
43 _t: PhantomData<B>,
44}
45
46impl<T, S, B> Dispatcher<T, S, B>
47where
48 T: AsyncRead + AsyncWrite + Unpin,
49 S: Service<Request = Request>,
50 S::Error: Into<Error>,
51 S::Response: Into<Response<B>>,
53 B: MessageBody,
54{
55 pub(crate) fn new(
56 service: CloneableService<S>,
57 connection: Connection<T, Bytes>,
58 on_connect: Option<Box<dyn DataFactory>>,
59 config: ServiceConfig,
60 timeout: Option<Delay>,
61 peer_addr: Option<net::SocketAddr>,
62 ) -> Self {
63 let (ka_expire, ka_timer) = if let Some(delay) = timeout {
72 (delay.deadline(), Some(delay))
73 } else if let Some(delay) = config.keep_alive_timer() {
74 (delay.deadline(), Some(delay))
75 } else {
76 (config.now(), None)
77 };
78
79 Dispatcher {
80 service,
81 config,
82 peer_addr,
83 connection,
84 on_connect,
85 ka_expire,
86 ka_timer,
87 _t: PhantomData,
88 }
89 }
90}
91
92impl<T, S, B> Future for Dispatcher<T, S, B>
93where
94 T: AsyncRead + AsyncWrite + Unpin,
95 S: Service<Request = Request>,
96 S::Error: Into<Error> + 'static,
97 S::Future: 'static,
98 S::Response: Into<Response<B>> + 'static,
99 B: MessageBody + 'static,
100{
101 type Output = Result<(), DispatchError>;
102
103 #[inline]
104 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
105 let this = self.get_mut();
106
107 loop {
108 match Pin::new(&mut this.connection).poll_accept(cx) {
109 Poll::Ready(None) => return Poll::Ready(Ok(())),
110 Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(err.into())),
111 Poll::Ready(Some(Ok((req, res)))) => {
112 if this.ka_timer.is_some() {
114 if let Some(expire) = this.config.keep_alive_expire() {
115 this.ka_expire = expire;
116 }
117 }
118
119 let (parts, body) = req.into_parts();
120 let mut req = Request::with_payload(Payload::<
121 crate::payload::PayloadStream,
122 >::H2(
123 crate::h2::Payload::new(body)
124 ));
125
126 let head = &mut req.head_mut();
127 head.uri = parts.uri;
128 head.method = parts.method;
129 head.version = parts.version;
130 head.headers = parts.headers.into();
131 head.peer_addr = this.peer_addr;
132
133 if let Some(ref on_connect) = this.on_connect {
135 on_connect.set(&mut req.extensions_mut());
136 }
137
138 actori_rt::spawn(ServiceResponse::<
139 S::Future,
140 S::Response,
141 S::Error,
142 B,
143 > {
144 state: ServiceResponseState::ServiceCall(
145 this.service.call(req),
146 Some(res),
147 ),
148 config: this.config.clone(),
149 buffer: None,
150 _t: PhantomData,
151 });
152 }
153 Poll::Pending => return Poll::Pending,
154 }
155 }
156 }
157}
158
159#[pin_project::pin_project]
160struct ServiceResponse<F, I, E, B> {
161 state: ServiceResponseState<F, B>,
162 config: ServiceConfig,
163 buffer: Option<Bytes>,
164 _t: PhantomData<(I, E)>,
165}
166
167enum ServiceResponseState<F, B> {
168 ServiceCall(F, Option<SendResponse<Bytes>>),
169 SendPayload(SendStream<Bytes>, ResponseBody<B>),
170}
171
172impl<F, I, E, B> ServiceResponse<F, I, E, B>
173where
174 F: Future<Output = Result<I, E>>,
175 E: Into<Error>,
176 I: Into<Response<B>>,
177 B: MessageBody,
178{
179 fn prepare_response(
180 &self,
181 head: &ResponseHead,
182 size: &mut BodySize,
183 ) -> http::Response<()> {
184 let mut has_date = false;
185 let mut skip_len = size != &BodySize::Stream;
186
187 let mut res = http::Response::new(());
188 *res.status_mut() = head.status;
189 *res.version_mut() = http::Version::HTTP_2;
190
191 match head.status {
193 http::StatusCode::NO_CONTENT
194 | http::StatusCode::CONTINUE
195 | http::StatusCode::PROCESSING => *size = BodySize::None,
196 http::StatusCode::SWITCHING_PROTOCOLS => {
197 skip_len = true;
198 *size = BodySize::Stream;
199 }
200 _ => (),
201 }
202 let _ = match size {
203 BodySize::None | BodySize::Stream => None,
204 BodySize::Empty => res
205 .headers_mut()
206 .insert(CONTENT_LENGTH, HeaderValue::from_static("0")),
207 BodySize::Sized(len) => res.headers_mut().insert(
208 CONTENT_LENGTH,
209 HeaderValue::try_from(format!("{}", len)).unwrap(),
210 ),
211 BodySize::Sized64(len) => res.headers_mut().insert(
212 CONTENT_LENGTH,
213 HeaderValue::try_from(format!("{}", len)).unwrap(),
214 ),
215 };
216
217 for (key, value) in head.headers.iter() {
219 match *key {
220 CONNECTION | TRANSFER_ENCODING => continue, CONTENT_LENGTH if skip_len => continue,
222 DATE => has_date = true,
223 _ => (),
224 }
225 res.headers_mut().append(key, value.clone());
226 }
227
228 if !has_date {
230 let mut bytes = BytesMut::with_capacity(29);
231 self.config.set_date_header(&mut bytes);
232 res.headers_mut().insert(DATE, unsafe {
233 HeaderValue::from_maybe_shared_unchecked(bytes.freeze())
234 });
235 }
236
237 res
238 }
239}
240
241impl<F, I, E, B> Future for ServiceResponse<F, I, E, B>
242where
243 F: Future<Output = Result<I, E>>,
244 E: Into<Error>,
245 I: Into<Response<B>>,
246 B: MessageBody,
247{
248 type Output = ();
249
250 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
251 let mut this = self.as_mut().project();
252
253 match this.state {
254 ServiceResponseState::ServiceCall(ref mut call, ref mut send) => {
255 match unsafe { Pin::new_unchecked(call) }.poll(cx) {
256 Poll::Ready(Ok(res)) => {
257 let (res, body) = res.into().replace_body(());
258
259 let mut send = send.take().unwrap();
260 let mut size = body.size();
261 let h2_res =
262 self.as_mut().prepare_response(res.head(), &mut size);
263 this = self.as_mut().project();
264
265 let stream = match send.send_response(h2_res, size.is_eof()) {
266 Err(e) => {
267 trace!("Error sending h2 response: {:?}", e);
268 return Poll::Ready(());
269 }
270 Ok(stream) => stream,
271 };
272
273 if size.is_eof() {
274 Poll::Ready(())
275 } else {
276 *this.state =
277 ServiceResponseState::SendPayload(stream, body);
278 self.poll(cx)
279 }
280 }
281 Poll::Pending => Poll::Pending,
282 Poll::Ready(Err(e)) => {
283 let res: Response = e.into().into();
284 let (res, body) = res.replace_body(());
285
286 let mut send = send.take().unwrap();
287 let mut size = body.size();
288 let h2_res =
289 self.as_mut().prepare_response(res.head(), &mut size);
290 this = self.as_mut().project();
291
292 let stream = match send.send_response(h2_res, size.is_eof()) {
293 Err(e) => {
294 trace!("Error sending h2 response: {:?}", e);
295 return Poll::Ready(());
296 }
297 Ok(stream) => stream,
298 };
299
300 if size.is_eof() {
301 Poll::Ready(())
302 } else {
303 *this.state = ServiceResponseState::SendPayload(
304 stream,
305 body.into_body(),
306 );
307 self.poll(cx)
308 }
309 }
310 }
311 }
312 ServiceResponseState::SendPayload(ref mut stream, ref mut body) => loop {
313 loop {
314 if let Some(ref mut buffer) = this.buffer {
315 match stream.poll_capacity(cx) {
316 Poll::Pending => return Poll::Pending,
317 Poll::Ready(None) => return Poll::Ready(()),
318 Poll::Ready(Some(Ok(cap))) => {
319 let len = buffer.len();
320 let bytes = buffer.split_to(std::cmp::min(cap, len));
321
322 if let Err(e) = stream.send_data(bytes, false) {
323 warn!("{:?}", e);
324 return Poll::Ready(());
325 } else if !buffer.is_empty() {
326 let cap = std::cmp::min(buffer.len(), CHUNK_SIZE);
327 stream.reserve_capacity(cap);
328 } else {
329 this.buffer.take();
330 }
331 }
332 Poll::Ready(Some(Err(e))) => {
333 warn!("{:?}", e);
334 return Poll::Ready(());
335 }
336 }
337 } else {
338 match body.poll_next(cx) {
339 Poll::Pending => return Poll::Pending,
340 Poll::Ready(None) => {
341 if let Err(e) = stream.send_data(Bytes::new(), true) {
342 warn!("{:?}", e);
343 }
344 return Poll::Ready(());
345 }
346 Poll::Ready(Some(Ok(chunk))) => {
347 stream.reserve_capacity(std::cmp::min(
348 chunk.len(),
349 CHUNK_SIZE,
350 ));
351 *this.buffer = Some(chunk);
352 }
353 Poll::Ready(Some(Err(e))) => {
354 error!("Response payload stream error: {:?}", e);
355 return Poll::Ready(());
356 }
357 }
358 }
359 }
360 },
361 }
362 }
363}