sbd_server/
ip_rate.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use tokio::task::JoinHandle;
4
5type Map = HashMap<Arc<std::net::Ipv6Addr>, u64>;
6
7/// Rate limit connections by IP address.
8pub struct IpRate {
9    origin: tokio::time::Instant,
10    map: Arc<Mutex<Map>>,
11    disabled: bool,
12    limit: u64,
13    burst: u64,
14    ip_deny: crate::ip_deny::IpDeny,
15}
16
17impl IpRate {
18    /// Construct a new IpRate limit instance.
19    pub fn new(config: Arc<crate::Config>) -> Self {
20        Self {
21            origin: tokio::time::Instant::now(),
22            map: Arc::new(Mutex::new(HashMap::new())),
23            disabled: config.disable_rate_limiting,
24            limit: config.limit_ip_byte_nanos() as u64,
25            burst: config.limit_ip_byte_burst as u64
26                * config.limit_ip_byte_nanos() as u64,
27            ip_deny: crate::ip_deny::IpDeny::new(config),
28        }
29    }
30
31    /// Prune entries that have tracked backwards 10s or more.
32    /// The 10s just prevents hashtable thrashing if a connection
33    /// is using significantly less than its rate limit.
34    /// This is why the keepalive interval is 5 seconds and
35    /// connections are closed after 10 seconds.
36    pub fn prune(&self) {
37        let now = self.origin.elapsed().as_nanos() as u64;
38        self.map.lock().unwrap().retain(|_, cur| {
39            if now <= *cur {
40                true
41            } else {
42                // examples using seconds:
43                // now:100,cur:120 100-120=-20<10  true=keep
44                // now:100,cur:100 100-100=0<10    true=keep
45                // now:100,cur:80   100-80=20<10  false=prune
46                now - *cur < 10_000_000_000
47            }
48        });
49    }
50
51    /// Return true if this ip is blocked.
52    pub async fn is_blocked(&self, ip: &Arc<std::net::Ipv6Addr>) -> bool {
53        self.ip_deny.is_blocked(ip).await
54    }
55
56    /// Return true if we are not over the rate limit.
57    pub async fn is_ok(
58        &self,
59        ip: &Arc<std::net::Ipv6Addr>,
60        bytes: usize,
61    ) -> bool {
62        if self.disabled {
63            return true;
64        }
65
66        // multiply by our rate allowed per byte
67        let rate_add = bytes as u64 * self.limit;
68
69        // get now
70        let now = self.origin.elapsed().as_nanos() as u64;
71
72        let is_ok = {
73            // lock the map mutex
74            let mut lock = self.map.lock().unwrap();
75
76            // get the entry (default to now)
77            let e = lock.entry(ip.clone()).or_insert(now);
78
79            // if we've already used time greater than now use that,
80            // otherwise consider we're starting from scratch
81            let cur = std::cmp::max(*e, now) + rate_add;
82
83            // update the map with the current limit
84            *e = cur;
85
86            // subtract now back out to see if we're greater than our burst
87            cur - now <= self.burst
88        };
89
90        if !is_ok {
91            tracing::info!("IP rate limit exceeded for {ip}, blocking");
92            self.ip_deny.block(ip).await;
93        }
94
95        is_ok
96    }
97}
98
99/// Spawn a Tokio task to prune the IpRate map.
100pub fn spawn_prune_task(ip_rate: Arc<IpRate>) -> JoinHandle<()> {
101    let ip_rate = Arc::downgrade(&ip_rate);
102    tokio::task::spawn(async move {
103        loop {
104            tokio::time::sleep(std::time::Duration::from_secs(5)).await;
105            if let Some(ip_rate) = ip_rate.upgrade() {
106                ip_rate.prune();
107            } else {
108                break;
109            }
110        }
111    })
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    fn test_new(limit: u64, burst: u64) -> IpRate {
119        IpRate {
120            origin: tokio::time::Instant::now(),
121            map: Arc::new(Mutex::new(HashMap::new())),
122            disabled: false,
123            limit,
124            burst,
125            ip_deny: crate::ip_deny::IpDeny::new(Arc::new(
126                crate::Config::default(),
127            )),
128        }
129    }
130
131    #[tokio::test(flavor = "current_thread", start_paused = true)]
132    async fn check_one_to_one() {
133        let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
134
135        let rate = test_new(1, 1);
136
137        for _ in 0..10 {
138            // should always be ok when advancing with time
139            tokio::time::advance(std::time::Duration::from_nanos(1)).await;
140            assert!(rate.is_ok(&addr1, 1).await);
141        }
142
143        // but one more without a time advance fails
144        assert!(!rate.is_ok(&addr1, 1).await);
145
146        tokio::time::advance(std::time::Duration::from_nanos(1)).await;
147
148        // make sure prune doesn't prune it yet
149        rate.prune();
150        assert_eq!(1, rate.map.lock().unwrap().len());
151
152        tokio::time::advance(std::time::Duration::from_secs(10)).await;
153
154        // make sure prune doesn't prune it yet
155        rate.prune();
156        assert_eq!(1, rate.map.lock().unwrap().len());
157
158        // but one more should do it
159        tokio::time::advance(std::time::Duration::from_nanos(1)).await;
160        rate.prune();
161        assert_eq!(0, rate.map.lock().unwrap().len());
162    }
163
164    #[tokio::test(flavor = "current_thread", start_paused = true)]
165    async fn check_burst() {
166        let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
167
168        let rate = test_new(1, 5);
169
170        for _ in 0..5 {
171            assert!(rate.is_ok(&addr1, 1).await);
172        }
173
174        assert!(!rate.is_ok(&addr1, 1).await);
175
176        tokio::time::advance(std::time::Duration::from_nanos(2)).await;
177        assert!(rate.is_ok(&addr1, 1).await);
178
179        tokio::time::advance(std::time::Duration::from_secs(10)).await;
180        tokio::time::advance(std::time::Duration::from_nanos(4)).await;
181
182        rate.prune();
183        assert_eq!(1, rate.map.lock().unwrap().len());
184
185        tokio::time::advance(std::time::Duration::from_nanos(1)).await;
186
187        rate.prune();
188        assert_eq!(0, rate.map.lock().unwrap().len());
189    }
190
191    #[tokio::test(flavor = "current_thread", start_paused = true)]
192    async fn check_limit_mult() {
193        let addr1 = Arc::new(std::net::Ipv6Addr::new(1, 1, 1, 1, 1, 1, 1, 1));
194
195        let rate = test_new(3, 13);
196
197        assert!(rate.is_ok(&addr1, 2).await);
198        assert!(rate.is_ok(&addr1, 2).await);
199        assert!(!rate.is_ok(&addr1, 2).await);
200
201        tokio::time::advance(std::time::Duration::from_secs(10)).await;
202
203        assert!(rate.is_ok(&addr1, 2).await);
204        assert!(rate.is_ok(&addr1, 2).await);
205        assert!(!rate.is_ok(&addr1, 2).await);
206    }
207}