use async_trait::async_trait;
use reinhardt_di::params::{ParamContext, ParamError, ParamResult, extract::FromRequest};
use reinhardt_di::{DiError, DiResult, Injectable, InjectionContext};
use reinhardt_http::Request;
use serde::de::DeserializeOwned;
use std::fmt::{self, Debug};
use std::marker::PhantomData;
use std::ops::Deref;
use super::data::{SessionData, USER_ID_SESSION_KEY};
pub trait SessionKey: Send + Sync + 'static {
const KEY: &'static str;
}
#[derive(Debug, Clone, Copy)]
pub struct UserIdKey;
impl SessionKey for UserIdKey {
const KEY: &'static str = USER_ID_SESSION_KEY;
}
#[derive(Debug, Clone)]
pub struct SessionValue<T>(pub T);
#[derive(Debug, Clone)]
pub struct OptionalSessionValue<T>(pub Option<T>);
pub struct SessionValueNamed<K: SessionKey, T> {
value: T,
_phantom: PhantomData<fn() -> K>,
}
impl<K: SessionKey, T> SessionValueNamed<K, T> {
pub fn new(value: T) -> Self {
Self {
value,
_phantom: PhantomData,
}
}
pub fn into_inner(self) -> T {
self.value
}
}
impl<K: SessionKey, T> Deref for SessionValueNamed<K, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<K: SessionKey, T: Debug> Debug for SessionValueNamed<K, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SessionValueNamed")
.field("key", &K::KEY)
.field("value", &self.value)
.finish()
}
}
impl<K: SessionKey, T: Clone> Clone for SessionValueNamed<K, T> {
fn clone(&self) -> Self {
Self {
value: self.value.clone(),
_phantom: PhantomData,
}
}
}
pub struct OptionalSessionValueNamed<K: SessionKey, T> {
value: Option<T>,
_phantom: PhantomData<fn() -> K>,
}
impl<K: SessionKey, T> OptionalSessionValueNamed<K, T> {
pub fn new(value: Option<T>) -> Self {
Self {
value,
_phantom: PhantomData,
}
}
pub fn into_inner(self) -> Option<T> {
self.value
}
}
impl<K: SessionKey, T> Deref for OptionalSessionValueNamed<K, T> {
type Target = Option<T>;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<K: SessionKey, T: Debug> Debug for OptionalSessionValueNamed<K, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OptionalSessionValueNamed")
.field("key", &K::KEY)
.field("value", &self.value)
.finish()
}
}
impl<K: SessionKey, T: Clone> Clone for OptionalSessionValueNamed<K, T> {
fn clone(&self) -> Self {
Self {
value: self.value.clone(),
_phantom: PhantomData,
}
}
}
async fn load_session_value_via_di<T>(ctx: &InjectionContext, key: &str) -> DiResult<T>
where
T: DeserializeOwned + Send + Sync + 'static,
{
let session = SessionData::inject(ctx).await?;
session.get::<T>(key).ok_or_else(|| {
DiError::Authentication(format!(
"SessionValue<{}>: no value stored under session key '{}'",
std::any::type_name::<T>(),
key,
))
})
}
async fn load_session_value_via_request<T>(req: &Request, key: &str) -> ParamResult<T>
where
T: DeserializeOwned + Send + Sync + 'static,
{
let di_ctx = req.get_di_context::<InjectionContext>().ok_or_else(|| {
ParamError::Internal(
"SessionValue: DI context not available on the request. \
Ensure the router is configured with `.with_di_context()` and \
`SessionMiddleware` is installed in the middleware chain."
.to_string(),
)
})?;
load_session_value_via_di::<T>(&di_ctx, key)
.await
.map_err(di_error_to_param_error)
}
fn di_error_to_param_error(err: DiError) -> ParamError {
match err {
DiError::Authentication(msg) | DiError::NotFound(msg) => ParamError::Authentication(msg),
other => ParamError::Internal(other.to_string()),
}
}
#[async_trait]
impl<T> Injectable for SessionValue<T>
where
T: DeserializeOwned + Send + Sync + 'static,
{
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
load_session_value_via_di::<T>(ctx, USER_ID_SESSION_KEY)
.await
.map(SessionValue)
}
}
#[async_trait]
impl<T> Injectable for OptionalSessionValue<T>
where
T: DeserializeOwned + Send + Sync + 'static,
{
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
match SessionData::inject(ctx).await {
Ok(session) => Ok(OptionalSessionValue(session.get::<T>(USER_ID_SESSION_KEY))),
Err(DiError::NotFound(_)) => Ok(OptionalSessionValue(None)),
Err(e) => Err(e),
}
}
}
#[async_trait]
impl<K, T> Injectable for SessionValueNamed<K, T>
where
K: SessionKey,
T: DeserializeOwned + Send + Sync + 'static,
{
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
load_session_value_via_di::<T>(ctx, K::KEY)
.await
.map(Self::new)
}
}
#[async_trait]
impl<K, T> Injectable for OptionalSessionValueNamed<K, T>
where
K: SessionKey,
T: DeserializeOwned + Send + Sync + 'static,
{
async fn inject(ctx: &InjectionContext) -> DiResult<Self> {
match SessionData::inject(ctx).await {
Ok(session) => Ok(Self::new(session.get::<T>(K::KEY))),
Err(DiError::NotFound(_)) => Ok(Self::new(None)),
Err(e) => Err(e),
}
}
}
#[async_trait]
impl<T> FromRequest for SessionValue<T>
where
T: DeserializeOwned + Send + Sync + 'static,
{
async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
load_session_value_via_request::<T>(req, USER_ID_SESSION_KEY)
.await
.map(SessionValue)
}
}
#[async_trait]
impl<T> FromRequest for OptionalSessionValue<T>
where
T: DeserializeOwned + Send + Sync + 'static,
{
async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
let di_ctx = match req.get_di_context::<InjectionContext>() {
Some(c) => c,
None => return Ok(OptionalSessionValue(None)),
};
match SessionData::inject(&di_ctx).await {
Ok(session) => Ok(OptionalSessionValue(session.get::<T>(USER_ID_SESSION_KEY))),
Err(_) => Ok(OptionalSessionValue(None)),
}
}
}
#[async_trait]
impl<K, T> FromRequest for SessionValueNamed<K, T>
where
K: SessionKey,
T: DeserializeOwned + Send + Sync + 'static,
{
async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
load_session_value_via_request::<T>(req, K::KEY)
.await
.map(Self::new)
}
}
#[async_trait]
impl<K, T> FromRequest for OptionalSessionValueNamed<K, T>
where
K: SessionKey,
T: DeserializeOwned + Send + Sync + 'static,
{
async fn from_request(req: &Request, _ctx: &ParamContext) -> ParamResult<Self> {
let di_ctx = match req.get_di_context::<InjectionContext>() {
Some(c) => c,
None => return Ok(Self::new(None)),
};
match SessionData::inject(&di_ctx).await {
Ok(session) => Ok(Self::new(session.get::<T>(K::KEY))),
Err(_) => Ok(Self::new(None)),
}
}
}
#[cfg(test)]
mod tests {
use super::super::test_support::TenantIdKey;
use super::*;
use rstest::rstest;
#[rstest]
fn user_id_key_resolves_to_canonical_session_key() {
let key = UserIdKey::KEY;
assert_eq!(key, USER_ID_SESSION_KEY);
}
#[rstest]
fn session_value_named_constructor_and_deref_roundtrip() {
let extractor = SessionValueNamed::<TenantIdKey, i64>::new(42);
let via_deref: i64 = *extractor;
let via_into_inner = extractor.into_inner();
assert_eq!(via_deref, 42);
assert_eq!(via_into_inner, 42);
}
#[rstest]
fn optional_session_value_named_constructor_and_deref_roundtrip_some() {
let extractor = OptionalSessionValueNamed::<TenantIdKey, i64>::new(Some(7));
let via_deref: Option<i64> = *extractor;
let via_into_inner = extractor.into_inner();
assert_eq!(via_deref, Some(7));
assert_eq!(via_into_inner, Some(7));
}
#[rstest]
fn optional_session_value_named_constructor_and_deref_roundtrip_none() {
let extractor = OptionalSessionValueNamed::<TenantIdKey, i64>::new(None);
let via_deref: Option<i64> = *extractor;
let via_into_inner = extractor.into_inner();
assert_eq!(via_deref, None);
assert_eq!(via_into_inner, None);
}
#[rstest]
fn optional_session_value_named_debug_includes_key_name() {
let extractor = OptionalSessionValueNamed::<TenantIdKey, i64>::new(Some(99));
let rendered = format!("{extractor:?}");
assert!(
rendered.contains("OptionalSessionValueNamed"),
"Debug output should name the struct, got {rendered:?}"
);
assert!(
rendered.contains("tenant_id"),
"Debug output should include the session key name, got {rendered:?}"
);
}
#[rstest]
fn optional_session_value_named_clone_preserves_inner_some() {
let original = OptionalSessionValueNamed::<TenantIdKey, i64>::new(Some(123));
let cloned = original.clone();
assert_eq!(*cloned, Some(123));
assert_eq!(*original, Some(123));
}
#[rstest]
fn di_error_authentication_maps_to_param_authentication() {
let di_err = DiError::Authentication("nope".to_string());
let param_err = di_error_to_param_error(di_err);
match param_err {
ParamError::Authentication(msg) => assert_eq!(msg, "nope"),
other => panic!("expected ParamError::Authentication, got {other:?}"),
}
}
#[rstest]
fn di_error_not_found_maps_to_param_authentication() {
let di_err = DiError::NotFound("missing session".to_string());
let param_err = di_error_to_param_error(di_err);
assert!(matches!(param_err, ParamError::Authentication(_)));
}
}