conduit_hyper/service/
blocking_handler.rs1use crate::adaptor::ConduitRequest;
2use crate::file_stream::FileStream;
3use crate::service::ServiceError;
4use crate::{ConduitResponse, HyperResponse};
5
6use std::net::SocketAddr;
7use std::sync::Arc;
8
9use conduit::header::CONTENT_LENGTH;
10use conduit::{Handler, StartInstant, StatusCode};
11use hyper::body::HttpBody;
12use hyper::{Body, Request, Response};
13use tracing::{error, warn};
14
15const MAX_CONTENT_LENGTH: u64 = 128 * 1024 * 1024; #[derive(Debug)]
22pub struct BlockingHandler<H: Handler> {
23 handler: Arc<H>,
24}
25
26impl<H: Handler> BlockingHandler<H> {
27 pub fn new(handler: H) -> Self {
28 Self {
29 handler: Arc::new(handler),
30 }
31 }
32
33 pub(crate) async fn blocking_handler(
35 self: Arc<Self>,
36 request: Request<Body>,
37 remote_addr: SocketAddr,
38 ) -> Result<HyperResponse, ServiceError> {
39 if let Err(response) = check_content_length(&request) {
40 return Ok(response);
41 }
42
43 let (parts, body) = request.into_parts();
44 let now = StartInstant::now();
45
46 let full_body = hyper::body::to_bytes(body).await?;
47 let request = Request::from_parts(parts, full_body);
48
49 let handler = self.handler.clone();
50 tokio::task::spawn_blocking(move || {
51 let mut request = ConduitRequest::new(request, remote_addr, now);
52 handler
53 .call(&mut request)
54 .map(conduit_into_hyper)
55 .unwrap_or_else(|e| server_error_response(&e.to_string()))
56 })
57 .await
58 .map_err(Into::into)
59 }
60}
61
62fn conduit_into_hyper(response: ConduitResponse) -> HyperResponse {
64 use conduit::Body::*;
65
66 let (parts, body) = response.into_parts();
67 let body = match body {
68 Static(slice) => slice.into(),
69 Owned(vec) => vec.into(),
70 File(file) => FileStream::from_std(file).into_streamed_body(),
71 };
72 HyperResponse::from_parts(parts, body)
73}
74
75fn server_error_response(message: &str) -> HyperResponse {
77 error!("Internal Server Error: {}", message);
78 let body = hyper::Body::from("Internal Server Error");
79 Response::builder()
80 .status(StatusCode::INTERNAL_SERVER_ERROR)
81 .body(body)
82 .expect("Unexpected invalid header")
83}
84
85fn check_content_length(request: &Request<Body>) -> Result<(), HyperResponse> {
94 fn bad_request(message: &str) -> HyperResponse {
95 warn!("Bad request: Content-Length {}", message);
96
97 Response::builder()
98 .status(StatusCode::BAD_REQUEST)
99 .body(Body::empty())
100 .expect("Unexpected invalid header")
101 }
102
103 if let Some(content_length) = request.headers().get(CONTENT_LENGTH) {
104 let content_length = match content_length.to_str() {
105 Ok(some) => some,
106 Err(_) => return Err(bad_request("not ASCII")),
107 };
108
109 let content_length = match content_length.parse::<u64>() {
110 Ok(some) => some,
111 Err(_) => return Err(bad_request("not a u64")),
112 };
113
114 if content_length > MAX_CONTENT_LENGTH {
115 return Err(bad_request("too large"));
116 }
117 }
118
119 if request.size_hint().lower() > MAX_CONTENT_LENGTH {
122 return Err(bad_request("size_hint().lower() too large"));
123 }
124
125 Ok(())
126}