tocket/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3//! # Tocket
4//!
5//! This library provides implementation of token bucket algorithm and some storage implementations.
6//!
7//! ## Available storages:
8//! - [`InMemoryStorage`]
9//! - [`RedisStorage`]
10//! - [`DistributedStorage`]
11//!
12//! You can implement your own [storage] (e.g. Postgres).
13//!
14//! ## Features
15//! - `redis-impl` - redis storage implementation
16//! - `distributed-impl` - distributed storage implementation
17//!
18//! [`InMemoryStorage`]: crate::in_memory::InMemoryStorage
19//! [`RedisStorage`]: crate::in_redis::RedisStorage
20//! [`DistributedStorage`]: crate::distributed::DistributedStorage
21//! [storage]: crate::Storage
22
23pub mod in_memory;
24
25#[cfg(feature = "distributed-impl")]
26#[cfg_attr(docsrs, doc(cfg(feature = "distributed-impl")))]
27pub mod distributed;
28
29#[cfg(feature = "redis-impl")]
30#[cfg_attr(docsrs, doc(cfg(feature = "redis-impl")))]
31pub mod in_redis;
32
33pub use in_memory::*;
34
35#[cfg(feature = "distributed-impl")]
36#[cfg_attr(docsrs, doc(cfg(feature = "distributed-impl")))]
37pub use distributed::*;
38
39#[cfg(feature = "redis-impl")]
40#[cfg_attr(docsrs, doc(cfg(feature = "redis-impl")))]
41pub use in_redis::*;
42
43/// Trait that provides function for tokens acquiring.
44///
45/// Object that implements this trait should load state, execute provided algorithm
46/// and save updated state.
47pub trait Storage {
48    type Error: From<RateLimitExceededError>;
49
50    fn try_acquire(&self, alg: TokenBucketAlgorithm, permits: u32) -> Result<(), Self::Error>;
51}
52
53/// State of token bucket.
54#[derive(Debug, Clone, Eq, PartialEq)]
55pub struct State {
56    pub cap: u32,
57    pub available_tokens: u32,
58    pub last_refill: time::OffsetDateTime,
59    pub refill_tick: time::Duration,
60}
61
62/// Rate limiter that implements token bucket algorithm.
63pub struct TokenBucket<S> {
64    storage: S,
65}
66
67impl<S> TokenBucket<S>
68where
69    S: Storage,
70{
71    /// Creates new token bucket rate limiter with provided storage.
72    pub fn new(storage: S) -> Self {
73        Self { storage }
74    }
75
76    /// Tries to acquire N tokens.
77    ///
78    /// # Errors
79    ///
80    /// Will return `Err` if there are not enough tokens or if the storage could not save/load state.
81    pub fn try_acquire(&self, permits: u32) -> Result<(), S::Error> {
82        self.storage
83            .try_acquire(TokenBucketAlgorithm { mode: Mode::N }, permits)
84    }
85
86    /// Tries to acquire 1 token.
87    ///
88    /// # Errors
89    ///
90    /// Will return `Err` if there are not enough tokens or if the storage could not save/load state.
91    pub fn try_acquire_one(&self) -> Result<(), S::Error> {
92        self.try_acquire(1)
93    }
94
95    /// Tries to acquire N or all available tokens if `available < N`.
96    ///
97    /// # Errors
98    ///
99    /// Will return `Err` if the storage could not save/load state.
100    pub fn try_acquire_n_or_all(&self, permits: u32) -> Result<(), S::Error> {
101        self.storage
102            .try_acquire(TokenBucketAlgorithm { mode: Mode::All }, permits)
103    }
104}
105
106/// Struct that implements token bucket algorithm.
107#[derive(Debug)]
108pub struct TokenBucketAlgorithm {
109    mode: Mode,
110}
111
112#[derive(Debug, Clone, Copy, Eq, PartialEq)]
113enum Mode {
114    N,
115    All,
116}
117
118impl TokenBucketAlgorithm {
119    pub fn try_acquire(
120        &self,
121        state: &mut State,
122        permits: u32,
123    ) -> Result<(), RateLimitExceededError> {
124        self.refill_state(state);
125
126        match self.mode {
127            Mode::N => {
128                if state.available_tokens >= permits {
129                    state.available_tokens -= permits;
130                    Ok(())
131                } else {
132                    Err(RateLimitExceededError(()))
133                }
134            }
135            Mode::All => {
136                state.available_tokens -= u32::min(permits, state.available_tokens);
137                Ok(())
138            }
139        }
140    }
141
142    fn refill_state(&self, state: &mut State) {
143        let now = time::OffsetDateTime::now_utc();
144        let since_last_refill = now - state.last_refill;
145
146        if since_last_refill <= state.refill_tick {
147            return;
148        }
149
150        let tokens_since_last_refill = {
151            let mut tokens_count = 0u32;
152            let mut k = since_last_refill;
153            loop {
154                k -= state.refill_tick;
155                if k <= time::Duration::ZERO {
156                    break;
157                }
158                tokens_count += 1;
159            }
160            tokens_count
161        };
162
163        state.available_tokens =
164            u32::min(state.available_tokens + tokens_since_last_refill, state.cap);
165        state.last_refill += state.refill_tick * tokens_since_last_refill;
166    }
167}
168
169#[derive(Debug, thiserror::Error)]
170#[error("rate limit exceeded")]
171pub struct RateLimitExceededError(());