use std::sync::Arc;
use conduit::{box_error, Body, Handler, HandlerResult, RequestExt, Response, StatusCode};
use futures_util::future::Future;
use hyper::{body::to_bytes, service::Service};
use tokio::{sync::oneshot, task::JoinHandle};
use super::service::{BlockingHandler, ServiceError};
use super::HyperResponse;
struct OkResult;
impl Handler for OkResult {
fn call(&self, _req: &mut dyn RequestExt) -> HandlerResult {
Response::builder()
.header("ok", "value")
.body(Body::from_static(b"Hello, world!"))
.map_err(box_error)
}
}
struct ErrorResult;
impl Handler for ErrorResult {
fn call(&self, _req: &mut dyn RequestExt) -> HandlerResult {
let error = ::std::io::Error::last_os_error();
Err(Box::new(error))
}
}
struct Panic;
impl Handler for Panic {
fn call(&self, _req: &mut dyn RequestExt) -> HandlerResult {
panic!()
}
}
struct InvalidHeader;
impl Handler for InvalidHeader {
fn call(&self, _req: &mut dyn RequestExt) -> HandlerResult {
Response::builder()
.header("invalid-value", "\r\n")
.body(Body::from_static(b"discarded"))
.map_err(box_error)
}
}
struct InvalidStatus;
impl Handler for InvalidStatus {
fn call(&self, _req: &mut dyn RequestExt) -> HandlerResult {
Response::builder()
.status(1000)
.body(Body::empty())
.map_err(box_error)
}
}
struct Sleep;
impl Handler for Sleep {
fn call(&self, req: &mut dyn RequestExt) -> HandlerResult {
std::thread::sleep(std::time::Duration::from_millis(100));
OkResult.call(req)
}
}
struct AssertPercentDecodedPath;
impl Handler for AssertPercentDecodedPath {
fn call(&self, req: &mut dyn RequestExt) -> HandlerResult {
if req.path() == "/:" && req.query_string() == Some("%3a") {
OkResult.call(req)
} else {
ErrorResult.call(req)
}
}
}
fn make_service<H: Handler>(
handler: H,
) -> impl Service<
hyper::Request<hyper::Body>,
Response = HyperResponse,
Future = impl Future<Output = Result<HyperResponse, ServiceError>> + Send + 'static,
Error = ServiceError,
> {
use hyper::service::service_fn;
let handler = std::sync::Arc::new(BlockingHandler::new(handler));
service_fn(move |request: hyper::Request<hyper::Body>| {
let remote_addr = ([0, 0, 0, 0], 0).into();
handler.clone().blocking_handler(request, remote_addr)
})
}
async fn simulate_request<H: Handler>(handler: H) -> HyperResponse {
let mut service = make_service(handler);
service.call(hyper::Request::default()).await.unwrap()
}
async fn assert_generic_err(resp: HyperResponse) {
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert!(resp.headers().is_empty());
let full_body = to_bytes(resp.into_body()).await.unwrap();
assert_eq!(&*full_body, b"Internal Server Error");
}
#[tokio::test]
async fn valid_ok_response() {
let resp = simulate_request(OkResult).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(resp.headers().len(), 1);
let full_body = to_bytes(resp.into_body()).await.unwrap();
assert_eq!(&*full_body, b"Hello, world!");
}
#[tokio::test]
async fn invalid_ok_responses() {
assert_generic_err(simulate_request(InvalidHeader).await).await;
assert_generic_err(simulate_request(InvalidStatus).await).await;
}
#[tokio::test]
async fn err_responses() {
assert_generic_err(simulate_request(ErrorResult).await).await;
}
#[ignore] #[tokio::test]
async fn recover_from_panic() {
assert_generic_err(simulate_request(Panic).await).await;
}
#[tokio::test]
async fn sleeping_doesnt_block_another_request() {
let mut service = make_service(Sleep);
let first = service.call(hyper::Request::default());
let second = service.call(hyper::Request::default());
let start = std::time::Instant::now();
let (first, second) = futures_util::join!(first, second);
assert!(start.elapsed().as_millis() < 150);
assert_eq!(first.unwrap().status(), StatusCode::OK);
assert_eq!(second.unwrap().status(), StatusCode::OK);
}
#[tokio::test]
async fn path_is_percent_decoded_but_not_query_string() {
let mut service = make_service(AssertPercentDecodedPath);
let req = hyper::Request::put("/%3a?%3a")
.body(hyper::Body::default())
.unwrap();
let resp = service.call(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
async fn spawn_http_server() -> (
String,
JoinHandle<Result<(), hyper::Error>>,
oneshot::Sender<()>,
) {
let (quit_tx, quit_rx) = oneshot::channel::<()>();
let addr = ([127, 0, 0, 1], 0).into();
let server = hyper::Server::bind(&addr).serve(hyper::service::make_service_fn(move |_| {
let handler = Arc::new(BlockingHandler::new(OkResult));
let remote_addr = ([0, 0, 0, 0], 0).into();
async move { crate::Service::from_blocking(handler, remote_addr) }
}));
let url = format!("http://{}", server.local_addr());
let server = server.with_graceful_shutdown(async {
quit_rx.await.ok();
});
(url, tokio::spawn(server), quit_tx)
}
#[tokio::test]
async fn content_length_too_large() {
const ACTUAL_BODY_SIZE: usize = 10_000;
const CLAIMED_CONTENT_LENGTH: u64 = 11_111_111_111_111_111_111;
let (url, server, quit_tx) = spawn_http_server().await;
let client = hyper::Client::new();
let (mut sender, body) = hyper::Body::channel();
sender
.send_data(vec![0; ACTUAL_BODY_SIZE].into())
.await
.unwrap();
let req = hyper::Request::put(url)
.header(hyper::header::CONTENT_LENGTH, CLAIMED_CONTENT_LENGTH)
.body(body)
.unwrap();
let resp = client
.request(req)
.await
.expect("should be a valid response");
quit_tx.send(()).unwrap();
server.await.unwrap().unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
}