kora_lib/usage_limit/
usage_store.rs1use std::{collections::HashMap, sync::Mutex};
2
3use async_trait::async_trait;
4use deadpool_redis::{Connection, Pool};
5use redis::AsyncCommands;
6
7use crate::{error::KoraError, sanitize_error};
8
9#[async_trait]
11pub trait UsageStore: Send + Sync {
12 async fn increment(&self, key: &str) -> Result<u32, KoraError>;
14
15 async fn increment_with_expiry(&self, key: &str, expires_at: u64) -> Result<u32, KoraError>;
17
18 async fn get(&self, key: &str) -> Result<u32, KoraError>;
20
21 async fn clear(&self) -> Result<(), KoraError>;
23}
24
25pub struct RedisUsageStore {
27 pool: Pool,
28}
29
30impl RedisUsageStore {
31 pub fn new(pool: Pool) -> Self {
32 Self { pool }
33 }
34
35 async fn get_connection(&self) -> Result<Connection, KoraError> {
36 self.pool.get().await.map_err(|e| {
37 KoraError::InternalServerError(sanitize_error!(format!(
38 "Failed to get Redis connection: {}",
39 e
40 )))
41 })
42 }
43}
44
45#[async_trait]
46impl UsageStore for RedisUsageStore {
47 async fn increment(&self, key: &str) -> Result<u32, KoraError> {
48 let mut conn = self.get_connection().await?;
49 let count: u32 = conn.incr(key, 1).await.map_err(|e| {
50 KoraError::InternalServerError(sanitize_error!(format!(
51 "Failed to increment usage for {}: {}",
52 key, e
53 )))
54 })?;
55 Ok(count)
56 }
57
58 async fn increment_with_expiry(&self, key: &str, expires_at: u64) -> Result<u32, KoraError> {
59 let mut conn = self.get_connection().await?;
60
61 let (count,): (u32,) = redis::pipe()
64 .atomic()
65 .incr(key, 1)
66 .cmd("EXPIREAT")
67 .arg(key)
68 .arg(expires_at as i64)
69 .ignore()
70 .query_async(&mut conn)
71 .await
72 .map_err(|e| {
73 KoraError::InternalServerError(sanitize_error!(format!(
74 "Failed to increment with expiry for {}: {}",
75 key, e
76 )))
77 })?;
78
79 Ok(count)
80 }
81
82 async fn get(&self, key: &str) -> Result<u32, KoraError> {
83 let mut conn = self.get_connection().await?;
84 let count: Option<u32> = conn.get(key).await.map_err(|e| {
85 KoraError::InternalServerError(sanitize_error!(format!(
86 "Failed to get usage for {}: {}",
87 key, e
88 )))
89 })?;
90 Ok(count.unwrap_or(0))
91 }
92
93 async fn clear(&self) -> Result<(), KoraError> {
94 let mut conn = self.get_connection().await?;
95 let _: () = conn.flushdb().await.map_err(|e| {
96 KoraError::InternalServerError(sanitize_error!(format!("Failed to clear Redis: {}", e)))
97 })?;
98 Ok(())
99 }
100}
101
102struct UsageEntry {
104 count: u32,
105 expiry: Option<u64>, }
107
108pub struct InMemoryUsageStore {
110 data: Mutex<HashMap<String, UsageEntry>>,
111}
112
113impl InMemoryUsageStore {
114 pub fn new() -> Self {
115 Self { data: Mutex::new(HashMap::new()) }
116 }
117
118 fn current_timestamp() -> u64 {
119 std::time::SystemTime::now()
120 .duration_since(std::time::UNIX_EPOCH)
121 .map(|d| d.as_secs())
122 .unwrap_or(0)
123 }
124}
125
126impl Default for InMemoryUsageStore {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132#[async_trait]
133impl UsageStore for InMemoryUsageStore {
134 async fn increment(&self, key: &str) -> Result<u32, KoraError> {
135 let mut data = self.data.lock().map_err(|e| {
136 KoraError::InternalServerError(sanitize_error!(format!(
137 "Failed to lock usage store: {}",
138 e
139 )))
140 })?;
141 let entry = data.entry(key.to_string()).or_insert(UsageEntry { count: 0, expiry: None });
142 entry.count += 1;
143 Ok(entry.count)
144 }
145
146 async fn increment_with_expiry(&self, key: &str, expires_at: u64) -> Result<u32, KoraError> {
147 let mut data = self.data.lock().map_err(|e| {
148 KoraError::InternalServerError(sanitize_error!(format!(
149 "Failed to lock usage store: {}",
150 e
151 )))
152 })?;
153
154 let now = Self::current_timestamp();
155 let entry = data.entry(key.to_string()).or_insert(UsageEntry { count: 0, expiry: None });
156
157 if let Some(expiry) = entry.expiry {
159 if now >= expiry {
160 entry.count = 0;
161 }
162 }
163
164 entry.count += 1;
165 entry.expiry = Some(expires_at);
167
168 Ok(entry.count)
169 }
170
171 async fn get(&self, key: &str) -> Result<u32, KoraError> {
172 let data = self.data.lock().map_err(|e| {
173 KoraError::InternalServerError(sanitize_error!(format!(
174 "Failed to lock usage store: {}",
175 e
176 )))
177 })?;
178
179 if let Some(entry) = data.get(key) {
180 if let Some(expiry) = entry.expiry {
182 if Self::current_timestamp() >= expiry {
183 return Ok(0);
184 }
185 }
186 Ok(entry.count)
187 } else {
188 Ok(0)
189 }
190 }
191
192 async fn clear(&self) -> Result<(), KoraError> {
193 let mut data = self.data.lock().map_err(|e| {
194 KoraError::InternalServerError(sanitize_error!(format!(
195 "Failed to lock usage store: {}",
196 e
197 )))
198 })?;
199 data.clear();
200 Ok(())
201 }
202}
203
204#[cfg(test)]
206pub struct ErrorUsageStore {
207 should_error_get: bool,
208 should_error_increment: bool,
209}
210
211#[cfg(test)]
212impl ErrorUsageStore {
213 pub fn new(should_error_get: bool, should_error_increment: bool) -> Self {
214 Self { should_error_get, should_error_increment }
215 }
216}
217
218#[cfg(test)]
219#[async_trait]
220impl UsageStore for ErrorUsageStore {
221 async fn increment(&self, _key: &str) -> Result<u32, KoraError> {
222 if self.should_error_increment {
223 Err(KoraError::InternalServerError("Redis connection failed".to_string()))
224 } else {
225 Ok(1)
226 }
227 }
228
229 async fn increment_with_expiry(&self, _key: &str, _expires_at: u64) -> Result<u32, KoraError> {
230 if self.should_error_increment {
231 Err(KoraError::InternalServerError("Redis connection failed".to_string()))
232 } else {
233 Ok(1)
234 }
235 }
236
237 async fn get(&self, _key: &str) -> Result<u32, KoraError> {
238 if self.should_error_get {
239 Err(KoraError::InternalServerError("Redis connection failed".to_string()))
240 } else {
241 Ok(0)
242 }
243 }
244
245 async fn clear(&self) -> Result<(), KoraError> {
246 Ok(())
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[tokio::test]
255 async fn test_in_memory_usage_store() {
256 let store = InMemoryUsageStore::new();
257
258 assert_eq!(store.get("wallet1").await.unwrap(), 0);
260
261 assert_eq!(store.increment("wallet1").await.unwrap(), 1);
263 assert_eq!(store.get("wallet1").await.unwrap(), 1);
264
265 assert_eq!(store.increment("wallet1").await.unwrap(), 2);
267 assert_eq!(store.get("wallet1").await.unwrap(), 2);
268
269 assert_eq!(store.increment("wallet2").await.unwrap(), 1);
271 assert_eq!(store.get("wallet2").await.unwrap(), 1);
272 assert_eq!(store.get("wallet1").await.unwrap(), 2);
273
274 store.clear().await.unwrap();
276 assert_eq!(store.get("wallet1").await.unwrap(), 0);
277 assert_eq!(store.get("wallet2").await.unwrap(), 0);
278 }
279}