use crate::auth::session::{FlashMessage, SessionData, SessionId};
use axum::{
extract::FromRequestParts,
http::{request::Parts, StatusCode},
};
use std::convert::Infallible;
#[derive(Debug, Clone)]
pub struct SessionExtractor(pub SessionId, pub SessionData);
impl<S> FromRequestParts<S> for SessionExtractor
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let session_id = parts
.extensions
.get::<SessionId>()
.cloned()
.ok_or((StatusCode::INTERNAL_SERVER_ERROR, "Session not initialized"))?;
let session_data = parts
.extensions
.get::<SessionData>()
.cloned()
.ok_or((StatusCode::INTERNAL_SERVER_ERROR, "Session data not found"))?;
Ok(Self(session_id, session_data))
}
}
#[derive(Debug, Clone, Default)]
pub struct FlashExtractor(pub Vec<FlashMessage>);
impl<S> FromRequestParts<S> for FlashExtractor
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let messages = parts
.extensions
.get_mut::<SessionData>()
.map(|session| std::mem::take(&mut session.flash_messages))
.unwrap_or_default();
Ok(Self(messages))
}
}
#[derive(Debug, Clone)]
pub struct OptionalSession(pub Option<(SessionId, SessionData)>);
impl<S> FromRequestParts<S> for OptionalSession
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let session = parts
.extensions
.get::<SessionId>()
.cloned()
.and_then(|id| {
parts
.extensions
.get::<SessionData>()
.cloned()
.map(|data| (id, data))
});
Ok(Self(session))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flash_extractor_default() {
let flash = FlashExtractor::default();
assert!(flash.0.is_empty());
}
#[test]
fn test_optional_session_default() {
let session: OptionalSession = OptionalSession(None);
assert!(session.0.is_none());
}
}