use std::time::Duration;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use http::HeaderValue;
use huskarl::core::crypto::cipher::{
AeadEncryptor, AeadSealer, AeadUnsealer, AeadV1Sealer, AeadV1Unsealer, BoxedAeadCipher,
CipherMatch,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::{
cookie::{
DEFAULT_COOKIE_MAX_AGE, cookie_attrs, encode_kid, get_cookie, get_kid_cookie,
kid_cookie_name,
},
session::{SessionDriver, SessionError, to_session_err},
session_state::{Session, SessionState},
};
pub trait ExternalSessionStore: Send + Sync {
type SessionType: Session + PersistedSession + Send + Sync + 'static;
type Error: std::error::Error + Send + Sync + 'static;
fn create(
&self,
persisted: PersistedSessionState,
completed: &crate::grant::CompletedLogin,
) -> impl Future<Output = Result<Self::SessionType, Self::Error>> + Send;
fn load(
&self,
session_key: Uuid,
) -> impl Future<Output = Result<Option<Self::SessionType>, Self::Error>> + Send;
fn save(
&self,
session: &Self::SessionType,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
fn touch(
&self,
session: &Self::SessionType,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
fn delete(
&self,
session: &Self::SessionType,
) -> impl Future<Output = Result<(), Self::Error>> + Send;
}
#[non_exhaustive]
#[derive(Clone, Serialize, Deserialize, bon::Builder)]
pub struct PersistedSessionState {
pub session_key: Uuid,
pub state: SessionState,
}
impl Session for PersistedSessionState {
fn state(&self) -> &SessionState {
&self.state
}
fn set_state(&mut self, state: SessionState) {
self.state = state;
}
}
pub trait PersistedSession {
fn persisted(&self) -> &PersistedSessionState;
fn persisted_mut(&mut self) -> &mut PersistedSessionState;
}
impl PersistedSession for PersistedSessionState {
fn persisted(&self) -> &PersistedSessionState {
self
}
fn persisted_mut(&mut self) -> &mut PersistedSessionState {
self
}
}
fn generate_session_key() -> Uuid {
Uuid::now_v7()
}
pub struct StoreBackedSessionStore<E> {
external: E,
sealer: AeadV1Sealer<BoxedAeadCipher>,
unsealer: AeadV1Unsealer<BoxedAeadCipher>,
cookie_name: String,
secure: bool,
cookie_path: String,
max_age: Duration,
}
#[bon::bon]
impl<E: ExternalSessionStore> StoreBackedSessionStore<E> {
#[builder]
pub fn new(
external: E,
cipher: BoxedAeadCipher,
#[builder(into)] cookie_name: String,
secure: bool,
#[builder(into)] cookie_path: String,
#[builder(default = DEFAULT_COOKIE_MAX_AGE)]
max_age: Duration,
) -> Self {
Self {
external,
sealer: AeadV1Sealer::new(cipher.clone()),
unsealer: AeadV1Unsealer::new(cipher),
cookie_name,
secure,
cookie_path,
max_age,
}
}
fn base_cookie_attrs(&self) -> String {
cookie_attrs(self.secure, &self.cookie_path)
}
fn cookie_attrs(&self) -> String {
format!(
"{}; Max-Age={}",
self.base_cookie_attrs(),
self.max_age.as_secs()
)
}
async fn pointer_cookie_headers(
&self,
session_key: Uuid,
) -> Result<Vec<HeaderValue>, SessionError> {
let bundle = self
.sealer
.seal(session_key.as_bytes(), b"session_ptr")
.await
.map_err(to_session_err)?;
let kid = self.sealer.key_id();
let cookie_value = URL_SAFE_NO_PAD.encode(&bundle);
let attrs = self.cookie_attrs();
let pointer =
HeaderValue::from_str(&format!("{}={cookie_value}; {attrs}", self.cookie_name))
.map_err(to_session_err)?;
let kid_header = self.build_kid_header(kid.as_deref())?;
Ok(vec![pointer, kid_header])
}
fn build_kid_header(&self, kid: Option<&str>) -> Result<HeaderValue, SessionError> {
let name = kid_cookie_name(&self.cookie_name);
let value = match kid {
Some(k) => format!("{name}={}; {}", encode_kid(k), self.cookie_attrs()),
None => format!("{name}=; {}; Max-Age=0", self.base_cookie_attrs()),
};
HeaderValue::from_str(&value).map_err(to_session_err)
}
async fn read_pointer_cookie(&self, headers: &http::HeaderMap) -> Option<Uuid> {
let encoded = get_cookie(headers, &self.cookie_name)?;
let bundle = URL_SAFE_NO_PAD.decode(encoded).ok()?;
let kid = get_kid_cookie(headers, &self.cookie_name);
let cipher_match = kid
.as_deref()
.map(|k| CipherMatch::builder().kid(k).build());
let plaintext = self
.unsealer
.unseal(cipher_match.as_ref(), &bundle, b"session_ptr")
.await
.ok()?;
let bytes: [u8; 16] = plaintext.try_into().ok()?;
Some(Uuid::from_bytes(bytes))
}
}
impl<E: ExternalSessionStore> StoreBackedSessionStore<E> {
pub(crate) async fn create_session(
&self,
completed: &crate::grant::CompletedLogin,
default_lifetime: std::time::Duration,
) -> Result<(E::SessionType, Vec<HeaderValue>), SessionError> {
let persisted = PersistedSessionState {
session_key: generate_session_key(),
state: SessionState::from_completed(completed, default_lifetime),
};
let session = self
.external
.create(persisted, completed)
.await
.map_err(to_session_err)?;
let cookies = self
.pointer_cookie_headers(session.persisted().session_key)
.await?;
Ok((session, cookies))
}
pub(crate) async fn load_session(
&self,
headers: &http::HeaderMap,
) -> Result<Option<E::SessionType>, E::Error> {
let Some(session_key) = self.read_pointer_cookie(headers).await else {
return Ok(None);
};
self.external.load(session_key).await
}
pub(crate) async fn save_session(
&self,
session: &E::SessionType,
) -> Result<Vec<HeaderValue>, SessionError> {
self.external.save(session).await.map_err(to_session_err)?;
Ok(vec![])
}
pub(crate) async fn touch_session(
&self,
session: &E::SessionType,
) -> Result<Vec<HeaderValue>, SessionError> {
self.external.touch(session).await.map_err(to_session_err)?;
Ok(vec![])
}
pub(crate) async fn delete_session(
&self,
session: &E::SessionType,
) -> Result<Vec<HeaderValue>, SessionError> {
self.external
.delete(session)
.await
.map_err(to_session_err)?;
let clear_attrs = format!("{}; Max-Age=0", self.base_cookie_attrs());
let mut headers = Vec::new();
if let Ok(v) = HeaderValue::from_str(&format!("{}=; {clear_attrs}", self.cookie_name)) {
headers.push(v);
}
let kid_name = kid_cookie_name(&self.cookie_name);
if let Ok(v) = HeaderValue::from_str(&format!("{kid_name}=; {clear_attrs}")) {
headers.push(v);
}
Ok(headers)
}
}
impl<E: ExternalSessionStore> crate::session::sealed::Sealed for StoreBackedSessionStore<E> {}
impl<E: ExternalSessionStore> SessionDriver for StoreBackedSessionStore<E> {
type SessionType = E::SessionType;
type LoadError = E::Error;
async fn create(
&self,
completed: crate::grant::CompletedLogin,
default_lifetime: std::time::Duration,
_headers: &http::HeaderMap,
) -> Result<(E::SessionType, Vec<HeaderValue>), SessionError> {
self.create_session(&completed, default_lifetime).await
}
async fn load(&self, headers: &http::HeaderMap) -> Result<Option<E::SessionType>, E::Error> {
self.load_session(headers).await
}
async fn save(
&self,
session: &E::SessionType,
_headers: &http::HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
self.save_session(session).await
}
async fn touch(
&self,
session: &E::SessionType,
_headers: &http::HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
self.touch_session(session).await
}
async fn delete(
&self,
session: &E::SessionType,
_headers: &http::HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
self.delete_session(session).await
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use huskarl::core::{
crypto::cipher::BoxedAeadCipher,
secrets::{Secret, SecretBytes, SecretOutput},
};
use huskarl_crypto_native::aead::{AesGcmKey, AesGcmKeyType};
use super::*;
use crate::session_state::{Session, SessionState};
#[derive(Clone)]
struct TestSecret(SecretBytes);
impl Secret for TestSecret {
type Output = SecretBytes;
type Error = Infallible;
async fn get_secret_value(&self) -> Result<SecretOutput<SecretBytes>, Infallible> {
Ok(SecretOutput {
value: self.0.clone(),
identity: None,
})
}
}
async fn test_cipher() -> BoxedAeadCipher {
let key = AesGcmKey::from_secret(
AesGcmKeyType::Aes256,
TestSecret(SecretBytes::new(vec![0u8; 32])),
|_| None,
)
.await
.unwrap();
BoxedAeadCipher::new(key)
}
#[derive(Clone)]
struct MinimalSession {
persisted: PersistedSessionState,
}
impl Session for MinimalSession {
fn state(&self) -> &SessionState {
self.persisted.state()
}
fn set_state(&mut self, s: SessionState) {
self.persisted.set_state(s);
}
}
impl PersistedSession for MinimalSession {
fn persisted(&self) -> &PersistedSessionState {
&self.persisted
}
fn persisted_mut(&mut self) -> &mut PersistedSessionState {
&mut self.persisted
}
}
struct MinimalExternalStore(MinimalSession);
impl ExternalSessionStore for MinimalExternalStore {
type SessionType = MinimalSession;
type Error = Infallible;
async fn create(
&self,
_: PersistedSessionState,
_: &crate::grant::CompletedLogin,
) -> Result<MinimalSession, Infallible> {
Ok(self.0.clone())
}
async fn load(&self, _: Uuid) -> Result<Option<MinimalSession>, Infallible> {
Ok(Some(self.0.clone()))
}
async fn save(&self, _: &MinimalSession) -> Result<(), Infallible> {
Ok(())
}
async fn touch(&self, _: &MinimalSession) -> Result<(), Infallible> {
Ok(())
}
async fn delete(&self, _: &MinimalSession) -> Result<(), Infallible> {
Ok(())
}
}
fn test_session() -> MinimalSession {
let now = std::time::SystemTime::now();
MinimalSession {
persisted: PersistedSessionState {
session_key: Uuid::now_v7(),
state: SessionState::builder()
.token_expiry(now + std::time::Duration::from_hours(1))
.created_at(now)
.last_active(now)
.build(),
},
}
}
#[tokio::test]
async fn touch_returns_no_cookies() {
let session = test_session();
let store = StoreBackedSessionStore::builder()
.external(MinimalExternalStore(session.clone()))
.cipher(test_cipher().await)
.cookie_name("session")
.secure(true)
.cookie_path("/")
.build();
let headers = store.touch_session(&session).await.unwrap();
assert!(
headers.is_empty(),
"touch should not re-emit the pointer cookie"
);
}
#[tokio::test]
async fn pointer_cookie_roundtrips_uuid() {
let session = test_session();
let original_key = session.persisted.session_key;
let store = StoreBackedSessionStore::builder()
.external(MinimalExternalStore(session.clone()))
.cipher(test_cipher().await)
.cookie_name("session")
.secure(true)
.cookie_path("/")
.build();
let headers_out = store.pointer_cookie_headers(original_key).await.unwrap();
let pointer = headers_out
.iter()
.find(|h| {
let s = h.to_str().unwrap();
let value_part = s.split(';').next().unwrap();
let (name, value) = value_part.split_once('=').unwrap();
name.trim() == "session" && !value.is_empty()
})
.expect("pointer cookie present");
let cookie_value = pointer
.to_str()
.unwrap()
.split(';')
.next()
.unwrap()
.split_once('=')
.unwrap()
.1;
let mut req_headers = http::HeaderMap::new();
req_headers.insert(
http::header::COOKIE,
format!("session={cookie_value}").parse().unwrap(),
);
let recovered = store
.read_pointer_cookie(&req_headers)
.await
.expect("decodes");
assert_eq!(recovered, original_key);
}
async fn test_cipher_with_kid(kid: &str) -> BoxedAeadCipher {
let kid_owned = kid.to_owned();
let key = AesGcmKey::from_secret(
AesGcmKeyType::Aes256,
TestSecret(SecretBytes::new(vec![0u8; 32])),
move |_| Some(kid_owned.clone()),
)
.await
.unwrap();
BoxedAeadCipher::new(key)
}
#[tokio::test]
async fn pointer_cookie_emits_kid_sidecar_when_cipher_has_identity() {
let session = test_session();
let store = StoreBackedSessionStore::builder()
.external(MinimalExternalStore(session.clone()))
.cipher(test_cipher_with_kid("kid-7").await)
.cookie_name("session")
.secure(true)
.cookie_path("/")
.build();
let headers_out = store
.pointer_cookie_headers(session.persisted.session_key)
.await
.unwrap();
let expected_value = URL_SAFE_NO_PAD.encode("kid-7".as_bytes());
let sidecar_set = headers_out.iter().any(|h| {
let s = h.to_str().unwrap();
s.starts_with(&format!("session.kid={expected_value};"))
});
assert!(
sidecar_set,
"expected kid sidecar set to base64url(identity)"
);
}
#[tokio::test]
async fn delete_clears_pointer_and_kid_sidecar() {
let session = test_session();
let store = StoreBackedSessionStore::builder()
.external(MinimalExternalStore(session.clone()))
.cipher(test_cipher().await)
.cookie_name("session")
.secure(true)
.cookie_path("/")
.build();
let clears = store.delete_session(&session).await.unwrap();
let bare = clears.iter().any(|h| {
let s = h.to_str().unwrap();
s.starts_with("session=;") && s.contains("Max-Age=0")
});
let kid = clears.iter().any(|h| {
let s = h.to_str().unwrap();
s.starts_with("session.kid=;") && s.contains("Max-Age=0")
});
assert!(bare, "expected pointer cookie clear");
assert!(kid, "expected kid sidecar clear");
}
}