prob_rate_limiter/
lib.rs

1//! prob-rate-limiter
2//! =================
3//! [![crates.io version](https://img.shields.io/crates/v/prob-rate-limiter.svg)](https://crates.io/crates/prob-rate-limiter)
4//! [![license: Apache 2.0](https://gitlab.com/leonhard-llc/ops/-/raw/main/license-apache-2.0.svg)](https://gitlab.com/leonhard-llc/ops/-/raw/main/prob-rate-limiter/LICENSE)
5//! [![unsafe forbidden](https://gitlab.com/leonhard-llc/ops/-/raw/main/unsafe-forbidden.svg)](https://github.com/rust-secure-code/safety-dance/)
6//! [![pipeline status](https://gitlab.com/leonhard-llc/ops/badges/main/pipeline.svg)](https://gitlab.com/leonhard-llc/ops/-/pipelines)
7//!
8//! `ProbRateLimiter` is a *probabilistic* rate limiter.
9//! When load approaches the configured limit,
10//! the struct chooses randomly whether to accept or reject each request.
11//! It adjusts the probability of rejection so throughput is steady around the limit.
12//!
13//! # Use Cases
14//! - Shed load to prevent overload
15//! - Avoid overloading the services you depend on
16//! - Control costs
17//!
18//! # Features
19//! - Tiny, uses 44 bytes
20//! - 100% test coverage
21//! - Optimized: 32ns per check, 31M checks per second on an i5-8259U
22//! - No `unsafe` or unsafe deps
23//!
24//! # Limitations
25//! - Requires a mutable reference.
26//! - Not fair.  Treats all requests equally, regardless of source.
27//!   A client that overloads the server will consume most of the throughput.
28//!
29//! # Alternatives
30//! - [r8limit](https://crates.io/crates/r8limit)
31//!   - Uses a sliding window
32//!   - No `unsafe` or deps
33//!   - Optimized: 48ns per check, 21M checks per second on an i5-8259U
34//!   - Requires a mutable reference.
35//! - [governor](https://crates.io/crates/governor)
36//!   - Uses a non-mutable reference, easy to share between threads
37//!   - Popular
38//!   - Good docs
39//!   - Optimized: 29ns per check on an i5-8259U.
40//!   - Unnecessary `unsafe`
41//!   - Uses non-standard mutex library [`parking_lot`](https://crates.io/crates/parking_lot)
42//!   - Uses a complicated algorithm
43//! - [leaky-bucket](https://crates.io/crates/leaky-bucket)
44//!   - Async tasks can wait for their turn to use a resource.
45//!   - Unsuitable for load shedding because there is no `try_acquire`.
46//!
47//! # Related Crates
48//! - [safe-dns](https://crates.io/crates/safe-dns) uses this
49//!
50//! # Example
51//! ```
52//! # use prob_rate_limiter::ProbRateLimiter;
53//! # use std::time::{Duration, Instant};
54//! let mut limiter = ProbRateLimiter::new(10.0).unwrap();
55//! let mut now = Instant::now();
56//! assert!(limiter.check(5, now));
57//! assert!(limiter.check(5, now));
58//! now += Duration::from_secs(1);
59//! assert!(limiter.check(5, now));
60//! assert!(limiter.check(5, now));
61//! now += Duration::from_secs(1);
62//! assert!(limiter.check(5, now));
63//! assert!(limiter.check(5, now));
64//! now += Duration::from_secs(1);
65//! assert!(limiter.check(5, now));
66//! assert!(limiter.check(5, now));
67//! now += Duration::from_secs(1);
68//! assert!(limiter.check(5, now));
69//! assert!(limiter.check(5, now));
70//! assert!(!limiter.check(5, now));
71//! ```
72//!
73//! # Cargo Geiger Safety Report
74//!
75//! # Changelog
76//! - v0.1.1 - Simplify `new`.  Add more docs.
77//! - v0.1.0 - Initial version
78//!
79//! # TO DO
80//! - Publish
81//! - Add graph from the benchmark.
82#![forbid(unsafe_code)]
83
84use core::time::Duration;
85use oorandom::Rand32;
86use std::time::Instant;
87
88trait SaturatingAddAssign<T> {
89    fn saturating_add_assign(&mut self, rhs: T);
90}
91impl SaturatingAddAssign<u32> for u32 {
92    fn saturating_add_assign(&mut self, rhs: u32) {
93        *self = self.saturating_add(rhs);
94    }
95}
96
97fn decide(recent_cost: u32, max_cost: u32, mut rand_float: impl FnMut() -> f32) -> bool {
98    // Value is in [0.0, 1.0).
99    let load = if max_cost == 0 || recent_cost >= max_cost {
100        return false;
101    } else {
102        f64::from(recent_cost) / f64::from(max_cost)
103    };
104    // Value is in (-inf, 1.0).
105    let linear_reject_prob = (load - 0.75) * 4.0;
106    if linear_reject_prob <= 0.0 {
107        return true;
108    }
109    let reject_prob = linear_reject_prob.powi(2);
110    reject_prob < rand_float().into()
111}
112
113#[cfg(test)]
114#[test]
115#[allow(clippy::unreadable_literal)]
116fn test_decide() {
117    assert!(!decide(0, 0, || unreachable!()));
118    assert!(decide(0, 100, || unreachable!()));
119    assert!(decide(50, 100, || unreachable!()));
120    assert!(decide(75, 100, || unreachable!()));
121    assert!(decide(76, 100, || 0.999999));
122    assert!(!decide(76, 100, || 0.0));
123    assert!(!decide(85, 100, || 0.15));
124    assert!(decide(85, 100, || 0.17));
125    assert!(!decide(90, 100, || 0.35));
126    assert!(decide(90, 100, || 0.37));
127    assert!(!decide(95, 100, || 0.63));
128    assert!(decide(95, 100, || 0.65));
129    assert!(!decide(99, 100, || 0.92));
130    assert!(decide(99, 100, || 0.93));
131    assert!(!decide(100, 100, || unreachable!()));
132    assert!(!decide(101, 100, || unreachable!()));
133}
134
135/// A probabilistic rate-limiter.
136/// - When not overloaded, accepts all requests
137/// - As load approaches limit, probabilistically rejects more and more requests.
138/// - Onset of overload does not trigger a sudden total outage.
139#[derive(Clone, Debug)]
140pub struct ProbRateLimiter {
141    tick_duration: Duration,
142    max_cost: u32,
143    cost: u32,
144    last: Instant,
145    prng: Rand32,
146}
147impl ProbRateLimiter {
148    /// Makes a new rate limiter that accepts `max_cost_per_tick` every `tick_duration`.
149    ///
150    /// # Errors
151    /// Returns an error when `tick_duration` is less than 1 microsecond.
152    pub fn new_custom(
153        tick_duration: Duration,
154        max_cost_per_tick: u32,
155        now: Instant,
156        prng: Rand32,
157    ) -> Result<Self, String> {
158        if tick_duration.as_micros() == 0 {
159            return Err(format!("tick_duration too small: {:?}", tick_duration));
160        }
161        Ok(Self {
162            tick_duration,
163            max_cost: max_cost_per_tick * 2,
164            cost: 0_u32,
165            last: now,
166            prng,
167        })
168    }
169
170    /// Makes a new rate limiter that accepts `max_cost_per_sec` cost every second.
171    #[allow(clippy::missing_panics_doc)]
172    #[must_use]
173    pub fn new(max_cost_per_sec: u32) -> Self {
174        Self::new_custom(
175            Duration::from_secs(1),
176            max_cost_per_sec,
177            Instant::now(),
178            Rand32::new(0),
179        )
180        .unwrap()
181    }
182
183    /// Try a request.  Returns `true` when the request should be accepted.
184    pub fn attempt(&mut self, now: Instant) -> bool {
185        if self.max_cost == 0 {
186            return false;
187        }
188        let elapsed = now.saturating_duration_since(self.last);
189        #[allow(clippy::cast_possible_truncation)]
190        let elapsed_ticks = (elapsed.as_micros() / self.tick_duration.as_micros()) as u32;
191        self.last += self.tick_duration * elapsed_ticks;
192        self.cost = self.cost.wrapping_shr(elapsed_ticks);
193        decide(self.cost, self.max_cost, || self.prng.rand_float())
194    }
195
196    /// Record the cost of a request.
197    pub fn record(&mut self, cost: u32) {
198        self.cost.saturating_add_assign(cost);
199    }
200
201    /// A convenience method that calls [`attempt`] and [`record`].
202    /// Use this when the cost of each request is fixed or cheap to calculate.
203    pub fn check(&mut self, cost: u32, now: Instant) -> bool {
204        if self.attempt(now) {
205            self.record(cost);
206            true
207        } else {
208            false
209        }
210    }
211}