Skip to main content

fr_rust/ddos/
ddos.rs

1use actix_web::{
2    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
3    error::{ErrorForbidden, ErrorTooManyRequests},
4    http::header,
5    Error,
6};
7use futures_util::future::LocalBoxFuture;
8use std::{
9    collections::HashMap,
10    future::ready,
11    rc::Rc,
12    sync::{Arc, RwLock},
13    time::{Duration, Instant},
14};
15#[derive(Debug)]
16struct IpStats {
17    count: u32,
18    window_start: Instant,
19    banned_until: Option<Instant>,
20}
21#[derive(Clone)]
22pub struct DdosConfig {
23    pub max_requests: u32,
24    pub window_secs: u64,
25    pub ban_duration_secs: u64,
26    pub block_missing_ua: bool,
27    pub blocked_agents: Vec<String>,
28}
29#[derive(Clone)]
30pub struct DdosShield {
31    config: DdosConfig,
32    ip_records: Arc<RwLock<HashMap<String, IpStats>>>,
33}
34impl DdosShield {
35    pub fn builder() -> Self {
36        Self {
37            config: DdosConfig {
38                max_requests: 50,
39                window_secs: 60,
40                ban_duration_secs: 86400,
41                block_missing_ua: false,
42                blocked_agents: vec!["curl".into()],
43            },
44            ip_records: Arc::new(RwLock::new(HashMap::new())),
45        }
46    }
47    pub fn max_requests(mut self, reqs: u32) -> Self {
48        self.config.max_requests = reqs;
49        self
50    }
51    pub fn window_secs(mut self, secs: u64) -> Self {
52        self.config.window_secs = secs;
53        self
54    }
55    pub fn ban_duration_secs(mut self, secs: u64) -> Self {
56        self.config.ban_duration_secs = secs;
57        self
58    }
59    pub fn block_agent(mut self, agent: &str) -> Self {
60        self.config.blocked_agents.push(agent.to_lowercase());
61        self
62    }
63    pub fn allow_missing_ua(mut self, allow: bool) -> Self {
64        self.config.block_missing_ua = !allow;
65        self
66    }
67    pub fn build(self) -> Self {
68        self
69    }
70}
71impl<S, B> Transform<S, ServiceRequest> for DdosShield
72where
73    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
74    S::Future: 'static,
75    B: 'static,
76{
77    type Response = ServiceResponse<B>;
78    type Error = Error;
79    type InitError = ();
80    type Transform = DdosShieldMiddleware<S>;
81    type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
82    fn new_transform(&self, service: S) -> Self::Future {
83        ready(Ok(DdosShieldMiddleware {
84            service: Rc::new(service),
85            config: self.config.clone(),
86            ip_records: self.ip_records.clone(),
87        }))
88    }
89}
90pub struct DdosShieldMiddleware<S> {
91    service: Rc<S>,
92    config: DdosConfig,
93    ip_records: Arc<RwLock<HashMap<String, IpStats>>>,
94}
95impl<S, B> Service<ServiceRequest> for DdosShieldMiddleware<S>
96where
97    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
98    S::Future: 'static,
99    B: 'static,
100{
101    type Response = ServiceResponse<B>;
102    type Error = Error;
103    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
104    forward_ready!(service);
105    fn call(&self, req: ServiceRequest) -> Self::Future {
106        let user_agent = req
107            .headers()
108            .get(header::USER_AGENT)
109            .and_then(|h| h.to_str().ok())
110            .unwrap_or("")
111            .to_lowercase();
112        if self.config.block_missing_ua && user_agent.is_empty() {
113            return Box::pin(ready(Err(ErrorForbidden("Blocked: Missing User-Agent"))));
114        }
115        if self.config.blocked_agents.iter().any(|bot| user_agent.contains(bot)) {
116            return Box::pin(ready(Err(ErrorForbidden("Blocked: Malicious Actor Detected"))));
117        }
118        let ip = req
119            .connection_info()
120            .realip_remote_addr()
121            .unwrap_or("unknown_ip")
122            .to_string();
123        let mut is_banned = false;
124        let mut triggered_ban = false;
125        {
126            let mut records = self.ip_records.write().unwrap();
127            let now = Instant::now();
128            let stats = records.entry(ip).or_insert(IpStats {
129                count: 0,
130                window_start: now,
131                banned_until: None,
132            });
133            if let Some(banned_time) = stats.banned_until {
134                if now < banned_time {
135                    is_banned = true;
136                } else {
137                    stats.banned_until = None;
138                    stats.count = 1;
139                    stats.window_start = now;
140                }
141            } else {
142                if now.duration_since(stats.window_start).as_secs() > self.config.window_secs {
143                    stats.count = 1;
144                    stats.window_start = now;
145                } else {
146                    stats.count += 1;
147                    if stats.count > self.config.max_requests {
148                        stats.banned_until = Some(now + Duration::from_secs(self.config.ban_duration_secs));
149                        triggered_ban = true;
150                        is_banned = true;
151                    }
152                }
153            }
154        }
155        if is_banned {
156            let msg = if triggered_ban {
157                "Rate limit exceeded. Your IP has been temporarily banned."
158            } else {
159                "Your IP is banned due to previous abuse."
160            };
161            return Box::pin(ready(Err(ErrorTooManyRequests(msg))));
162        }
163        let fut = self.service.call(req);
164        Box::pin(async move {
165            let res = fut.await?;
166            Ok(res)
167        })
168    }
169}