1use dashmap::DashMap;
2use serde::Serialize;
3use serde::de::DeserializeOwned;
4use std::sync::Arc;
5
6#[derive(Debug)]
18pub struct AccountRegionStore<T: Default + Send + Sync + 'static> {
19 inner: Arc<DashMap<(String, String), Arc<T>>>,
20}
21
22impl<T: Default + Send + Sync + 'static> Clone for AccountRegionStore<T> {
23 fn clone(&self) -> Self {
24 Self {
25 inner: Arc::clone(&self.inner),
26 }
27 }
28}
29
30impl<T: Default + Send + Sync + 'static> AccountRegionStore<T> {
31 pub fn new() -> Self {
32 Self {
33 inner: Arc::new(DashMap::new()),
34 }
35 }
36
37 pub fn get(&self, account_id: &str, region: &str) -> Arc<T> {
39 self.inner
40 .entry((account_id.to_string(), region.to_string()))
41 .or_insert_with(|| Arc::new(T::default()))
42 .clone()
43 }
44
45 pub fn clear(&self) {
47 self.inner.clear();
48 }
49
50 pub fn iter_all(&self) -> Vec<((String, String), Arc<T>)> {
55 self.inner
56 .iter()
57 .map(|entry| (entry.key().clone(), Arc::clone(entry.value())))
58 .collect()
59 }
60
61 pub fn set(&self, account_id: &str, region: &str, value: T) {
64 self.inner.insert(
65 (account_id.to_string(), region.to_string()),
66 Arc::new(value),
67 );
68 }
69}
70
71pub trait Snapshottable: Send + Sync + Sized {
72 type Snapshot: Serialize + DeserializeOwned + Send;
73
74 fn to_snapshot(&self, account_id: &str, region: &str) -> Self::Snapshot;
75
76 fn from_snapshot(snapshot: Self::Snapshot) -> (String, String, Self);
77}
78
79impl<T: Snapshottable + Default + 'static> AccountRegionStore<T> {
80 pub fn snapshot_to_bytes(&self) -> Option<Vec<u8>> {
81 let snaps: Vec<T::Snapshot> = self
82 .iter_all()
83 .into_iter()
84 .map(|((acct, region), state)| state.to_snapshot(&acct, ®ion))
85 .collect();
86 serde_json::to_vec(&snaps).ok()
87 }
88
89 pub fn restore_from_bytes(&self, data: &[u8]) -> Result<(), String> {
90 let snaps: Vec<T::Snapshot> = serde_json::from_slice(data).map_err(|e| e.to_string())?;
91 self.clear();
92 for snap in snaps {
93 let (acct, region, state) = T::from_snapshot(snap);
94 self.set(&acct, ®ion, state);
95 }
96 Ok(())
97 }
98}
99
100impl<T: Default + Send + Sync + 'static> Default for AccountRegionStore<T> {
101 fn default() -> Self {
102 Self::new()
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use serde::Deserialize;
110 use std::sync::atomic::{AtomicU32, Ordering};
111
112 #[derive(Debug, Default)]
113 struct TestState {
114 value: AtomicU32,
115 }
116
117 #[derive(Debug, Serialize, Deserialize)]
118 struct TestSnapshot {
119 account_id: String,
120 region: String,
121 value: u32,
122 }
123
124 impl Snapshottable for TestState {
125 type Snapshot = TestSnapshot;
126
127 fn to_snapshot(&self, account_id: &str, region: &str) -> Self::Snapshot {
128 TestSnapshot {
129 account_id: account_id.to_string(),
130 region: region.to_string(),
131 value: self.value.load(Ordering::SeqCst),
132 }
133 }
134
135 fn from_snapshot(snapshot: Self::Snapshot) -> (String, String, Self) {
136 (
137 snapshot.account_id,
138 snapshot.region,
139 TestState {
140 value: AtomicU32::new(snapshot.value),
141 },
142 )
143 }
144 }
145
146 #[test]
147 fn snapshot_round_trip() {
148 let store = AccountRegionStore::<TestState>::new();
149 store
150 .get("111", "us-east-1")
151 .value
152 .store(7, Ordering::SeqCst);
153 store
154 .get("222", "us-west-2")
155 .value
156 .store(42, Ordering::SeqCst);
157
158 let bytes = store.snapshot_to_bytes().expect("snapshot");
159
160 let restored = AccountRegionStore::<TestState>::new();
161 restored.restore_from_bytes(&bytes).expect("restore");
162
163 let mut entries: Vec<((String, String), u32)> = restored
164 .iter_all()
165 .into_iter()
166 .map(|(k, v)| (k, v.value.load(Ordering::SeqCst)))
167 .collect();
168 entries.sort_by(|a, b| a.0.cmp(&b.0));
169
170 assert_eq!(
171 entries,
172 vec![
173 (("111".to_string(), "us-east-1".to_string()), 7),
174 (("222".to_string(), "us-west-2".to_string()), 42),
175 ]
176 );
177 }
178
179 #[test]
180 fn restore_replaces_existing_state() {
181 let store = AccountRegionStore::<TestState>::new();
182 store
183 .get("111", "us-east-1")
184 .value
185 .store(7, Ordering::SeqCst);
186
187 let bytes = store.snapshot_to_bytes().expect("snapshot");
188
189 store
190 .get("111", "us-east-1")
191 .value
192 .store(99, Ordering::SeqCst);
193 store
194 .get("999", "eu-west-1")
195 .value
196 .store(1, Ordering::SeqCst);
197
198 store.restore_from_bytes(&bytes).expect("restore");
199
200 let entries = store.iter_all();
201 assert_eq!(entries.len(), 1);
202 let ((acct, region), state) = &entries[0];
203 assert_eq!(acct, "111");
204 assert_eq!(region, "us-east-1");
205 assert_eq!(state.value.load(Ordering::SeqCst), 7);
206 }
207}