use crate::{
providers::session::{SessionId, SessionProvider, SessionData},
traits::{Authenticatable, SessionStorage, UserContext},
AuthError, AuthResult,
};
#[derive(Debug, Clone)]
pub struct SessionMiddlewareConfig {
pub cookie_name: String,
pub cookie_domain: Option<String>,
pub cookie_path: String,
pub cookie_http_only: bool,
pub cookie_secure: bool,
pub cookie_same_site: CookieSameSite,
pub require_csrf: bool,
pub skip_paths: Vec<String>,
pub optional: bool,
}
#[derive(Debug, Clone)]
pub enum CookieSameSite {
Strict,
Lax,
None,
}
impl std::fmt::Display for CookieSameSite {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CookieSameSite::Strict => write!(f, "Strict"),
CookieSameSite::Lax => write!(f, "Lax"),
CookieSameSite::None => write!(f, "None"),
}
}
}
impl Default for SessionMiddlewareConfig {
fn default() -> Self {
Self {
cookie_name: "session_id".to_string(),
cookie_domain: None,
cookie_path: "/".to_string(),
cookie_http_only: true,
cookie_secure: false, cookie_same_site: CookieSameSite::Lax,
require_csrf: true,
skip_paths: vec!["/health".to_string(), "/metrics".to_string()],
optional: false,
}
}
}
impl SessionMiddlewareConfig {
pub fn new() -> Self {
Self::default()
}
pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
self.cookie_name = name.into();
self
}
pub fn cookie_domain(mut self, domain: impl Into<String>) -> Self {
self.cookie_domain = Some(domain.into());
self
}
pub fn cookie_path(mut self, path: impl Into<String>) -> Self {
self.cookie_path = path.into();
self
}
pub fn cookie_http_only(mut self, http_only: bool) -> Self {
self.cookie_http_only = http_only;
self
}
pub fn cookie_secure(mut self, secure: bool) -> Self {
self.cookie_secure = secure;
self
}
pub fn cookie_same_site(mut self, same_site: CookieSameSite) -> Self {
self.cookie_same_site = same_site;
self
}
pub fn require_csrf(mut self, require: bool) -> Self {
self.require_csrf = require;
self
}
pub fn skip_paths(mut self, paths: Vec<String>) -> Self {
self.skip_paths = paths;
self
}
pub fn skip_path(mut self, path: impl Into<String>) -> Self {
self.skip_paths.push(path.into());
self
}
pub fn optional(mut self, optional: bool) -> Self {
self.optional = optional;
self
}
pub fn production() -> Self {
Self {
cookie_secure: true,
cookie_same_site: CookieSameSite::Strict,
require_csrf: true,
..Default::default()
}
}
pub fn development() -> Self {
Self {
cookie_secure: false,
require_csrf: false, ..Default::default()
}
}
}
pub struct SessionMiddleware<S, U>
where
S: SessionStorage<SessionId = SessionId, SessionData = SessionData>,
U: Authenticatable,
{
provider: SessionProvider<S, U>,
config: SessionMiddlewareConfig,
}
impl<S, U> SessionMiddleware<S, U>
where
S: SessionStorage<SessionId = SessionId, SessionData = SessionData>,
U: Authenticatable + Clone,
{
pub fn new(
provider: SessionProvider<S, U>,
config: SessionMiddlewareConfig,
) -> Self {
Self {
provider,
config,
}
}
pub fn with_default_config(provider: SessionProvider<S, U>) -> Self {
Self::new(provider, SessionMiddlewareConfig::default())
}
pub fn name(&self) -> &str {
"session"
}
pub fn extract_session_id_from_cookie(&self, cookie_header: &str) -> Option<SessionId> {
for cookie in cookie_header.split(';') {
let cookie = cookie.trim();
if let Some(value) = cookie.strip_prefix(&format!("{}=", self.config.cookie_name)) {
if let Ok(session_id) = SessionId::from_string(value.to_string()) {
return Some(session_id);
}
}
}
None
}
pub fn should_skip_path(&self, path: &str) -> bool {
self.config.skip_paths.iter().any(|skip_path| {
path.starts_with(skip_path)
})
}
pub fn create_cookie_header(&self, session_id: &SessionId, max_age: Option<i64>) -> String {
let mut cookie = format!("{}={}", self.config.cookie_name, session_id);
if let Some(domain) = &self.config.cookie_domain {
cookie.push_str(&format!("; Domain={}", domain));
}
cookie.push_str(&format!("; Path={}", self.config.cookie_path));
if self.config.cookie_http_only {
cookie.push_str("; HttpOnly");
}
if self.config.cookie_secure {
cookie.push_str("; Secure");
}
cookie.push_str(&format!("; SameSite={}", self.config.cookie_same_site));
if let Some(max_age) = max_age {
cookie.push_str(&format!("; Max-Age={}", max_age));
}
cookie
}
pub async fn validate_session(&self, session_id: &SessionId) -> AuthResult<SessionData> {
self.provider.validate_session(session_id).await
}
pub fn create_user_context(&self, session_data: &SessionData) -> UserContext {
UserContext {
user_id: session_data.user_id.clone(),
username: session_data.username.clone(),
roles: session_data.roles.clone(),
permissions: session_data.permissions.clone(),
auth_provider: "session".to_string(),
authenticated_at: session_data.created_at,
expires_at: Some(session_data.expires_at),
additional_data: session_data.metadata.clone(),
}
}
pub fn provider(&self) -> &SessionProvider<S, U> {
&self.provider
}
pub fn config(&self) -> &SessionMiddlewareConfig {
&self.config
}
}
pub struct SessionMiddlewareBuilder<S, U>
where
S: SessionStorage<SessionId = SessionId, SessionData = SessionData>,
U: Authenticatable,
{
provider: Option<SessionProvider<S, U>>,
config: SessionMiddlewareConfig,
}
impl<S, U> SessionMiddlewareBuilder<S, U>
where
S: SessionStorage<SessionId = SessionId, SessionData = SessionData>,
U: Authenticatable,
{
pub fn new() -> Self {
Self {
provider: None,
config: SessionMiddlewareConfig::default(),
}
}
pub fn provider(mut self, provider: SessionProvider<S, U>) -> Self {
self.provider = Some(provider);
self
}
pub fn cookie_name(mut self, name: impl Into<String>) -> Self {
self.config.cookie_name = name.into();
self
}
pub fn optional(mut self) -> Self {
self.config.optional = true;
self
}
pub fn skip_paths(mut self, paths: Vec<String>) -> Self {
self.config.skip_paths = paths;
self
}
pub fn skip_path(mut self, path: impl Into<String>) -> Self {
self.config.skip_paths.push(path.into());
self
}
pub fn production(mut self) -> Self {
self.config = SessionMiddlewareConfig::production();
self
}
pub fn development(mut self) -> Self {
self.config = SessionMiddlewareConfig::development();
self
}
pub fn build(self) -> AuthResult<SessionMiddleware<S, U>>
where
U: Clone,
{
let provider = self.provider
.ok_or_else(|| AuthError::generic_error("Session provider is required"))?;
Ok(SessionMiddleware::new(provider, self.config))
}
}
impl<S, U> Default for SessionMiddlewareBuilder<S, U>
where
S: SessionStorage<SessionId = SessionId, SessionData = SessionData>,
U: Authenticatable,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::session::MemorySessionStorage;
#[derive(Debug, Clone)]
struct MockUser {
id: String,
username: String,
roles: Vec<String>,
permissions: Vec<String>,
}
#[async_trait::async_trait]
impl Authenticatable for MockUser {
type Id = String;
type Credentials = String;
fn id(&self) -> &Self::Id {
&self.id
}
fn username(&self) -> &str {
&self.username
}
fn roles(&self) -> Vec<String> {
self.roles.clone()
}
fn permissions(&self) -> Vec<String> {
self.permissions.clone()
}
async fn verify_credentials(&self, _credentials: &Self::Credentials) -> AuthResult<bool> {
Ok(true)
}
}
#[test]
fn test_session_middleware_config() {
let config = SessionMiddlewareConfig::new()
.cookie_name("test_session")
.cookie_domain("example.com")
.cookie_secure(true)
.require_csrf(false)
.optional(true);
assert_eq!(config.cookie_name, "test_session");
assert_eq!(config.cookie_domain, Some("example.com".to_string()));
assert!(config.cookie_secure);
assert!(!config.require_csrf);
assert!(config.optional);
}
#[test]
fn test_cookie_same_site_display() {
assert_eq!(CookieSameSite::Strict.to_string(), "Strict");
assert_eq!(CookieSameSite::Lax.to_string(), "Lax");
assert_eq!(CookieSameSite::None.to_string(), "None");
}
#[test]
fn test_session_middleware_builder() {
let storage = MemorySessionStorage::new();
let provider: SessionProvider<MemorySessionStorage, MockUser> = SessionProvider::with_default_config(storage);
let builder = SessionMiddlewareBuilder::new()
.provider(provider)
.cookie_name("test_session")
.optional()
.skip_path("/health");
let middleware = builder.build().unwrap();
assert_eq!(middleware.name(), "session");
}
#[tokio::test]
async fn test_cookie_header_creation() {
let storage = MemorySessionStorage::new();
let provider: SessionProvider<MemorySessionStorage, MockUser> = SessionProvider::with_default_config(storage);
let config = SessionMiddlewareConfig::production();
let middleware = SessionMiddleware::new(provider, config);
let session_id = SessionId::generate();
let cookie = middleware.create_cookie_header(&session_id, Some(3600));
assert!(cookie.contains(&format!("session_id={}", session_id)));
assert!(cookie.contains("HttpOnly"));
assert!(cookie.contains("Secure"));
assert!(cookie.contains("SameSite=Strict"));
assert!(cookie.contains("Max-Age=3600"));
}
#[test]
fn test_session_id_extraction() {
let storage = MemorySessionStorage::new();
let provider: SessionProvider<MemorySessionStorage, MockUser> = SessionProvider::with_default_config(storage);
let middleware = SessionMiddleware::with_default_config(provider);
let cookie_header = "other_cookie=value; session_id=short; another=value";
let extracted = middleware.extract_session_id_from_cookie(cookie_header);
assert!(extracted.is_none());
let valid_session_id = "a".repeat(32);
let cookie_header = format!("session_id={}", valid_session_id);
let extracted = middleware.extract_session_id_from_cookie(&cookie_header);
assert!(extracted.is_some());
let complex_cookie = format!("first=value1; session_id={}; last=value2", valid_session_id);
let extracted = middleware.extract_session_id_from_cookie(&complex_cookie);
assert!(extracted.is_some());
}
#[test]
fn test_path_skipping() {
let storage = MemorySessionStorage::new();
let provider: SessionProvider<MemorySessionStorage, MockUser> = SessionProvider::with_default_config(storage);
let middleware = SessionMiddleware::with_default_config(provider);
assert!(middleware.should_skip_path("/health"));
assert!(middleware.should_skip_path("/metrics"));
assert!(!middleware.should_skip_path("/api/users"));
}
}