1extern crate time;
2
3pub mod store;
4
5use error::CellError;
6
7const MAX_CAS_ATTEMPTS: i64 = 5;
10
11#[derive(Debug, PartialEq)]
12pub struct Rate {
13 pub period: time::Duration,
14}
15
16impl Rate {
17 pub fn per_day(n: i64) -> Rate {
18 Rate::per_period(n, time::Duration::days(1))
19 }
20
21 pub fn per_hour(n: i64) -> Rate {
22 Rate::per_period(n, time::Duration::hours(1))
23 }
24
25 pub fn per_minute(n: i64) -> Rate {
26 Rate::per_period(n, time::Duration::minutes(1))
27 }
28
29 pub fn per_period(n: i64, period: time::Duration) -> Rate {
33 let ns: i64 = period.num_nanoseconds().unwrap();
34 let period = time::Duration::nanoseconds(((ns as f64) / (n as f64)) as i64);
35 Rate { period }
36 }
37
38 pub fn per_second(n: i64) -> Rate {
39 Rate::per_period(n, time::Duration::seconds(1))
40 }
41}
42
43#[derive(Debug, PartialEq)]
44pub struct RateLimitResult {
45 pub limit: i64,
46 pub remaining: i64,
47 pub reset_after: time::Duration,
48 pub retry_after: time::Duration,
49}
50
51pub struct RateLimiter<'a, T: 'a + store::Store> {
52 pub store: &'a mut T,
53
54 delay_variation_tolerance: time::Duration,
58
59 emission_interval: time::Duration,
63
64 limit: i64,
65}
66
67impl<'a, T: 'a + store::Store> RateLimiter<'a, T> {
68 pub fn new(store: &'a mut T, quota: &RateQuota) -> RateLimiter<'a, T> {
69 RateLimiter {
70 delay_variation_tolerance: time::Duration::nanoseconds(
71 quota.max_rate.period.num_nanoseconds().unwrap() * (quota.max_burst + 1),
72 ),
73 emission_interval: quota.max_rate.period,
74 limit: quota.max_burst + 1,
75 store,
76 }
77 }
78
79 pub fn rate_limit(
90 &mut self,
91 key: &str,
92 quantity: i64,
93 ) -> Result<(bool, RateLimitResult), CellError> {
94 let mut rlc = RateLimitResult {
95 limit: self.limit,
96 remaining: 0,
97 retry_after: time::Duration::seconds(-1),
98 reset_after: time::Duration::seconds(-1),
99 };
100
101 let increment = time::Duration::nanoseconds(
102 self.emission_interval.num_nanoseconds().unwrap() * quantity,
103 );
104 self.log_start(key, quantity, increment);
105
106 let limited: bool;
110
111 let mut ttl: time::Duration;
112
113 let mut i = 0;
122 loop {
123 log_debug!(self.store, "iteration = {}", i);
124
125 let (tat_val, now) = self.store.get_with_time(key)?;
128
129 let tat = match tat_val {
130 -1 => now,
131 _ => from_nanoseconds(tat_val),
132 };
133 log_debug!(
134 self.store,
135 "tat = {} (from store = {})",
136 tat.rfc3339(),
137 tat_val
138 );
139
140 let new_tat = if now > tat {
141 now + increment
142 } else {
143 tat + increment
144 };
145 log_debug!(self.store, "new_tat = {}", new_tat.rfc3339());
146
147 let allow_at = new_tat - self.delay_variation_tolerance;
149 let diff = now - allow_at;
150 log_debug!(
151 self.store,
152 "diff = {}ms (now - allow_at)",
153 diff.num_milliseconds()
154 );
155
156 if diff < time::Duration::zero() {
157 log_debug!(
158 self.store,
159 "BLOCKED retry_after = {}ms",
160 -diff.num_milliseconds()
161 );
162
163 if increment <= self.delay_variation_tolerance {
164 rlc.retry_after = -diff;
165 }
166
167 limited = true;
168 ttl = tat - now;
169 break;
170 }
171
172 let new_tat_ns = nanoseconds(new_tat);
173 ttl = new_tat - now;
174 log_debug!(self.store, "ALLOWED");
175
176 let updated = if tat_val == -1 {
182 self.store.set_if_not_exists_with_ttl(key, new_tat_ns, ttl)?
183 } else {
184 self.store
185 .compare_and_swap_with_ttl(key, tat_val, new_tat_ns, ttl)?
186 };
187
188 if updated {
189 limited = false;
190 break;
191 }
192
193 i += 1;
194 if i > MAX_CAS_ATTEMPTS {
195 return Err(error!(
196 "Failed to update rate limit after \
197 {} attempts",
198 MAX_CAS_ATTEMPTS
199 ));
200 }
201 }
202
203 let next = self.delay_variation_tolerance - ttl;
204 if next > -self.emission_interval {
205 rlc.remaining = (next.num_microseconds().unwrap() as f64
206 / self.emission_interval.num_microseconds().unwrap() as f64)
207 as i64;
208 }
209 rlc.reset_after = ttl;
210
211 self.log_end(&rlc);
212 Ok((limited, rlc))
213 }
214
215 fn log_end(&self, rlc: &RateLimitResult) {
216 log_debug!(
217 self.store,
218 "limit = {} remaining = {}",
219 self.limit,
220 rlc.remaining
221 );
222 log_debug!(
223 self.store,
224 "retry_after = {}ms",
225 rlc.retry_after.num_milliseconds()
226 );
227 log_debug!(
228 self.store,
229 "reset_after = {}ms (ttl)",
230 rlc.reset_after.num_milliseconds()
231 );
232 }
233
234 fn log_start(&self, key: &str, quantity: i64, increment: time::Duration) {
235 log_debug!(self.store, "");
236 log_debug!(self.store, "-----");
237 log_debug!(self.store, "key = {}", key);
238 log_debug!(self.store, "quantity = {}", quantity);
239 log_debug!(
240 self.store,
241 "delay_variation_tolerance = {}ms",
242 self.delay_variation_tolerance.num_milliseconds()
243 );
244 log_debug!(
245 self.store,
246 "emission_interval = {}ms",
247 self.emission_interval.num_milliseconds()
248 );
249 log_debug!(
250 self.store,
251 "tat_increment = {}ms (emission_interval * quantity)",
252 increment.num_milliseconds()
253 );
254 }
255}
256
257#[derive(Debug, PartialEq)]
258pub struct RateQuota {
259 pub max_burst: i64,
260 pub max_rate: Rate,
261}
262
263fn from_nanoseconds(x: i64) -> time::Tm {
264 let ns = (10 as i64).pow(9);
265 time::at(time::Timespec {
266 sec: x / ns,
267 nsec: (x % ns) as i32,
268 })
269}
270
271fn nanoseconds(x: time::Tm) -> i64 {
272 let ts = x.to_timespec();
273 ts.sec * (10 as i64).pow(9) + i64::from(ts.nsec)
274}
275
276#[cfg(test)]
277mod tests {
278 extern crate time;
279
280 use cell::*;
281 use error::CellError;
282 use std::error::Error;
283
284 #[test]
285 fn it_creates_rates_from_days() {
286 assert_eq!(
287 Rate {
288 period: time::Duration::hours(1),
289 },
290 Rate::per_day(24)
291 )
292 }
293
294 #[test]
295 fn it_creates_rates_from_hours() {
296 assert_eq!(
297 Rate {
298 period: time::Duration::minutes(10),
299 },
300 Rate::per_hour(6)
301 )
302 }
303
304 #[test]
305 fn it_creates_rates_from_minutes() {
306 assert_eq!(
307 Rate {
308 period: time::Duration::seconds(10),
309 },
310 Rate::per_minute(6)
311 )
312 }
313
314 #[test]
315 fn it_creates_rates_from_periods() {
316 assert_eq!(
317 Rate {
318 period: time::Duration::seconds(20),
319 },
320 Rate::per_period(6, time::Duration::minutes(2))
321 )
322 }
323
324 #[test]
325 fn it_creates_rates_from_seconds() {
326 assert_eq!(
327 Rate {
328 period: time::Duration::milliseconds(200),
329 },
330 Rate::per_second(5)
331 )
332 }
333
334 #[cfg_attr(rustfmt, rustfmt_skip)]
337 #[test]
338 fn it_rate_limits() {
339 let limit = 5;
340 let quota = RateQuota {
341 max_burst: limit - 1,
342 max_rate: Rate::per_second(1),
343 };
344 let start = time::now_utc();
345 let mut memory_store = store::MemoryStore::new_verbose();
346 let mut test_store = TestStore::new(&mut memory_store);
347 let mut limiter = RateLimiter::new(&mut test_store, "a);
348
349 let cases = [
350 RateLimitCase::new(0, start, 6, 5, time::Duration::zero(),
356 time::Duration::seconds(-1), true),
357
358 RateLimitCase::new(1, start, 1, 4, time::Duration::seconds(1),
360 time::Duration::seconds(-1), false),
361 RateLimitCase::new(2, start, 1, 3, time::Duration::seconds(2),
362 time::Duration::seconds(-1), false),
363 RateLimitCase::new(3, start, 1, 2, time::Duration::seconds(3),
364 time::Duration::seconds(-1), false),
365 RateLimitCase::new(4, start, 1, 1, time::Duration::seconds(4),
366 time::Duration::seconds(-1), false),
367 RateLimitCase::new(5, start, 1, 0, time::Duration::seconds(5),
368 time::Duration::seconds(-1), false),
369 RateLimitCase::new(6, start, 1, 0, time::Duration::seconds(5),
370 time::Duration::seconds(1), true),
371
372 RateLimitCase::new(7, start + time::Duration::milliseconds(3000), 1, 2,
373 time::Duration::milliseconds(3000), time::Duration::seconds(-1), false),
374 RateLimitCase::new(8, start + time::Duration::milliseconds(3100), 1, 1,
375 time::Duration::milliseconds(3900), time::Duration::seconds(-1), false),
376 RateLimitCase::new(9, start + time::Duration::milliseconds(4000), 1, 1,
377 time::Duration::milliseconds(4000), time::Duration::seconds(-1), false),
378 RateLimitCase::new(10, start + time::Duration::milliseconds(8000), 1, 4,
379 time::Duration::milliseconds(1000), time::Duration::seconds(-1), false),
380 RateLimitCase::new(11, start + time::Duration::milliseconds(9500), 1, 4,
381 time::Duration::milliseconds(1000), time::Duration::seconds(-1), false),
382
383 RateLimitCase::new(12, start + time::Duration::milliseconds(9500), 0, 4,
385 time::Duration::seconds(1), time::Duration::seconds(-1), false),
386
387 RateLimitCase::new(13, start + time::Duration::milliseconds(9500), 2, 2,
389 time::Duration::seconds(3), time::Duration::seconds(-1), false),
390
391 RateLimitCase::new(14, start + time::Duration::milliseconds(9500), 5, 2,
393 time::Duration::seconds(3), time::Duration::seconds(3), true),
394 ];
395
396 for case in cases.iter() {
397 println!("starting test case = {:?}", case.num);
398 println!("{:?}", case);
399
400 limiter.store.clock = case.now;
401 let (limited, results) = limiter.rate_limit("foo", case.volume).unwrap();
402
403 println!("limited = {:?}", limited);
404 println!("{:?}", results);
405 println!("");
406
407 assert_eq!(case.limited, limited);
408 assert_eq!(limit, results.limit);
409 assert_eq!(case.remaining, results.remaining);
410 assert_eq!(case.reset_after, results.reset_after);
411 assert_eq!(case.retry_after, results.retry_after);
412 }
413 }
414
415 #[test]
416 fn it_handles_rate_limit_update_failures() {
417 let quota = RateQuota {
418 max_burst: 1,
419 max_rate: Rate::per_second(1),
420 };
421 let mut memory_store = store::MemoryStore::new_verbose();
422 let mut test_store = TestStore::new(&mut memory_store);
423 test_store.fail_updates = true;
424
425 let mut limiter = RateLimiter::new(&mut test_store, "a);
426
427 let err = error!("Failed to update rate limit after 5 attempts");
428
429 assert_eq!(
430 err.description(),
431 limiter.rate_limit("foo", 1).unwrap_err().description()
432 );
433 }
434
435 #[derive(Debug, PartialEq)]
436 struct RateLimitCase {
437 num: i64,
438 now: time::Tm,
439 volume: i64,
440 remaining: i64,
441 reset_after: time::Duration,
442 retry_after: time::Duration,
443 limited: bool,
444 }
445
446 impl RateLimitCase {
447 fn new(
448 num: i64,
449 now: time::Tm,
450 volume: i64,
451 remaining: i64,
452 reset_after: time::Duration,
453 retry_after: time::Duration,
454 limited: bool,
455 ) -> RateLimitCase {
456 return RateLimitCase {
457 num: num,
458 now: now,
459 volume: volume,
460 remaining: remaining,
461 reset_after: reset_after,
462 retry_after: retry_after,
463 limited: limited,
464 };
465 }
466 }
467
468 struct TestStore<'a> {
472 clock: time::Tm,
473 fail_updates: bool,
474 store: &'a mut store::MemoryStore,
475 }
476
477 impl<'a> TestStore<'a> {
478 fn new(store: &'a mut store::MemoryStore) -> TestStore {
479 TestStore {
480 clock: time::empty_tm(),
481 fail_updates: false,
482 store: store,
483 }
484 }
485 }
486
487 impl<'a> store::Store for TestStore<'a> {
488 fn compare_and_swap_with_ttl(
489 &mut self,
490 key: &str,
491 old: i64,
492 new: i64,
493 ttl: time::Duration,
494 ) -> Result<bool, CellError> {
495 if self.fail_updates {
496 Ok(false)
497 } else {
498 self.store.compare_and_swap_with_ttl(key, old, new, ttl)
499 }
500 }
501
502 fn get_with_time(&self, key: &str) -> Result<(i64, time::Tm), CellError> {
503 let tup = self.store.get_with_time(key)?;
504 Ok((tup.0, self.clock))
505 }
506
507 fn log_debug(&self, message: &str) {
508 self.store.log_debug(message)
509 }
510
511 fn set_if_not_exists_with_ttl(
512 &mut self,
513 key: &str,
514 value: i64,
515 ttl: time::Duration,
516 ) -> Result<bool, CellError> {
517 if self.fail_updates {
518 Ok(false)
519 } else {
520 self.store.set_if_not_exists_with_ttl(key, value, ttl)
521 }
522 }
523 }
524}