use super::{NonceEntry, NonceStorage, StorageStats};
use crate::NonceError;
use crate::nonce::time_utils;
use async_trait::async_trait;
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug)]
pub struct MemoryStorage {
data: std::sync::Arc<tokio::sync::RwLock<HashMap<String, NonceEntry>>>,
}
impl MemoryStorage {
pub fn new() -> Self {
Self {
data: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
data: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::with_capacity(capacity))),
}
}
fn make_key(nonce: &str, context: Option<&str>) -> String {
match context {
Some(ctx) => {
let mut key = String::with_capacity(nonce.len() + ctx.len() + 1);
key.push_str(nonce);
key.push(':');
key.push_str(ctx);
key
}
None => {
let mut key = String::with_capacity(nonce.len() + 1);
key.push_str(nonce);
key.push(':');
key
}
}
}
}
#[async_trait]
impl NonceStorage for MemoryStorage {
async fn get(
&self,
nonce: &str,
context: Option<&str>,
) -> Result<Option<NonceEntry>, NonceError> {
let key = Self::make_key(nonce, context);
let data = self.data.read().await;
Ok(data.get(&key).cloned())
}
async fn set(
&self,
nonce: &str,
context: Option<&str>,
_ttl: Duration,
) -> Result<(), NonceError> {
let key = Self::make_key(nonce, context);
let entry = NonceEntry {
nonce: nonce.to_string(),
created_at: time_utils::current_timestamp()?,
context: context.map(|s| s.to_string()),
};
let mut data = self.data.write().await;
if data.contains_key(&key) {
return Err(NonceError::DuplicateNonce);
}
data.insert(key, entry);
Ok(())
}
async fn exists(&self, nonce: &str, context: Option<&str>) -> Result<bool, NonceError> {
let key = Self::make_key(nonce, context);
let data = self.data.read().await;
Ok(data.contains_key(&key))
}
async fn cleanup_expired(&self, cutoff_time: i64) -> Result<usize, NonceError> {
let mut data = self.data.write().await;
let initial_count = data.len();
data.retain(|_, entry| entry.created_at > cutoff_time);
Ok(initial_count - data.len())
}
async fn get_stats(&self) -> Result<StorageStats, NonceError> {
let data = self.data.read().await;
let base_entry_size = std::mem::size_of::<NonceEntry>();
let mut total_memory = data.len() * base_entry_size;
for (key, entry) in data.iter() {
total_memory += key.len(); total_memory += entry.nonce.len(); if let Some(ctx) = &entry.context {
total_memory += ctx.len(); }
}
total_memory += data.capacity() * std::mem::size_of::<(String, NonceEntry)>();
Ok(StorageStats {
total_records: data.len(),
backend_info: format!(
"In-memory HashMap storage (~{} bytes, capacity: {})",
total_memory,
data.capacity()
),
})
}
}
impl MemoryStorage {
pub async fn batch_set(
&self,
nonces: Vec<(&str, Option<&str>)>,
_ttl: Duration,
) -> Result<usize, NonceError> {
let created_at = time_utils::current_timestamp()?;
let mut data = self.data.write().await;
let mut success_count = 0;
for (nonce, context) in nonces {
let key = Self::make_key(nonce, context);
if let std::collections::hash_map::Entry::Vacant(e) = data.entry(key) {
let entry = NonceEntry {
nonce: nonce.to_string(),
created_at,
context: context.map(|s| s.to_string()),
};
e.insert(entry);
success_count += 1;
}
}
Ok(success_count)
}
pub async fn batch_exists(
&self,
nonces: Vec<(&str, Option<&str>)>,
) -> Result<Vec<bool>, NonceError> {
let data = self.data.read().await;
let mut results = Vec::with_capacity(nonces.len());
for (nonce, context) in nonces {
let key = Self::make_key(nonce, context);
results.push(data.contains_key(&key));
}
Ok(results)
}
pub async fn batch_get(
&self,
nonces: Vec<(&str, Option<&str>)>,
) -> Result<Vec<Option<NonceEntry>>, NonceError> {
let data = self.data.read().await;
let mut results = Vec::with_capacity(nonces.len());
for (nonce, context) in nonces {
let key = Self::make_key(nonce, context);
results.push(data.get(&key).cloned());
}
Ok(results)
}
}
impl Default for MemoryStorage {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::{SystemTime, UNIX_EPOCH};
#[tokio::test]
async fn test_memory_storage_basic_operations() -> Result<(), NonceError> {
let storage = MemoryStorage::new();
storage
.set("test-nonce", None, Duration::from_secs(300))
.await?;
assert!(storage.exists("test-nonce", None).await?);
let entry = storage.get("test-nonce", None).await?;
assert!(entry.is_some());
let entry = entry.unwrap();
assert_eq!(entry.nonce, "test-nonce");
assert!(entry.context.is_none());
Ok(())
}
#[tokio::test]
async fn test_memory_storage_duplicate_nonce() -> Result<(), NonceError> {
let storage = MemoryStorage::new();
storage
.set("test-nonce", None, Duration::from_secs(300))
.await?;
let result = storage
.set("test-nonce", None, Duration::from_secs(300))
.await;
assert!(matches!(result, Err(NonceError::DuplicateNonce)));
Ok(())
}
#[tokio::test]
async fn test_memory_storage_context_isolation() -> Result<(), NonceError> {
let storage = MemoryStorage::new();
storage
.set("test-nonce", Some("context1"), Duration::from_secs(300))
.await?;
storage
.set("test-nonce", Some("context2"), Duration::from_secs(300))
.await?;
assert!(storage.exists("test-nonce", Some("context1")).await?);
assert!(storage.exists("test-nonce", Some("context2")).await?);
assert!(!storage.exists("test-nonce", Some("context3")).await?);
Ok(())
}
#[tokio::test]
async fn test_memory_storage_cleanup() -> Result<(), NonceError> {
let storage = MemoryStorage::new();
storage
.set("old-nonce", None, Duration::from_secs(300))
.await?;
storage
.set("new-nonce", None, Duration::from_secs(300))
.await?;
let future_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64
+ 3600;
let removed = storage.cleanup_expired(future_time).await?;
assert_eq!(removed, 2);
assert!(!storage.exists("old-nonce", None).await?);
assert!(!storage.exists("new-nonce", None).await?);
Ok(())
}
#[tokio::test]
async fn test_memory_storage_stats() -> Result<(), NonceError> {
let storage = MemoryStorage::new();
let stats = storage.get_stats().await?;
assert_eq!(stats.total_records, 0);
assert!(stats.backend_info.contains("In-memory"));
storage
.set("nonce1", None, Duration::from_secs(300))
.await?;
storage
.set("nonce2", Some("context"), Duration::from_secs(300))
.await?;
let stats = storage.get_stats().await?;
assert_eq!(stats.total_records, 2);
assert!(stats.backend_info.contains("In-memory"));
assert!(stats.backend_info.contains("bytes"));
assert!(stats.backend_info.contains("capacity"));
Ok(())
}
#[tokio::test]
async fn test_memory_storage_with_capacity() -> Result<(), NonceError> {
let storage = MemoryStorage::with_capacity(100);
storage.set("test1", None, Duration::from_secs(300)).await?;
storage
.set("test2", Some("ctx"), Duration::from_secs(300))
.await?;
let stats = storage.get_stats().await?;
assert_eq!(stats.total_records, 2);
assert!(stats.backend_info.contains("capacity"));
Ok(())
}
#[tokio::test]
async fn test_batch_operations() -> Result<(), NonceError> {
let storage = MemoryStorage::new();
let nonces = vec![
("batch1", None),
("batch2", Some("ctx1")),
("batch3", Some("ctx2")),
("batch1", None), ];
let inserted = storage.batch_set(nonces, Duration::from_secs(300)).await?;
assert_eq!(inserted, 3);
let check_nonces = vec![
("batch1", None),
("batch2", Some("ctx1")),
("batch3", Some("ctx2")),
("batch4", None), ];
let exists_results = storage.batch_exists(check_nonces).await?;
assert_eq!(exists_results, vec![true, true, true, false]);
let get_nonces = vec![
("batch1", None),
("batch4", None), ];
let get_results = storage.batch_get(get_nonces).await?;
assert!(get_results[0].is_some());
assert!(get_results[1].is_none());
Ok(())
}
#[tokio::test]
async fn test_memory_storage_key_generation() {
assert_eq!(MemoryStorage::make_key("nonce1", None), "nonce1:");
assert_eq!(MemoryStorage::make_key("nonce1", Some("ctx")), "nonce1:ctx");
assert_eq!(MemoryStorage::make_key("nonce1", Some("")), "nonce1:");
}
#[tokio::test]
async fn test_memory_storage_concurrent_access() -> Result<(), NonceError> {
let storage = std::sync::Arc::new(MemoryStorage::new());
let mut handles = vec![];
for i in 0..10 {
let storage_clone = std::sync::Arc::clone(&storage);
let handle = tokio::spawn(async move {
storage_clone
.set(&format!("nonce-{i}"), None, Duration::from_secs(300))
.await
});
handles.push(handle);
}
for handle in handles {
assert!(handle.await.unwrap().is_ok());
}
let stats = storage.get_stats().await?;
assert_eq!(stats.total_records, 10);
Ok(())
}
#[tokio::test]
async fn test_memory_usage_calculation() -> Result<(), NonceError> {
let storage = MemoryStorage::new();
storage.set("short", None, Duration::from_secs(300)).await?;
storage
.set(
"very_long_nonce_name_for_testing",
Some("long_context_name"),
Duration::from_secs(300),
)
.await?;
let stats = storage.get_stats().await?;
assert_eq!(stats.total_records, 2);
assert!(stats.backend_info.contains("bytes"));
assert!(stats.backend_info.contains("capacity"));
Ok(())
}
}