1use crate::errors::{SecurityError, SecurityResult};
4use crate::SecurityContext;
5use serde::{Deserialize, Serialize};
6use std::collections::HashSet;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SecurityPolicy {
11 pub allowed_ip_ranges: Vec<String>,
13 pub blocked_ips: Vec<String>,
15 pub require_tls: bool,
17 pub min_tls_version: String,
19 pub allowed_origins: Vec<String>,
21 pub max_request_size: usize,
23 pub session_timeout: u64,
25 pub require_mfa: bool,
27 pub allowed_endpoints: Vec<String>,
29 pub blocked_endpoints: Vec<String>,
31 pub enable_audit: bool,
33 pub data_classifications: Vec<DataClassification>,
35}
36
37impl Default for SecurityPolicy {
38 fn default() -> Self {
39 Self {
40 allowed_ip_ranges: vec!["0.0.0.0/0".to_string()],
41 blocked_ips: vec![],
42 require_tls: true,
43 min_tls_version: "1.2".to_string(),
44 allowed_origins: vec![],
45 max_request_size: 10 * 1024 * 1024, session_timeout: 3600, require_mfa: false,
48 allowed_endpoints: vec![],
49 blocked_endpoints: vec![],
50 enable_audit: true,
51 data_classifications: vec![
52 DataClassification::Public,
53 DataClassification::Internal,
54 DataClassification::Confidential,
55 DataClassification::Secret,
56 ],
57 }
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
63pub enum DataClassification {
64 Public,
65 Internal,
66 Confidential,
67 Secret,
68}
69
70pub struct PolicyEnforcer {
72 policy: SecurityPolicy,
73 blocked_ips: HashSet<String>,
74}
75
76impl PolicyEnforcer {
77 pub fn new(policy: SecurityPolicy) -> Self {
79 let blocked_ips = policy.blocked_ips.iter().cloned().collect();
80 Self {
81 policy,
82 blocked_ips,
83 }
84 }
85
86 pub fn default() -> Self {
88 Self::new(SecurityPolicy::default())
89 }
90
91 pub fn check_ip(&self, ip: &str) -> SecurityResult<()> {
93 if self.blocked_ips.contains(ip) {
95 return Err(SecurityError::PolicyViolation(format!(
96 "IP address {} is blocked",
97 ip
98 )));
99 }
100
101 if !self.policy.allowed_ip_ranges.is_empty()
103 && !self.policy.allowed_ip_ranges.contains(&"0.0.0.0/0".to_string())
104 {
105 if !self.policy.allowed_ip_ranges.contains(&ip.to_string()) {
108 return Err(SecurityError::PolicyViolation(format!(
109 "IP address {} is not in allowed ranges",
110 ip
111 )));
112 }
113 }
114
115 Ok(())
116 }
117
118 pub fn check_tls(&self, is_tls: bool, version: &str) -> SecurityResult<()> {
120 if self.policy.require_tls && !is_tls {
121 return Err(SecurityError::InsecureProtocol(
122 "TLS is required".to_string(),
123 ));
124 }
125
126 if is_tls {
127 let min_version = self.parse_tls_version(&self.policy.min_tls_version);
128 let actual_version = self.parse_tls_version(version);
129
130 if actual_version < min_version {
131 return Err(SecurityError::InsecureProtocol(format!(
132 "TLS version {} is below minimum {}",
133 version, self.policy.min_tls_version
134 )));
135 }
136 }
137
138 Ok(())
139 }
140
141 fn parse_tls_version(&self, version: &str) -> u32 {
143 match version {
144 "1.0" => 10,
145 "1.1" => 11,
146 "1.2" => 12,
147 "1.3" => 13,
148 _ => 0,
149 }
150 }
151
152 pub fn check_origin(&self, origin: &str) -> SecurityResult<()> {
154 if self.policy.allowed_origins.is_empty() {
155 return Ok(()); }
157
158 if self.policy.allowed_origins.contains(&origin.to_string())
159 || self.policy.allowed_origins.contains(&"*".to_string())
160 {
161 Ok(())
162 } else {
163 Err(SecurityError::PolicyViolation(format!(
164 "Origin {} is not allowed",
165 origin
166 )))
167 }
168 }
169
170 pub fn check_request_size(&self, size: usize) -> SecurityResult<()> {
172 if size > self.policy.max_request_size {
173 return Err(SecurityError::RequestTooLarge(size));
174 }
175 Ok(())
176 }
177
178 pub fn check_endpoint(&self, endpoint: &str) -> SecurityResult<()> {
180 if self.is_endpoint_blocked(endpoint) {
182 return Err(SecurityError::PolicyViolation(format!(
183 "Endpoint {} is blocked",
184 endpoint
185 )));
186 }
187
188 if !self.policy.allowed_endpoints.is_empty()
190 && !self.is_endpoint_allowed(endpoint)
191 {
192 return Err(SecurityError::PolicyViolation(format!(
193 "Endpoint {} is not in allowed list",
194 endpoint
195 )));
196 }
197
198 Ok(())
199 }
200
201 fn is_endpoint_allowed(&self, endpoint: &str) -> bool {
203 self.policy
204 .allowed_endpoints
205 .iter()
206 .any(|pattern| self.matches_pattern(endpoint, pattern))
207 }
208
209 fn is_endpoint_blocked(&self, endpoint: &str) -> bool {
211 self.policy
212 .blocked_endpoints
213 .iter()
214 .any(|pattern| self.matches_pattern(endpoint, pattern))
215 }
216
217 fn matches_pattern(&self, text: &str, pattern: &str) -> bool {
219 if pattern == "*" {
220 return true;
221 }
222
223 if pattern.ends_with('*') {
224 let prefix = &pattern[..pattern.len() - 1];
225 text.starts_with(prefix)
226 } else if pattern.starts_with('*') {
227 let suffix = &pattern[1..];
228 text.ends_with(suffix)
229 } else {
230 text == pattern
231 }
232 }
233
234 pub fn check_mfa(&self, has_mfa: bool, is_sensitive: bool) -> SecurityResult<()> {
236 if self.policy.require_mfa && is_sensitive && !has_mfa {
237 return Err(SecurityError::PolicyViolation(
238 "MFA is required for sensitive operations".to_string(),
239 ));
240 }
241 Ok(())
242 }
243
244 pub fn check_session(
246 &self,
247 created_at: chrono::DateTime<chrono::Utc>,
248 ) -> SecurityResult<()> {
249 let elapsed = chrono::Utc::now()
250 .signed_duration_since(created_at)
251 .num_seconds() as u64;
252
253 if elapsed > self.policy.session_timeout {
254 return Err(SecurityError::InvalidSession(
255 "Session expired".to_string(),
256 ));
257 }
258
259 Ok(())
260 }
261
262 pub fn check_data_classification(
264 &self,
265 classification: &DataClassification,
266 ) -> SecurityResult<()> {
267 if !self.policy.data_classifications.contains(classification) {
268 return Err(SecurityError::PolicyViolation(format!(
269 "Data classification {:?} is not allowed",
270 classification
271 )));
272 }
273 Ok(())
274 }
275
276 pub fn check_request(&self, context: &SecurityContext) -> SecurityResult<()> {
278 self.check_ip(&context.ip_address)?;
280
281 if let Some(ref session_id) = context.session_id {
283 if !session_id.is_empty() {
284 self.check_session(context.timestamp)?;
285 }
286 }
287
288 Ok(())
289 }
290
291 pub fn block_ip(&mut self, ip: String) {
293 self.blocked_ips.insert(ip.clone());
294 if !self.policy.blocked_ips.contains(&ip) {
295 self.policy.blocked_ips.push(ip);
296 }
297 }
298
299 pub fn unblock_ip(&mut self, ip: &str) {
301 self.blocked_ips.remove(ip);
302 self.policy.blocked_ips.retain(|x| x != ip);
303 }
304
305 pub fn get_policy(&self) -> &SecurityPolicy {
307 &self.policy
308 }
309
310 pub fn update_policy(&mut self, policy: SecurityPolicy) {
312 self.blocked_ips = policy.blocked_ips.iter().cloned().collect();
313 self.policy = policy;
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_ip_blocking() {
323 let mut policy = SecurityPolicy::default();
324 policy.blocked_ips.push("192.168.1.100".to_string());
325
326 let enforcer = PolicyEnforcer::new(policy);
327
328 assert!(enforcer.check_ip("192.168.1.1").is_ok());
329 assert!(enforcer.check_ip("192.168.1.100").is_err());
330 }
331
332 #[test]
333 fn test_tls_check() {
334 let policy = SecurityPolicy {
335 require_tls: true,
336 min_tls_version: "1.2".to_string(),
337 ..Default::default()
338 };
339
340 let enforcer = PolicyEnforcer::new(policy);
341
342 assert!(enforcer.check_tls(true, "1.2").is_ok());
343 assert!(enforcer.check_tls(true, "1.3").is_ok());
344 assert!(enforcer.check_tls(true, "1.1").is_err());
345 assert!(enforcer.check_tls(false, "1.2").is_err());
346 }
347
348 #[test]
349 fn test_origin_check() {
350 let policy = SecurityPolicy {
351 allowed_origins: vec!["https://example.com".to_string()],
352 ..Default::default()
353 };
354
355 let enforcer = PolicyEnforcer::new(policy);
356
357 assert!(enforcer
358 .check_origin("https://example.com")
359 .is_ok());
360 assert!(enforcer
361 .check_origin("https://evil.com")
362 .is_err());
363 }
364
365 #[test]
366 fn test_request_size() {
367 let policy = SecurityPolicy {
368 max_request_size: 1024,
369 ..Default::default()
370 };
371
372 let enforcer = PolicyEnforcer::new(policy);
373
374 assert!(enforcer.check_request_size(512).is_ok());
375 assert!(enforcer.check_request_size(2048).is_err());
376 }
377
378 #[test]
379 fn test_endpoint_patterns() {
380 let policy = SecurityPolicy {
381 allowed_endpoints: vec!["/api/*".to_string()],
382 blocked_endpoints: vec!["/api/admin/*".to_string()],
383 ..Default::default()
384 };
385
386 let enforcer = PolicyEnforcer::new(policy);
387
388 assert!(enforcer.check_endpoint("/api/users").is_ok());
389 assert!(enforcer.check_endpoint("/api/admin/users").is_err());
390 }
391
392 #[test]
393 fn test_mfa_requirement() {
394 let policy = SecurityPolicy {
395 require_mfa: true,
396 ..Default::default()
397 };
398
399 let enforcer = PolicyEnforcer::new(policy);
400
401 assert!(enforcer.check_mfa(true, true).is_ok());
402 assert!(enforcer.check_mfa(false, false).is_ok());
403 assert!(enforcer.check_mfa(false, true).is_err());
404 }
405
406 #[test]
407 fn test_session_timeout() {
408 let policy = SecurityPolicy {
409 session_timeout: 3600,
410 ..Default::default()
411 };
412
413 let enforcer = PolicyEnforcer::new(policy);
414
415 let recent = chrono::Utc::now() - chrono::Duration::seconds(1800);
417 assert!(enforcer.check_session(recent).is_ok());
418
419 let expired = chrono::Utc::now() - chrono::Duration::seconds(7200);
421 assert!(enforcer.check_session(expired).is_err());
422 }
423
424 #[test]
425 fn test_data_classification() {
426 let policy = SecurityPolicy {
427 data_classifications: vec![
428 DataClassification::Public,
429 DataClassification::Internal,
430 ],
431 ..Default::default()
432 };
433
434 let enforcer = PolicyEnforcer::new(policy);
435
436 assert!(enforcer
437 .check_data_classification(&DataClassification::Public)
438 .is_ok());
439 assert!(enforcer
440 .check_data_classification(&DataClassification::Secret)
441 .is_err());
442 }
443
444 #[test]
445 fn test_dynamic_blocking() {
446 let mut enforcer = PolicyEnforcer::default();
447
448 enforcer.block_ip("10.0.0.1".to_string());
449 assert!(enforcer.check_ip("10.0.0.1").is_err());
450
451 enforcer.unblock_ip("10.0.0.1");
452 assert!(enforcer.check_ip("10.0.0.1").is_ok());
453 }
454
455 #[test]
456 fn test_comprehensive_check() {
457 let enforcer = PolicyEnforcer::default();
458
459 let context = SecurityContext::new("user123", "192.168.1.1")
460 .with_session("sess_abc");
461
462 assert!(enforcer.check_request(&context).is_ok());
463 }
464
465 #[test]
466 fn test_pattern_matching() {
467 let enforcer = PolicyEnforcer::default();
468
469 assert!(enforcer.matches_pattern("/api/users", "/api/*"));
470 assert!(enforcer.matches_pattern("/api/users", "*/users"));
471 assert!(enforcer.matches_pattern("/api/users", "/api/users"));
472 assert!(!enforcer.matches_pattern("/api/users", "/admin/*"));
473 }
474}