use std::{marker::PhantomData, time::Duration};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use http::HeaderValue;
use huskarl::core::crypto::cipher::{
AeadEncryptor, AeadSealer, AeadUnsealer, AeadV1Sealer, AeadV1Unsealer, BoxedAeadCipher,
CipherMatch,
};
use serde::{Deserialize, Serialize};
use crate::{
cookie::{
DEFAULT_COOKIE_MAX_AGE, cookie_attrs, decode_payload, encode_kid, encode_payload,
get_kid_cookie, kid_cookie_name,
},
grant::CompletedLogin,
session::{SessionDriver, SessionError, to_session_err},
session_state::{Session, SessionState},
};
const CHUNK_SIZE: usize = 3800;
pub trait CookieData:
Session + Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static
{
type Error: std::error::Error + Send + Sync + 'static;
fn from_login(state: SessionState, completed: &CompletedLogin) -> Result<Self, Self::Error>;
}
#[derive(Serialize, Deserialize)]
#[serde(transparent)]
pub struct CookieSession(SessionState);
impl Session for CookieSession {
fn state(&self) -> &SessionState {
&self.0
}
fn set_state(&mut self, state: SessionState) {
self.0 = state;
}
}
impl CookieData for CookieSession {
type Error = std::convert::Infallible;
fn from_login(state: SessionState, _completed: &CompletedLogin) -> Result<Self, Self::Error> {
Ok(CookieSession(state))
}
}
pub struct CookieSessionStore<C = CookieSession> {
sealer: AeadV1Sealer<BoxedAeadCipher>,
unsealer: AeadV1Unsealer<BoxedAeadCipher>,
cookie_name: String,
secure: bool,
cookie_path: String,
max_age: Duration,
_phantom: PhantomData<C>,
}
#[bon::bon]
impl<C> CookieSessionStore<C> {
#[builder]
pub fn new(
cipher: BoxedAeadCipher,
#[builder(into)] cookie_name: String,
secure: bool,
#[builder(into)] cookie_path: String,
#[builder(default = DEFAULT_COOKIE_MAX_AGE)]
max_age: Duration,
) -> Self {
Self {
sealer: AeadV1Sealer::new(cipher.clone()),
unsealer: AeadV1Unsealer::new(cipher),
cookie_name,
secure,
cookie_path,
max_age,
_phantom: PhantomData,
}
}
fn base_cookie_attrs(&self) -> String {
cookie_attrs(self.secure, &self.cookie_path)
}
fn cookie_attrs(&self) -> String {
format!(
"{}; Max-Age={}",
self.base_cookie_attrs(),
self.max_age.as_secs()
)
}
}
impl<C: CookieData> CookieSessionStore<C> {
pub(crate) async fn load_session(&self, headers: &http::HeaderMap) -> Option<C> {
let chunks = self.collect_session_chunks(headers);
let raw_encoded = reassemble_chunks(&chunks)?;
let bundle = URL_SAFE_NO_PAD.decode(&raw_encoded).ok()?;
let kid = get_kid_cookie(headers, &self.cookie_name);
let cipher_match = kid
.as_deref()
.map(|k| CipherMatch::builder().kid(k).build());
let plaintext = self
.unsealer
.unseal(cipher_match.as_ref(), &bundle, b"session")
.await
.ok()?;
decode_payload(&plaintext).ok()
}
fn collect_session_chunks(
&self,
headers: &http::HeaderMap,
) -> std::collections::HashMap<usize, String> {
let mut chunks = std::collections::HashMap::new();
for value in headers.get_all(http::header::COOKIE) {
let Ok(s) = value.to_str() else { continue };
for pair in s.split(';') {
if let Some((index, val)) = self.parse_chunk_pair(pair) {
chunks.insert(index, val);
}
}
}
chunks
}
fn parse_chunk_pair(&self, pair: &str) -> Option<(usize, String)> {
let (k, v) = pair.trim().split_once('=')?;
Some((self.parse_chunk_index(k)?, v.trim().to_owned()))
}
fn parse_chunk_index(&self, name: &str) -> Option<usize> {
let suffix = name.trim().strip_prefix(&self.cookie_name)?;
suffix.strip_prefix('.')?.parse::<usize>().ok()
}
fn for_each_request_chunk_index(&self, headers: &http::HeaderMap, mut f: impl FnMut(usize)) {
for value in headers.get_all(http::header::COOKIE) {
let Ok(s) = value.to_str() else { continue };
for pair in s.split(';') {
let Some((name, _)) = pair.trim().split_once('=') else {
continue;
};
if let Some(idx) = self.parse_chunk_index(name) {
f(idx);
}
}
}
}
pub(crate) async fn save_session(
&self,
session: &C,
request_headers: &http::HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
let payload = encode_payload(session)?;
let bundle = self
.sealer
.seal(&payload, b"session")
.await
.map_err(to_session_err)?;
let kid = self.sealer.key_id();
let cookie_value = URL_SAFE_NO_PAD.encode(&bundle);
let chunks = split_into_chunks(&cookie_value);
let num_chunks = chunks.len();
let attrs = self.cookie_attrs();
let mut headers = Vec::with_capacity(num_chunks + 2);
for (i, chunk) in chunks.iter().enumerate() {
headers.push(self.build_chunk_header(i, chunk, &attrs)?);
}
self.append_clears_for_leftover_chunks(&mut headers, num_chunks, request_headers);
headers.push(self.build_kid_header(kid.as_deref())?);
Ok(headers)
}
fn build_chunk_header(
&self,
i: usize,
chunk: &str,
attrs: &str,
) -> Result<HeaderValue, SessionError> {
HeaderValue::from_str(&format!("{}.{i}={chunk}; {attrs}", self.cookie_name))
.map_err(to_session_err)
}
fn build_kid_header(&self, kid: Option<&str>) -> Result<HeaderValue, SessionError> {
let name = kid_cookie_name(&self.cookie_name);
let value = match kid {
Some(k) => format!("{name}={}; {}", encode_kid(k), self.cookie_attrs()),
None => format!("{name}=; {}; Max-Age=0", self.base_cookie_attrs()),
};
HeaderValue::from_str(&value).map_err(to_session_err)
}
fn append_clears_for_leftover_chunks(
&self,
headers: &mut Vec<HeaderValue>,
num_chunks: usize,
request_headers: &http::HeaderMap,
) {
let clear_attrs = format!("{}; Max-Age=0", self.base_cookie_attrs());
let cookie_name = &self.cookie_name;
self.for_each_request_chunk_index(request_headers, |idx| {
if idx >= num_chunks
&& let Ok(v) =
HeaderValue::from_str(&format!("{cookie_name}.{idx}=; {clear_attrs}"))
{
headers.push(v);
}
});
}
pub(crate) fn delete_headers(&self, request_headers: &http::HeaderMap) -> Vec<HeaderValue> {
let clear_attrs = format!("{}; Max-Age=0", self.base_cookie_attrs());
let cookie_name = &self.cookie_name;
let mut headers = Vec::new();
let kid_name = kid_cookie_name(cookie_name);
if let Ok(v) = HeaderValue::from_str(&format!("{kid_name}=; {clear_attrs}")) {
headers.push(v);
}
self.for_each_request_chunk_index(request_headers, |idx| {
if let Ok(v) = HeaderValue::from_str(&format!("{cookie_name}.{idx}=; {clear_attrs}")) {
headers.push(v);
}
});
headers
}
}
impl<C: CookieData> crate::session::sealed::Sealed for CookieSessionStore<C> {}
impl<C: CookieData> SessionDriver for CookieSessionStore<C> {
type SessionType = C;
type LoadError = std::convert::Infallible;
async fn create(
&self,
completed: CompletedLogin,
default_lifetime: std::time::Duration,
headers: &http::HeaderMap,
) -> Result<(C, Vec<HeaderValue>), SessionError> {
let state = SessionState::from_completed(&completed, default_lifetime);
let session = C::from_login(state, &completed).map_err(to_session_err)?;
let cookies = self.save_session(&session, headers).await?;
Ok((session, cookies))
}
async fn load(&self, headers: &http::HeaderMap) -> Result<Option<C>, std::convert::Infallible> {
Ok(self.load_session(headers).await)
}
async fn save(
&self,
session: &C,
headers: &http::HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
self.save_session(session, headers).await
}
async fn touch(
&self,
session: &C,
headers: &http::HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
self.save_session(session, headers).await
}
async fn delete(
&self,
_session: &C,
headers: &http::HeaderMap,
) -> Result<Vec<HeaderValue>, SessionError> {
Ok(self.delete_headers(headers))
}
}
fn split_into_chunks(cookie_value: &str) -> Vec<&str> {
cookie_value
.as_bytes()
.chunks(CHUNK_SIZE)
.map(|c| std::str::from_utf8(c).expect("base64 output is ASCII"))
.collect()
}
fn reassemble_chunks(chunks: &std::collections::HashMap<usize, String>) -> Option<String> {
let first = chunks.get(&0)?;
let mut raw_encoded = String::with_capacity(chunks.len() * CHUNK_SIZE);
raw_encoded.push_str(first);
let mut i = 1;
while let Some(chunk) = chunks.get(&i) {
raw_encoded.push_str(chunk);
i += 1;
}
Some(raw_encoded)
}
#[cfg(test)]
mod tests {
use std::{
convert::Infallible,
time::{Duration, SystemTime},
};
use http::HeaderMap;
use huskarl::core::secrets::{Secret, SecretBytes, SecretOutput};
use huskarl_crypto_native::aead::{AesGcmKey, AesGcmKeyType};
use super::*;
use crate::session_state::SessionState;
#[derive(Clone)]
struct TestSecret(SecretBytes);
impl Secret for TestSecret {
type Output = SecretBytes;
type Error = Infallible;
async fn get_secret_value(&self) -> Result<SecretOutput<SecretBytes>, Infallible> {
Ok(SecretOutput {
value: self.0.clone(),
identity: None,
})
}
}
async fn test_cipher() -> BoxedAeadCipher {
let key = AesGcmKey::from_secret(
AesGcmKeyType::Aes256,
TestSecret(SecretBytes::new(vec![0u8; 32])),
|_| None,
)
.await
.unwrap();
BoxedAeadCipher::new(key)
}
fn test_state() -> SessionState {
let now = SystemTime::now();
SessionState::builder()
.token_expiry(now + Duration::from_hours(1))
.created_at(now)
.last_active(now)
.build()
}
async fn test_store() -> CookieSessionStore<CookieSession> {
CookieSessionStore::builder()
.cipher(test_cipher().await)
.cookie_name("huskarl_session")
.secure(true)
.cookie_path("/")
.build()
}
async fn test_cipher_with_kid(kid: &str) -> BoxedAeadCipher {
let kid_owned = kid.to_owned();
let key = AesGcmKey::from_secret(
AesGcmKeyType::Aes256,
TestSecret(SecretBytes::new(vec![0u8; 32])),
move |_| Some(kid_owned.clone()),
)
.await
.unwrap();
BoxedAeadCipher::new(key)
}
fn request_cookies_from_set_cookies(set_cookies: &[HeaderValue]) -> HeaderMap {
let mut headers = HeaderMap::new();
let mut pairs = Vec::new();
for v in set_cookies {
let s = v.to_str().unwrap();
let pair = s.split(';').next().unwrap();
let (_name, value) = pair.split_once('=').unwrap();
if !value.is_empty() {
pairs.push(pair.to_owned());
}
}
if !pairs.is_empty() {
headers.insert(http::header::COOKIE, pairs.join("; ").parse().unwrap());
}
headers
}
fn request_with_chunk_slots(n: usize) -> HeaderMap {
let mut headers = HeaderMap::new();
if n > 0 {
let pairs: Vec<String> = (0..n).map(|i| format!("huskarl_session.{i}=x")).collect();
headers.insert(http::header::COOKIE, pairs.join("; ").parse().unwrap());
}
headers
}
#[tokio::test]
async fn save_emits_chunk_zero_with_raw_base64_value() {
let store = test_store().await;
let session = CookieSession(test_state());
let cookies = store
.save_session(&session, &HeaderMap::new())
.await
.unwrap();
let chunk0 = cookies[0].to_str().unwrap();
assert!(chunk0.starts_with("huskarl_session.0="), "got: {chunk0}");
let value = chunk0.split('=').nth(1).unwrap().split(';').next().unwrap();
assert!(
!value.contains(':'),
"chunk 0 must not carry a delimiter prefix: {value}"
);
assert!(!value.is_empty(), "chunk 0 must carry payload data");
}
#[tokio::test]
async fn save_sets_security_attributes() {
let store = test_store().await;
let session = CookieSession(test_state());
let cookies = store
.save_session(&session, &HeaderMap::new())
.await
.unwrap();
let chunk0 = cookies[0].to_str().unwrap();
assert!(chunk0.contains("HttpOnly"));
assert!(chunk0.contains("SameSite=Lax"));
assert!(chunk0.contains("Secure"));
assert!(chunk0.contains("Path=/"));
}
#[tokio::test]
async fn save_emits_no_chunk_clears_when_request_has_none() {
let store = test_store().await;
let session = CookieSession(test_state());
let cookies = store
.save_session(&session, &HeaderMap::new())
.await
.unwrap();
let chunk_clears = cookies
.iter()
.filter(|c| {
let s = c.to_str().unwrap();
s.contains("huskarl_session.")
&& !s.starts_with("huskarl_session.kid=")
&& s.contains("Max-Age=0")
})
.count();
assert_eq!(
chunk_clears, 0,
"no chunk slots to clear without prior chunks"
);
}
#[tokio::test]
async fn save_emits_kid_set_when_cipher_has_identity() {
let store = CookieSessionStore::<CookieSession>::builder()
.cipher(test_cipher_with_kid("arn:aws:kms:us-east-1:111:key/abc").await)
.cookie_name("huskarl_session")
.secure(true)
.cookie_path("/")
.build();
let session = CookieSession(test_state());
let cookies = store
.save_session(&session, &HeaderMap::new())
.await
.unwrap();
let expected_value = URL_SAFE_NO_PAD.encode("arn:aws:kms:us-east-1:111:key/abc".as_bytes());
let kid_set = cookies.iter().any(|c| {
let s = c.to_str().unwrap();
s.starts_with(&format!("huskarl_session.kid={expected_value};"))
});
assert!(kid_set, "expected kid sidecar set to base64url(identity)");
}
#[tokio::test]
async fn save_then_load_roundtrips_with_kid_sidecar() {
let store = CookieSessionStore::<CookieSession>::builder()
.cipher(test_cipher_with_kid("test-kid").await)
.cookie_name("huskarl_session")
.secure(true)
.cookie_path("/")
.build();
let session = CookieSession(test_state());
let set_cookies = store
.save_session(&session, &HeaderMap::new())
.await
.unwrap();
let req_headers = request_cookies_from_set_cookies(&set_cookies);
assert_eq!(
get_kid_cookie(&req_headers, "huskarl_session").as_deref(),
Some("test-kid")
);
let loaded = store.load_session(&req_headers).await;
assert!(
loaded.is_some(),
"session should load with kid sidecar present"
);
}
#[tokio::test]
async fn load_falls_back_when_kid_sidecar_is_garbage() {
let store = CookieSessionStore::<CookieSession>::builder()
.cipher(test_cipher_with_kid("test-kid").await)
.cookie_name("huskarl_session")
.secure(true)
.cookie_path("/")
.build();
let session = CookieSession(test_state());
let set_cookies = store
.save_session(&session, &HeaderMap::new())
.await
.unwrap();
let mut req_headers = request_cookies_from_set_cookies(&set_cookies);
let existing = req_headers
.get(http::header::COOKIE)
.unwrap()
.to_str()
.unwrap()
.to_owned();
let stripped: Vec<&str> = existing
.split(';')
.map(str::trim)
.filter(|p| !p.starts_with("huskarl_session.kid="))
.collect();
let combined = format!("{}; huskarl_session.kid=!!!", stripped.join("; "));
req_headers.insert(http::header::COOKIE, combined.parse().unwrap());
assert!(store.load_session(&req_headers).await.is_some());
}
#[tokio::test]
async fn save_emits_kid_clear_when_cipher_has_no_identity() {
let store = test_store().await;
let session = CookieSession(test_state());
let cookies = store
.save_session(&session, &HeaderMap::new())
.await
.unwrap();
let kid_clear = cookies.iter().any(|c| {
let s = c.to_str().unwrap();
s.starts_with("huskarl_session.kid=;") && s.contains("Max-Age=0")
});
assert!(
kid_clear,
"expected kid sidecar clear with no-identity cipher"
);
}
#[tokio::test]
async fn save_clears_only_request_chunks_above_new_count() {
let store = test_store().await;
let session = CookieSession(test_state());
let req = request_with_chunk_slots(5);
let cookies = store.save_session(&session, &req).await.unwrap();
for stale in 1..5 {
let cleared = cookies.iter().any(|c| {
let s = c.to_str().unwrap();
s.starts_with(&format!("huskarl_session.{stale}=;")) && s.contains("Max-Age=0")
});
assert!(cleared, "expected clear for stale slot .{stale}");
}
let zero_clear = cookies.iter().any(|c| {
let s = c.to_str().unwrap();
s.starts_with("huskarl_session.0=;") && s.contains("Max-Age=0")
});
assert!(
!zero_clear,
"slot .0 must not be cleared — it's overwritten with new data",
);
}
#[test]
fn cbor_payload_is_smaller_than_json() {
let state = test_state();
let session = CookieSession(state);
let json = serde_json::to_vec(&session).unwrap();
let mut cbor = Vec::new();
ciborium::into_writer(&session, &mut cbor).unwrap();
assert!(
cbor.len() < json.len(),
"CBOR ({}) should be smaller than JSON ({})",
cbor.len(),
json.len()
);
assert!(
cbor.len() * 100 / json.len() <= 85,
"expected CBOR <=85% of JSON size, got {}% ({} / {})",
cbor.len() * 100 / json.len(),
cbor.len(),
json.len()
);
}
#[tokio::test]
async fn save_then_load_roundtrips_state() {
let store = test_store().await;
let original_state = test_state();
let session = CookieSession(original_state.clone());
let set_cookies = store
.save_session(&session, &HeaderMap::new())
.await
.unwrap();
let req_headers = request_cookies_from_set_cookies(&set_cookies);
let loaded = store
.load_session(&req_headers)
.await
.expect("session loads");
let secs = |t: SystemTime| t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
assert_eq!(
secs(loaded.state().token_expiry),
secs(original_state.token_expiry)
);
assert_eq!(
secs(loaded.state().created_at),
secs(original_state.created_at)
);
assert_eq!(
secs(loaded.state().last_active),
secs(original_state.last_active)
);
}
#[tokio::test]
async fn load_returns_none_when_no_cookies() {
let store = test_store().await;
assert!(store.load_session(&HeaderMap::new()).await.is_none());
}
#[tokio::test]
async fn load_returns_none_for_unrelated_cookies() {
let store = test_store().await;
let mut headers = HeaderMap::new();
headers.insert(
http::header::COOKIE,
"other=value; another=42".parse().unwrap(),
);
assert!(store.load_session(&headers).await.is_none());
}
#[tokio::test]
async fn load_returns_none_when_continuation_chunk_missing() {
let store = test_store().await;
let mut headers = HeaderMap::new();
headers.insert(
http::header::COOKIE,
"huskarl_session.0=AAAA; huskarl_session.2=BBBB"
.parse()
.unwrap(),
);
assert!(store.load_session(&headers).await.is_none());
}
#[tokio::test]
async fn load_returns_none_when_decryption_fails() {
let store = test_store().await;
let mut headers = HeaderMap::new();
headers.insert(
http::header::COOKIE,
"huskarl_session.0=AAAAAAAAAAAA".parse().unwrap(),
);
assert!(store.load_session(&headers).await.is_none());
}
#[tokio::test]
async fn delete_emits_clears_for_every_chunk_slot_the_request_sent() {
let store = test_store().await;
let req = request_with_chunk_slots(5);
let clears = store.delete_headers(&req);
assert_eq!(clears.len(), 6);
for c in &clears {
assert!(c.to_str().unwrap().contains("Max-Age=0"));
}
for i in 0..5 {
let found = clears.iter().any(|c| {
let s = c.to_str().unwrap();
s.starts_with(&format!("huskarl_session.{i}=;"))
});
assert!(found, "expected clear for slot .{i}");
}
let kid_cleared = clears.iter().any(|c| {
let s = c.to_str().unwrap();
s.starts_with("huskarl_session.kid=;")
});
assert!(kid_cleared, "expected kid sidecar clear");
}
#[tokio::test]
async fn delete_emits_only_kid_clear_when_request_has_no_chunks() {
let store = test_store().await;
let clears = store.delete_headers(&HeaderMap::new());
assert_eq!(clears.len(), 1);
let kid = clears.iter().any(|c| {
let s = c.to_str().unwrap();
s.starts_with("huskarl_session.kid=;") && s.contains("Max-Age=0")
});
assert!(kid, "expected kid sidecar clear");
}
#[tokio::test]
async fn parse_chunk_pair_matches_indexed_cookie() {
let store = test_store().await;
assert_eq!(
store.parse_chunk_pair("huskarl_session.3=abc"),
Some((3, "abc".to_owned()))
);
}
#[tokio::test]
async fn parse_chunk_pair_rejects_unrelated_cookie() {
let store = test_store().await;
assert_eq!(store.parse_chunk_pair("other=value"), None);
}
#[tokio::test]
async fn parse_chunk_pair_rejects_base_name_without_index() {
let store = test_store().await;
assert_eq!(store.parse_chunk_pair("huskarl_session=foo"), None);
}
#[tokio::test]
async fn parse_chunk_pair_rejects_non_numeric_suffix() {
let store = test_store().await;
assert_eq!(store.parse_chunk_pair("huskarl_session.abc=foo"), None);
}
#[tokio::test]
async fn parse_chunk_pair_accepts_any_index_within_usize() {
let store = test_store().await;
assert_eq!(
store.parse_chunk_pair("huskarl_session.42=foo"),
Some((42, "foo".to_owned()))
);
assert_eq!(
store.parse_chunk_pair("huskarl_session.1000000=foo"),
Some((1_000_000, "foo".to_owned()))
);
}
#[test]
fn reassemble_returns_none_when_chunk_zero_missing() {
let mut chunks = std::collections::HashMap::new();
chunks.insert(1, "c1".to_owned());
assert!(reassemble_chunks(&chunks).is_none());
}
#[test]
fn reassemble_concatenates_contiguous_chunks() {
let mut chunks = std::collections::HashMap::new();
chunks.insert(0, "c0".to_owned());
chunks.insert(1, "c1".to_owned());
chunks.insert(2, "c2".to_owned());
assert_eq!(reassemble_chunks(&chunks).as_deref(), Some("c0c1c2"));
}
#[test]
fn reassemble_stops_at_first_gap() {
let mut chunks = std::collections::HashMap::new();
chunks.insert(0, "c0".to_owned());
chunks.insert(1, "c1".to_owned());
chunks.insert(3, "stale".to_owned());
assert_eq!(reassemble_chunks(&chunks).as_deref(), Some("c0c1"));
}
#[test]
fn reassemble_handles_many_chunks() {
let mut chunks = std::collections::HashMap::new();
for i in 0..64 {
chunks.insert(i, format!("c{i}"));
}
let out = reassemble_chunks(&chunks).expect("contiguous chunks reassemble");
assert!(out.starts_with("c0"));
assert!(out.ends_with("c63"));
}
}