use async_trait::async_trait;
use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
use reinhardt_http::Request;
use std::sync::Arc;
use super::cookie::find_cookie_value;
use super::data::SessionData;
use super::id::{ActiveSessionId, SessionCookieName, SessionId};
use super::store::SessionStore;
const DEFAULT_SESSION_COOKIE_NAME: &str = "sessionid";
fn extract_session_id_from_request(request: &Request, cookie_name: &str) -> DiResult<String> {
find_cookie_value(request, cookie_name).ok_or_else(|| {
DiError::NotFound(format!(
"Session cookie '{}' not found in Cookie header",
cookie_name
))
})
}
#[async_trait]
impl Injectable for SessionData {
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
let store = ctx.get_singleton::<Arc<SessionStore>>().ok_or_else(|| {
DiError::NotFound(
concat!(
"SessionStore not found in SingletonScope. ",
"Ensure SessionMiddleware is configured and its store is registered."
)
.to_string(),
)
})?;
let request = ctx.get_request::<Request>().ok_or_else(|| {
DiError::NotFound("Request not found in InjectionContext".to_string())
})?;
let ext_cookie_name = request.extensions.get::<SessionCookieName>();
let cookie_name = ext_cookie_name
.as_ref()
.map(|cn| cn.as_str())
.unwrap_or(DEFAULT_SESSION_COOKIE_NAME);
let session_id = if let Some(sid) = request.extensions.get::<SessionId>() {
sid.as_ref().to_string()
} else {
extract_session_id_from_request(&request, cookie_name)?
};
let id_holder = request.extensions.get::<ActiveSessionId>();
let mut session = store
.get(&session_id)
.filter(|s| s.is_valid())
.ok_or_else(|| {
DiError::NotFound("Valid session not found. Session may have expired.".to_string())
})?;
session.id_holder = id_holder;
Ok(session)
}
}
#[derive(Clone)]
pub struct SessionStoreRef(pub Arc<SessionStore>);
impl SessionStoreRef {
pub fn inner(&self) -> &SessionStore {
&self.0
}
pub fn arc(&self) -> Arc<SessionStore> {
Arc::clone(&self.0)
}
}
#[async_trait]
impl Injectable for SessionStoreRef {
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
ctx.get_singleton::<Arc<SessionStore>>()
.map(|arc_store| SessionStoreRef(Arc::clone(&*arc_store)))
.ok_or_else(|| {
DiError::NotFound(
"SessionStore not found in SingletonScope. \
Ensure SessionMiddleware is configured and its store is registered."
.to_string(),
)
})
}
}