Skip to main content

reinhardt_middleware/
auth.rs

1// The `User` trait is deprecated in favour of the new `#[model]`-based user macro system.
2// Downstream crates still reference it during the transition period.
3#![allow(deprecated)]
4
5#[cfg(feature = "sessions")]
6use async_trait::async_trait;
7#[cfg(feature = "sessions")]
8use std::sync::Arc;
9
10#[cfg(feature = "sessions")]
11use reinhardt_http::{
12	Handler, IsActive, IsAdmin, IsAuthenticated, Middleware, Request, Response, Result,
13};
14
15#[cfg(feature = "sessions")]
16use reinhardt_auth::session::{SESSION_KEY_USER_ID, SessionStore};
17#[cfg(feature = "sessions")]
18use reinhardt_auth::{AnonymousUser, AuthenticationBackend, User};
19
20/// Authentication middleware
21/// Extracts user information from session and attaches it to request extensions
22///
23/// This middleware integrates with tower/hyper to provide Django-style authentication
24/// for Reinhardt applications. It automatically:
25/// - Extracts session ID from cookies
26/// - Loads user information from the session store
27/// - Attaches user authentication state to request extensions
28/// - Supports any authentication backend implementing `AuthenticationBackend`
29///
30/// # Examples
31///
32/// Basic usage with in-memory session store:
33///
34/// ```rust,no_run
35/// # #[tokio::main]
36/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
37/// use std::sync::Arc;
38/// use reinhardt_middleware::AuthenticationMiddleware;
39/// use reinhardt_auth::session::InMemorySessionStore;
40/// use reinhardt_http::MiddlewareChain;
41/// # use reinhardt_http::{Handler, {Request, Response, Result}};
42/// # use reinhardt_auth::{AuthenticationBackend, AuthenticationError, User, SimpleUser};
43/// # use async_trait::async_trait;
44/// # use uuid::Uuid;
45/// #
46/// # struct MyHandler;
47/// # #[async_trait]
48/// # impl Handler for MyHandler {
49/// #     async fn handle(&self, _request: Request) -> Result<Response> {
50/// #         Ok(Response::ok())
51/// #     }
52/// # }
53/// #
54/// # // Simple test authentication backend
55/// # struct TestAuthBackend;
56/// # #[async_trait]
57/// # impl AuthenticationBackend for TestAuthBackend {
58/// #     async fn authenticate(&self, _request: &Request) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
59/// #         Ok(Some(Box::new(SimpleUser {
60/// #             id: Uuid::now_v7(),
61/// #             username: "testuser".to_string(),
62/// #             email: "test@example.com".to_string(),
63/// #             is_active: true,
64/// #             is_admin: false,
65/// #             is_staff: false,
66/// #             is_superuser: false,
67/// #         })))
68/// #     }
69/// #     async fn get_user(&self, _user_id: &str) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
70/// #         Ok(None)
71/// #     }
72/// # }
73///
74/// // Create session store and authentication backend
75/// let session_store = Arc::new(InMemorySessionStore::new());
76/// let auth_backend = Arc::new(TestAuthBackend);
77///
78/// // Create authentication middleware
79/// let auth_middleware = AuthenticationMiddleware::new(session_store, auth_backend);
80///
81/// // Wrap your handler with the middleware using MiddlewareChain
82/// # let handler = Arc::new(MyHandler);
83/// let app = MiddlewareChain::new(handler)
84///     .with_middleware(Arc::new(auth_middleware));
85/// # Ok(())
86/// # }
87/// ```
88///
89/// Accessing authentication state in handlers:
90///
91/// ```
92/// # use reinhardt_http::{Handler, {Request, Response, Result}};
93/// # use async_trait::async_trait;
94/// struct ProtectedHandler;
95///
96/// #[async_trait]
97/// impl Handler for ProtectedHandler {
98///     async fn handle(&self, request: Request) -> Result<Response> {
99///         // Extract authentication state from request extensions
100///         let is_authenticated: Option<bool> = request.extensions.get();
101///         let user_id: Option<String> = request.extensions.get();
102///         let is_admin: Option<bool> = request.extensions.get();
103///
104///         if !is_authenticated.unwrap_or(false) {
105///             return Ok(Response::new(hyper::StatusCode::UNAUTHORIZED));
106///         }
107///
108///         Ok(Response::ok().with_body(format!("Welcome user: {:?}", user_id)))
109///     }
110/// }
111/// ```
112#[cfg(feature = "sessions")]
113pub struct AuthenticationMiddleware<S: SessionStore, A: AuthenticationBackend> {
114	session_store: Arc<S>,
115	auth_backend: Arc<A>,
116}
117
118#[cfg(feature = "sessions")]
119impl<S: SessionStore, A: AuthenticationBackend> AuthenticationMiddleware<S, A> {
120	/// Create a new authentication middleware
121	///
122	/// # Arguments
123	///
124	/// * `session_store` - Session storage backend
125	/// * `auth_backend` - Authentication backend for user lookup
126	///
127	/// # Examples
128	///
129	/// ```no_run
130	/// use std::sync::Arc;
131	/// use reinhardt_middleware::AuthenticationMiddleware;
132	/// use reinhardt_auth::session::InMemorySessionStore;
133	/// # use reinhardt_http::Request;
134	/// # use reinhardt_auth::{AuthenticationBackend, AuthenticationError, User, SimpleUser};
135	/// # use uuid::Uuid;
136	/// #
137	/// # // Simple test authentication backend
138	/// # struct TestAuthBackend;
139	/// # #[async_trait::async_trait]
140	/// # impl AuthenticationBackend for TestAuthBackend {
141	/// #     async fn authenticate(&self, _request: &Request) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
142	/// #         Ok(Some(Box::new(SimpleUser {
143	/// #             id: Uuid::now_v7(),
144	/// #             username: "testuser".to_string(),
145	/// #             email: "test@example.com".to_string(),
146	/// #             is_active: true,
147	/// #             is_admin: false,
148	/// #             is_staff: false,
149	/// #             is_superuser: false,
150	/// #         })))
151	/// #     }
152	/// #     async fn get_user(&self, _user_id: &str) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
153	/// #         Ok(None)
154	/// #     }
155	/// # }
156	///
157	/// let session_store = Arc::new(InMemorySessionStore::new());
158	/// let auth_backend = Arc::new(TestAuthBackend);
159	/// let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
160	/// ```
161	pub fn new(session_store: Arc<S>, auth_backend: Arc<A>) -> Self {
162		Self {
163			session_store,
164			auth_backend,
165		}
166	}
167
168	/// Extract session ID from cookies.
169	///
170	/// Validates that the session ID is non-empty and well-formed
171	/// (UUID format) before returning it.
172	fn extract_session_id(&self, request: &Request) -> Option<String> {
173		const SESSION_COOKIE_NAME: &str = "sessionid";
174		request
175			.headers
176			.get("cookie")
177			.and_then(|v| v.to_str().ok())
178			.and_then(|cookies| {
179				cookies.split(';').find_map(|cookie| {
180					let mut parts = cookie.trim().split('=');
181					if parts.next()? == SESSION_COOKIE_NAME {
182						Some(parts.next()?.to_string())
183					} else {
184						None
185					}
186				})
187			})
188			.filter(|id| Self::is_valid_session_id(id))
189	}
190
191	/// Validate that a session ID is non-empty and well-formed.
192	///
193	/// Session IDs are expected to be UUIDs (32 hex chars + 4 hyphens = 36 chars).
194	/// This prevents accepting arbitrary strings as session identifiers.
195	fn is_valid_session_id(id: &str) -> bool {
196		if id.is_empty() || id.len() > 128 {
197			return false;
198		}
199		// Validate UUID format (xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx)
200		uuid::Uuid::parse_str(id).is_ok()
201	}
202
203	/// Get user from session
204	async fn get_user_from_session(&self, session_id: &String) -> Option<Box<dyn User>> {
205		if let Some(session) = self.session_store.load(session_id).await
206			&& let Some(user_id_value) = session.get(SESSION_KEY_USER_ID)
207			&& let Some(user_id) = user_id_value.as_str()
208			&& let Ok(Some(user)) = self.auth_backend.get_user(user_id).await
209		{
210			return Some(user);
211		}
212		None
213	}
214}
215
216#[cfg(feature = "sessions")]
217#[async_trait]
218impl<S: SessionStore + 'static, A: AuthenticationBackend + 'static> Middleware
219	for AuthenticationMiddleware<S, A>
220{
221	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
222		let user: Box<dyn User> = if let Some(ref session_id) = self.extract_session_id(&request) {
223			self.get_user_from_session(session_id)
224				.await
225				.unwrap_or_else(|| Box::new(AnonymousUser))
226		} else {
227			Box::new(AnonymousUser)
228		};
229
230		let is_authenticated = user.is_authenticated();
231		let is_admin = user.is_admin();
232		let is_active = user.is_active();
233		let user_id = user.id();
234
235		// Insert individual values for backward compatibility
236		request.extensions.insert(user_id.clone());
237		request.extensions.insert(IsAuthenticated(is_authenticated));
238		request.extensions.insert(IsAdmin(is_admin));
239		request.extensions.insert(IsActive(is_active));
240
241		// Insert AuthState object for CurrentUser and new code
242		let auth_state = if is_authenticated {
243			AuthState::authenticated(user_id, is_admin, is_active)
244		} else {
245			AuthState::anonymous()
246		};
247		request.extensions.insert(auth_state);
248
249		next.handle(request).await
250	}
251}
252
253// Re-export AuthState from reinhardt-http for backward compatibility.
254// AuthState is the canonical type for storing authentication state in extensions.
255pub use reinhardt_http::AuthState;
256
257#[cfg(all(test, feature = "sessions"))]
258mod tests {
259	use super::*;
260	use bytes::Bytes;
261	use hyper::{HeaderMap, Method, Version};
262	use reinhardt_auth::AuthenticationError;
263	use reinhardt_auth::SimpleUser;
264	use reinhardt_auth::session::{InMemorySessionStore, Session};
265	use uuid::Uuid;
266
267	struct TestHandler;
268
269	#[async_trait]
270	impl Handler for TestHandler {
271		async fn handle(&self, request: Request) -> Result<Response> {
272			let user_id: Option<String> = request.extensions.get();
273			let is_authenticated = request
274				.extensions
275				.get::<IsAuthenticated>()
276				.map(|v| v.0)
277				.unwrap_or(false);
278
279			Ok(Response::ok().with_json(&serde_json::json!({
280				"user_id": user_id.unwrap_or_default(),
281				"is_authenticated": is_authenticated
282			}))?)
283		}
284	}
285
286	struct TestAuthBackend {
287		user: Option<SimpleUser>,
288	}
289
290	#[async_trait::async_trait]
291	impl AuthenticationBackend for TestAuthBackend {
292		async fn authenticate(
293			&self,
294			_request: &Request,
295		) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
296			Ok(self
297				.user
298				.as_ref()
299				.map(|u| Box::new(u.clone()) as Box<dyn User>))
300		}
301
302		async fn get_user(
303			&self,
304			_user_id: &str,
305		) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
306			Ok(self
307				.user
308				.as_ref()
309				.map(|u| Box::new(u.clone()) as Box<dyn User>))
310		}
311	}
312
313	#[tokio::test]
314	async fn test_auth_middleware_with_valid_session() {
315		let session_store = Arc::new(InMemorySessionStore::new());
316		let user = SimpleUser {
317			id: Uuid::now_v7(),
318			username: "testuser".to_string(),
319			email: "test@example.com".to_string(),
320			is_active: true,
321			is_admin: false,
322			is_staff: false,
323			is_superuser: false,
324		};
325		let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
326
327		let session_id = session_store.create_session_id();
328		let mut session = Session::new();
329		session.set(SESSION_KEY_USER_ID, serde_json::json!("user123"));
330		session_store.save(&session_id, &session).await;
331
332		let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
333		let handler = Arc::new(TestHandler);
334
335		let mut headers = HeaderMap::new();
336		headers.insert(
337			"cookie",
338			format!("sessionid={}", session_id).parse().unwrap(),
339		);
340
341		let request = Request::builder()
342			.method(Method::GET)
343			.uri("/test")
344			.version(Version::HTTP_11)
345			.headers(headers)
346			.body(Bytes::new())
347			.build()
348			.unwrap();
349
350		let response = middleware.process(request, handler).await.unwrap();
351		assert_eq!(response.status, reinhardt_http::Response::ok().status);
352	}
353
354	#[tokio::test]
355	async fn test_auth_middleware_without_session() {
356		let session_store = Arc::new(InMemorySessionStore::new());
357		let auth_backend = Arc::new(TestAuthBackend { user: None });
358
359		let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
360		let handler = Arc::new(TestHandler);
361
362		let request = Request::builder()
363			.method(Method::GET)
364			.uri("/test")
365			.version(Version::HTTP_11)
366			.headers(HeaderMap::new())
367			.body(Bytes::new())
368			.build()
369			.unwrap();
370
371		let response = middleware.process(request, handler).await.unwrap();
372		assert_eq!(response.status, reinhardt_http::Response::ok().status);
373
374		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
375		assert!(body_str.contains("\"is_authenticated\":false"));
376	}
377
378	#[test]
379	fn test_auth_state_from_extensions() {
380		let extensions = reinhardt_http::Extensions::new();
381		extensions.insert("user123".to_string());
382		extensions.insert(IsAuthenticated(true));
383
384		let auth_state = AuthState::from_extensions(&extensions);
385		assert!(auth_state.is_some());
386		assert!(!auth_state.unwrap().is_anonymous());
387	}
388
389	#[test]
390	fn test_auth_state_is_anonymous() {
391		let anon_state = AuthState::anonymous();
392
393		assert!(anon_state.is_anonymous());
394
395		let auth_state = AuthState::authenticated("user123", false, true);
396
397		assert!(!auth_state.is_anonymous());
398	}
399}