Skip to main content

kora_lib/usage_limit/
usage_store.rs

1use std::{collections::HashMap, sync::Mutex};
2
3use async_trait::async_trait;
4use deadpool_redis::{Connection, Pool};
5use redis::AsyncCommands;
6
7use crate::{error::KoraError, sanitize_error};
8
9/// Trait for storing and retrieving usage counts
10#[async_trait]
11pub trait UsageStore: Send + Sync {
12    /// Increment usage count for a key and return the new value
13    async fn increment(&self, key: &str) -> Result<u32, KoraError>;
14
15    /// Increment usage count with absolute expiration (key expires at unix timestamp)
16    async fn increment_with_expiry(&self, key: &str, expires_at: u64) -> Result<u32, KoraError>;
17
18    /// Get current usage count for a key (returns 0 if not found)
19    async fn get(&self, key: &str) -> Result<u32, KoraError>;
20
21    /// Clear all usage data (mainly for testing)
22    async fn clear(&self) -> Result<(), KoraError>;
23}
24
25/// Redis-based implementation for production
26pub struct RedisUsageStore {
27    pool: Pool,
28}
29
30impl RedisUsageStore {
31    pub fn new(pool: Pool) -> Self {
32        Self { pool }
33    }
34
35    async fn get_connection(&self) -> Result<Connection, KoraError> {
36        self.pool.get().await.map_err(|e| {
37            KoraError::InternalServerError(sanitize_error!(format!(
38                "Failed to get Redis connection: {}",
39                e
40            )))
41        })
42    }
43}
44
45#[async_trait]
46impl UsageStore for RedisUsageStore {
47    async fn increment(&self, key: &str) -> Result<u32, KoraError> {
48        let mut conn = self.get_connection().await?;
49        let count: u32 = conn.incr(key, 1).await.map_err(|e| {
50            KoraError::InternalServerError(sanitize_error!(format!(
51                "Failed to increment usage for {}: {}",
52                key, e
53            )))
54        })?;
55        Ok(count)
56    }
57
58    async fn increment_with_expiry(&self, key: &str, expires_at: u64) -> Result<u32, KoraError> {
59        let mut conn = self.get_connection().await?;
60
61        // Use Redis pipeline for atomic INCR + EXPIREAT
62        // EXPIREAT sets absolute expiration timestamp, so repeated calls are idempotent
63        let (count,): (u32,) = redis::pipe()
64            .atomic()
65            .incr(key, 1)
66            .cmd("EXPIREAT")
67            .arg(key)
68            .arg(expires_at as i64)
69            .ignore()
70            .query_async(&mut conn)
71            .await
72            .map_err(|e| {
73                KoraError::InternalServerError(sanitize_error!(format!(
74                    "Failed to increment with expiry for {}: {}",
75                    key, e
76                )))
77            })?;
78
79        Ok(count)
80    }
81
82    async fn get(&self, key: &str) -> Result<u32, KoraError> {
83        let mut conn = self.get_connection().await?;
84        let count: Option<u32> = conn.get(key).await.map_err(|e| {
85            KoraError::InternalServerError(sanitize_error!(format!(
86                "Failed to get usage for {}: {}",
87                key, e
88            )))
89        })?;
90        Ok(count.unwrap_or(0))
91    }
92
93    async fn clear(&self) -> Result<(), KoraError> {
94        let mut conn = self.get_connection().await?;
95        let _: () = conn.flushdb().await.map_err(|e| {
96            KoraError::InternalServerError(sanitize_error!(format!("Failed to clear Redis: {}", e)))
97        })?;
98        Ok(())
99    }
100}
101
102/// Entry with count and optional expiry timestamp
103struct UsageEntry {
104    count: u32,
105    expiry: Option<u64>, // Unix timestamp when this entry expires
106}
107
108/// In-memory implementation for testing
109pub struct InMemoryUsageStore {
110    data: Mutex<HashMap<String, UsageEntry>>,
111}
112
113impl InMemoryUsageStore {
114    pub fn new() -> Self {
115        Self { data: Mutex::new(HashMap::new()) }
116    }
117
118    fn current_timestamp() -> u64 {
119        std::time::SystemTime::now()
120            .duration_since(std::time::UNIX_EPOCH)
121            .map(|d| d.as_secs())
122            .unwrap_or(0)
123    }
124}
125
126impl Default for InMemoryUsageStore {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132#[async_trait]
133impl UsageStore for InMemoryUsageStore {
134    async fn increment(&self, key: &str) -> Result<u32, KoraError> {
135        let mut data = self.data.lock().map_err(|e| {
136            KoraError::InternalServerError(sanitize_error!(format!(
137                "Failed to lock usage store: {}",
138                e
139            )))
140        })?;
141        let entry = data.entry(key.to_string()).or_insert(UsageEntry { count: 0, expiry: None });
142        entry.count += 1;
143        Ok(entry.count)
144    }
145
146    async fn increment_with_expiry(&self, key: &str, expires_at: u64) -> Result<u32, KoraError> {
147        let mut data = self.data.lock().map_err(|e| {
148            KoraError::InternalServerError(sanitize_error!(format!(
149                "Failed to lock usage store: {}",
150                e
151            )))
152        })?;
153
154        let now = Self::current_timestamp();
155        let entry = data.entry(key.to_string()).or_insert(UsageEntry { count: 0, expiry: None });
156
157        // Check if expired, reset if so
158        if let Some(expiry) = entry.expiry {
159            if now >= expiry {
160                entry.count = 0;
161            }
162        }
163
164        entry.count += 1;
165        // Always set to the same absolute expiry (idempotent like EXPIREAT)
166        entry.expiry = Some(expires_at);
167
168        Ok(entry.count)
169    }
170
171    async fn get(&self, key: &str) -> Result<u32, KoraError> {
172        let data = self.data.lock().map_err(|e| {
173            KoraError::InternalServerError(sanitize_error!(format!(
174                "Failed to lock usage store: {}",
175                e
176            )))
177        })?;
178
179        if let Some(entry) = data.get(key) {
180            // Check if expired
181            if let Some(expiry) = entry.expiry {
182                if Self::current_timestamp() >= expiry {
183                    return Ok(0);
184                }
185            }
186            Ok(entry.count)
187        } else {
188            Ok(0)
189        }
190    }
191
192    async fn clear(&self) -> Result<(), KoraError> {
193        let mut data = self.data.lock().map_err(|e| {
194            KoraError::InternalServerError(sanitize_error!(format!(
195                "Failed to lock usage store: {}",
196                e
197            )))
198        })?;
199        data.clear();
200        Ok(())
201    }
202}
203
204/// Mock store that simulates Redis errors for testing error handling
205#[cfg(test)]
206pub struct ErrorUsageStore {
207    should_error_get: bool,
208    should_error_increment: bool,
209}
210
211#[cfg(test)]
212impl ErrorUsageStore {
213    pub fn new(should_error_get: bool, should_error_increment: bool) -> Self {
214        Self { should_error_get, should_error_increment }
215    }
216}
217
218#[cfg(test)]
219#[async_trait]
220impl UsageStore for ErrorUsageStore {
221    async fn increment(&self, _key: &str) -> Result<u32, KoraError> {
222        if self.should_error_increment {
223            Err(KoraError::InternalServerError("Redis connection failed".to_string()))
224        } else {
225            Ok(1)
226        }
227    }
228
229    async fn increment_with_expiry(&self, _key: &str, _expires_at: u64) -> Result<u32, KoraError> {
230        if self.should_error_increment {
231            Err(KoraError::InternalServerError("Redis connection failed".to_string()))
232        } else {
233            Ok(1)
234        }
235    }
236
237    async fn get(&self, _key: &str) -> Result<u32, KoraError> {
238        if self.should_error_get {
239            Err(KoraError::InternalServerError("Redis connection failed".to_string()))
240        } else {
241            Ok(0)
242        }
243    }
244
245    async fn clear(&self) -> Result<(), KoraError> {
246        Ok(())
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[tokio::test]
255    async fn test_in_memory_usage_store() {
256        let store = InMemoryUsageStore::new();
257
258        // Initial count should be 0
259        assert_eq!(store.get("wallet1").await.unwrap(), 0);
260
261        // Increment should return 1
262        assert_eq!(store.increment("wallet1").await.unwrap(), 1);
263        assert_eq!(store.get("wallet1").await.unwrap(), 1);
264
265        // Increment again should return 2
266        assert_eq!(store.increment("wallet1").await.unwrap(), 2);
267        assert_eq!(store.get("wallet1").await.unwrap(), 2);
268
269        // Different key should be independent
270        assert_eq!(store.increment("wallet2").await.unwrap(), 1);
271        assert_eq!(store.get("wallet2").await.unwrap(), 1);
272        assert_eq!(store.get("wallet1").await.unwrap(), 2);
273
274        // Clear should reset everything
275        store.clear().await.unwrap();
276        assert_eq!(store.get("wallet1").await.unwrap(), 0);
277        assert_eq!(store.get("wallet2").await.unwrap(), 0);
278    }
279}