use std::{collections::HashMap, sync::Mutex};
use async_trait::async_trait;
use deadpool_redis::{Connection, Pool};
use redis::AsyncCommands;
use crate::{error::KoraError, sanitize_error};
#[async_trait]
pub trait UsageStore: Send + Sync {
async fn increment(&self, key: &str) -> Result<u32, KoraError>;
async fn get(&self, key: &str) -> Result<u32, KoraError>;
async fn clear(&self) -> Result<(), KoraError>;
}
pub struct RedisUsageStore {
pool: Pool,
}
impl RedisUsageStore {
pub fn new(pool: Pool) -> Self {
Self { pool }
}
async fn get_connection(&self) -> Result<Connection, KoraError> {
self.pool.get().await.map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!(
"Failed to get Redis connection: {}",
e
)))
})
}
}
#[async_trait]
impl UsageStore for RedisUsageStore {
async fn increment(&self, key: &str) -> Result<u32, KoraError> {
let mut conn = self.get_connection().await?;
let count: u32 = conn.incr(key, 1).await.map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!(
"Failed to increment usage for {}: {}",
key, e
)))
})?;
Ok(count)
}
async fn get(&self, key: &str) -> Result<u32, KoraError> {
let mut conn = self.get_connection().await?;
let count: Option<u32> = conn.get(key).await.map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!(
"Failed to get usage for {}: {}",
key, e
)))
})?;
Ok(count.unwrap_or(0))
}
async fn clear(&self) -> Result<(), KoraError> {
let mut conn = self.get_connection().await?;
let _: () = conn.flushdb().await.map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!("Failed to clear Redis: {}", e)))
})?;
Ok(())
}
}
pub struct InMemoryUsageStore {
data: Mutex<HashMap<String, u32>>,
}
impl InMemoryUsageStore {
pub fn new() -> Self {
Self { data: Mutex::new(HashMap::new()) }
}
}
impl Default for InMemoryUsageStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl UsageStore for InMemoryUsageStore {
async fn increment(&self, key: &str) -> Result<u32, KoraError> {
let mut data = self.data.lock().map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!(
"Failed to lock usage store: {}",
e
)))
})?;
let count = data.entry(key.to_string()).or_insert(0);
*count += 1;
Ok(*count)
}
async fn get(&self, key: &str) -> Result<u32, KoraError> {
let data = self.data.lock().map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!(
"Failed to lock usage store: {}",
e
)))
})?;
Ok(data.get(key).copied().unwrap_or(0))
}
async fn clear(&self) -> Result<(), KoraError> {
let mut data = self.data.lock().map_err(|e| {
KoraError::InternalServerError(sanitize_error!(format!(
"Failed to lock usage store: {}",
e
)))
})?;
data.clear();
Ok(())
}
}
#[cfg(test)]
pub struct ErrorUsageStore {
should_error_get: bool,
should_error_increment: bool,
}
#[cfg(test)]
impl ErrorUsageStore {
pub fn new(should_error_get: bool, should_error_increment: bool) -> Self {
Self { should_error_get, should_error_increment }
}
}
#[cfg(test)]
#[async_trait]
impl UsageStore for ErrorUsageStore {
async fn increment(&self, _key: &str) -> Result<u32, KoraError> {
if self.should_error_increment {
Err(KoraError::InternalServerError("Redis connection failed".to_string()))
} else {
Ok(1)
}
}
async fn get(&self, _key: &str) -> Result<u32, KoraError> {
if self.should_error_get {
Err(KoraError::InternalServerError("Redis connection failed".to_string()))
} else {
Ok(0)
}
}
async fn clear(&self) -> Result<(), KoraError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_usage_store() {
let store = InMemoryUsageStore::new();
assert_eq!(store.get("wallet1").await.unwrap(), 0);
assert_eq!(store.increment("wallet1").await.unwrap(), 1);
assert_eq!(store.get("wallet1").await.unwrap(), 1);
assert_eq!(store.increment("wallet1").await.unwrap(), 2);
assert_eq!(store.get("wallet1").await.unwrap(), 2);
assert_eq!(store.increment("wallet2").await.unwrap(), 1);
assert_eq!(store.get("wallet2").await.unwrap(), 1);
assert_eq!(store.get("wallet1").await.unwrap(), 2);
store.clear().await.unwrap();
assert_eq!(store.get("wallet1").await.unwrap(), 0);
assert_eq!(store.get("wallet2").await.unwrap(), 0);
}
}