1use std::sync::Arc;
7use std::time::Duration;
8
9use redis::AsyncCommands;
10
11use crate::{retry_call, ErrorTypes, RedisObjects};
12
13
14const BEGIN_SCRIPT: &str = r#"
30local t = redis.call('time')
31local key = tonumber(t[1] .. string.format("%06d", t[2]))
32
33local name = ARGV[1]
34local max = tonumber(ARGV[2])
35local timeout = tonumber(ARGV[3] .. "000000")
36
37redis.call('zremrangebyscore', name, 0, key - timeout)
38if redis.call('zcard', name) < max then
39 redis.call('zadd', name, key, key)
40 return true
41else
42 return false
43end
44"#;
45
46#[derive(Clone)]
48pub struct UserQuotaTracker {
49 store: Arc<RedisObjects>,
50 prefix: String,
51 begin: redis::Script,
52 timeout: Duration,
53}
54
55impl UserQuotaTracker {
56 pub (crate) fn new(store: Arc<RedisObjects>, prefix: String) -> Self {
57 Self {
58 store,
59 prefix,
60 begin: redis::Script::new(BEGIN_SCRIPT),
61 timeout: Duration::from_secs(120)
62 }
63 }
64
65 pub fn set_timeout(mut self, timeout: Duration) -> Self {
67 self.timeout = timeout;
68 self
69 }
70
71 fn queue_name(&self, user: &str) -> String {
72 format!("{}-{user}", self.prefix)
73 }
74
75 pub async fn begin(&self, user: &str, max_quota: u32) -> Result<bool, ErrorTypes> {
77 let mut call = self.begin.key(self.queue_name(user));
78 let call = call.arg(max_quota).arg(self.timeout.as_secs());
79 Ok(retry_call!(method, self.store.pool, call, invoke_async)?)
80 }
81
82 pub async fn end(&self, user: &str) -> Result<(), ErrorTypes> {
84 let _: () = retry_call!(self.store.pool, zpopmin, &self.queue_name(user), 1)?;
85 Ok(())
86 }
87}