#![allow(deprecated)]
#[cfg(feature = "sessions")]
use async_trait::async_trait;
#[cfg(feature = "sessions")]
use std::sync::Arc;
#[cfg(feature = "sessions")]
use reinhardt_http::{
Handler, IsActive, IsAdmin, IsAuthenticated, Middleware, Request, Response, Result,
};
#[cfg(feature = "sessions")]
use reinhardt_auth::session::{SESSION_KEY_USER_ID, SessionStore};
#[cfg(feature = "sessions")]
use reinhardt_auth::{AnonymousUser, AuthenticationBackend, User};
#[cfg(feature = "sessions")]
pub struct AuthenticationMiddleware<S: SessionStore, A: AuthenticationBackend> {
session_store: Arc<S>,
auth_backend: Arc<A>,
}
#[cfg(feature = "sessions")]
impl<S: SessionStore, A: AuthenticationBackend> AuthenticationMiddleware<S, A> {
pub fn new(session_store: Arc<S>, auth_backend: Arc<A>) -> Self {
Self {
session_store,
auth_backend,
}
}
fn extract_session_id(&self, request: &Request) -> Option<String> {
const SESSION_COOKIE_NAME: &str = "sessionid";
request
.headers
.get("cookie")
.and_then(|v| v.to_str().ok())
.and_then(|cookies| {
cookies.split(';').find_map(|cookie| {
let mut parts = cookie.trim().split('=');
if parts.next()? == SESSION_COOKIE_NAME {
Some(parts.next()?.to_string())
} else {
None
}
})
})
.filter(|id| Self::is_valid_session_id(id))
}
fn is_valid_session_id(id: &str) -> bool {
if id.is_empty() || id.len() > 128 {
return false;
}
uuid::Uuid::parse_str(id).is_ok()
}
async fn get_user_from_session(&self, session_id: &String) -> Option<Box<dyn User>> {
if let Some(session) = self.session_store.load(session_id).await
&& let Some(user_id_value) = session.get(SESSION_KEY_USER_ID)
&& let Some(user_id) = user_id_value.as_str()
&& let Ok(Some(user)) = self.auth_backend.get_user(user_id).await
{
return Some(user);
}
None
}
}
#[cfg(feature = "sessions")]
#[async_trait]
impl<S: SessionStore + 'static, A: AuthenticationBackend + 'static> Middleware
for AuthenticationMiddleware<S, A>
{
async fn process(&self, request: Request, next: Arc<dyn Handler>) -> Result<Response> {
let user: Box<dyn User> = if let Some(ref session_id) = self.extract_session_id(&request) {
self.get_user_from_session(session_id)
.await
.unwrap_or_else(|| Box::new(AnonymousUser))
} else {
Box::new(AnonymousUser)
};
let is_authenticated = user.is_authenticated();
let is_admin = user.is_admin();
let is_active = user.is_active();
let user_id = user.id();
request.extensions.insert(user_id.clone());
request.extensions.insert(IsAuthenticated(is_authenticated));
request.extensions.insert(IsAdmin(is_admin));
request.extensions.insert(IsActive(is_active));
let auth_state = if is_authenticated {
AuthState::authenticated(user_id, is_admin, is_active)
} else {
AuthState::anonymous()
};
request.extensions.insert(auth_state);
next.handle(request).await
}
}
pub use reinhardt_http::AuthState;
#[cfg(all(test, feature = "sessions"))]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, Version};
use reinhardt_auth::AuthenticationError;
use reinhardt_auth::SimpleUser;
use reinhardt_auth::session::{InMemorySessionStore, Session};
use uuid::Uuid;
struct TestHandler;
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, request: Request) -> Result<Response> {
let user_id: Option<String> = request.extensions.get();
let is_authenticated = request
.extensions
.get::<IsAuthenticated>()
.map(|v| v.0)
.unwrap_or(false);
Ok(Response::ok().with_json(&serde_json::json!({
"user_id": user_id.unwrap_or_default(),
"is_authenticated": is_authenticated
}))?)
}
}
struct TestAuthBackend {
user: Option<SimpleUser>,
}
#[async_trait::async_trait]
impl AuthenticationBackend for TestAuthBackend {
async fn authenticate(
&self,
_request: &Request,
) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
Ok(self
.user
.as_ref()
.map(|u| Box::new(u.clone()) as Box<dyn User>))
}
async fn get_user(
&self,
_user_id: &str,
) -> std::result::Result<Option<Box<dyn User>>, AuthenticationError> {
Ok(self
.user
.as_ref()
.map(|u| Box::new(u.clone()) as Box<dyn User>))
}
}
#[tokio::test]
async fn test_auth_middleware_with_valid_session() {
let session_store = Arc::new(InMemorySessionStore::new());
let user = SimpleUser {
id: Uuid::now_v7(),
username: "testuser".to_string(),
email: "test@example.com".to_string(),
is_active: true,
is_admin: false,
is_staff: false,
is_superuser: false,
};
let auth_backend = Arc::new(TestAuthBackend { user: Some(user) });
let session_id = session_store.create_session_id();
let mut session = Session::new();
session.set(SESSION_KEY_USER_ID, serde_json::json!("user123"));
session_store.save(&session_id, &session).await;
let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(
"cookie",
format!("sessionid={}", session_id).parse().unwrap(),
);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, reinhardt_http::Response::ok().status);
}
#[tokio::test]
async fn test_auth_middleware_without_session() {
let session_store = Arc::new(InMemorySessionStore::new());
let auth_backend = Arc::new(TestAuthBackend { user: None });
let middleware = AuthenticationMiddleware::new(session_store, auth_backend);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/test")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
assert_eq!(response.status, reinhardt_http::Response::ok().status);
let body_str = String::from_utf8(response.body.to_vec()).unwrap();
assert!(body_str.contains("\"is_authenticated\":false"));
}
#[test]
fn test_auth_state_from_extensions() {
let extensions = reinhardt_http::Extensions::new();
extensions.insert("user123".to_string());
extensions.insert(IsAuthenticated(true));
let auth_state = AuthState::from_extensions(&extensions);
assert!(auth_state.is_some());
assert!(!auth_state.unwrap().is_anonymous());
}
#[test]
fn test_auth_state_is_anonymous() {
let anon_state = AuthState::anonymous();
assert!(anon_state.is_anonymous());
let auth_state = AuthState::authenticated("user123", false, true);
assert!(!auth_state.is_anonymous());
}
}