Skip to main content

awsim_core/
state.rs

1use dashmap::DashMap;
2use serde::Serialize;
3use serde::de::DeserializeOwned;
4use std::sync::Arc;
5
6/// A thread-safe, account+region-namespaced state store.
7///
8/// Each AWS service uses this to store its state, ensuring that
9/// resources in different accounts/regions are isolated.
10///
11/// Example:
12/// ```ignore
13/// let store = AccountRegionStore::<SqsState>::new();
14/// let state = store.get("000000000000", "us-east-1");
15/// state.queues.insert("my-queue".into(), queue);
16/// ```
17#[derive(Debug)]
18pub struct AccountRegionStore<T: Default + Send + Sync + 'static> {
19    inner: Arc<DashMap<(String, String), Arc<T>>>,
20}
21
22impl<T: Default + Send + Sync + 'static> Clone for AccountRegionStore<T> {
23    fn clone(&self) -> Self {
24        Self {
25            inner: Arc::clone(&self.inner),
26        }
27    }
28}
29
30impl<T: Default + Send + Sync + 'static> AccountRegionStore<T> {
31    pub fn new() -> Self {
32        Self {
33            inner: Arc::new(DashMap::new()),
34        }
35    }
36
37    /// Get or create the state for a given account+region pair.
38    pub fn get(&self, account_id: &str, region: &str) -> Arc<T> {
39        self.inner
40            .entry((account_id.to_string(), region.to_string()))
41            .or_insert_with(|| Arc::new(T::default()))
42            .clone()
43    }
44
45    /// Clear all state (useful for testing).
46    pub fn clear(&self) {
47        self.inner.clear();
48    }
49
50    /// Iterate over all (account_id, region) → state entries.
51    ///
52    /// Returns a snapshot of the keys paired with the `Arc<T>` values so the
53    /// caller can read state without holding any DashMap locks long-term.
54    pub fn iter_all(&self) -> Vec<((String, String), Arc<T>)> {
55        self.inner
56            .iter()
57            .map(|entry| (entry.key().clone(), Arc::clone(entry.value())))
58            .collect()
59    }
60
61    /// Insert a state value for the given (account_id, region), replacing any
62    /// existing entry.
63    pub fn set(&self, account_id: &str, region: &str, value: T) {
64        self.inner.insert(
65            (account_id.to_string(), region.to_string()),
66            Arc::new(value),
67        );
68    }
69}
70
71pub trait Snapshottable: Send + Sync + Sized {
72    type Snapshot: Serialize + DeserializeOwned + Send;
73
74    fn to_snapshot(&self, account_id: &str, region: &str) -> Self::Snapshot;
75
76    fn from_snapshot(snapshot: Self::Snapshot) -> (String, String, Self);
77}
78
79impl<T: Snapshottable + Default + 'static> AccountRegionStore<T> {
80    pub fn snapshot_to_bytes(&self) -> Option<Vec<u8>> {
81        let snaps: Vec<T::Snapshot> = self
82            .iter_all()
83            .into_iter()
84            .map(|((acct, region), state)| state.to_snapshot(&acct, &region))
85            .collect();
86        serde_json::to_vec(&snaps).ok()
87    }
88
89    pub fn restore_from_bytes(&self, data: &[u8]) -> Result<(), String> {
90        let snaps: Vec<T::Snapshot> = serde_json::from_slice(data).map_err(|e| e.to_string())?;
91        self.clear();
92        for snap in snaps {
93            let (acct, region, state) = T::from_snapshot(snap);
94            self.set(&acct, &region, state);
95        }
96        Ok(())
97    }
98}
99
100impl<T: Default + Send + Sync + 'static> Default for AccountRegionStore<T> {
101    fn default() -> Self {
102        Self::new()
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use serde::Deserialize;
110    use std::sync::atomic::{AtomicU32, Ordering};
111
112    #[derive(Debug, Default)]
113    struct TestState {
114        value: AtomicU32,
115    }
116
117    #[derive(Debug, Serialize, Deserialize)]
118    struct TestSnapshot {
119        account_id: String,
120        region: String,
121        value: u32,
122    }
123
124    impl Snapshottable for TestState {
125        type Snapshot = TestSnapshot;
126
127        fn to_snapshot(&self, account_id: &str, region: &str) -> Self::Snapshot {
128            TestSnapshot {
129                account_id: account_id.to_string(),
130                region: region.to_string(),
131                value: self.value.load(Ordering::SeqCst),
132            }
133        }
134
135        fn from_snapshot(snapshot: Self::Snapshot) -> (String, String, Self) {
136            (
137                snapshot.account_id,
138                snapshot.region,
139                TestState {
140                    value: AtomicU32::new(snapshot.value),
141                },
142            )
143        }
144    }
145
146    #[test]
147    fn snapshot_round_trip() {
148        let store = AccountRegionStore::<TestState>::new();
149        store
150            .get("111", "us-east-1")
151            .value
152            .store(7, Ordering::SeqCst);
153        store
154            .get("222", "us-west-2")
155            .value
156            .store(42, Ordering::SeqCst);
157
158        let bytes = store.snapshot_to_bytes().expect("snapshot");
159
160        let restored = AccountRegionStore::<TestState>::new();
161        restored.restore_from_bytes(&bytes).expect("restore");
162
163        let mut entries: Vec<((String, String), u32)> = restored
164            .iter_all()
165            .into_iter()
166            .map(|(k, v)| (k, v.value.load(Ordering::SeqCst)))
167            .collect();
168        entries.sort_by(|a, b| a.0.cmp(&b.0));
169
170        assert_eq!(
171            entries,
172            vec![
173                (("111".to_string(), "us-east-1".to_string()), 7),
174                (("222".to_string(), "us-west-2".to_string()), 42),
175            ]
176        );
177    }
178
179    #[test]
180    fn restore_replaces_existing_state() {
181        let store = AccountRegionStore::<TestState>::new();
182        store
183            .get("111", "us-east-1")
184            .value
185            .store(7, Ordering::SeqCst);
186
187        let bytes = store.snapshot_to_bytes().expect("snapshot");
188
189        store
190            .get("111", "us-east-1")
191            .value
192            .store(99, Ordering::SeqCst);
193        store
194            .get("999", "eu-west-1")
195            .value
196            .store(1, Ordering::SeqCst);
197
198        store.restore_from_bytes(&bytes).expect("restore");
199
200        let entries = store.iter_all();
201        assert_eq!(entries.len(), 1);
202        let ((acct, region), state) = &entries[0];
203        assert_eq!(acct, "111");
204        assert_eq!(region, "us-east-1");
205        assert_eq!(state.value.load(Ordering::SeqCst), 7);
206    }
207}