1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use tokio::task::JoinHandle;
4
5type Map = HashMap<Arc<std::net::Ipv6Addr>, u64>;
6
7pub struct IpRate {
9 origin: tokio::time::Instant,
10 map: Arc<Mutex<Map>>,
11 disabled: bool,
12 limit: u64,
13 burst: u64,
14 ip_deny: crate::ip_deny::IpDeny,
15}
16
17impl IpRate {
18 pub fn new(config: Arc<crate::Config>) -> Self {
20 Self {
21 origin: tokio::time::Instant::now(),
22 map: Arc::new(Mutex::new(HashMap::new())),
23 disabled: config.disable_rate_limiting,
24 limit: config.limit_ip_byte_nanos() as u64,
25 burst: config.limit_ip_byte_burst as u64
26 * config.limit_ip_byte_nanos() as u64,
27 ip_deny: crate::ip_deny::IpDeny::new(config),
28 }
29 }
30
31 pub fn prune(&self) {
37 let now = self.origin.elapsed().as_nanos() as u64;
38 self.map.lock().unwrap().retain(|_, cur| {
39 if now <= *cur {
40 true
41 } else {
42 now - *cur < 10_000_000_000
47 }
48 });
49 }
50
51 pub async fn is_blocked(&self, ip: &Arc<std::net::Ipv6Addr>) -> bool {
53 self.ip_deny.is_blocked(ip).await
54 }
55
56 pub async fn is_ok(
58 &self,
59 ip: &Arc<std::net::Ipv6Addr>,
60 bytes: usize,
61 ) -> bool {
62 if self.disabled {
63 return true;
64 }
65
66 let rate_add = bytes as u64 * self.limit;
68
69 let now = self.origin.elapsed().as_nanos() as u64;
71
72 let is_ok = {
73 let mut lock = self.map.lock().unwrap();
75
76 let e = lock.entry(ip.clone()).or_insert(now);
78
79 let cur = std::cmp::max(*e, now) + rate_add;
82
83 *e = cur;
85
86 cur - now <= self.burst
88 };
89
90 if !is_ok {
91 tracing::info!("IP rate limit exceeded for {ip}, blocking");
92 self.ip_deny.block(ip).await;
93 }
94
95 is_ok
96 }
97}
98
99pub fn spawn_prune_task(ip_rate: Arc<IpRate>) -> JoinHandle<()> {
101 let ip_rate = Arc::downgrade(&ip_rate);
102 tokio::task::spawn(async move {
103 loop {
104 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
105 if let Some(ip_rate) = ip_rate.upgrade() {
106 ip_rate.prune();
107 } else {
108 break;
109 }
110 }
111 })
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117
118 fn test_new(limit: u64, burst: u64) -> IpRate {
119 IpRate {
120 origin: tokio::time::Instant::now(),
121 map: Arc::new(Mutex::new(HashMap::new())),
122 disabled: false,
123 limit,
124 burst,
125 ip_deny: crate::ip_deny::IpDeny::new(Arc::new(
126 crate::Config::default(),
127 )),
128 }
129 }
130
131 #[tokio::test(flavor = "current_thread", start_paused = true)]
132 async fn check_one_to_one() {
133 let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
134
135 let rate = test_new(1, 1);
136
137 for _ in 0..10 {
138 tokio::time::advance(std::time::Duration::from_nanos(1)).await;
140 assert!(rate.is_ok(&addr1, 1).await);
141 }
142
143 assert!(!rate.is_ok(&addr1, 1).await);
145
146 tokio::time::advance(std::time::Duration::from_nanos(1)).await;
147
148 rate.prune();
150 assert_eq!(1, rate.map.lock().unwrap().len());
151
152 tokio::time::advance(std::time::Duration::from_secs(10)).await;
153
154 rate.prune();
156 assert_eq!(1, rate.map.lock().unwrap().len());
157
158 tokio::time::advance(std::time::Duration::from_nanos(1)).await;
160 rate.prune();
161 assert_eq!(0, rate.map.lock().unwrap().len());
162 }
163
164 #[tokio::test(flavor = "current_thread", start_paused = true)]
165 async fn check_burst() {
166 let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
167
168 let rate = test_new(1, 5);
169
170 for _ in 0..5 {
171 assert!(rate.is_ok(&addr1, 1).await);
172 }
173
174 assert!(!rate.is_ok(&addr1, 1).await);
175
176 tokio::time::advance(std::time::Duration::from_nanos(2)).await;
177 assert!(rate.is_ok(&addr1, 1).await);
178
179 tokio::time::advance(std::time::Duration::from_secs(10)).await;
180 tokio::time::advance(std::time::Duration::from_nanos(4)).await;
181
182 rate.prune();
183 assert_eq!(1, rate.map.lock().unwrap().len());
184
185 tokio::time::advance(std::time::Duration::from_nanos(1)).await;
186
187 rate.prune();
188 assert_eq!(0, rate.map.lock().unwrap().len());
189 }
190
191 #[tokio::test(flavor = "current_thread", start_paused = true)]
192 async fn check_limit_mult() {
193 let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
194
195 let rate = test_new(3, 13);
196
197 assert!(rate.is_ok(&addr1, 2).await);
198 assert!(rate.is_ok(&addr1, 2).await);
199 assert!(!rate.is_ok(&addr1, 2).await);
200
201 tokio::time::advance(std::time::Duration::from_secs(10)).await;
202
203 assert!(rate.is_ok(&addr1, 2).await);
204 assert!(rate.is_ok(&addr1, 2).await);
205 assert!(!rate.is_ok(&addr1, 2).await);
206 }
207}