Skip to main content

reinhardt_auth/
auth_user.rs

1//! Authenticated user extractor that loads the full user model from database.
2//!
3//! Wraps the user model `U` as a tuple struct for destructuring, consistent
4//! with `Path`, `Json`, and other Reinhardt extractors.
5
6use crate::BaseUser;
7use async_trait::async_trait;
8use reinhardt_db::orm::{CustomManager, DatabaseConnection, Model};
9use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
10use reinhardt_http::AuthState;
11use std::sync::Arc;
12
13/// Authenticated user extractor that loads the full user model from database.
14///
15/// Wraps the user model `U` as a tuple struct for destructuring, consistent
16/// with `Path<T>`, `Json<T>`, and other Reinhardt extractors.
17///
18/// Requires `feature = "params"` to access request data from `InjectionContext`.
19///
20/// # Usage
21///
22/// ```rust,ignore
23/// use reinhardt_auth::CurrentUser;
24/// use reinhardt_auth::DefaultUser;
25///
26/// #[get("/profile/")]
27/// pub async fn profile(
28///     #[inject] CurrentUser(user): CurrentUser<DefaultUser>,
29/// ) -> ViewResult<Response> {
30///     let username = user.get_username();
31///     // ...
32/// }
33/// ```
34///
35/// # Failure
36///
37/// Returns an injection error when:
38/// - No `AuthState` in request extensions (HTTP 401)
39/// - `user_id` parse failure (HTTP 401, not nil UUID fallback)
40/// - `DatabaseConnection` not registered in DI (HTTP 503)
41/// - Database query failure (HTTP 500)
42#[derive(Debug, Clone)]
43pub struct CurrentUser<U: BaseUser>(pub U);
44
45/// Deprecated compatibility name for [`CurrentUser`].
46///
47/// Use [`CurrentUser`] in new code. This wrapper has the same tuple-struct
48/// shape and injection behavior so existing `AuthUser(user): AuthUser<U>`
49/// patterns continue to compile until the 0.3 removal.
50#[deprecated(
51	since = "0.2.0",
52	note = "use CurrentUser; AuthUser will be removed in 0.3"
53)]
54#[derive(Debug, Clone)]
55pub struct AuthUser<U: BaseUser>(pub U);
56
57#[cfg(feature = "params")]
58async fn resolve_current_user<U>(ctx: &InjectionContext) -> DiResult<U>
59where
60	U: BaseUser + Model + Clone + Send + Sync + 'static,
61	<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
62	<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
63	<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
64{
65	// Get HTTP request from context.
66	let request = ctx.get_http_request().ok_or_else(|| {
67		DiError::NotFound("CurrentUser: No HTTP request available in InjectionContext".to_string())
68	})?;
69
70	// Get AuthState from request extensions.
71	let auth_state: AuthState = request.extensions.get().ok_or_else(|| {
72		DiError::NotFound("CurrentUser: No AuthState found in request extensions".to_string())
73	})?;
74
75	if !auth_state.is_authenticated() {
76		return Err(DiError::Authentication(
77			"CurrentUser: User is not authenticated".to_string(),
78		));
79	}
80
81	// Parse user_id — NO fallback to nil UUID (#2430).
82	let user_pk = auth_state
83		.user_id()
84		.parse::<<U as BaseUser>::PrimaryKey>()
85		.map_err(|e| {
86			::tracing::warn!(
87				user_id = %auth_state.user_id(),
88				error = ?e,
89				"CurrentUser: failed to parse user_id from AuthState"
90			);
91			DiError::Authentication("CurrentUser: Invalid user_id format in AuthState".to_string())
92		})?;
93
94	let model_pk = <U as Model>::PrimaryKey::from(user_pk);
95
96	// Resolve DatabaseConnection from DI (singleton-first, request-scope fallback)
97	// using get_singleton/get_request directly because DatabaseConnection is
98	// pre-seeded into the singleton scope at server startup, not registered in
99	// the global DependencyRegistry.
100	let db: Arc<DatabaseConnection> = ctx
101		.get_singleton::<DatabaseConnection>()
102		.or_else(|| ctx.get_request::<DatabaseConnection>())
103		.ok_or_else(|| {
104			::tracing::warn!("CurrentUser: DatabaseConnection not available for user resolution");
105			DiError::Internal {
106				message: "CurrentUser: DatabaseConnection not registered in DI context".to_string(),
107			}
108		})?;
109
110	U::objects()
111		.get(model_pk)
112		.first_with_db(&db)
113		.await
114		.map_err(|e| {
115			::tracing::warn!(error = ?e, "CurrentUser: Failed to load user from database");
116			DiError::Internal {
117				message: "CurrentUser: Database query failed".to_string(),
118			}
119		})?
120		.ok_or_else(|| {
121			::tracing::warn!(
122				user_id = %auth_state.user_id(),
123				"CurrentUser: User not found in database"
124			);
125			DiError::NotFound("CurrentUser: User not found".to_string())
126		})
127}
128
129#[cfg(feature = "params")]
130#[async_trait]
131impl<U> Injectable for CurrentUser<U>
132where
133	U: BaseUser + Model + Clone + Send + Sync + 'static,
134	<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
135	<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
136	<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
137{
138	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
139		resolve_current_user(ctx).await.map(CurrentUser)
140	}
141}
142
143#[cfg(not(feature = "params"))]
144#[async_trait]
145impl<U> Injectable for CurrentUser<U>
146where
147	U: BaseUser + Model + Clone + Send + Sync + 'static,
148	<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
149	<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
150	<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
151{
152	async fn inject(_ctx: &InjectionContext) -> DiResult<Self> {
153		Err(DiError::NotFound(
154			"CurrentUser requires the 'params' feature to be enabled".to_string(),
155		))
156	}
157}
158
159#[cfg(feature = "params")]
160#[allow(deprecated)]
161#[async_trait]
162impl<U> Injectable for AuthUser<U>
163where
164	U: BaseUser + Model + Clone + Send + Sync + 'static,
165	<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
166	<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
167	<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
168{
169	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
170		resolve_current_user(ctx).await.map(AuthUser)
171	}
172}
173
174#[cfg(not(feature = "params"))]
175#[allow(deprecated)]
176#[async_trait]
177impl<U> Injectable for AuthUser<U>
178where
179	U: BaseUser + Model + Clone + Send + Sync + 'static,
180	<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
181	<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
182	<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
183{
184	async fn inject(_ctx: &InjectionContext) -> DiResult<Self> {
185		Err(DiError::NotFound(
186			"AuthUser requires the 'params' feature to be enabled".to_string(),
187		))
188	}
189}
190
191#[cfg(test)]
192#[allow(deprecated)]
193mod tests {
194	use super::{AuthUser, CurrentUser};
195	use crate::{BaseUser, PasswordHasher};
196	use chrono::{DateTime, Utc};
197	use serde::{Deserialize, Serialize};
198
199	#[derive(Default)]
200	struct TestHasher;
201
202	impl PasswordHasher for TestHasher {
203		fn hash(&self, password: &str) -> Result<String, reinhardt_core::exception::Error> {
204			Ok(password.to_string())
205		}
206
207		fn verify(
208			&self,
209			password: &str,
210			hash: &str,
211		) -> Result<bool, reinhardt_core::exception::Error> {
212			Ok(password == hash)
213		}
214	}
215
216	#[derive(Clone, Serialize, Deserialize)]
217	struct TestUser {
218		username: String,
219		password_hash: Option<String>,
220		last_login: Option<DateTime<Utc>>,
221		is_active: bool,
222	}
223
224	impl BaseUser for TestUser {
225		type PrimaryKey = String;
226		type Hasher = TestHasher;
227
228		fn get_username_field() -> &'static str {
229			"username"
230		}
231
232		fn get_username(&self) -> &str {
233			&self.username
234		}
235
236		fn password_hash(&self) -> Option<&str> {
237			self.password_hash.as_deref()
238		}
239
240		fn set_password_hash(&mut self, hash: String) {
241			self.password_hash = Some(hash);
242		}
243
244		fn last_login(&self) -> Option<DateTime<Utc>> {
245			self.last_login
246		}
247
248		fn set_last_login(&mut self, time: DateTime<Utc>) {
249			self.last_login = Some(time);
250		}
251
252		fn is_active(&self) -> bool {
253			self.is_active
254		}
255	}
256
257	fn test_user(username: &str) -> TestUser {
258		TestUser {
259			username: username.to_string(),
260			password_hash: None,
261			last_login: None,
262			is_active: true,
263		}
264	}
265
266	#[test]
267	fn current_user_supports_tuple_struct_destructuring() {
268		let CurrentUser(user): CurrentUser<TestUser> = CurrentUser(test_user("alice"));
269
270		assert_eq!(user.get_username(), "alice");
271	}
272
273	#[allow(deprecated)]
274	#[test]
275	fn deprecated_auth_user_keeps_tuple_struct_destructuring() {
276		let AuthUser(user): AuthUser<TestUser> = AuthUser(test_user("bob"));
277
278		assert_eq!(user.get_username(), "bob");
279	}
280}