#![allow(clippy::unused_async, reason = "Middleware functions need to be async")]
use super::{
errors::AuthError,
handlers::get_login,
state::StateProvider,
};
use crate::app::{
errors::AppError,
state::StateProvider as AppStateProvider
};
use axum::{
Extension,
body::Body,
extract::{FromRequestParts, State, rejection::ExtensionRejection},
http::{Request, StatusCode, Uri, request::Parts},
middleware::Next,
response::{IntoResponse as _, Response},
};
use core::fmt::{Debug, Display};
use rubedo::sugar::s;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use std::sync::Arc;
use tower_sessions::Session;
use tracing::info;
const SESSION_USER_ID_KEY: &str = "_user_id";
#[derive(Clone, Debug)]
pub struct Context<U: User> {
pub current_user: Option<U>,
session: Session,
}
impl<U: User> Context<U> {
#[must_use]
pub const fn new(session: Session) -> Self {
Self {
current_user: None,
session,
}
}
pub async fn get_user<SP, UP>(&self, state: &SP) -> Option<U>
where
SP: StateProvider,
UP: UserProvider<User = U>,
{
if let Ok(Some(user_id)) = self.session.get(SESSION_USER_ID_KEY).await {
if let Some(user) = UP::find_by_id(state, &user_id) {
return Some(user);
}
self.logout().await;
}
None
}
pub async fn login(&mut self, user: &U) -> Result<(), AuthError> {
self.session.insert(SESSION_USER_ID_KEY, user.id()).await?;
self.current_user = Some(user.clone());
Ok(())
}
pub async fn logout(&self) {
self.session.clear().await;
}
}
impl<S, U> FromRequestParts<S> for Context<U>
where
S: Send + Sync,
U: User,
{
type Rejection = ExtensionRejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Extension::<Self>::from_request_parts(parts, state).await.map(|Extension(stats_cx)| stats_cx)
}
}
pub trait Credentials: Clone + Debug + for<'de> Deserialize<'de> + Send + Sync + 'static {
fn to_loggable_string(&self) -> String;
}
pub trait User: Clone + Debug + Send + Sync + 'static {
type Id: Clone + Debug + DeserializeOwned + Display + Serialize + Send + Sync + 'static;
fn id(&self) -> &Self::Id;
fn to_loggable_string(&self) -> String;
}
pub trait UserProvider: Debug + 'static {
type Credentials: Credentials;
type User: User;
fn find_by_credentials<SP: StateProvider>(
state: &SP,
credentials: &Self::Credentials,
) -> Option<Self::User>;
fn find_by_id<SP: StateProvider>(
state: &SP,
id: &<Self::User as User>::Id,
) -> Option<Self::User>;
}
pub async fn auth_layer<SP, U, UP>(
State(state): State<Arc<SP>>,
Extension(session): Extension<Session>,
mut request: Request<Body>,
next: Next,
) -> Response
where
SP: StateProvider,
U: User,
UP: UserProvider<User = U>,
{
let mut auth_cx = Context::<U>::new(session);
let user = auth_cx.get_user::<SP, UP>(&state).await;
info!("Current user: {}", user.as_ref().map_or(s!("none"), |u| u.id().to_string()));
auth_cx.current_user = user;
drop(request.extensions_mut().insert(auth_cx));
next.run(request).await
}
pub async fn protect<SP, U>(
State(state): State<Arc<SP>>,
Extension(auth_cx): Extension<Context<U>>,
uri: Uri,
request: Request<Body>,
next: Next,
) -> Response
where
SP: StateProvider,
U: User,
{
match auth_cx.current_user {
Some(_) => next.run(request).await,
_ => {
(
StatusCode::UNAUTHORIZED,
get_login(State(state), uri).await,
).into_response()
},
}
}
pub async fn protected_error_layer<SP, U>(
State(state): State<Arc<SP>>,
Extension(auth_cx): Extension<Context<U>>,
uri: Uri,
request: Request<Body>,
next: Next,
) -> Result<Response, AppError>
where
SP: AppStateProvider,
U: User,
{
let response = next.run(request).await;
let (mut parts, body) = response.into_parts();
Ok(match parts.status {
StatusCode::NOT_FOUND => {
if parts.headers.contains_key("protected") && auth_cx.current_user.is_none() {
drop(parts.headers.remove("content-length"));
drop(parts.headers.remove("content-type"));
drop(parts.headers.remove("protected"));
parts.status = StatusCode::UNAUTHORIZED;
return Ok((
parts,
get_login(State(state), uri).await,
).into_response());
}
(
parts,
body,
).into_response()
},
_ => {
(
parts,
body,
).into_response()
},
})
}