1use core::cmp;
2
3use http::{
4 header::{HeaderName, HeaderValue},
5 Response,
6};
7
8use crate::{nanos::Nanos, quota::Quota};
9
10#[derive(Clone, PartialEq, Eq, Debug)]
12pub struct RateSnapshot {
13 t: Nanos,
15 tau: Nanos,
17 pub(crate) time_of_measurement: Nanos,
19 pub(crate) tat: Nanos,
21}
22
23const X_RT_LIMIT: HeaderName = HeaderName::from_static("x-ratelimit-limit");
24const X_RT_REMAINING: HeaderName = HeaderName::from_static("x-ratelimit-remaining");
25
26impl RateSnapshot {
27 pub fn extend_response<Ext>(&self, res: &mut Response<Ext>) {
31 let burst_size = self.quota().burst_size().get();
32 let remaining_burst_capacity = self.remaining_burst_capacity();
33 let headers = res.headers_mut();
34 headers.insert(X_RT_LIMIT, HeaderValue::from(burst_size));
35 headers.insert(X_RT_REMAINING, HeaderValue::from(remaining_burst_capacity));
36 }
37
38 pub(crate) const fn new(t: Nanos, tau: Nanos, time_of_measurement: Nanos, tat: Nanos) -> Self {
39 Self {
40 t,
41 tau,
42 time_of_measurement,
43 tat,
44 }
45 }
46
47 pub(crate) fn quota(&self) -> Quota {
49 Quota::from_gcra_parameters(self.t, self.tau)
50 }
51
52 fn remaining_burst_capacity(&self) -> u32 {
53 let t0 = self.time_of_measurement + self.t;
54 (cmp::min((t0 + self.tau).saturating_sub(self.tat).as_u64(), self.tau.as_u64()) / self.t.as_u64()) as u32
55 }
56}
57
58#[cfg(test)]
59mod test {
60 use core::time::Duration;
61
62 use crate::{quota::Quota, state::RateLimiter, timer::FakeRelativeClock};
63
64 #[test]
65 fn state_information() {
66 let clock = FakeRelativeClock::default();
67 let lim = RateLimiter::direct_with_clock(Quota::per_second(4), &clock);
68 assert_eq!(Ok(3), lim.check().map(|outcome| outcome.remaining_burst_capacity()));
69 assert_eq!(Ok(2), lim.check().map(|outcome| outcome.remaining_burst_capacity()));
70 assert_eq!(Ok(1), lim.check().map(|outcome| outcome.remaining_burst_capacity()));
71 assert_eq!(Ok(0), lim.check().map(|outcome| outcome.remaining_burst_capacity()));
72 assert!(lim.check().is_err());
73 }
74
75 #[test]
76 fn state_snapshot_tracks_quota_accurately() {
77 let period = Duration::from_millis(90);
78 let quota = Quota::with_period(period).unwrap().allow_burst(2);
79
80 let clock = FakeRelativeClock::default();
81
82 let lim = RateLimiter::direct_with_clock(quota, &clock);
84
85 assert_eq!(lim.check().unwrap().remaining_burst_capacity(), 1);
86 assert_eq!(lim.check().unwrap().remaining_burst_capacity(), 0);
87 assert_eq!(lim.check().map_err(|_| ()), Err(()), "should rate limit");
88
89 clock.advance(Duration::from_secs(120));
90 assert_eq!(lim.check().map(|s| s.remaining_burst_capacity()), Ok(2));
91 assert_eq!(lim.check().map(|s| s.remaining_burst_capacity()), Ok(1));
92 assert_eq!(lim.check().map(|s| s.remaining_burst_capacity()), Ok(0));
93 assert_eq!(lim.check().map_err(|_| ()), Err(()), "should rate limit");
94 }
95
96 #[test]
97 fn state_snapshot_tracks_quota_accurately_with_real_clock() {
98 let period = Duration::from_millis(90);
99 let quota = Quota::with_period(period).unwrap().allow_burst(2);
100 let lim = RateLimiter::direct(quota);
101
102 assert_eq!(lim.check().unwrap().remaining_burst_capacity(), 1);
103 assert_eq!(lim.check().unwrap().remaining_burst_capacity(), 0);
104 assert_eq!(lim.check().map_err(|_| ()), Err(()), "should rate limit");
105 }
106}