use bytes::Bytes;
use http::header::{HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE};
use http_body_util::Full;
use hyper::{Method, Request, Response, StatusCode};
use hyperlite::{parse_json_body, serve, success, BoxBody, BoxError, Router};
use serde::Serialize;
use serde_json::Value;
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Instant;
use tower::util::BoxCloneSyncService;
use tower::{Layer, Service, ServiceBuilder};
use tower_http::cors::{AllowOrigin, CorsLayer};
use tower_http::request_id::{
MakeRequestUuid, PropagateRequestIdLayer, RequestId, SetRequestIdLayer,
};
use tracing::{info, warn};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
#[derive(Clone)]
struct AppState {
app_name: String,
request_count: Arc<Mutex<u64>>,
}
impl AppState {
fn new(name: impl Into<String>) -> Self {
Self {
app_name: name.into(),
request_count: Arc::new(Mutex::new(0)),
}
}
}
#[derive(Clone)]
struct RequestCounterLayer {
state: Arc<AppState>,
}
impl RequestCounterLayer {
fn new(state: Arc<AppState>) -> Self {
Self { state }
}
}
#[derive(Clone)]
struct RequestCounterService<S> {
inner: S,
state: Arc<AppState>,
}
impl<S> Layer<S> for RequestCounterLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Service = RequestCounterService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestCounterService {
inner,
state: self.state.clone(),
}
}
}
impl<S> Service<Request<BoxBody>> for RequestCounterService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
{
let mut total = self
.state
.request_count
.lock()
.expect("request counter poisoned");
*total += 1;
}
let future = self.inner.call(req);
Box::pin(future)
}
}
#[derive(Serialize)]
struct HealthResponse {
status: &'static str,
total_requests: u64,
service: String,
}
#[derive(Serialize)]
struct ProtectedResponse {
message: &'static str,
request_id: Option<String>,
}
#[derive(Serialize)]
struct EchoResponse {
received: Value,
}
async fn health_handler(
_req: Request<BoxBody>,
state: Arc<AppState>,
) -> Result<Response<Full<Bytes>>, BoxError> {
let total = *state
.request_count
.lock()
.expect("request counter poisoned");
let response = HealthResponse {
status: "ok",
total_requests: total,
service: state.app_name.clone(),
};
Ok(success(StatusCode::OK, response))
}
async fn protected_handler(
req: Request<BoxBody>,
_state: Arc<AppState>,
) -> Result<Response<Full<Bytes>>, BoxError> {
let request_id = req
.extensions()
.get::<RequestId>()
.and_then(|id| id.header_value().to_str().ok())
.map(|value| value.to_string());
if request_id.is_none() {
warn!("request id missing from protected handler");
}
let payload = ProtectedResponse {
message: "Authenticated request received",
request_id,
};
Ok(success(StatusCode::OK, payload))
}
async fn echo_handler(
req: Request<BoxBody>,
_state: Arc<AppState>,
) -> Result<Response<Full<Bytes>>, BoxError> {
let body = parse_json_body::<Value>(req).await?;
let payload = EchoResponse { received: body };
Ok(success(StatusCode::OK, payload))
}
#[derive(Clone)]
struct MiddlewareStack {
state: Arc<AppState>,
}
impl MiddlewareStack {
fn new(state: Arc<AppState>) -> Self {
Self { state }
}
}
#[derive(Clone, Default)]
struct RequestTracingLayer;
#[derive(Clone)]
struct RequestTracingService<S> {
inner: S,
}
impl<S> Layer<S> for MiddlewareStack
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Clone
+ Send
+ Sync
+ 'static,
S::Future: Send + 'static,
{
type Service = BoxCloneSyncService<Request<BoxBody>, Response<Full<Bytes>>, Infallible>;
fn layer(&self, inner: S) -> Self::Service {
let cors_layer = CorsLayer::new()
.allow_origin(AllowOrigin::list(vec![
HeaderValue::from_static("http://localhost:3000"),
HeaderValue::from_static("http://localhost:5173"),
]))
.allow_methods([
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::OPTIONS,
])
.allow_headers([ACCEPT, AUTHORIZATION, CONTENT_TYPE])
.allow_credentials(true);
let request_id_layer = SetRequestIdLayer::x_request_id(MakeRequestUuid);
let propagate_layer = PropagateRequestIdLayer::x_request_id();
let trace_layer = RequestTracingLayer;
let service = ServiceBuilder::new()
.layer(request_id_layer)
.layer(propagate_layer)
.layer(trace_layer)
.layer(cors_layer)
.layer(RequestCounterLayer::new(self.state.clone()))
.service(inner);
BoxCloneSyncService::new(service)
}
}
impl<S> Layer<S> for RequestTracingLayer
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Service = RequestTracingService<S>;
fn layer(&self, inner: S) -> Self::Service {
RequestTracingService { inner }
}
}
impl<S> Service<Request<BoxBody>> for RequestTracingService<S>
where
S: Service<Request<BoxBody>, Response = Response<Full<Bytes>>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<Full<Bytes>>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
let method = req.method().clone();
let path = req.uri().path().to_owned();
let request_id = req
.headers()
.get("x-request-id")
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string());
let start = Instant::now();
let future = self.inner.call(req);
Box::pin(async move {
let response = future.await?;
let status = response.status();
let elapsed_ms = start.elapsed().as_millis();
if let Some(id) = request_id {
info!(%method, %path, %status, elapsed_ms, %id, "request completed");
} else {
info!(%method, %path, %status, elapsed_ms, "request completed");
}
Ok(response)
})
}
}
fn build_middleware_stack(state: Arc<AppState>) -> MiddlewareStack {
MiddlewareStack::new(state)
}
#[tokio::main]
async fn main() -> Result<(), BoxError> {
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")))
.with(tracing_subscriber::fmt::layer())
.init();
let base_state = AppState::new("Hyperlite Middleware Demo");
let state = Arc::new(base_state.clone());
let router = Router::new(base_state)
.route(
"/health",
Method::GET,
Arc::new(|req, state| Box::pin(health_handler(req, state))),
)
.route(
"/protected",
Method::GET,
Arc::new(|req, state| Box::pin(protected_handler(req, state))),
)
.route(
"/echo",
Method::POST,
Arc::new(|req, state| Box::pin(echo_handler(req, state))),
);
let service = build_middleware_stack(state.clone()).layer(router);
let addr: SocketAddr = "127.0.0.1:3000"
.parse()
.expect("valid socket address for middleware example");
info!("Middleware example running on http://{addr}");
serve(addr, service).await
}