Skip to main content

reinhardt_auth/
current_user.rs

1//! CurrentUser Injectable for dependency injection
2//!
3//! Provides access to the authenticated user in endpoint handlers.
4//!
5//! This module integrates with the authentication middleware to provide
6//! type-safe access to the currently authenticated user via dependency injection.
7//!
8//! # How it works
9//!
10//! 1. The authentication middleware (e.g., `AuthenticationMiddleware`) validates
11//!    the user and stores an `AuthState` in the request extensions.
12//! 2. When a handler requests `CurrentUser<U>`, the injectable implementation:
13//!    - Extracts `AuthState` from request extensions
14//!    - Parses the user_id to the model's primary key type
15//!    - Loads the user from the database using `Model::objects().get(pk)`
16//! 3. Returns `CurrentUser::authenticated(user, user_id)` or `CurrentUser::anonymous()`
17
18use crate::AuthenticationError;
19use crate::BaseUser;
20use async_trait::async_trait;
21use reinhardt_db::orm::{DatabaseConnection, Model};
22#[cfg(not(feature = "params"))]
23use reinhardt_di::DiError;
24use reinhardt_di::{DiResult, Injectable, InjectionContext};
25use reinhardt_http::AuthState;
26use std::sync::Arc;
27use uuid::Uuid;
28
29/// Wrapper type representing the currently authenticated user for DI.
30///
31/// This type provides access to the authenticated user within endpoint handlers
32/// through dependency injection. It wraps an optional user instance and user ID,
33/// allowing handlers to check authentication status and access user data.
34///
35/// # Example
36///
37/// ```rust,ignore
38/// use reinhardt_auth::CurrentUser;
39/// use reinhardt_auth::DefaultUser;
40/// use reinhardt_http::Response;
41///
42/// async fn my_handler(
43///     current_user: CurrentUser<DefaultUser>,
44/// ) -> Result<Response, Box<dyn std::error::Error>> {
45///     if current_user.is_authenticated() {
46///         let user = current_user.user()?;
47///         let user_id = current_user.id()?;
48///         println!("Authenticated user: {} (ID: {})", user.get_username(), user_id);
49///     }
50///     Ok(Response::ok())
51/// }
52/// ```
53#[deprecated(
54	since = "0.1.0-rc.12",
55	note = "Use `AuthUser<U>` instead. `CurrentUser<U>` will become a type alias for `AuthUser<U>` in 0.2.0."
56)]
57pub struct CurrentUser<U: BaseUser + Clone> {
58	user: Option<U>,
59	user_id: Option<Uuid>,
60}
61
62#[allow(deprecated)] // Implementing Clone for deprecated CurrentUser
63impl<U: BaseUser + Clone> Clone for CurrentUser<U> {
64	fn clone(&self) -> Self {
65		Self {
66			user: self.user.clone(),
67			user_id: self.user_id,
68		}
69	}
70}
71
72#[allow(deprecated)] // Methods for deprecated CurrentUser
73impl<U: BaseUser + Clone> CurrentUser<U> {
74	/// Creates a new authenticated CurrentUser.
75	///
76	/// # Arguments
77	///
78	/// * `user` - The authenticated user instance
79	/// * `user_id` - The user's unique identifier
80	pub fn authenticated(user: U, user_id: Uuid) -> Self {
81		Self {
82			user: Some(user),
83			user_id: Some(user_id),
84		}
85	}
86
87	/// Creates an anonymous (unauthenticated) CurrentUser.
88	pub fn anonymous() -> Self {
89		Self {
90			user: None,
91			user_id: None,
92		}
93	}
94
95	/// Returns whether the current user is authenticated.
96	pub fn is_authenticated(&self) -> bool {
97		self.user.is_some()
98	}
99
100	/// Returns a reference to the user if authenticated.
101	///
102	/// # Errors
103	///
104	/// Returns `AuthenticationError::NotAuthenticated` if the user is not authenticated.
105	pub fn user(&self) -> Result<&U, AuthenticationError> {
106		self.user
107			.as_ref()
108			.ok_or(AuthenticationError::NotAuthenticated)
109	}
110
111	/// Returns the user ID if authenticated.
112	///
113	/// # Errors
114	///
115	/// Returns `AuthenticationError::NotAuthenticated` if the user is not authenticated.
116	pub fn id(&self) -> Result<Uuid, AuthenticationError> {
117		self.user_id.ok_or(AuthenticationError::NotAuthenticated)
118	}
119
120	/// Consumes this wrapper and returns the user if authenticated.
121	///
122	/// # Errors
123	///
124	/// Returns `AuthenticationError::NotAuthenticated` if the user is not authenticated.
125	pub fn into_user(self) -> Result<U, AuthenticationError> {
126		self.user.ok_or(AuthenticationError::NotAuthenticated)
127	}
128
129	/// Returns the user as a trait object for permission checking.
130	///
131	/// This method is used to pass the user to `ModelAdmin` permission methods
132	/// that accept `&(dyn Any + Send + Sync)`.
133	///
134	/// # Returns
135	///
136	/// Returns `Some` with a reference to the user as a trait object if authenticated,
137	/// or `None` if the user is anonymous.
138	pub fn as_any(&self) -> Option<&(dyn std::any::Any + Send + Sync)>
139	where
140		U: 'static,
141	{
142		self.user
143			.as_ref()
144			.map(|u| u as &(dyn std::any::Any + Send + Sync))
145	}
146}
147
148#[allow(deprecated)] // Injectable impl for deprecated CurrentUser
149#[async_trait]
150impl<U> Injectable for CurrentUser<U>
151where
152	U: BaseUser + Model + Clone + Send + Sync + 'static,
153	// Ensure BaseUser::PrimaryKey and Model::PrimaryKey are the same type
154	<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
155	<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
156	<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
157{
158	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
159		// 1. Get HTTP request from context
160		#[cfg(feature = "params")]
161		let request = match ctx.get_http_request() {
162			Some(req) => req,
163			None => return Ok(Self::anonymous()),
164		};
165
166		#[cfg(not(feature = "params"))]
167		return Err(DiError::NotFound(
168			"CurrentUser requires the 'params' feature to be enabled".to_string(),
169		));
170
171		// 2. Get AuthState from request extensions
172		#[cfg(feature = "params")]
173		let auth_state: AuthState = match request.extensions.get() {
174			Some(state) => state,
175			None => return Ok(Self::anonymous()),
176		};
177
178		// 3. Check if authenticated
179		#[cfg(feature = "params")]
180		if !auth_state.is_authenticated() {
181			return Ok(Self::anonymous());
182		}
183
184		// 4. Parse user_id to PrimaryKey type
185		#[cfg(feature = "params")]
186		let base_pk: <U as BaseUser>::PrimaryKey = match auth_state.user_id().parse() {
187			Ok(pk) => pk,
188			Err(_) => return Ok(Self::anonymous()),
189		};
190
191		// Convert BaseUser::PrimaryKey to Model::PrimaryKey
192		#[cfg(feature = "params")]
193		let model_pk: <U as Model>::PrimaryKey = base_pk.into();
194
195		// 5. Get DatabaseConnection from DI context
196		// Resolve DatabaseConnection from DI (singleton-first, request-scope fallback)
197		// Uses get_singleton/get_request directly instead of ctx.resolve() because
198		// DatabaseConnection is pre-seeded into the singleton scope at server startup,
199		// not registered in the global DependencyRegistry.
200		#[cfg(feature = "params")]
201		let db: Arc<DatabaseConnection> = match ctx
202			.get_singleton::<DatabaseConnection>()
203			.or_else(|| ctx.get_request::<DatabaseConnection>())
204		{
205			Some(conn) => conn,
206			None => {
207				::tracing::warn!(
208					"DatabaseConnection not registered in DI context. \
209					 CurrentUser will be anonymous. \
210					 Hint: Register DatabaseConnection as a singleton in InjectionContext."
211				);
212				return Ok(Self::anonymous());
213			}
214		};
215
216		// 6. Load user from database using Model::objects() (Django-style ORM)
217		#[cfg(feature = "params")]
218		let user: U = match U::objects().get(model_pk).first_with_db(&db).await {
219			Ok(Some(u)) => u,
220			Ok(None) | Err(_) => return Ok(Self::anonymous()),
221		};
222
223		// 7. Parse UUID for CurrentUser (Uuid is commonly used for user IDs)
224		#[cfg(feature = "params")]
225		let user_id = match Uuid::parse_str(auth_state.user_id()) {
226			Ok(id) => id,
227			Err(e) => {
228				::tracing::warn!(
229					user_id = %auth_state.user_id(),
230					error = ?e,
231					"CurrentUser: failed to parse user_id as UUID"
232				);
233				return Ok(Self::anonymous());
234			}
235		};
236
237		#[cfg(feature = "params")]
238		Ok(Self::authenticated(user, user_id))
239	}
240}
241
242#[allow(deprecated)] // Tests for deprecated CurrentUser
243#[cfg(test)]
244mod tests {
245	use super::*;
246	use crate::PasswordHasher;
247	use chrono::{DateTime, Utc};
248	use serde::{Deserialize, Serialize};
249
250	/// Mock password hasher for testing
251	#[derive(Default)]
252	struct MockHasher;
253
254	impl PasswordHasher for MockHasher {
255		fn hash(&self, password: &str) -> Result<String, reinhardt_core::exception::Error> {
256			Ok(format!("hashed:{}", password))
257		}
258
259		fn verify(
260			&self,
261			password: &str,
262			hash: &str,
263		) -> Result<bool, reinhardt_core::exception::Error> {
264			Ok(hash == format!("hashed:{}", password))
265		}
266	}
267
268	// Test user implementation for unit tests
269	#[derive(Clone, Serialize, Deserialize)]
270	struct TestUser {
271		id: Uuid,
272		username: String,
273		is_active: bool,
274	}
275
276	impl BaseUser for TestUser {
277		type PrimaryKey = Uuid;
278		type Hasher = MockHasher;
279
280		fn get_username_field() -> &'static str {
281			"username"
282		}
283
284		fn get_username(&self) -> &str {
285			&self.username
286		}
287
288		fn password_hash(&self) -> Option<&str> {
289			None
290		}
291
292		fn set_password_hash(&mut self, _hash: String) {}
293
294		fn last_login(&self) -> Option<DateTime<Utc>> {
295			None
296		}
297
298		fn set_last_login(&mut self, _time: DateTime<Utc>) {}
299
300		fn is_active(&self) -> bool {
301			self.is_active
302		}
303	}
304
305	#[test]
306	fn test_authenticated_user() {
307		let user_id = Uuid::now_v7();
308		let user = TestUser {
309			id: user_id,
310			username: "testuser".to_string(),
311			is_active: true,
312		};
313
314		let current_user = CurrentUser::authenticated(user, user_id);
315
316		assert!(current_user.is_authenticated());
317		assert_eq!(current_user.id().unwrap(), user_id);
318		assert_eq!(current_user.user().unwrap().get_username(), "testuser");
319	}
320
321	#[test]
322	fn test_anonymous_user() {
323		let current_user: CurrentUser<TestUser> = CurrentUser::anonymous();
324
325		assert!(!current_user.is_authenticated());
326		assert!(current_user.id().is_err());
327		assert!(current_user.user().is_err());
328	}
329
330	#[test]
331	fn test_into_user_authenticated() {
332		let user_id = Uuid::now_v7();
333		let user = TestUser {
334			id: user_id,
335			username: "testuser".to_string(),
336			is_active: true,
337		};
338
339		let current_user = CurrentUser::authenticated(user, user_id);
340		let extracted = current_user.into_user().unwrap();
341
342		assert_eq!(extracted.get_username(), "testuser");
343	}
344
345	#[test]
346	fn test_into_user_anonymous() {
347		let current_user: CurrentUser<TestUser> = CurrentUser::anonymous();
348		let result = current_user.into_user();
349
350		assert!(result.is_err());
351	}
352}