use crate::activity::{ActivityTracker, ActivityType};
use crate::analytics::{AnalyticsTracker, SecurityEvent, SecurityEventType};
use crate::device::{DeviceInfo, DeviceManager};
use crate::error::SessionError;
use axum::{
extract::{ConnectInfo, Request},
http::{header, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use redis::aio::ConnectionManager;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::{debug, warn};
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct SessionContext {
pub session_id: Uuid,
pub user_id: Uuid,
pub device_id: Option<String>,
pub ip_address: String,
pub user_agent: Option<String>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct OptionalSessionContext(pub Option<SessionContext>);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
pub sub: String,
pub session_id: String,
pub exp: i64,
pub iat: i64,
pub iss: String,
pub aud: String,
}
#[derive(Clone)]
pub struct SessionMiddlewareState {
pub activity_tracker: Arc<ActivityTracker>,
pub analytics_tracker: Arc<AnalyticsTracker>,
pub device_manager: Arc<DeviceManager>,
pub redis: ConnectionManager,
pub jwt_secret: Option<String>,
}
fn extract_token_from_header(req: &Request) -> Option<String> {
req.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok())
.and_then(|h| {
if h.starts_with("Bearer ") {
Some(h[7..].to_string())
} else {
None
}
})
}
fn extract_ip_address(req: &Request) -> String {
if let Some(forwarded) = req.headers().get("x-forwarded-for") {
if let Ok(forwarded_str) = forwarded.to_str() {
if let Some(first_ip) = forwarded_str.split(',').next() {
return first_ip.trim().to_string();
}
}
}
if let Some(real_ip) = req.headers().get("x-real-ip") {
if let Ok(ip_str) = real_ip.to_str() {
return ip_str.to_string();
}
}
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
.unwrap_or_else(|| "unknown".to_string())
}
fn extract_user_agent(req: &Request) -> Option<String> {
req.headers()
.get(header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
}
fn validate_jwt_token(token: &str, _secret: &str) -> Result<JwtClaims, SessionError> {
use base64::{Engine as _, engine::general_purpose};
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(SessionError::InvalidSession("Invalid JWT format".to_string()));
}
let payload = parts[1];
let decoded = general_purpose::URL_SAFE_NO_PAD.decode(payload)
.map_err(|_| SessionError::InvalidSession("Invalid base64 encoding".to_string()))?;
let claims: JwtClaims = serde_json::from_slice(&decoded)
.map_err(|_| SessionError::InvalidSession("Invalid JWT claims".to_string()))?;
let now = chrono::Utc::now().timestamp();
if claims.exp < now {
return Err(SessionError::SessionExpired);
}
Ok(claims)
}
pub async fn auth_middleware(
state: Arc<SessionMiddlewareState>,
mut req: Request,
next: Next,
) -> Result<Response, impl IntoResponse> {
let token = match extract_token_from_header(&req) {
Some(t) => t,
None => {
return Err((
StatusCode::UNAUTHORIZED,
"Missing Authorization header".to_string(),
));
}
};
let jwt_secret = state.jwt_secret.as_deref().unwrap_or("");
let claims = match validate_jwt_token(&token, jwt_secret) {
Ok(c) => c,
Err(e) => {
warn!("JWT validation failed: {:?}", e);
return Err((StatusCode::UNAUTHORIZED, "Invalid token".to_string()));
}
};
let user_id = Uuid::parse_str(&claims.sub)
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid user ID".to_string()))?;
let session_id = Uuid::parse_str(&claims.session_id)
.map_err(|_| (StatusCode::UNAUTHORIZED, "Invalid session ID".to_string()))?;
let mut redis = state.redis.clone();
let ip_address = extract_ip_address(&req);
let user_agent = extract_user_agent(&req);
if let Err(e) = state
.activity_tracker
.record_activity(
session_id,
ActivityType::ApiRequest,
Some(format!("{} {}", req.method(), req.uri().path())),
&mut redis,
)
.await
{
warn!("Failed to record activity: {:?}", e);
}
let context = SessionContext {
session_id,
user_id,
device_id: None, ip_address,
user_agent,
metadata: None,
};
req.extensions_mut().insert(context.clone());
debug!("Session validated for user {} (session {})", user_id, session_id);
Ok(next.run(req).await)
}
pub async fn optional_auth_middleware(
state: Arc<SessionMiddlewareState>,
mut req: Request,
next: Next,
) -> Response {
let token = match extract_token_from_header(&req) {
Some(t) => t,
None => {
req.extensions_mut().insert(OptionalSessionContext(None));
return next.run(req).await;
}
};
let jwt_secret = state.jwt_secret.as_deref().unwrap_or("");
let claims = match validate_jwt_token(&token, jwt_secret) {
Ok(c) => c,
Err(e) => {
warn!("JWT validation failed (optional): {:?}", e);
req.extensions_mut().insert(OptionalSessionContext(None));
return next.run(req).await;
}
};
let user_id = match Uuid::parse_str(&claims.sub) {
Ok(id) => id,
Err(_) => {
req.extensions_mut().insert(OptionalSessionContext(None));
return next.run(req).await;
}
};
let session_id = match Uuid::parse_str(&claims.session_id) {
Ok(id) => id,
Err(_) => {
req.extensions_mut().insert(OptionalSessionContext(None));
return next.run(req).await;
}
};
let mut redis = state.redis.clone();
let ip_address = extract_ip_address(&req);
let user_agent = extract_user_agent(&req);
if let Err(e) = state
.activity_tracker
.record_activity(
session_id,
ActivityType::ApiRequest,
Some(format!("{} {}", req.method(), req.uri().path())),
&mut redis,
)
.await
{
warn!("Failed to record activity: {:?}", e);
}
let context = SessionContext {
session_id,
user_id,
device_id: None,
ip_address,
user_agent,
metadata: None,
};
req.extensions_mut().insert(OptionalSessionContext(Some(context)));
next.run(req).await
}
pub async fn activity_tracking_middleware(
state: Arc<SessionMiddlewareState>,
req: Request,
next: Next,
) -> Response {
let context = req.extensions().get::<SessionContext>().cloned();
if let Some(ctx) = context {
let mut redis = state.redis.clone();
if let Err(e) = state
.activity_tracker
.record_activity(
ctx.session_id,
ActivityType::ApiRequest,
Some(format!("{} {}", req.method(), req.uri().path())),
&mut redis,
)
.await
{
warn!("Failed to record activity: {:?}", e);
}
}
next.run(req).await
}
pub async fn device_tracking_middleware(
state: Arc<SessionMiddlewareState>,
req: Request,
next: Next,
) -> Response {
let context = req.extensions().get::<SessionContext>().cloned();
if let Some(ctx) = context {
if let Some(user_agent) = &ctx.user_agent {
let mut redis = state.redis.clone();
let device_info = DeviceInfo::new(
ctx.device_id.clone().unwrap_or_else(|| "unknown".to_string()),
user_agent.clone(),
ctx.ip_address.clone(),
"unknown".to_string(), );
if let Err(e) = state
.device_manager
.register_device(
device_info.device_id.clone(),
ctx.user_id,
device_info,
&mut redis,
)
.await
{
warn!("Failed to register device: {:?}", e);
}
}
}
next.run(req).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_token() {
let mut req = Request::builder()
.header(header::AUTHORIZATION, "Bearer test_token_123")
.body(())
.unwrap();
let token = extract_token_from_header(&req);
assert_eq!(token, Some("test_token_123".to_string()));
}
#[test]
fn test_extract_user_agent() {
let mut req = Request::builder()
.header(header::USER_AGENT, "Mozilla/5.0")
.body(())
.unwrap();
let ua = extract_user_agent(&req);
assert_eq!(ua, Some("Mozilla/5.0".to_string()));
}
}