use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use anyhow::Result;
use tokio::sync::Mutex;
use wacore_ng::store::traits::SignalStore;
pub struct SignalStoreCache {
sessions: Mutex<StoreState>,
identities: Mutex<StoreState>,
sender_keys: Mutex<StoreState>,
}
struct StoreState {
cache: HashMap<String, Option<Arc<[u8]>>>,
dirty: HashSet<String>,
deleted: HashSet<String>,
}
impl StoreState {
fn new() -> Self {
Self {
cache: HashMap::new(),
dirty: HashSet::new(),
deleted: HashSet::new(),
}
}
fn clear(&mut self) {
self.cache = HashMap::new();
self.dirty = HashSet::new();
self.deleted = HashSet::new();
}
}
impl Default for SignalStoreCache {
fn default() -> Self {
Self::new()
}
}
impl SignalStoreCache {
pub fn new() -> Self {
Self {
sessions: Mutex::new(StoreState::new()),
identities: Mutex::new(StoreState::new()),
sender_keys: Mutex::new(StoreState::new()),
}
}
pub async fn get_session(
&self,
address: &str,
backend: &dyn SignalStore,
) -> Result<Option<Arc<[u8]>>> {
let mut state = self.sessions.lock().await;
if let Some(cached) = state.cache.get(address) {
return Ok(cached.clone());
}
let data = backend.get_session(address).await?;
let arc_data = data.map(Arc::from);
state.cache.insert(address.to_string(), arc_data.clone());
Ok(arc_data)
}
pub async fn put_session(&self, address: &str, data: &[u8]) {
let mut state = self.sessions.lock().await;
let addr = address.to_string();
state.cache.insert(addr.clone(), Some(Arc::from(data)));
state.dirty.insert(addr);
state.deleted.remove(address);
}
pub async fn delete_session(&self, address: &str) {
let mut state = self.sessions.lock().await;
let addr = address.to_string();
state.cache.insert(addr.clone(), None);
state.deleted.insert(addr);
state.dirty.remove(address);
}
pub async fn has_session(&self, address: &str, backend: &dyn SignalStore) -> Result<bool> {
Ok(self.get_session(address, backend).await?.is_some())
}
pub async fn get_identity(
&self,
address: &str,
backend: &dyn SignalStore,
) -> Result<Option<Arc<[u8]>>> {
let mut state = self.identities.lock().await;
if let Some(cached) = state.cache.get(address) {
return Ok(cached.clone());
}
let data = backend.load_identity(address).await?;
let arc_data = data.map(Arc::from);
state.cache.insert(address.to_string(), arc_data.clone());
Ok(arc_data)
}
pub async fn put_identity(&self, address: &str, data: &[u8]) {
let mut state = self.identities.lock().await;
let addr = address.to_string();
state.cache.insert(addr.clone(), Some(Arc::from(data)));
state.dirty.insert(addr);
state.deleted.remove(address);
}
pub async fn delete_identity(&self, address: &str) {
let mut state = self.identities.lock().await;
let addr = address.to_string();
state.cache.insert(addr.clone(), None);
state.deleted.insert(addr);
state.dirty.remove(address);
}
pub async fn get_sender_key(
&self,
address: &str,
backend: &dyn SignalStore,
) -> Result<Option<Arc<[u8]>>> {
let mut state = self.sender_keys.lock().await;
if let Some(cached) = state.cache.get(address) {
return Ok(cached.clone());
}
let data = backend.get_sender_key(address).await?;
let arc_data = data.map(Arc::from);
state.cache.insert(address.to_string(), arc_data.clone());
Ok(arc_data)
}
pub async fn put_sender_key(&self, address: &str, data: &[u8]) {
let mut state = self.sender_keys.lock().await;
let addr = address.to_string();
state.cache.insert(addr.clone(), Some(Arc::from(data)));
state.dirty.insert(addr);
state.deleted.remove(address);
}
pub async fn flush(&self, backend: &dyn SignalStore) -> Result<()> {
let mut sessions = self.sessions.lock().await;
let mut identities = self.identities.lock().await;
let mut sender_keys = self.sender_keys.lock().await;
let session_dirty: Vec<_> = sessions.dirty.iter().cloned().collect();
let session_deleted: Vec<_> = sessions.deleted.iter().cloned().collect();
let identity_dirty: Vec<_> = identities.dirty.iter().cloned().collect();
let identity_deleted: Vec<_> = identities.deleted.iter().cloned().collect();
let sender_key_dirty: Vec<_> = sender_keys.dirty.iter().cloned().collect();
for address in &session_dirty {
if let Some(Some(data)) = sessions.cache.get(address) {
backend.put_session(address, data).await?;
}
}
for address in &session_deleted {
backend.delete_session(address).await?;
}
for address in &identity_dirty {
if let Some(Some(data)) = identities.cache.get(address) {
let key: [u8; 32] = data.as_ref().try_into().map_err(|_| {
anyhow::anyhow!(
"Corrupted identity key for {address}: expected 32 bytes, got {}",
data.len()
)
})?;
backend.put_identity(address, key).await?;
}
}
for address in &identity_deleted {
backend.delete_identity(address).await?;
}
for name in &sender_key_dirty {
if let Some(Some(data)) = sender_keys.cache.get(name) {
backend.put_sender_key(name, data).await?;
}
}
sessions.dirty.clear();
sessions.deleted.clear();
identities.dirty.clear();
identities.deleted.clear();
sender_keys.dirty.clear();
Ok(())
}
pub async fn clear(&self) {
self.sessions.lock().await.clear();
self.identities.lock().await.clear();
self.sender_keys.lock().await.clear();
}
}