1use actix_web::{
2 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
3 error::{ErrorForbidden, ErrorTooManyRequests},
4 http::header,
5 rt, Error,
6};
7use futures_util::future::LocalBoxFuture;
8use std::{
9 collections::HashMap,
10 future::ready,
11 rc::Rc,
12 sync::{Arc, Mutex},
13 time::{Duration, Instant},
14};
15use tokio::time;
16
17#[derive(Debug, Clone)]
18struct IpStats {
19 count: u32,
20 window_start: Instant,
21 banned_until: Option<Instant>,
22}
23
24impl IpStats {
25 fn is_expired(&self, now: Instant, window_secs: u64) -> bool {
26 now.duration_since(self.window_start).as_secs() > window_secs
27 }
28
29 fn is_banned(&self, now: Instant) -> bool {
30 matches!(self.banned_until, Some(until) if now < until)
31 }
32
33 fn reset_window(&mut self, now: Instant) {
34 self.count = 1;
35 self.window_start = now;
36 }
37
38 fn increment(&mut self) {
39 self.count += 1;
40 }
41
42 fn ban(&mut self, ban_duration: Duration) {
43 self.banned_until = Some(Instant::now() + ban_duration);
44 }
45
46 fn clear_ban(&mut self) {
47 self.banned_until = None;
48 }
49}
50
51pub struct DdosConfig {
53 pub max_requests: u32,
54 pub window_secs: u64,
55 pub ban_duration_secs: u64,
56 pub block_missing_ua: bool,
57 pub blocked_agents: Vec<String>,
58 pub cleanup_interval_secs: u64,
59 pub max_ip_records: usize,
60}
61
62impl Default for DdosConfig {
63 fn default() -> Self {
64 Self {
65 max_requests: 50,
66 window_secs: 60,
67 ban_duration_secs: 86400,
68 block_missing_ua: false,
69 blocked_agents: vec![],
70 cleanup_interval_secs: 300,
71 max_ip_records: 10000,
72 }
73 }
74}
75
76#[derive(Clone)]
77pub struct DdosShield {
78 config: Arc<DdosConfig>,
79 ip_records: Arc<Mutex<HashMap<String, IpStats>>>,
81}
82
83impl DdosShield {
84 pub fn new() -> Self {
85 let shield = Self {
86 config: Arc::new(DdosConfig::default()),
87 ip_records: Arc::new(Mutex::new(HashMap::with_capacity(1024))),
88 };
89 shield.start_cleanup_task();
90 shield
91 }
92
93 pub fn builder() -> DdosShieldBuilder {
94 DdosShieldBuilder::default()
95 }
96
97 fn start_cleanup_task(&self) {
98 let ip_records = self.ip_records.clone();
99 let config = self.config.clone();
100
101 rt::spawn(async move {
103 let mut interval = time::interval(Duration::from_secs(config.cleanup_interval_secs));
104 loop {
105 interval.tick().await;
106 Self::cleanup_old_records(&ip_records, &config);
108 }
109 });
110 }
111
112 fn cleanup_old_records(ip_records: &Arc<Mutex<HashMap<String, IpStats>>>, config: &DdosConfig) {
114 let mut records = ip_records.lock().unwrap();
115 let now = Instant::now();
116 let ban_duration = Duration::from_secs(config.ban_duration_secs);
117 let window_duration = Duration::from_secs(config.window_secs);
118
119 records.retain(|_, stats| {
121 let ban_expired = stats.banned_until.map_or(false, |until| now >= until);
122 let window_expired = now.duration_since(stats.window_start) > window_duration + ban_duration;
123 !((ban_expired || stats.banned_until.is_none()) && window_expired)
124 });
125
126 if records.len() > config.max_ip_records {
129 let overage = records.len() - config.max_ip_records;
130 let keys_to_remove: Vec<String> = records.keys().take(overage).cloned().collect();
131 for key in keys_to_remove {
132 records.remove(&key);
133 }
134 }
135 }
136
137 fn check_user_agent(&self, req: &ServiceRequest) -> Result<(), Error> {
138 let user_agent = req
139 .headers()
140 .get(header::USER_AGENT)
141 .and_then(|h| h.to_str().ok())
142 .unwrap_or("")
143 .to_lowercase();
144
145 if self.config.block_missing_ua && user_agent.is_empty() {
146 return Err(ErrorForbidden("Blocked: Missing User-Agent"));
147 }
148
149 if self.config.blocked_agents.iter().any(|bot| user_agent.contains(bot)) {
150 return Err(ErrorForbidden("Blocked: Malicious Actor Detected"));
151 }
152
153 Ok(())
154 }
155
156 fn check_rate_limit(&self, ip: &str) -> Result<(), String> {
157 let mut records = self.ip_records.lock().unwrap();
159 let now = Instant::now();
160
161 let stats = records
162 .entry(ip.to_string())
163 .or_insert_with(|| IpStats {
164 count: 0,
165 window_start: now,
166 banned_until: None,
167 });
168
169 if stats.is_banned(now) {
170 return Err("Your IP is banned due to previous abuse.".to_string());
171 }
172
173 if stats.banned_until.is_some() {
174 stats.clear_ban();
175 stats.reset_window(now);
176 return Ok(());
177 }
178
179 if stats.is_expired(now, self.config.window_secs) {
180 stats.reset_window(now);
181 return Ok(());
182 }
183
184 stats.increment();
185
186 if stats.count > self.config.max_requests {
187 stats.ban(Duration::from_secs(self.config.ban_duration_secs));
188 Err("Rate limit exceeded. Your IP has been temporarily banned.".to_string())
189 } else {
190 Ok(())
191 }
192 }
193}
194
195impl Default for DdosShield {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[derive(Default)]
202pub struct DdosShieldBuilder {
203 config: DdosConfig,
204}
205
206impl DdosShieldBuilder {
207 pub fn max_requests(mut self, reqs: u32) -> Self { self.config.max_requests = reqs; self }
209 pub fn window_secs(mut self, secs: u64) -> Self { self.config.window_secs = secs; self }
210 pub fn ban_duration_secs(mut self, secs: u64) -> Self { self.config.ban_duration_secs = secs; self }
211 pub fn block_agent(mut self, agent: &str) -> Self { self.config.blocked_agents.push(agent.to_lowercase()); self }
212 pub fn allow_missing_ua(mut self, allow: bool) -> Self { self.config.block_missing_ua = !allow; self }
213 pub fn cleanup_interval_secs(mut self, secs: u64) -> Self { self.config.cleanup_interval_secs = secs; self }
214 pub fn max_ip_records(mut self, max: usize) -> Self { self.config.max_ip_records = max; self }
215
216 pub fn build(self) -> DdosShield {
217 let shield = DdosShield {
218 config: Arc::new(self.config),
219 ip_records: Arc::new(Mutex::new(HashMap::with_capacity(1024))),
220 };
221 shield.start_cleanup_task();
222 shield
223 }
224}
225
226impl<S, B> Transform<S, ServiceRequest> for DdosShield
227where
228 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static, S::Future: 'static,
230 B: 'static,
231{
232 type Response = ServiceResponse<B>;
233 type Error = Error;
234 type InitError = ();
235 type Transform = DdosShieldMiddleware<S>;
236 type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
237
238 fn new_transform(&self, service: S) -> Self::Future {
239 ready(Ok(DdosShieldMiddleware {
240 service: Rc::new(service), shield: self.clone(),
242 }))
243 }
244}
245
246pub struct DdosShieldMiddleware<S> {
247 service: Rc<S>, shield: DdosShield,
249}
250
251impl<S, B> Service<ServiceRequest> for DdosShieldMiddleware<S>
252where
253 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
254 S::Future: 'static,
255 B: 'static,
256{
257 type Response = ServiceResponse<B>;
258 type Error = Error;
259 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
260
261 forward_ready!(service);
262
263 fn call(&self, req: ServiceRequest) -> Self::Future {
264 let shield = self.shield.clone();
265 let service = Rc::clone(&self.service); Box::pin(async move {
268 if let Err(err) = shield.check_user_agent(&req) {
270 return Err(err);
271 }
272
273 let ip = req
276 .connection_info()
277 .realip_remote_addr()
278 .unwrap_or("unknown_ip")
279 .to_string();
280
281 match shield.check_rate_limit(&ip) {
283 Ok(()) => service.call(req).await,
284 Err(msg) => Err(ErrorTooManyRequests(msg)),
285 }
286 })
287 }
288}