use crate::AuthenticationError;
use crate::BaseUser;
use async_trait::async_trait;
use reinhardt_db::orm::{DatabaseConnection, Model};
#[cfg(not(feature = "params"))]
use reinhardt_di::DiError;
use reinhardt_di::{DiResult, Injectable, InjectionContext};
use reinhardt_http::AuthState;
use std::sync::Arc;
use uuid::Uuid;
#[deprecated(
since = "0.1.0-rc.12",
note = "Use `AuthUser<U>` instead. `CurrentUser<U>` will become a type alias for `AuthUser<U>` in 0.2.0."
)]
pub struct CurrentUser<U: BaseUser + Clone> {
user: Option<U>,
user_id: Option<Uuid>,
}
#[allow(deprecated)] impl<U: BaseUser + Clone> Clone for CurrentUser<U> {
fn clone(&self) -> Self {
Self {
user: self.user.clone(),
user_id: self.user_id,
}
}
}
#[allow(deprecated)] impl<U: BaseUser + Clone> CurrentUser<U> {
pub fn authenticated(user: U, user_id: Uuid) -> Self {
Self {
user: Some(user),
user_id: Some(user_id),
}
}
pub fn anonymous() -> Self {
Self {
user: None,
user_id: None,
}
}
pub fn is_authenticated(&self) -> bool {
self.user.is_some()
}
pub fn user(&self) -> Result<&U, AuthenticationError> {
self.user
.as_ref()
.ok_or(AuthenticationError::NotAuthenticated)
}
pub fn id(&self) -> Result<Uuid, AuthenticationError> {
self.user_id.ok_or(AuthenticationError::NotAuthenticated)
}
pub fn into_user(self) -> Result<U, AuthenticationError> {
self.user.ok_or(AuthenticationError::NotAuthenticated)
}
pub fn as_any(&self) -> Option<&(dyn std::any::Any + Send + Sync)>
where
U: 'static,
{
self.user
.as_ref()
.map(|u| u as &(dyn std::any::Any + Send + Sync))
}
}
#[allow(deprecated)] #[async_trait]
impl<U> Injectable for CurrentUser<U>
where
U: BaseUser + Model + Clone + Send + Sync + 'static,
<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
{
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
#[cfg(feature = "params")]
let request = match ctx.get_http_request() {
Some(req) => req,
None => return Ok(Self::anonymous()),
};
#[cfg(not(feature = "params"))]
return Err(DiError::NotFound(
"CurrentUser requires the 'params' feature to be enabled".to_string(),
));
#[cfg(feature = "params")]
let auth_state: AuthState = match request.extensions.get() {
Some(state) => state,
None => return Ok(Self::anonymous()),
};
#[cfg(feature = "params")]
if !auth_state.is_authenticated() {
return Ok(Self::anonymous());
}
#[cfg(feature = "params")]
let base_pk: <U as BaseUser>::PrimaryKey = match auth_state.user_id().parse() {
Ok(pk) => pk,
Err(_) => return Ok(Self::anonymous()),
};
#[cfg(feature = "params")]
let model_pk: <U as Model>::PrimaryKey = base_pk.into();
#[cfg(feature = "params")]
let db: Arc<DatabaseConnection> = match ctx
.get_singleton::<DatabaseConnection>()
.or_else(|| ctx.get_request::<DatabaseConnection>())
{
Some(conn) => conn,
None => {
::tracing::warn!(
"DatabaseConnection not registered in DI context. \
CurrentUser will be anonymous. \
Hint: Register DatabaseConnection as a singleton in InjectionContext."
);
return Ok(Self::anonymous());
}
};
#[cfg(feature = "params")]
let user: U = match U::objects().get(model_pk).first_with_db(&db).await {
Ok(Some(u)) => u,
Ok(None) | Err(_) => return Ok(Self::anonymous()),
};
#[cfg(feature = "params")]
let user_id = match Uuid::parse_str(auth_state.user_id()) {
Ok(id) => id,
Err(e) => {
::tracing::warn!(
user_id = %auth_state.user_id(),
error = ?e,
"CurrentUser: failed to parse user_id as UUID"
);
return Ok(Self::anonymous());
}
};
#[cfg(feature = "params")]
Ok(Self::authenticated(user, user_id))
}
}
#[allow(deprecated)] #[cfg(test)]
mod tests {
use super::*;
use crate::PasswordHasher;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
#[derive(Default)]
struct MockHasher;
impl PasswordHasher for MockHasher {
fn hash(&self, password: &str) -> Result<String, reinhardt_core::exception::Error> {
Ok(format!("hashed:{}", password))
}
fn verify(
&self,
password: &str,
hash: &str,
) -> Result<bool, reinhardt_core::exception::Error> {
Ok(hash == format!("hashed:{}", password))
}
}
#[derive(Clone, Serialize, Deserialize)]
struct TestUser {
id: Uuid,
username: String,
is_active: bool,
}
impl BaseUser for TestUser {
type PrimaryKey = Uuid;
type Hasher = MockHasher;
fn get_username_field() -> &'static str {
"username"
}
fn get_username(&self) -> &str {
&self.username
}
fn password_hash(&self) -> Option<&str> {
None
}
fn set_password_hash(&mut self, _hash: String) {}
fn last_login(&self) -> Option<DateTime<Utc>> {
None
}
fn set_last_login(&mut self, _time: DateTime<Utc>) {}
fn is_active(&self) -> bool {
self.is_active
}
}
#[test]
fn test_authenticated_user() {
let user_id = Uuid::now_v7();
let user = TestUser {
id: user_id,
username: "testuser".to_string(),
is_active: true,
};
let current_user = CurrentUser::authenticated(user, user_id);
assert!(current_user.is_authenticated());
assert_eq!(current_user.id().unwrap(), user_id);
assert_eq!(current_user.user().unwrap().get_username(), "testuser");
}
#[test]
fn test_anonymous_user() {
let current_user: CurrentUser<TestUser> = CurrentUser::anonymous();
assert!(!current_user.is_authenticated());
assert!(current_user.id().is_err());
assert!(current_user.user().is_err());
}
#[test]
fn test_into_user_authenticated() {
let user_id = Uuid::now_v7();
let user = TestUser {
id: user_id,
username: "testuser".to_string(),
is_active: true,
};
let current_user = CurrentUser::authenticated(user, user_id);
let extracted = current_user.into_user().unwrap();
assert_eq!(extracted.get_username(), "testuser");
}
#[test]
fn test_into_user_anonymous() {
let current_user: CurrentUser<TestUser> = CurrentUser::anonymous();
let result = current_user.into_user();
assert!(result.is_err());
}
}