use axum::{
extract::{Request, State},
middleware::Next,
response::Response,
};
use http::StatusCode;
use std::sync::Arc;
use tracing::info_span;
use subtle::ConstantTimeEq;
use crate::server::AppState;
pub async fn trace_request(request: Request, next: Next) -> Response {
let span = info_span!(
"HTTP request",
method = %request.method(),
uri = %request.uri(),
);
let _enter = span.enter();
next.run(request).await
}
pub async fn auth_middleware(
State(state): State<Arc<AppState>>,
request: Request,
next: Next,
) -> std::result::Result<Response, StatusCode> {
const API_KEY_HEADER: &str = "X-API-Key";
let expected_key = match &state.api_key {
Some(key) => key,
None => return Ok(next.run(request).await),
};
let provided_key = request
.headers()
.get(API_KEY_HEADER)
.and_then(|h| h.to_str().ok());
match provided_key {
Some(key) if key.as_bytes().ct_eq(expected_key.as_bytes()).into() => {
Ok(next.run(request).await)
}
Some(_) => {
tracing::warn!("Invalid API key provided");
Err(StatusCode::UNAUTHORIZED)
}
None => {
tracing::warn!("No API key provided");
Err(StatusCode::UNAUTHORIZED)
}
}
}
pub fn create_concurrency_limit_layer() -> tower::limit::ConcurrencyLimitLayer {
tower::limit::ConcurrencyLimitLayer::new(100)
}
pub fn create_custom_concurrency_limit_layer(max: usize) -> tower::limit::ConcurrencyLimitLayer {
tower::limit::ConcurrencyLimitLayer::new(max)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_concurrency_limit_layer() {
let _layer = create_concurrency_limit_layer();
}
#[test]
fn test_create_custom_concurrency_limit_layer() {
let _layer = create_custom_concurrency_limit_layer(50);
}
}