use std::sync::Arc;
use axum::body::Body;
use axum::extract::{FromRequestParts, State};
use axum::http::request::Parts;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum::Router;
use crate::sql::sqlx::PgPool;
use super::auth_backends::{AuthError, AuthUser, BoxedBackend};
use super::permissions;
#[derive(Debug, Clone)]
pub struct AuthenticatedUser {
pub id: i64,
pub username: String,
pub is_superuser: bool,
}
impl From<AuthUser> for AuthenticatedUser {
fn from(u: AuthUser) -> Self {
Self { id: u.id, username: u.username, is_superuser: u.is_superuser }
}
}
pub struct CurrentUser(pub Option<AuthenticatedUser>);
impl<S: Send + Sync> FromRequestParts<S> for CurrentUser {
type Rejection = std::convert::Infallible;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
Ok(CurrentUser(parts.extensions.get::<AuthenticatedUser>().cloned()))
}
}
#[derive(Clone)]
struct AuthState {
backends: Arc<Vec<BoxedBackend>>,
pool: PgPool,
required: bool, }
#[derive(Clone)]
struct PermState {
codename: &'static str,
pool: PgPool,
}
async fn auth_middleware(
State(state): State<AuthState>,
mut req: Request<Body>,
next: Next,
) -> Response {
let headers = req.headers().clone();
let uri = req.uri().clone();
let method = req.method().clone();
let mut builder = axum::http::Request::builder().method(&method).uri(&uri);
for (k, v) in &headers {
builder = builder.header(k, v);
}
let dummy = builder.body(()).unwrap_or_else(|_| axum::http::Request::new(()));
let (dummy_parts, _) = dummy.into_parts();
let mut authenticated: Option<AuthUser> = None;
let mut error_response: Option<Response> = None;
for backend in state.backends.iter() {
match backend.authenticate(&dummy_parts, &state.pool).await {
Ok(Some(user)) => { authenticated = Some(user); break; }
Ok(None) => {}
Err(AuthError::Inactive) => {
error_response = Some((StatusCode::FORBIDDEN, "account inactive").into_response());
break;
}
Err(e) => {
error_response = Some((StatusCode::UNAUTHORIZED, e.to_string()).into_response());
break;
}
}
}
if let Some(resp) = error_response {
return resp;
}
match authenticated {
Some(user) => {
req.extensions_mut().insert(AuthenticatedUser::from(user));
next.run(req).await
}
None if state.required => {
(StatusCode::UNAUTHORIZED, "authentication required").into_response()
}
None => next.run(req).await,
}
}
async fn perm_middleware(
State(state): State<PermState>,
req: Request<Body>,
next: Next,
) -> Response {
let user = req.extensions().get::<AuthenticatedUser>().cloned();
let Some(user) = user else {
return (StatusCode::UNAUTHORIZED, "authentication required").into_response();
};
let ok = permissions::has_perm(user.id, state.codename, &state.pool)
.await
.unwrap_or(false);
if !ok {
return (
StatusCode::FORBIDDEN,
format!("permission required: {}", state.codename),
)
.into_response();
}
next.run(req).await
}
pub trait RouterAuthExt<S> {
fn require_auth(self, backends: Vec<BoxedBackend>, pool: PgPool) -> Self;
fn optional_auth(self, backends: Vec<BoxedBackend>, pool: PgPool) -> Self;
fn require_perm(self, codename: &'static str, pool: PgPool) -> Self;
}
impl<S: Clone + Send + Sync + 'static> RouterAuthExt<S> for Router<S> {
fn require_auth(self, backends: Vec<BoxedBackend>, pool: PgPool) -> Self {
let state = AuthState { backends: Arc::new(backends), pool, required: true };
self.layer(axum::middleware::from_fn_with_state(state, auth_middleware))
}
fn optional_auth(self, backends: Vec<BoxedBackend>, pool: PgPool) -> Self {
let state = AuthState { backends: Arc::new(backends), pool, required: false };
self.layer(axum::middleware::from_fn_with_state(state, auth_middleware))
}
fn require_perm(self, codename: &'static str, pool: PgPool) -> Self {
let state = PermState { codename, pool };
self.layer(axum::middleware::from_fn_with_state(state, perm_middleware))
}
}