auth_framework/authorization_enhanced/
context.rs

1//! Authorization context builders for enhanced RBAC
2//!
3//! This module provides utilities to build context objects for conditional
4//! permissions and request-specific authorization decisions.
5
6use crate::tokens::AuthToken;
7use axum::extract::Request;
8use chrono::{DateTime, Datelike, Timelike, Utc, Weekday};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::net::IpAddr;
12
13/// Request context for authorization decisions
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AuthorizationContext {
16    /// User information
17    pub user_id: String,
18    pub roles: Vec<String>,
19    pub session_id: Option<String>,
20
21    /// Request metadata
22    pub method: String,
23    pub path: String,
24    pub ip_address: Option<IpAddr>,
25    pub user_agent: Option<String>,
26
27    /// Time-based context
28    pub request_time: DateTime<Utc>,
29    pub time_of_day: TimeOfDay,
30    pub day_type: DayType,
31
32    /// Device and connection info
33    pub device_type: DeviceType,
34    pub connection_type: ConnectionType,
35
36    /// Security context
37    pub security_level: SecurityLevel,
38    pub risk_score: u8, // 0-100
39
40    /// Custom attributes
41    pub custom_attributes: HashMap<String, String>,
42}
43
44/// Time of day classification
45#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
46pub enum TimeOfDay {
47    BusinessHours,
48    AfterHours,
49    Weekend,
50    Holiday,
51}
52
53/// Day type classification
54#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
55pub enum DayType {
56    Weekday,
57    Weekend,
58    Holiday,
59}
60
61/// Device type detection
62#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
63pub enum DeviceType {
64    Desktop,
65    Mobile,
66    Tablet,
67    Unknown,
68}
69
70/// Connection type analysis
71#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
72pub enum ConnectionType {
73    Direct,
74    VPN,
75    Proxy,
76    Tor,
77    Corporate,
78    Unknown,
79}
80
81/// Security level assessment
82#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
83pub enum SecurityLevel {
84    Low,
85    Medium,
86    High,
87    Critical,
88}
89
90/// Enhanced context builder for authorization decisions
91pub struct ContextBuilder {
92    /// Known holiday dates (could be loaded from config)
93    holidays: Vec<chrono::NaiveDate>,
94    /// Business hours configuration
95    business_start: u8,
96    business_end: u8,
97    /// IP ranges for corporate networks
98    corporate_networks: Vec<ipnetwork::IpNetwork>,
99}
100
101impl Default for ContextBuilder {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl ContextBuilder {
108    /// Create a new context builder with default settings
109    pub fn new() -> Self {
110        Self {
111            holidays: Vec::new(),
112            business_start: 9,
113            business_end: 17,
114            corporate_networks: Vec::new(),
115        }
116    }
117
118    /// Configure business hours
119    pub fn with_business_hours(mut self, start: u8, end: u8) -> Self {
120        self.business_start = start;
121        self.business_end = end;
122        self
123    }
124
125    /// Add corporate network ranges
126    pub fn with_corporate_networks(mut self, networks: Vec<ipnetwork::IpNetwork>) -> Self {
127        self.corporate_networks = networks;
128        self
129    }
130
131    /// Add holiday dates
132    pub fn with_holidays(mut self, holidays: Vec<chrono::NaiveDate>) -> Self {
133        self.holidays = holidays;
134        self
135    }
136
137    /// Build authorization context from request and auth token
138    pub fn build_context(&self, request: &Request, auth_token: &AuthToken) -> AuthorizationContext {
139        let now = Utc::now();
140        let ip_address = self.extract_ip_address(request);
141        let user_agent = self.extract_user_agent(request);
142
143        AuthorizationContext {
144            user_id: auth_token.user_id.clone(),
145            roles: auth_token.roles.clone(),
146            session_id: auth_token.metadata.session_id.clone(),
147
148            method: request.method().to_string(),
149            path: request.uri().path().to_string(),
150            ip_address,
151            user_agent: user_agent.clone(),
152
153            request_time: now,
154            time_of_day: self.classify_time_of_day(now),
155            day_type: self.classify_day_type(now),
156
157            device_type: self.detect_device_type(&user_agent),
158            connection_type: self.analyze_connection_type(request, &ip_address),
159
160            security_level: self.assess_security_level(request),
161            risk_score: self.calculate_risk_score(request, &ip_address, &user_agent),
162
163            custom_attributes: self.extract_custom_attributes(request),
164        }
165    }
166
167    /// Convert context to HashMap for role-system compatibility
168    pub fn to_hashmap(&self, context: &AuthorizationContext) -> HashMap<String, String> {
169        let mut map = HashMap::new();
170
171        // User context
172        map.insert("user_id".to_string(), context.user_id.clone());
173        map.insert("roles".to_string(), context.roles.join(","));
174        if let Some(session_id) = &context.session_id {
175            map.insert("session_id".to_string(), session_id.clone());
176        }
177
178        // Request context
179        map.insert("method".to_string(), context.method.clone());
180        map.insert("path".to_string(), context.path.clone());
181        if let Some(ip) = &context.ip_address {
182            map.insert("ip_address".to_string(), ip.to_string());
183        }
184        if let Some(ua) = &context.user_agent {
185            map.insert("user_agent".to_string(), ua.clone());
186        }
187
188        // Time context
189        map.insert(
190            "time_of_day".to_string(),
191            format!("{:?}", context.time_of_day).to_lowercase(),
192        );
193        map.insert(
194            "day_type".to_string(),
195            format!("{:?}", context.day_type).to_lowercase(),
196        );
197        map.insert(
198            "request_hour".to_string(),
199            context.request_time.hour().to_string(),
200        );
201        map.insert(
202            "request_weekday".to_string(),
203            context.request_time.weekday().to_string(),
204        );
205
206        // Device and connection
207        map.insert(
208            "device_type".to_string(),
209            format!("{:?}", context.device_type).to_lowercase(),
210        );
211        map.insert(
212            "connection_type".to_string(),
213            format!("{:?}", context.connection_type).to_lowercase(),
214        );
215
216        // Security context
217        map.insert(
218            "security_level".to_string(),
219            format!("{:?}", context.security_level).to_lowercase(),
220        );
221        map.insert("risk_score".to_string(), context.risk_score.to_string());
222
223        // Custom attributes
224        for (key, value) in &context.custom_attributes {
225            map.insert(format!("custom_{}", key), value.clone());
226        }
227
228        map
229    }
230
231    /// Extract IP address from request headers
232    fn extract_ip_address(&self, request: &Request) -> Option<IpAddr> {
233        // Try X-Forwarded-For first
234        if let Some(forwarded) = request.headers().get("x-forwarded-for")
235            && let Ok(forwarded_str) = forwarded.to_str()
236        {
237            if let Some(ip_str) = forwarded_str.split(',').next()
238                && let Ok(ip) = ip_str.trim().parse()
239            {
240                return Some(ip);
241            }
242
243            // Try X-Real-IP
244            if let Some(real_ip) = request.headers().get("x-real-ip")
245                && let Ok(ip_str) = real_ip.to_str()
246                && let Ok(ip) = ip_str.parse()
247            {
248                return Some(ip);
249            }
250
251            // Could also get from connection info if available
252            None
253        } else {
254            // Fallback to remote address if no headers found
255            request
256                .extensions()
257                .get::<axum::extract::ConnectInfo<IpAddr>>()
258                .map(|info| info.0)
259        }
260    }
261
262    /// Extract user agent from request headers
263    fn extract_user_agent(&self, request: &Request) -> Option<String> {
264        request
265            .headers()
266            .get("user-agent")
267            .and_then(|ua| ua.to_str().ok())
268            .map(|s| s.to_string())
269    }
270
271    /// Classify time of day based on business hours and holidays
272    fn classify_time_of_day(&self, now: DateTime<Utc>) -> TimeOfDay {
273        let date = now.date_naive();
274
275        // Check if it's a holiday
276        if self.holidays.contains(&date) {
277            return TimeOfDay::Holiday;
278        }
279
280        // Check if it's weekend
281        match now.weekday() {
282            Weekday::Sat | Weekday::Sun => return TimeOfDay::Weekend,
283            _ => {}
284        }
285
286        // Check business hours
287        let hour = now.hour() as u8;
288        if hour >= self.business_start && hour < self.business_end {
289            TimeOfDay::BusinessHours
290        } else {
291            TimeOfDay::AfterHours
292        }
293    }
294
295    /// Classify day type
296    fn classify_day_type(&self, now: DateTime<Utc>) -> DayType {
297        let date = now.date_naive();
298
299        if self.holidays.contains(&date) {
300            DayType::Holiday
301        } else {
302            match now.weekday() {
303                Weekday::Sat | Weekday::Sun => DayType::Weekend,
304                _ => DayType::Weekday,
305            }
306        }
307    }
308
309    /// Detect device type from user agent
310    fn detect_device_type(&self, user_agent: &Option<String>) -> DeviceType {
311        let ua = match user_agent {
312            Some(ua) => ua.to_lowercase(),
313            None => return DeviceType::Unknown,
314        };
315
316        if ua.contains("mobile") || ua.contains("android") || ua.contains("iphone") {
317            DeviceType::Mobile
318        } else if ua.contains("tablet") || ua.contains("ipad") {
319            DeviceType::Tablet
320        } else if ua.contains("mozilla") || ua.contains("chrome") || ua.contains("firefox") {
321            DeviceType::Desktop
322        } else {
323            DeviceType::Unknown
324        }
325    }
326
327    /// Analyze connection type from headers and IP
328    fn analyze_connection_type(
329        &self,
330        request: &Request,
331        ip_address: &Option<IpAddr>,
332    ) -> ConnectionType {
333        // Check for VPN indicators in headers
334        if let Some(via) = request.headers().get("via")
335            && let Ok(via_str) = via.to_str()
336        {
337            if via_str.to_lowercase().contains("vpn") {
338                return ConnectionType::VPN;
339            }
340            if via_str.to_lowercase().contains("proxy") {
341                return ConnectionType::Proxy;
342            }
343
344            // Check for Tor indicators
345            if let Some(ua) = request.headers().get("user-agent")
346                && let Ok(ua_str) = ua.to_str()
347                && ua_str.contains("Tor")
348            {
349                return ConnectionType::Tor;
350            }
351
352            // Check if IP is in corporate network range
353            if let Some(ip) = ip_address {
354                for network in &self.corporate_networks {
355                    if network.contains(*ip) {
356                        return ConnectionType::Corporate;
357                    }
358                }
359            }
360
361            return ConnectionType::Direct;
362        }
363        // Fallback to unknown if no indicators found
364        ConnectionType::Unknown
365    }
366
367    /// Assess security level based on endpoint
368    fn assess_security_level(&self, request: &Request) -> SecurityLevel {
369        let path = request.uri().path();
370
371        match path {
372            _ if path.starts_with("/admin/system/") => SecurityLevel::Critical,
373            _ if path.starts_with("/admin/") => SecurityLevel::High,
374            _ if path.contains("/secrets/") => SecurityLevel::Critical,
375            _ if path.contains("/keys/") => SecurityLevel::High,
376            _ if path.starts_with("/api/") => SecurityLevel::Medium,
377            _ => SecurityLevel::Low,
378        }
379    }
380
381    /// Calculate risk score (0-100)
382    fn calculate_risk_score(
383        &self,
384        request: &Request,
385        ip_address: &Option<IpAddr>,
386        user_agent: &Option<String>,
387    ) -> u8 {
388        let mut risk_score = 0u8;
389
390        // Base risk from endpoint
391        let path = request.uri().path();
392        if path.starts_with("/admin/") {
393            risk_score += 30;
394        } else if path.contains("/secrets/") || path.contains("/keys/") {
395            risk_score += 40;
396        } else if path.starts_with("/api/") {
397            risk_score += 10;
398        }
399
400        // Risk from connection type
401        let connection_type = self.analyze_connection_type(request, ip_address);
402        match connection_type {
403            ConnectionType::Tor => risk_score += 50,
404            ConnectionType::VPN => risk_score += 20,
405            ConnectionType::Proxy => risk_score += 15,
406            ConnectionType::Corporate => risk_score = risk_score.saturating_sub(10),
407            ConnectionType::Direct => {}
408            ConnectionType::Unknown => risk_score += 10,
409        }
410
411        // Risk from device type
412        let device_type = self.detect_device_type(user_agent);
413        match device_type {
414            DeviceType::Mobile => risk_score += 5,
415            DeviceType::Unknown => risk_score += 15,
416            _ => {}
417        }
418
419        // Risk from time
420        let now = Utc::now();
421        match self.classify_time_of_day(now) {
422            TimeOfDay::AfterHours => risk_score += 10,
423            TimeOfDay::Weekend => risk_score += 5,
424            _ => {}
425        }
426
427        // Missing user agent is suspicious
428        if user_agent.is_none() {
429            risk_score += 20;
430        }
431
432        // Cap at 100
433        risk_score.min(100)
434    }
435
436    /// Extract custom attributes from headers
437    fn extract_custom_attributes(&self, request: &Request) -> HashMap<String, String> {
438        let mut attributes = HashMap::new();
439
440        // Extract custom headers starting with X-Auth-
441        for (name, value) in request.headers() {
442            let name_str = name.as_str().to_lowercase();
443            if let Some(attr_name) = name_str.strip_prefix("x-auth-")
444                && let Ok(value_str) = value.to_str()
445            {
446                attributes.insert(attr_name.to_string(), value_str.to_string());
447            }
448        }
449
450        // Extract query parameters for additional context
451        if let Some(query) = request.uri().query() {
452            for pair in query.split('&') {
453                if let Some((key, value)) = pair.split_once('=')
454                    && key.starts_with("ctx_")
455                {
456                    attributes.insert(
457                        key.strip_prefix("ctx_").unwrap().to_string(),
458                        urlencoding::decode(value).unwrap_or_default().to_string(),
459                    );
460                }
461            }
462        }
463
464        attributes
465    }
466
467    /// Enrich existing context with additional computed attributes
468    pub fn enrich_context(&self, mut context: AuthorizationContext) -> AuthorizationContext {
469        // Add computed risk factors
470        let current_risk = context.risk_score;
471        context.risk_score = std::cmp::max(current_risk, 1); // Ensure minimum risk
472
473        // Add time-based enrichment
474        let now = chrono::Utc::now();
475        context
476            .custom_attributes
477            .insert("enriched_timestamp".to_string(), now.to_rfc3339());
478
479        // Add security level enhancement
480        context.custom_attributes.insert(
481            "security_assessment".to_string(),
482            match context.security_level {
483                SecurityLevel::Low => "basic".to_string(),
484                SecurityLevel::Medium => "standard".to_string(),
485                SecurityLevel::High => "enhanced".to_string(),
486                SecurityLevel::Critical => "maximum".to_string(),
487            },
488        );
489
490        context
491    }
492}
493
494/// Conditional permission evaluator
495/// PRODUCTION FIX: Implemented conditional evaluation for enterprise security requirements
496pub struct ConditionalEvaluator {
497    context_builder: ContextBuilder,
498}
499
500impl ConditionalEvaluator {
501    /// Create new conditional evaluator
502    pub fn new(context_builder: ContextBuilder) -> Self {
503        Self { context_builder }
504    }
505
506    /// Evaluate time-based conditions
507    pub fn evaluate_time_conditions(
508        &self,
509        context: &AuthorizationContext,
510        conditions: &HashMap<String, String>,
511    ) -> bool {
512        // Check business hours requirement
513        if let Some(require_business_hours) = conditions.get("require_business_hours")
514            && require_business_hours == "true"
515        {
516            match context.time_of_day {
517                TimeOfDay::BusinessHours => {}
518                _ => return false,
519            }
520        }
521
522        // Check weekday requirement
523        if let Some(require_weekday) = conditions.get("require_weekday")
524            && require_weekday == "true"
525        {
526            match context.day_type {
527                DayType::Weekday => {}
528                _ => return false,
529            }
530        }
531
532        true
533    }
534
535    /// Evaluate location-based conditions
536    pub fn evaluate_location_conditions(
537        &self,
538        context: &AuthorizationContext,
539        conditions: &HashMap<String, String>,
540    ) -> bool {
541        // Check corporate network requirement
542        if let Some(require_corporate) = conditions.get("require_corporate_network")
543            && require_corporate == "true"
544        {
545            match context.connection_type {
546                ConnectionType::Corporate => {}
547                _ => return false,
548            }
549        }
550
551        // Check VPN restrictions
552        if let Some(block_vpn) = conditions.get("block_vpn")
553            && block_vpn == "true"
554        {
555            match context.connection_type {
556                ConnectionType::VPN | ConnectionType::Tor => return false,
557                _ => {}
558            }
559        }
560
561        true
562    }
563
564    /// Evaluate device-based conditions
565    pub fn evaluate_device_conditions(
566        &self,
567        context: &AuthorizationContext,
568        conditions: &HashMap<String, String>,
569    ) -> bool {
570        // Check device type restrictions
571        if let Some(allowed_devices) = conditions.get("allowed_device_types") {
572            let allowed: Vec<&str> = allowed_devices.split(',').collect();
573            let device_str = format!("{:?}", context.device_type).to_lowercase();
574
575            if !allowed.contains(&device_str.as_str()) {
576                return false;
577            }
578        }
579
580        true
581    }
582
583    /// Evaluate risk-based conditions
584    pub fn evaluate_risk_conditions(
585        &self,
586        context: &AuthorizationContext,
587        conditions: &HashMap<String, String>,
588    ) -> bool {
589        // Check maximum risk score
590        if let Some(max_risk_str) = conditions.get("max_risk_score")
591            && let Ok(max_risk) = max_risk_str.parse::<u8>()
592            && context.risk_score > max_risk
593        {
594            return false;
595        }
596
597        true
598    }
599
600    /// Main conditional evaluation method for production use
601    /// Evaluates complex conditional permission rules based on context
602    pub fn evaluate_conditional_permission(
603        &self,
604        context: &AuthorizationContext,
605        permission_conditions: &HashMap<String, String>,
606    ) -> bool {
607        // PRODUCTION FIX: Implement proper conditional evaluation
608        tracing::debug!(
609            "Evaluating conditional permission with conditions: {:?}",
610            permission_conditions
611        );
612
613        // If no conditions specified, allow by default
614        if permission_conditions.is_empty() {
615            return true;
616        }
617
618        // Enrich context using the context builder for more comprehensive evaluation
619        let _enriched_context = self.context_builder.enrich_context(context.clone());
620
621        // Evaluate all condition types - ALL must pass for conditional permission to be granted
622        let time_check = self.evaluate_time_conditions(context, permission_conditions);
623        let location_check = self.evaluate_location_conditions(context, permission_conditions);
624        let device_check = self.evaluate_device_conditions(context, permission_conditions);
625        let risk_check = self.evaluate_risk_conditions(context, permission_conditions);
626
627        let result = time_check && location_check && device_check && risk_check;
628
629        tracing::info!(
630            "Conditional evaluation result: {} (time: {}, location: {}, device: {}, risk: {})",
631            result,
632            time_check,
633            location_check,
634            device_check,
635            risk_check
636        );
637
638        result
639    }
640
641    /// Evaluate all conditions
642    pub fn evaluate_all_conditions(
643        &self,
644        context: &AuthorizationContext,
645        conditions: &HashMap<String, String>,
646    ) -> bool {
647        self.evaluate_time_conditions(context, conditions)
648            && self.evaluate_location_conditions(context, conditions)
649            && self.evaluate_device_conditions(context, conditions)
650            && self.evaluate_risk_conditions(context, conditions)
651    }
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657
658    #[test]
659    fn test_context_builder_creation() {
660        let builder = ContextBuilder::new()
661            .with_business_hours(8, 18)
662            .with_holidays(vec![chrono::NaiveDate::from_ymd_opt(2024, 12, 25).unwrap()]);
663
664        assert_eq!(builder.business_start, 8);
665        assert_eq!(builder.business_end, 18);
666        assert_eq!(builder.holidays.len(), 1);
667    }
668
669    #[test]
670    fn test_time_classification() {
671        let builder = ContextBuilder::new();
672
673        // Business hours
674        let business_time = chrono::Utc::now()
675            .with_hour(14)
676            .unwrap()
677            .with_minute(0)
678            .unwrap();
679
680        match business_time.weekday() {
681            Weekday::Sat | Weekday::Sun => {
682                assert!(matches!(
683                    builder.classify_time_of_day(business_time),
684                    TimeOfDay::Weekend
685                ));
686            }
687            _ => {
688                assert!(matches!(
689                    builder.classify_time_of_day(business_time),
690                    TimeOfDay::BusinessHours
691                ));
692            }
693        }
694    }
695
696    #[test]
697    fn test_device_detection() {
698        let builder = ContextBuilder::new();
699
700        let mobile_ua = Some("Mozilla/5.0 (iPhone; CPU iPhone OS 14_0 like Mac OS X)".to_string());
701        assert!(matches!(
702            builder.detect_device_type(&mobile_ua),
703            DeviceType::Mobile
704        ));
705
706        let desktop_ua =
707            Some("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36".to_string());
708        assert!(matches!(
709            builder.detect_device_type(&desktop_ua),
710            DeviceType::Desktop
711        ));
712
713        assert!(matches!(
714            builder.detect_device_type(&None),
715            DeviceType::Unknown
716        ));
717    }
718
719    #[test]
720    fn test_risk_calculation() {
721        let _builder = ContextBuilder::new();
722
723        // Create a mock request - in real tests we'd use proper test utilities
724        // This is a simplified test to verify the logic structure
725        // NOTE: Complete test suite available with additional test infrastructure
726    }
727}