use crate::tokens::AuthToken;
use axum::extract::Request;
use chrono::{DateTime, Datelike, Timelike, Utc, Weekday};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::net::IpAddr;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthorizationContext {
pub user_id: String,
pub roles: Vec<String>,
pub session_id: Option<String>,
pub method: String,
pub path: String,
pub ip_address: Option<IpAddr>,
pub user_agent: Option<String>,
pub request_time: DateTime<Utc>,
pub time_of_day: TimeOfDay,
pub day_type: DayType,
pub device_type: DeviceType,
pub connection_type: ConnectionType,
pub security_level: SecurityLevel,
pub risk_score: u8,
pub custom_attributes: HashMap<String, String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TimeOfDay {
BusinessHours,
AfterHours,
Weekend,
Holiday,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum DayType {
Weekday,
Weekend,
Holiday,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum DeviceType {
Desktop,
Mobile,
Tablet,
Unknown,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum ConnectionType {
Direct,
VPN,
Proxy,
Tor,
Corporate,
Unknown,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum SecurityLevel {
Low,
Medium,
High,
Critical,
}
pub struct ContextBuilder {
holidays: Vec<chrono::NaiveDate>,
business_start: u8,
business_end: u8,
corporate_networks: Vec<ipnetwork::IpNetwork>,
}
impl Default for ContextBuilder {
fn default() -> Self {
Self::new()
}
}
impl ContextBuilder {
pub fn new() -> Self {
Self {
holidays: Vec::new(),
business_start: 9,
business_end: 17,
corporate_networks: Vec::new(),
}
}
pub fn with_business_hours(mut self, start: u8, end: u8) -> Self {
self.business_start = start;
self.business_end = end;
self
}
pub fn with_corporate_networks(mut self, networks: Vec<ipnetwork::IpNetwork>) -> Self {
self.corporate_networks = networks;
self
}
pub fn with_holidays(mut self, holidays: Vec<chrono::NaiveDate>) -> Self {
self.holidays = holidays;
self
}
pub fn build_context(&self, request: &Request, auth_token: &AuthToken) -> AuthorizationContext {
let now = Utc::now();
let ip_address = self.extract_ip_address(request);
let user_agent = self.extract_user_agent(request);
AuthorizationContext {
user_id: auth_token.user_id.clone(),
roles: auth_token.roles.clone(),
session_id: auth_token.metadata.session_id.clone(),
method: request.method().to_string(),
path: request.uri().path().to_string(),
ip_address,
user_agent: user_agent.clone(),
request_time: now,
time_of_day: self.classify_time_of_day(now),
day_type: self.classify_day_type(now),
device_type: self.detect_device_type(&user_agent),
connection_type: self.analyze_connection_type(request, &ip_address),
security_level: self.assess_security_level(request),
risk_score: self.calculate_risk_score(request, &ip_address, &user_agent),
custom_attributes: self.extract_custom_attributes(request),
}
}
pub fn to_hashmap(&self, context: &AuthorizationContext) -> HashMap<String, String> {
let mut map = HashMap::new();
map.insert("user_id".to_string(), context.user_id.clone());
map.insert("roles".to_string(), context.roles.join(","));
if let Some(session_id) = &context.session_id {
map.insert("session_id".to_string(), session_id.clone());
}
map.insert("method".to_string(), context.method.clone());
map.insert("path".to_string(), context.path.clone());
if let Some(ip) = &context.ip_address {
map.insert("ip_address".to_string(), ip.to_string());
}
if let Some(ua) = &context.user_agent {
map.insert("user_agent".to_string(), ua.clone());
}
map.insert(
"time_of_day".to_string(),
format!("{:?}", context.time_of_day).to_lowercase(),
);
map.insert(
"day_type".to_string(),
format!("{:?}", context.day_type).to_lowercase(),
);
map.insert(
"request_hour".to_string(),
context.request_time.hour().to_string(),
);
map.insert(
"request_weekday".to_string(),
context.request_time.weekday().to_string(),
);
map.insert(
"device_type".to_string(),
format!("{:?}", context.device_type).to_lowercase(),
);
map.insert(
"connection_type".to_string(),
format!("{:?}", context.connection_type).to_lowercase(),
);
map.insert(
"security_level".to_string(),
format!("{:?}", context.security_level).to_lowercase(),
);
map.insert("risk_score".to_string(), context.risk_score.to_string());
for (key, value) in &context.custom_attributes {
map.insert(format!("custom_{}", key), value.clone());
}
map
}
fn extract_ip_address(&self, request: &Request) -> Option<IpAddr> {
if let Some(forwarded) = request.headers().get("x-forwarded-for")
&& let Ok(forwarded_str) = forwarded.to_str()
{
if let Some(ip_str) = forwarded_str.split(',').next()
&& let Ok(ip) = ip_str.trim().parse()
{
return Some(ip);
}
if let Some(real_ip) = request.headers().get("x-real-ip")
&& let Ok(ip_str) = real_ip.to_str()
&& let Ok(ip) = ip_str.parse()
{
return Some(ip);
}
None
} else {
request
.extensions()
.get::<axum::extract::ConnectInfo<IpAddr>>()
.map(|info| info.0)
}
}
fn extract_user_agent(&self, request: &Request) -> Option<String> {
request
.headers()
.get("user-agent")
.and_then(|ua| ua.to_str().ok())
.map(|s| s.to_string())
}
fn classify_time_of_day(&self, now: DateTime<Utc>) -> TimeOfDay {
let date = now.date_naive();
if self.holidays.contains(&date) {
return TimeOfDay::Holiday;
}
match now.weekday() {
Weekday::Sat | Weekday::Sun => return TimeOfDay::Weekend,
_ => {}
}
let hour = now.hour() as u8;
if hour >= self.business_start && hour < self.business_end {
TimeOfDay::BusinessHours
} else {
TimeOfDay::AfterHours
}
}
fn classify_day_type(&self, now: DateTime<Utc>) -> DayType {
let date = now.date_naive();
if self.holidays.contains(&date) {
DayType::Holiday
} else {
match now.weekday() {
Weekday::Sat | Weekday::Sun => DayType::Weekend,
_ => DayType::Weekday,
}
}
}
fn detect_device_type(&self, user_agent: &Option<String>) -> DeviceType {
let ua = match user_agent {
Some(ua) => ua.to_lowercase(),
None => return DeviceType::Unknown,
};
if ua.contains("mobile") || ua.contains("android") || ua.contains("iphone") {
DeviceType::Mobile
} else if ua.contains("tablet") || ua.contains("ipad") {
DeviceType::Tablet
} else if ua.contains("mozilla") || ua.contains("chrome") || ua.contains("firefox") {
DeviceType::Desktop
} else {
DeviceType::Unknown
}
}
fn analyze_connection_type(
&self,
request: &Request,
ip_address: &Option<IpAddr>,
) -> ConnectionType {
if let Some(via) = request.headers().get("via")
&& let Ok(via_str) = via.to_str()
{
if via_str.to_lowercase().contains("vpn") {
return ConnectionType::VPN;
}
if via_str.to_lowercase().contains("proxy") {
return ConnectionType::Proxy;
}
if let Some(ua) = request.headers().get("user-agent")
&& let Ok(ua_str) = ua.to_str()
&& ua_str.contains("Tor")
{
return ConnectionType::Tor;
}
if let Some(ip) = ip_address {
for network in &self.corporate_networks {
if network.contains(*ip) {
return ConnectionType::Corporate;
}
}
}
return ConnectionType::Direct;
}
ConnectionType::Unknown
}
fn assess_security_level(&self, request: &Request) -> SecurityLevel {
let path = request.uri().path();
match path {
_ if path.starts_with("/admin/system/") => SecurityLevel::Critical,
_ if path.starts_with("/admin/") => SecurityLevel::High,
_ if path.contains("/secrets/") => SecurityLevel::Critical,
_ if path.contains("/keys/") => SecurityLevel::High,
_ if path.starts_with("/api/") => SecurityLevel::Medium,
_ => SecurityLevel::Low,
}
}
fn calculate_risk_score(
&self,
request: &Request,
ip_address: &Option<IpAddr>,
user_agent: &Option<String>,
) -> u8 {
let mut risk_score = 0u8;
let path = request.uri().path();
if path.starts_with("/admin/") {
risk_score += 30;
} else if path.contains("/secrets/") || path.contains("/keys/") {
risk_score += 40;
} else if path.starts_with("/api/") {
risk_score += 10;
}
let connection_type = self.analyze_connection_type(request, ip_address);
match connection_type {
ConnectionType::Tor => risk_score += 50,
ConnectionType::VPN => risk_score += 20,
ConnectionType::Proxy => risk_score += 15,
ConnectionType::Corporate => risk_score = risk_score.saturating_sub(10),
ConnectionType::Direct => {}
ConnectionType::Unknown => risk_score += 10,
}
let device_type = self.detect_device_type(user_agent);
match device_type {
DeviceType::Mobile => risk_score += 5,
DeviceType::Unknown => risk_score += 15,
_ => {}
}
let now = Utc::now();
match self.classify_time_of_day(now) {
TimeOfDay::AfterHours => risk_score += 10,
TimeOfDay::Weekend => risk_score += 5,
_ => {}
}
if user_agent.is_none() {
risk_score += 20;
}
risk_score.min(100)
}
fn extract_custom_attributes(&self, request: &Request) -> HashMap<String, String> {
let mut attributes = HashMap::new();
for (name, value) in request.headers() {
let name_str = name.as_str().to_lowercase();
if let Some(attr_name) = name_str.strip_prefix("x-auth-")
&& let Ok(value_str) = value.to_str()
{
attributes.insert(attr_name.to_string(), value_str.to_string());
}
}
if let Some(query) = request.uri().query() {
for pair in query.split('&') {
if let Some((key, value)) = pair.split_once('=')
&& key.starts_with("ctx_")
{
attributes.insert(
key.strip_prefix("ctx_").unwrap().to_string(),
urlencoding::decode(value).unwrap_or_default().to_string(),
);
}
}
}
attributes
}
pub fn enrich_context(&self, mut context: AuthorizationContext) -> AuthorizationContext {
let current_risk = context.risk_score;
context.risk_score = std::cmp::max(current_risk, 1);
let now = chrono::Utc::now();
context
.custom_attributes
.insert("enriched_timestamp".to_string(), now.to_rfc3339());
context.custom_attributes.insert(
"security_assessment".to_string(),
match context.security_level {
SecurityLevel::Low => "basic".to_string(),
SecurityLevel::Medium => "standard".to_string(),
SecurityLevel::High => "enhanced".to_string(),
SecurityLevel::Critical => "maximum".to_string(),
},
);
context
}
}
pub struct ConditionalEvaluator {
context_builder: ContextBuilder,
}
impl ConditionalEvaluator {
pub fn new(context_builder: ContextBuilder) -> Self {
Self { context_builder }
}
pub fn evaluate_time_conditions(
&self,
context: &AuthorizationContext,
conditions: &HashMap<String, String>,
) -> bool {
if let Some(require_business_hours) = conditions.get("require_business_hours")
&& require_business_hours == "true"
{
match context.time_of_day {
TimeOfDay::BusinessHours => {}
_ => return false,
}
}
if let Some(require_weekday) = conditions.get("require_weekday")
&& require_weekday == "true"
{
match context.day_type {
DayType::Weekday => {}
_ => return false,
}
}
true
}
pub fn evaluate_location_conditions(
&self,
context: &AuthorizationContext,
conditions: &HashMap<String, String>,
) -> bool {
if let Some(require_corporate) = conditions.get("require_corporate_network")
&& require_corporate == "true"
{
match context.connection_type {
ConnectionType::Corporate => {}
_ => return false,
}
}
if let Some(block_vpn) = conditions.get("block_vpn")
&& block_vpn == "true"
{
match context.connection_type {
ConnectionType::VPN | ConnectionType::Tor => return false,
_ => {}
}
}
true
}
pub fn evaluate_device_conditions(
&self,
context: &AuthorizationContext,
conditions: &HashMap<String, String>,
) -> bool {
if let Some(allowed_devices) = conditions.get("allowed_device_types") {
let allowed: Vec<&str> = allowed_devices.split(',').collect();
let device_str = format!("{:?}", context.device_type).to_lowercase();
if !allowed.contains(&device_str.as_str()) {
return false;
}
}
true
}
pub fn evaluate_risk_conditions(
&self,
context: &AuthorizationContext,
conditions: &HashMap<String, String>,
) -> bool {
if let Some(max_risk_str) = conditions.get("max_risk_score")
&& let Ok(max_risk) = max_risk_str.parse::<u8>()
&& context.risk_score > max_risk
{
return false;
}
true
}
pub fn evaluate_conditional_permission(
&self,
context: &AuthorizationContext,
permission_conditions: &HashMap<String, String>,
) -> bool {
tracing::debug!(
"Evaluating conditional permission with conditions: {:?}",
permission_conditions
);
if permission_conditions.is_empty() {
return true;
}
let _enriched_context = self.context_builder.enrich_context(context.clone());
let time_check = self.evaluate_time_conditions(context, permission_conditions);
let location_check = self.evaluate_location_conditions(context, permission_conditions);
let device_check = self.evaluate_device_conditions(context, permission_conditions);
let risk_check = self.evaluate_risk_conditions(context, permission_conditions);
let result = time_check && location_check && device_check && risk_check;
tracing::info!(
"Conditional evaluation result: {} (time: {}, location: {}, device: {}, risk: {})",
result,
time_check,
location_check,
device_check,
risk_check
);
result
}
pub fn evaluate_all_conditions(
&self,
context: &AuthorizationContext,
conditions: &HashMap<String, String>,
) -> bool {
self.evaluate_time_conditions(context, conditions)
&& self.evaluate_location_conditions(context, conditions)
&& self.evaluate_device_conditions(context, conditions)
&& self.evaluate_risk_conditions(context, conditions)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_builder_creation() {
let builder = ContextBuilder::new()
.with_business_hours(8, 18)
.with_holidays(vec![chrono::NaiveDate::from_ymd_opt(2024, 12, 25).unwrap()]);
assert_eq!(builder.business_start, 8);
assert_eq!(builder.business_end, 18);
assert_eq!(builder.holidays.len(), 1);
}
#[test]
fn test_time_classification() {
let builder = ContextBuilder::new();
let business_time = chrono::Utc::now()
.with_hour(14)
.unwrap()
.with_minute(0)
.unwrap();
match business_time.weekday() {
Weekday::Sat | Weekday::Sun => {
assert!(matches!(
builder.classify_time_of_day(business_time),
TimeOfDay::Weekend
));
}
_ => {
assert!(matches!(
builder.classify_time_of_day(business_time),
TimeOfDay::BusinessHours
));
}
}
}
#[test]
fn test_device_detection() {
let builder = ContextBuilder::new();
let mobile_ua = Some("Mozilla/5.0 (iPhone; CPU iPhone OS 14_0 like Mac OS X)".to_string());
assert!(matches!(
builder.detect_device_type(&mobile_ua),
DeviceType::Mobile
));
let desktop_ua =
Some("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36".to_string());
assert!(matches!(
builder.detect_device_type(&desktop_ua),
DeviceType::Desktop
));
assert!(matches!(
builder.detect_device_type(&None),
DeviceType::Unknown
));
}
#[test]
fn test_risk_calculation() {
let _builder = ContextBuilder::new();
}
}