use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::RwLock;
use uvb_storage_api::{TransactionError, TransactionRecord, TransactionStatus, TransactionStore};
pub struct InMemoryTransactionStore {
transactions: Arc<RwLock<HashMap<String, TransactionRecord>>>,
}
impl InMemoryTransactionStore {
pub fn new() -> Self {
Self {
transactions: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for InMemoryTransactionStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl TransactionStore for InMemoryTransactionStore {
async fn create(&self, record: TransactionRecord) -> Result<(), TransactionError> {
let mut txns = self.transactions.write().await;
txns.insert(record.id.clone(), record);
Ok(())
}
async fn get(&self, id: &str) -> Result<Option<TransactionRecord>, TransactionError> {
let txns = self.transactions.read().await;
Ok(txns.get(id).cloned())
}
async fn update(&self, record: TransactionRecord) -> Result<(), TransactionError> {
let mut txns = self.transactions.write().await;
if !txns.contains_key(&record.id) {
return Err(TransactionError::NotFound);
}
txns.insert(record.id.clone(), record);
Ok(())
}
async fn delete(&self, id: &str) -> Result<(), TransactionError> {
let mut txns = self.transactions.write().await;
txns.remove(id).ok_or(TransactionError::NotFound)?;
Ok(())
}
async fn list_by_user(
&self,
user_id: &str,
limit: usize,
) -> Result<Vec<TransactionRecord>, TransactionError> {
let txns = self.transactions.read().await;
let mut results: Vec<TransactionRecord> = txns
.values()
.filter(|t| t.subject.user_id == user_id)
.cloned()
.collect();
results.sort_by_key(|a| std::cmp::Reverse(a.created_at));
results.truncate(limit);
Ok(results)
}
async fn cleanup_expired(&self) -> Result<usize, TransactionError> {
let mut txns = self.transactions.write().await;
let now = SystemTime::now();
let expired_ids: Vec<String> = txns
.iter()
.filter(|(_, t)| t.expires_at <= now && t.status != TransactionStatus::Succeeded)
.map(|(id, _)| id.clone())
.collect();
let count = expired_ids.len();
for id in expired_ids {
if let Some(txn) = txns.get_mut(&id) {
txn.status = TransactionStatus::Expired;
}
}
Ok(count)
}
}