kora_lib/usage_limit/
usage_store.rs1use std::{collections::HashMap, sync::Mutex};
2
3use async_trait::async_trait;
4use deadpool_redis::{Connection, Pool};
5use redis::AsyncCommands;
6
7use crate::{error::KoraError, sanitize_error};
8
9#[async_trait]
11pub trait UsageStore: Send + Sync {
12 async fn increment(&self, key: &str) -> Result<u32, KoraError>;
14
15 async fn get(&self, key: &str) -> Result<u32, KoraError>;
17
18 async fn clear(&self) -> Result<(), KoraError>;
20}
21
22pub struct RedisUsageStore {
24 pool: Pool,
25}
26
27impl RedisUsageStore {
28 pub fn new(pool: Pool) -> Self {
29 Self { pool }
30 }
31
32 async fn get_connection(&self) -> Result<Connection, KoraError> {
33 self.pool.get().await.map_err(|e| {
34 KoraError::InternalServerError(sanitize_error!(format!(
35 "Failed to get Redis connection: {}",
36 e
37 )))
38 })
39 }
40}
41
42#[async_trait]
43impl UsageStore for RedisUsageStore {
44 async fn increment(&self, key: &str) -> Result<u32, KoraError> {
45 let mut conn = self.get_connection().await?;
46 let count: u32 = conn.incr(key, 1).await.map_err(|e| {
47 KoraError::InternalServerError(sanitize_error!(format!(
48 "Failed to increment usage for {}: {}",
49 key, e
50 )))
51 })?;
52 Ok(count)
53 }
54
55 async fn get(&self, key: &str) -> Result<u32, KoraError> {
56 let mut conn = self.get_connection().await?;
57 let count: Option<u32> = conn.get(key).await.map_err(|e| {
58 KoraError::InternalServerError(sanitize_error!(format!(
59 "Failed to get usage for {}: {}",
60 key, e
61 )))
62 })?;
63 Ok(count.unwrap_or(0))
64 }
65
66 async fn clear(&self) -> Result<(), KoraError> {
67 let mut conn = self.get_connection().await?;
68 let _: () = conn.flushdb().await.map_err(|e| {
69 KoraError::InternalServerError(sanitize_error!(format!("Failed to clear Redis: {}", e)))
70 })?;
71 Ok(())
72 }
73}
74
75pub struct InMemoryUsageStore {
77 data: Mutex<HashMap<String, u32>>,
78}
79
80impl InMemoryUsageStore {
81 pub fn new() -> Self {
82 Self { data: Mutex::new(HashMap::new()) }
83 }
84}
85
86impl Default for InMemoryUsageStore {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92#[async_trait]
93impl UsageStore for InMemoryUsageStore {
94 async fn increment(&self, key: &str) -> Result<u32, KoraError> {
95 let mut data = self.data.lock().map_err(|e| {
96 KoraError::InternalServerError(sanitize_error!(format!(
97 "Failed to lock usage store: {}",
98 e
99 )))
100 })?;
101 let count = data.entry(key.to_string()).or_insert(0);
102 *count += 1;
103 Ok(*count)
104 }
105
106 async fn get(&self, key: &str) -> Result<u32, KoraError> {
107 let data = self.data.lock().map_err(|e| {
108 KoraError::InternalServerError(sanitize_error!(format!(
109 "Failed to lock usage store: {}",
110 e
111 )))
112 })?;
113 Ok(data.get(key).copied().unwrap_or(0))
114 }
115
116 async fn clear(&self) -> Result<(), KoraError> {
117 let mut data = self.data.lock().map_err(|e| {
118 KoraError::InternalServerError(sanitize_error!(format!(
119 "Failed to lock usage store: {}",
120 e
121 )))
122 })?;
123 data.clear();
124 Ok(())
125 }
126}
127
128#[cfg(test)]
130pub struct ErrorUsageStore {
131 should_error_get: bool,
132 should_error_increment: bool,
133}
134
135#[cfg(test)]
136impl ErrorUsageStore {
137 pub fn new(should_error_get: bool, should_error_increment: bool) -> Self {
138 Self { should_error_get, should_error_increment }
139 }
140}
141
142#[cfg(test)]
143#[async_trait]
144impl UsageStore for ErrorUsageStore {
145 async fn increment(&self, _key: &str) -> Result<u32, KoraError> {
146 if self.should_error_increment {
147 Err(KoraError::InternalServerError("Redis connection failed".to_string()))
148 } else {
149 Ok(1)
150 }
151 }
152
153 async fn get(&self, _key: &str) -> Result<u32, KoraError> {
154 if self.should_error_get {
155 Err(KoraError::InternalServerError("Redis connection failed".to_string()))
156 } else {
157 Ok(0)
158 }
159 }
160
161 async fn clear(&self) -> Result<(), KoraError> {
162 Ok(())
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[tokio::test]
171 async fn test_in_memory_usage_store() {
172 let store = InMemoryUsageStore::new();
173
174 assert_eq!(store.get("wallet1").await.unwrap(), 0);
176
177 assert_eq!(store.increment("wallet1").await.unwrap(), 1);
179 assert_eq!(store.get("wallet1").await.unwrap(), 1);
180
181 assert_eq!(store.increment("wallet1").await.unwrap(), 2);
183 assert_eq!(store.get("wallet1").await.unwrap(), 2);
184
185 assert_eq!(store.increment("wallet2").await.unwrap(), 1);
187 assert_eq!(store.get("wallet2").await.unwrap(), 1);
188 assert_eq!(store.get("wallet1").await.unwrap(), 2);
189
190 store.clear().await.unwrap();
192 assert_eq!(store.get("wallet1").await.unwrap(), 0);
193 assert_eq!(store.get("wallet2").await.unwrap(), 0);
194 }
195}