use std::sync::Arc;
use crate::snapshot::{SnapshotMetadata, SnapshotStore};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum SnapshotPolicy {
Periodic { every: u64 },
Manual,
}
impl Default for SnapshotPolicy {
fn default() -> Self {
Self::Periodic { every: 100 }
}
}
pub struct AsyncSnapshotter<S: SnapshotStore + ?Sized> {
store: Arc<S>,
policy: SnapshotPolicy,
keep_last: usize,
}
impl<S: SnapshotStore + ?Sized> AsyncSnapshotter<S> {
pub fn new(store: Arc<S>, policy: SnapshotPolicy) -> Self {
Self { store, policy, keep_last: 1 }
}
pub fn with_keep_last(mut self, n: usize) -> Self {
assert!(n >= 1, "keep_last must be >= 1");
self.keep_last = n;
self
}
pub fn should_snapshot(&self, sequence_nr: u64) -> bool {
match self.policy {
SnapshotPolicy::Manual => false,
SnapshotPolicy::Periodic { every: 0 } => false,
SnapshotPolicy::Periodic { every } => sequence_nr > 0 && sequence_nr % every == 0,
}
}
pub async fn save(&self, persistence_id: impl Into<String>, sequence_nr: u64, payload: Vec<u8>) {
let pid = persistence_id.into();
let meta = SnapshotMetadata { persistence_id: pid.clone(), sequence_nr, timestamp: now_ms() };
self.store.save(meta, payload).await;
if self.keep_last >= 1 && sequence_nr >= self.keep_last as u64 {
let prune_to = sequence_nr.saturating_sub(self.keep_last as u64);
if prune_to > 0 {
self.store.delete(&pid, prune_to).await;
}
}
}
}
fn now_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::InMemorySnapshotStore;
#[test]
fn periodic_policy_fires_on_multiples() {
let store = InMemorySnapshotStore::new();
let s = AsyncSnapshotter::new(store, SnapshotPolicy::Periodic { every: 10 });
assert!(!s.should_snapshot(0));
assert!(!s.should_snapshot(9));
assert!(s.should_snapshot(10));
assert!(!s.should_snapshot(11));
assert!(s.should_snapshot(20));
}
#[test]
fn manual_policy_never_fires() {
let store = InMemorySnapshotStore::new();
let s = AsyncSnapshotter::new(store, SnapshotPolicy::Manual);
for n in 0..100 {
assert!(!s.should_snapshot(n));
}
}
#[tokio::test]
async fn save_writes_to_store_and_loads_back() {
let store = InMemorySnapshotStore::new();
let s = AsyncSnapshotter::new(store.clone(), SnapshotPolicy::Periodic { every: 5 });
s.save("a", 5, vec![1, 2, 3]).await;
let (meta, payload) = store.load("a").await.unwrap();
assert_eq!(meta.sequence_nr, 5);
assert_eq!(payload, vec![1, 2, 3]);
}
#[tokio::test]
async fn keep_last_prunes_old_snapshots() {
let store = InMemorySnapshotStore::new();
let s = AsyncSnapshotter::new(store.clone(), SnapshotPolicy::Periodic { every: 1 }).with_keep_last(2);
for n in 1..=5 {
s.save("a", n, vec![n as u8]).await;
}
let last = store.load("a").await.unwrap();
assert_eq!(last.0.sequence_nr, 5);
}
}