1extern crate redis;
2extern crate chrono;
3
4use std::default::Default;
5
6use chrono::{DateTime, Utc};
7use redis::{
8 Client as RedisClient,
9 Script as RedisScript,
10 Commands,
11};
12
13const LUA_SCRIPT: &str = include_str!("limiter.lua");
14const KEY_PREFIX: &str = "limiter";
15const REDIS_HOST: &str = "localhost";
16const REDIS_PORT: u16 = 6379;
17const REDIS_DB: u16 = 0;
18
19fn timestamp_ms(t: DateTime<Utc>) -> i64 {
20 t.timestamp() * 1000 + i64::from(t.timestamp_subsec_millis())
21}
22
23fn now_ms() -> i64 {
24 timestamp_ms(Utc::now())
25}
26
27pub trait Limiter {
28 fn get_token_count<'a>(&self, key: &'a str, interval: u32) -> Option<u32>;
29 fn consume<'a>(&self, args: Vec<(&'a str, u32, u32, u32)>)
30 -> Result<(), RedisConsumeError>;
31 fn consume_one<'a>(&self, key: &'a str, interval: u32, capacity: u32, n: u32)
32 -> Result<(), RedisConsumeError> {
33 self.consume(vec![(key, interval, capacity, n)])
34 }
35}
36
37#[derive(Debug)]
38pub enum RedisConsumeError {
39 Denied {
40 redis_key: String,
41 interval: u32,
42 capacity: u32,
43 current_tokens: u32,
44 last_fill_at: i64,
45 },
46 BadArg(String),
47 Redis(redis::RedisError)
48}
49
50pub struct RedisLimiter {
51 redis_cli: RedisClient,
52 key_prefix: String,
53 script: RedisScript,
54}
55
56#[derive(Default)]
57pub struct RedisLimiterBuilder<'a> {
58 redis_cli: Option<RedisClient>,
59 host: Option<&'a str>,
60 port: Option<u16>,
61 db: Option<u16>,
62 key_prefix: Option<&'a str>,
63 script_str: Option<&'a str>,
64}
65
66impl<'a> RedisLimiterBuilder<'a> {
67 pub fn new() -> Self {
68 RedisLimiterBuilder{
69 redis_cli: None,
70 host: None,
71 port: None,
72 db: None,
73 key_prefix: None,
74 script_str: None,
75 }
76 }
77 pub fn build(self) -> RedisLimiter {
78 let script_str = self.script_str.unwrap_or(LUA_SCRIPT);
79 let key_prefix = self.key_prefix.unwrap_or(KEY_PREFIX);
80 if let Some(redis_cli) = self.redis_cli {
81 RedisLimiter::new(redis_cli, key_prefix, script_str)
82 } else {
83 let url = format!(
84 "redis://{}:{}/{}",
85 self.host.unwrap_or(REDIS_HOST),
86 self.port.unwrap_or(REDIS_PORT),
87 self.db.unwrap_or(REDIS_DB)
88 );
89 let client = RedisClient::open(url.as_str()).unwrap();
90 RedisLimiter::new(client, key_prefix, script_str)
91 }
92 }
93
94 pub fn redis_cli(&mut self, client: RedisClient) -> &mut Self {
95 self.redis_cli = Some(client);
96 self
97 }
98 pub fn host(&mut self, value: &'a str) -> &mut Self {
99 self.host = Some(value);
100 self
101 }
102 pub fn port(&mut self, value: u16) -> &mut Self {
103 self.port = Some(value);
104 self
105 }
106 pub fn db(&mut self, value: u16) -> &mut Self {
107 self.db = Some(value);
108 self
109 }
110 pub fn key_prefix(&mut self, value: &'a str) -> &mut Self {
111 self.key_prefix = Some(value);
112 self
113 }
114 pub fn script_str(&mut self, value: &'a str) -> &mut Self {
115 self.script_str = Some(value);
116 self
117 }
118}
119
120impl RedisLimiter {
121 pub fn new<'a>(
122 redis_cli: RedisClient,
123 key_prefix: &'a str,
124 script_str: &'a str,
125 ) -> Self {
126 let key_prefix = key_prefix.to_owned();
127 let script = RedisScript::new(script_str);
128 RedisLimiter{ redis_cli, key_prefix, script }
129 }
130
131 pub fn get_redis_key<'a>(&self, key: &'a str, interval: u32) -> String {
132 format!("{}:{}:{}", self.key_prefix, key, interval)
133 }
134}
135
136impl Default for RedisLimiter {
137 fn default() -> Self { RedisLimiterBuilder::new().build() }
138}
139
140impl Limiter for RedisLimiter {
141 fn get_token_count<'a>(&self, key: &'a str, interval: u32) -> Option<u32> {
142 self.redis_cli
143 .get_connection()
144 .unwrap()
145 .hget(self.get_redis_key(key, interval), "tokens")
146 .ok()
147 }
148
149 fn consume<'a>(&self, args: Vec<(&'a str, u32, u32, u32)>)
150 -> Result<(), RedisConsumeError> {
151 let now_ms = now_ms();
152 let mut invocation = self.script.prepare_invoke();
153 for (key, interval, capacity, n) in args {
154 if key.len() < 1 || n < 1 || interval < 1 || capacity < 1 {
155 return Err(RedisConsumeError::BadArg(format!(
156 "[BadArg]: key={}, interval={}, capacity={}, n={}",
157 key, interval, capacity, n
158 )));
159 }
160 let redis_key = self.get_redis_key(key, interval);
161 let expire = interval * 2 + 15;
162 let interval_ms = interval * 1000;
163 invocation
164 .key(redis_key)
165 .arg(interval_ms)
166 .arg(capacity)
167 .arg(n)
168 .arg(now_ms)
169 .arg(expire);
170 }
171 let conn = try!{
172 self.redis_cli
173 .get_connection()
174 .map_err(RedisConsumeError::Redis)
175 };
176 match invocation.invoke(&conn) {
177 Ok((_, 0, 0, 0, 0)) => Ok(()),
178 Ok((redis_key, interval_ms, capacity,
179 current_tokens, last_fill_at)) => {
180 let interval = interval_ms / 1000;
181 Err(RedisConsumeError::Denied{
182 redis_key, interval, capacity,
183 current_tokens, last_fill_at
184 })
185 }
186 Err(e) => Err(RedisConsumeError::Redis(e))
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use std::time::Duration;
194 use std::thread;
195 use super::*;
196
197 fn redis_client() -> RedisClient {
198 let url = format!("redis://{}:{}/{}", REDIS_HOST, REDIS_PORT, REDIS_DB);
199 RedisClient::open(url.as_str()).unwrap()
200 }
201
202 fn consume_many<'a>(
203 limiter: &RedisLimiter,
204 key: &'a str, interval: u32, capacity: u32, n: u32) {
205 for i in 0..n {
206 let (success, count) = if i >= capacity {
207 (false, Some(0))
208 } else {
209 (true, Some(capacity - i - 1))
210 };
211 assert_eq!(limiter.consume_one(key, interval, capacity, 1).is_ok(), success);
212 assert_eq!(limiter.get_token_count(key, interval), count);
213 }
214 }
215
216 fn del_keys<'a>(limiter: &RedisLimiter, args: Vec<(&'a str, u32)>) {
217 let client = redis_client();
218 for (key, interval) in args {
219 let _: () = client
220 .del(limiter.get_redis_key(key, interval))
221 .unwrap();
222 }
223 }
224
225 #[test]
226 fn test_basic() {
227 let limiter = RedisLimiter::default();
228 let key = "test_basic";
229 let interval = 10;
230 let capacity = 6;
231
232 assert_eq!(limiter.get_token_count(key, interval), None);
233 consume_many(&limiter, key, interval, capacity, 12);
234
235 del_keys(&limiter, vec![(key, interval)]);
236 }
237
238 #[test]
239 fn test_refill() {
240 let limiter = RedisLimiter::default();
241 let key = "test_refill";
242 let interval = 1;
243 let capacity = 5;
244
245 assert_eq!(limiter.get_token_count(key, interval), None);
246 consume_many(&limiter, key, interval, capacity, 6);
247 assert_eq!(limiter.consume_one(key, interval, capacity, 1).is_ok(), false);
248 assert_eq!(limiter.get_token_count(key, interval), Some(0));
249
250 thread::sleep(Duration::from_millis((interval * 1000 + 2) as u64));
251 assert_eq!(limiter.consume_one(key, interval, capacity, 1).is_ok(), true);
252 assert_eq!(limiter.get_token_count(key, interval), Some(capacity-1));
253
254 del_keys(&limiter, vec![(key, interval)]);
255 }
256
257 #[test]
258 fn test_multiple() {
259 let limiter = RedisLimiter::default();
260 let key = "test_multiple";
261
262 let (key_1, interval_1, capacity_1, n_1) = (format!("{}-1", key), 2, 3, 1);
263 let (key_2, interval_2, capacity_2, n_2) = (format!("{}-2", key), 4, 4, 1);
264 for _ in 0..capacity_1 {
266 assert_eq!(limiter.consume_one(key_1.as_str(), interval_1, capacity_1, n_1).is_ok(), true);
267 }
268 for (sleep_ms, args, should_ok, token_count_1, token_count_2) in vec![
269 (0,
271 vec![
272 (key_1.as_str(), interval_1, capacity_1, n_1),
273 (key_2.as_str(), interval_2, capacity_2, n_2),
274 ],
275 false,
276 Some(0), None),
277 (0,
279 vec![
280 (key_2.as_str(), interval_2, capacity_2, n_2),
281 (key_1.as_str(), interval_1, capacity_1, n_1),
282 ],
283 false,
284 Some(0), Some(capacity_2)),
285 ((interval_1 * 1000 + 2) as u64,
288 vec![
289 (key_2.as_str(), interval_2, capacity_2, n_2),
290 (key_1.as_str(), interval_1, capacity_1, n_1),
291 ],
292 true,
293 Some(capacity_1 - 1), Some(capacity_2 - 1)),
294 ] {
295 if sleep_ms > 0 {
296 thread::sleep(Duration::from_millis(sleep_ms));
297 }
298 let rv = limiter.consume(args);
299 if !should_ok {
300 assert_eq!(rv.is_err(), true);
301 let _ = rv.map_err(|err| {
302 match err {
303 RedisConsumeError::Denied {
304 redis_key, interval, capacity,
305 current_tokens, last_fill_at: _
306 } => {
307 assert_eq!(redis_key, limiter.get_redis_key(key_1.as_str(), interval_1));
308 assert_eq!(interval, interval_1);
309 assert_eq!(capacity, capacity_1);
310 assert_eq!(current_tokens, 0);
311 }
312 e @ _ => {
313 panic!("Invalid RedisConsumeError: {:?}", e)
314 }
315 }
316 });
317 }
318 assert_eq!(limiter.get_token_count(key_1.as_str(), interval_1), token_count_1);
319 assert_eq!(limiter.get_token_count(key_2.as_str(), interval_2), token_count_2);
320 }
321
322 del_keys(&limiter, vec![(key_1.as_str(), interval_1), (key_2.as_str(), interval_2)]);
323 }
324}