use parking_lot::Mutex;
use std::collections::HashMap;
use std::fmt;
use std::fmt::Write as _;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use super::extract::Request;
use super::handler::Handler;
use super::response::{Response, StatusCode};
const DEFAULT_COOKIE_NAME: &str = "session_id";
const SESSION_ID_HEX_LEN: usize = 32;
const CSRF_TOKEN_KEY: &str = "__asupersync.csrf_token";
const LAST_ACCESSED_KEY: &str = "__asupersync.last_accessed_unix_secs";
const REGENERATE_FLAG_KEY: &str = "__asupersync.regenerate";
struct RegenerateGuard<'a, S: SessionStore + ?Sized> {
armed: bool,
store: &'a S,
session_handle: Arc<Mutex<SessionData>>,
session_id: String,
is_new: bool,
}
impl<S: SessionStore + ?Sized> RegenerateGuard<'_, S> {
fn disarm(&mut self) {
self.armed = false;
}
}
impl<S: SessionStore + ?Sized> Drop for RegenerateGuard<'_, S> {
fn drop(&mut self) {
if !self.armed {
return;
}
let regenerate_requested = {
let guard = self.session_handle.lock();
guard.get(REGENERATE_FLAG_KEY).is_some()
};
if regenerate_requested && !self.is_new {
if std::thread::panicking() {
let _delete_outcome =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
self.store.delete(&self.session_id);
}));
} else {
self.store.delete(&self.session_id);
}
}
}
}
fn is_state_changing_method(method: &str) -> bool {
matches!(
method.to_ascii_uppercase().as_str(),
"POST" | "PUT" | "PATCH" | "DELETE"
)
}
fn now_unix_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_secs())
}
pub trait SessionStore: Send + Sync + 'static {
fn load(&self, id: &str) -> Option<SessionData>;
fn save(&self, id: &str, data: &SessionData);
fn delete(&self, id: &str);
}
#[derive(Debug, Clone, Default)]
pub struct SessionData {
values: HashMap<String, String>,
modified: bool,
}
impl SessionData {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn get(&self, key: &str) -> Option<&str> {
self.values.get(key).map(String::as_str)
}
pub fn insert(&mut self, key: impl Into<String>, value: impl Into<String>) -> Option<String> {
self.modified = true;
self.values.insert(key.into(), value.into())
}
pub fn remove(&mut self, key: &str) -> Option<String> {
self.modified = true;
self.values.remove(key)
}
#[must_use]
pub fn is_modified(&self) -> bool {
self.modified
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.values.len()
}
#[must_use]
pub fn keys(&self) -> Vec<&str> {
self.values.keys().map(String::as_str).collect()
}
pub fn clear(&mut self) {
self.modified = true;
self.values.clear();
}
}
#[derive(Clone, Default)]
pub struct MemoryStore {
sessions: Arc<Mutex<HashMap<String, SessionData>>>,
}
impl MemoryStore {
#[must_use]
pub fn new() -> Self {
Self {
sessions: Arc::new(Mutex::new(HashMap::new())),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.sessions.lock().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.sessions.lock().is_empty()
}
}
impl fmt::Debug for MemoryStore {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let count = self.sessions.lock().len();
f.debug_struct("MemoryStore")
.field("sessions", &count)
.finish()
}
}
impl SessionStore for MemoryStore {
fn load(&self, id: &str) -> Option<SessionData> {
self.sessions.lock().get(id).cloned()
}
fn save(&self, id: &str, data: &SessionData) {
let mut stored = data.clone();
stored.modified = false;
self.sessions.lock().insert(id.to_string(), stored);
}
fn delete(&self, id: &str) {
self.sessions.lock().remove(id);
}
}
fn generate_session_id() -> Option<String> {
let mut buf = [0u8; 16];
getrandom::fill(&mut buf).ok()?;
let mut hex = String::with_capacity(32);
for b in &buf {
let _ = write!(hex, "{b:02x}");
}
Some(hex)
}
fn is_valid_session_id(id: &str) -> bool {
id.len() == SESSION_ID_HEX_LEN && id.bytes().all(|b| b.is_ascii_hexdigit())
}
fn get_cookie(req: &Request, name: &str) -> Option<String> {
let header = req
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("cookie"))
.map(|(_, v)| v)?;
for pair in header.split(';') {
let pair = pair.trim();
if let Some((k, v)) = pair.split_once('=') {
if k.trim() == name {
return Some(v.trim().to_string());
}
}
}
None
}
fn is_cookie_token_safe(s: &str, allow_eq: bool) -> bool {
s.bytes().all(|b| {
b >= 0x20
&& b != 0x7f
&& b != b';'
&& b != b','
&& b != b'\r'
&& b != b'\n'
&& (allow_eq || b != b'=')
})
}
fn set_cookie_header(name: &str, value: &str, config: &SessionConfig) -> String {
assert!(
is_cookie_token_safe(name, false),
"br-asupersync-uz7oxb: cookie name contains forbidden byte (;,=,CR,LF,control,DEL)"
);
assert!(
is_cookie_token_safe(value, true),
"br-asupersync-uz7oxb: cookie value contains forbidden byte (;,,CR,LF,control,DEL)"
);
assert!(
is_cookie_token_safe(&config.cookie_path, true),
"br-asupersync-uz7oxb: cookie_path contains forbidden byte (;,,CR,LF,control,DEL)"
);
let mut cookie = format!("{name}={value}; Path={}", config.cookie_path);
if config.http_only {
cookie.push_str("; HttpOnly");
}
if config.secure {
cookie.push_str("; Secure");
}
match config.same_site {
SameSite::Strict => cookie.push_str("; SameSite=Strict"),
SameSite::Lax => cookie.push_str("; SameSite=Lax"),
SameSite::None => cookie.push_str("; SameSite=None"),
}
if let Some(max_age) = config.max_age {
let _ = write!(cookie, "; Max-Age={max_age}");
}
cookie
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SameSite {
Strict,
Lax,
None,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SessionConfigError {
SameSiteNoneWithoutSecure,
}
impl fmt::Display for SessionConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SameSiteNoneWithoutSecure => write!(
f,
"session: SameSite=None requires Secure (browsers reject cross-site cookies otherwise)"
),
}
}
}
impl std::error::Error for SessionConfigError {}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub cookie_name: String,
pub cookie_path: String,
pub http_only: bool,
pub secure: bool,
pub same_site: SameSite,
pub max_age: Option<u64>,
pub idle_ttl_seconds: Option<u64>,
pub csrf_protection: bool,
pub allowed_origins: Vec<String>,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
cookie_name: DEFAULT_COOKIE_NAME.to_string(),
cookie_path: "/".to_string(),
http_only: true,
secure: true,
same_site: SameSite::Lax,
max_age: None,
idle_ttl_seconds: None,
csrf_protection: true,
allowed_origins: Vec::new(),
}
}
}
impl SessionConfig {
pub fn validate(&self) -> Result<(), SessionConfigError> {
if self.same_site == SameSite::None && !self.secure {
return Err(SessionConfigError::SameSiteNoneWithoutSecure);
}
Ok(())
}
}
pub struct SessionLayer<S: SessionStore> {
store: Arc<S>,
config: SessionConfig,
}
impl<S: SessionStore> SessionLayer<S> {
pub fn new(store: S) -> Self {
let config = SessionConfig::default();
config
.validate()
.expect("default SessionConfig must validate");
Self {
store: Arc::new(store),
config,
}
}
#[must_use]
pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
self.config.cookie_name = name.into();
self
}
#[must_use]
pub fn cookie_path(mut self, path: impl Into<String>) -> Self {
self.config.cookie_path = path.into();
self
}
#[must_use]
pub fn allowed_origins<I, S2>(mut self, origins: I) -> Self
where
I: IntoIterator<Item = S2>,
S2: Into<String>,
{
self.config.allowed_origins = origins.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn http_only(mut self, value: bool) -> Self {
self.config.http_only = value;
self
}
#[must_use]
pub fn secure(mut self, value: bool) -> Self {
self.config.secure = value;
self
}
#[must_use]
pub fn same_site(mut self, value: SameSite) -> Self {
self.config.same_site = value;
self.config
.validate()
.expect("SessionConfig validation failed (SameSite=None requires Secure)");
self
}
#[must_use]
pub fn max_age(mut self, seconds: u64) -> Self {
self.config.max_age = Some(seconds);
self
}
#[must_use]
pub fn idle_ttl_seconds(mut self, seconds: u64) -> Self {
self.config.idle_ttl_seconds = Some(seconds);
self
}
#[must_use]
pub fn csrf_protection(mut self, enabled: bool) -> Self {
self.config.csrf_protection = enabled;
self
}
pub fn wrap<H: Handler>(self, inner: H) -> SessionMiddleware<S, H> {
self.config
.validate()
.expect("SessionConfig validation failed before wrap");
SessionMiddleware {
inner,
store: self.store,
config: self.config,
}
}
}
impl<S: SessionStore> fmt::Debug for SessionLayer<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SessionLayer")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
pub struct SessionMiddleware<S: SessionStore, H: Handler> {
inner: H,
store: Arc<S>,
config: SessionConfig,
}
impl<S: SessionStore, H: Handler> Handler for SessionMiddleware<S, H> {
fn call(
&self,
cx: &crate::Cx,
mut req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
let cx = cx.clone();
Box::pin(async move {
let (mut session_id, mut is_new) = match get_cookie(&req, &self.config.cookie_name) {
Some(id) if is_valid_session_id(&id) => (id, false),
_ => {
let Some(id) = generate_session_id() else {
return Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Session initialization failed: OS entropy unavailable".to_string(),
);
};
(id, true)
}
};
let mut session_data = if is_new {
SessionData::new()
} else if let Some(data) = self.store.load(&session_id) {
if self.is_idle_expired(&data) {
self.store.delete(&session_id);
let Some(new_id) = generate_session_id() else {
return Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Session renewal failed: OS entropy unavailable".to_string(),
);
};
session_id = new_id;
is_new = true;
SessionData::new()
} else {
data
}
} else {
let Some(new_id) = generate_session_id() else {
return Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Session creation failed: OS entropy unavailable".to_string(),
);
};
session_id = new_id;
is_new = true;
SessionData::new()
};
session_data.insert(LAST_ACCESSED_KEY, now_unix_secs().to_string());
if self.config.csrf_protection && session_data.get(CSRF_TOKEN_KEY).is_none() {
let Some(csrf_token) = generate_session_id() else {
return Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
"CSRF token generation failed: OS entropy unavailable".to_string(),
);
};
session_data.insert(CSRF_TOKEN_KEY, csrf_token);
}
if self.config.csrf_protection
&& is_state_changing_method(&req.method)
&& !self.config.allowed_origins.is_empty()
{
match request_origin(&req) {
None => {
return Response::new(
StatusCode::FORBIDDEN,
crate::bytes::Bytes::from_static(
b"CSRF: missing Origin/Referer header on state-changing request",
),
)
.header("content-type", "text/plain; charset=utf-8");
}
Some(origin) => {
if !origin_is_allowed(&origin, &self.config.allowed_origins) {
return Response::new(
StatusCode::FORBIDDEN,
crate::bytes::Bytes::from_static(
b"CSRF: Origin/Referer not in allow-list",
),
)
.header("content-type", "text/plain; charset=utf-8");
}
}
}
}
if self.config.csrf_protection && is_state_changing_method(&req.method) {
if !is_new {
let header_token = req
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("x-csrf-token"))
.map_or("", |(_, v)| v.as_str());
let session_token = session_data.get(CSRF_TOKEN_KEY).unwrap_or("");
if !constant_time_eq_str(header_token, session_token)
|| session_token.is_empty()
{
return Response::new(
StatusCode::FORBIDDEN,
crate::bytes::Bytes::from_static(b"CSRF token missing or invalid"),
)
.header("content-type", "text/plain; charset=utf-8");
}
}
}
let session_handle = Arc::new(Mutex::new(session_data));
req.extensions
.insert_typed(Session(Arc::clone(&session_handle)));
let mut regenerate_guard = RegenerateGuard {
armed: true,
store: self.store.as_ref(),
session_handle: Arc::clone(&session_handle),
session_id: session_id.clone(),
is_new,
};
let mut resp = self.inner.call(&cx, req).await;
session_data = {
let guard = session_handle.lock();
guard.clone()
};
let regenerate_requested = session_data.get(REGENERATE_FLAG_KEY).is_some();
if regenerate_requested {
session_data.remove(REGENERATE_FLAG_KEY);
if !is_new {
self.store.delete(&session_id);
}
let Some(new_id) = generate_session_id() else {
return Response::new(
StatusCode::INTERNAL_SERVER_ERROR,
"Session regeneration failed: OS entropy unavailable".to_string(),
);
};
session_id = new_id;
is_new = true;
session_data.insert(LAST_ACCESSED_KEY, now_unix_secs().to_string());
}
regenerate_guard.disarm();
let session_cleared = session_data.is_empty() && session_data.is_modified();
if session_cleared {
if !is_new {
self.store.delete(&session_id);
}
} else if session_data.is_modified() || regenerate_requested {
self.store.save(&session_id, &session_data);
}
if session_cleared {
if !is_new {
let mut expire_config = self.config.clone();
expire_config.max_age = Some(0);
let cookie_val =
set_cookie_header(&self.config.cookie_name, "", &expire_config);
resp.append_set_cookie(cookie_val);
}
} else if session_data.is_modified() || regenerate_requested {
let cookie_val =
set_cookie_header(&self.config.cookie_name, &session_id, &self.config);
resp.append_set_cookie(cookie_val);
}
resp
})
}
}
impl<S: SessionStore, H: Handler> SessionMiddleware<S, H> {
fn is_idle_expired(&self, data: &SessionData) -> bool {
let Some(ttl) = self.config.idle_ttl_seconds else {
return false;
};
let Some(last_str) = data.get(LAST_ACCESSED_KEY) else {
return false;
};
let Ok(last) = last_str.parse::<u64>() else {
return false;
};
let now = now_unix_secs();
now.saturating_sub(last) > ttl
}
}
fn request_origin(req: &Request) -> Option<String> {
if let Some((_, origin)) = req
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("origin"))
{
let trimmed = origin.trim();
if !trimmed.is_empty() && !trimmed.eq_ignore_ascii_case("null") {
return Some(trimmed.to_string());
}
}
if let Some((_, referer)) = req
.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("referer"))
{
let r = referer.trim();
let scheme_end = r.find("://")?;
let after_scheme_idx = scheme_end + 3;
let rest = &r[after_scheme_idx..];
let path_start = rest.find('/').unwrap_or(rest.len());
return Some(r[..after_scheme_idx + path_start].to_string());
}
None
}
fn origin_is_allowed(origin: &str, allowed: &[String]) -> bool {
fn normalise(s: &str) -> String {
let trimmed = s.trim().trim_end_matches('/');
trimmed.to_ascii_lowercase()
}
let want = normalise(origin);
allowed.iter().any(|a| normalise(a) == want)
}
fn constant_time_eq_str(a: &str, b: &str) -> bool {
let ab = a.as_bytes();
let bb = b.as_bytes();
if ab.len() != bb.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in ab.iter().zip(bb.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[derive(Clone)]
pub struct Session(Arc<Mutex<SessionData>>);
impl Session {
#[must_use]
pub fn get(&self, key: &str) -> Option<String> {
self.0.lock().get(key).map(ToString::to_string)
}
pub fn insert(&self, key: impl Into<String>, value: impl Into<String>) {
self.0.lock().insert(key, value);
}
#[must_use]
pub fn remove(&self, key: &str) -> Option<String> {
self.0.lock().remove(key)
}
pub fn clear(&self) {
self.0.lock().clear();
}
#[must_use]
pub fn contains(&self, key: &str) -> bool {
self.0.lock().get(key).is_some()
}
#[must_use]
pub fn csrf_token(&self) -> Option<String> {
self.0.lock().get(CSRF_TOKEN_KEY).map(ToString::to_string)
}
pub fn regenerate(&self) -> Option<()> {
let mut guard = self.0.lock();
guard.insert(REGENERATE_FLAG_KEY, "1");
let csrf_token = generate_session_id()?;
guard.insert(CSRF_TOKEN_KEY, csrf_token);
Some(())
}
pub fn rotate_csrf_token(&self) -> Option<String> {
let token = generate_session_id()?;
self.0.lock().insert(CSRF_TOKEN_KEY, token.clone());
Some(token)
}
}
impl fmt::Debug for Session {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let data = self.0.lock();
f.debug_struct("Session")
.field("len", &data.len())
.field("modified", &data.is_modified())
.finish()
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::super::handler::Handler;
use super::super::response::StatusCode;
use super::*;
fn call_sync<H: Handler + ?Sized>(handler: &H, req: Request) -> Response {
futures_lite::future::block_on(Handler::call(handler, &crate::Cx::for_testing(), req))
}
impl<S, H> SessionMiddleware<S, H>
where
S: SessionStore,
H: Handler,
{
fn call(&self, req: Request) -> Response {
call_sync(self, req)
}
}
#[test]
fn session_data_insert_get() {
let mut data = SessionData::new();
assert!(data.is_empty());
assert_eq!(data.len(), 0);
data.insert("user", "alice");
assert_eq!(data.get("user"), Some("alice"));
assert_eq!(data.len(), 1);
assert!(!data.is_empty());
assert!(data.is_modified());
}
#[test]
fn session_data_remove() {
let mut data = SessionData::new();
data.insert("key", "val");
let removed = data.remove("key");
assert_eq!(removed.as_deref(), Some("val"));
assert!(data.is_empty());
}
#[test]
fn session_data_clear() {
let mut data = SessionData::new();
data.insert("a", "1");
data.insert("b", "2");
data.clear();
assert!(data.is_empty());
assert!(data.is_modified());
}
#[test]
fn session_data_keys() {
let mut data = SessionData::new();
data.insert("x", "1");
data.insert("y", "2");
let mut keys = data.keys();
keys.sort_unstable();
assert_eq!(keys, vec!["x", "y"]);
}
#[test]
fn session_data_not_modified_initially() {
let data = SessionData::new();
assert!(!data.is_modified());
}
#[test]
fn session_data_debug_clone() {
let mut data = SessionData::new();
data.insert("k", "v");
let dbg = format!("{data:?}");
assert!(dbg.contains("SessionData"));
let cloned = data.clone();
assert_eq!(cloned.get("k"), Some("v"));
}
#[test]
fn memory_store_save_load() {
let store = MemoryStore::new();
let mut data = SessionData::new();
data.insert("user", "bob");
store.save("sess1", &data);
assert_eq!(store.len(), 1);
let loaded = store.load("sess1").unwrap();
assert_eq!(loaded.get("user"), Some("bob"));
}
#[test]
fn memory_store_delete() {
let store = MemoryStore::new();
store.save("sess1", &SessionData::new());
assert_eq!(store.len(), 1);
store.delete("sess1");
assert!(store.is_empty());
assert!(store.load("sess1").is_none());
}
#[test]
fn memory_store_load_missing() {
let store = MemoryStore::new();
assert!(store.load("nonexistent").is_none());
}
#[test]
fn memory_store_debug_clone() {
let store = MemoryStore::new();
let dbg = format!("{store:?}");
assert!(dbg.contains("MemoryStore"));
}
#[test]
fn memory_store_default() {
let store = MemoryStore::default();
assert!(store.is_empty());
}
#[test]
fn generate_id_is_valid() {
let id = generate_session_id().expect("OS entropy must be available in session-id test");
assert!(is_valid_session_id(&id));
assert_eq!(id.len(), SESSION_ID_HEX_LEN);
}
#[test]
fn generate_id_uniqueness() {
let id1 = generate_session_id().expect("OS entropy must be available in session-id test");
let id2 = generate_session_id().expect("OS entropy must be available in session-id test");
assert_ne!(id1, id2);
}
#[test]
fn validate_session_id() {
assert!(is_valid_session_id("0123456789abcdef0123456789abcdef"));
assert!(!is_valid_session_id("short"));
assert!(!is_valid_session_id("zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"));
assert!(!is_valid_session_id(""));
}
#[test]
fn get_cookie_basic() {
let mut req = Request::new("GET", "/");
req.headers
.insert("cookie".to_string(), "session_id=abc123".to_string());
assert_eq!(get_cookie(&req, "session_id"), Some("abc123".to_string()));
}
#[test]
fn get_cookie_multiple() {
let mut req = Request::new("GET", "/");
req.headers.insert(
"cookie".to_string(),
"foo=bar; session_id=xyz; other=val".to_string(),
);
assert_eq!(get_cookie(&req, "session_id"), Some("xyz".to_string()));
}
#[test]
fn get_cookie_missing() {
let req = Request::new("GET", "/");
assert!(get_cookie(&req, "session_id").is_none());
}
#[test]
fn set_cookie_default_config() {
let config = SessionConfig::default();
let header = set_cookie_header("sid", "val123", &config);
assert!(header.contains("sid=val123"));
assert!(header.contains("Path=/"));
assert!(header.contains("HttpOnly"));
assert!(header.contains("SameSite=Lax"));
assert!(header.contains("Secure"));
}
#[test]
fn set_cookie_secure_strict() {
let config = SessionConfig {
secure: true,
same_site: SameSite::Strict,
max_age: Some(3600),
..Default::default()
};
let header = set_cookie_header("sid", "val", &config);
assert!(header.contains("Secure"));
assert!(header.contains("SameSite=Strict"));
assert!(header.contains("Max-Age=3600"));
}
#[test]
fn session_layer_builder() {
let layer = SessionLayer::new(MemoryStore::new())
.cookie_name("my_session")
.cookie_path("/app")
.http_only(false)
.secure(true)
.same_site(SameSite::None)
.max_age(7200);
assert_eq!(layer.config.cookie_name, "my_session");
assert_eq!(layer.config.cookie_path, "/app");
assert!(!layer.config.http_only);
assert!(layer.config.secure);
assert_eq!(layer.config.same_site, SameSite::None);
assert_eq!(layer.config.max_age, Some(7200));
}
#[test]
fn session_layer_debug() {
let layer = SessionLayer::new(MemoryStore::new());
let dbg = format!("{layer:?}");
assert!(dbg.contains("SessionLayer"));
}
struct TestHandler;
impl Handler for TestHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
Box::pin(async move {
req.extensions.get_typed::<Session>().map_or_else(
|| Response::new(StatusCode::OK, b"no session".to_vec()),
|session| {
let count = session
.get("count")
.and_then(|s| s.parse::<u32>().ok())
.unwrap_or(0);
session.insert("count", (count + 1).to_string());
let body = format!("count={}", count + 1);
Response::new(StatusCode::OK, body.into_bytes())
},
)
})
}
}
#[test]
fn middleware_creates_session_on_first_request() {
let store = MemoryStore::new();
let layer = SessionLayer::new(store.clone());
let handler = layer.wrap(TestHandler);
let req = Request::new("GET", "/");
let resp = handler.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert!(!resp.set_cookies.is_empty());
let cookie = resp.set_cookies.first().unwrap();
assert!(cookie.contains("session_id="));
assert_eq!(store.len(), 1);
}
#[test]
fn middleware_loads_existing_session() {
let store = MemoryStore::new();
let layer = SessionLayer::new(store);
let handler = layer.wrap(TestHandler);
let req1 = Request::new("GET", "/");
let resp1 = handler.call(req1);
let cookie_header = resp1.set_cookies.first().unwrap().clone();
let session_id = cookie_header
.split('=')
.nth(1)
.unwrap()
.split(';')
.next()
.unwrap();
let mut req2 = Request::new("GET", "/");
req2.headers
.insert("cookie".to_string(), format!("session_id={session_id}"));
let resp2 = handler.call(req2);
let body2 = std::str::from_utf8(&resp2.body).unwrap();
assert_eq!(body2, "count=2");
}
#[test]
fn middleware_invalid_session_id_creates_new() {
let store = MemoryStore::new();
let layer = SessionLayer::new(store.clone());
let handler = layer.wrap(TestHandler);
let mut req = Request::new("GET", "/");
req.headers
.insert("cookie".to_string(), "session_id=bad!".to_string());
let resp = handler.call(req);
assert!(!resp.set_cookies.is_empty());
assert_eq!(store.len(), 1);
}
#[test]
fn middleware_fixation_unknown_id_regenerated() {
let store = MemoryStore::new();
let layer = SessionLayer::new(store.clone());
let handler = layer.wrap(TestHandler);
let unknown_attacker_id = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa0"; let mut req = Request::new("GET", "/");
req.headers.insert(
"cookie".to_string(),
format!("session_id={unknown_attacker_id}"),
);
let resp = handler.call(req);
let cookie = resp.set_cookies.first().unwrap();
assert!(
!cookie.contains(unknown_attacker_id),
"must not reuse attacker-supplied ID"
);
assert_eq!(store.len(), 1);
}
#[test]
fn middleware_regenerate_rotates_id_and_preserves_data() {
struct LoginHandler;
impl Handler for LoginHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>>
{
Box::pin(async move {
if let Some(session) = req.extensions.get_typed::<Session>() {
session.insert("user_id", "alice");
session.regenerate();
}
Response::new(StatusCode::OK, b"logged in".to_vec())
})
}
}
let store = MemoryStore::new();
let attacker_planted_id = "1234567890abcdef1234567890abcdef";
let mut pre_auth = SessionData::new();
pre_auth.insert("pre_auth_marker", "still here");
store.save(attacker_planted_id, &pre_auth);
let layer = SessionLayer::new(store.clone()).csrf_protection(false);
let handler = layer.wrap(LoginHandler);
let mut req = Request::new("POST", "/login");
req.headers.insert(
"cookie".to_string(),
format!("session_id={attacker_planted_id}"),
);
let resp = handler.call(req);
assert_eq!(resp.status, StatusCode::OK);
let cookie = resp
.set_cookies
.first()
.expect("middleware must issue Set-Cookie after regenerate");
assert!(
!cookie.contains(attacker_planted_id),
"middleware reused the attacker's planted ID after regenerate(); fixation is OPEN. Set-Cookie: {cookie}"
);
let new_id = cookie
.split('=')
.nth(1)
.expect("malformed Set-Cookie")
.split(';')
.next()
.expect("missing cookie value")
.to_string();
assert_ne!(
new_id, attacker_planted_id,
"regenerate() did not actually rotate the ID"
);
assert_eq!(new_id.len(), 32, "new ID must be 32-char hex");
assert!(
store.load(attacker_planted_id).is_none(),
"old session id must be deleted from store after regenerate()"
);
let migrated = store
.load(&new_id)
.expect("new session id must be persisted");
assert_eq!(
migrated.get("user_id"),
Some("alice"),
"post-login user_id was not preserved across regenerate()"
);
assert!(
migrated.get(REGENERATE_FLAG_KEY).is_none(),
"REGENERATE_FLAG_KEY leaked into persisted session data"
);
}
#[test]
fn middleware_clear_session_expires_cookie() {
struct ClearHandler;
impl Handler for ClearHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>>
{
Box::pin(async move {
if let Some(session) = req.extensions.get_typed::<Session>() {
session.insert("data", "value"); session.clear();
}
Response::new(StatusCode::OK, b"cleared".to_vec())
})
}
}
let store = MemoryStore::new();
let mut seed = SessionData::new();
seed.insert("data", "value");
store.save("abcdef01234567890abcdef012345678", &seed);
let layer = SessionLayer::new(store.clone());
let handler = layer.wrap(ClearHandler);
let mut req = Request::new("GET", "/");
req.headers.insert(
"cookie".to_string(),
"session_id=abcdef01234567890abcdef012345678".to_string(),
);
let resp = handler.call(req);
let cookie = resp.set_cookies.first().unwrap();
assert!(
cookie.contains("Max-Age=0"),
"cookie must be expired on clear"
);
assert!(store.is_empty(), "server-side data must be deleted");
}
#[test]
fn generate_id_uses_crypto_randomness() {
let ids: Vec<String> = (0..100)
.map(|_| {
generate_session_id().expect("OS entropy must be available in session-id test")
})
.collect();
for id in &ids {
assert!(is_valid_session_id(id));
}
let set: std::collections::HashSet<_> = ids.iter().collect();
assert_eq!(set.len(), 100);
}
#[test]
fn session_handle_operations() {
let session = Session(Arc::new(Mutex::new(SessionData::new())));
session.insert("key", "value");
assert!(session.contains("key"));
assert_eq!(session.get("key"), Some("value".to_string()));
let _ = session.remove("key");
assert!(!session.contains("key"));
}
#[test]
fn session_handle_clear() {
let session = Session(Arc::new(Mutex::new(SessionData::new())));
session.insert("a", "1");
session.insert("b", "2");
session.clear();
assert!(!session.contains("a"));
}
#[test]
fn session_handle_debug() {
let session = Session(Arc::new(Mutex::new(SessionData::new())));
let dbg = format!("{session:?}");
assert!(dbg.contains("Session"));
}
#[test]
fn same_site_variants() {
let config_none = SessionConfig {
same_site: SameSite::None,
..Default::default()
};
let header = set_cookie_header("s", "v", &config_none);
assert!(header.contains("SameSite=None"));
}
#[test]
#[should_panic(expected = "br-asupersync-uz7oxb")]
fn cookie_name_with_semicolon_panics() {
let cfg = SessionConfig::default();
let _ = set_cookie_header("evil; HttpOnly=false; X", "v", &cfg);
}
#[test]
#[should_panic(expected = "br-asupersync-uz7oxb")]
fn cookie_value_with_semicolon_panics() {
let cfg = SessionConfig::default();
let _ = set_cookie_header("s", "v; Domain=attacker.com", &cfg);
}
#[test]
#[should_panic(expected = "br-asupersync-uz7oxb")]
fn cookie_path_with_crlf_panics() {
let cfg = SessionConfig {
cookie_path: "/foo\r\nX-Injected: 1".to_string(),
..Default::default()
};
let _ = set_cookie_header("s", "v", &cfg);
}
#[test]
fn cookie_helper_accepts_safe_inputs() {
let cfg = SessionConfig::default();
let h = set_cookie_header("session", "abcd1234", &cfg);
assert!(h.starts_with("session=abcd1234; Path=/"));
}
#[test]
fn rotate_csrf_token_changes_token() {
let session = Session(Arc::new(Mutex::new(SessionData::new())));
session.insert(CSRF_TOKEN_KEY, "old-token");
let new = session
.rotate_csrf_token()
.expect("OS entropy must be available in CSRF rotation test");
assert_ne!(new, "old-token");
assert_eq!(session.csrf_token().as_deref(), Some(new.as_str()));
}
#[test]
fn regenerate_rotates_csrf_and_sets_flag() {
let session = Session(Arc::new(Mutex::new(SessionData::new())));
session.insert(CSRF_TOKEN_KEY, "old-token");
session.regenerate();
let inner = session.0.lock();
assert!(inner.get(REGENERATE_FLAG_KEY).is_some());
let new_csrf = inner.get(CSRF_TOKEN_KEY).unwrap();
assert_ne!(new_csrf, "old-token");
}
fn make_request_with_headers(method: &str, headers: &[(&str, &str)]) -> Request {
let mut h = HashMap::new();
for (k, v) in headers {
h.insert((*k).to_string(), (*v).to_string());
}
Request {
method: method.to_string(),
path: "/api/x".to_string(),
query: None,
headers: h,
body: crate::bytes::Bytes::new(),
path_params: HashMap::new(),
extensions: crate::web::extract::Extensions::new(),
}
}
#[test]
fn referer_origin_strips_path() {
let req = make_request_with_headers(
"POST",
&[("Referer", "https://app.example.com/foo/bar?q=1")],
);
let origin = request_origin(&req);
assert_eq!(origin.as_deref(), Some("https://app.example.com"));
}
#[test]
fn origin_allow_list_match_is_case_insensitive_and_trim_slash() {
let allowed = vec!["https://App.Example.Com/".to_string()];
assert!(origin_is_allowed("https://app.example.com", &allowed));
assert!(origin_is_allowed("HTTPS://APP.EXAMPLE.COM", &allowed));
assert!(!origin_is_allowed("https://attacker.com", &allowed));
}
#[test]
fn origin_header_takes_precedence_over_referer() {
let req = make_request_with_headers(
"POST",
&[
("Origin", "https://app.example.com"),
("Referer", "https://other.example.com/"),
],
);
assert_eq!(
request_origin(&req).as_deref(),
Some("https://app.example.com")
);
}
#[test]
fn null_origin_falls_back_to_referer() {
let req = make_request_with_headers(
"POST",
&[
("Origin", "null"),
("Referer", "https://app.example.com/foo"),
],
);
assert_eq!(
request_origin(&req).as_deref(),
Some("https://app.example.com")
);
}
fn extract_set_cookie_id(cookie_header: &str) -> &str {
cookie_header
.split(';')
.next()
.unwrap()
.split_once('=')
.unwrap()
.1
}
struct PanicAfterRegenerateHandler;
impl Handler for PanicAfterRegenerateHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
Box::pin(async move {
let session = req
.extensions
.get_typed::<Session>()
.expect("middleware injects Session");
session.regenerate();
panic!("simulated handler panic after regenerate");
})
}
}
#[test]
fn regenerate_guard_fails_closed_when_handler_panics() {
let store = MemoryStore::new();
let layer = SessionLayer::new(store.clone()).csrf_protection(false);
let original_id = "0123456789abcdef0123456789abcdef".to_string();
let mut seeded = SessionData::new();
seeded.insert("authed_user", "alice");
store.save(&original_id, &seeded);
assert_eq!(store.len(), 1);
let handler = layer.wrap(PanicAfterRegenerateHandler);
let mut req = Request::new("POST", "/login");
req.headers
.insert("cookie".to_string(), format!("session_id={original_id}"));
let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
handler.call(req);
}));
assert!(
outcome.is_err(),
"handler must propagate the panic — the test relies on it"
);
assert_eq!(
store.len(),
0,
"RegenerateGuard must invalidate the OLD session on the cancel/panic path"
);
}
struct PanicNoRegenerateHandler;
impl Handler for PanicNoRegenerateHandler {
fn call(
&self,
_cx: &crate::Cx,
_req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
Box::pin(async { panic!("simulated handler panic without regenerate") })
}
}
#[test]
fn regenerate_guard_drop_is_noop_without_pending_flag() {
let store = MemoryStore::new();
let layer = SessionLayer::new(store.clone());
let original_id = "fedcba9876543210fedcba9876543210".to_string();
let mut seeded = SessionData::new();
seeded.insert("k", "v");
store.save(&original_id, &seeded);
let handler = layer.wrap(PanicNoRegenerateHandler);
let mut req = Request::new("GET", "/");
req.headers
.insert("cookie".to_string(), format!("session_id={original_id}"));
let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
handler.call(req);
}));
assert!(outcome.is_err());
assert_eq!(
store.len(),
1,
"guard must NOT invalidate sessions that did not request regenerate"
);
}
struct PanicOnDeleteStore;
impl SessionStore for PanicOnDeleteStore {
fn load(&self, _id: &str) -> Option<SessionData> {
None
}
fn save(&self, _id: &str, _data: &SessionData) {}
fn delete(&self, _id: &str) {
std::panic::panic_any("delete backend unavailable");
}
}
#[test]
fn regenerate_guard_drop_suppresses_store_delete_panic_during_handler_unwind() {
let mut data = SessionData::new();
data.insert(REGENERATE_FLAG_KEY, "1");
let session_handle = Arc::new(Mutex::new(data));
let store = PanicOnDeleteStore;
let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = RegenerateGuard {
armed: true,
store: &store,
session_handle,
session_id: "00112233445566778899aabbccddeeff".to_string(),
is_new: false,
};
std::panic::panic_any("handler failed after requesting regenerate");
}));
let panic_payload = outcome.expect_err("the original handler unwind should propagate");
assert_eq!(
panic_payload.downcast_ref::<&str>(),
Some(&"handler failed after requesting regenerate"),
"the guard must preserve the original handler panic"
);
}
#[test]
fn regenerate_guard_drop_surfaces_store_delete_panic_without_handler_unwind() {
let mut data = SessionData::new();
data.insert(REGENERATE_FLAG_KEY, "1");
let session_handle = Arc::new(Mutex::new(data));
let store = PanicOnDeleteStore;
let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = RegenerateGuard {
armed: true,
store: &store,
session_handle,
session_id: "00112233445566778899aabbccddeeff".to_string(),
is_new: false,
};
}));
let panic_payload = outcome.expect_err("normal-path store.delete panic should propagate");
assert_eq!(
panic_payload.downcast_ref::<&str>(),
Some(&"delete backend unavailable"),
"normal-path store.delete panics should not be silently swallowed"
);
}
struct RegenerateAndReturnHandler;
impl Handler for RegenerateAndReturnHandler {
fn call(
&self,
_cx: &crate::Cx,
req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>> {
Box::pin(async move {
let session = req
.extensions
.get_typed::<Session>()
.expect("middleware injects Session");
session.regenerate();
Response::new(StatusCode::OK, b"ok".to_vec())
})
}
}
#[test]
fn regenerate_guard_disarmed_on_happy_path_rotates_normally() {
let store = MemoryStore::new();
let layer = SessionLayer::new(store.clone()).csrf_protection(false);
let original_id = "1111222233334444aaaabbbbccccdddd".to_string();
let mut seeded = SessionData::new();
seeded.insert("authed_user", "bob");
store.save(&original_id, &seeded);
assert_eq!(store.len(), 1);
let handler = layer.wrap(RegenerateAndReturnHandler);
let mut req = Request::new("POST", "/login");
req.headers
.insert("cookie".to_string(), format!("session_id={original_id}"));
let resp = handler.call(req);
assert_eq!(resp.status, StatusCode::OK);
let cookie_header = resp.set_cookies.first().expect("Set-Cookie present");
let new_id = extract_set_cookie_id(cookie_header);
assert_ne!(new_id, original_id, "ID must rotate");
assert_eq!(store.len(), 1, "exactly one entry under the new ID");
assert!(
store.load(&original_id).is_none(),
"original session must be deleted after rotation"
);
assert!(store.load(new_id).is_some(), "new session must be present");
}
#[test]
fn middleware_preserves_inner_handler_set_cookie() {
struct CsrfEmittingHandler;
impl Handler for CsrfEmittingHandler {
fn call(
&self,
_cx: &crate::Cx,
_req: Request,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send + '_>>
{
Box::pin(async {
let mut resp = Response::new(StatusCode::OK, b"ok".to_vec());
resp.append_set_cookie("csrf_token=abc123; HttpOnly; Path=/");
resp
})
}
}
let store = MemoryStore::new();
let layer = SessionLayer::new(store);
let handler = layer.wrap(CsrfEmittingHandler);
let req = Request::new("GET", "/");
let resp = handler.call(req);
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(
resp.set_cookies.len(),
2,
"expected BOTH the handler's CSRF cookie and the middleware's session cookie; \
pre-fix the HashMap-backed header store collapsed them to one. \
Got: {:?}",
resp.set_cookies
);
assert!(
resp.set_cookies
.iter()
.any(|c| c.contains("csrf_token=abc123")),
"inner handler's CSRF cookie must survive the middleware layer; \
got {:?}",
resp.set_cookies
);
assert!(
resp.set_cookies.iter().any(|c| c.contains("session_id=")),
"session cookie must still be emitted alongside CSRF; got {:?}",
resp.set_cookies
);
}
}