reinhardt_middleware/
auth.rs1#![allow(deprecated)]
4
5#[cfg(feature = "sessions")]
6use async_trait::async_trait;
7#[cfg(feature = "sessions")]
8use std::sync::Arc;
9
10#[cfg(feature = "sessions")]
11use reinhardt_http::{
12 Handler, IsActive, IsAdmin, IsAuthenticated, Middleware, Request, Response, Result,
13};
14
15#[cfg(feature = "sessions")]
16use reinhardt_auth::session::{SESSION_KEY_USER_ID, SessionStore};
17#[cfg(feature = "sessions")]
18use reinhardt_auth::{AnonymousUser, AuthenticationBackend, User};
19
20#[cfg(feature = "sessions")]
113pub struct AuthenticationMiddleware<S: SessionStore, A: AuthenticationBackend> {
114 session_store: Arc<S>,
115 auth_backend: Arc<A>,
116}
117
118#[cfg(feature = "sessions")]
119impl<S: SessionStore, A: AuthenticationBackend> AuthenticationMiddleware<S, A> {
120 pub fn new(session_store: Arc<S>, auth_backend: Arc<A>) -> Self {
162 Self {
163 session_store,
164 auth_backend,
165 }
166 }
167
168 fn extract_session_id(&self, request: &Request) -> Option<String> {
173 const SESSION_COOKIE_NAME: &str = "sessionid";
174 request
175 .headers
176 .get("cookie")
177 .and_then(|v| v.to_str().ok())
178 .and_then(|cookies| {
179 cookies.split(';').find_map(|cookie| {
180 let mut parts = cookie.trim().split('=');
181 if parts.next()? == SESSION_COOKIE_NAME {
182 Some(parts.next()?.to_string())
183 } else {
184 None
185 }
186 })
187 })
188 .filter(|id| Self::is_valid_session_id(id))
189 }
190
191 fn is_valid_session_id(id: &str) -> bool {
196 if id.is_empty() || id.len() > 128 {
197 return false;
198 }
199 uuid::Uuid::parse_str(id).is_ok()
201 }
202
203 async fn get_user_from_session(&self, session_id: &String) -> Option<Box<dyn User>> {
205 if let Some(session) = self.session_store.load(session_id).await
206 && let Some(user_id_value) = session.get(SESSION_KEY_USER_ID)
207 && let Some(user_id) = user_id_value.as_str()
208 && let Ok(Some(user)) = self.auth_backend.get_user(user_id).await
209 {
210 return Some(user);
211 }
212 None
213 }
214}
215
216#[cfg(feature = "sessions")]
217#[async_trait]
218impl<S: SessionStore + 'static, A: AuthenticationBackend + 'static> Middleware
219 for AuthenticationMiddleware<S, A>
220{
221 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
222 let user: Box<dyn User> = if let Some(ref session_id) = self.extract_session_id(&request) {
223 self.get_user_from_session(session_id)
224 .await
225 .unwrap_or_else(|| Box::new(AnonymousUser))
226 } else {
227 Box::new(AnonymousUser)
228 };
229
230 let is_authenticated = user.is_authenticated();
231 let is_admin = user.is_admin();
232 let is_active = user.is_active();
233 let user_id = user.id();
234
235 request.extensions.insert(user_id.clone());
237 request.extensions.insert(IsAuthenticated(is_authenticated));
238 request.extensions.insert(IsAdmin(is_admin));
239 request.extensions.insert(IsActive(is_active));
240
241 let auth_state = if is_authenticated {
243 AuthState::authenticated(user_id, is_admin, is_active)
244 } else {
245 AuthState::anonymous()
246 };
247 request.extensions.insert(auth_state);
248
249 next.handle(request).await
250 }
251}
252
253pub use reinhardt_http::AuthState;
256
257#[cfg(all(test, feature = "sessions"))]
258mod tests {
259 use super::*;
260 use bytes::Bytes;
261 use hyper::{HeaderMap, Method, Version};
262 use reinhardt_auth::AuthenticationError;
263 use reinhardt_auth::SimpleUser;
264 use reinhardt_auth::session::{InMemorySessionStore, Session};
265 use uuid::Uuid;
266
267 struct TestHandler;
268
269 #[async_trait]
270 impl Handler for TestHandler {
271 async fn handle(&self, request: Request) -> Result<Response> {
272 let user_id: Option<String> = request.extensions.get();
273 let is_authenticated = request
274 .extensions
275 .get::<IsAuthenticated>()
276 .map(|v| v.0)
277 .unwrap_or(false);
278
279 Ok(Response::ok().with_json(&serde_json::json!({
280 "user_id": user_id.unwrap_or_default(),
281 "is_authenticated": is_authenticated
282 }))?)
283 }
284 }
285
286 struct TestAuthBackend {
287 user: Option<SimpleUser>,
288 }
289
290 #[async_trait::async_trait]
291 impl AuthenticationBackend for TestAuthBackend {
292 async fn authenticate(
293 &self,
294 _request: &Request,
295 ) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
296 Ok(self
297 .user
298 .as_ref()
299 .map(|u| Box::new(u.clone()) as Box<dyn User>))
300 }
301
302 async fn get_user(
303 &self,
304 _user_id: &str,
305 ) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
306 Ok(self
307 .user
308 .as_ref()
309 .map(|u| Box::new(u.clone()) as Box<dyn User>))
310 }
311 }
312
313 #[tokio::test]
314 async fn test_auth_middleware_with_valid_session() {
315 let session_store = Arc::new(InMemorySessionStore::new());
316 let user = SimpleUser {
317 id: Uuid::now_v7(),
318 username: "testuser".to_string(),
319 email: "test@example.com".to_string(),
320 is_active: true,
321 is_admin: false,
322 is_staff: false,
323 is_superuser: false,
324 };
325 let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
326
327 let session_id = session_store.create_session_id();
328 let mut session = Session::new();
329 session.set(SESSION_KEY_USER_ID, serde_json::json!("user123"));
330 session_store.save(&session_id, &session).await;
331
332 let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
333 let handler = Arc::new(TestHandler);
334
335 let mut headers = HeaderMap::new();
336 headers.insert(
337 "cookie",
338 format!("sessionid={}", session_id).parse().unwrap(),
339 );
340
341 let request = Request::builder()
342 .method(Method::GET)
343 .uri("/test")
344 .version(Version::HTTP_11)
345 .headers(headers)
346 .body(Bytes::new())
347 .build()
348 .unwrap();
349
350 let response = middleware.process(request, handler).await.unwrap();
351 assert_eq!(response.status, reinhardt_http::Response::ok().status);
352 }
353
354 #[tokio::test]
355 async fn test_auth_middleware_without_session() {
356 let session_store = Arc::new(InMemorySessionStore::new());
357 let auth_backend = Arc::new(TestAuthBackend { user: None });
358
359 let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
360 let handler = Arc::new(TestHandler);
361
362 let request = Request::builder()
363 .method(Method::GET)
364 .uri("/test")
365 .version(Version::HTTP_11)
366 .headers(HeaderMap::new())
367 .body(Bytes::new())
368 .build()
369 .unwrap();
370
371 let response = middleware.process(request, handler).await.unwrap();
372 assert_eq!(response.status, reinhardt_http::Response::ok().status);
373
374 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
375 assert!(body_str.contains("\"is_authenticated\":false"));
376 }
377
378 #[test]
379 fn test_auth_state_from_extensions() {
380 let extensions = reinhardt_http::Extensions::new();
381 extensions.insert("user123".to_string());
382 extensions.insert(IsAuthenticated(true));
383
384 let auth_state = AuthState::from_extensions(&extensions);
385 assert!(auth_state.is_some());
386 assert!(!auth_state.unwrap().is_anonymous());
387 }
388
389 #[test]
390 fn test_auth_state_is_anonymous() {
391 let anon_state = AuthState::anonymous();
392
393 assert!(anon_state.is_anonymous());
394
395 let auth_state = AuthState::authenticated("user123", false, true);
396
397 assert!(!auth_state.is_anonymous());
398 }
399}