use std::{
net::SocketAddr,
sync::{Arc, Mutex},
time::Duration,
};
use async_trait::async_trait;
use http_body_util::BodyExt;
use hyper::{
body::{self, Body, Bytes},
server::conn::http1,
service::service_fn,
Request, Uri,
};
use tokio::{net::TcpListener, select, sync::watch, time::Instant};
use crate::Error;
use crate::{run_handler, Handler};
#[derive(Debug, Clone)]
pub struct Server {
close_tx: Arc<watch::Sender<u8>>,
addr: SocketAddr,
req_count: Arc<Mutex<u64>>,
concurrent_req_count: Arc<Mutex<u64>>,
}
impl Server {
pub async fn new<H: Handler + Clone + Send + Sync + 'static>(
handler: H,
) -> Result<Self, Error> {
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
let tcp_listener = TcpListener::bind(addr)
.await
.map_err(Error::BindTCPListener)?;
let addr = tcp_listener
.local_addr()
.map_err(Error::GetTCPListenerAddress)?;
let (close_tx, close_rx) = watch::channel::<u8>(0);
let req_count = Arc::new(Mutex::new(0));
let concurrent_req_count = Arc::new(Mutex::new(0));
{
let handler = handler.clone();
let req_count = req_count.clone();
let concurrent_req_count = concurrent_req_count.clone();
tokio::spawn(async move {
let mut close_rx = close_rx.clone();
loop {
let (tcp_stream, remote_addr) = select! {
_ = close_rx.changed() => {
return;
}
res = tcp_listener.accept() => {
match res {
Ok(res) => res,
Err(err) => {
eprintln!("Error while accepting TCP connection: {}", err);
return;
}
}
}
};
let handler = handler.clone();
let mut close_rx = close_rx.clone();
let req_count = req_count.clone();
let concurrent_req_count = concurrent_req_count.clone();
tokio::spawn(async move {
let handler = &handler;
let req_count = &req_count;
let concurrent_req_count = &concurrent_req_count;
let service = service_fn(|mut req: Request<body::Incoming>| async move {
*concurrent_req_count.lock().expect("lock poisoned") += 1;
req.extensions_mut().insert(remote_addr);
let res = run_handler(handler.clone(), req).await;
*concurrent_req_count.lock().expect("lock poisoned") -= 1;
*req_count.lock().expect("lock poisoned") += 1;
res
});
let res = select! {
_ = close_rx.changed() => {
return;
}
res = http1::Builder::new()
.keep_alive(true)
.serve_connection(tcp_stream, service) => res,
};
if let Err(http_err) = res {
eprintln!("Error while serving HTTP connection: {}", http_err);
}
});
}
});
};
Ok(Self {
close_tx: Arc::new(close_tx),
addr,
req_count,
concurrent_req_count,
})
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub fn url(&self, path_and_query: &str) -> Uri {
Uri::builder()
.scheme("http")
.authority(self.addr.to_string().as_str())
.path_and_query(path_and_query)
.build()
.expect("should be a valid URL")
}
pub fn req_count(&self) -> u64 {
*self.req_count.lock().expect("lock poisoned")
}
pub async fn await_req_count(&self, target_count: u64, timeout: Duration) -> Result<(), Error> {
let start = Instant::now();
loop {
let current_count = self.req_count();
if current_count == target_count {
return Ok(());
}
if start.elapsed() > timeout {
return Err(Error::AwaitReqCountTimeout {
current_count,
target_count,
timeout,
});
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
pub fn concurrent_req_count(&self) -> u64 {
*self.concurrent_req_count.lock().expect("lock poisoned")
}
pub async fn await_concurrent_req_count(
&self,
target_count: u64,
timeout: Duration,
) -> Result<(), Error> {
let start = Instant::now();
loop {
let current_count = self.concurrent_req_count();
if current_count == target_count {
return Ok(());
}
if start.elapsed() > timeout {
return Err(Error::AwaitConcurrentReqCountTimeout {
current_count,
target_count,
timeout,
});
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
pub fn close(&self) {
self.close_tx.send(1).expect("failed to close server");
}
}
impl Drop for Server {
fn drop(&mut self) {
if Arc::strong_count(&self.close_tx) == 1 {
self.close();
}
}
}
#[async_trait]
pub trait GetRequestBody {
async fn body_bytes(self) -> Result<Bytes, hyper::Error>;
}
#[async_trait]
impl<B> GetRequestBody for Request<B>
where
B: Body<Data = Bytes> + Send + Sync + 'static,
<B as Body>::Error: Into<hyper::Error>,
{
async fn body_bytes(self) -> Result<Bytes, hyper::Error> {
self.into_body()
.collect()
.await
.map(|full| full.to_bytes())
.map_err(|err| err.into())
}
}
#[cfg(test)]
mod test {
use http_body_util::Full;
use hyper::{body::Bytes, Response};
use super::*;
use crate::handle_ok;
#[tokio::test]
async fn server_ok() {
async fn handler(
req: Request<body::Incoming>,
) -> Result<Response<Full<Bytes>>, hyper::Error> {
let body = req.body_bytes().await?;
Ok(Response::new(Full::new(body)))
}
let server = Server::new(handler).await.expect("create server");
let client = reqwest::Client::new();
static ITERATIONS: u64 = 10;
for i in 0..ITERATIONS {
let res = client
.post(server.url("/").to_string())
.body(format!("hello world {}", i))
.send()
.await
.expect("send request");
assert_eq!(res.status(), 200);
assert_eq!(
res.text().await.expect("read response"),
format!("hello world {}", i)
);
assert_eq!(server.req_count(), i + 1);
}
assert_eq!(server.req_count(), ITERATIONS);
}
#[tokio::test]
async fn server_move_closure_copy() {
let val = 1234;
let server = Server::new(move |_: Request<body::Incoming>| async move {
handle_ok(Response::new(val.to_string().into()))
})
.await
.expect("create server");
let client = reqwest::Client::new();
static ITERATIONS: u64 = 10;
for i in 0..ITERATIONS {
let res = client
.get(server.url("/").to_string())
.send()
.await
.expect("send request");
assert_eq!(res.status(), 200);
assert_eq!(res.text().await.expect("read response"), val.to_string());
assert_eq!(server.req_count(), i + 1);
}
assert_eq!(server.req_count(), ITERATIONS);
}
#[tokio::test]
async fn server_move_closure_arc() {
let val = Arc::new(Mutex::new(1234));
let server = {
let val = val.clone();
Server::new(move |_: Request<body::Incoming>| async move {
let mut val = val.lock().expect("lock poisoned");
*val += 1;
handle_ok(Response::new(val.to_string().into()))
})
.await
.expect("create server")
};
let client = reqwest::Client::new();
static ITERATIONS: u64 = 10;
for i in 0..ITERATIONS {
let res = client
.get(server.url("/").to_string())
.send()
.await
.expect("send request");
assert_eq!(res.status(), 200);
assert_eq!(
res.text().await.expect("read response"),
val.lock().expect("lock poisoned").to_string()
);
assert_eq!(server.req_count(), i + 1);
}
assert_eq!(server.req_count(), ITERATIONS);
}
#[tokio::test]
async fn server_failure() {
async fn handler(_: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, String> {
Err("Internal Server Error".to_string())
}
let server = Server::new(handler).await.expect("create server");
let client = reqwest::Client::new();
static ITERATIONS: u64 = 10;
for i in 0..ITERATIONS {
let res = client
.get(server.url("/").to_string())
.send()
.await
.expect("send request");
assert_eq!(res.status(), 500);
assert_eq!(
res.text().await.expect("read response"),
"Internal Server Error"
);
assert_eq!(server.req_count(), i + 1);
}
assert_eq!(server.req_count(), ITERATIONS);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 16)]
async fn server_await_req_count() {
async fn handler(_: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, String> {
Ok(Response::new("hello world".into()))
}
let server = Server::new(handler).await.expect("create server");
let client = reqwest::Client::new();
static ITERATIONS: u64 = 10;
let url = server.url("/").to_string();
let futures: Vec<tokio::task::JoinHandle<()>> = (0..ITERATIONS)
.map(|_| {
let client = client.clone();
let url = url.clone();
tokio::spawn(async move {
let res = client.get(url).send().await.expect("send request");
assert_eq!(res.status(), 200);
})
})
.collect();
server
.await_req_count(ITERATIONS, Duration::from_secs(1))
.await
.expect("requests finished");
assert_eq!(server.req_count(), ITERATIONS);
for fut in futures {
fut.await.unwrap();
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 16)]
async fn server_long_requests_cancellation() {
async fn handler(_: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, String> {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(Response::new("hello world".into()))
}
let server = Server::new(handler).await.expect("create server");
let client = reqwest::Client::new();
static ITERATIONS: u64 = 10;
let url = server.url("/").to_string();
let futures: Vec<tokio::task::JoinHandle<Result<(), String>>> = (0..ITERATIONS)
.map(|_| {
let client = client.clone();
let url = url.clone();
tokio::spawn(async move {
let res = client.get(url).send().await;
match res {
Ok(_) => Err("expected request to be canceled".to_string()),
Err(_) => Ok(()),
}
})
})
.collect();
server
.await_concurrent_req_count(ITERATIONS, Duration::from_secs(1))
.await
.expect("requests start");
assert_eq!(server.concurrent_req_count(), ITERATIONS);
let now = Instant::now();
drop(server);
for fut in futures {
fut.await.unwrap().expect("request canceled");
}
assert!(now.elapsed() < Duration::from_secs(1));
}
#[tokio::test]
async fn server_keep_alive() {
let server = {
let last_socket_addr: Arc<Mutex<Option<SocketAddr>>> = Arc::new(Mutex::new(None));
Server::new(move |req: Request<body::Incoming>| async move {
let socket_addr = req.extensions().get::<SocketAddr>().unwrap();
let mut last_socket_addr = last_socket_addr.lock().expect("lock poisoned");
match *last_socket_addr {
Some(last_socket_addr) => {
assert_eq!(&last_socket_addr, socket_addr);
}
None => {
*last_socket_addr = Some(*socket_addr);
}
}
handle_ok(Response::new("hello world".into()))
})
.await
.expect("create server")
};
let client = reqwest::Client::new();
static ITERATIONS: u64 = 10;
for i in 0..ITERATIONS {
let res = client
.get(server.url("/").to_string())
.send()
.await
.expect("send request");
assert_eq!(res.status(), 200);
assert_eq!(res.text().await.expect("read response"), "hello world");
assert_eq!(server.req_count(), i + 1);
}
assert_eq!(server.req_count(), ITERATIONS);
}
}