bucket_limiter/
lib.rs

1extern crate redis;
2extern crate chrono;
3
4use std::default::Default;
5
6use chrono::{DateTime, Utc};
7use redis::{
8    Client as RedisClient,
9    Script as RedisScript,
10    Commands,
11};
12
13const LUA_SCRIPT: &str = include_str!("limiter.lua");
14const KEY_PREFIX: &str = "limiter";
15const REDIS_HOST: &str = "localhost";
16const REDIS_PORT: u16 = 6379;
17const REDIS_DB: u16 = 0;
18
19fn timestamp_ms(t: DateTime<Utc>) -> i64 {
20    t.timestamp() * 1000 + i64::from(t.timestamp_subsec_millis())
21}
22
23fn now_ms() -> i64 {
24    timestamp_ms(Utc::now())
25}
26
27pub trait Limiter {
28    fn get_token_count<'a>(&self, key: &'a str, interval: u32) -> Option<u32>;
29    fn consume<'a>(&self, args: Vec<(&'a str, u32, u32, u32)>)
30                   -> Result<(), RedisConsumeError>;
31    fn consume_one<'a>(&self, key: &'a str, interval: u32, capacity: u32, n: u32)
32                       -> Result<(), RedisConsumeError> {
33        self.consume(vec![(key, interval, capacity, n)])
34    }
35}
36
37#[derive(Debug)]
38pub enum RedisConsumeError {
39    Denied {
40        redis_key: String,
41        interval: u32,
42        capacity: u32,
43        current_tokens: u32,
44        last_fill_at: i64,
45    },
46    BadArg(String),
47    Redis(redis::RedisError)
48}
49
50pub struct RedisLimiter {
51    redis_cli: RedisClient,
52    key_prefix: String,
53    script: RedisScript,
54}
55
56#[derive(Default)]
57pub struct RedisLimiterBuilder<'a> {
58    redis_cli: Option<RedisClient>,
59    host: Option<&'a str>,
60    port: Option<u16>,
61    db: Option<u16>,
62    key_prefix: Option<&'a str>,
63    script_str: Option<&'a str>,
64}
65
66impl<'a> RedisLimiterBuilder<'a> {
67    pub fn new() -> Self {
68        RedisLimiterBuilder{
69            redis_cli: None,
70            host: None,
71            port: None,
72            db: None,
73            key_prefix: None,
74            script_str: None,
75        }
76    }
77    pub fn build(self) -> RedisLimiter {
78        let script_str = self.script_str.unwrap_or(LUA_SCRIPT);
79        let key_prefix = self.key_prefix.unwrap_or(KEY_PREFIX);
80        if let Some(redis_cli) = self.redis_cli {
81            RedisLimiter::new(redis_cli, key_prefix, script_str)
82        } else {
83            let url = format!(
84                "redis://{}:{}/{}",
85                self.host.unwrap_or(REDIS_HOST),
86                self.port.unwrap_or(REDIS_PORT),
87                self.db.unwrap_or(REDIS_DB)
88            );
89            let client = RedisClient::open(url.as_str()).unwrap();
90            RedisLimiter::new(client, key_prefix, script_str)
91        }
92    }
93
94    pub fn redis_cli(&mut self, client: RedisClient) -> &mut Self {
95        self.redis_cli = Some(client);
96        self
97    }
98    pub fn host(&mut self, value: &'a str) -> &mut Self {
99        self.host = Some(value);
100        self
101    }
102    pub fn port(&mut self, value: u16) -> &mut Self {
103        self.port = Some(value);
104        self
105    }
106    pub fn db(&mut self, value: u16) -> &mut Self {
107        self.db = Some(value);
108        self
109    }
110    pub fn key_prefix(&mut self, value: &'a str) -> &mut Self {
111        self.key_prefix = Some(value);
112        self
113    }
114    pub fn script_str(&mut self, value: &'a str) -> &mut Self {
115        self.script_str = Some(value);
116        self
117    }
118}
119
120impl RedisLimiter {
121    pub fn new<'a>(
122        redis_cli: RedisClient,
123        key_prefix: &'a str,
124        script_str: &'a str,
125    ) -> Self {
126        let key_prefix = key_prefix.to_owned();
127        let script = RedisScript::new(script_str);
128        RedisLimiter{ redis_cli, key_prefix, script }
129    }
130
131    pub fn get_redis_key<'a>(&self, key: &'a str, interval: u32) -> String {
132        format!("{}:{}:{}", self.key_prefix, key, interval)
133    }
134}
135
136impl Default for RedisLimiter {
137    fn default() -> Self { RedisLimiterBuilder::new().build() }
138}
139
140impl Limiter for RedisLimiter {
141    fn get_token_count<'a>(&self, key: &'a str, interval: u32) -> Option<u32> {
142        self.redis_cli
143            .get_connection()
144            .unwrap()
145            .hget(self.get_redis_key(key, interval), "tokens")
146            .ok()
147    }
148
149    fn consume<'a>(&self, args: Vec<(&'a str, u32, u32, u32)>)
150                   -> Result<(), RedisConsumeError> {
151        let now_ms = now_ms();
152        let mut invocation = self.script.prepare_invoke();
153        for (key, interval, capacity, n) in args {
154            if key.len() < 1 || n < 1 || interval < 1 || capacity < 1 {
155                return Err(RedisConsumeError::BadArg(format!(
156                    "[BadArg]: key={}, interval={}, capacity={}, n={}",
157                    key, interval, capacity, n
158                )));
159            }
160            let redis_key = self.get_redis_key(key, interval);
161            let expire = interval * 2 + 15;
162            let interval_ms = interval * 1000;
163            invocation
164                .key(redis_key)
165                .arg(interval_ms)
166                .arg(capacity)
167                .arg(n)
168                .arg(now_ms)
169                .arg(expire);
170        }
171        let conn = try!{
172            self.redis_cli
173                .get_connection()
174                .map_err(RedisConsumeError::Redis)
175        };
176        match invocation.invoke(&conn) {
177            Ok((_, 0, 0, 0, 0)) => Ok(()),
178            Ok((redis_key, interval_ms, capacity,
179                current_tokens, last_fill_at)) => {
180                let interval = interval_ms / 1000;
181                Err(RedisConsumeError::Denied{
182                    redis_key, interval, capacity,
183                    current_tokens, last_fill_at
184                })
185            }
186            Err(e) => Err(RedisConsumeError::Redis(e))
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use std::time::Duration;
194    use std::thread;
195    use super::*;
196
197    fn redis_client() -> RedisClient {
198        let url = format!("redis://{}:{}/{}", REDIS_HOST, REDIS_PORT, REDIS_DB);
199        RedisClient::open(url.as_str()).unwrap()
200    }
201
202    fn consume_many<'a>(
203        limiter: &RedisLimiter,
204        key: &'a str, interval: u32, capacity: u32, n: u32) {
205        for i in 0..n {
206            let (success, count) = if i >= capacity {
207                (false, Some(0))
208            } else {
209                (true, Some(capacity - i - 1))
210            };
211            assert_eq!(limiter.consume_one(key, interval, capacity, 1).is_ok(), success);
212            assert_eq!(limiter.get_token_count(key, interval), count);
213        }
214    }
215
216    fn del_keys<'a>(limiter: &RedisLimiter, args: Vec<(&'a str, u32)>) {
217        let client = redis_client();
218        for (key, interval) in args {
219            let _: () = client
220                .del(limiter.get_redis_key(key, interval))
221                .unwrap();
222        }
223    }
224
225    #[test]
226    fn test_basic() {
227        let limiter = RedisLimiter::default();
228        let key = "test_basic";
229        let interval = 10;
230        let capacity = 6;
231
232        assert_eq!(limiter.get_token_count(key, interval), None);
233        consume_many(&limiter, key, interval, capacity, 12);
234
235        del_keys(&limiter, vec![(key, interval)]);
236    }
237
238    #[test]
239    fn test_refill() {
240        let limiter = RedisLimiter::default();
241        let key = "test_refill";
242        let interval = 1;
243        let capacity = 5;
244
245        assert_eq!(limiter.get_token_count(key, interval), None);
246        consume_many(&limiter, key, interval, capacity, 6);
247        assert_eq!(limiter.consume_one(key, interval, capacity, 1).is_ok(), false);
248        assert_eq!(limiter.get_token_count(key, interval), Some(0));
249
250        thread::sleep(Duration::from_millis((interval * 1000 + 2) as u64));
251        assert_eq!(limiter.consume_one(key, interval, capacity, 1).is_ok(), true);
252        assert_eq!(limiter.get_token_count(key, interval), Some(capacity-1));
253
254        del_keys(&limiter, vec![(key, interval)]);
255    }
256
257    #[test]
258    fn test_multiple() {
259        let limiter = RedisLimiter::default();
260        let key = "test_multiple";
261
262        let (key_1, interval_1, capacity_1, n_1) = (format!("{}-1", key), 2, 3, 1);
263        let (key_2, interval_2, capacity_2, n_2) = (format!("{}-2", key), 4, 4, 1);
264        // [Step.prepare]: Consume all tokens in key_1
265        for _ in 0..capacity_1 {
266            assert_eq!(limiter.consume_one(key_1.as_str(), interval_1, capacity_1, n_1).is_ok(), true);
267        }
268        for (sleep_ms, args, should_ok, token_count_1, token_count_2) in vec![
269            // [Step.1]: All key_1 comsumed, so should be Error, and key_2 not touched yet.
270            (0,
271             vec![
272                 (key_1.as_str(), interval_1, capacity_1, n_1),
273                 (key_2.as_str(), interval_2, capacity_2, n_2),
274             ],
275             false,
276             Some(0), None),
277            // [Step.2]: Touch key_2 first then it has token_count=${capacity_2}, Error because key_1 still empty
278            (0,
279             vec![
280                 (key_2.as_str(), interval_2, capacity_2, n_2),
281                 (key_1.as_str(), interval_1, capacity_1, n_1),
282             ],
283             false,
284             Some(0), Some(capacity_2)),
285            // [Step.3]: Sleep more than interval_1 ms, then consume,
286            //           key_1's token_count become (capacity_1 - 1), key_2' token_count become (capacity_2 - 1)
287            ((interval_1 * 1000 + 2) as u64,
288             vec![
289                 (key_2.as_str(), interval_2, capacity_2, n_2),
290                 (key_1.as_str(), interval_1, capacity_1, n_1),
291             ],
292             true,
293             Some(capacity_1 - 1), Some(capacity_2 - 1)),
294        ] {
295            if sleep_ms > 0 {
296                thread::sleep(Duration::from_millis(sleep_ms));
297            }
298            let rv = limiter.consume(args);
299            if !should_ok {
300                assert_eq!(rv.is_err(), true);
301                let _ = rv.map_err(|err| {
302                    match err {
303                        RedisConsumeError::Denied {
304                            redis_key, interval, capacity,
305                            current_tokens, last_fill_at: _
306                        } => {
307                            assert_eq!(redis_key, limiter.get_redis_key(key_1.as_str(), interval_1));
308                            assert_eq!(interval, interval_1);
309                            assert_eq!(capacity, capacity_1);
310                            assert_eq!(current_tokens, 0);
311                        }
312                        e @ _ => {
313                            panic!("Invalid RedisConsumeError: {:?}", e)
314                        }
315                    }
316                });
317            }
318            assert_eq!(limiter.get_token_count(key_1.as_str(), interval_1), token_count_1);
319            assert_eq!(limiter.get_token_count(key_2.as_str(), interval_2), token_count_2);
320        }
321
322        del_keys(&limiter, vec![(key_1.as_str(), interval_1), (key_2.as_str(), interval_2)]);
323    }
324}