1use crate::{HttpOrWsIncoming, IsTls, TcpIncoming, TcpOrTlsIncoming, TcpOrTlsStream, TcpStream};
2use async_http_codec::internal::buffer_decode::BufferDecode;
3use async_http_codec::{
4 BodyDecodeWithContinue, BodyDecodeWithContinueState, BodyEncode, RequestHead, ResponseHead,
5};
6use futures::prelude::*;
7use futures::stream::{FusedStream, FuturesUnordered};
8use futures::StreamExt;
9use http::header::{IntoHeaderName, HOST, LOCATION, TRANSFER_ENCODING};
10use http::uri::{Authority, Parts, Scheme};
11use http::{HeaderMap, HeaderValue, Method, Request, StatusCode, Uri, Version};
12use log::debug;
13use std::borrow::Cow;
14use std::convert::TryFrom;
15use std::io;
16use std::pin::Pin;
17use std::task::{Context, Poll};
18
19pub struct HttpIncoming<
20 IO: AsyncRead + AsyncWrite + Unpin = TcpOrTlsStream,
21 T: Stream<Item = IO> + Unpin = TcpOrTlsIncoming,
22> {
23 incoming: Option<T>,
24 decoding: FuturesUnordered<BufferDecode<IO, RequestHead<'static>>>,
25}
26
27impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = IO> + Unpin> HttpIncoming<IO, T> {
28 pub fn new(transport_incoming: T) -> Self {
29 HttpIncoming {
30 incoming: Some(transport_incoming),
31 decoding: FuturesUnordered::new(),
32 }
33 }
34 pub fn or_ws(self) -> HttpOrWsIncoming<IO, Self> {
35 HttpOrWsIncoming::new(self)
36 }
37}
38
39impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = IO> + Unpin> Stream
40 for HttpIncoming<IO, T>
41{
42 type Item = HttpRequest<IO>;
43
44 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
45 loop {
46 match self.decoding.poll_next_unpin(cx) {
47 Poll::Ready(Some(Ok((transport, head)))) => {
48 match BodyDecodeWithContinueState::from_head(&head) {
49 Ok(state) => {
50 return Poll::Ready(Some(HttpRequest {
51 head,
52 state,
53 transport,
54 }))
55 }
56 Err(err) => log::debug!("http head error: {:?}", err),
57 };
58 }
59 Poll::Ready(Some(Err(err))) => log::debug!("http head decode error: {:?}", err),
60 Poll::Ready(None) | Poll::Pending => match &mut self.incoming {
61 Some(incoming) => match incoming.poll_next_unpin(cx) {
62 Poll::Ready(Some(transport)) => {
63 self.decoding.push(RequestHead::decode(transport))
64 }
65 Poll::Ready(None) => drop(self.incoming.take()),
66 Poll::Pending => return Poll::Pending,
67 },
68 None => match self.is_terminated() {
69 true => return Poll::Ready(None),
70 false => return Poll::Pending,
71 },
72 },
73 }
74 }
75 }
76}
77
78impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = IO> + Unpin> FusedStream
79 for HttpIncoming<IO, T>
80{
81 fn is_terminated(&self) -> bool {
82 self.incoming.is_none() && self.decoding.is_terminated()
83 }
84}
85
86impl<IO: AsyncRead + AsyncWrite + Unpin, T: Stream<Item = IO> + Unpin> Unpin
87 for HttpIncoming<IO, T>
88{
89}
90
91pub struct HttpRequest<IO: AsyncRead + AsyncWrite + Unpin = TcpOrTlsStream> {
92 pub(crate) head: RequestHead<'static>,
93 pub(crate) state: BodyDecodeWithContinueState,
94 pub(crate) transport: IO,
95}
96
97impl core::fmt::Debug for HttpRequest {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("Point")
100 .field("head", &self.head)
101 .finish_non_exhaustive()
102 }
103}
104
105impl<IO: AsyncRead + AsyncWrite + Unpin + IsTls> IsTls for HttpRequest<IO> {
106 fn is_tls(&self) -> bool {
107 self.transport.is_tls()
108 }
109}
110
111impl<IO: AsyncRead + AsyncWrite + Unpin> HttpRequest<IO> {
112 pub fn into_inner(self) -> Request<BodyDecodeWithContinue<BodyDecodeWithContinueState, IO>> {
131 Request::from_parts(self.head.into(), self.state.into_async_read(self.transport))
132 }
133 pub fn from_inner(
135 request: Request<BodyDecodeWithContinue<BodyDecodeWithContinueState, IO>>,
136 ) -> Self {
137 let (head, body) = request.into_parts();
138 let head = head.into();
139 let (state, transport) = body.into_inner();
140 Self {
141 head,
142 state,
143 transport,
144 }
145 }
146 pub async fn response(mut self) -> io::Result<HttpResponse<IO>> {
148 while 0 < self.body().read(&mut [0u8; 1 << 14]).await? {}
149 let Self {
150 head,
151 state: _,
152 transport,
153 } = self;
154 let request_head = http::request::Parts::from(head);
155 let request_headers = request_head.headers;
156 let request_method = request_head.method;
157 let request_uri = request_head.uri;
158 let headers = Cow::Owned(HeaderMap::with_capacity(128));
159 Ok(HttpResponse {
160 request_headers,
161 request_uri,
162 request_method,
163 head: ResponseHead::new(StatusCode::OK, request_head.version, headers),
164 transport,
165 })
166 }
167
168 pub fn body(&mut self) -> BodyDecodeWithContinue<&mut BodyDecodeWithContinueState, &mut IO> {
172 self.state.as_async_read(&mut self.transport)
173 }
174 pub async fn body_string(&mut self, limit: usize) -> io::Result<String> {
176 let mut body = String::new();
177 self.body()
178 .take(limit as u64)
179 .read_to_string(&mut body)
180 .await?;
181 if body.len() == limit && self.body().read(&mut [0u8]).await? > 0 {
182 return Err(io::Error::new(
183 io::ErrorKind::OutOfMemory,
184 "body size exceeds limit",
185 ));
186 }
187 Ok(body)
188 }
189 pub async fn body_vec(&mut self, limit: usize) -> io::Result<Vec<u8>> {
191 let mut body = Vec::new();
192 self.body()
193 .take(limit as u64)
194 .read_to_end(&mut body)
195 .await?;
196 if body.len() == limit && self.body().read(&mut [0u8]).await? > 0 {
197 return Err(io::Error::new(
198 io::ErrorKind::OutOfMemory,
199 "body size exceeds limit",
200 ));
201 }
202 Ok(body)
203 }
204 pub fn headers(&self) -> &HeaderMap {
206 self.head.headers()
207 }
208 pub fn uri(&self) -> &Uri {
210 &self.head.uri()
211 }
212 pub fn method(&self) -> Method {
214 self.head.method().clone()
215 }
216 pub fn version(&self) -> Version {
218 self.head.version()
219 }
220}
221
222pub struct HttpResponse<IO: AsyncRead + AsyncWrite + Unpin = TcpOrTlsStream> {
223 request_uri: Uri,
224 request_headers: HeaderMap,
225 request_method: Method,
226 head: ResponseHead<'static>,
227 transport: IO,
228}
229
230impl core::fmt::Debug for HttpResponse {
231 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232 f.debug_struct("Point")
233 .field("head", &self.head)
234 .finish_non_exhaustive()
235 }
236}
237
238impl<IO: AsyncRead + AsyncWrite + Unpin> HttpResponse<IO> {
239 pub fn request_headers(&self) -> &HeaderMap {
241 &self.request_headers
242 }
243 pub fn uri(&self) -> &Uri {
245 &self.request_uri
246 }
247 pub fn method(&self) -> Method {
249 self.request_method.clone()
250 }
251 pub fn version(&self) -> Version {
253 self.head.version()
254 }
255 pub fn headers(&self) -> &HeaderMap {
257 self.head.headers()
258 }
259 pub fn headers_mut(&mut self) -> &mut HeaderMap {
261 self.head.headers_mut()
262 }
263 pub fn insert_header(&mut self, key: impl IntoHeaderName, value: HeaderValue) -> &mut Self {
265 self.headers_mut().insert(key, value);
266 self
267 }
268 pub fn status(&self) -> StatusCode {
270 self.head.status()
271 }
272 pub fn status_mut(&mut self) -> &mut StatusCode {
274 self.head.status_mut()
275 }
276 pub fn set_status(&mut self, status: StatusCode) -> &mut Self {
278 *self.status_mut() = status;
279 self
280 }
281 pub async fn send(self, body: impl AsRef<[u8]>) -> io::Result<()> {
284 let mut encoder = self.body().await?;
285 encoder.write_all(body.as_ref()).await?;
286 encoder.close().await?;
287 Ok(())
288 }
289 pub async fn body(mut self) -> io::Result<BodyEncode<IO>> {
291 self.insert_header(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));
292 self.head.encode(&mut self.transport).await?;
293 Ok(BodyEncode::new(self.transport, None))
294 }
295}
296
297impl HttpIncoming<TcpStream, TcpIncoming> {
298 pub fn redirect_https(self) -> RedirectHttps {
299 RedirectHttps {
300 incoming: self,
301 redirecting: FuturesUnordered::new(),
302 }
303 }
304}
305
306pub struct RedirectHttps {
307 incoming: HttpIncoming<TcpStream, TcpIncoming>,
308 redirecting: FuturesUnordered<Pin<Box<dyn Future<Output = ()> + Send + Sync>>>,
310}
311
312impl RedirectHttps {
313 fn set_location_header(resp: &mut HttpResponse<TcpStream>) -> http::Result<()> {
314 let authority = match resp.request_headers().get(HOST) {
315 Some(host) => Some(Authority::try_from(host.as_bytes())?),
316 None => None,
317 };
318 let mut parts: Parts = Default::default();
319 parts.scheme = Some(Scheme::HTTPS);
320 parts.authority = authority;
321 parts.path_and_query = resp.uri().path_and_query().cloned();
322 let header_value = HeaderValue::try_from(Uri::from_parts(parts)?.to_string())?;
323 resp.insert_header(LOCATION, header_value);
324 Ok(())
325 }
326 async fn send(req: HttpRequest<TcpStream>) {
327 match req.response().await {
328 Err(err) => debug!("error reading body of request to be redirected: {:?}", err),
329 Ok(mut resp) => match Self::set_location_header(&mut resp) {
330 Err(err) => debug!("error constructing redirect location header: {:?}", err),
331 Ok(()) => {
332 resp.set_status(StatusCode::TEMPORARY_REDIRECT);
333 if let Err(err) = resp.send(&[]).await {
334 debug!("error sending redirect response: {:?}", err)
335 }
336 }
337 },
338 }
339 }
340}
341
342impl Future for RedirectHttps {
343 type Output = ();
344
345 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
346 loop {
347 if let Poll::Ready(Some(())) = Pin::new(&mut self.redirecting).poll_next(cx) {
348 continue;
349 }
350 if !self.incoming.is_terminated() {
351 if let Poll::Ready(Some(req)) = Pin::new(&mut self.incoming).poll_next(cx) {
352 self.redirecting.push(Box::pin(Self::send(req)));
353 continue;
354 }
355 }
356 return match self.incoming.is_terminated() && self.redirecting.is_terminated() {
357 true => Poll::Ready(()),
358 false => Poll::Pending,
359 };
360 }
361 }
362}