use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use uuid::Uuid;
use zeroize::Zeroize;
const DEFAULT_MAX_ENTRIES: usize = 10_000;
struct CachedKey {
key: [u8; 32],
last_accessed: Instant,
}
impl CachedKey {
fn new(key: [u8; 32]) -> Self {
Self {
key,
last_accessed: Instant::now(),
}
}
fn touch(&mut self) {
self.last_accessed = Instant::now();
}
}
impl Drop for CachedKey {
fn drop(&mut self) {
self.key.zeroize();
}
}
#[derive(Debug, Clone)]
pub struct WalletUnlockCacheConfig {
pub max_entries: usize,
}
impl Default for WalletUnlockCacheConfig {
fn default() -> Self {
Self {
max_entries: DEFAULT_MAX_ENTRIES,
}
}
}
pub struct WalletUnlockCache {
entries: RwLock<HashMap<Uuid, CachedKey>>,
config: WalletUnlockCacheConfig,
}
impl WalletUnlockCache {
pub fn new() -> Self {
Self::with_config(WalletUnlockCacheConfig::default())
}
pub fn with_config(config: WalletUnlockCacheConfig) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
config,
}
}
pub async fn store(&self, session_id: Uuid, key: [u8; 32]) {
let mut entries = self.entries.write().await;
let is_new_entry = !entries.contains_key(&session_id);
if is_new_entry && entries.len() >= self.config.max_entries {
if let Some(lru_id) = entries
.iter()
.min_by_key(|(_, v)| v.last_accessed)
.map(|(k, _)| *k)
{
entries.remove(&lru_id);
}
}
entries.insert(session_id, CachedKey::new(key));
}
pub async fn get(&self, session_id: Uuid) -> Option<[u8; 32]> {
let mut entries = self.entries.write().await;
entries.get_mut(&session_id).map(|entry| {
entry.touch();
entry.key
})
}
pub async fn is_unlocked(&self, session_id: Uuid) -> bool {
self.entries.read().await.contains_key(&session_id)
}
pub async fn remove(&self, session_id: Uuid) {
let mut entries = self.entries.write().await;
entries.remove(&session_id);
}
pub async fn remove_all_for_sessions(&self, session_ids: &[Uuid]) {
let mut entries = self.entries.write().await;
for session_id in session_ids {
entries.remove(session_id);
}
}
pub async fn len(&self) -> usize {
self.entries.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.entries.read().await.is_empty()
}
}
impl Default for WalletUnlockCache {
fn default() -> Self {
Self::new()
}
}
pub fn create_wallet_unlock_cache() -> Arc<WalletUnlockCache> {
Arc::new(WalletUnlockCache::new())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_store_and_get() {
let cache = WalletUnlockCache::new();
let session_id = Uuid::new_v4();
let key = [0x42u8; 32];
cache.store(session_id, key).await;
let retrieved = cache.get(session_id).await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap(), key);
}
#[tokio::test]
async fn test_is_unlocked() {
let cache = WalletUnlockCache::new();
let session_id = Uuid::new_v4();
let key = [0x42u8; 32];
assert!(!cache.is_unlocked(session_id).await);
cache.store(session_id, key).await;
assert!(cache.is_unlocked(session_id).await);
}
#[tokio::test]
async fn test_remove() {
let cache = WalletUnlockCache::new();
let session_id = Uuid::new_v4();
let key = [0x42u8; 32];
cache.store(session_id, key).await;
assert!(cache.is_unlocked(session_id).await);
cache.remove(session_id).await;
assert!(!cache.is_unlocked(session_id).await);
}
#[tokio::test]
async fn test_persists_without_expiration() {
let cache = WalletUnlockCache::new();
let session_id = Uuid::new_v4();
let key = [0x42u8; 32];
cache.store(session_id, key).await;
assert!(cache.is_unlocked(session_id).await);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(cache.is_unlocked(session_id).await);
assert!(cache.get(session_id).await.is_some());
}
#[tokio::test]
async fn test_lru_eviction() {
let config = WalletUnlockCacheConfig { max_entries: 3 };
let cache = WalletUnlockCache::with_config(config);
let session1 = Uuid::new_v4();
let session2 = Uuid::new_v4();
let session3 = Uuid::new_v4();
let session4 = Uuid::new_v4();
cache.store(session1, [1u8; 32]).await;
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
cache.store(session2, [2u8; 32]).await;
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
cache.store(session3, [3u8; 32]).await;
assert_eq!(cache.len().await, 3);
cache.get(session1).await;
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
cache.store(session4, [4u8; 32]).await;
assert_eq!(cache.len().await, 3);
assert!(cache.is_unlocked(session1).await); assert!(!cache.is_unlocked(session2).await); assert!(cache.is_unlocked(session3).await); assert!(cache.is_unlocked(session4).await); }
#[tokio::test]
async fn test_remove_all_for_sessions() {
let cache = WalletUnlockCache::new();
let sessions: Vec<Uuid> = (0..3).map(|_| Uuid::new_v4()).collect();
let other_session = Uuid::new_v4();
for session_id in &sessions {
cache.store(*session_id, [0u8; 32]).await;
}
cache.store(other_session, [0u8; 32]).await;
assert_eq!(cache.len().await, 4);
cache.remove_all_for_sessions(&sessions).await;
assert_eq!(cache.len().await, 1);
assert!(cache.is_unlocked(other_session).await);
}
#[tokio::test]
async fn test_update_existing_entry_no_eviction() {
let config = WalletUnlockCacheConfig { max_entries: 3 };
let cache = WalletUnlockCache::with_config(config);
let session1 = Uuid::new_v4();
let session2 = Uuid::new_v4();
let session3 = Uuid::new_v4();
cache.store(session1, [1u8; 32]).await;
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
cache.store(session2, [2u8; 32]).await;
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
cache.store(session3, [3u8; 32]).await;
assert_eq!(cache.len().await, 3);
cache.store(session1, [11u8; 32]).await;
assert_eq!(cache.len().await, 3);
assert!(cache.is_unlocked(session1).await);
assert!(cache.is_unlocked(session2).await);
assert!(cache.is_unlocked(session3).await);
let key = cache.get(session1).await.unwrap();
assert_eq!(key, [11u8; 32]);
}
}