use std::{collections::HashMap, convert::Infallible, future::Future, pin::Pin, sync::Arc};
use axum::{
extract::{FromRequestParts, MatchedPath, Request},
http::request::Parts,
response::{IntoResponse, Redirect, Response},
RequestExt,
};
use quokka::{
handler::html::TemplateDataLoader,
state::{FromState, ProvideState},
};
use crate::{service::page_loader::AdminPageLoader, state::AdminState};
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct PermissionContext {
pub verb: String,
pub resource: String,
}
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct AuthenticatedUser {
pub name: String,
pub groups: Vec<String>,
pub context: HashMap<String, serde_json::Value>,
}
pub trait AdminAuthProvider<S> {
type AuthParams: FromRequestParts<S>;
fn authenticate(
&self,
params: Self::AuthParams,
) -> impl Future<Output = quokka::Result<Option<AuthenticatedUser>>> + Send;
fn authorize(
&self,
user: &AuthenticatedUser,
permission: &PermissionContext,
) -> impl Future<Output = quokka::Result<bool>> + Send;
fn provider_name(&self) -> &str {
std::any::type_name_of_val(self)
}
}
#[derive(Clone)]
pub struct AuthProviders<S> {
pub(crate) providers: Vec<Arc<dyn InnerAuthProvider<S>>>,
}
#[derive(Clone, Default)]
pub struct LoginProviders {
pub(crate) providers: Vec<Arc<dyn InnerLoginProvider + Send + Sync>>,
}
#[derive(Clone)]
pub struct AdminAuthMiddleware<S> {
state: S,
}
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct LoginData {
pub login_name: String,
#[serde(skip_serializing)]
pub password: String,
}
#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
pub struct LoginResult {
pub user_identifier: String,
}
pub trait AdminLoginProvider {
fn do_login(
&self,
login_data: &LoginData,
) -> impl Future<Output = quokka::Result<Option<LoginResult>>> + Send;
fn type_name(&self) -> &str {
std::any::type_name_of_val(self)
}
}
#[doc(hidden)]
pub trait InnerAuthProvider<S>: Send + Sync {
fn authenticate<'a>(
&'a self,
request: &'a mut Request,
state: &'a S,
) -> Pin<Box<dyn Future<Output = quokka::Result<Option<AuthenticatedUser>>> + Send + 'a>>;
fn authorize<'a>(
&'a self,
user: &'a AuthenticatedUser,
permission: &'a PermissionContext,
) -> Pin<Box<dyn Future<Output = quokka::Result<bool>> + Send + 'a>>;
fn provider_name(&self) -> &str;
}
#[doc(hidden)]
pub trait InnerLoginProvider {
fn login<'a>(
&'a self,
login_data: &'a LoginData,
) -> Pin<Box<dyn Future<Output = quokka::Result<Option<LoginResult>>> + Send + 'a>>;
fn provider_name(&self) -> &str;
}
#[derive(Clone)]
#[doc(hidden)]
pub struct AdminAuthLayer<S, I> {
state: S,
inner: I,
admin: AdminState<S>,
page_loader: AdminPageLoader,
}
impl<T: AdminLoginProvider> InnerLoginProvider for T {
fn login<'a>(
&'a self,
login_data: &'a LoginData,
) -> Pin<Box<dyn Future<Output = quokka::Result<Option<LoginResult>>> + Send + 'a>> {
Box::pin(self.do_login(login_data))
}
fn provider_name(&self) -> &str {
self.type_name()
}
}
impl<S, T> InnerAuthProvider<S> for T
where
S: Send + Sync + 'static,
T: AdminAuthProvider<S> + Send + Sync,
T::AuthParams: 'static,
<T::AuthParams as FromRequestParts<S>>::Rejection: std::fmt::Debug,
{
fn authenticate<'a>(
&'a self,
request: &'a mut Request,
state: &'a S,
) -> Pin<Box<dyn Future<Output = quokka::Result<Option<AuthenticatedUser>>> + Send + 'a>> {
Box::pin(async move {
let params = request
.extract_parts_with_state::<T::AuthParams, S>(state)
.await
.inspect_err(|error| tracing::error!(?error, "Unable to extract request params"))
.map_err(|_| quokka::Error::status("Unable authenticate user", 500))?;
<T as AdminAuthProvider<S>>::authenticate(self, params).await
})
}
fn authorize<'a>(
&'a self,
user: &'a AuthenticatedUser,
permission: &'a PermissionContext,
) -> Pin<Box<dyn Future<Output = quokka::Result<bool>> + Send + 'a>> {
Box::pin(<T as AdminAuthProvider<S>>::authorize(
self, user, permission,
))
}
fn provider_name(&self) -> &str {
<T as AdminAuthProvider<S>>::provider_name(self)
}
}
impl<S, I> tower_layer::Layer<I> for AdminAuthMiddleware<S>
where
S: Send + Sync + Clone,
S: ProvideState<AdminState<S>>,
S: ProvideState<AdminPageLoader>,
{
type Service = AdminAuthLayer<S, I>;
fn layer(&self, inner: I) -> Self::Service {
AdminAuthLayer {
state: self.state.clone(),
inner,
admin: self.state.provide(),
page_loader: self.state.provide(),
}
}
}
impl<S, I> tower_service::Service<Request> for AdminAuthLayer<S, I>
where
I: tower_service::Service<Request, Response = Response, Error = Infallible>
+ Clone
+ Send
+ 'static,
I::Future: Send,
S: Send + Sync + Clone + 'static,
S: ProvideState<AdminState<S>>,
S: ProvideState<AdminPageLoader>,
{
type Response = Response;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn call(&mut self, mut request: Request) -> Self::Future {
let state = self.state.clone();
let admin = self.admin.clone();
let page_loader = self.page_loader.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
let mut user: Option<AuthenticatedUser> = None;
let permission: PermissionContext = request.extract_parts().await.unwrap();
for provider in &admin.auth_providers.providers {
match provider.authenticate(&mut request, &state).await {
Ok(Some(authenticated_user)) => {
user = Some(authenticated_user);
break;
}
Err(error) => {
tracing::error!(
?error,
provider = provider.provider_name(),
"Error while authenticating user"
)
}
_ => {}
}
}
let Some(user) = user else {
return Ok(Redirect::to(&admin.login_url).into_response());
};
let span = tracing::info_span!("authenticated user", ?user, ?permission);
let _ = span.enter();
if let Some(admin_group) = &admin.super_admin_group {
if user.groups.contains(admin_group) {
tracing::debug!(?user, ?permission, "Granted permission for super_admin");
let span = tracing::info_span!("super_admin user", ?user);
let _ = span.enter();
request.extensions_mut().insert(user);
request.extensions_mut().insert(permission);
return inner.call(request).await;
}
}
for provider in &admin.auth_providers.providers {
match provider.authorize(&user, &permission).await {
Ok(true) => {
tracing::debug!(
provider = provider.provider_name(),
"Granted permissions for user"
);
let span = tracing::info_span!("authorized user", ?user);
let _ = span.enter();
request.extensions_mut().insert(user);
request.extensions_mut().insert(permission);
return inner.call(request).await;
}
Err(error) => {
tracing::error!(
?error,
provider = provider.provider_name(),
"Error while checking authorization of user"
)
}
_ => {}
}
}
Ok(<AdminPageLoader as TemplateDataLoader<S>>::render_error(
&page_loader,
quokka::Error::status("Forbidden", 403),
)
.await
.into_response())
})
}
}
impl<S: Send + Sync> FromRequestParts<S> for PermissionContext {
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
if let Some(permission) = parts.extensions.get::<PermissionContext>() {
return Ok(permission.clone());
}
let uri = MatchedPath::from_request_parts(parts, state).await.unwrap();
Ok(PermissionContext {
verb: parts.method.to_string(),
resource: uri.as_str().to_string(),
})
}
}
impl<S> Default for AuthProviders<S> {
fn default() -> Self {
Self {
providers: Default::default(),
}
}
}
impl<S: Clone> FromState<S> for AdminAuthMiddleware<S> {
fn from_state(state: &S) -> Self {
Self {
state: state.clone(),
}
}
}