use crate::core::AdminUser;
use async_trait::async_trait;
use reinhardt_auth::BaseUser;
use reinhardt_db::orm::{DatabaseConnection, Model};
use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
use reinhardt_http::AuthState;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub(crate) type AdminUserLoaderFn = Arc<
dyn Fn(
String,
Arc<DatabaseConnection>,
) -> Pin<Box<dyn Future<Output = Result<Arc<dyn AdminUser>, DiError>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub(crate) struct AdminUserLoader(pub(crate) AdminUserLoaderFn);
#[derive(Clone)]
pub struct AdminAuthenticatedUser(pub Arc<dyn AdminUser>);
impl std::fmt::Debug for AdminAuthenticatedUser {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AdminAuthenticatedUser")
.field("username", &self.0.get_username())
.finish()
}
}
#[async_trait]
impl Injectable for AdminAuthenticatedUser {
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
let request = ctx.get_http_request().ok_or_else(|| {
DiError::Authentication(
"AdminAuthenticatedUser: No HTTP request available in InjectionContext".to_string(),
)
})?;
let auth_state: AuthState = request.extensions.get().ok_or_else(|| {
DiError::Authentication(
"AdminAuthenticatedUser: No AuthState found in request extensions".to_string(),
)
})?;
if !auth_state.is_authenticated() {
return Err(DiError::Authentication(
"AdminAuthenticatedUser: User is not authenticated".to_string(),
));
}
let user_id = auth_state.user_id().to_string();
let loader: Arc<AdminUserLoader> =
ctx.get_singleton::<AdminUserLoader>()
.ok_or_else(|| DiError::NotRegistered {
type_name: "AdminUserLoader".into(),
hint: "Call AdminSite::set_user_type::<U>() before building admin routes, \
or use the default by calling admin_routes_with_di() which \
registers AdminDefaultUser as a fallback."
.into(),
})?;
let db: Arc<DatabaseConnection> = ctx
.get_singleton::<DatabaseConnection>()
.or_else(|| ctx.get_request::<DatabaseConnection>())
.ok_or_else(|| {
::tracing::warn!(
"AdminAuthenticatedUser: DatabaseConnection not available for user resolution"
);
DiError::Internal {
message:
"AdminAuthenticatedUser: DatabaseConnection not registered in DI context"
.to_string(),
}
})?;
let user = (loader.0)(user_id, db).await?;
if !user.is_active() {
return Err(DiError::Authentication(
"User account is not active".to_string(),
));
}
if !user.is_staff() {
return Err(DiError::Authentication(
"User does not have staff privileges".to_string(),
));
}
Ok(AdminAuthenticatedUser(user))
}
}
pub(crate) fn create_admin_user_loader<U>() -> AdminUserLoader
where
U: BaseUser + AdminUser + Model + Clone + Send + Sync + 'static,
<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
{
let loader: AdminUserLoaderFn = Arc::new(move |user_id, db| {
Box::pin(async move {
let pk = user_id
.parse::<<U as BaseUser>::PrimaryKey>()
.map_err(|e| {
::tracing::warn!(
user_id = %user_id,
error = ?e,
"AdminUserLoader: failed to parse user_id"
);
DiError::Authentication("AdminUserLoader: Invalid user_id format".to_string())
})?;
let model_pk = <U as Model>::PrimaryKey::from(pk);
let user = U::objects()
.get(model_pk)
.first_with_db(&db)
.await
.map_err(|e| {
::tracing::warn!(error = ?e, "AdminUserLoader: Database query failed");
DiError::Internal {
message: "AdminUserLoader: Database query failed".to_string(),
}
})?
.ok_or_else(|| {
::tracing::warn!(
user_id = %user_id,
"AdminUserLoader: User not found in database"
);
DiError::NotFound("AdminUserLoader: User not found".to_string())
})?;
Ok(Arc::new(user) as Arc<dyn AdminUser>)
})
});
AdminUserLoader(loader)
}
pub(crate) struct AuthenticatedUserInfo {
pub(crate) user_id: String,
pub(crate) username: String,
pub(crate) is_staff: bool,
pub(crate) is_superuser: bool,
}
pub(crate) type AdminLoginAuthenticatorFn = Arc<
dyn Fn(
String,
String,
Arc<DatabaseConnection>,
)
-> Pin<Box<dyn Future<Output = Result<Option<AuthenticatedUserInfo>, DiError>> + Send>>
+ Send
+ Sync,
>;
#[derive(Clone)]
pub struct AdminLoginAuthenticator(pub(crate) AdminLoginAuthenticatorFn);
#[async_trait]
impl Injectable for AdminLoginAuthenticator {
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
ctx.get_singleton::<AdminLoginAuthenticator>()
.map(|arc| (*arc).clone())
.ok_or_else(|| DiError::NotRegistered {
type_name: "AdminLoginAuthenticator".into(),
hint: "Call AdminSite::set_user_type::<U>() or use admin_routes_with_di() \
which registers AdminDefaultUser as a fallback."
.into(),
})
}
}
pub(crate) fn create_admin_login_authenticator<U>() -> AdminLoginAuthenticator
where
U: BaseUser + AdminUser + Model + Clone + Send + Sync + 'static,
<U as BaseUser>::PrimaryKey: ToString + Send + Sync,
{
use reinhardt_db::orm::{Filter, FilterOperator, FilterValue};
let authenticator: AdminLoginAuthenticatorFn = Arc::new(move |username, password, db| {
Box::pin(async move {
let user: Option<U> = U::objects()
.filter_by(Filter::new(
"username",
FilterOperator::Eq,
FilterValue::String(username.clone()),
))
.first_with_db(&db)
.await
.map_err(|e| {
::tracing::warn!(error = ?e, "AdminLoginAuthenticator: Database query failed");
DiError::Internal {
message: "AdminLoginAuthenticator: Database query failed".to_string(),
}
})?;
let Some(user) = user else {
::tracing::debug!(username = %username, "AdminLoginAuthenticator: User not found");
return Ok(None);
};
let password_valid = user.check_password(&password).map_err(|e| {
::tracing::warn!(error = ?e, "AdminLoginAuthenticator: Password check failed");
DiError::Internal {
message: "AdminLoginAuthenticator: Password verification error".to_string(),
}
})?;
if !password_valid {
::tracing::debug!(username = %username, "AdminLoginAuthenticator: Invalid password");
return Ok(None);
}
if !AdminUser::is_active(&user) {
::tracing::debug!(username = %username, "AdminLoginAuthenticator: User is not active");
return Ok(None);
}
if !user.is_staff() {
::tracing::debug!(username = %username, "AdminLoginAuthenticator: User is not staff");
return Ok(None);
}
let user_id = user
.primary_key()
.map(|pk| pk.to_string())
.unwrap_or_default();
Ok(Some(AuthenticatedUserInfo {
user_id,
username: AdminUser::get_username(&user).to_string(),
is_staff: user.is_staff(),
is_superuser: user.is_superuser(),
}))
})
});
AdminLoginAuthenticator(authenticator)
}
#[cfg(test)]
mod tests {
use super::*;
use reinhardt_di::SingletonScope;
use rstest::rstest;
#[rstest]
#[tokio::test]
async fn test_inject_returns_error_when_no_http_request() {
let singleton = Arc::new(SingletonScope::new());
let ctx = InjectionContext::builder(singleton).build();
let result = AdminAuthenticatedUser::inject(&ctx).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("No HTTP request"),
"Expected 'No HTTP request' error, got: {}",
err
);
}
#[rstest]
#[tokio::test]
async fn test_inject_returns_error_when_no_auth_state() {
let singleton = Arc::new(SingletonScope::new());
let request = reinhardt_http::Request::builder()
.uri("/admin/test")
.build()
.expect("Failed to build test request");
let ctx = InjectionContext::builder(singleton)
.with_request(request)
.build();
let result = AdminAuthenticatedUser::inject(&ctx).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("No AuthState"),
"Expected 'No AuthState' error, got: {}",
err
);
}
#[rstest]
#[tokio::test]
async fn test_inject_returns_error_when_not_authenticated() {
let singleton = Arc::new(SingletonScope::new());
let request = reinhardt_http::Request::builder()
.uri("/admin/test")
.build()
.expect("Failed to build test request");
request.extensions.insert(AuthState::anonymous());
let ctx = InjectionContext::builder(singleton)
.with_request(request)
.build();
let result = AdminAuthenticatedUser::inject(&ctx).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("not authenticated"),
"Expected 'not authenticated' error, got: {}",
err
);
}
#[rstest]
#[tokio::test]
async fn test_inject_returns_error_when_no_loader_registered() {
let singleton = Arc::new(SingletonScope::new());
let request = reinhardt_http::Request::builder()
.uri("/admin/test")
.build()
.expect("Failed to build test request");
request
.extensions
.insert(AuthState::authenticated("user-123", true, true));
let ctx = InjectionContext::builder(singleton)
.with_request(request)
.build();
let result = AdminAuthenticatedUser::inject(&ctx).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("AdminUserLoader"),
"Expected 'AdminUserLoader' error, got: {}",
err
);
}
#[rstest]
#[tokio::test]
async fn test_inject_returns_error_when_no_database_connection() {
let singleton = Arc::new(SingletonScope::new());
let loader = AdminUserLoader(Arc::new(|_user_id, _db| {
Box::pin(async { Err(DiError::NotFound("should not be called".to_string())) })
}));
singleton.set_arc(Arc::new(loader));
let request = reinhardt_http::Request::builder()
.uri("/admin/test")
.build()
.expect("Failed to build test request");
request
.extensions
.insert(AuthState::authenticated("user-123", true, true));
let ctx = InjectionContext::builder(singleton)
.with_request(request)
.build();
let result = AdminAuthenticatedUser::inject(&ctx).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("DatabaseConnection"),
"Expected error mentioning DatabaseConnection, got: {}",
err
);
}
#[rstest]
fn test_admin_user_loader_can_be_stored_in_singleton_scope() {
let singleton = SingletonScope::new();
let loader = AdminUserLoader(Arc::new(|_user_id, _db| {
Box::pin(async { Err(DiError::NotFound("test loader".to_string())) })
}));
singleton.set_arc(Arc::new(loader));
let retrieved = singleton.get::<AdminUserLoader>();
assert!(
retrieved.is_some(),
"AdminUserLoader should be retrievable from singleton scope"
);
}
}