Skip to main content

fakecloud_core/
multi_account.rs

1//! Generic multi-account state container.
2//!
3//! Wraps a `HashMap<AccountId, T>` so each AWS account gets its own isolated
4//! state instance. Accounts are created lazily via [`MultiAccountState::get_or_create`]
5//! the first time a request targets them — matching the design in #381 where
6//! "an account exists because a credential resolves to it."
7
8use std::collections::HashMap;
9
10use serde::{Deserialize, Serialize};
11
12/// Trait implemented by per-service state structs that participate in
13/// multi-account isolation.
14pub trait AccountState: Sized {
15    /// Create a fresh, empty state for the given account.
16    fn new_for_account(account_id: &str, region: &str, endpoint: &str) -> Self;
17
18    /// Called after a new account state is created via [`MultiAccountState::get_or_create`],
19    /// with a reference to an existing sibling state. Services can override
20    /// this to propagate shared resources (e.g. body caches) to the new state.
21    fn inherit_from(&mut self, _sibling: &Self) {}
22}
23
24/// Account-partitioned state container.
25///
26/// Holds one `T` per account id. The `default_account_id` is pre-created at
27/// startup so unauthenticated requests (which fall back to `--account-id`)
28/// always have a state to land in.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct MultiAccountState<T> {
31    default_account_id: String,
32    region: String,
33    endpoint: String,
34    accounts: HashMap<String, T>,
35}
36
37impl<T: AccountState> MultiAccountState<T> {
38    /// Create a new container, pre-populating the default account.
39    pub fn new(default_account_id: &str, region: &str, endpoint: &str) -> Self {
40        let mut accounts = HashMap::new();
41        accounts.insert(
42            default_account_id.to_string(),
43            T::new_for_account(default_account_id, region, endpoint),
44        );
45        Self {
46            default_account_id: default_account_id.to_string(),
47            region: region.to_string(),
48            endpoint: endpoint.to_string(),
49            accounts,
50        }
51    }
52
53    /// Get or lazily create the state for `account_id`.
54    ///
55    /// When a new account is created, [`AccountState::inherit_from`] is called
56    /// with the default account's state so services can propagate shared
57    /// resources (e.g. body caches).
58    pub fn get_or_create(&mut self, account_id: &str) -> &mut T {
59        if !self.accounts.contains_key(account_id) {
60            let mut state = T::new_for_account(account_id, &self.region, &self.endpoint);
61            // Let the new state inherit shared resources from the default account.
62            if let Some(sibling) = self.accounts.get(&self.default_account_id) {
63                state.inherit_from(sibling);
64            }
65            self.accounts.insert(account_id.to_string(), state);
66        }
67        self.accounts.get_mut(account_id).unwrap()
68    }
69
70    /// Get or lazily create the state for `account_id`, then run `init` on
71    /// the newly created state. The callback is only invoked when the account
72    /// is freshly created, not on subsequent lookups.
73    pub fn get_or_create_with<F>(&mut self, account_id: &str, init: F) -> &mut T
74    where
75        F: FnOnce(&mut T),
76    {
77        if !self.accounts.contains_key(account_id) {
78            let mut state = T::new_for_account(account_id, &self.region, &self.endpoint);
79            init(&mut state);
80            self.accounts.insert(account_id.to_string(), state);
81        }
82        self.accounts.get_mut(account_id).unwrap()
83    }
84
85    /// Read-only lookup. Returns `None` if the account has never been seen.
86    pub fn get(&self, account_id: &str) -> Option<&T> {
87        self.accounts.get(account_id)
88    }
89
90    /// Mutable lookup without auto-creation.
91    pub fn get_mut(&mut self, account_id: &str) -> Option<&mut T> {
92        self.accounts.get_mut(account_id)
93    }
94
95    /// Iterate over all account states (read-only).
96    pub fn iter(&self) -> impl Iterator<Item = (&str, &T)> {
97        self.accounts.iter().map(|(k, v)| (k.as_str(), v))
98    }
99
100    /// Iterate over all account states (mutable).
101    pub fn iter_mut(&mut self) -> impl Iterator<Item = (&str, &mut T)> {
102        self.accounts.iter_mut().map(|(k, v)| (k.as_str(), v))
103    }
104
105    /// The default account id configured via `--account-id`.
106    pub fn default_account_id(&self) -> &str {
107        &self.default_account_id
108    }
109
110    /// Mutable reference to the default account's state (always exists).
111    pub fn default_mut(&mut self) -> &mut T {
112        self.accounts.get_mut(&self.default_account_id).unwrap()
113    }
114
115    /// Reference to the default account's state (always exists).
116    pub fn default_ref(&self) -> &T {
117        self.accounts.get(&self.default_account_id).unwrap()
118    }
119
120    /// Reset all accounts back to empty state. The default account is
121    /// recreated; all other accounts are dropped.
122    pub fn reset(&mut self) {
123        self.accounts.clear();
124        self.accounts.insert(
125            self.default_account_id.clone(),
126            T::new_for_account(&self.default_account_id, &self.region, &self.endpoint),
127        );
128    }
129
130    /// Find the first account whose state satisfies `predicate` and return
131    /// the account id. Useful for resolving globally-unique resources (e.g.
132    /// S3 bucket names) back to their owning account.
133    pub fn find_account<F>(&self, predicate: F) -> Option<&str>
134    where
135        F: Fn(&T) -> bool,
136    {
137        self.accounts
138            .iter()
139            .find(|(_, v)| predicate(v))
140            .map(|(k, _)| k.as_str())
141    }
142
143    /// Number of accounts with state.
144    pub fn account_count(&self) -> usize {
145        self.accounts.len()
146    }
147
148    /// Region shared by all accounts.
149    pub fn region(&self) -> &str {
150        &self.region
151    }
152
153    /// Endpoint shared by all accounts.
154    pub fn endpoint(&self) -> &str {
155        &self.endpoint
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162
163    #[derive(Debug, Clone, Serialize, Deserialize)]
164    struct TestState {
165        account_id: String,
166        items: Vec<String>,
167    }
168
169    impl AccountState for TestState {
170        fn new_for_account(account_id: &str, _region: &str, _endpoint: &str) -> Self {
171            Self {
172                account_id: account_id.to_string(),
173                items: Vec::new(),
174            }
175        }
176    }
177
178    #[test]
179    fn default_account_exists_on_creation() {
180        let mas: MultiAccountState<TestState> =
181            MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
182        assert_eq!(mas.account_count(), 1);
183        assert!(mas.get("111111111111").is_some());
184    }
185
186    #[test]
187    fn get_or_create_makes_new_account() {
188        let mut mas: MultiAccountState<TestState> =
189            MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
190        let state = mas.get_or_create("222222222222");
191        assert_eq!(state.account_id, "222222222222");
192        assert_eq!(mas.account_count(), 2);
193    }
194
195    #[test]
196    fn get_returns_none_for_unknown() {
197        let mas: MultiAccountState<TestState> =
198            MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
199        assert!(mas.get("999999999999").is_none());
200    }
201
202    #[test]
203    fn reset_clears_all_but_default() {
204        let mut mas: MultiAccountState<TestState> =
205            MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
206        mas.get_or_create("222222222222");
207        mas.get_or_create("333333333333");
208        assert_eq!(mas.account_count(), 3);
209        mas.reset();
210        assert_eq!(mas.account_count(), 1);
211        assert!(mas.get("111111111111").is_some());
212        assert!(mas.get("222222222222").is_none());
213    }
214
215    #[test]
216    fn iter_visits_all_accounts() {
217        let mut mas: MultiAccountState<TestState> =
218            MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
219        mas.get_or_create("222222222222");
220        let ids: Vec<&str> = mas.iter().map(|(id, _)| id).collect();
221        assert_eq!(ids.len(), 2);
222        assert!(ids.contains(&"111111111111"));
223        assert!(ids.contains(&"222222222222"));
224    }
225}