atomr_persistence/
async_snapshot.rs1use std::sync::Arc;
16
17use crate::snapshot::{SnapshotMetadata, SnapshotStore};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21#[non_exhaustive]
22pub enum SnapshotPolicy {
23 Periodic { every: u64 },
25 Manual,
27}
28
29impl Default for SnapshotPolicy {
30 fn default() -> Self {
31 Self::Periodic { every: 100 }
32 }
33}
34
35pub struct AsyncSnapshotter<S: SnapshotStore + ?Sized> {
37 store: Arc<S>,
38 policy: SnapshotPolicy,
39 keep_last: usize,
41}
42
43impl<S: SnapshotStore + ?Sized> AsyncSnapshotter<S> {
44 pub fn new(store: Arc<S>, policy: SnapshotPolicy) -> Self {
45 Self { store, policy, keep_last: 1 }
46 }
47
48 pub fn with_keep_last(mut self, n: usize) -> Self {
49 assert!(n >= 1, "keep_last must be >= 1");
50 self.keep_last = n;
51 self
52 }
53
54 pub fn should_snapshot(&self, sequence_nr: u64) -> bool {
56 match self.policy {
57 SnapshotPolicy::Manual => false,
58 SnapshotPolicy::Periodic { every: 0 } => false,
59 SnapshotPolicy::Periodic { every } => sequence_nr > 0 && sequence_nr % every == 0,
60 }
61 }
62
63 pub async fn save(&self, persistence_id: impl Into<String>, sequence_nr: u64, payload: Vec<u8>) {
66 let pid = persistence_id.into();
67 let meta = SnapshotMetadata { persistence_id: pid.clone(), sequence_nr, timestamp: now_ms() };
68 self.store.save(meta, payload).await;
69 if self.keep_last >= 1 && sequence_nr >= self.keep_last as u64 {
70 let prune_to = sequence_nr.saturating_sub(self.keep_last as u64);
75 if prune_to > 0 {
76 self.store.delete(&pid, prune_to).await;
77 }
78 }
79 }
80}
81
82fn now_ms() -> u64 {
83 std::time::SystemTime::now()
84 .duration_since(std::time::UNIX_EPOCH)
85 .map(|d| d.as_millis() as u64)
86 .unwrap_or(0)
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use crate::InMemorySnapshotStore;
93
94 #[test]
95 fn periodic_policy_fires_on_multiples() {
96 let store = InMemorySnapshotStore::new();
97 let s = AsyncSnapshotter::new(store, SnapshotPolicy::Periodic { every: 10 });
98 assert!(!s.should_snapshot(0));
99 assert!(!s.should_snapshot(9));
100 assert!(s.should_snapshot(10));
101 assert!(!s.should_snapshot(11));
102 assert!(s.should_snapshot(20));
103 }
104
105 #[test]
106 fn manual_policy_never_fires() {
107 let store = InMemorySnapshotStore::new();
108 let s = AsyncSnapshotter::new(store, SnapshotPolicy::Manual);
109 for n in 0..100 {
110 assert!(!s.should_snapshot(n));
111 }
112 }
113
114 #[tokio::test]
115 async fn save_writes_to_store_and_loads_back() {
116 let store = InMemorySnapshotStore::new();
117 let s = AsyncSnapshotter::new(store.clone(), SnapshotPolicy::Periodic { every: 5 });
118 s.save("a", 5, vec![1, 2, 3]).await;
119 let (meta, payload) = store.load("a").await.unwrap();
120 assert_eq!(meta.sequence_nr, 5);
121 assert_eq!(payload, vec![1, 2, 3]);
122 }
123
124 #[tokio::test]
125 async fn keep_last_prunes_old_snapshots() {
126 let store = InMemorySnapshotStore::new();
127 let s = AsyncSnapshotter::new(store.clone(), SnapshotPolicy::Periodic { every: 1 }).with_keep_last(2);
128 for n in 1..=5 {
129 s.save("a", n, vec![n as u8]).await;
130 }
131 let last = store.load("a").await.unwrap();
134 assert_eq!(last.0.sequence_nr, 5);
135 }
136}