Skip to main content

kora_lib/usage_limit/
usage_store.rs

1use std::{collections::HashMap, sync::Mutex};
2
3use async_trait::async_trait;
4use deadpool_redis::{Connection, Pool};
5use redis::AsyncCommands;
6
7use crate::{error::KoraError, sanitize_error};
8
9/// Trait for storing and retrieving usage counts
10#[async_trait]
11pub trait UsageStore: Send + Sync {
12    /// Increment usage count for a key and return the new value
13    async fn increment(&self, key: &str) -> Result<u32, KoraError>;
14
15    /// Get current usage count for a key (returns 0 if not found)
16    async fn get(&self, key: &str) -> Result<u32, KoraError>;
17
18    /// Clear all usage data (mainly for testing)
19    async fn clear(&self) -> Result<(), KoraError>;
20}
21
22/// Redis-based implementation for production
23pub struct RedisUsageStore {
24    pool: Pool,
25}
26
27impl RedisUsageStore {
28    pub fn new(pool: Pool) -> Self {
29        Self { pool }
30    }
31
32    async fn get_connection(&self) -> Result<Connection, KoraError> {
33        self.pool.get().await.map_err(|e| {
34            KoraError::InternalServerError(sanitize_error!(format!(
35                "Failed to get Redis connection: {}",
36                e
37            )))
38        })
39    }
40}
41
42#[async_trait]
43impl UsageStore for RedisUsageStore {
44    async fn increment(&self, key: &str) -> Result<u32, KoraError> {
45        let mut conn = self.get_connection().await?;
46        let count: u32 = conn.incr(key, 1).await.map_err(|e| {
47            KoraError::InternalServerError(sanitize_error!(format!(
48                "Failed to increment usage for {}: {}",
49                key, e
50            )))
51        })?;
52        Ok(count)
53    }
54
55    async fn get(&self, key: &str) -> Result<u32, KoraError> {
56        let mut conn = self.get_connection().await?;
57        let count: Option<u32> = conn.get(key).await.map_err(|e| {
58            KoraError::InternalServerError(sanitize_error!(format!(
59                "Failed to get usage for {}: {}",
60                key, e
61            )))
62        })?;
63        Ok(count.unwrap_or(0))
64    }
65
66    async fn clear(&self) -> Result<(), KoraError> {
67        let mut conn = self.get_connection().await?;
68        let _: () = conn.flushdb().await.map_err(|e| {
69            KoraError::InternalServerError(sanitize_error!(format!("Failed to clear Redis: {}", e)))
70        })?;
71        Ok(())
72    }
73}
74
75/// In-memory implementation for testing
76pub struct InMemoryUsageStore {
77    data: Mutex<HashMap<String, u32>>,
78}
79
80impl InMemoryUsageStore {
81    pub fn new() -> Self {
82        Self { data: Mutex::new(HashMap::new()) }
83    }
84}
85
86impl Default for InMemoryUsageStore {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92#[async_trait]
93impl UsageStore for InMemoryUsageStore {
94    async fn increment(&self, key: &str) -> Result<u32, KoraError> {
95        let mut data = self.data.lock().map_err(|e| {
96            KoraError::InternalServerError(sanitize_error!(format!(
97                "Failed to lock usage store: {}",
98                e
99            )))
100        })?;
101        let count = data.entry(key.to_string()).or_insert(0);
102        *count += 1;
103        Ok(*count)
104    }
105
106    async fn get(&self, key: &str) -> Result<u32, KoraError> {
107        let data = self.data.lock().map_err(|e| {
108            KoraError::InternalServerError(sanitize_error!(format!(
109                "Failed to lock usage store: {}",
110                e
111            )))
112        })?;
113        Ok(data.get(key).copied().unwrap_or(0))
114    }
115
116    async fn clear(&self) -> Result<(), KoraError> {
117        let mut data = self.data.lock().map_err(|e| {
118            KoraError::InternalServerError(sanitize_error!(format!(
119                "Failed to lock usage store: {}",
120                e
121            )))
122        })?;
123        data.clear();
124        Ok(())
125    }
126}
127
128/// Mock store that simulates Redis errors for testing error handling
129#[cfg(test)]
130pub struct ErrorUsageStore {
131    should_error_get: bool,
132    should_error_increment: bool,
133}
134
135#[cfg(test)]
136impl ErrorUsageStore {
137    pub fn new(should_error_get: bool, should_error_increment: bool) -> Self {
138        Self { should_error_get, should_error_increment }
139    }
140}
141
142#[cfg(test)]
143#[async_trait]
144impl UsageStore for ErrorUsageStore {
145    async fn increment(&self, _key: &str) -> Result<u32, KoraError> {
146        if self.should_error_increment {
147            Err(KoraError::InternalServerError("Redis connection failed".to_string()))
148        } else {
149            Ok(1)
150        }
151    }
152
153    async fn get(&self, _key: &str) -> Result<u32, KoraError> {
154        if self.should_error_get {
155            Err(KoraError::InternalServerError("Redis connection failed".to_string()))
156        } else {
157            Ok(0)
158        }
159    }
160
161    async fn clear(&self) -> Result<(), KoraError> {
162        Ok(())
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[tokio::test]
171    async fn test_in_memory_usage_store() {
172        let store = InMemoryUsageStore::new();
173
174        // Initial count should be 0
175        assert_eq!(store.get("wallet1").await.unwrap(), 0);
176
177        // Increment should return 1
178        assert_eq!(store.increment("wallet1").await.unwrap(), 1);
179        assert_eq!(store.get("wallet1").await.unwrap(), 1);
180
181        // Increment again should return 2
182        assert_eq!(store.increment("wallet1").await.unwrap(), 2);
183        assert_eq!(store.get("wallet1").await.unwrap(), 2);
184
185        // Different key should be independent
186        assert_eq!(store.increment("wallet2").await.unwrap(), 1);
187        assert_eq!(store.get("wallet2").await.unwrap(), 1);
188        assert_eq!(store.get("wallet1").await.unwrap(), 2);
189
190        // Clear should reset everything
191        store.clear().await.unwrap();
192        assert_eq!(store.get("wallet1").await.unwrap(), 0);
193        assert_eq!(store.get("wallet2").await.unwrap(), 0);
194    }
195}