use std::convert::Infallible;
use std::time::Duration;
use http::{Request, Response, StatusCode};
use http_body_util::Full;
use hyper::body::{Bytes, Incoming};
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder;
use tokio::net::TcpListener;
use tower::{Service, ServiceBuilder, ServiceExt};
use tower_acc::{ConcurrencyLimitLayer, Vegas};
const BUFFER_SIZE: usize = 16;
async fn handler(_req: Request<Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
tokio::time::sleep(Duration::from_millis(200)).await;
Ok(Response::new(Full::new(Bytes::from("Hello, world!\n"))))
}
#[tokio::main]
async fn main() {
let algorithm = Vegas::builder().initial_limit(5).max_limit(20).build();
let svc = ServiceBuilder::new()
.load_shed()
.buffer(BUFFER_SIZE)
.layer(ConcurrencyLimitLayer::new(algorithm))
.service_fn(handler);
let listener = TcpListener::bind("0.0.0.0:3000").await.unwrap();
println!("Listening on http://localhost:3000");
println!("Buffer size: {BUFFER_SIZE}, initial concurrency limit: 5, max: 20");
loop {
let (stream, _addr) = listener.accept().await.unwrap();
let svc = svc.clone();
tokio::spawn(async move {
let hyper_svc = hyper::service::service_fn(move |req: Request<Incoming>| {
let mut svc = svc.clone();
async move {
match svc.ready().await {
Ok(svc) => match svc.call(req).await {
Ok(resp) => Ok::<_, Infallible>(resp),
Err(err) => Ok(error_response(err)),
},
Err(err) => Ok(error_response(err)),
}
}
});
let result = Builder::new(TokioExecutor::new())
.serve_connection(TokioIo::new(stream), hyper_svc)
.await;
if let Err(err) = result {
eprintln!("Connection error: {err}");
}
});
}
}
fn error_response(err: tower::BoxError) -> Response<Full<Bytes>> {
if err.is::<tower::load_shed::error::Overloaded>() {
Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.body(Full::new(Bytes::from("service unavailable")))
.unwrap()
} else {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::new(Bytes::from(format!(
"internal server error: {err}"
))))
.unwrap()
}
}