use std::convert::Infallible;
use std::sync::Arc;
use axum_core::extract::{FromRequestParts, OptionalFromRequestParts};
use http::request::Parts;
use tokio::sync::Mutex;
use crate::error::SessionRejection;
use crate::state::SessionState;
pub struct Session<T>(Arc<Mutex<SessionState<T>>>);
impl<T> Clone for Session<T> {
fn clone(&self) -> Self {
Self(Arc::clone(&self.0))
}
}
impl<T> std::fmt::Debug for Session<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Session")
.field("state", &Arc::as_ptr(&self.0))
.finish()
}
}
impl<T> Session<T> {
pub async fn get(&self) -> Option<T>
where
T: Clone,
{
self.0.lock().await.payload.clone()
}
pub async fn insert(&self, value: T) {
let mut guard = self.0.lock().await;
guard.payload = Some(value);
guard.mutated = true;
}
pub async fn take(&self) -> Option<T> {
let mut guard = self.0.lock().await;
guard.mutated = true;
guard.payload.take()
}
pub async fn clear(&self) {
let mut guard = self.0.lock().await;
guard.mutated = true;
guard.payload = None;
}
pub async fn modify<R>(&self, f: impl FnOnce(&mut Option<T>) -> R) -> R {
let mut guard = self.0.lock().await;
let r = f(&mut guard.payload);
guard.mutated = true;
r
}
}
impl<S, T> FromRequestParts<S> for Session<T>
where
S: Send + Sync,
T: Send + 'static,
{
type Rejection = SessionRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts
.extensions
.get::<Arc<Mutex<SessionState<T>>>>()
.cloned()
.map(Session)
.ok_or(SessionRejection::NotMounted)
}
}
impl<S, T> OptionalFromRequestParts<S> for Session<T>
where
S: Send + Sync,
T: Send + 'static,
{
type Rejection = Infallible;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Option<Self>, Self::Rejection> {
Ok(parts
.extensions
.get::<Arc<Mutex<SessionState<T>>>>()
.cloned()
.map(Session))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::SystemTime;
use http::Request;
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
struct UserPayload {
id: u64,
name: String,
}
fn sample_payload() -> UserPayload {
UserPayload {
id: 42,
name: "alice".into(),
}
}
fn empty_session_with_payload(payload: Option<UserPayload>) -> Session<UserPayload> {
let now = SystemTime::UNIX_EPOCH;
let mut state = SessionState::new_empty(now);
state.payload = payload;
Session(Arc::new(Mutex::new(state)))
}
fn parts_with_state(
payload: Option<UserPayload>,
) -> (Parts, Arc<Mutex<SessionState<UserPayload>>>) {
let now = SystemTime::UNIX_EPOCH;
let mut state = SessionState::new_empty(now);
state.payload = payload;
let arc = Arc::new(Mutex::new(state));
let req = Request::builder().body(()).unwrap();
let (mut parts, ()) = req.into_parts();
parts.extensions.insert(Arc::clone(&arc));
(parts, arc)
}
#[tokio::test]
async fn session_get_is_async() {
let payload = sample_payload();
let session = empty_session_with_payload(Some(payload.clone()));
assert_eq!(session.get().await, Some(payload));
}
#[tokio::test]
async fn session_insert_is_async() {
let session = empty_session_with_payload(None);
let payload = sample_payload();
session.insert(payload.clone()).await;
assert_eq!(session.get().await, Some(payload));
assert!(session.0.lock().await.mutated, "insert must set mutated");
}
#[tokio::test]
async fn session_take_is_async() {
let payload = sample_payload();
let session = empty_session_with_payload(Some(payload.clone()));
assert_eq!(session.take().await, Some(payload));
assert_eq!(session.get().await, None);
assert!(session.0.lock().await.mutated, "take must set mutated");
}
#[tokio::test]
async fn session_clear_is_async() {
let session = empty_session_with_payload(Some(sample_payload()));
session.clear().await;
assert_eq!(session.get().await, None);
assert!(session.0.lock().await.mutated, "clear must set mutated");
}
#[tokio::test]
async fn session_modify_is_async() {
let payload = sample_payload();
let session = empty_session_with_payload(Some(payload.clone()));
let observed_id: u64 = session
.modify(|p| {
let user = p.as_mut().expect("modify sees the prior payload");
user.id += 1;
user.id
})
.await;
assert_eq!(observed_id, payload.id + 1);
assert_eq!(
session.get().await.map(|u| u.id),
Some(payload.id + 1),
"modify must persist its mutation"
);
assert!(session.0.lock().await.mutated, "modify must set mutated");
}
#[tokio::test]
async fn from_request_parts_returns_session_when_extension_present_seshcookie_rs_ac5_1() {
let payload = sample_payload();
let (mut parts, _arc) = parts_with_state(Some(payload.clone()));
let session =
<Session<UserPayload> as FromRequestParts<()>>::from_request_parts(&mut parts, &())
.await
.expect("extension is present, extraction must succeed");
assert_eq!(session.get().await, Some(payload));
}
#[tokio::test]
async fn from_request_parts_returns_not_mounted_when_extension_missing_seshcookie_rs_ac5_1() {
let req = Request::builder().body(()).unwrap();
let (mut parts, ()) = req.into_parts();
let result =
<Session<UserPayload> as FromRequestParts<()>>::from_request_parts(&mut parts, &())
.await;
assert_eq!(result.unwrap_err(), SessionRejection::NotMounted);
}
#[tokio::test]
async fn optional_from_request_parts_returns_none_when_extension_missing_seshcookie_rs_ac5_3() {
let req = Request::builder().body(()).unwrap();
let (mut parts, ()) = req.into_parts();
let result = <Session<UserPayload> as OptionalFromRequestParts<()>>::from_request_parts(
&mut parts,
&(),
)
.await;
match result {
Ok(None) => {}
Ok(Some(_)) => panic!("missing extension must yield None"),
Err(infallible) => match infallible {},
}
}
#[tokio::test]
async fn optional_from_request_parts_returns_some_when_extension_present_seshcookie_rs_ac5_3() {
let payload = sample_payload();
let (mut parts, _arc) = parts_with_state(Some(payload.clone()));
let result = <Session<UserPayload> as OptionalFromRequestParts<()>>::from_request_parts(
&mut parts,
&(),
)
.await
.expect("Infallible cannot be constructed");
let session = result.expect("extension was inserted, must yield Some");
assert_eq!(session.get().await, Some(payload));
}
#[tokio::test]
async fn two_extractions_share_underlying_state_seshcookie_rs_ac5_5() {
let (mut parts, _arc) = parts_with_state(None);
let s1 =
<Session<UserPayload> as FromRequestParts<()>>::from_request_parts(&mut parts, &())
.await
.expect("first extraction must succeed");
let s2 =
<Session<UserPayload> as FromRequestParts<()>>::from_request_parts(&mut parts, &())
.await
.expect("second extraction must succeed");
let payload = sample_payload();
s1.insert(payload.clone()).await;
assert_eq!(
s2.get().await,
Some(payload),
"second handle must see s1's insert"
);
s2.take().await;
assert_eq!(s1.get().await, None, "first handle must see s2's take");
}
#[tokio::test]
async fn clone_produces_shared_state_handle_seshcookie_rs_ac5_5() {
let s1 = empty_session_with_payload(None);
let s2 = s1.clone();
let payload = sample_payload();
s1.insert(payload.clone()).await;
assert_eq!(s2.get().await, Some(payload));
}
#[tokio::test]
async fn debug_does_not_lock_mutex() {
let session = empty_session_with_payload(Some(sample_payload()));
let guard = session.0.lock().await;
let rendered = format!("{session:?}");
assert!(
rendered.contains("Session"),
"Debug must name the type: got {rendered:?}"
);
assert!(
!rendered.contains("alice"),
"payload must never appear in Debug output: got {rendered:?}"
);
drop(guard);
}
}