use std::{
any::{Any, TypeId},
collections::HashMap,
sync::{Arc, OnceLock, RwLock},
};
use bitwarden_error::bitwarden_error;
use thiserror::Error;
use crate::{
repository::{Repository, RepositoryItem, RepositoryItemData, RepositoryMigrations},
sdk_managed::{Database, DatabaseConfiguration, MemoryDatabase, SystemDatabase},
settings::{Key, Setting, SettingItem},
};
pub struct StateRegistry {
sdk_managed: RwLock<Vec<RepositoryItemData>>,
client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,
database: OnceLock<SystemDatabase>,
}
impl std::fmt::Debug for StateRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StateRegistry").finish()
}
}
#[allow(missing_docs)]
#[bitwarden_error(flat)]
#[derive(Debug, Error)]
pub enum StateRegistryError {
#[error("Database is already initialized")]
DatabaseAlreadyInitialized,
#[error("Database is not initialized")]
DatabaseNotInitialized,
#[error(transparent)]
Database(#[from] crate::sdk_managed::DatabaseError),
}
impl StateRegistry {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
StateRegistry {
client_managed: RwLock::new(HashMap::new()),
database: OnceLock::new(),
sdk_managed: RwLock::new(Vec::new()),
}
}
pub fn new_with_memory_db() -> Self {
let registry = Self::new();
let _ = registry
.database
.set(SystemDatabase::Memory(MemoryDatabase::new()));
registry
}
pub async fn initialize_database(
&self,
configuration: DatabaseConfiguration,
migrations: RepositoryMigrations,
) -> Result<(), StateRegistryError> {
if self.database.get().is_some() {
return Err(StateRegistryError::DatabaseAlreadyInitialized);
}
let _ = self
.database
.set(SystemDatabase::initialize(configuration, migrations.clone()).await?);
*self
.sdk_managed
.write()
.expect("RwLock should not be poisoned") = migrations.into_repository_items();
Ok(())
}
pub fn setting<T>(&self, key: Key<T>) -> Result<Setting<T>, StateRegistryError> {
let repo = self.get::<SettingItem>()?;
Ok(Setting::new(repo, key))
}
pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
self.client_managed
.write()
.expect("RwLock should not be poisoned")
.insert(TypeId::of::<T>(), Box::new(value));
}
fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
self.client_managed
.read()
.expect("RwLock should not be poisoned")
.get(&TypeId::of::<T>())
.and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
.map(Arc::clone)
}
fn get_sdk_managed<T: RepositoryItem>(
&self,
) -> Result<Arc<dyn Repository<T>>, StateRegistryError> {
self.database
.get()
.map(|db| db.get_repository::<T>())
.ok_or(StateRegistryError::DatabaseNotInitialized)
}
pub fn get<T>(&self) -> Result<Arc<dyn Repository<T>>, StateRegistryError>
where
T: RepositoryItem,
{
if let Some(repo) = self.get_client_managed::<T>() {
return Ok(repo);
}
self.get_sdk_managed::<T>()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
register_repository_item,
repository::{RepositoryError, RepositoryItem},
};
macro_rules! impl_repository {
($name:ident, $ty:ty) => {
#[async_trait::async_trait]
impl Repository<$ty> for $name {
async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
Ok(Some(TestItem(self.0.clone())))
}
async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
unimplemented!()
}
async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
unimplemented!()
}
async fn set_bulk(
&self,
_values: Vec<(String, $ty)>,
) -> Result<(), RepositoryError> {
unimplemented!()
}
async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
unimplemented!()
}
async fn remove_bulk(&self, _keys: Vec<String>) -> Result<(), RepositoryError> {
unimplemented!()
}
async fn remove_all(&self) -> Result<(), RepositoryError> {
unimplemented!()
}
}
};
}
use serde::{Deserialize, Serialize};
#[derive(PartialEq, Eq, Debug)]
struct TestA(usize);
#[derive(PartialEq, Eq, Debug)]
struct TestB(String);
#[derive(PartialEq, Eq, Debug)]
struct TestC(Vec<u8>);
#[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
struct TestItem<T>(T);
register_repository_item!(String => TestItem<usize>, "TestItem_usize");
register_repository_item!(String => TestItem<String>, "TestItem_String");
register_repository_item!(String => TestItem<Vec<u8>>, "TestItem_Vec");
impl_repository!(TestA, TestItem<usize>);
impl_repository!(TestB, TestItem<String>);
impl_repository!(TestC, TestItem<Vec<u8>>);
#[tokio::test]
async fn test_state_registry() {
let a = Arc::new(TestA(145832));
let b = Arc::new(TestB("test".to_string()));
let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));
let map = StateRegistry::new();
async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T>
where
T::Key: Default,
{
map.get_client_managed::<T>()
.unwrap()
.get(Default::default())
.await
.unwrap()
}
assert!(map.get_client_managed::<TestItem<usize>>().is_none());
assert!(map.get_client_managed::<TestItem<String>>().is_none());
assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
map.register_client_managed(a.clone());
assert_eq!(get(&map).await, Some(TestItem(a.0)));
assert!(map.get_client_managed::<TestItem<String>>().is_none());
assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
map.register_client_managed(b.clone());
assert_eq!(get(&map).await, Some(TestItem(a.0)));
assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());
map.register_client_managed(c.clone());
assert_eq!(get(&map).await, Some(TestItem(a.0)));
assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
}
#[tokio::test]
async fn test_fallback_client_managed_found() {
let registry = StateRegistry::new();
let test_repo = Arc::new(TestA(12345));
registry.register_client_managed(test_repo.clone());
let repo = registry.get::<TestItem<usize>>().unwrap();
let result = repo.get(String::new()).await.unwrap();
assert_eq!(result, Some(TestItem(12345)));
}
#[tokio::test]
async fn test_fallback_neither_available() {
let registry = StateRegistry::new();
let result = registry.get::<TestItem<usize>>();
assert!(matches!(
result,
Err(StateRegistryError::DatabaseNotInitialized)
));
}
#[tokio::test]
async fn test_new_with_memory_db_sync() {
let registry = StateRegistry::new_with_memory_db();
let repo = registry.get::<TestItem<usize>>().unwrap();
let result = repo.get(String::new()).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_setting_on_memory_db() {
use crate::register_setting_key;
register_setting_key!(const TEST_SETTING: String = "test_registry_setting_key");
let registry = StateRegistry::new_with_memory_db();
let setting = registry.setting(TEST_SETTING).unwrap();
assert_eq!(setting.get().await.unwrap(), None::<String>);
setting.update("hello".to_string()).await.unwrap();
assert_eq!(setting.get().await.unwrap(), Some("hello".to_string()));
setting.delete().await.unwrap();
assert_eq!(setting.get().await.unwrap(), None::<String>);
}
}