ray/
rate_limiter.rs

1use std::sync::{Arc, Mutex};
2use std::time::{Duration, Instant};
3
4#[derive(Debug, Default)]
5pub(crate) struct RateLimiterState {
6    max_calls: Option<u64>,
7    max_per_second: Option<u64>,
8    hits: Vec<Instant>,
9    notified: bool,
10}
11
12#[derive(Debug, Clone)]
13pub struct RateLimiterHandle {
14    state: Arc<Mutex<RateLimiterState>>,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub(crate) enum RateLimitStatus {
19    Allowed,
20    Blocked { notify: bool },
21}
22
23impl RateLimiterHandle {
24    pub(crate) fn new(state: Arc<Mutex<RateLimiterState>>) -> Self {
25        Self { state }
26    }
27
28    /// Limit the total number of payloads sent.
29    pub fn max(self, max_calls: i64) -> Self {
30        {
31            let mut state = self.state.lock().unwrap();
32            state.max_calls = normalize_limit(max_calls);
33            if state.max_calls.is_none() {
34                state.notified = false;
35            }
36        }
37        self
38    }
39
40    /// Limit the number of payloads sent per second.
41    pub fn per_second(self, max_per_second: i64) -> Self {
42        {
43            let mut state = self.state.lock().unwrap();
44            state.max_per_second = normalize_limit(max_per_second);
45            if state.max_per_second.is_none() {
46                state.notified = false;
47            }
48        }
49        self
50    }
51
52    /// Clear all rate limits.
53    pub fn clear(self) -> Self {
54        {
55            let mut state = self.state.lock().unwrap();
56            state.clear();
57        }
58        self
59    }
60
61    pub(crate) fn status(&self) -> RateLimitStatus {
62        let mut state = self.state.lock().unwrap();
63        let max_reached = state.is_max_reached();
64        let per_second_reached = state.is_max_per_second_reached();
65
66        if max_reached || per_second_reached {
67            let notify = !state.notified;
68            if notify {
69                state.notified = true;
70            }
71            return RateLimitStatus::Blocked { notify };
72        }
73
74        RateLimitStatus::Allowed
75    }
76
77    pub(crate) fn hit(&self) {
78        let mut state = self.state.lock().unwrap();
79        state.hit();
80    }
81}
82
83impl RateLimiterState {
84    fn hit(&mut self) {
85        self.hits.push(Instant::now());
86    }
87
88    fn clear(&mut self) {
89        self.max_calls = None;
90        self.max_per_second = None;
91        self.hits.clear();
92        self.notified = false;
93    }
94
95    fn is_max_reached(&mut self) -> bool {
96        let Some(max_calls) = self.max_calls else {
97            return false;
98        };
99
100        let reached = self.hits.len() as u64 >= max_calls;
101        if !reached {
102            self.notified = false;
103        }
104        reached
105    }
106
107    fn is_max_per_second_reached(&mut self) -> bool {
108        let Some(max_per_second) = self.max_per_second else {
109            return false;
110        };
111
112        let reached = self.count_last_second() >= max_per_second;
113        if !reached {
114            self.notified = false;
115        }
116        reached
117    }
118
119    fn count_last_second(&self) -> u64 {
120        let now = Instant::now();
121        self.hits
122            .iter()
123            .filter(|hit| now.duration_since(**hit) <= Duration::from_secs(1))
124            .count() as u64
125    }
126}
127
128fn normalize_limit(value: i64) -> Option<u64> {
129    if value <= 0 {
130        None
131    } else {
132        Some(value as u64)
133    }
134}