Skip to main content

reinhardt_auth/
current_user.rs

1//! CurrentUser Injectable for dependency injection
2//!
3//! Provides access to the authenticated user in endpoint handlers.
4//!
5//! This module integrates with the authentication middleware to provide
6//! type-safe access to the currently authenticated user via dependency injection.
7//!
8//! # How it works
9//!
10//! 1. The authentication middleware (e.g., `AuthenticationMiddleware`) validates
11//!    the user and stores an `AuthState` in the request extensions.
12//! 2. When a handler requests `CurrentUser<U>`, the injectable implementation:
13//!    - Extracts `AuthState` from request extensions
14//!    - Parses the user_id to the model's primary key type
15//!    - Loads the user from the database using `Model::objects().get(pk)`
16//! 3. Returns `CurrentUser::authenticated(user, user_id)` or `CurrentUser::anonymous()`
17
18use crate::AuthenticationError;
19use crate::BaseUser;
20use async_trait::async_trait;
21use reinhardt_db::orm::{DatabaseConnection, Model};
22use reinhardt_di::{DiResult, Injectable, InjectionContext};
23use reinhardt_http::AuthState;
24use std::sync::Arc;
25use uuid::Uuid;
26
27/// Wrapper type representing the currently authenticated user for DI.
28///
29/// This type provides access to the authenticated user within endpoint handlers
30/// through dependency injection. It wraps an optional user instance and user ID,
31/// allowing handlers to check authentication status and access user data.
32///
33/// # Example
34///
35/// ```rust,ignore
36/// use reinhardt_auth::CurrentUser;
37/// use reinhardt_auth::DefaultUser;
38/// use reinhardt_http::Response;
39///
40/// async fn my_handler(
41///     current_user: CurrentUser<DefaultUser>,
42/// ) -> Result<Response, Box<dyn std::error::Error>> {
43///     if current_user.is_authenticated() {
44///         let user = current_user.user()?;
45///         let user_id = current_user.id()?;
46///         println!("Authenticated user: {} (ID: {})", user.get_username(), user_id);
47///     }
48///     Ok(Response::ok())
49/// }
50/// ```
51#[deprecated(
52	since = "0.1.0-rc.12",
53	note = "Use `AuthUser<U>` instead. `CurrentUser<U>` will become a type alias for `AuthUser<U>` in 0.2.0."
54)]
55pub struct CurrentUser<U: BaseUser + Clone> {
56	user: Option<U>,
57	user_id: Option<Uuid>,
58}
59
60#[allow(deprecated)] // Implementing Clone for deprecated CurrentUser
61impl<U: BaseUser + Clone> Clone for CurrentUser<U> {
62	fn clone(&self) -> Self {
63		Self {
64			user: self.user.clone(),
65			user_id: self.user_id,
66		}
67	}
68}
69
70#[allow(deprecated)] // Methods for deprecated CurrentUser
71impl<U: BaseUser + Clone> CurrentUser<U> {
72	/// Creates a new authenticated CurrentUser.
73	///
74	/// # Arguments
75	///
76	/// * `user` - The authenticated user instance
77	/// * `user_id` - The user's unique identifier
78	pub fn authenticated(user: U, user_id: Uuid) -> Self {
79		Self {
80			user: Some(user),
81			user_id: Some(user_id),
82		}
83	}
84
85	/// Creates an anonymous (unauthenticated) CurrentUser.
86	pub fn anonymous() -> Self {
87		Self {
88			user: None,
89			user_id: None,
90		}
91	}
92
93	/// Returns whether the current user is authenticated.
94	pub fn is_authenticated(&self) -> bool {
95		self.user.is_some()
96	}
97
98	/// Returns a reference to the user if authenticated.
99	///
100	/// # Errors
101	///
102	/// Returns `AuthenticationError::NotAuthenticated` if the user is not authenticated.
103	pub fn user(&self) -> Result<&U, AuthenticationError> {
104		self.user
105			.as_ref()
106			.ok_or(AuthenticationError::NotAuthenticated)
107	}
108
109	/// Returns the user ID if authenticated.
110	///
111	/// # Errors
112	///
113	/// Returns `AuthenticationError::NotAuthenticated` if the user is not authenticated.
114	pub fn id(&self) -> Result<Uuid, AuthenticationError> {
115		self.user_id.ok_or(AuthenticationError::NotAuthenticated)
116	}
117
118	/// Consumes this wrapper and returns the user if authenticated.
119	///
120	/// # Errors
121	///
122	/// Returns `AuthenticationError::NotAuthenticated` if the user is not authenticated.
123	pub fn into_user(self) -> Result<U, AuthenticationError> {
124		self.user.ok_or(AuthenticationError::NotAuthenticated)
125	}
126
127	/// Returns the user as a trait object for permission checking.
128	///
129	/// This method is used to pass the user to `ModelAdmin` permission methods
130	/// that accept `&(dyn Any + Send + Sync)`.
131	///
132	/// # Returns
133	///
134	/// Returns `Some` with a reference to the user as a trait object if authenticated,
135	/// or `None` if the user is anonymous.
136	pub fn as_any(&self) -> Option<&(dyn std::any::Any + Send + Sync)>
137	where
138		U: 'static,
139	{
140		self.user
141			.as_ref()
142			.map(|u| u as &(dyn std::any::Any + Send + Sync))
143	}
144}
145
146#[allow(deprecated)] // Injectable impl for deprecated CurrentUser
147#[async_trait]
148impl<U> Injectable for CurrentUser<U>
149where
150	U: BaseUser + Model + Clone + Send + Sync + 'static,
151	// Ensure BaseUser::PrimaryKey and Model::PrimaryKey are the same type
152	<U as BaseUser>::PrimaryKey: std::str::FromStr + ToString + Send + Sync,
153	<<U as BaseUser>::PrimaryKey as std::str::FromStr>::Err: std::fmt::Debug,
154	<U as Model>::PrimaryKey: From<<U as BaseUser>::PrimaryKey>,
155{
156	async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
157		// 1. Get HTTP request from context
158		#[cfg(feature = "params")]
159		let request = match ctx.get_http_request() {
160			Some(req) => req,
161			None => return Ok(Self::anonymous()),
162		};
163
164		#[cfg(not(feature = "params"))]
165		return Ok(Self::anonymous());
166
167		// 2. Get AuthState from request extensions
168		#[cfg(feature = "params")]
169		let auth_state: AuthState = match request.extensions.get() {
170			Some(state) => state,
171			None => return Ok(Self::anonymous()),
172		};
173
174		// 3. Check if authenticated
175		#[cfg(feature = "params")]
176		if !auth_state.is_authenticated() {
177			return Ok(Self::anonymous());
178		}
179
180		// 4. Parse user_id to PrimaryKey type
181		#[cfg(feature = "params")]
182		let base_pk: <U as BaseUser>::PrimaryKey = match auth_state.user_id().parse() {
183			Ok(pk) => pk,
184			Err(_) => return Ok(Self::anonymous()),
185		};
186
187		// Convert BaseUser::PrimaryKey to Model::PrimaryKey
188		#[cfg(feature = "params")]
189		let model_pk: <U as Model>::PrimaryKey = base_pk.into();
190
191		// 5. Get DatabaseConnection from DI context
192		// Resolve DatabaseConnection from DI (singleton-first, request-scope fallback)
193		// Uses get_singleton/get_request directly instead of ctx.resolve() because
194		// DatabaseConnection is pre-seeded into the singleton scope at server startup,
195		// not registered in the global DependencyRegistry.
196		#[cfg(feature = "params")]
197		let db: Arc<DatabaseConnection> = match ctx
198			.get_singleton::<DatabaseConnection>()
199			.or_else(|| ctx.get_request::<DatabaseConnection>())
200		{
201			Some(conn) => conn,
202			None => {
203				::tracing::warn!(
204					"DatabaseConnection not registered in DI context. \
205					 CurrentUser will be anonymous. \
206					 Hint: Register DatabaseConnection as a singleton in InjectionContext."
207				);
208				return Ok(Self::anonymous());
209			}
210		};
211
212		// 6. Load user from database using Model::objects() (Django-style ORM)
213		#[cfg(feature = "params")]
214		let user: U = match U::objects().get(model_pk).first_with_db(&db).await {
215			Ok(Some(u)) => u,
216			Ok(None) | Err(_) => return Ok(Self::anonymous()),
217		};
218
219		// 7. Parse UUID for CurrentUser (Uuid is commonly used for user IDs)
220		#[cfg(feature = "params")]
221		let user_id = match Uuid::parse_str(auth_state.user_id()) {
222			Ok(id) => id,
223			Err(e) => {
224				::tracing::warn!(
225					user_id = %auth_state.user_id(),
226					error = ?e,
227					"CurrentUser: failed to parse user_id as UUID"
228				);
229				return Ok(Self::anonymous());
230			}
231		};
232
233		#[cfg(feature = "params")]
234		Ok(Self::authenticated(user, user_id))
235	}
236}
237
238#[allow(deprecated)] // Tests for deprecated CurrentUser
239#[cfg(test)]
240mod tests {
241	use super::*;
242	use crate::PasswordHasher;
243	use chrono::{DateTime, Utc};
244	use serde::{Deserialize, Serialize};
245
246	/// Mock password hasher for testing
247	#[derive(Default)]
248	struct MockHasher;
249
250	impl PasswordHasher for MockHasher {
251		fn hash(&self, password: &str) -> Result<String, reinhardt_core::exception::Error> {
252			Ok(format!("hashed:{}", password))
253		}
254
255		fn verify(
256			&self,
257			password: &str,
258			hash: &str,
259		) -> Result<bool, reinhardt_core::exception::Error> {
260			Ok(hash == format!("hashed:{}", password))
261		}
262	}
263
264	// Test user implementation for unit tests
265	#[derive(Clone, Serialize, Deserialize)]
266	struct TestUser {
267		id: Uuid,
268		username: String,
269		is_active: bool,
270	}
271
272	impl BaseUser for TestUser {
273		type PrimaryKey = Uuid;
274		type Hasher = MockHasher;
275
276		fn get_username_field() -> &'static str {
277			"username"
278		}
279
280		fn get_username(&self) -> &str {
281			&self.username
282		}
283
284		fn password_hash(&self) -> Option<&str> {
285			None
286		}
287
288		fn set_password_hash(&mut self, _hash: String) {}
289
290		fn last_login(&self) -> Option<DateTime<Utc>> {
291			None
292		}
293
294		fn set_last_login(&mut self, _time: DateTime<Utc>) {}
295
296		fn is_active(&self) -> bool {
297			self.is_active
298		}
299	}
300
301	#[test]
302	fn test_authenticated_user() {
303		let user_id = Uuid::new_v4();
304		let user = TestUser {
305			id: user_id,
306			username: "testuser".to_string(),
307			is_active: true,
308		};
309
310		let current_user = CurrentUser::authenticated(user, user_id);
311
312		assert!(current_user.is_authenticated());
313		assert_eq!(current_user.id().unwrap(), user_id);
314		assert_eq!(current_user.user().unwrap().get_username(), "testuser");
315	}
316
317	#[test]
318	fn test_anonymous_user() {
319		let current_user: CurrentUser<TestUser> = CurrentUser::anonymous();
320
321		assert!(!current_user.is_authenticated());
322		assert!(current_user.id().is_err());
323		assert!(current_user.user().is_err());
324	}
325
326	#[test]
327	fn test_into_user_authenticated() {
328		let user_id = Uuid::new_v4();
329		let user = TestUser {
330			id: user_id,
331			username: "testuser".to_string(),
332			is_active: true,
333		};
334
335		let current_user = CurrentUser::authenticated(user, user_id);
336		let extracted = current_user.into_user().unwrap();
337
338		assert_eq!(extracted.get_username(), "testuser");
339	}
340
341	#[test]
342	fn test_into_user_anonymous() {
343		let current_user: CurrentUser<TestUser> = CurrentUser::anonymous();
344		let result = current_user.into_user();
345
346		assert!(result.is_err());
347	}
348}