use super::error::StoreError;
use crate::store::Device;
use crate::store::traits::Backend;
use async_lock::RwLock;
use event_listener::Event;
use futures::FutureExt;
use log::{debug, error};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use wacore::runtime::{AbortHandle, Runtime, ShutdownSignal, wait_for_shutdown};
pub struct PersistenceManager {
device: Arc<RwLock<Device>>,
backend: Arc<dyn Backend>,
dirty: Arc<AtomicBool>,
save_notify: Arc<Event>,
saver_halted: Arc<AtomicBool>,
}
impl PersistenceManager {
pub async fn new(backend: Arc<dyn Backend>) -> Result<Self, StoreError> {
debug!("PersistenceManager: Ensuring device row exists.");
let exists = backend.exists().await?;
if !exists {
debug!("PersistenceManager: No device row found. Creating new device row.");
let id = backend.create().await?;
debug!("PersistenceManager: Created device row with id={id}.");
}
debug!("PersistenceManager: Attempting to load device data via Backend.");
let device_data_opt = backend.load().await?;
let device = if let Some(serializable_device) = device_data_opt {
debug!(
"PersistenceManager: Loaded existing device data (PushName: '{}'). Initializing Device.",
serializable_device.push_name
);
let mut dev = Device::new(backend.clone());
dev.load_from_serializable(serializable_device);
dev
} else {
debug!("PersistenceManager: No data yet; initializing default Device in memory.");
Device::new(backend.clone())
};
Ok(Self {
device: Arc::new(RwLock::new(device)),
backend,
dirty: Arc::new(AtomicBool::new(false)),
save_notify: Arc::new(Event::new()),
saver_halted: Arc::new(AtomicBool::new(false)),
})
}
pub async fn get_device_arc(&self) -> Arc<RwLock<Device>> {
self.device.clone()
}
pub async fn get_device_snapshot(&self) -> Device {
self.device.read().await.clone()
}
pub fn backend(&self) -> Arc<dyn Backend> {
self.backend.clone()
}
pub fn is_saver_halted(&self) -> bool {
self.saver_halted.load(Ordering::Acquire)
}
pub async fn modify_device<F, R>(&self, modifier: F) -> R
where
F: FnOnce(&mut Device) -> R,
{
let mut device_guard = self.device.write().await;
let result = modifier(&mut device_guard);
self.dirty.store(true, Ordering::Relaxed);
self.save_notify.notify(1);
result
}
pub async fn flush(&self) -> Result<(), StoreError> {
self.save_to_disk().await
}
async fn save_to_disk(&self) -> Result<(), StoreError> {
if self.dirty.swap(false, Ordering::AcqRel) {
debug!("Device state is dirty, saving to disk.");
let device_guard = self.device.read().await;
let serializable_device = device_guard.to_serializable();
drop(device_guard);
if let Err(e) = self.backend.save(&serializable_device).await {
self.dirty.store(true, Ordering::Release);
return Err(e);
}
debug!("Device state saved successfully.");
}
Ok(())
}
pub async fn create_snapshot(
&self,
name: &str,
extra_content: Option<&[u8]>,
) -> Result<(), StoreError> {
#[cfg(feature = "debug-snapshots")]
{
self.save_to_disk().await?;
self.backend.snapshot_db(name, extra_content).await
}
#[cfg(not(feature = "debug-snapshots"))]
{
let _ = name;
let _ = extra_content;
log::warn!("Snapshot requested but 'debug-snapshots' feature is disabled");
Ok(())
}
}
pub fn run_background_saver(
self: Arc<Self>,
runtime: Arc<dyn Runtime>,
interval: Duration,
shutdown: ShutdownSignal,
) -> AbortHandle {
const MAX_CONSECUTIVE_FAILURES: u32 = 10;
let rt = runtime.clone();
let weak = Arc::downgrade(&self);
drop(self);
debug!("Background saver started (interval {interval:?})");
runtime.spawn(Box::pin(async move {
let mut consecutive_failures: u32 = 0;
if let Some(this) = weak.upgrade()
&& let Err(e) = this.save_to_disk().await
{
error!("Background saver: initial flush failed: {e}");
consecutive_failures = 1;
}
loop {
let Some(this) = weak.upgrade() else {
debug!("PersistenceManager dropped, exiting background saver.");
return;
};
let save_listener = this.save_notify.listen();
drop(this);
let should_exit = futures::select! {
_ = save_listener.fuse() => false,
_ = rt.sleep(interval).fuse() => false,
_ = wait_for_shutdown(&shutdown).fuse() => true,
};
let Some(this) = weak.upgrade() else {
debug!("PersistenceManager dropped, exiting background saver.");
return;
};
let flush_result = this.save_to_disk().await;
if should_exit {
match &flush_result {
Err(e) => {
error!("Background saver: final flush on shutdown failed: {e}");
}
Ok(()) => {
debug!("Background saver received shutdown; final flush complete.");
}
}
return;
}
if let Err(e) = flush_result {
consecutive_failures += 1;
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES {
this.saver_halted.store(true, Ordering::Release);
error!(
"Background saver: {consecutive_failures} consecutive flush failures, \
halting to prevent silent data loss. Last error: {e}"
);
return;
}
error!(
"Background saver flush failed ({consecutive_failures}/{MAX_CONSECUTIVE_FAILURES}): {e}"
);
} else {
consecutive_failures = 0;
}
}
}))
}
}
use super::commands::{DeviceCommand, apply_command_to_device};
impl PersistenceManager {
pub async fn process_command(&self, command: DeviceCommand) {
self.modify_device(|device| {
apply_command_to_device(device, command);
})
.await;
}
}
impl PersistenceManager {
pub async fn get_sender_key_devices(
&self,
group_jid: &str,
) -> Result<Vec<(String, bool)>, StoreError> {
self.backend.get_sender_key_devices(group_jid).await
}
pub async fn set_sender_key_status(
&self,
group_jid: &str,
entries: &[(&str, bool)],
) -> Result<(), StoreError> {
self.backend.set_sender_key_status(group_jid, entries).await
}
pub async fn clear_sender_key_devices(&self, group_jid: &str) -> Result<(), StoreError> {
self.backend.clear_sender_key_devices(group_jid).await
}
pub async fn delete_sender_key_device_rows(
&self,
device_jids: &[&str],
) -> Result<(), StoreError> {
self.backend
.delete_sender_key_device_rows(device_jids)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime_impl::TokioRuntime;
use wacore::time::Instant;
#[tokio::test]
async fn saver_flushes_and_exits_on_shutdown() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(
PersistenceManager::new(backend.clone())
.await
.expect("pm init"),
);
let notifier = wacore::runtime::ShutdownNotifier::new();
let shutdown_signal = notifier.subscribe();
let runtime: Arc<dyn Runtime> = Arc::new(TokioRuntime);
let handle =
pm.clone()
.run_background_saver(runtime, Duration::from_secs(3600), shutdown_signal);
tokio::time::sleep(Duration::from_millis(50)).await;
pm.modify_device(|d| {
d.push_name = "shutdown-flush".to_string();
})
.await;
notifier.notify();
let deadline = Instant::now() + Duration::from_secs(2);
loop {
if let Ok(Some(d)) = backend.load().await
&& d.push_name == "shutdown-flush"
{
break;
}
if Instant::now() > deadline {
panic!("final flush did not reach backend after shutdown");
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
drop(handle);
}
#[tokio::test]
async fn saver_exits_when_abort_handle_dropped_without_signal() {
let backend = crate::test_utils::create_test_backend().await;
let pm = Arc::new(PersistenceManager::new(backend).await.expect("pm init"));
let runtime: Arc<dyn Runtime> = Arc::new(TokioRuntime);
let baseline = Arc::strong_count(&runtime);
let handle = pm.clone().run_background_saver(
Arc::clone(&runtime),
Duration::from_secs(3600),
ShutdownSignal::never(),
);
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(
Arc::strong_count(&runtime) > baseline,
"running saver should hold a captured runtime Arc"
);
drop(handle);
let deadline = Instant::now() + Duration::from_secs(1);
while Arc::strong_count(&runtime) > baseline {
if Instant::now() > deadline {
panic!(
"saver task did not release the runtime Arc within 1s of AbortHandle drop \
(strong_count={}, baseline={})",
Arc::strong_count(&runtime),
baseline
);
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
#[tokio::test]
async fn abort_handle_in_arc_drops_only_when_last_ref_released() {
use std::sync::atomic::{AtomicBool, Ordering};
struct Owner(std::sync::OnceLock<AbortHandle>);
let owner = Arc::new(Owner(std::sync::OnceLock::new()));
let aborted = Arc::new(AtomicBool::new(false));
let aborted_clone = Arc::clone(&aborted);
owner
.0
.set(AbortHandle::new(move || {
aborted_clone.store(true, Ordering::SeqCst);
}))
.ok()
.expect("first set");
let owner_clone = Arc::clone(&owner);
drop(owner);
assert!(
!aborted.load(Ordering::SeqCst),
"handle must survive while another Arc ref is held"
);
drop(owner_clone);
assert!(
aborted.load(Ordering::SeqCst),
"last Arc drop must release the handle and fire abort"
);
}
}