Skip to main content

auth_framework/api/
security.rs

1//! Security Features - Rate Limiting, DoS Protection, IP Blacklisting
2//!
3//! Advanced security features for API protection
4
5use crate::api::{ApiResponse, ApiState};
6use axum::{
7    Json,
8    extract::{ConnectInfo, State},
9    http::{Request, StatusCode},
10    middleware::Next,
11    response::Response,
12};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use tokio::sync::RwLock;
19
20/// Rate limiter configuration
21#[derive(Debug, Clone)]
22pub struct RateLimitConfig {
23    /// Maximum requests per window
24    pub max_requests: u32,
25    /// Time window duration
26    pub window_duration: Duration,
27    /// Penalty duration for exceeding limit
28    pub penalty_duration: Duration,
29}
30
31impl Default for RateLimitConfig {
32    fn default() -> Self {
33        Self {
34            max_requests: 100,
35            window_duration: Duration::from_secs(60),
36            penalty_duration: Duration::from_secs(300), // 5 minutes
37        }
38    }
39}
40
41/// DoS protection configuration
42#[derive(Debug, Clone)]
43pub struct DosProtectionConfig {
44    /// Maximum request rate (requests per second) before triggering protection
45    pub max_rate: f64,
46    /// Duration to monitor for DoS attacks
47    pub monitor_duration: Duration,
48    /// Duration to block suspected DoS attackers
49    pub block_duration: Duration,
50}
51
52impl Default for DosProtectionConfig {
53    fn default() -> Self {
54        Self {
55            max_rate: 50.0, // 50 requests per second
56            monitor_duration: Duration::from_secs(10),
57            block_duration: Duration::from_secs(600), // 10 minutes
58        }
59    }
60}
61
62/// IP blacklist configuration
63#[derive(Debug, Clone)]
64pub struct IpBlacklistConfig {
65    /// Duration to keep IPs in blacklist
66    pub blacklist_duration: Duration,
67    /// Maximum failed attempts before blacklisting
68    pub max_failed_attempts: u32,
69    /// Time window for counting failed attempts
70    pub attempt_window: Duration,
71}
72
73impl Default for IpBlacklistConfig {
74    fn default() -> Self {
75        Self {
76            blacklist_duration: Duration::from_secs(3600), // 1 hour
77            max_failed_attempts: 10,
78            attempt_window: Duration::from_secs(300), // 5 minutes
79        }
80    }
81}
82
83/// Request tracking information
84#[derive(Debug, Clone)]
85struct RequestInfo {
86    count: u32,
87    first_request: Instant,
88    last_request: Instant,
89    penalty_until: Option<Instant>,
90}
91
92/// DoS tracking information
93#[derive(Debug, Clone)]
94struct DosInfo {
95    request_times: Vec<Instant>,
96    blocked_until: Option<Instant>,
97}
98
99/// Failed attempt tracking
100#[derive(Debug, Clone)]
101struct FailureInfo {
102    attempts: u32,
103    first_attempt: Instant,
104    blacklisted_until: Option<Instant>,
105}
106
107/// Security manager for handling rate limiting, DoS protection, and IP blacklisting
108pub struct SecurityManager {
109    rate_limit_config: RateLimitConfig,
110    dos_config: DosProtectionConfig,
111    blacklist_config: IpBlacklistConfig,
112
113    // Rate limiting state
114    rate_limits: Arc<RwLock<HashMap<IpAddr, RequestInfo>>>,
115
116    // DoS protection state
117    dos_tracking: Arc<RwLock<HashMap<IpAddr, DosInfo>>>,
118
119    // IP blacklisting state
120    failure_tracking: Arc<RwLock<HashMap<IpAddr, FailureInfo>>>,
121    manual_blacklist: Arc<RwLock<Vec<IpAddr>>>,
122}
123
124impl SecurityManager {
125    /// Create a new security manager with default configuration
126    pub fn new() -> Self {
127        Self::with_config(
128            RateLimitConfig::default(),
129            DosProtectionConfig::default(),
130            IpBlacklistConfig::default(),
131        )
132    }
133
134    /// Create a new security manager with custom configuration
135    pub fn with_config(
136        rate_limit_config: RateLimitConfig,
137        dos_config: DosProtectionConfig,
138        blacklist_config: IpBlacklistConfig,
139    ) -> Self {
140        Self {
141            rate_limit_config,
142            dos_config,
143            blacklist_config,
144            rate_limits: Arc::new(RwLock::new(HashMap::new())),
145            dos_tracking: Arc::new(RwLock::new(HashMap::new())),
146            failure_tracking: Arc::new(RwLock::new(HashMap::new())),
147            manual_blacklist: Arc::new(RwLock::new(Vec::new())),
148        }
149    }
150
151    /// Check if IP is rate limited
152    pub async fn check_rate_limit(&self, ip: IpAddr) -> bool {
153        let now = Instant::now();
154        let mut rate_limits = self.rate_limits.write().await;
155
156        // Clean expired entries
157        rate_limits.retain(|_, info| {
158            now.duration_since(info.first_request) < self.rate_limit_config.window_duration * 2
159        });
160
161        let info = rate_limits.entry(ip).or_insert_with(|| RequestInfo {
162            count: 0,
163            first_request: now,
164            last_request: now,
165            penalty_until: None,
166        });
167
168        // Check if still under penalty
169        if let Some(penalty_until) = info.penalty_until {
170            if now < penalty_until {
171                return false; // Still penalized
172            } else {
173                info.penalty_until = None; // Clear expired penalty
174            }
175        }
176
177        // Reset window if needed
178        if now.duration_since(info.first_request) > self.rate_limit_config.window_duration {
179            info.count = 0;
180            info.first_request = now;
181        }
182
183        info.count += 1;
184        info.last_request = now;
185
186        // Check if limit exceeded
187        if info.count > self.rate_limit_config.max_requests {
188            info.penalty_until = Some(now + self.rate_limit_config.penalty_duration);
189            return false;
190        }
191
192        true
193    }
194
195    /// Check for DoS attacks
196    pub async fn check_dos_protection(&self, ip: IpAddr) -> bool {
197        let now = Instant::now();
198        let mut dos_tracking = self.dos_tracking.write().await;
199
200        // Clean expired entries
201        dos_tracking.retain(|_, info| {
202            if let Some(blocked_until) = info.blocked_until {
203                now < blocked_until
204            } else {
205                info.request_times.first().is_some_and(|first| {
206                    now.duration_since(*first) < self.dos_config.monitor_duration * 2
207                })
208            }
209        });
210
211        let info = dos_tracking.entry(ip).or_insert_with(|| DosInfo {
212            request_times: Vec::new(),
213            blocked_until: None,
214        });
215
216        // Check if still blocked
217        if let Some(blocked_until) = info.blocked_until {
218            if now < blocked_until {
219                return false; // Still blocked
220            } else {
221                info.blocked_until = None; // Clear expired block
222                info.request_times.clear(); // Reset tracking
223            }
224        }
225
226        // Add current request
227        info.request_times.push(now);
228
229        // Remove old requests outside monitor window
230        info.request_times
231            .retain(|&time| now.duration_since(time) <= self.dos_config.monitor_duration);
232
233        // Check if DoS threshold exceeded
234        let rate = info.request_times.len() as f64 / self.dos_config.monitor_duration.as_secs_f64();
235        if rate > self.dos_config.max_rate {
236            info.blocked_until = Some(now + self.dos_config.block_duration);
237            return false;
238        }
239
240        true
241    }
242
243    /// Check if IP is blacklisted
244    pub async fn check_blacklist(&self, ip: IpAddr) -> bool {
245        // Check manual blacklist
246        let manual_blacklist = self.manual_blacklist.read().await;
247        if manual_blacklist.contains(&ip) {
248            return false;
249        }
250        drop(manual_blacklist);
251
252        // Check automatic blacklist
253        let now = Instant::now();
254        let mut failure_tracking = self.failure_tracking.write().await;
255
256        // Clean expired entries
257        failure_tracking.retain(|_, info| {
258            if let Some(blacklisted_until) = info.blacklisted_until {
259                now < blacklisted_until
260            } else {
261                now.duration_since(info.first_attempt) < self.blacklist_config.attempt_window * 2
262            }
263        });
264
265        if let Some(info) = failure_tracking.get(&ip)
266            && let Some(blacklisted_until) = info.blacklisted_until
267        {
268            return now >= blacklisted_until;
269        }
270
271        true
272    }
273
274    /// Record a failed authentication attempt
275    pub async fn record_failure(&self, ip: IpAddr) {
276        let now = Instant::now();
277        let mut failure_tracking = self.failure_tracking.write().await;
278
279        let info = failure_tracking.entry(ip).or_insert_with(|| FailureInfo {
280            attempts: 0,
281            first_attempt: now,
282            blacklisted_until: None,
283        });
284
285        // Reset window if needed
286        if now.duration_since(info.first_attempt) > self.blacklist_config.attempt_window {
287            info.attempts = 0;
288            info.first_attempt = now;
289        }
290
291        info.attempts += 1;
292
293        // Check if should blacklist
294        if info.attempts >= self.blacklist_config.max_failed_attempts {
295            info.blacklisted_until = Some(now + self.blacklist_config.blacklist_duration);
296        }
297    }
298
299    /// Manually add IP to blacklist
300    pub async fn add_to_blacklist(&self, ip: IpAddr) {
301        let mut manual_blacklist = self.manual_blacklist.write().await;
302        if !manual_blacklist.contains(&ip) {
303            manual_blacklist.push(ip);
304        }
305    }
306
307    /// Remove IP from manual blacklist
308    pub async fn remove_from_blacklist(&self, ip: IpAddr) {
309        let mut manual_blacklist = self.manual_blacklist.write().await;
310        manual_blacklist.retain(|&x| x != ip);
311    }
312
313    /// Get security statistics
314    pub async fn get_stats(&self) -> SecurityStats {
315        let rate_limits = self.rate_limits.read().await;
316        let dos_tracking = self.dos_tracking.read().await;
317        let failure_tracking = self.failure_tracking.read().await;
318        let manual_blacklist = self.manual_blacklist.read().await;
319
320        let now = Instant::now();
321
322        SecurityStats {
323            total_rate_limited_ips: rate_limits.len(),
324            currently_penalized_ips: rate_limits
325                .values()
326                .filter(|info| info.penalty_until.is_some_and(|until| now < until))
327                .count(),
328            total_dos_tracked_ips: dos_tracking.len(),
329            currently_blocked_ips: dos_tracking
330                .values()
331                .filter(|info| info.blocked_until.is_some_and(|until| now < until))
332                .count(),
333            total_failure_tracked_ips: failure_tracking.len(),
334            currently_blacklisted_ips: failure_tracking
335                .values()
336                .filter(|info| info.blacklisted_until.is_some_and(|until| now < until))
337                .count()
338                + manual_blacklist.len(),
339            manual_blacklist_size: manual_blacklist.len(),
340        }
341    }
342}
343
344impl Default for SecurityManager {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350/// Security statistics
351#[derive(Debug, Serialize)]
352pub struct SecurityStats {
353    pub total_rate_limited_ips: usize,
354    pub currently_penalized_ips: usize,
355    pub total_dos_tracked_ips: usize,
356    pub currently_blocked_ips: usize,
357    pub total_failure_tracked_ips: usize,
358    pub currently_blacklisted_ips: usize,
359    pub manual_blacklist_size: usize,
360}
361
362/// Security middleware
363pub async fn security_middleware(
364    ConnectInfo(addr): ConnectInfo<SocketAddr>,
365    State(state): State<ApiState>,
366    request: Request<axum::body::Body>,
367    next: Next,
368) -> Result<Response, StatusCode> {
369    let ip = addr.ip();
370
371    // Get security manager
372    let security_manager = match state.auth_framework.security_manager() {
373        Some(manager) => manager,
374        None => return Ok(next.run(request).await), // No security manager, allow request
375    };
376
377    // Check blacklist first
378    if !security_manager.check_blacklist(ip).await {
379        return Err(StatusCode::FORBIDDEN);
380    }
381
382    // Check DoS protection
383    if !security_manager.check_dos_protection(ip).await {
384        return Err(StatusCode::TOO_MANY_REQUESTS);
385    }
386
387    // Check rate limiting
388    if !security_manager.check_rate_limit(ip).await {
389        return Err(StatusCode::TOO_MANY_REQUESTS);
390    }
391
392    // Allow request
393    Ok(next.run(request).await)
394}
395
396// ============================================================================
397// API Endpoints
398// ============================================================================
399
400/// Get security statistics
401pub async fn get_security_stats(
402    State(state): State<ApiState>,
403) -> Result<Json<ApiResponse<SecurityStats>>, StatusCode> {
404    let security_manager = match state.auth_framework.security_manager() {
405        Some(manager) => manager,
406        None => return Err(StatusCode::SERVICE_UNAVAILABLE),
407    };
408
409    let stats = security_manager.get_stats().await;
410    Ok(Json(ApiResponse::success(stats)))
411}
412
413/// Blacklist management request
414#[derive(Debug, Deserialize)]
415pub struct BlacklistRequest {
416    pub ip: IpAddr,
417}
418
419/// Add IP to blacklist
420pub async fn add_to_blacklist(
421    State(state): State<ApiState>,
422    Json(request): Json<BlacklistRequest>,
423) -> Result<Json<ApiResponse<()>>, StatusCode> {
424    let security_manager = match state.auth_framework.security_manager() {
425        Some(manager) => manager,
426        None => return Err(StatusCode::SERVICE_UNAVAILABLE),
427    };
428
429    security_manager.add_to_blacklist(request.ip).await;
430    Ok(Json(ApiResponse::success_with_message(
431        (),
432        "IP added to blacklist",
433    )))
434}
435
436/// Remove IP from blacklist
437pub async fn remove_from_blacklist(
438    State(state): State<ApiState>,
439    Json(request): Json<BlacklistRequest>,
440) -> Result<Json<ApiResponse<()>>, StatusCode> {
441    let security_manager = match state.auth_framework.security_manager() {
442        Some(manager) => manager,
443        None => return Err(StatusCode::SERVICE_UNAVAILABLE),
444    };
445
446    security_manager.remove_from_blacklist(request.ip).await;
447    Ok(Json(ApiResponse::success_with_message(
448        (),
449        "IP removed from blacklist",
450    )))
451}