use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SessionStorage {
#[default]
Memory,
Redis,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionConfig {
#[serde(default = "default_cookie_name")]
pub cookie_name: String,
#[serde(default = "default_expiry_secs")]
pub expiry_secs: u64,
#[serde(default)]
pub inactivity_timeout_secs: Option<u64>,
#[serde(default = "default_cookie_path")]
pub cookie_path: String,
#[serde(default)]
pub cookie_domain: Option<String>,
#[serde(default = "default_secure")]
pub secure: bool,
#[serde(default = "default_http_only")]
pub http_only: bool,
#[serde(default = "default_same_site")]
pub same_site: String,
#[serde(default)]
pub storage: SessionStorage,
#[serde(default)]
pub redis_url: Option<String>,
#[serde(default)]
pub csrf: CsrfConfig,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
cookie_name: default_cookie_name(),
expiry_secs: default_expiry_secs(),
inactivity_timeout_secs: None,
cookie_path: default_cookie_path(),
cookie_domain: None,
secure: default_secure(),
http_only: default_http_only(),
same_site: default_same_site(),
storage: SessionStorage::default(),
redis_url: None,
csrf: CsrfConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CsrfConfig {
#[serde(default = "default_csrf_enabled")]
pub enabled: bool,
#[serde(default = "default_token_length")]
pub token_length: usize,
#[serde(default = "default_header_name")]
pub header_name: String,
#[serde(default = "default_form_field_name")]
pub form_field_name: String,
#[serde(default = "default_safe_methods")]
pub safe_methods: Vec<String>,
}
impl Default for CsrfConfig {
fn default() -> Self {
Self {
enabled: default_csrf_enabled(),
token_length: default_token_length(),
header_name: default_header_name(),
form_field_name: default_form_field_name(),
safe_methods: default_safe_methods(),
}
}
}
fn default_cookie_name() -> String {
"session_id".to_string()
}
fn default_expiry_secs() -> u64 {
86400 }
fn default_cookie_path() -> String {
"/".to_string()
}
fn default_secure() -> bool {
true
}
fn default_http_only() -> bool {
true
}
fn default_same_site() -> String {
"lax".to_string()
}
fn default_csrf_enabled() -> bool {
true
}
fn default_token_length() -> usize {
32
}
fn default_header_name() -> String {
"X-CSRF-Token".to_string()
}
fn default_form_field_name() -> String {
"_csrf".to_string()
}
fn default_safe_methods() -> Vec<String> {
vec![
"GET".to_string(),
"HEAD".to_string(),
"OPTIONS".to_string(),
"TRACE".to_string(),
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_config_defaults() {
let config = SessionConfig::default();
assert_eq!(config.cookie_name, "session_id");
assert_eq!(config.expiry_secs, 86400);
assert!(config.secure);
assert!(config.http_only);
assert_eq!(config.same_site, "lax");
assert_eq!(config.storage, SessionStorage::Memory);
}
#[test]
fn test_csrf_config_defaults() {
let config = CsrfConfig::default();
assert!(config.enabled);
assert_eq!(config.token_length, 32);
assert_eq!(config.header_name, "X-CSRF-Token");
assert_eq!(config.form_field_name, "_csrf");
assert!(config.safe_methods.contains(&"GET".to_string()));
}
#[test]
fn test_session_storage_serialization() {
let memory = SessionStorage::Memory;
let redis = SessionStorage::Redis;
assert_eq!(serde_json::to_string(&memory).unwrap(), "\"memory\"");
assert_eq!(serde_json::to_string(&redis).unwrap(), "\"redis\"");
}
}