Skip to main content

irontide_session/
rate_limiter.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    reason = "M175: token-bucket rate limiter — u128 → u64 truncation harmless at realistic rate budgets"
4)]
5
6//! Token bucket rate limiter and local network detection.
7
8use std::net::IpAddr;
9use std::time::Duration;
10
11use serde::{Deserialize, Serialize};
12
13/// Token bucket rate limiter.
14///
15/// Tokens represent bytes. `rate` is bytes/second.
16/// Tokens are added via `refill()` (called on a timer).
17/// Burst capacity = 1 second of tokens.
18#[allow(dead_code)] // consumed by torrent/session modules (wired in later tasks)
19pub(crate) struct TokenBucket {
20    rate: u64,     // bytes/sec, 0 = unlimited
21    tokens: u64,   // current available tokens
22    capacity: u64, // max tokens (= rate, i.e., 1 second burst)
23}
24
25#[allow(dead_code)]
26impl TokenBucket {
27    pub fn new(rate: u64) -> Self {
28        Self {
29            rate,
30            tokens: 0,
31            capacity: rate, // 1 second burst
32        }
33    }
34
35    pub fn unlimited() -> Self {
36        Self {
37            rate: 0,
38            tokens: 0,
39            capacity: 0,
40        }
41    }
42
43    pub fn is_unlimited(&self) -> bool {
44        self.rate == 0
45    }
46
47    /// Current rate limit in bytes/sec (0 = unlimited).
48    pub fn rate(&self) -> u64 {
49        self.rate
50    }
51
52    /// Add tokens proportional to elapsed time.
53    pub fn refill(&mut self, elapsed: Duration) {
54        if self.rate == 0 {
55            return;
56        }
57        let add = (u128::from(self.rate) * elapsed.as_millis() / 1000) as u64;
58        self.tokens = (self.tokens + add).min(self.capacity);
59    }
60
61    /// Try to consume `amount` tokens. Returns true if allowed.
62    pub fn try_consume(&mut self, amount: u64) -> bool {
63        if self.rate == 0 {
64            return true;
65        }
66        if self.tokens >= amount {
67            self.tokens -= amount;
68            true
69        } else {
70            false
71        }
72    }
73
74    /// How many bytes can be consumed right now.
75    pub fn available(&self) -> u64 {
76        if self.rate == 0 {
77            u64::MAX
78        } else {
79            self.tokens
80        }
81    }
82
83    /// Update the rate limit. Resets capacity but preserves current tokens (clamped).
84    pub fn set_rate(&mut self, rate: u64) {
85        self.rate = rate;
86        self.capacity = rate;
87        if rate > 0 {
88            self.tokens = self.tokens.min(self.capacity);
89        }
90    }
91}
92
93/// Transport type for per-class rate limiting.
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
95#[allow(dead_code)]
96pub(crate) enum PeerTransport {
97    Tcp,
98    Utp,
99}
100
101/// Mixed-mode bandwidth allocation algorithm for TCP/uTP coexistence.
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
103pub enum MixedModeAlgorithm {
104    /// Throttle uTP upload when any TCP peer is connected.
105    /// uTP gets at most 10% of the global upload rate when TCP peers are present.
106    PreferTcp,
107    /// Allocate bandwidth proportional to the number of TCP vs uTP peers.
108    PeerProportional,
109}
110
111/// Per-class rate limiter set (BEP 40 / libtorrent parity).
112///
113/// Maintains separate upload/download buckets for TCP and uTP, plus global
114/// upload/download buckets. Uses check-before-consume pattern to avoid
115/// partial consumption when one bucket has capacity but another doesn't.
116#[allow(dead_code)]
117pub(crate) struct RateLimiterSet {
118    tcp_upload: TokenBucket,
119    tcp_download: TokenBucket,
120    utp_upload: TokenBucket,
121    utp_download: TokenBucket,
122    global_upload: TokenBucket,
123    global_download: TokenBucket,
124}
125
126#[allow(dead_code)]
127impl RateLimiterSet {
128    /// Create a new rate limiter set. Rate of 0 = unlimited.
129    pub fn new(
130        tcp_upload_rate: u64,
131        tcp_download_rate: u64,
132        utp_upload_rate: u64,
133        utp_download_rate: u64,
134        global_upload_rate: u64,
135        global_download_rate: u64,
136    ) -> Self {
137        Self {
138            tcp_upload: TokenBucket::new(tcp_upload_rate),
139            tcp_download: TokenBucket::new(tcp_download_rate),
140            utp_upload: TokenBucket::new(utp_upload_rate),
141            utp_download: TokenBucket::new(utp_download_rate),
142            global_upload: TokenBucket::new(global_upload_rate),
143            global_download: TokenBucket::new(global_download_rate),
144        }
145    }
146
147    /// Refill all buckets proportional to elapsed time.
148    pub fn refill(&mut self, elapsed: Duration) {
149        self.tcp_upload.refill(elapsed);
150        self.tcp_download.refill(elapsed);
151        self.utp_upload.refill(elapsed);
152        self.utp_download.refill(elapsed);
153        self.global_upload.refill(elapsed);
154        self.global_download.refill(elapsed);
155    }
156
157    /// Try to consume upload tokens for the given transport class.
158    ///
159    /// Checks both the class bucket and global bucket *before* consuming
160    /// either, to avoid partial consumption without refund.
161    pub fn try_consume_upload(&mut self, amount: u64, transport: PeerTransport) -> bool {
162        let class = match transport {
163            PeerTransport::Tcp => &self.tcp_upload,
164            PeerTransport::Utp => &self.utp_upload,
165        };
166        // Check both before consuming either
167        if !class.is_unlimited() && class.available() < amount {
168            return false;
169        }
170        if !self.global_upload.is_unlimited() && self.global_upload.available() < amount {
171            return false;
172        }
173        // Both have capacity — consume from both
174        let class = match transport {
175            PeerTransport::Tcp => &mut self.tcp_upload,
176            PeerTransport::Utp => &mut self.utp_upload,
177        };
178        class.try_consume(amount);
179        self.global_upload.try_consume(amount);
180        true
181    }
182
183    /// Try to consume download tokens for the given transport class.
184    pub fn try_consume_download(&mut self, amount: u64, transport: PeerTransport) -> bool {
185        let class = match transport {
186            PeerTransport::Tcp => &self.tcp_download,
187            PeerTransport::Utp => &self.utp_download,
188        };
189        if !class.is_unlimited() && class.available() < amount {
190            return false;
191        }
192        if !self.global_download.is_unlimited() && self.global_download.available() < amount {
193            return false;
194        }
195        let class = match transport {
196            PeerTransport::Tcp => &mut self.tcp_download,
197            PeerTransport::Utp => &mut self.utp_download,
198        };
199        class.try_consume(amount);
200        self.global_download.try_consume(amount);
201        true
202    }
203
204    /// Update per-class rates at runtime (e.g., from `apply_settings`).
205    pub fn set_rates(
206        &mut self,
207        tcp_upload: u64,
208        tcp_download: u64,
209        utp_upload: u64,
210        utp_download: u64,
211        global_upload: u64,
212        global_download: u64,
213    ) {
214        self.tcp_upload.set_rate(tcp_upload);
215        self.tcp_download.set_rate(tcp_download);
216        self.utp_upload.set_rate(utp_upload);
217        self.utp_download.set_rate(utp_download);
218        self.global_upload.set_rate(global_upload);
219        self.global_download.set_rate(global_download);
220    }
221
222    /// Apply mixed-mode bandwidth allocation based on peer transport composition.
223    /// Only adjusts upload — download is not throttled by transport type.
224    pub fn apply_mixed_mode(
225        &mut self,
226        algorithm: MixedModeAlgorithm,
227        tcp_peers: usize,
228        utp_peers: usize,
229        global_upload_rate: u64,
230    ) {
231        if global_upload_rate == 0 {
232            self.tcp_upload.set_rate(0);
233            self.utp_upload.set_rate(0);
234            return;
235        }
236        if tcp_peers == 0 && utp_peers == 0 {
237            self.tcp_upload.set_rate(0);
238            self.utp_upload.set_rate(0);
239            return;
240        }
241        match algorithm {
242            MixedModeAlgorithm::PreferTcp => {
243                if tcp_peers > 0 && utp_peers > 0 {
244                    let tcp_rate = global_upload_rate * 9 / 10;
245                    let utp_rate = global_upload_rate / 10;
246                    self.tcp_upload.set_rate(tcp_rate.max(1));
247                    self.utp_upload.set_rate(utp_rate.max(1));
248                } else {
249                    self.tcp_upload.set_rate(0);
250                    self.utp_upload.set_rate(0);
251                }
252            }
253            MixedModeAlgorithm::PeerProportional => {
254                let total = tcp_peers + utp_peers;
255                let tcp_rate = global_upload_rate * tcp_peers as u64 / total as u64;
256                let utp_rate = global_upload_rate * utp_peers as u64 / total as u64;
257                self.tcp_upload
258                    .set_rate(if tcp_peers > 0 { tcp_rate.max(1) } else { 0 });
259                self.utp_upload
260                    .set_rate(if utp_peers > 0 { utp_rate.max(1) } else { 0 });
261            }
262        }
263    }
264}
265
266/// Check if an IP address is on a local/private network.
267///
268/// IPv4: loopback, private (RFC 1918), link-local (169.254.0.0/16), unspecified
269/// (`0.0.0.0`).
270/// IPv6: loopback (`::1`), link-local (`fe80::/10`), unique-local / ULA
271/// (`fc00::/7`), unspecified (`::`).
272///
273/// IPv4-mapped IPv6 addresses (`::ffff:x.x.x.x`) are normalised to IPv4 first
274/// so the v4 RFC 1918 / loopback / link-local checks catch them — without this
275/// the IPv6 branch only matches literal `::1`, which leaves `::ffff:127.0.0.1`,
276/// `::ffff:192.168.1.1`, etc. unprotected against SSRF.
277///
278/// The unspecified-address checks (`0.0.0.0`, `::`) close another SSRF gap:
279/// on Linux, connecting to `0.0.0.0:p` resolves to `127.0.0.1:p`, so a URL
280/// like `http://0.0.0.0/` is a loopback request that the underlying RFC 1918
281/// / loopback predicates miss (`is_loopback()` only matches `127.0.0.0/8`).
282#[allow(dead_code)] // consumed by torrent module (wired in later tasks)
283pub(crate) fn is_local_network(addr: IpAddr) -> bool {
284    // Normalize IPv4-mapped IPv6 to IPv4 so the v4 predicates apply uniformly.
285    let addr = match addr {
286        IpAddr::V6(ip) => ip.to_ipv4_mapped().map_or(IpAddr::V6(ip), IpAddr::V4),
287        IpAddr::V4(_) => addr,
288    };
289    match addr {
290        IpAddr::V4(ip) => {
291            ip.is_loopback() || ip.is_private() || ip.is_link_local() || ip.is_unspecified()
292        }
293        IpAddr::V6(ip) => {
294            if ip.is_loopback() || ip.is_unspecified() {
295                return true;
296            }
297            let octets = ip.octets();
298            // fe80::/10 — link-local
299            if octets[0] == 0xfe && (octets[1] & 0xc0) == 0x80 {
300                return true;
301            }
302            // fc00::/7 — unique-local (ULA)
303            if (octets[0] & 0xfe) == 0xfc {
304                return true;
305            }
306            false
307        }
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn unlimited_bucket_always_allows() {
317        let mut tb = TokenBucket::unlimited();
318        assert!(tb.try_consume(1_000_000));
319        assert!(tb.is_unlimited());
320        assert_eq!(tb.available(), u64::MAX);
321    }
322
323    #[test]
324    fn limited_bucket_allows_up_to_capacity() {
325        let mut tb = TokenBucket::new(1000); // 1000 bytes/sec
326        tb.refill(Duration::from_millis(100)); // +100 tokens
327        assert!(tb.try_consume(100));
328        assert!(!tb.try_consume(1)); // exhausted
329    }
330
331    #[test]
332    fn refill_adds_tokens_proportionally() {
333        let mut tb = TokenBucket::new(10_000); // 10 KB/s
334        tb.refill(Duration::from_millis(100)); // +1000 tokens
335        assert!(tb.try_consume(1000));
336        assert!(!tb.try_consume(1));
337    }
338
339    #[test]
340    fn tokens_cap_at_one_second_burst() {
341        let mut tb = TokenBucket::new(1000);
342        tb.refill(Duration::from_secs(5)); // would be 5000, but capped at 1000
343        assert!(tb.try_consume(1000));
344        assert!(!tb.try_consume(1));
345    }
346
347    #[test]
348    fn try_consume_partial() {
349        let mut tb = TokenBucket::new(1000);
350        tb.refill(Duration::from_millis(100)); // +100
351        assert_eq!(tb.available(), 100);
352        assert!(tb.try_consume(50));
353        assert_eq!(tb.available(), 50);
354    }
355
356    #[test]
357    fn set_rate_clamps_tokens() {
358        let mut tb = TokenBucket::new(1000);
359        tb.refill(Duration::from_secs(1)); // 1000 tokens
360        assert_eq!(tb.available(), 1000);
361        tb.set_rate(500); // capacity now 500, tokens clamped
362        assert_eq!(tb.available(), 500);
363    }
364
365    #[test]
366    fn local_network_detection() {
367        assert!(is_local_network("127.0.0.1".parse().unwrap()));
368        assert!(is_local_network("192.168.1.1".parse().unwrap()));
369        assert!(is_local_network("10.0.0.1".parse().unwrap()));
370        assert!(is_local_network("172.16.0.1".parse().unwrap()));
371        assert!(is_local_network("169.254.1.1".parse().unwrap()));
372        assert!(is_local_network("::1".parse().unwrap()));
373        assert!(!is_local_network("8.8.8.8".parse().unwrap()));
374        assert!(!is_local_network("1.2.3.4".parse().unwrap()));
375    }
376
377    #[test]
378    fn ipv6_local_network_detection() {
379        // Loopback
380        assert!(is_local_network("::1".parse().unwrap()));
381        // Link-local (fe80::/10)
382        assert!(is_local_network("fe80::1".parse().unwrap()));
383        assert!(is_local_network("fe80::abcd:1234".parse().unwrap()));
384        // Unique-local / ULA (fc00::/7)
385        assert!(is_local_network("fc00::1".parse().unwrap()));
386        assert!(is_local_network("fd00::1".parse().unwrap()));
387        assert!(is_local_network("fd12:3456:789a::1".parse().unwrap()));
388        // Global unicast — not local
389        assert!(!is_local_network("2001:db8::1".parse().unwrap()));
390        assert!(!is_local_network(
391            "2607:f8b0:4004:800::200e".parse().unwrap()
392        ));
393    }
394
395    #[test]
396    fn unspecified_v4_is_local() {
397        // 0.0.0.0 on Linux connects to 127.0.0.1 — must be treated as local.
398        assert!(is_local_network("0.0.0.0".parse().unwrap()));
399    }
400
401    #[test]
402    fn unspecified_v6_is_local() {
403        // :: (all-zeros IPv6) is the v6 equivalent — same SSRF concern.
404        assert!(is_local_network("::".parse().unwrap()));
405    }
406
407    #[test]
408    fn ipv4_mapped_v6_loopback_is_local() {
409        // ::ffff:127.0.0.1 is loopback expressed as IPv4-mapped IPv6.
410        // Ipv6Addr::is_loopback() only matches literal ::1, so without
411        // to_ipv4_mapped() normalisation this would slip past the SSRF check.
412        assert!(is_local_network("::ffff:127.0.0.1".parse().unwrap()));
413        assert!(is_local_network("::ffff:7f00:1".parse().unwrap()));
414    }
415
416    #[test]
417    fn ipv4_mapped_v6_private_is_local() {
418        // Same gap for RFC 1918 ranges expressed as IPv4-mapped IPv6.
419        assert!(is_local_network("::ffff:192.168.1.1".parse().unwrap()));
420        assert!(is_local_network("::ffff:10.0.0.1".parse().unwrap()));
421        assert!(is_local_network("::ffff:172.16.0.1".parse().unwrap()));
422        // Public IPv4 mapped into IPv6 is still public.
423        assert!(!is_local_network("::ffff:8.8.8.8".parse().unwrap()));
424    }
425
426    #[test]
427    fn rate_limiter_set_all_unlimited() {
428        let mut rls = RateLimiterSet::new(0, 0, 0, 0, 0, 0);
429        rls.refill(Duration::from_secs(1));
430        assert!(rls.try_consume_upload(1_000_000, PeerTransport::Tcp));
431        assert!(rls.try_consume_upload(1_000_000, PeerTransport::Utp));
432        assert!(rls.try_consume_download(1_000_000, PeerTransport::Tcp));
433        assert!(rls.try_consume_download(1_000_000, PeerTransport::Utp));
434    }
435
436    #[test]
437    fn rate_limiter_set_class_limited() {
438        let mut rls = RateLimiterSet::new(1000, 1000, 500, 500, 0, 0);
439        rls.refill(Duration::from_secs(1));
440        // TCP: 1000 capacity
441        assert!(rls.try_consume_upload(1000, PeerTransport::Tcp));
442        assert!(!rls.try_consume_upload(1, PeerTransport::Tcp)); // exhausted
443        // uTP: 500 capacity, independent
444        assert!(rls.try_consume_upload(500, PeerTransport::Utp));
445        assert!(!rls.try_consume_upload(1, PeerTransport::Utp));
446    }
447
448    #[test]
449    fn rate_limiter_set_global_limits() {
450        // Global upload limit = 500, class limit = 1000 each
451        let mut rls = RateLimiterSet::new(1000, 0, 1000, 0, 500, 0);
452        rls.refill(Duration::from_secs(1));
453        // TCP class has 1000, but global only has 500
454        assert!(rls.try_consume_upload(500, PeerTransport::Tcp));
455        // Now global is exhausted — uTP should also be blocked
456        assert!(!rls.try_consume_upload(1, PeerTransport::Utp));
457    }
458
459    #[test]
460    fn rate_limiter_set_check_before_consume_no_partial() {
461        // If global allows but class doesn't, no partial consumption
462        let mut rls = RateLimiterSet::new(100, 0, 0, 0, 1000, 0);
463        rls.refill(Duration::from_secs(1));
464        assert!(rls.try_consume_upload(100, PeerTransport::Tcp));
465        // Class exhausted, global still has 900 — should fail cleanly
466        assert!(!rls.try_consume_upload(1, PeerTransport::Tcp));
467        // uTP is unlimited, global has 900
468        assert!(rls.try_consume_upload(900, PeerTransport::Utp));
469    }
470
471    #[test]
472    fn rate_limiter_set_refill_all() {
473        let mut rls = RateLimiterSet::new(1000, 2000, 500, 750, 5000, 10000);
474        rls.refill(Duration::from_millis(100));
475        // Each bucket should have 10% of its rate
476        assert!(rls.try_consume_upload(100, PeerTransport::Tcp));
477        assert!(rls.try_consume_download(200, PeerTransport::Tcp));
478        assert!(rls.try_consume_upload(50, PeerTransport::Utp));
479        assert!(rls.try_consume_download(75, PeerTransport::Utp));
480    }
481
482    #[test]
483    fn rate_limiter_set_runtime_update() {
484        let mut rls = RateLimiterSet::new(1000, 1000, 1000, 1000, 0, 0);
485        rls.refill(Duration::from_secs(1));
486        assert!(rls.try_consume_upload(1000, PeerTransport::Tcp));
487        // Update TCP upload to 500
488        rls.set_rates(500, 1000, 1000, 1000, 0, 0);
489        rls.refill(Duration::from_secs(1));
490        assert!(rls.try_consume_upload(500, PeerTransport::Tcp));
491        assert!(!rls.try_consume_upload(1, PeerTransport::Tcp));
492    }
493
494    #[test]
495    fn mixed_mode_prefer_tcp_both_present() {
496        let mut rls = RateLimiterSet::new(0, 0, 0, 0, 10000, 0);
497        rls.apply_mixed_mode(MixedModeAlgorithm::PreferTcp, 3, 5, 10000);
498        rls.refill(Duration::from_secs(1));
499        assert!(rls.try_consume_upload(9000, PeerTransport::Tcp));
500        assert!(!rls.try_consume_upload(1, PeerTransport::Tcp));
501        rls.refill(Duration::from_secs(1));
502        assert!(rls.try_consume_upload(1000, PeerTransport::Utp));
503        assert!(!rls.try_consume_upload(1, PeerTransport::Utp));
504    }
505
506    #[test]
507    fn mixed_mode_prefer_tcp_only_utp() {
508        // When only uTP peers exist, per-class rate is set to unlimited (0),
509        // so uTP can consume up to the full global limit without per-class throttling.
510        let mut rls = RateLimiterSet::new(0, 0, 0, 0, 10000, 0);
511        rls.apply_mixed_mode(MixedModeAlgorithm::PreferTcp, 0, 5, 10000);
512        rls.refill(Duration::from_secs(1));
513        // uTP per-class bucket is unlimited, so full global capacity is available
514        assert!(rls.try_consume_upload(10000, PeerTransport::Utp));
515        assert!(!rls.try_consume_upload(1, PeerTransport::Utp));
516    }
517
518    #[test]
519    fn mixed_mode_proportional() {
520        let mut rls = RateLimiterSet::new(0, 0, 0, 0, 10000, 0);
521        rls.apply_mixed_mode(MixedModeAlgorithm::PeerProportional, 3, 7, 10000);
522        rls.refill(Duration::from_secs(1));
523        assert!(rls.try_consume_upload(3000, PeerTransport::Tcp));
524        assert!(!rls.try_consume_upload(1, PeerTransport::Tcp));
525        rls.refill(Duration::from_secs(1));
526        assert!(rls.try_consume_upload(7000, PeerTransport::Utp));
527        assert!(!rls.try_consume_upload(1, PeerTransport::Utp));
528    }
529
530    #[test]
531    fn mixed_mode_unlimited_global_noop() {
532        let mut rls = RateLimiterSet::new(0, 0, 0, 0, 0, 0);
533        rls.apply_mixed_mode(MixedModeAlgorithm::PeerProportional, 3, 7, 0);
534        rls.refill(Duration::from_secs(1));
535        assert!(rls.try_consume_upload(1_000_000, PeerTransport::Tcp));
536        assert!(rls.try_consume_upload(1_000_000, PeerTransport::Utp));
537    }
538}