llm_config_security/
policy.rs

1//! Security policy enforcement
2
3use crate::errors::{SecurityError, SecurityResult};
4use crate::SecurityContext;
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8/// Security policy configuration
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SecurityPolicy {
11    /// Allowed IP ranges (CIDR notation)
12    pub allowed_ip_ranges: Vec<String>,
13    /// Blocked IP addresses
14    pub blocked_ips: Vec<String>,
15    /// Require TLS
16    pub require_tls: bool,
17    /// Minimum TLS version
18    pub min_tls_version: String,
19    /// Allowed origins for CORS
20    pub allowed_origins: Vec<String>,
21    /// Maximum request size in bytes
22    pub max_request_size: usize,
23    /// Session timeout in seconds
24    pub session_timeout: u64,
25    /// Require MFA for sensitive operations
26    pub require_mfa: bool,
27    /// Allowed API endpoints
28    pub allowed_endpoints: Vec<String>,
29    /// Blocked endpoints
30    pub blocked_endpoints: Vec<String>,
31    /// Enable audit logging
32    pub enable_audit: bool,
33    /// Data classification levels
34    pub data_classifications: Vec<DataClassification>,
35}
36
37impl Default for SecurityPolicy {
38    fn default() -> Self {
39        Self {
40            allowed_ip_ranges: vec!["0.0.0.0/0".to_string()],
41            blocked_ips: vec![],
42            require_tls: true,
43            min_tls_version: "1.2".to_string(),
44            allowed_origins: vec![],
45            max_request_size: 10 * 1024 * 1024, // 10MB
46            session_timeout: 3600,               // 1 hour
47            require_mfa: false,
48            allowed_endpoints: vec![],
49            blocked_endpoints: vec![],
50            enable_audit: true,
51            data_classifications: vec![
52                DataClassification::Public,
53                DataClassification::Internal,
54                DataClassification::Confidential,
55                DataClassification::Secret,
56            ],
57        }
58    }
59}
60
61/// Data classification levels
62#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
63pub enum DataClassification {
64    Public,
65    Internal,
66    Confidential,
67    Secret,
68}
69
70/// Policy enforcer
71pub struct PolicyEnforcer {
72    policy: SecurityPolicy,
73    blocked_ips: HashSet<String>,
74}
75
76impl PolicyEnforcer {
77    /// Create a new policy enforcer
78    pub fn new(policy: SecurityPolicy) -> Self {
79        let blocked_ips = policy.blocked_ips.iter().cloned().collect();
80        Self {
81            policy,
82            blocked_ips,
83        }
84    }
85
86    /// Create with default policy
87    pub fn default() -> Self {
88        Self::new(SecurityPolicy::default())
89    }
90
91    /// Check if an IP is allowed
92    pub fn check_ip(&self, ip: &str) -> SecurityResult<()> {
93        // Check if IP is blocked
94        if self.blocked_ips.contains(ip) {
95            return Err(SecurityError::PolicyViolation(format!(
96                "IP address {} is blocked",
97                ip
98            )));
99        }
100
101        // Check if IP is in allowed ranges
102        if !self.policy.allowed_ip_ranges.is_empty()
103            && !self.policy.allowed_ip_ranges.contains(&"0.0.0.0/0".to_string())
104        {
105            // In a real implementation, we would use proper CIDR matching
106            // For now, just check if IP is in the list
107            if !self.policy.allowed_ip_ranges.contains(&ip.to_string()) {
108                return Err(SecurityError::PolicyViolation(format!(
109                    "IP address {} is not in allowed ranges",
110                    ip
111                )));
112            }
113        }
114
115        Ok(())
116    }
117
118    /// Check if TLS is required
119    pub fn check_tls(&self, is_tls: bool, version: &str) -> SecurityResult<()> {
120        if self.policy.require_tls && !is_tls {
121            return Err(SecurityError::InsecureProtocol(
122                "TLS is required".to_string(),
123            ));
124        }
125
126        if is_tls {
127            let min_version = self.parse_tls_version(&self.policy.min_tls_version);
128            let actual_version = self.parse_tls_version(version);
129
130            if actual_version < min_version {
131                return Err(SecurityError::InsecureProtocol(format!(
132                    "TLS version {} is below minimum {}",
133                    version, self.policy.min_tls_version
134                )));
135            }
136        }
137
138        Ok(())
139    }
140
141    /// Parse TLS version string to number for comparison
142    fn parse_tls_version(&self, version: &str) -> u32 {
143        match version {
144            "1.0" => 10,
145            "1.1" => 11,
146            "1.2" => 12,
147            "1.3" => 13,
148            _ => 0,
149        }
150    }
151
152    /// Check CORS origin
153    pub fn check_origin(&self, origin: &str) -> SecurityResult<()> {
154        if self.policy.allowed_origins.is_empty() {
155            return Ok(()); // No CORS restrictions
156        }
157
158        if self.policy.allowed_origins.contains(&origin.to_string())
159            || self.policy.allowed_origins.contains(&"*".to_string())
160        {
161            Ok(())
162        } else {
163            Err(SecurityError::PolicyViolation(format!(
164                "Origin {} is not allowed",
165                origin
166            )))
167        }
168    }
169
170    /// Check request size
171    pub fn check_request_size(&self, size: usize) -> SecurityResult<()> {
172        if size > self.policy.max_request_size {
173            return Err(SecurityError::RequestTooLarge(size));
174        }
175        Ok(())
176    }
177
178    /// Check if endpoint is allowed
179    pub fn check_endpoint(&self, endpoint: &str) -> SecurityResult<()> {
180        // Check if endpoint is blocked
181        if self.is_endpoint_blocked(endpoint) {
182            return Err(SecurityError::PolicyViolation(format!(
183                "Endpoint {} is blocked",
184                endpoint
185            )));
186        }
187
188        // Check if endpoint is in allowed list (if list is not empty)
189        if !self.policy.allowed_endpoints.is_empty()
190            && !self.is_endpoint_allowed(endpoint)
191        {
192            return Err(SecurityError::PolicyViolation(format!(
193                "Endpoint {} is not in allowed list",
194                endpoint
195            )));
196        }
197
198        Ok(())
199    }
200
201    /// Check if endpoint matches allowed patterns
202    fn is_endpoint_allowed(&self, endpoint: &str) -> bool {
203        self.policy
204            .allowed_endpoints
205            .iter()
206            .any(|pattern| self.matches_pattern(endpoint, pattern))
207    }
208
209    /// Check if endpoint matches blocked patterns
210    fn is_endpoint_blocked(&self, endpoint: &str) -> bool {
211        self.policy
212            .blocked_endpoints
213            .iter()
214            .any(|pattern| self.matches_pattern(endpoint, pattern))
215    }
216
217    /// Simple pattern matching (supports wildcards)
218    fn matches_pattern(&self, text: &str, pattern: &str) -> bool {
219        if pattern == "*" {
220            return true;
221        }
222
223        if pattern.ends_with('*') {
224            let prefix = &pattern[..pattern.len() - 1];
225            text.starts_with(prefix)
226        } else if pattern.starts_with('*') {
227            let suffix = &pattern[1..];
228            text.ends_with(suffix)
229        } else {
230            text == pattern
231        }
232    }
233
234    /// Check if MFA is required
235    pub fn check_mfa(&self, has_mfa: bool, is_sensitive: bool) -> SecurityResult<()> {
236        if self.policy.require_mfa && is_sensitive && !has_mfa {
237            return Err(SecurityError::PolicyViolation(
238                "MFA is required for sensitive operations".to_string(),
239            ));
240        }
241        Ok(())
242    }
243
244    /// Check session validity
245    pub fn check_session(
246        &self,
247        created_at: chrono::DateTime<chrono::Utc>,
248    ) -> SecurityResult<()> {
249        let elapsed = chrono::Utc::now()
250            .signed_duration_since(created_at)
251            .num_seconds() as u64;
252
253        if elapsed > self.policy.session_timeout {
254            return Err(SecurityError::InvalidSession(
255                "Session expired".to_string(),
256            ));
257        }
258
259        Ok(())
260    }
261
262    /// Validate data classification
263    pub fn check_data_classification(
264        &self,
265        classification: &DataClassification,
266    ) -> SecurityResult<()> {
267        if !self.policy.data_classifications.contains(classification) {
268            return Err(SecurityError::PolicyViolation(format!(
269                "Data classification {:?} is not allowed",
270                classification
271            )));
272        }
273        Ok(())
274    }
275
276    /// Comprehensive security check
277    pub fn check_request(&self, context: &SecurityContext) -> SecurityResult<()> {
278        // Check IP
279        self.check_ip(&context.ip_address)?;
280
281        // Check session if present
282        if let Some(ref session_id) = context.session_id {
283            if !session_id.is_empty() {
284                self.check_session(context.timestamp)?;
285            }
286        }
287
288        Ok(())
289    }
290
291    /// Add an IP to the blocklist
292    pub fn block_ip(&mut self, ip: String) {
293        self.blocked_ips.insert(ip.clone());
294        if !self.policy.blocked_ips.contains(&ip) {
295            self.policy.blocked_ips.push(ip);
296        }
297    }
298
299    /// Remove an IP from the blocklist
300    pub fn unblock_ip(&mut self, ip: &str) {
301        self.blocked_ips.remove(ip);
302        self.policy.blocked_ips.retain(|x| x != ip);
303    }
304
305    /// Get the current policy
306    pub fn get_policy(&self) -> &SecurityPolicy {
307        &self.policy
308    }
309
310    /// Update the policy
311    pub fn update_policy(&mut self, policy: SecurityPolicy) {
312        self.blocked_ips = policy.blocked_ips.iter().cloned().collect();
313        self.policy = policy;
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_ip_blocking() {
323        let mut policy = SecurityPolicy::default();
324        policy.blocked_ips.push("192.168.1.100".to_string());
325
326        let enforcer = PolicyEnforcer::new(policy);
327
328        assert!(enforcer.check_ip("192.168.1.1").is_ok());
329        assert!(enforcer.check_ip("192.168.1.100").is_err());
330    }
331
332    #[test]
333    fn test_tls_check() {
334        let policy = SecurityPolicy {
335            require_tls: true,
336            min_tls_version: "1.2".to_string(),
337            ..Default::default()
338        };
339
340        let enforcer = PolicyEnforcer::new(policy);
341
342        assert!(enforcer.check_tls(true, "1.2").is_ok());
343        assert!(enforcer.check_tls(true, "1.3").is_ok());
344        assert!(enforcer.check_tls(true, "1.1").is_err());
345        assert!(enforcer.check_tls(false, "1.2").is_err());
346    }
347
348    #[test]
349    fn test_origin_check() {
350        let policy = SecurityPolicy {
351            allowed_origins: vec!["https://example.com".to_string()],
352            ..Default::default()
353        };
354
355        let enforcer = PolicyEnforcer::new(policy);
356
357        assert!(enforcer
358            .check_origin("https://example.com")
359            .is_ok());
360        assert!(enforcer
361            .check_origin("https://evil.com")
362            .is_err());
363    }
364
365    #[test]
366    fn test_request_size() {
367        let policy = SecurityPolicy {
368            max_request_size: 1024,
369            ..Default::default()
370        };
371
372        let enforcer = PolicyEnforcer::new(policy);
373
374        assert!(enforcer.check_request_size(512).is_ok());
375        assert!(enforcer.check_request_size(2048).is_err());
376    }
377
378    #[test]
379    fn test_endpoint_patterns() {
380        let policy = SecurityPolicy {
381            allowed_endpoints: vec!["/api/*".to_string()],
382            blocked_endpoints: vec!["/api/admin/*".to_string()],
383            ..Default::default()
384        };
385
386        let enforcer = PolicyEnforcer::new(policy);
387
388        assert!(enforcer.check_endpoint("/api/users").is_ok());
389        assert!(enforcer.check_endpoint("/api/admin/users").is_err());
390    }
391
392    #[test]
393    fn test_mfa_requirement() {
394        let policy = SecurityPolicy {
395            require_mfa: true,
396            ..Default::default()
397        };
398
399        let enforcer = PolicyEnforcer::new(policy);
400
401        assert!(enforcer.check_mfa(true, true).is_ok());
402        assert!(enforcer.check_mfa(false, false).is_ok());
403        assert!(enforcer.check_mfa(false, true).is_err());
404    }
405
406    #[test]
407    fn test_session_timeout() {
408        let policy = SecurityPolicy {
409            session_timeout: 3600,
410            ..Default::default()
411        };
412
413        let enforcer = PolicyEnforcer::new(policy);
414
415        // Recent session
416        let recent = chrono::Utc::now() - chrono::Duration::seconds(1800);
417        assert!(enforcer.check_session(recent).is_ok());
418
419        // Expired session
420        let expired = chrono::Utc::now() - chrono::Duration::seconds(7200);
421        assert!(enforcer.check_session(expired).is_err());
422    }
423
424    #[test]
425    fn test_data_classification() {
426        let policy = SecurityPolicy {
427            data_classifications: vec![
428                DataClassification::Public,
429                DataClassification::Internal,
430            ],
431            ..Default::default()
432        };
433
434        let enforcer = PolicyEnforcer::new(policy);
435
436        assert!(enforcer
437            .check_data_classification(&DataClassification::Public)
438            .is_ok());
439        assert!(enforcer
440            .check_data_classification(&DataClassification::Secret)
441            .is_err());
442    }
443
444    #[test]
445    fn test_dynamic_blocking() {
446        let mut enforcer = PolicyEnforcer::default();
447
448        enforcer.block_ip("10.0.0.1".to_string());
449        assert!(enforcer.check_ip("10.0.0.1").is_err());
450
451        enforcer.unblock_ip("10.0.0.1");
452        assert!(enforcer.check_ip("10.0.0.1").is_ok());
453    }
454
455    #[test]
456    fn test_comprehensive_check() {
457        let enforcer = PolicyEnforcer::default();
458
459        let context = SecurityContext::new("user123", "192.168.1.1")
460            .with_session("sess_abc");
461
462        assert!(enforcer.check_request(&context).is_ok());
463    }
464
465    #[test]
466    fn test_pattern_matching() {
467        let enforcer = PolicyEnforcer::default();
468
469        assert!(enforcer.matches_pattern("/api/users", "/api/*"));
470        assert!(enforcer.matches_pattern("/api/users", "*/users"));
471        assert!(enforcer.matches_pattern("/api/users", "/api/users"));
472        assert!(!enforcer.matches_pattern("/api/users", "/admin/*"));
473    }
474}