use std::collections::HashMap;
use serde::{Deserialize, Serialize};
pub trait AccountState: Sized {
fn new_for_account(account_id: &str, region: &str, endpoint: &str) -> Self;
fn inherit_from(&mut self, _sibling: &Self) {}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiAccountState<T> {
default_account_id: String,
region: String,
endpoint: String,
accounts: HashMap<String, T>,
}
impl<T: AccountState> MultiAccountState<T> {
pub fn new(default_account_id: &str, region: &str, endpoint: &str) -> Self {
let mut accounts = HashMap::new();
accounts.insert(
default_account_id.to_string(),
T::new_for_account(default_account_id, region, endpoint),
);
Self {
default_account_id: default_account_id.to_string(),
region: region.to_string(),
endpoint: endpoint.to_string(),
accounts,
}
}
pub fn get_or_create(&mut self, account_id: &str) -> &mut T {
if !self.accounts.contains_key(account_id) {
let mut state = T::new_for_account(account_id, &self.region, &self.endpoint);
if let Some(sibling) = self.accounts.get(&self.default_account_id) {
state.inherit_from(sibling);
}
self.accounts.insert(account_id.to_string(), state);
}
self.accounts.get_mut(account_id).unwrap()
}
pub fn get_or_create_with<F>(&mut self, account_id: &str, init: F) -> &mut T
where
F: FnOnce(&mut T),
{
if !self.accounts.contains_key(account_id) {
let mut state = T::new_for_account(account_id, &self.region, &self.endpoint);
init(&mut state);
self.accounts.insert(account_id.to_string(), state);
}
self.accounts.get_mut(account_id).unwrap()
}
pub fn get(&self, account_id: &str) -> Option<&T> {
self.accounts.get(account_id)
}
pub fn get_mut(&mut self, account_id: &str) -> Option<&mut T> {
self.accounts.get_mut(account_id)
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &T)> {
self.accounts.iter().map(|(k, v)| (k.as_str(), v))
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = (&str, &mut T)> {
self.accounts.iter_mut().map(|(k, v)| (k.as_str(), v))
}
pub fn default_account_id(&self) -> &str {
&self.default_account_id
}
pub fn default_mut(&mut self) -> &mut T {
self.accounts.get_mut(&self.default_account_id).unwrap()
}
pub fn default_ref(&self) -> &T {
self.accounts.get(&self.default_account_id).unwrap()
}
pub fn reset(&mut self) {
self.accounts.clear();
self.accounts.insert(
self.default_account_id.clone(),
T::new_for_account(&self.default_account_id, &self.region, &self.endpoint),
);
}
pub fn find_account<F>(&self, predicate: F) -> Option<&str>
where
F: Fn(&T) -> bool,
{
self.accounts
.iter()
.find(|(_, v)| predicate(v))
.map(|(k, _)| k.as_str())
}
pub fn account_count(&self) -> usize {
self.accounts.len()
}
pub fn region(&self) -> &str {
&self.region
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestState {
account_id: String,
items: Vec<String>,
}
impl AccountState for TestState {
fn new_for_account(account_id: &str, _region: &str, _endpoint: &str) -> Self {
Self {
account_id: account_id.to_string(),
items: Vec::new(),
}
}
}
#[test]
fn default_account_exists_on_creation() {
let mas: MultiAccountState<TestState> =
MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
assert_eq!(mas.account_count(), 1);
assert!(mas.get("111111111111").is_some());
}
#[test]
fn get_or_create_makes_new_account() {
let mut mas: MultiAccountState<TestState> =
MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
let state = mas.get_or_create("222222222222");
assert_eq!(state.account_id, "222222222222");
assert_eq!(mas.account_count(), 2);
}
#[test]
fn get_returns_none_for_unknown() {
let mas: MultiAccountState<TestState> =
MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
assert!(mas.get("999999999999").is_none());
}
#[test]
fn reset_clears_all_but_default() {
let mut mas: MultiAccountState<TestState> =
MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
mas.get_or_create("222222222222");
mas.get_or_create("333333333333");
assert_eq!(mas.account_count(), 3);
mas.reset();
assert_eq!(mas.account_count(), 1);
assert!(mas.get("111111111111").is_some());
assert!(mas.get("222222222222").is_none());
}
#[test]
fn iter_visits_all_accounts() {
let mut mas: MultiAccountState<TestState> =
MultiAccountState::new("111111111111", "us-east-1", "http://localhost:4566");
mas.get_or_create("222222222222");
let ids: Vec<&str> = mas.iter().map(|(id, _)| id).collect();
assert_eq!(ids.len(), 2);
assert!(ids.contains(&"111111111111"));
assert!(ids.contains(&"222222222222"));
}
}