1use crate::api::{ApiResponse, ApiState};
6use axum::{
7 Json,
8 extract::{ConnectInfo, State},
9 http::{Request, StatusCode},
10 middleware::Next,
11 response::Response,
12};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::net::{IpAddr, SocketAddr};
16use std::sync::Arc;
17use std::time::{Duration, Instant};
18use tokio::sync::RwLock;
19
20#[derive(Debug, Clone)]
22pub struct RateLimitConfig {
23 pub max_requests: u32,
25 pub window_duration: Duration,
27 pub penalty_duration: Duration,
29}
30
31impl Default for RateLimitConfig {
32 fn default() -> Self {
33 Self {
34 max_requests: 100,
35 window_duration: Duration::from_secs(60),
36 penalty_duration: Duration::from_secs(300), }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct DosProtectionConfig {
44 pub max_rate: f64,
46 pub monitor_duration: Duration,
48 pub block_duration: Duration,
50}
51
52impl Default for DosProtectionConfig {
53 fn default() -> Self {
54 Self {
55 max_rate: 50.0, monitor_duration: Duration::from_secs(10),
57 block_duration: Duration::from_secs(600), }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct IpBlacklistConfig {
65 pub blacklist_duration: Duration,
67 pub max_failed_attempts: u32,
69 pub attempt_window: Duration,
71}
72
73impl Default for IpBlacklistConfig {
74 fn default() -> Self {
75 Self {
76 blacklist_duration: Duration::from_secs(3600), max_failed_attempts: 10,
78 attempt_window: Duration::from_secs(300), }
80 }
81}
82
83#[derive(Debug, Clone)]
85struct RequestInfo {
86 count: u32,
87 first_request: Instant,
88 last_request: Instant,
89 penalty_until: Option<Instant>,
90}
91
92#[derive(Debug, Clone)]
94struct DosInfo {
95 request_times: Vec<Instant>,
96 blocked_until: Option<Instant>,
97}
98
99#[derive(Debug, Clone)]
101struct FailureInfo {
102 attempts: u32,
103 first_attempt: Instant,
104 blacklisted_until: Option<Instant>,
105}
106
107pub struct SecurityManager {
109 rate_limit_config: RateLimitConfig,
110 dos_config: DosProtectionConfig,
111 blacklist_config: IpBlacklistConfig,
112
113 rate_limits: Arc<RwLock<HashMap<IpAddr, RequestInfo>>>,
115
116 dos_tracking: Arc<RwLock<HashMap<IpAddr, DosInfo>>>,
118
119 failure_tracking: Arc<RwLock<HashMap<IpAddr, FailureInfo>>>,
121 manual_blacklist: Arc<RwLock<Vec<IpAddr>>>,
122}
123
124impl SecurityManager {
125 pub fn new() -> Self {
127 Self::with_config(
128 RateLimitConfig::default(),
129 DosProtectionConfig::default(),
130 IpBlacklistConfig::default(),
131 )
132 }
133
134 pub fn with_config(
136 rate_limit_config: RateLimitConfig,
137 dos_config: DosProtectionConfig,
138 blacklist_config: IpBlacklistConfig,
139 ) -> Self {
140 Self {
141 rate_limit_config,
142 dos_config,
143 blacklist_config,
144 rate_limits: Arc::new(RwLock::new(HashMap::new())),
145 dos_tracking: Arc::new(RwLock::new(HashMap::new())),
146 failure_tracking: Arc::new(RwLock::new(HashMap::new())),
147 manual_blacklist: Arc::new(RwLock::new(Vec::new())),
148 }
149 }
150
151 pub async fn check_rate_limit(&self, ip: IpAddr) -> bool {
153 let now = Instant::now();
154 let mut rate_limits = self.rate_limits.write().await;
155
156 rate_limits.retain(|_, info| {
158 now.duration_since(info.first_request) < self.rate_limit_config.window_duration * 2
159 });
160
161 let info = rate_limits.entry(ip).or_insert_with(|| RequestInfo {
162 count: 0,
163 first_request: now,
164 last_request: now,
165 penalty_until: None,
166 });
167
168 if let Some(penalty_until) = info.penalty_until {
170 if now < penalty_until {
171 return false; } else {
173 info.penalty_until = None; }
175 }
176
177 if now.duration_since(info.first_request) > self.rate_limit_config.window_duration {
179 info.count = 0;
180 info.first_request = now;
181 }
182
183 info.count += 1;
184 info.last_request = now;
185
186 if info.count > self.rate_limit_config.max_requests {
188 info.penalty_until = Some(now + self.rate_limit_config.penalty_duration);
189 return false;
190 }
191
192 true
193 }
194
195 pub async fn check_dos_protection(&self, ip: IpAddr) -> bool {
197 let now = Instant::now();
198 let mut dos_tracking = self.dos_tracking.write().await;
199
200 dos_tracking.retain(|_, info| {
202 if let Some(blocked_until) = info.blocked_until {
203 now < blocked_until
204 } else {
205 info.request_times.first().is_some_and(|first| {
206 now.duration_since(*first) < self.dos_config.monitor_duration * 2
207 })
208 }
209 });
210
211 let info = dos_tracking.entry(ip).or_insert_with(|| DosInfo {
212 request_times: Vec::new(),
213 blocked_until: None,
214 });
215
216 if let Some(blocked_until) = info.blocked_until {
218 if now < blocked_until {
219 return false; } else {
221 info.blocked_until = None; info.request_times.clear(); }
224 }
225
226 info.request_times.push(now);
228
229 info.request_times
231 .retain(|&time| now.duration_since(time) <= self.dos_config.monitor_duration);
232
233 let rate = info.request_times.len() as f64 / self.dos_config.monitor_duration.as_secs_f64();
235 if rate > self.dos_config.max_rate {
236 info.blocked_until = Some(now + self.dos_config.block_duration);
237 return false;
238 }
239
240 true
241 }
242
243 pub async fn check_blacklist(&self, ip: IpAddr) -> bool {
245 let manual_blacklist = self.manual_blacklist.read().await;
247 if manual_blacklist.contains(&ip) {
248 return false;
249 }
250 drop(manual_blacklist);
251
252 let now = Instant::now();
254 let mut failure_tracking = self.failure_tracking.write().await;
255
256 failure_tracking.retain(|_, info| {
258 if let Some(blacklisted_until) = info.blacklisted_until {
259 now < blacklisted_until
260 } else {
261 now.duration_since(info.first_attempt) < self.blacklist_config.attempt_window * 2
262 }
263 });
264
265 if let Some(info) = failure_tracking.get(&ip)
266 && let Some(blacklisted_until) = info.blacklisted_until
267 {
268 return now >= blacklisted_until;
269 }
270
271 true
272 }
273
274 pub async fn record_failure(&self, ip: IpAddr) {
276 let now = Instant::now();
277 let mut failure_tracking = self.failure_tracking.write().await;
278
279 let info = failure_tracking.entry(ip).or_insert_with(|| FailureInfo {
280 attempts: 0,
281 first_attempt: now,
282 blacklisted_until: None,
283 });
284
285 if now.duration_since(info.first_attempt) > self.blacklist_config.attempt_window {
287 info.attempts = 0;
288 info.first_attempt = now;
289 }
290
291 info.attempts += 1;
292
293 if info.attempts >= self.blacklist_config.max_failed_attempts {
295 info.blacklisted_until = Some(now + self.blacklist_config.blacklist_duration);
296 }
297 }
298
299 pub async fn add_to_blacklist(&self, ip: IpAddr) {
301 let mut manual_blacklist = self.manual_blacklist.write().await;
302 if !manual_blacklist.contains(&ip) {
303 manual_blacklist.push(ip);
304 }
305 }
306
307 pub async fn remove_from_blacklist(&self, ip: IpAddr) {
309 let mut manual_blacklist = self.manual_blacklist.write().await;
310 manual_blacklist.retain(|&x| x != ip);
311 }
312
313 pub async fn get_stats(&self) -> SecurityStats {
315 let rate_limits = self.rate_limits.read().await;
316 let dos_tracking = self.dos_tracking.read().await;
317 let failure_tracking = self.failure_tracking.read().await;
318 let manual_blacklist = self.manual_blacklist.read().await;
319
320 let now = Instant::now();
321
322 SecurityStats {
323 total_rate_limited_ips: rate_limits.len(),
324 currently_penalized_ips: rate_limits
325 .values()
326 .filter(|info| info.penalty_until.is_some_and(|until| now < until))
327 .count(),
328 total_dos_tracked_ips: dos_tracking.len(),
329 currently_blocked_ips: dos_tracking
330 .values()
331 .filter(|info| info.blocked_until.is_some_and(|until| now < until))
332 .count(),
333 total_failure_tracked_ips: failure_tracking.len(),
334 currently_blacklisted_ips: failure_tracking
335 .values()
336 .filter(|info| info.blacklisted_until.is_some_and(|until| now < until))
337 .count()
338 + manual_blacklist.len(),
339 manual_blacklist_size: manual_blacklist.len(),
340 }
341 }
342}
343
344impl Default for SecurityManager {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350#[derive(Debug, Serialize)]
352pub struct SecurityStats {
353 pub total_rate_limited_ips: usize,
354 pub currently_penalized_ips: usize,
355 pub total_dos_tracked_ips: usize,
356 pub currently_blocked_ips: usize,
357 pub total_failure_tracked_ips: usize,
358 pub currently_blacklisted_ips: usize,
359 pub manual_blacklist_size: usize,
360}
361
362pub async fn security_middleware(
364 ConnectInfo(addr): ConnectInfo<SocketAddr>,
365 State(state): State<ApiState>,
366 request: Request<axum::body::Body>,
367 next: Next,
368) -> Result<Response, StatusCode> {
369 let ip = addr.ip();
370
371 let security_manager = match state.auth_framework.security_manager() {
373 Some(manager) => manager,
374 None => return Ok(next.run(request).await), };
376
377 if !security_manager.check_blacklist(ip).await {
379 return Err(StatusCode::FORBIDDEN);
380 }
381
382 if !security_manager.check_dos_protection(ip).await {
384 return Err(StatusCode::TOO_MANY_REQUESTS);
385 }
386
387 if !security_manager.check_rate_limit(ip).await {
389 return Err(StatusCode::TOO_MANY_REQUESTS);
390 }
391
392 Ok(next.run(request).await)
394}
395
396pub async fn get_security_stats(
402 State(state): State<ApiState>,
403) -> Result<Json<ApiResponse<SecurityStats>>, StatusCode> {
404 let security_manager = match state.auth_framework.security_manager() {
405 Some(manager) => manager,
406 None => return Err(StatusCode::SERVICE_UNAVAILABLE),
407 };
408
409 let stats = security_manager.get_stats().await;
410 Ok(Json(ApiResponse::success(stats)))
411}
412
413#[derive(Debug, Deserialize)]
415pub struct BlacklistRequest {
416 pub ip: IpAddr,
417}
418
419pub async fn add_to_blacklist(
421 State(state): State<ApiState>,
422 Json(request): Json<BlacklistRequest>,
423) -> Result<Json<ApiResponse<()>>, StatusCode> {
424 let security_manager = match state.auth_framework.security_manager() {
425 Some(manager) => manager,
426 None => return Err(StatusCode::SERVICE_UNAVAILABLE),
427 };
428
429 security_manager.add_to_blacklist(request.ip).await;
430 Ok(Json(ApiResponse::success_with_message(
431 (),
432 "IP added to blacklist",
433 )))
434}
435
436pub async fn remove_from_blacklist(
438 State(state): State<ApiState>,
439 Json(request): Json<BlacklistRequest>,
440) -> Result<Json<ApiResponse<()>>, StatusCode> {
441 let security_manager = match state.auth_framework.security_manager() {
442 Some(manager) => manager,
443 None => return Err(StatusCode::SERVICE_UNAVAILABLE),
444 };
445
446 security_manager.remove_from_blacklist(request.ip).await;
447 Ok(Json(ApiResponse::success_with_message(
448 (),
449 "IP removed from blacklist",
450 )))
451}