Skip to main content

reinhardt_auth/
auth_user.rs

1//! Authenticated user extractor that loads the full user model from database.
2//!
3//! Wraps the user model `U` as a tuple struct for destructuring, consistent
4//! with `Path`, `Json`, and other Reinhardt extractors.
5
6use crate::BaseUser;
7use async_trait::async_trait;
8use reinhardt_db::orm::{DatabaseConnection, Model};
9use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
10use reinhardt_http::AuthState;
11use std::sync::Arc;
12
13/// Authenticated user extractor that loads the full user model from database.
14///
15/// Wraps the user model `U` as a tuple struct for destructuring, consistent
16/// with `Path<T>`, `Json<T>`, and other Reinhardt extractors.
17///
18/// Requires `feature = "params"` to access request data from `InjectionContext`.
19///
20/// # Usage
21///
22/// ```rust,ignore
23/// use reinhardt_auth::AuthUser;
24/// use reinhardt_auth::DefaultUser;
25///
26/// #[get("/profile/")]
27/// pub async fn profile(
28///     #[inject] AuthUser(user): AuthUser<DefaultUser>,
29/// ) -> ViewResult<Response> {
30///     let username = user.get_username();
31///     // ...
32/// }
33/// ```
34///
35/// # Failure
36///
37/// Returns an injection error when:
38/// - No `AuthState` in request extensions (HTTP 401)
39/// - `user_id` parse failure (HTTP 401, not nil UUID fallback)
40/// - `DatabaseConnection` not registered in DI (HTTP 503)
41/// - Database query failure (HTTP 500)
42#[derive(Debug, Clone)]
43pub struct AuthUser<U: BaseUser>(pub U);
44
45#[cfg(feature = "params")]
46#[async_trait]
47impl<U> Injectable for AuthUser<U>
48where
49	U: BaseUser + Model + Clone + Send + Sync + 'static,
50	<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
51	<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
52	<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
53{
54	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
55		// Get HTTP request from context
56		let request = ctx.get_http_request().ok_or_else(|| {
57			DiError::NotFound("AuthUser: No HTTP request available in InjectionContext".to_string())
58		})?;
59
60		// Get AuthState from request extensions
61		let auth_state: AuthState = request.extensions.get().ok_or_else(|| {
62			DiError::NotFound("AuthUser: No AuthState found in request extensions".to_string())
63		})?;
64
65		if !auth_state.is_authenticated() {
66			return Err(DiError::NotFound(
67				"AuthUser: User is not authenticated".to_string(),
68			));
69		}
70
71		// Parse user_id — NO fallback to nil UUID (#2430)
72		let user_pk = auth_state
73			.user_id()
74			.parse::<<U as BaseUser>::PrimaryKey>()
75			.map_err(|e| {
76				::tracing::warn!(
77					user_id = %auth_state.user_id(),
78					error = ?e,
79					"AuthUser: failed to parse user_id from AuthState"
80				);
81				DiError::NotFound("AuthUser: Invalid user_id format in AuthState".to_string())
82			})?;
83
84		let model_pk = <U as Model>::PrimaryKey::from(user_pk);
85
86		// Resolve DatabaseConnection from DI (singleton-first, request-scope fallback)
87		// Uses get_singleton/get_request directly instead of ctx.resolve() because
88		// DatabaseConnection is pre-seeded into the singleton scope at server startup,
89		// not registered in the global DependencyRegistry.
90		let db: Arc<DatabaseConnection> = ctx
91			.get_singleton::<DatabaseConnection>()
92			.or_else(|| ctx.get_request::<DatabaseConnection>())
93			.ok_or_else(|| {
94				::tracing::warn!("AuthUser: DatabaseConnection not available for user resolution");
95				DiError::Internal {
96					message: "AuthUser: DatabaseConnection not registered in DI context"
97						.to_string(),
98				}
99			})?;
100
101		// Query user from database
102		let user = U::objects()
103			.get(model_pk)
104			.first_with_db(&db)
105			.await
106			.map_err(|e| {
107				::tracing::warn!(error = ?e, "AuthUser: Failed to load user from database");
108				DiError::Internal {
109					message: "AuthUser: Database query failed".to_string(),
110				}
111			})?
112			.ok_or_else(|| {
113				::tracing::warn!(
114					user_id = %auth_state.user_id(),
115					"AuthUser: User not found in database"
116				);
117				DiError::NotFound("AuthUser: User not found".to_string())
118			})?;
119
120		Ok(AuthUser(user))
121	}
122}
123
124#[cfg(not(feature = "params"))]
125#[async_trait]
126impl<U> Injectable for AuthUser<U>
127where
128	U: BaseUser + Model + Clone + Send + Sync + 'static,
129	<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
130	<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
131	<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
132{
133	async fn inject(_ctx: &InjectionContext) -> DiResult<Self> {
134		Err(DiError::NotFound(
135			"AuthUser requires the 'params' feature to be enabled".to_string(),
136		))
137	}
138}