1use std::sync::Arc;
7use std::time::Duration;
8
9use redis::AsyncCommands;
10use tracing::instrument;
11
12use crate::{retry_call, ErrorTypes, RedisObjects};
13
14
15const BEGIN_SCRIPT: &str = r#"
31local t = redis.call('time')
32local key = tonumber(t[1] .. string.format("%06d", t[2]))
33
34local name = ARGV[1]
35local max = tonumber(ARGV[2])
36local timeout = tonumber(ARGV[3] .. "000000")
37
38redis.call('zremrangebyscore', name, 0, key - timeout)
39if redis.call('zcard', name) < max then
40 redis.call('zadd', name, key, key)
41 return true
42else
43 return false
44end
45"#;
46
47#[derive(Clone)]
49pub struct UserQuotaTracker {
50 store: Arc<RedisObjects>,
51 prefix: String,
52 begin: redis::Script,
53 timeout: Duration,
54}
55
56impl std::fmt::Debug for UserQuotaTracker {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("UserQuotaTracker").field("store", &self.store).field("prefix", &self.prefix).finish()
59 }
60}
61
62impl UserQuotaTracker {
63 pub (crate) fn new(store: Arc<RedisObjects>, prefix: String) -> Self {
64 Self {
65 store,
66 prefix,
67 begin: redis::Script::new(BEGIN_SCRIPT),
68 timeout: Duration::from_secs(120)
69 }
70 }
71
72 pub fn set_timeout(mut self, timeout: Duration) -> Self {
74 self.timeout = timeout;
75 self
76 }
77
78 fn queue_name(&self, user: &str) -> String {
79 format!("{}-{user}", self.prefix)
80 }
81
82 #[instrument]
84 pub async fn begin(&self, user: &str, max_quota: u32) -> Result<bool, ErrorTypes> {
85 let mut call = self.begin.key(self.queue_name(user));
86 let call = call.arg(max_quota).arg(self.timeout.as_secs());
87 Ok(retry_call!(method, self.store.pool, call, invoke_async)?)
88 }
89
90 #[instrument]
92 pub async fn end(&self, user: &str) -> Result<(), ErrorTypes> {
93 let _: () = retry_call!(self.store.pool, zpopmin, &self.queue_name(user), 1)?;
94 Ok(())
95 }
96}