use hyper::body::Incoming;
use hyper::{Request, Response};
use crate::auth::{AuthConfig, CurrentUser, PublicRoutes};
use crate::context::RequestContext;
use crate::error::Error;
use crate::middleware::{BoxFuture, Middleware, Next};
use crate::response::{BoxBody, IntoResponse};
pub struct AuthMiddleware {
config: AuthConfig,
public_routes: PublicRoutes,
}
impl AuthMiddleware {
pub fn new(config: AuthConfig) -> Self {
Self {
config,
public_routes: PublicRoutes::new(),
}
}
pub fn with_public_routes(config: AuthConfig, public_routes: PublicRoutes) -> Self {
Self {
config,
public_routes,
}
}
fn extract_bearer_token(req: &Request<Incoming>) -> Option<&str> {
req.headers()
.get(http::header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
}
}
impl Middleware for AuthMiddleware {
fn handle<'a>(
&'a self,
mut req: Request<Incoming>,
_ctx: &'a RequestContext,
next: Next<'a>,
) -> BoxFuture<'a, Response<BoxBody>> {
Box::pin(async move {
let method = req.method().as_str();
let path = req.uri().path();
if self.public_routes.is_public(method, path) {
return next.run(req).await;
}
let token = match Self::extract_bearer_token(&req) {
Some(t) => t,
None => {
return Error::unauthorized("missing authorization header").into_response();
}
};
let claims = match self.config.decode(token) {
Ok(c) => c,
Err(e) => {
return e.into_response();
}
};
let current_user = CurrentUser {
id: claims.sub.clone(),
claims,
};
req.extensions_mut().insert(current_user);
next.run(req).await
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_middleware_new() {
let config = AuthConfig::new("secret", 3600);
let _middleware = AuthMiddleware::new(config);
}
#[test]
fn test_auth_middleware_with_public_routes() {
let config = AuthConfig::new("secret", 3600);
let mut public = PublicRoutes::new();
public.add("GET", "/health");
let middleware = AuthMiddleware::with_public_routes(config, public);
assert!(middleware.public_routes.is_public("GET", "/health"));
}
}