use axum::body::Body;
use axum::extract::FromRef;
use axum::http::{Request, StatusCode, header::COOKIE};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde_json::json;
use std::sync::Arc;
use allowthem_core::{AuthClient, AuthError, PermissionName, RoleName, User, parse_session_cookie};
pub async fn require_auth<S>(
state: axum::extract::State<S>,
mut request: Request<Body>,
next: Next,
) -> Response
where
Arc<dyn AuthClient>: FromRef<S>,
S: Send + Sync + Clone,
{
let client = <Arc<dyn AuthClient>>::from_ref(&*state);
let headers = request.headers().clone();
let user = match authenticate(&*client, &headers).await {
Ok(u) => u,
Err(r) => return r,
};
request.extensions_mut().insert(user);
next.run(request).await
}
pub fn require_role<S>(
role: impl Into<String>,
) -> impl Fn(
axum::extract::State<S>,
Request<Body>,
Next,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
+ Clone
+ Send
+ 'static
where
Arc<dyn AuthClient>: FromRef<S>,
S: Send + Sync + Clone + 'static,
{
let role_name = role.into();
move |state, request, next| {
let role_name = role_name.clone();
Box::pin(require_role_inner(state, request, next, role_name))
}
}
async fn require_role_inner<S>(
state: axum::extract::State<S>,
mut request: Request<Body>,
next: Next,
role_name: String,
) -> Response
where
Arc<dyn AuthClient>: FromRef<S>,
S: Send + Sync + Clone,
{
let client = <Arc<dyn AuthClient>>::from_ref(&*state);
let headers = request.headers().clone();
let user = match authenticate(&*client, &headers).await {
Ok(u) => u,
Err(r) => return r,
};
let rn = RoleName::new(role_name);
match client.check_role(&user.id, &rn).await {
Ok(true) => {}
Ok(false) => {
return (
StatusCode::FORBIDDEN,
axum::Json(json!({"error": "forbidden"})),
)
.into_response();
}
Err(e) => return internal_error(e),
}
request.extensions_mut().insert(user);
next.run(request).await
}
pub fn require_permission<S>(
permission: impl Into<String>,
) -> impl Fn(
axum::extract::State<S>,
Request<Body>,
Next,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>>
+ Clone
+ Send
+ 'static
where
Arc<dyn AuthClient>: FromRef<S>,
S: Send + Sync + Clone + 'static,
{
let perm_name = permission.into();
move |state, request, next| {
let perm_name = perm_name.clone();
Box::pin(require_permission_inner(state, request, next, perm_name))
}
}
async fn require_permission_inner<S>(
state: axum::extract::State<S>,
mut request: Request<Body>,
next: Next,
perm_name: String,
) -> Response
where
Arc<dyn AuthClient>: FromRef<S>,
S: Send + Sync + Clone,
{
let client = <Arc<dyn AuthClient>>::from_ref(&*state);
let headers = request.headers().clone();
let user = match authenticate(&*client, &headers).await {
Ok(u) => u,
Err(r) => return r,
};
let pn = PermissionName::new(perm_name);
match client.check_permission(&user.id, &pn).await {
Ok(true) => {}
Ok(false) => {
return (
StatusCode::FORBIDDEN,
axum::Json(json!({"error": "forbidden"})),
)
.into_response();
}
Err(e) => return internal_error(e),
}
request.extensions_mut().insert(user);
next.run(request).await
}
async fn authenticate(
client: &dyn AuthClient,
headers: &axum::http::HeaderMap,
) -> Result<User, Response> {
let cookie_header = headers
.get(COOKIE)
.and_then(|v| v.to_str().ok())
.ok_or_else(unauthenticated)?
.to_string();
let token = parse_session_cookie(&cookie_header, client.session_cookie_name())
.ok_or_else(unauthenticated)?;
let user = client
.validate_session(&token)
.await
.map_err(internal_error)?
.ok_or_else(unauthenticated)?;
Ok(user)
}
fn unauthenticated() -> Response {
(
StatusCode::UNAUTHORIZED,
axum::Json(json!({"error": "unauthenticated"})),
)
.into_response()
}
fn internal_error(err: AuthError) -> Response {
tracing::error!("auth middleware error: {err}");
(
StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(json!({"error": "internal error"})),
)
.into_response()
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use allowthem_core::{
AllowThem, AllowThemBuilder, AuthClient, Email, EmbeddedAuthClient, generate_token,
hash_token,
};
use axum::extract::FromRef;
use axum::http::StatusCode;
use axum::routing::get;
use axum::{Router, middleware};
use chrono::{Duration, Utc};
use tower::ServiceExt;
#[derive(Clone)]
struct TestState {
auth: Arc<dyn AuthClient>,
}
impl FromRef<TestState> for Arc<dyn AuthClient> {
fn from_ref(s: &TestState) -> Self {
Arc::clone(&s.auth)
}
}
async fn test_setup() -> (AllowThem, String) {
let ath = AllowThemBuilder::new("sqlite::memory:")
.cookie_secure(false)
.build()
.await
.unwrap();
let email = Email::new("user@example.com".into()).unwrap();
let user = ath
.db()
.create_user(email, "password123", None)
.await
.unwrap();
let token = generate_token();
let token_hash = hash_token(&token);
let expires = Utc::now() + Duration::hours(24);
ath.db()
.create_session(user.id, token_hash, None, None, expires)
.await
.unwrap();
let cookie = ath.session_cookie(&token);
let cookie_value = cookie.split(';').next().unwrap().to_string();
(ath, cookie_value)
}
async fn ok_handler() -> StatusCode {
StatusCode::OK
}
fn auth_app(ath: AllowThem) -> Router {
let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
let state = TestState { auth };
Router::new()
.route("/protected", get(ok_handler))
.layer(middleware::from_fn_with_state(
state.clone(),
require_auth::<TestState>,
))
.with_state(state)
}
fn role_app(ath: AllowThem, role: &str) -> Router {
let role = role.to_string();
let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
let state = TestState { auth };
Router::new()
.route("/protected", get(ok_handler))
.layer(middleware::from_fn_with_state(
state.clone(),
require_role::<TestState>(role),
))
.with_state(state)
}
fn perm_app(ath: AllowThem, perm: &str) -> Router {
let perm = perm.to_string();
let auth: Arc<dyn AuthClient> = Arc::new(EmbeddedAuthClient::new(ath, "/login"));
let state = TestState { auth };
Router::new()
.route("/protected", get(ok_handler))
.layer(middleware::from_fn_with_state(
state.clone(),
require_permission::<TestState>(perm),
))
.with_state(state)
}
fn make_request(cookie: Option<&str>) -> axum::http::Request<Body> {
let mut builder = axum::http::Request::builder().uri("/protected");
if let Some(c) = cookie {
builder = builder.header(COOKIE, c);
}
builder.body(Body::empty()).unwrap()
}
#[tokio::test]
async fn authenticated_request_passes_through() {
let (ath, cookie) = test_setup().await;
let app = auth_app(ath);
let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn unauthenticated_request_returns_401() {
let (ath, _) = test_setup().await;
let app = auth_app(ath);
let resp = app.oneshot(make_request(None)).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn require_role_with_correct_role_passes() {
let (ath, cookie) = test_setup().await;
let rn = allowthem_core::RoleName::new("admin");
let role = ath.db().create_role(&rn, None).await.unwrap();
let email = Email::new("user@example.com".into()).unwrap();
let user = ath.db().get_user_by_email(&email).await.unwrap();
ath.db().assign_role(&user.id, &role.id).await.unwrap();
let app = role_app(ath, "admin");
let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn require_role_with_wrong_role_returns_403() {
let (ath, cookie) = test_setup().await;
let app = role_app(ath, "admin");
let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn require_permission_with_correct_permission_passes() {
let (ath, cookie) = test_setup().await;
let pn = allowthem_core::PermissionName::new("posts:write");
let perm = ath.db().create_permission(&pn, None).await.unwrap();
let email = Email::new("user@example.com".into()).unwrap();
let user = ath.db().get_user_by_email(&email).await.unwrap();
ath.db()
.assign_permission_to_user(&user.id, &perm.id)
.await
.unwrap();
let app = perm_app(ath, "posts:write");
let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn require_permission_with_missing_permission_returns_403() {
let (ath, cookie) = test_setup().await;
let app = perm_app(ath, "posts:write");
let resp = app.oneshot(make_request(Some(&cookie))).await.unwrap();
assert_eq!(resp.status(), StatusCode::FORBIDDEN);
}
}