1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use crate::adaptor::ConduitRequest;
use crate::file_stream::FileStream;
use crate::service::ServiceError;
use crate::{ConduitResponse, HyperResponse};
use std::net::SocketAddr;
use std::sync::Arc;
use conduit::header::CONTENT_LENGTH;
use conduit::{Handler, StartInstant, StatusCode};
use hyper::body::HttpBody;
use hyper::{Body, Request, Response};
use tracing::{error, warn};
const MAX_CONTENT_LENGTH: u64 = 128 * 1024 * 1024; #[derive(Debug)]
pub struct BlockingHandler<H: Handler> {
handler: Arc<H>,
}
impl<H: Handler> BlockingHandler<H> {
pub fn new(handler: H) -> Self {
Self {
handler: Arc::new(handler),
}
}
pub(crate) async fn blocking_handler(
self: Arc<Self>,
request: Request<Body>,
remote_addr: SocketAddr,
) -> Result<HyperResponse, ServiceError> {
if let Err(response) = check_content_length(&request) {
return Ok(response);
}
let (parts, body) = request.into_parts();
let now = StartInstant::now();
let full_body = hyper::body::to_bytes(body).await?;
let request = Request::from_parts(parts, full_body);
let handler = self.handler.clone();
tokio::task::spawn_blocking(move || {
let mut request = ConduitRequest::new(request, remote_addr, now);
handler
.call(&mut request)
.map(conduit_into_hyper)
.unwrap_or_else(|e| server_error_response(&e.to_string()))
})
.await
.map_err(Into::into)
}
}
fn conduit_into_hyper(response: ConduitResponse) -> HyperResponse {
use conduit::Body::*;
let (parts, body) = response.into_parts();
let body = match body {
Static(slice) => slice.into(),
Owned(vec) => vec.into(),
File(file) => FileStream::from_std(file).into_streamed_body(),
};
HyperResponse::from_parts(parts, body)
}
fn server_error_response(message: &str) -> HyperResponse {
error!("Internal Server Error: {}", message);
let body = hyper::Body::from("Internal Server Error");
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(body)
.expect("Unexpected invalid header")
}
fn check_content_length(request: &Request<Body>) -> Result<(), HyperResponse> {
fn bad_request(message: &str) -> HyperResponse {
warn!("Bad request: Content-Length {}", message);
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
.expect("Unexpected invalid header")
}
if let Some(content_length) = request.headers().get(CONTENT_LENGTH) {
let content_length = match content_length.to_str() {
Ok(some) => some,
Err(_) => return Err(bad_request("not ASCII")),
};
let content_length = match content_length.parse::<u64>() {
Ok(some) => some,
Err(_) => return Err(bad_request("not a u64")),
};
if content_length > MAX_CONTENT_LENGTH {
return Err(bad_request("too large"));
}
}
if request.size_hint().lower() > MAX_CONTENT_LENGTH {
return Err(bad_request("size_hint().lower() too large"));
}
Ok(())
}