Skip to main content

reinhardt_middleware/
auth.rs

1// The `User` trait is deprecated in favour of the new `#[model]`-based user macro system.
2// Downstream crates still reference it during the transition period.
3#![allow(deprecated)]
4
5#[cfg(feature = "sessions")]
6use async_trait::async_trait;
7#[cfg(feature = "sessions")]
8use std::sync::Arc;
9
10#[cfg(feature = "sessions")]
11use reinhardt_http::{Handler, Middleware, Request, Response, Result};
12
13#[cfg(feature = "sessions")]
14use reinhardt_auth::session::{SESSION_KEY_USER_ID, SessionStore};
15#[cfg(feature = "sessions")]
16use reinhardt_auth::{AnonymousUser, AuthenticationBackend, User};
17
18/// Authentication middleware
19/// Extracts user information from session and attaches it to request extensions
20///
21/// This middleware integrates with tower/hyper to provide Django-style authentication
22/// for Reinhardt applications. It automatically:
23/// - Extracts session ID from cookies
24/// - Loads user information from the session store
25/// - Attaches user authentication state to request extensions
26/// - Supports any authentication backend implementing `AuthenticationBackend`
27///
28/// # Examples
29///
30/// Basic usage with in-memory session store:
31///
32/// ```rust,no_run
33/// # #[tokio::main]
34/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
35/// use std::sync::Arc;
36/// use reinhardt_middleware::AuthenticationMiddleware;
37/// use reinhardt_auth::session::InMemorySessionStore;
38/// use reinhardt_http::MiddlewareChain;
39/// # use reinhardt_http::{Handler, {Request, Response, Result}};
40/// # use reinhardt_auth::{AuthenticationBackend, AuthenticationError, User, SimpleUser};
41/// # use async_trait::async_trait;
42/// # use uuid::Uuid;
43/// #
44/// # struct MyHandler;
45/// # #[async_trait]
46/// # impl Handler for MyHandler {
47/// #     async fn handle(&self, _request: Request) -> Result<Response> {
48/// #         Ok(Response::ok())
49/// #     }
50/// # }
51/// #
52/// # // Simple test authentication backend
53/// # struct TestAuthBackend;
54/// # #[async_trait]
55/// # impl AuthenticationBackend for TestAuthBackend {
56/// #     async fn authenticate(&self, _request: &Request) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
57/// #         Ok(Some(Box::new(SimpleUser {
58/// #             id: Uuid::new_v4(),
59/// #             username: "testuser".to_string(),
60/// #             email: "test@example.com".to_string(),
61/// #             is_active: true,
62/// #             is_admin: false,
63/// #             is_staff: false,
64/// #             is_superuser: false,
65/// #         })))
66/// #     }
67/// #     async fn get_user(&self, _user_id: &str) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
68/// #         Ok(None)
69/// #     }
70/// # }
71///
72/// // Create session store and authentication backend
73/// let session_store = Arc::new(InMemorySessionStore::new());
74/// let auth_backend = Arc::new(TestAuthBackend);
75///
76/// // Create authentication middleware
77/// let auth_middleware = AuthenticationMiddleware::new(session_store, auth_backend);
78///
79/// // Wrap your handler with the middleware using MiddlewareChain
80/// # let handler = Arc::new(MyHandler);
81/// let app = MiddlewareChain::new(handler)
82///     .with_middleware(Arc::new(auth_middleware));
83/// # Ok(())
84/// # }
85/// ```
86///
87/// Accessing authentication state in handlers:
88///
89/// ```
90/// # use reinhardt_http::{Handler, {Request, Response, Result}};
91/// # use async_trait::async_trait;
92/// struct ProtectedHandler;
93///
94/// #[async_trait]
95/// impl Handler for ProtectedHandler {
96///     async fn handle(&self, request: Request) -> Result<Response> {
97///         // Extract authentication state from request extensions
98///         let is_authenticated: Option<bool> = request.extensions.get();
99///         let user_id: Option<String> = request.extensions.get();
100///         let is_admin: Option<bool> = request.extensions.get();
101///
102///         if !is_authenticated.unwrap_or(false) {
103///             return Ok(Response::new(hyper::StatusCode::UNAUTHORIZED));
104///         }
105///
106///         Ok(Response::ok().with_body(format!("Welcome user: {:?}", user_id)))
107///     }
108/// }
109/// ```
110#[cfg(feature = "sessions")]
111pub struct AuthenticationMiddleware<S: SessionStore, A: AuthenticationBackend> {
112	session_store: Arc<S>,
113	auth_backend: Arc<A>,
114}
115
116#[cfg(feature = "sessions")]
117impl<S: SessionStore, A: AuthenticationBackend> AuthenticationMiddleware<S, A> {
118	/// Create a new authentication middleware
119	///
120	/// # Arguments
121	///
122	/// * `session_store` - Session storage backend
123	/// * `auth_backend` - Authentication backend for user lookup
124	///
125	/// # Examples
126	///
127	/// ```no_run
128	/// use std::sync::Arc;
129	/// use reinhardt_middleware::AuthenticationMiddleware;
130	/// use reinhardt_auth::session::InMemorySessionStore;
131	/// # use reinhardt_http::Request;
132	/// # use reinhardt_auth::{AuthenticationBackend, AuthenticationError, User, SimpleUser};
133	/// # use uuid::Uuid;
134	/// #
135	/// # // Simple test authentication backend
136	/// # struct TestAuthBackend;
137	/// # #[async_trait::async_trait]
138	/// # impl AuthenticationBackend for TestAuthBackend {
139	/// #     async fn authenticate(&self, _request: &Request) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
140	/// #         Ok(Some(Box::new(SimpleUser {
141	/// #             id: Uuid::new_v4(),
142	/// #             username: "testuser".to_string(),
143	/// #             email: "test@example.com".to_string(),
144	/// #             is_active: true,
145	/// #             is_admin: false,
146	/// #             is_staff: false,
147	/// #             is_superuser: false,
148	/// #         })))
149	/// #     }
150	/// #     async fn get_user(&self, _user_id: &str) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
151	/// #         Ok(None)
152	/// #     }
153	/// # }
154	///
155	/// let session_store = Arc::new(InMemorySessionStore::new());
156	/// let auth_backend = Arc::new(TestAuthBackend);
157	/// let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
158	/// ```
159	pub fn new(session_store: Arc<S>, auth_backend: Arc<A>) -> Self {
160		Self {
161			session_store,
162			auth_backend,
163		}
164	}
165
166	/// Extract session ID from cookies.
167	///
168	/// Validates that the session ID is non-empty and well-formed
169	/// (UUID format) before returning it.
170	fn extract_session_id(&self, request: &Request) -> Option<String> {
171		const SESSION_COOKIE_NAME: &str = "sessionid";
172		request
173			.headers
174			.get("cookie")
175			.and_then(|v| v.to_str().ok())
176			.and_then(|cookies| {
177				cookies.split(';').find_map(|cookie| {
178					let mut parts = cookie.trim().split('=');
179					if parts.next()? == SESSION_COOKIE_NAME {
180						Some(parts.next()?.to_string())
181					} else {
182						None
183					}
184				})
185			})
186			.filter(|id| Self::is_valid_session_id(id))
187	}
188
189	/// Validate that a session ID is non-empty and well-formed.
190	///
191	/// Session IDs are expected to be UUIDs (32 hex chars + 4 hyphens = 36 chars).
192	/// This prevents accepting arbitrary strings as session identifiers.
193	fn is_valid_session_id(id: &str) -> bool {
194		if id.is_empty() || id.len() > 128 {
195			return false;
196		}
197		// Validate UUID format (xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx)
198		uuid::Uuid::parse_str(id).is_ok()
199	}
200
201	/// Get user from session
202	async fn get_user_from_session(&self, session_id: &String) -> Option<Box<dyn User>> {
203		if let Some(session) = self.session_store.load(session_id).await
204			&& let Some(user_id_value) = session.get(SESSION_KEY_USER_ID)
205			&& let Some(user_id) = user_id_value.as_str()
206			&& let Ok(Some(user)) = self.auth_backend.get_user(user_id).await
207		{
208			return Some(user);
209		}
210		None
211	}
212}
213
214#[cfg(feature = "sessions")]
215#[async_trait]
216impl<S: SessionStore + 'static, A: AuthenticationBackend + 'static> Middleware
217	for AuthenticationMiddleware<S, A>
218{
219	async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
220		let user: Box<dyn User> = if let Some(ref session_id) = self.extract_session_id(&request) {
221			self.get_user_from_session(session_id)
222				.await
223				.unwrap_or_else(|| Box::new(AnonymousUser))
224		} else {
225			Box::new(AnonymousUser)
226		};
227
228		let is_authenticated = user.is_authenticated();
229		let is_admin = user.is_admin();
230		let is_active = user.is_active();
231		let user_id = user.id();
232
233		// Insert individual values for backward compatibility
234		request.extensions.insert(user_id.clone());
235		request.extensions.insert(is_authenticated);
236		request.extensions.insert(is_admin);
237		request.extensions.insert(is_active);
238
239		// Insert AuthState object for CurrentUser and new code
240		let auth_state = if is_authenticated {
241			AuthState::authenticated(user_id, is_admin, is_active)
242		} else {
243			AuthState::anonymous()
244		};
245		request.extensions.insert(auth_state);
246
247		next.handle(request).await
248	}
249}
250
251// Re-export AuthState from reinhardt-http for backward compatibility.
252// AuthState is the canonical type for storing authentication state in extensions.
253pub use reinhardt_http::AuthState;
254
255#[cfg(all(test, feature = "sessions"))]
256mod tests {
257	use super::*;
258	use bytes::Bytes;
259	use hyper::{HeaderMap, Method, Version};
260	use reinhardt_auth::AuthenticationError;
261	use reinhardt_auth::SimpleUser;
262	use reinhardt_auth::session::{InMemorySessionStore, Session};
263	use uuid::Uuid;
264
265	struct TestHandler;
266
267	#[async_trait]
268	impl Handler for TestHandler {
269		async fn handle(&self, request: Request) -> Result<Response> {
270			let user_id: Option<String> = request.extensions.get();
271			let is_authenticated: Option<bool> = request.extensions.get();
272
273			Ok(Response::ok().with_json(&serde_json::json!({
274				"user_id": user_id.unwrap_or_default(),
275				"is_authenticated": is_authenticated.unwrap_or(false)
276			}))?)
277		}
278	}
279
280	struct TestAuthBackend {
281		user: Option<SimpleUser>,
282	}
283
284	#[async_trait::async_trait]
285	impl AuthenticationBackend for TestAuthBackend {
286		async fn authenticate(
287			&self,
288			_request: &Request,
289		) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
290			Ok(self
291				.user
292				.as_ref()
293				.map(|u| Box::new(u.clone()) as Box<dyn User>))
294		}
295
296		async fn get_user(
297			&self,
298			_user_id: &str,
299		) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
300			Ok(self
301				.user
302				.as_ref()
303				.map(|u| Box::new(u.clone()) as Box<dyn User>))
304		}
305	}
306
307	#[tokio::test]
308	async fn test_auth_middleware_with_valid_session() {
309		let session_store = Arc::new(InMemorySessionStore::new());
310		let user = SimpleUser {
311			id: Uuid::new_v4(),
312			username: "testuser".to_string(),
313			email: "test@example.com".to_string(),
314			is_active: true,
315			is_admin: false,
316			is_staff: false,
317			is_superuser: false,
318		};
319		let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
320
321		let session_id = session_store.create_session_id();
322		let mut session = Session::new();
323		session.set(SESSION_KEY_USER_ID, serde_json::json!("user123"));
324		session_store.save(&session_id, &session).await;
325
326		let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
327		let handler = Arc::new(TestHandler);
328
329		let mut headers = HeaderMap::new();
330		headers.insert(
331			"cookie",
332			format!("sessionid={}", session_id).parse().unwrap(),
333		);
334
335		let request = Request::builder()
336			.method(Method::GET)
337			.uri("/test")
338			.version(Version::HTTP_11)
339			.headers(headers)
340			.body(Bytes::new())
341			.build()
342			.unwrap();
343
344		let response = middleware.process(request, handler).await.unwrap();
345		assert_eq!(response.status, reinhardt_http::Response::ok().status);
346	}
347
348	#[tokio::test]
349	async fn test_auth_middleware_without_session() {
350		let session_store = Arc::new(InMemorySessionStore::new());
351		let auth_backend = Arc::new(TestAuthBackend { user: None });
352
353		let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
354		let handler = Arc::new(TestHandler);
355
356		let request = Request::builder()
357			.method(Method::GET)
358			.uri("/test")
359			.version(Version::HTTP_11)
360			.headers(HeaderMap::new())
361			.body(Bytes::new())
362			.build()
363			.unwrap();
364
365		let response = middleware.process(request, handler).await.unwrap();
366		assert_eq!(response.status, reinhardt_http::Response::ok().status);
367
368		let body_str = String::from_utf8(response.body.to_vec()).unwrap();
369		assert!(body_str.contains("\"is_authenticated\":false"));
370	}
371
372	#[test]
373	fn test_auth_state_from_extensions() {
374		let extensions = reinhardt_http::Extensions::new();
375		extensions.insert("user123".to_string());
376		extensions.insert(true);
377
378		let auth_state = AuthState::from_extensions(&extensions);
379		assert!(auth_state.is_some());
380		assert!(!auth_state.unwrap().is_anonymous());
381	}
382
383	#[test]
384	fn test_auth_state_is_anonymous() {
385		let anon_state = AuthState::anonymous();
386
387		assert!(anon_state.is_anonymous());
388
389		let auth_state = AuthState::authenticated("user123", false, true);
390
391		assert!(!auth_state.is_anonymous());
392	}
393}