Skip to main content

fr_rust/ddos/
ddos.rs

1use actix_web::{
2    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
3    error::{ErrorForbidden, ErrorTooManyRequests},
4    http::header,
5    rt, Error,
6};
7use futures_util::future::LocalBoxFuture;
8use std::{
9    collections::HashMap,
10    future::ready,
11    rc::Rc,
12    sync::{Arc, Mutex},
13    time::{Duration, Instant},
14};
15use tokio::time;
16
17#[derive(Debug, Clone)]
18struct IpStats {
19    count: u32,
20    window_start: Instant,
21    banned_until: Option<Instant>,
22}
23
24impl IpStats {
25    fn is_expired(&self, now: Instant, window_secs: u64) -> bool {
26        now.duration_since(self.window_start).as_secs() > window_secs
27    }
28
29    fn is_banned(&self, now: Instant) -> bool {
30        matches!(self.banned_until, Some(until) if now < until)
31    }
32
33    fn reset_window(&mut self, now: Instant) {
34        self.count = 1;
35        self.window_start = now;
36    }
37
38    fn increment(&mut self) {
39        self.count += 1;
40    }
41
42    fn ban(&mut self, ban_duration: Duration) {
43        self.banned_until = Some(Instant::now() + ban_duration);
44    }
45
46    fn clear_ban(&mut self) {
47        self.banned_until = None;
48    }
49}
50
51// Wrapping config in Arc avoids cloning large Vecs on every middleware initialization
52pub struct DdosConfig {
53    pub max_requests: u32,
54    pub window_secs: u64,
55    pub ban_duration_secs: u64,
56    pub block_missing_ua: bool,
57    pub blocked_agents: Vec<String>,
58    pub cleanup_interval_secs: u64,
59    pub max_ip_records: usize,
60}
61
62impl Default for DdosConfig {
63    fn default() -> Self {
64        Self {
65            max_requests: 50,
66            window_secs: 60,
67            ban_duration_secs: 86400,
68            block_missing_ua: false,
69            blocked_agents: vec![],
70            cleanup_interval_secs: 300,
71            max_ip_records: 10000,
72        }
73    }
74}
75
76#[derive(Clone)]
77pub struct DdosShield {
78    config: Arc<DdosConfig>,
79    // Switched to std::sync::Mutex for faster, synchronous non-blocking updates
80    ip_records: Arc<Mutex<HashMap<String, IpStats>>>, 
81}
82
83impl DdosShield {
84    pub fn new() -> Self {
85        let shield = Self {
86            config: Arc::new(DdosConfig::default()),
87            ip_records: Arc::new(Mutex::new(HashMap::with_capacity(1024))),
88        };
89        shield.start_cleanup_task();
90        shield
91    }
92
93    pub fn builder() -> DdosShieldBuilder {
94        DdosShieldBuilder::default()
95    }
96
97    fn start_cleanup_task(&self) {
98        let ip_records = self.ip_records.clone();
99        let config = self.config.clone();
100
101        // Fixed actix_rt -> actix_web::rt
102        rt::spawn(async move {
103            let mut interval = time::interval(Duration::from_secs(config.cleanup_interval_secs));
104            loop {
105                interval.tick().await;
106                // We no longer need .await here because we are using a standard Mutex
107                Self::cleanup_old_records(&ip_records, &config);
108            }
109        });
110    }
111
112    // This is now synchronous and heavily optimized
113    fn cleanup_old_records(ip_records: &Arc<Mutex<HashMap<String, IpStats>>>, config: &DdosConfig) {
114        let mut records = ip_records.lock().unwrap();
115        let now = Instant::now();
116        let ban_duration = Duration::from_secs(config.ban_duration_secs);
117        let window_duration = Duration::from_secs(config.window_secs);
118
119        // O(N) cleanup in-place without cloning strings
120        records.retain(|_, stats| {
121            let ban_expired = stats.banned_until.map_or(false, |until| now >= until);
122            let window_expired = now.duration_since(stats.window_start) > window_duration + ban_duration;
123            !((ban_expired || stats.banned_until.is_none()) && window_expired)
124        });
125
126        // Fast O(N) enforcement: If over capacity, arbitrarily drop elements to prevent OOM
127        // Much faster than sorting the entire map by timestamp inside a lock
128        if records.len() > config.max_ip_records {
129            let overage = records.len() - config.max_ip_records;
130            let keys_to_remove: Vec<String> = records.keys().take(overage).cloned().collect();
131            for key in keys_to_remove {
132                records.remove(&key);
133            }
134        }
135    }
136
137    fn check_user_agent(&self, req: &ServiceRequest) -> Result<(), Error> {
138        let user_agent = req
139            .headers()
140            .get(header::USER_AGENT)
141            .and_then(|h| h.to_str().ok())
142            .unwrap_or("")
143            .to_lowercase();
144
145        if self.config.block_missing_ua && user_agent.is_empty() {
146            return Err(ErrorForbidden("Blocked: Missing User-Agent"));
147        }
148
149        if self.config.blocked_agents.iter().any(|bot| user_agent.contains(bot)) {
150            return Err(ErrorForbidden("Blocked: Malicious Actor Detected"));
151        }
152
153        Ok(())
154    }
155
156    fn check_rate_limit(&self, ip: &str) -> Result<(), String> {
157        // Fast, synchronous locking
158        let mut records = self.ip_records.lock().unwrap();
159        let now = Instant::now();
160
161        let stats = records
162            .entry(ip.to_string())
163            .or_insert_with(|| IpStats {
164                count: 0,
165                window_start: now,
166                banned_until: None,
167            });
168
169        if stats.is_banned(now) {
170            return Err("Your IP is banned due to previous abuse.".to_string());
171        }
172
173        if stats.banned_until.is_some() {
174            stats.clear_ban();
175            stats.reset_window(now);
176            return Ok(());
177        }
178
179        if stats.is_expired(now, self.config.window_secs) {
180            stats.reset_window(now);
181            return Ok(());
182        }
183
184        stats.increment();
185
186        if stats.count > self.config.max_requests {
187            stats.ban(Duration::from_secs(self.config.ban_duration_secs));
188            Err("Rate limit exceeded. Your IP has been temporarily banned.".to_string())
189        } else {
190            Ok(())
191        }
192    }
193}
194
195impl Default for DdosShield {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201#[derive(Default)]
202pub struct DdosShieldBuilder {
203    config: DdosConfig,
204}
205
206impl DdosShieldBuilder {
207    // ... (Keep builder methods exactly the same as your original) ...
208    pub fn max_requests(mut self, reqs: u32) -> Self { self.config.max_requests = reqs; self }
209    pub fn window_secs(mut self, secs: u64) -> Self { self.config.window_secs = secs; self }
210    pub fn ban_duration_secs(mut self, secs: u64) -> Self { self.config.ban_duration_secs = secs; self }
211    pub fn block_agent(mut self, agent: &str) -> Self { self.config.blocked_agents.push(agent.to_lowercase()); self }
212    pub fn allow_missing_ua(mut self, allow: bool) -> Self { self.config.block_missing_ua = !allow; self }
213    pub fn cleanup_interval_secs(mut self, secs: u64) -> Self { self.config.cleanup_interval_secs = secs; self }
214    pub fn max_ip_records(mut self, max: usize) -> Self { self.config.max_ip_records = max; self }
215
216    pub fn build(self) -> DdosShield {
217        let shield = DdosShield {
218            config: Arc::new(self.config),
219            ip_records: Arc::new(Mutex::new(HashMap::with_capacity(1024))),
220        };
221        shield.start_cleanup_task();
222        shield
223    }
224}
225
226impl<S, B> Transform<S, ServiceRequest> for DdosShield
227where
228    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static, // Removed Clone requirement
229    S::Future: 'static,
230    B: 'static,
231{
232    type Response = ServiceResponse<B>;
233    type Error = Error;
234    type InitError = ();
235    type Transform = DdosShieldMiddleware<S>;
236    type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
237
238    fn new_transform(&self, service: S) -> Self::Future {
239        ready(Ok(DdosShieldMiddleware {
240            service: Rc::new(service), // Wrap in Rc here
241            shield: self.clone(),
242        }))
243    }
244}
245
246pub struct DdosShieldMiddleware<S> {
247    service: Rc<S>, // Rc instead of plain S
248    shield: DdosShield,
249}
250
251impl<S, B> Service<ServiceRequest> for DdosShieldMiddleware<S>
252where
253    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
254    S::Future: 'static,
255    B: 'static,
256{
257    type Response = ServiceResponse<B>;
258    type Error = Error;
259    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
260
261    forward_ready!(service);
262
263    fn call(&self, req: ServiceRequest) -> Self::Future {
264        let shield = self.shield.clone();
265        let service = Rc::clone(&self.service); // Cheap pointer clone
266
267        Box::pin(async move {
268            // Check UA synchronously
269            if let Err(err) = shield.check_user_agent(&req) {
270                return Err(err);
271            }
272
273            // Warning: ensure your app is behind a proxy and configured to trust it, 
274            // otherwise users can spoof this header.
275            let ip = req
276                .connection_info()
277                .realip_remote_addr()
278                .unwrap_or("unknown_ip")
279                .to_string();
280
281            // Check limits synchronously
282            match shield.check_rate_limit(&ip) {
283                Ok(()) => service.call(req).await,
284                Err(msg) => Err(ErrorTooManyRequests(msg)),
285            }
286        })
287    }
288}