foyer_common/
rated_ticket.rs

1// Copyright 2025 foyer Project Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::time::Instant;
16
17use parking_lot::Mutex;
18
19///  A ticket-based rate limiter.
20#[derive(Debug)]
21pub struct RatedTicket {
22    inner: Mutex<Inner>,
23    rate: f64,
24}
25
26#[derive(Debug)]
27struct Inner {
28    quota: f64,
29
30    last: Instant,
31}
32
33impl RatedTicket {
34    /// Create a ticket-based rate limiter.
35    pub fn new(rate: f64) -> Self {
36        let inner = Inner {
37            quota: 0.0,
38            last: Instant::now(),
39        };
40        Self {
41            rate,
42            inner: Mutex::new(inner),
43        }
44    }
45
46    /// Check if there is still some quota left.
47    pub fn probe(&self) -> bool {
48        let mut inner = self.inner.lock();
49
50        let now = Instant::now();
51        let refill = now.duration_since(inner.last).as_secs_f64() * self.rate;
52        inner.last = now;
53        inner.quota = f64::min(inner.quota + refill, self.rate);
54
55        inner.quota > 0.0
56    }
57
58    /// Reduce some quota manually.
59    pub fn reduce(&self, weight: f64) {
60        self.inner.lock().quota -= weight;
61    }
62
63    /// Consume some quota from the rate limiter.
64    ///
65    /// If there enough quota left, returns `true`; otherwise, returns `false`.
66    pub fn consume(&self, weight: f64) -> bool {
67        let mut inner = self.inner.lock();
68
69        let now = Instant::now();
70        let refill = now.duration_since(inner.last).as_secs_f64() * self.rate;
71        inner.last = now;
72        inner.quota = f64::min(inner.quota + refill, self.rate);
73
74        if inner.quota <= 0.0 {
75            return false;
76        }
77
78        inner.quota -= weight;
79
80        true
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use std::{
87        sync::{
88            atomic::{AtomicUsize, Ordering},
89            Arc,
90        },
91        time::Duration,
92    };
93
94    use itertools::Itertools;
95    use rand::{rng, Rng};
96
97    use super::*;
98
99    #[ignore]
100    #[test]
101    fn test_rated_ticket_consume() {
102        test(consume)
103    }
104
105    #[ignore]
106    #[test]
107    fn test_rated_ticket_probe_reduce() {
108        test(probe_reduce)
109    }
110
111    fn test<F>(f: F)
112    where
113        F: Fn(usize, &Arc<AtomicUsize>, &Arc<RatedTicket>) + Send + Sync + Copy + 'static,
114    {
115        const CASES: usize = 10;
116        const ERATIO: f64 = 0.05;
117
118        let handles = (0..CASES).map(|_| std::thread::spawn(move || case(f))).collect_vec();
119        let mut eratios = vec![];
120        for handle in handles {
121            let eratio = handle.join().unwrap();
122            assert!(eratio < ERATIO, "eratio: {eratio} < ERATIO: {ERATIO}");
123            eratios.push(eratio);
124        }
125        println!("========== RatedTicket error ratio begin ==========");
126        for eratio in eratios {
127            println!("eratio: {eratio}");
128        }
129        println!("=========== RatedTicket error ratio end ===========");
130    }
131
132    fn consume(weight: usize, v: &Arc<AtomicUsize>, limiter: &Arc<RatedTicket>) {
133        if limiter.consume(weight as f64) {
134            v.fetch_add(weight, Ordering::Relaxed);
135        }
136    }
137
138    fn probe_reduce(weight: usize, v: &Arc<AtomicUsize>, limiter: &Arc<RatedTicket>) {
139        if limiter.probe() {
140            limiter.reduce(weight as f64);
141            v.fetch_add(weight, Ordering::Relaxed);
142        }
143    }
144
145    fn case<F>(f: F) -> f64
146    where
147        F: Fn(usize, &Arc<AtomicUsize>, &Arc<RatedTicket>) + Send + Sync + Copy + 'static,
148    {
149        const THREADS: usize = 8;
150        const RATE: usize = 1000;
151        const DURATION: Duration = Duration::from_secs(10);
152
153        let v = Arc::new(AtomicUsize::new(0));
154        let limiter = Arc::new(RatedTicket::new(RATE as f64));
155        let task = |rate: usize, v: Arc<AtomicUsize>, limiter: Arc<RatedTicket>, f: F| {
156            let start = Instant::now();
157            let mut rng = rng();
158            loop {
159                if start.elapsed() >= DURATION {
160                    break;
161                }
162                std::thread::sleep(Duration::from_millis(rng.random_range(1..10)));
163                f(rate, &v, &limiter)
164            }
165        };
166        let mut handles = vec![];
167        let mut rng = rng();
168        for _ in 0..THREADS {
169            let rate = rng.random_range(10..20);
170            let handle = std::thread::spawn({
171                let v = v.clone();
172                let limiter = limiter.clone();
173                move || task(rate, v, limiter, f)
174            });
175            handles.push(handle);
176        }
177
178        for handle in handles {
179            handle.join().unwrap();
180        }
181
182        let error = (v.load(Ordering::Relaxed) as isize - RATE as isize * DURATION.as_secs() as isize).unsigned_abs();
183        error as f64 / (RATE as f64 * DURATION.as_secs_f64())
184    }
185}