atomr_persistence/
async_snapshot.rs1use std::sync::Arc;
14
15use crate::snapshot::{SnapshotMetadata, SnapshotStore};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19#[non_exhaustive]
20pub enum SnapshotPolicy {
21 Periodic { every: u64 },
23 Manual,
25}
26
27impl Default for SnapshotPolicy {
28 fn default() -> Self {
29 Self::Periodic { every: 100 }
30 }
31}
32
33pub struct AsyncSnapshotter<S: SnapshotStore + ?Sized> {
35 store: Arc<S>,
36 policy: SnapshotPolicy,
37 keep_last: usize,
39}
40
41impl<S: SnapshotStore + ?Sized> AsyncSnapshotter<S> {
42 pub fn new(store: Arc<S>, policy: SnapshotPolicy) -> Self {
43 Self { store, policy, keep_last: 1 }
44 }
45
46 pub fn with_keep_last(mut self, n: usize) -> Self {
47 assert!(n >= 1, "keep_last must be >= 1");
48 self.keep_last = n;
49 self
50 }
51
52 pub fn should_snapshot(&self, sequence_nr: u64) -> bool {
54 match self.policy {
55 SnapshotPolicy::Manual => false,
56 SnapshotPolicy::Periodic { every: 0 } => false,
57 SnapshotPolicy::Periodic { every } => sequence_nr > 0 && sequence_nr % every == 0,
58 }
59 }
60
61 pub async fn save(&self, persistence_id: impl Into<String>, sequence_nr: u64, payload: Vec<u8>) {
64 let pid = persistence_id.into();
65 let meta = SnapshotMetadata { persistence_id: pid.clone(), sequence_nr, timestamp: now_ms() };
66 self.store.save(meta, payload).await;
67 if self.keep_last >= 1 && sequence_nr >= self.keep_last as u64 {
68 let prune_to = sequence_nr.saturating_sub(self.keep_last as u64);
73 if prune_to > 0 {
74 self.store.delete(&pid, prune_to).await;
75 }
76 }
77 }
78}
79
80fn now_ms() -> u64 {
81 std::time::SystemTime::now()
82 .duration_since(std::time::UNIX_EPOCH)
83 .map(|d| d.as_millis() as u64)
84 .unwrap_or(0)
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use crate::InMemorySnapshotStore;
91
92 #[test]
93 fn periodic_policy_fires_on_multiples() {
94 let store = InMemorySnapshotStore::new();
95 let s = AsyncSnapshotter::new(store, SnapshotPolicy::Periodic { every: 10 });
96 assert!(!s.should_snapshot(0));
97 assert!(!s.should_snapshot(9));
98 assert!(s.should_snapshot(10));
99 assert!(!s.should_snapshot(11));
100 assert!(s.should_snapshot(20));
101 }
102
103 #[test]
104 fn manual_policy_never_fires() {
105 let store = InMemorySnapshotStore::new();
106 let s = AsyncSnapshotter::new(store, SnapshotPolicy::Manual);
107 for n in 0..100 {
108 assert!(!s.should_snapshot(n));
109 }
110 }
111
112 #[tokio::test]
113 async fn save_writes_to_store_and_loads_back() {
114 let store = InMemorySnapshotStore::new();
115 let s = AsyncSnapshotter::new(store.clone(), SnapshotPolicy::Periodic { every: 5 });
116 s.save("a", 5, vec![1, 2, 3]).await;
117 let (meta, payload) = store.load("a").await.unwrap();
118 assert_eq!(meta.sequence_nr, 5);
119 assert_eq!(payload, vec![1, 2, 3]);
120 }
121
122 #[tokio::test]
123 async fn keep_last_prunes_old_snapshots() {
124 let store = InMemorySnapshotStore::new();
125 let s = AsyncSnapshotter::new(store.clone(), SnapshotPolicy::Periodic { every: 1 }).with_keep_last(2);
126 for n in 1..=5 {
127 s.save("a", n, vec![n as u8]).await;
128 }
129 let last = store.load("a").await.unwrap();
132 assert_eq!(last.0.sequence_nr, 5);
133 }
134}