1use actix_web::{
2 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
3 error::{ErrorForbidden, ErrorTooManyRequests},
4 http::header,
5 Error,
6};
7use futures_util::future::LocalBoxFuture;
8use std::{
9 collections::HashMap,
10 future::ready,
11 rc::Rc,
12 sync::{Arc, RwLock},
13 time::{Duration, Instant},
14};
15#[derive(Debug)]
16struct IpStats {
17 count: u32,
18 window_start: Instant,
19 banned_until: Option<Instant>,
20}
21#[derive(Clone)]
22pub struct DdosConfig {
23 pub max_requests: u32,
24 pub window_secs: u64,
25 pub ban_duration_secs: u64,
26 pub block_missing_ua: bool,
27 pub blocked_agents: Vec<String>,
28}
29#[derive(Clone)]
30pub struct DdosShield {
31 config: DdosConfig,
32 ip_records: Arc<RwLock<HashMap<String, IpStats>>>,
33}
34impl DdosShield {
35 pub fn builder() -> Self {
36 Self {
37 config: DdosConfig {
38 max_requests: 50,
39 window_secs: 60,
40 ban_duration_secs: 86400,
41 block_missing_ua: false,
42 blocked_agents: vec!["curl".into()],
43 },
44 ip_records: Arc::new(RwLock::new(HashMap::new())),
45 }
46 }
47 pub fn max_requests(mut self, reqs: u32) -> Self {
48 self.config.max_requests = reqs;
49 self
50 }
51 pub fn window_secs(mut self, secs: u64) -> Self {
52 self.config.window_secs = secs;
53 self
54 }
55 pub fn ban_duration_secs(mut self, secs: u64) -> Self {
56 self.config.ban_duration_secs = secs;
57 self
58 }
59 pub fn block_agent(mut self, agent: &str) -> Self {
60 self.config.blocked_agents.push(agent.to_lowercase());
61 self
62 }
63 pub fn allow_missing_ua(mut self, allow: bool) -> Self {
64 self.config.block_missing_ua = !allow;
65 self
66 }
67 pub fn build(self) -> Self {
68 self
69 }
70}
71impl<S, B> Transform<S, ServiceRequest> for DdosShield
72where
73 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
74 S::Future: 'static,
75 B: 'static,
76{
77 type Response = ServiceResponse<B>;
78 type Error = Error;
79 type InitError = ();
80 type Transform = DdosShieldMiddleware<S>;
81 type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
82 fn new_transform(&self, service: S) -> Self::Future {
83 ready(Ok(DdosShieldMiddleware {
84 service: Rc::new(service),
85 config: self.config.clone(),
86 ip_records: self.ip_records.clone(),
87 }))
88 }
89}
90pub struct DdosShieldMiddleware<S> {
91 service: Rc<S>,
92 config: DdosConfig,
93 ip_records: Arc<RwLock<HashMap<String, IpStats>>>,
94}
95impl<S, B> Service<ServiceRequest> for DdosShieldMiddleware<S>
96where
97 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
98 S::Future: 'static,
99 B: 'static,
100{
101 type Response = ServiceResponse<B>;
102 type Error = Error;
103 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
104 forward_ready!(service);
105 fn call(&self, req: ServiceRequest) -> Self::Future {
106 let user_agent = req
107 .headers()
108 .get(header::USER_AGENT)
109 .and_then(|h| h.to_str().ok())
110 .unwrap_or("")
111 .to_lowercase();
112 if self.config.block_missing_ua && user_agent.is_empty() {
113 return Box::pin(ready(Err(ErrorForbidden("Blocked: Missing User-Agent"))));
114 }
115 if self.config.blocked_agents.iter().any(|bot| user_agent.contains(bot)) {
116 return Box::pin(ready(Err(ErrorForbidden("Blocked: Malicious Actor Detected"))));
117 }
118 let ip = req
119 .connection_info()
120 .realip_remote_addr()
121 .unwrap_or("unknown_ip")
122 .to_string();
123 let mut is_banned = false;
124 let mut triggered_ban = false;
125 {
126 let mut records = self.ip_records.write().unwrap();
127 let now = Instant::now();
128 let stats = records.entry(ip).or_insert(IpStats {
129 count: 0,
130 window_start: now,
131 banned_until: None,
132 });
133 if let Some(banned_time) = stats.banned_until {
134 if now < banned_time {
135 is_banned = true;
136 } else {
137 stats.banned_until = None;
138 stats.count = 1;
139 stats.window_start = now;
140 }
141 } else {
142 if now.duration_since(stats.window_start).as_secs() > self.config.window_secs {
143 stats.count = 1;
144 stats.window_start = now;
145 } else {
146 stats.count += 1;
147 if stats.count > self.config.max_requests {
148 stats.banned_until = Some(now + Duration::from_secs(self.config.ban_duration_secs));
149 triggered_ban = true;
150 is_banned = true;
151 }
152 }
153 }
154 }
155 if is_banned {
156 let msg = if triggered_ban {
157 "Rate limit exceeded. Your IP has been temporarily banned."
158 } else {
159 "Your IP is banned due to previous abuse."
160 };
161 return Box::pin(ready(Err(ErrorTooManyRequests(msg))));
162 }
163 let fut = self.service.call(req);
164 Box::pin(async move {
165 let res = fut.await?;
166 Ok(res)
167 })
168 }
169}