use async_trait::async_trait;
#[allow(deprecated)]
use reinhardt_conf::Settings;
use reinhardt_http::{Handler, Middleware, MiddlewareDiRegistration, Request, Response, Result};
use std::any::TypeId;
use std::sync::Arc;
use super::config::SessionConfig;
use super::cookie::find_cookie_value;
use super::data::SessionData;
use super::id::{ActiveSessionId, SessionCookieName, SessionId};
use super::store::SessionStore;
pub struct SessionMiddleware {
config: SessionConfig,
store: Arc<SessionStore>,
}
impl SessionMiddleware {
pub fn new(config: SessionConfig) -> Self {
Self {
config,
store: Arc::new(SessionStore::new()),
}
}
#[allow(deprecated)] pub fn from_settings(settings: &Settings) -> Self {
Self::new(SessionConfig::from_settings(settings))
}
pub fn with_defaults() -> Self {
Self::new(SessionConfig::default())
}
pub fn from_arc(config: SessionConfig, store: Arc<SessionStore>) -> Self {
Self { config, store }
}
pub fn store(&self) -> &SessionStore {
&self.store
}
pub fn store_arc(&self) -> Arc<SessionStore> {
Arc::clone(&self.store)
}
fn get_session_id(&self, request: &Request) -> Option<String> {
find_cookie_value(request, &self.config.cookie_name)
}
fn build_cookie_header(&self, session_id: &str) -> String {
let mut parts = vec![format!("{}={}", self.config.cookie_name, session_id)];
parts.push(format!("Path={}", self.config.path));
if let Some(domain) = &self.config.domain {
parts.push(format!("Domain={}", domain));
}
if self.config.http_only {
parts.push("HttpOnly".to_string());
}
if self.config.secure {
parts.push("Secure".to_string());
}
if let Some(same_site) = &self.config.same_site {
parts.push(format!("SameSite={}", same_site));
}
parts.push(format!("Max-Age={}", self.config.ttl.as_secs()));
parts.join("; ")
}
}
impl Default for SessionMiddleware {
fn default() -> Self {
Self::with_defaults()
}
}
#[async_trait]
impl Middleware for SessionMiddleware {
fn di_registrations(&self) -> Vec<MiddlewareDiRegistration> {
vec![(
TypeId::of::<Arc<SessionStore>>(),
Arc::new(Arc::clone(&self.store)) as Arc<dyn std::any::Any + Send + Sync>,
)]
}
async fn process(&self, request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
let session_id = self.get_session_id(&request);
let mut session = if let Some(id) = session_id.clone() {
self.store
.get(&id)
.filter(|s| s.is_valid())
.unwrap_or_else(|| SessionData::new(self.config.ttl))
} else {
SessionData::new(self.config.ttl)
};
session.touch(self.config.ttl);
self.store.save(session.clone());
request
.extensions
.insert(SessionId::new(session.id.clone()));
request
.extensions
.insert(SessionCookieName::new(self.config.cookie_name.clone()));
let active_id = ActiveSessionId::new(session.id.clone());
request.extensions.insert(active_id.clone());
let mut response = match handler.handle(request).await {
Ok(resp) => resp,
Err(e) => Response::from(e),
};
let final_id = active_id.get();
let cookie = self.build_cookie_header(&final_id);
response.headers.append(
hyper::header::SET_COOKIE,
hyper::header::HeaderValue::from_str(&cookie).map_err(|e| {
reinhardt_core::exception::Error::Internal(format!(
"Failed to create cookie header: {}",
e
))
})?,
);
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use std::time::Duration;
fn make_middleware() -> SessionMiddleware {
SessionMiddleware::new(SessionConfig::new(
"sessionid".to_string(),
Duration::from_secs(3600),
))
}
#[rstest]
fn test_session_middleware_di_registrations_returns_store() {
let middleware = make_middleware();
let store_arc = middleware.store_arc();
let registrations = middleware.di_registrations();
assert_eq!(registrations.len(), 1);
let (type_id, value) = ®istrations[0];
assert_eq!(*type_id, TypeId::of::<Arc<SessionStore>>());
let downcast = value
.clone()
.downcast::<Arc<SessionStore>>()
.expect("registered Arc must downcast to Arc<SessionStore>");
assert!(
Arc::ptr_eq(&*downcast, &store_arc),
"middleware DI registration must expose the same Arc<SessionStore> the middleware writes to"
);
}
#[rstest]
fn test_session_middleware_di_registrations_apply_to_singleton_scope() {
let middleware = make_middleware();
let store_arc = middleware.store_arc();
let scope = reinhardt_di::SingletonScope::new();
let mut list = reinhardt_di::DiRegistrationList::new();
for (type_id, value) in middleware.di_registrations() {
list.register_arc_any(type_id, value);
}
list.apply_to(&scope);
let resolved = scope
.get::<Arc<SessionStore>>()
.expect("SingletonScope must resolve Arc<SessionStore> after applying middleware DI");
assert!(
Arc::ptr_eq(&*resolved, &store_arc),
"resolved Arc<SessionStore> must point at the same allocation the middleware owns"
);
}
#[tokio::test]
async fn test_session_data_inject_resolves_via_middleware_di_registrations() {
use crate::session::data::SessionData;
use bytes::Bytes;
use hyper::{Method, Version};
use reinhardt_di::{Injectable, InjectionContext, SingletonScope};
use reinhardt_http::Request;
let middleware = make_middleware();
let store_arc = middleware.store_arc();
let mut seeded = SessionData::new(Duration::from_secs(3600));
seeded
.set("user_id".to_string(), "alice".to_string())
.unwrap();
let seeded_id = seeded.id.clone();
store_arc.save(seeded.clone());
let scope = SingletonScope::new();
let mut list = reinhardt_di::DiRegistrationList::new();
for (type_id, value) in middleware.di_registrations() {
list.register_arc_any(type_id, value);
}
list.apply_to(&scope);
let request = Request::builder()
.method(Method::GET)
.uri("/")
.version(Version::HTTP_11)
.body(Bytes::new())
.build()
.unwrap();
request.extensions.insert(SessionId::new(seeded_id.clone()));
let ctx = InjectionContext::builder(Arc::new(scope)).build();
ctx.set_request(request);
let resolved = SessionData::inject(&ctx)
.await
.expect("SessionData::inject must succeed when middleware DI is registered");
assert_eq!(resolved.id, seeded_id);
assert_eq!(resolved.get::<String>("user_id").as_deref(), Some("alice"));
}
}