use crate::error::{Result, TidewayError};
use crate::traits::session::{SessionData, SessionStore};
use async_trait::async_trait;
use cookie::{Cookie, CookieJar, Key, SameSite};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
#[derive(Clone)]
pub struct CookieSessionStore {
key: Arc<Key>,
config: crate::session::SessionConfig,
}
impl CookieSessionStore {
pub fn new(config: &crate::session::SessionConfig) -> Result<Self> {
let key = if let Some(ref key_str) = config.encryption_key {
let key_bytes = hex::decode(key_str).map_err(|e| {
TidewayError::internal(format!("Invalid encryption key format: {}", e))
})?;
if key_bytes.len() != 64 {
return Err(TidewayError::internal(
"Encryption key must be 64 bytes (128 hex characters). Generate with: openssl rand -hex 64",
));
}
Key::from(&key_bytes)
} else if config.allow_insecure_key {
tracing::error!(
"┌──────────────────────────────────────────────────────────────────────────────┐"
);
tracing::error!(
"│ SECURITY WARNING: Using randomly generated session encryption key! │"
);
tracing::error!(
"│ │"
);
tracing::error!(
"│ This is INSECURE and should NEVER be used in production: │"
);
tracing::error!(
"│ • Sessions will be invalidated on every server restart │"
);
tracing::error!(
"│ • Sessions won't work across multiple server instances │"
);
tracing::error!(
"│ • Session cookies may be vulnerable to forgery │"
);
tracing::error!(
"│ │"
);
tracing::error!(
"│ To fix: Set SESSION_ENCRYPTION_KEY or config.session.encryption_key │"
);
tracing::error!(
"│ Generate a key with: openssl rand -hex 64 │"
);
tracing::error!(
"└──────────────────────────────────────────────────────────────────────────────┘"
);
Key::generate()
} else {
return Err(TidewayError::internal(
"Cookie sessions require an encryption key. \
Set SESSION_ENCRYPTION_KEY environment variable or config.session.encryption_key. \
Generate a key with: openssl rand -hex 64. \
For development only, set SESSION_ALLOW_INSECURE_KEY=true.",
));
};
Ok(Self {
key: Arc::new(key),
config: config.clone(),
})
}
pub fn encrypt(&self, data: &SessionData) -> Result<String> {
let serialized = serde_json::to_string(data)
.map_err(|e| TidewayError::internal(format!("Failed to serialize session: {}", e)))?;
let mut jar = CookieJar::new();
let cookie = Cookie::new(self.config.cookie_name.clone(), serialized);
jar.private_mut(&self.key).add(cookie);
let encrypted_cookie = jar
.get(&self.config.cookie_name)
.ok_or_else(|| TidewayError::internal("Failed to encrypt session cookie"))?;
Ok(encrypted_cookie.value().to_string())
}
pub fn decrypt(&self, encrypted_value: &str) -> Result<Option<SessionData>> {
let mut jar = CookieJar::new();
let cookie = Cookie::new(self.config.cookie_name.clone(), encrypted_value.to_string());
jar.add_original(cookie);
let decrypted = jar.private(&self.key).get(&self.config.cookie_name);
match decrypted {
Some(cookie) => {
let data: SessionData = serde_json::from_str(cookie.value()).map_err(|e| {
TidewayError::internal(format!("Failed to deserialize session: {}", e))
})?;
Ok(Some(data))
}
None => {
Ok(None)
}
}
}
pub fn build_cookie(&self, data: &SessionData) -> Result<Cookie<'static>> {
let encrypted_value = self.encrypt(data)?;
let ttl = self.session_ttl(data);
let mut cookie_builder = Cookie::build((self.config.cookie_name.clone(), encrypted_value))
.path(self.config.cookie_path.clone())
.http_only(self.config.cookie_http_only)
.secure(self.config.cookie_secure)
.same_site(SameSite::Lax)
.max_age(cookie::time::Duration::seconds(ttl.as_secs() as i64));
if let Some(domain) = &self.config.cookie_domain {
cookie_builder = cookie_builder.domain(domain.clone());
}
let cookie = cookie_builder.build();
Ok(cookie)
}
fn session_ttl(&self, data: &SessionData) -> Duration {
data.expires_at
.duration_since(SystemTime::now())
.unwrap_or(Duration::ZERO)
}
}
#[async_trait]
impl SessionStore for CookieSessionStore {
async fn load(&self, session_id: &str) -> Result<Option<SessionData>> {
Ok(self.decrypt(session_id)?.filter(|data| !data.is_expired()))
}
async fn save(&self, _session_id: &str, data: SessionData) -> Result<()> {
self.encrypt(&data)?;
Ok(())
}
async fn delete(&self, _session_id: &str) -> Result<()> {
Ok(())
}
async fn cleanup_expired(&self) -> Result<usize> {
Ok(0)
}
fn is_healthy(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::session::SessionConfig;
use std::time::Duration;
fn test_config() -> SessionConfig {
SessionConfig {
encryption_key: Some("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string()),
allow_insecure_key: false,
..Default::default()
}
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let config = test_config();
let store = CookieSessionStore::new(&config).unwrap();
let mut data = SessionData::new(Duration::from_secs(3600));
data.set("user_id".to_string(), "12345".to_string());
data.set("role".to_string(), "admin".to_string());
let encrypted = store.encrypt(&data).unwrap();
assert!(!encrypted.contains("12345"));
assert!(!encrypted.contains("admin"));
let decrypted = store.decrypt(&encrypted).unwrap();
assert!(decrypted.is_some());
let decrypted_data = decrypted.unwrap();
assert_eq!(decrypted_data.get("user_id"), Some(&"12345".to_string()));
assert_eq!(decrypted_data.get("role"), Some(&"admin".to_string()));
}
#[test]
fn test_tampered_cookie_rejected() {
let config = test_config();
let store = CookieSessionStore::new(&config).unwrap();
let mut data = SessionData::new(Duration::from_secs(3600));
data.set("user_id".to_string(), "12345".to_string());
let encrypted = store.encrypt(&data).unwrap();
let mut tampered = encrypted.clone();
if tampered.len() > 10 {
let bytes: Vec<char> = tampered.chars().collect();
let mut modified: Vec<char> = bytes.clone();
modified[5] = if bytes[5] == 'a' { 'b' } else { 'a' };
tampered = modified.into_iter().collect();
}
let result = store.decrypt(&tampered).unwrap();
assert!(result.is_none(), "Tampered cookie should not decrypt");
}
#[test]
fn test_different_key_cannot_decrypt() {
let config1 = SessionConfig {
encryption_key: Some("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef".to_string()),
allow_insecure_key: false,
..Default::default()
};
let config2 = SessionConfig {
encryption_key: Some("fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210".to_string()),
allow_insecure_key: false,
..Default::default()
};
let store1 = CookieSessionStore::new(&config1).unwrap();
let store2 = CookieSessionStore::new(&config2).unwrap();
let mut data = SessionData::new(Duration::from_secs(3600));
data.set("secret".to_string(), "sensitive_data".to_string());
let encrypted = store1.encrypt(&data).unwrap();
let result = store2.decrypt(&encrypted).unwrap();
assert!(result.is_none(), "Different key should not decrypt");
}
#[test]
fn test_invalid_key_length_rejected() {
let config = SessionConfig {
encryption_key: Some("too_short".to_string()),
allow_insecure_key: false,
..Default::default()
};
let result = CookieSessionStore::new(&config);
assert!(result.is_err());
}
#[test]
fn test_invalid_hex_rejected() {
let config = SessionConfig {
encryption_key: Some(
"0123456789abcdefg123456789abcdef0123456789abcdef0123456789abcdef".to_string(),
),
allow_insecure_key: false,
..Default::default()
};
let result = CookieSessionStore::new(&config);
assert!(result.is_err());
}
#[test]
fn test_no_key_without_insecure_flag_rejected() {
let config = SessionConfig {
encryption_key: None,
allow_insecure_key: false,
..Default::default()
};
let result = CookieSessionStore::new(&config);
assert!(result.is_err());
}
#[test]
fn test_garbage_input_returns_none() {
let config = test_config();
let store = CookieSessionStore::new(&config).unwrap();
let result = store.decrypt("not_a_valid_encrypted_cookie").unwrap();
assert!(result.is_none());
let result = store.decrypt("").unwrap();
assert!(result.is_none());
}
#[test]
fn test_build_cookie_has_correct_attributes() {
let config = test_config();
let store = CookieSessionStore::new(&config).unwrap();
let mut data = SessionData::new(Duration::from_secs(3600));
data.set("test".to_string(), "value".to_string());
let cookie = store.build_cookie(&data).unwrap();
assert_eq!(cookie.name(), config.cookie_name);
assert_eq!(cookie.path(), Some(config.cookie_path.as_str()));
assert_eq!(cookie.http_only(), Some(config.cookie_http_only));
assert_eq!(cookie.secure(), Some(config.cookie_secure));
}
#[test]
fn test_build_cookie_uses_session_expiry_for_max_age() {
let config = test_config();
let store = CookieSessionStore::new(&config).unwrap();
let data = SessionData::new(Duration::from_secs(120));
let cookie = store.build_cookie(&data).unwrap();
let max_age = cookie.max_age().expect("cookie max age should be set");
assert!(max_age.whole_seconds() <= 120);
assert!(max_age.whole_seconds() > 0);
}
#[tokio::test]
async fn test_session_store_trait() {
let config = test_config();
let store = CookieSessionStore::new(&config).unwrap();
let mut data = SessionData::new(Duration::from_secs(3600));
data.set("session_key".to_string(), "session_value".to_string());
store.save("unused", data.clone()).await.unwrap();
let encrypted = store.encrypt(&data).unwrap();
let loaded = store.load(&encrypted).await.unwrap();
assert!(loaded.is_some());
assert_eq!(
loaded.unwrap().get("session_key"),
Some(&"session_value".to_string())
);
let loaded = store.load("invalid").await.unwrap();
assert!(loaded.is_none());
}
#[tokio::test]
async fn test_session_store_trait_filters_expired_sessions() {
let config = test_config();
let store = CookieSessionStore::new(&config).unwrap();
let data = SessionData::new(Duration::from_millis(5));
let encrypted = store.encrypt(&data).unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
let loaded = store.load(&encrypted).await.unwrap();
assert!(loaded.is_none());
}
}