1use crate::tokens::AuthToken;
7use axum::extract::Request;
8use chrono::{DateTime, Datelike, Timelike, Utc, Weekday};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::net::IpAddr;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AuthorizationContext {
16 pub user_id: String,
18 pub roles: Vec<String>,
19 pub session_id: Option<String>,
20
21 pub method: String,
23 pub path: String,
24 pub ip_address: Option<IpAddr>,
25 pub user_agent: Option<String>,
26
27 pub request_time: DateTime<Utc>,
29 pub time_of_day: TimeOfDay,
30 pub day_type: DayType,
31
32 pub device_type: DeviceType,
34 pub connection_type: ConnectionType,
35
36 pub security_level: SecurityLevel,
38 pub risk_score: u8, pub custom_attributes: HashMap<String, String>,
42}
43
44#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
46pub enum TimeOfDay {
47 BusinessHours,
48 AfterHours,
49 Weekend,
50 Holiday,
51}
52
53#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
55pub enum DayType {
56 Weekday,
57 Weekend,
58 Holiday,
59}
60
61#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
63pub enum DeviceType {
64 Desktop,
65 Mobile,
66 Tablet,
67 Unknown,
68}
69
70#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
72pub enum ConnectionType {
73 Direct,
74 VPN,
75 Proxy,
76 Tor,
77 Corporate,
78 Unknown,
79}
80
81#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
83pub enum SecurityLevel {
84 Low,
85 Medium,
86 High,
87 Critical,
88}
89
90pub struct ContextBuilder {
92 holidays: Vec<chrono::NaiveDate>,
94 business_start: u8,
96 business_end: u8,
97 corporate_networks: Vec<ipnetwork::IpNetwork>,
99}
100
101impl Default for ContextBuilder {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl ContextBuilder {
108 pub fn new() -> Self {
110 Self {
111 holidays: Vec::new(),
112 business_start: 9,
113 business_end: 17,
114 corporate_networks: Vec::new(),
115 }
116 }
117
118 pub fn with_business_hours(mut self, start: u8, end: u8) -> Self {
120 self.business_start = start;
121 self.business_end = end;
122 self
123 }
124
125 pub fn with_corporate_networks(mut self, networks: Vec<ipnetwork::IpNetwork>) -> Self {
127 self.corporate_networks = networks;
128 self
129 }
130
131 pub fn with_holidays(mut self, holidays: Vec<chrono::NaiveDate>) -> Self {
133 self.holidays = holidays;
134 self
135 }
136
137 pub fn build_context(&self, request: &Request, auth_token: &AuthToken) -> AuthorizationContext {
139 let now = Utc::now();
140 let ip_address = self.extract_ip_address(request);
141 let user_agent = self.extract_user_agent(request);
142
143 AuthorizationContext {
144 user_id: auth_token.user_id.clone(),
145 roles: auth_token.roles.to_vec(),
146 session_id: auth_token.metadata.session_id.clone(),
147
148 method: request.method().to_string(),
149 path: request.uri().path().to_string(),
150 ip_address,
151 user_agent: user_agent.clone(),
152
153 request_time: now,
154 time_of_day: self.classify_time_of_day(now),
155 day_type: self.classify_day_type(now),
156
157 device_type: self.detect_device_type(&user_agent),
158 connection_type: self.analyze_connection_type(request, &ip_address),
159
160 security_level: self.assess_security_level(request),
161 risk_score: self.calculate_risk_score(request, &ip_address, &user_agent),
162
163 custom_attributes: self.extract_custom_attributes(request),
164 }
165 }
166
167 pub fn to_hashmap(&self, context: &AuthorizationContext) -> HashMap<String, String> {
169 let mut map = HashMap::new();
170
171 map.insert("user_id".to_string(), context.user_id.clone());
173 map.insert("roles".to_string(), context.roles.join(","));
174 if let Some(session_id) = &context.session_id {
175 map.insert("session_id".to_string(), session_id.clone());
176 }
177
178 map.insert("method".to_string(), context.method.clone());
180 map.insert("path".to_string(), context.path.clone());
181 if let Some(ip) = &context.ip_address {
182 map.insert("ip_address".to_string(), ip.to_string());
183 }
184 if let Some(ua) = &context.user_agent {
185 map.insert("user_agent".to_string(), ua.clone());
186 }
187
188 map.insert(
190 "time_of_day".to_string(),
191 format!("{:?}", context.time_of_day).to_lowercase(),
192 );
193 map.insert(
194 "day_type".to_string(),
195 format!("{:?}", context.day_type).to_lowercase(),
196 );
197 map.insert(
198 "request_hour".to_string(),
199 context.request_time.hour().to_string(),
200 );
201 map.insert(
202 "request_weekday".to_string(),
203 context.request_time.weekday().to_string(),
204 );
205
206 map.insert(
208 "device_type".to_string(),
209 format!("{:?}", context.device_type).to_lowercase(),
210 );
211 map.insert(
212 "connection_type".to_string(),
213 format!("{:?}", context.connection_type).to_lowercase(),
214 );
215
216 map.insert(
218 "security_level".to_string(),
219 format!("{:?}", context.security_level).to_lowercase(),
220 );
221 map.insert("risk_score".to_string(), context.risk_score.to_string());
222
223 for (key, value) in &context.custom_attributes {
225 map.insert(format!("custom_{}", key), value.clone());
226 }
227
228 map
229 }
230
231 fn extract_ip_address(&self, request: &Request) -> Option<IpAddr> {
233 if let Some(forwarded) = request.headers().get("x-forwarded-for")
235 && let Ok(forwarded_str) = forwarded.to_str()
236 {
237 if let Some(ip_str) = forwarded_str.split(',').next()
238 && let Ok(ip) = ip_str.trim().parse()
239 {
240 return Some(ip);
241 }
242
243 if let Some(real_ip) = request.headers().get("x-real-ip")
245 && let Ok(ip_str) = real_ip.to_str()
246 && let Ok(ip) = ip_str.parse()
247 {
248 return Some(ip);
249 }
250
251 None
253 } else {
254 request
256 .extensions()
257 .get::<axum::extract::ConnectInfo<IpAddr>>()
258 .map(|info| info.0)
259 }
260 }
261
262 fn extract_user_agent(&self, request: &Request) -> Option<String> {
264 request
265 .headers()
266 .get("user-agent")
267 .and_then(|ua| ua.to_str().ok())
268 .map(|s| s.to_string())
269 }
270
271 fn classify_time_of_day(&self, now: DateTime<Utc>) -> TimeOfDay {
273 let date = now.date_naive();
274
275 if self.holidays.contains(&date) {
277 return TimeOfDay::Holiday;
278 }
279
280 match now.weekday() {
282 Weekday::Sat | Weekday::Sun => return TimeOfDay::Weekend,
283 _ => {}
284 }
285
286 let hour = now.hour() as u8;
288 if hour >= self.business_start && hour < self.business_end {
289 TimeOfDay::BusinessHours
290 } else {
291 TimeOfDay::AfterHours
292 }
293 }
294
295 fn classify_day_type(&self, now: DateTime<Utc>) -> DayType {
297 let date = now.date_naive();
298
299 if self.holidays.contains(&date) {
300 DayType::Holiday
301 } else {
302 match now.weekday() {
303 Weekday::Sat | Weekday::Sun => DayType::Weekend,
304 _ => DayType::Weekday,
305 }
306 }
307 }
308
309 fn detect_device_type(&self, user_agent: &Option<String>) -> DeviceType {
311 let ua = match user_agent {
312 Some(ua) => ua.to_lowercase(),
313 None => return DeviceType::Unknown,
314 };
315
316 if ua.contains("mobile") || ua.contains("android") || ua.contains("iphone") {
317 DeviceType::Mobile
318 } else if ua.contains("tablet") || ua.contains("ipad") {
319 DeviceType::Tablet
320 } else if ua.contains("mozilla") || ua.contains("chrome") || ua.contains("firefox") {
321 DeviceType::Desktop
322 } else {
323 DeviceType::Unknown
324 }
325 }
326
327 fn analyze_connection_type(
329 &self,
330 request: &Request,
331 ip_address: &Option<IpAddr>,
332 ) -> ConnectionType {
333 if let Some(via) = request.headers().get("via")
335 && let Ok(via_str) = via.to_str()
336 {
337 if via_str.to_lowercase().contains("vpn") {
338 return ConnectionType::VPN;
339 }
340 if via_str.to_lowercase().contains("proxy") {
341 return ConnectionType::Proxy;
342 }
343
344 if let Some(ua) = request.headers().get("user-agent")
346 && let Ok(ua_str) = ua.to_str()
347 && ua_str.contains("Tor")
348 {
349 return ConnectionType::Tor;
350 }
351
352 if let Some(ip) = ip_address {
354 for network in &self.corporate_networks {
355 if network.contains(*ip) {
356 return ConnectionType::Corporate;
357 }
358 }
359 }
360
361 return ConnectionType::Direct;
362 }
363 ConnectionType::Unknown
365 }
366
367 fn assess_security_level(&self, request: &Request) -> SecurityLevel {
369 let path = request.uri().path();
370
371 match path {
372 _ if path.starts_with("/admin/system/") => SecurityLevel::Critical,
373 _ if path.starts_with("/admin/") => SecurityLevel::High,
374 _ if path.contains("/secrets/") => SecurityLevel::Critical,
375 _ if path.contains("/keys/") => SecurityLevel::High,
376 _ if path.starts_with("/api/") => SecurityLevel::Medium,
377 _ => SecurityLevel::Low,
378 }
379 }
380
381 fn calculate_risk_score(
383 &self,
384 request: &Request,
385 ip_address: &Option<IpAddr>,
386 user_agent: &Option<String>,
387 ) -> u8 {
388 let mut risk_score = 0u8;
389
390 let path = request.uri().path();
392 if path.starts_with("/admin/") {
393 risk_score += 30;
394 } else if path.contains("/secrets/") || path.contains("/keys/") {
395 risk_score += 40;
396 } else if path.starts_with("/api/") {
397 risk_score += 10;
398 }
399
400 let connection_type = self.analyze_connection_type(request, ip_address);
402 match connection_type {
403 ConnectionType::Tor => risk_score += 50,
404 ConnectionType::VPN => risk_score += 20,
405 ConnectionType::Proxy => risk_score += 15,
406 ConnectionType::Corporate => risk_score = risk_score.saturating_sub(10),
407 ConnectionType::Direct => {}
408 ConnectionType::Unknown => risk_score += 10,
409 }
410
411 let device_type = self.detect_device_type(user_agent);
413 match device_type {
414 DeviceType::Mobile => risk_score += 5,
415 DeviceType::Unknown => risk_score += 15,
416 _ => {}
417 }
418
419 let now = Utc::now();
421 match self.classify_time_of_day(now) {
422 TimeOfDay::AfterHours => risk_score += 10,
423 TimeOfDay::Weekend => risk_score += 5,
424 _ => {}
425 }
426
427 if user_agent.is_none() {
429 risk_score += 20;
430 }
431
432 risk_score.min(100)
434 }
435
436 fn extract_custom_attributes(&self, request: &Request) -> HashMap<String, String> {
438 let mut attributes = HashMap::new();
439
440 for (name, value) in request.headers() {
442 let name_str = name.as_str().to_lowercase();
443 if let Some(attr_name) = name_str.strip_prefix("x-auth-")
444 && let Ok(value_str) = value.to_str()
445 {
446 attributes.insert(attr_name.to_string(), value_str.to_string());
447 }
448 }
449
450 if let Some(query) = request.uri().query() {
452 for pair in query.split('&') {
453 if let Some((key, value)) = pair.split_once('=')
454 && key.starts_with("ctx_")
455 {
456 attributes.insert(
457 key.strip_prefix("ctx_")
458 .expect("guarded by starts_with check")
459 .to_string(),
460 urlencoding::decode(value).unwrap_or_default().to_string(),
461 );
462 }
463 }
464 }
465
466 attributes
467 }
468
469 pub fn enrich_context(&self, mut context: AuthorizationContext) -> AuthorizationContext {
471 let current_risk = context.risk_score;
473 context.risk_score = std::cmp::max(current_risk, 1); let now = chrono::Utc::now();
477 context
478 .custom_attributes
479 .insert("enriched_timestamp".to_string(), now.to_rfc3339());
480
481 context.custom_attributes.insert(
483 "security_assessment".to_string(),
484 match context.security_level {
485 SecurityLevel::Low => "basic".to_string(),
486 SecurityLevel::Medium => "standard".to_string(),
487 SecurityLevel::High => "enhanced".to_string(),
488 SecurityLevel::Critical => "maximum".to_string(),
489 },
490 );
491
492 context
493 }
494}
495
496pub struct ConditionalEvaluator {
498 context_builder: ContextBuilder,
499}
500
501impl ConditionalEvaluator {
502 pub fn new(context_builder: ContextBuilder) -> Self {
504 Self { context_builder }
505 }
506
507 pub fn evaluate_time_conditions(
509 &self,
510 context: &AuthorizationContext,
511 conditions: &HashMap<String, String>,
512 ) -> bool {
513 if let Some(require_business_hours) = conditions.get("require_business_hours")
515 && require_business_hours == "true"
516 {
517 match context.time_of_day {
518 TimeOfDay::BusinessHours => {}
519 _ => return false,
520 }
521 }
522
523 if let Some(require_weekday) = conditions.get("require_weekday")
525 && require_weekday == "true"
526 {
527 match context.day_type {
528 DayType::Weekday => {}
529 _ => return false,
530 }
531 }
532
533 true
534 }
535
536 pub fn evaluate_location_conditions(
538 &self,
539 context: &AuthorizationContext,
540 conditions: &HashMap<String, String>,
541 ) -> bool {
542 if let Some(require_corporate) = conditions.get("require_corporate_network")
544 && require_corporate == "true"
545 {
546 match context.connection_type {
547 ConnectionType::Corporate => {}
548 _ => return false,
549 }
550 }
551
552 if let Some(block_vpn) = conditions.get("block_vpn")
554 && block_vpn == "true"
555 {
556 match context.connection_type {
557 ConnectionType::VPN | ConnectionType::Tor => return false,
558 _ => {}
559 }
560 }
561
562 true
563 }
564
565 pub fn evaluate_device_conditions(
567 &self,
568 context: &AuthorizationContext,
569 conditions: &HashMap<String, String>,
570 ) -> bool {
571 if let Some(allowed_devices) = conditions.get("allowed_device_types") {
573 let allowed: Vec<&str> = allowed_devices.split(',').collect();
574 let device_str = format!("{:?}", context.device_type).to_lowercase();
575
576 if !allowed.contains(&device_str.as_str()) {
577 return false;
578 }
579 }
580
581 true
582 }
583
584 pub fn evaluate_risk_conditions(
586 &self,
587 context: &AuthorizationContext,
588 conditions: &HashMap<String, String>,
589 ) -> bool {
590 if let Some(max_risk_str) = conditions.get("max_risk_score")
592 && let Ok(max_risk) = max_risk_str.parse::<u8>()
593 && context.risk_score > max_risk
594 {
595 return false;
596 }
597
598 true
599 }
600
601 pub fn evaluate_conditional_permission(
604 &self,
605 context: &AuthorizationContext,
606 permission_conditions: &HashMap<String, String>,
607 ) -> bool {
608 tracing::debug!(
610 "Evaluating conditional permission with conditions: {:?}",
611 permission_conditions
612 );
613
614 if permission_conditions.is_empty() {
616 return true;
617 }
618
619 let _enriched_context = self.context_builder.enrich_context(context.clone());
621
622 let time_check = self.evaluate_time_conditions(context, permission_conditions);
624 let location_check = self.evaluate_location_conditions(context, permission_conditions);
625 let device_check = self.evaluate_device_conditions(context, permission_conditions);
626 let risk_check = self.evaluate_risk_conditions(context, permission_conditions);
627
628 let result = time_check && location_check && device_check && risk_check;
629
630 tracing::info!(
631 "Conditional evaluation result: {} (time: {}, location: {}, device: {}, risk: {})",
632 result,
633 time_check,
634 location_check,
635 device_check,
636 risk_check
637 );
638
639 result
640 }
641
642 pub fn evaluate_all_conditions(
644 &self,
645 context: &AuthorizationContext,
646 conditions: &HashMap<String, String>,
647 ) -> bool {
648 self.evaluate_time_conditions(context, conditions)
649 && self.evaluate_location_conditions(context, conditions)
650 && self.evaluate_device_conditions(context, conditions)
651 && self.evaluate_risk_conditions(context, conditions)
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_context_builder_creation() {
661 let builder = ContextBuilder::new()
662 .with_business_hours(8, 18)
663 .with_holidays(vec![chrono::NaiveDate::from_ymd_opt(2024, 12, 25).unwrap()]);
664
665 assert_eq!(builder.business_start, 8);
666 assert_eq!(builder.business_end, 18);
667 assert_eq!(builder.holidays.len(), 1);
668 }
669
670 #[test]
671 fn test_time_classification() {
672 let builder = ContextBuilder::new();
673
674 let business_time = chrono::Utc::now()
676 .with_hour(14)
677 .unwrap()
678 .with_minute(0)
679 .unwrap();
680
681 match business_time.weekday() {
682 Weekday::Sat | Weekday::Sun => {
683 assert!(matches!(
684 builder.classify_time_of_day(business_time),
685 TimeOfDay::Weekend
686 ));
687 }
688 _ => {
689 assert!(matches!(
690 builder.classify_time_of_day(business_time),
691 TimeOfDay::BusinessHours
692 ));
693 }
694 }
695 }
696
697 #[test]
698 fn test_device_detection() {
699 let builder = ContextBuilder::new();
700
701 let mobile_ua = Some("Mozilla/5.0 (iPhone; CPU iPhone OS 14_0 like Mac OS X)".to_string());
702 assert!(matches!(
703 builder.detect_device_type(&mobile_ua),
704 DeviceType::Mobile
705 ));
706
707 let desktop_ua =
708 Some("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36".to_string());
709 assert!(matches!(
710 builder.detect_device_type(&desktop_ua),
711 DeviceType::Desktop
712 ));
713
714 assert!(matches!(
715 builder.detect_device_type(&None),
716 DeviceType::Unknown
717 ));
718 }
719
720 #[test]
721 fn test_risk_calculation() {
722 let _builder = ContextBuilder::new();
723
724 }
728}