1use crate::state::StateStore;
2use crate::InsufficientCapacity;
3use crate::{clock, middleware::StateSnapshot, Quota};
4use crate::{middleware::RateLimitingMiddleware, nanos::Nanos};
5use core::num::NonZeroU32;
6use core::time::Duration;
7use core::{cmp, fmt};
8
9#[cfg(feature = "std")]
10use crate::Jitter;
11
12#[derive(Debug, PartialEq, Eq)]
17pub struct NotUntil<P: clock::Reference> {
18 state: StateSnapshot,
19 start: P,
20}
21
22impl<P: clock::Reference> NotUntil<P> {
23 #[inline]
25 pub(crate) fn new(state: StateSnapshot, start: P) -> Self {
26 Self { state, start }
27 }
28
29 #[inline]
33 pub fn earliest_possible(&self) -> P {
34 let tat: Nanos = self.state.tat;
35 self.start + tat
36 }
37
38 #[inline]
45 pub fn wait_time_from(&self, from: P) -> Duration {
46 let earliest = self.earliest_possible();
47 earliest.duration_since(earliest.min(from)).into()
48 }
49
50 #[inline]
52 pub fn quota(&self) -> Quota {
53 self.state.quota()
54 }
55
56 #[cfg(feature = "std")] #[inline]
58 pub(crate) fn earliest_possible_with_offset(&self, jitter: Jitter) -> P {
59 let tat = jitter + self.state.tat;
60 self.start + tat
61 }
62
63 #[cfg(feature = "std")] #[inline]
65 pub(crate) fn wait_time_with_offset(&self, from: P, jitter: Jitter) -> Duration {
66 let earliest = self.earliest_possible_with_offset(jitter);
67 earliest.duration_since(earliest.min(from)).into()
68 }
69}
70
71impl<P: clock::Reference> fmt::Display for NotUntil<P> {
72 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
73 write!(f, "rate-limited until {:?}", self.start + self.state.tat)
74 }
75}
76
77#[derive(Debug, PartialEq, Eq)]
78pub(crate) struct Gcra {
79 t: Nanos,
81
82 tau: Nanos,
86}
87
88impl Gcra {
89 pub(crate) fn new(quota: Quota) -> Self {
90 let t: Nanos = cmp::max(quota.replenish_1_per, Duration::from_nanos(1)).into();
91 let tau: Nanos = t * (quota.max_burst.get() - 1).into();
92 Gcra { t, tau }
93 }
94
95 pub(crate) fn t(&self) -> Nanos {
96 self.t
97 }
98
99 pub(crate) fn test_and_update<
101 K,
102 P: clock::Reference,
103 S: StateStore<Key = K>,
104 MW: RateLimitingMiddleware<P>,
105 >(
106 &self,
107 start: P,
108 key: &K,
109 state: &S,
110 t0: P,
111 ) -> Result<MW::PositiveOutcome, MW::NegativeOutcome> {
112 let t0 = t0.duration_since(start);
113 let tau = self.tau;
114 let t = self.t;
115 state.measure_and_replace(key, |tat| {
116 let tat = tat.unwrap_or(t0);
117 let earliest_time = tat.saturating_sub(tau);
118 if t0 < earliest_time {
119 Err(MW::disallow(
120 key,
121 StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time),
122 start,
123 ))
124 } else {
125 let next = cmp::max(tat, t0) + t;
126 Ok((
127 MW::allow(key, StateSnapshot::new(self.t, self.tau, t0, next)),
128 next,
129 ))
130 }
131 })
132 }
133
134 pub(crate) fn test_n_all_and_update<
136 K,
137 P: clock::Reference,
138 S: StateStore<Key = K>,
139 MW: RateLimitingMiddleware<P>,
140 >(
141 &self,
142 start: P,
143 key: &K,
144 n: NonZeroU32,
145 state: &S,
146 t0: P,
147 ) -> Result<Result<MW::PositiveOutcome, MW::NegativeOutcome>, InsufficientCapacity> {
148 let t0 = t0.duration_since(start);
149 let tau = self.tau;
150 let t = self.t;
151 let additional_weight = t * (n.get() - 1) as u64;
152
153 if additional_weight > tau {
156 return Err(InsufficientCapacity(
157 1 + (self.tau.as_u64() / t.as_u64()) as u32,
158 ));
159 }
160 Ok(state.measure_and_replace(key, |tat| {
161 let tat = tat.unwrap_or(t0);
162 let earliest_time = (tat + additional_weight).saturating_sub(tau);
163 if t0 < earliest_time {
164 Err(MW::disallow(
165 key,
166 StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time),
167 start,
168 ))
169 } else {
170 let next = cmp::max(tat, t0) + t + additional_weight;
171 Ok((
172 MW::allow(key, StateSnapshot::new(self.t, self.tau, t0, next)),
173 next,
174 ))
175 }
176 }))
177 }
178}
179
180#[cfg(test)]
181mod test {
182 use super::*;
183 use crate::Quota;
184 use core::num::NonZeroU32;
185
186 use proptest::prelude::*;
187
188 #[cfg(feature = "std")]
190 #[test]
191 fn gcra_derives() {
192 use assertables::assert_gt;
193 use nonzero_ext::nonzero;
194
195 let g = Gcra::new(Quota::per_second(nonzero!(1u32)));
196 let g2 = Gcra::new(Quota::per_second(nonzero!(2u32)));
197 assert_eq!(g, g);
198 assert_ne!(g, g2);
199 assert_gt!(format!("{g:?}").len(), 0);
200 }
201
202 #[cfg(feature = "std")]
204 #[test]
205 fn notuntil_impls() {
206 use crate::RateLimiter;
207 use assertables::assert_gt;
208 use clock::FakeRelativeClock;
209 use nonzero_ext::nonzero;
210
211 let clock = FakeRelativeClock::default();
212 let quota = Quota::per_second(nonzero!(1u32));
213 let lb = RateLimiter::direct_with_clock(quota, clock);
214 assert!(lb.check().is_ok());
215 assert!(lb
216 .check()
217 .map_err(|nu| {
218 assert_eq!(nu, nu);
219 assert_gt!(format!("{nu:?}").len(), 0);
220 assert_eq!(format!("{nu}"), "rate-limited until Nanos(1s)");
221 assert_eq!(nu.quota(), quota);
222 })
223 .is_err());
224 }
225
226 #[derive(Debug)]
227 struct Count(NonZeroU32);
228 impl Arbitrary for Count {
229 type Parameters = ();
230 fn arbitrary_with(_args: ()) -> Self::Strategy {
231 (1..10000u32)
232 .prop_map(|x| Count(NonZeroU32::new(x).unwrap()))
233 .boxed()
234 }
235
236 type Strategy = BoxedStrategy<Count>;
237 }
238
239 #[cfg(feature = "std")]
240 #[test]
241 fn cover_count_derives() {
242 assert_eq!(
243 format!("{:?}", Count(nonzero_ext::nonzero!(1_u32))),
244 "Count(1)"
245 );
246 }
247
248 #[test]
249 fn roundtrips_quota() {
250 proptest!(ProptestConfig::default(), |(per_second: Count, burst: Count)| {
251 let quota = Quota::per_second(per_second.0).allow_burst(burst.0);
252 let gcra = Gcra::new(quota);
253 let back = Quota::from_gcra_parameters(gcra.t, gcra.tau);
254 assert_eq!(quota, back);
255 })
256 }
257}