use std::sync::Arc;
use axum::Json;
use axum::extract::Request;
use axum::extract::{FromRef, State};
use axum::http::{HeaderMap, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum_extra::extract::CookieJar;
use serde_json::json;
use crate::cookies::AUTH_COOKIE_NAME;
use crate::jwt::{AccessToken, JwtConfig};
fn extract_token(headers: &HeaderMap) -> Option<String> {
let jar = CookieJar::from_headers(headers);
jar.get(AUTH_COOKIE_NAME)
.map(|c| c.value().to_string())
.or_else(|| {
headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "))
.map(|t| t.to_string())
})
}
pub async fn jwt_auth<S>(State(state): State<S>, req: Request, next: Next) -> Response
where
S: Send + Sync,
Arc<JwtConfig>: FromRef<S>,
{
let jwt_config = Arc::<JwtConfig>::from_ref(&state);
let token = match extract_token(req.headers()) {
Some(t) => t,
None => {
return (
StatusCode::UNAUTHORIZED,
Json(json!({
"error": {
"code": "MISSING_TOKEN",
"message": "No authentication token provided",
}
})),
)
.into_response();
}
};
match AccessToken::decode(&token, &jwt_config) {
Ok(_) => next.run(req).await,
Err(_) => (
StatusCode::UNAUTHORIZED,
Json(json!({
"error": {
"code": "INVALID_TOKEN",
"message": "Invalid or expired authentication token",
}
})),
)
.into_response(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::Router;
use axum::body::Body;
use axum::http::Request;
use axum::middleware;
use axum::routing::get;
use http_body_util::BodyExt;
use tower::ServiceExt;
use uuid::Uuid;
fn test_jwt_config() -> Arc<JwtConfig> {
Arc::new(JwtConfig {
secret: "test-secret-key".to_string(),
access_token_ttl_secs: 900,
refresh_token_ttl_secs: 604800,
cookie_domain: None,
cookie_secure: false,
})
}
async fn protected_handler() -> &'static str {
"protected"
}
#[tokio::test]
async fn rejects_missing_token() {
let jwt_config = test_jwt_config();
let app = Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn_with_state(jwt_config.clone(), jwt_auth))
.with_state(jwt_config);
let req = Request::builder()
.uri("/protected")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"]["code"], "MISSING_TOKEN");
}
#[tokio::test]
async fn rejects_invalid_token() {
let jwt_config = test_jwt_config();
let app = Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn_with_state(jwt_config.clone(), jwt_auth))
.with_state(jwt_config);
let req = Request::builder()
.uri("/protected")
.header("authorization", "Bearer not.a.valid.token")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let body = resp.into_body().collect().await.unwrap().to_bytes();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["error"]["code"], "INVALID_TOKEN");
}
#[tokio::test]
async fn allows_valid_bearer_token() {
let jwt_config = test_jwt_config();
let user_id = Uuid::now_v7();
let token = AccessToken::for_user(user_id, "testuser", false, &jwt_config).unwrap();
let app = Router::new()
.route("/protected", get(protected_handler))
.layer(middleware::from_fn_with_state(jwt_config.clone(), jwt_auth))
.with_state(jwt_config);
let req = Request::builder()
.uri("/protected")
.header("authorization", format!("Bearer {}", token.0))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = resp.into_body().collect().await.unwrap().to_bytes();
assert_eq!(&body[..], b"protected");
}
}