reinhardt_middleware/
auth.rs1#![allow(deprecated)]
4
5#[cfg(feature = "sessions")]
6use async_trait::async_trait;
7#[cfg(feature = "sessions")]
8use std::sync::Arc;
9
10#[cfg(feature = "sessions")]
11use reinhardt_http::{Handler, Middleware, Request, Response, Result};
12
13#[cfg(feature = "sessions")]
14use reinhardt_auth::session::{SESSION_KEY_USER_ID, SessionStore};
15#[cfg(feature = "sessions")]
16use reinhardt_auth::{AnonymousUser, AuthenticationBackend, User};
17
18#[cfg(feature = "sessions")]
111pub struct AuthenticationMiddleware<S: SessionStore, A: AuthenticationBackend> {
112 session_store: Arc<S>,
113 auth_backend: Arc<A>,
114}
115
116#[cfg(feature = "sessions")]
117impl<S: SessionStore, A: AuthenticationBackend> AuthenticationMiddleware<S, A> {
118 pub fn new(session_store: Arc<S>, auth_backend: Arc<A>) -> Self {
160 Self {
161 session_store,
162 auth_backend,
163 }
164 }
165
166 fn extract_session_id(&self, request: &Request) -> Option<String> {
171 const SESSION_COOKIE_NAME: &str = "sessionid";
172 request
173 .headers
174 .get("cookie")
175 .and_then(|v| v.to_str().ok())
176 .and_then(|cookies| {
177 cookies.split(';').find_map(|cookie| {
178 let mut parts = cookie.trim().split('=');
179 if parts.next()? == SESSION_COOKIE_NAME {
180 Some(parts.next()?.to_string())
181 } else {
182 None
183 }
184 })
185 })
186 .filter(|id| Self::is_valid_session_id(id))
187 }
188
189 fn is_valid_session_id(id: &str) -> bool {
194 if id.is_empty() || id.len() > 128 {
195 return false;
196 }
197 uuid::Uuid::parse_str(id).is_ok()
199 }
200
201 async fn get_user_from_session(&self, session_id: &String) -> Option<Box<dyn User>> {
203 if let Some(session) = self.session_store.load(session_id).await
204 && let Some(user_id_value) = session.get(SESSION_KEY_USER_ID)
205 && let Some(user_id) = user_id_value.as_str()
206 && let Ok(Some(user)) = self.auth_backend.get_user(user_id).await
207 {
208 return Some(user);
209 }
210 None
211 }
212}
213
214#[cfg(feature = "sessions")]
215#[async_trait]
216impl<S: SessionStore + 'static, A: AuthenticationBackend + 'static> Middleware
217 for AuthenticationMiddleware<S, A>
218{
219 async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
220 let user: Box<dyn User> = if let Some(ref session_id) = self.extract_session_id(&request) {
221 self.get_user_from_session(session_id)
222 .await
223 .unwrap_or_else(|| Box::new(AnonymousUser))
224 } else {
225 Box::new(AnonymousUser)
226 };
227
228 let is_authenticated = user.is_authenticated();
229 let is_admin = user.is_admin();
230 let is_active = user.is_active();
231 let user_id = user.id();
232
233 request.extensions.insert(user_id.clone());
235 request.extensions.insert(is_authenticated);
236 request.extensions.insert(is_admin);
237 request.extensions.insert(is_active);
238
239 let auth_state = if is_authenticated {
241 AuthState::authenticated(user_id, is_admin, is_active)
242 } else {
243 AuthState::anonymous()
244 };
245 request.extensions.insert(auth_state);
246
247 next.handle(request).await
248 }
249}
250
251pub use reinhardt_http::AuthState;
254
255#[cfg(all(test, feature = "sessions"))]
256mod tests {
257 use super::*;
258 use bytes::Bytes;
259 use hyper::{HeaderMap, Method, Version};
260 use reinhardt_auth::AuthenticationError;
261 use reinhardt_auth::SimpleUser;
262 use reinhardt_auth::session::{InMemorySessionStore, Session};
263 use uuid::Uuid;
264
265 struct TestHandler;
266
267 #[async_trait]
268 impl Handler for TestHandler {
269 async fn handle(&self, request: Request) -> Result<Response> {
270 let user_id: Option<String> = request.extensions.get();
271 let is_authenticated: Option<bool> = request.extensions.get();
272
273 Ok(Response::ok().with_json(&serde_json::json!({
274 "user_id": user_id.unwrap_or_default(),
275 "is_authenticated": is_authenticated.unwrap_or(false)
276 }))?)
277 }
278 }
279
280 struct TestAuthBackend {
281 user: Option<SimpleUser>,
282 }
283
284 #[async_trait::async_trait]
285 impl AuthenticationBackend for TestAuthBackend {
286 async fn authenticate(
287 &self,
288 _request: &Request,
289 ) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
290 Ok(self
291 .user
292 .as_ref()
293 .map(|u| Box::new(u.clone()) as Box<dyn User>))
294 }
295
296 async fn get_user(
297 &self,
298 _user_id: &str,
299 ) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
300 Ok(self
301 .user
302 .as_ref()
303 .map(|u| Box::new(u.clone()) as Box<dyn User>))
304 }
305 }
306
307 #[tokio::test]
308 async fn test_auth_middleware_with_valid_session() {
309 let session_store = Arc::new(InMemorySessionStore::new());
310 let user = SimpleUser {
311 id: Uuid::new_v4(),
312 username: "testuser".to_string(),
313 email: "test@example.com".to_string(),
314 is_active: true,
315 is_admin: false,
316 is_staff: false,
317 is_superuser: false,
318 };
319 let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
320
321 let session_id = session_store.create_session_id();
322 let mut session = Session::new();
323 session.set(SESSION_KEY_USER_ID, serde_json::json!("user123"));
324 session_store.save(&session_id, &session).await;
325
326 let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
327 let handler = Arc::new(TestHandler);
328
329 let mut headers = HeaderMap::new();
330 headers.insert(
331 "cookie",
332 format!("sessionid={}", session_id).parse().unwrap(),
333 );
334
335 let request = Request::builder()
336 .method(Method::GET)
337 .uri("/test")
338 .version(Version::HTTP_11)
339 .headers(headers)
340 .body(Bytes::new())
341 .build()
342 .unwrap();
343
344 let response = middleware.process(request, handler).await.unwrap();
345 assert_eq!(response.status, reinhardt_http::Response::ok().status);
346 }
347
348 #[tokio::test]
349 async fn test_auth_middleware_without_session() {
350 let session_store = Arc::new(InMemorySessionStore::new());
351 let auth_backend = Arc::new(TestAuthBackend { user: None });
352
353 let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
354 let handler = Arc::new(TestHandler);
355
356 let request = Request::builder()
357 .method(Method::GET)
358 .uri("/test")
359 .version(Version::HTTP_11)
360 .headers(HeaderMap::new())
361 .body(Bytes::new())
362 .build()
363 .unwrap();
364
365 let response = middleware.process(request, handler).await.unwrap();
366 assert_eq!(response.status, reinhardt_http::Response::ok().status);
367
368 let body_str = String::from_utf8(response.body.to_vec()).unwrap();
369 assert!(body_str.contains("\"is_authenticated\":false"));
370 }
371
372 #[test]
373 fn test_auth_state_from_extensions() {
374 let extensions = reinhardt_http::Extensions::new();
375 extensions.insert("user123".to_string());
376 extensions.insert(true);
377
378 let auth_state = AuthState::from_extensions(&extensions);
379 assert!(auth_state.is_some());
380 assert!(!auth_state.unwrap().is_anonymous());
381 }
382
383 #[test]
384 fn test_auth_state_is_anonymous() {
385 let anon_state = AuthState::anonymous();
386
387 assert!(anon_state.is_anonymous());
388
389 let auth_state = AuthState::authenticated("user123", false, true);
390
391 assert!(!auth_state.is_anonymous());
392 }
393}