bitwarden-state 3.0.0

Internal crate for the bitwarden crate. Do not use.
Documentation
use std::{
    any::{Any, TypeId},
    collections::HashMap,
    sync::{Arc, OnceLock, RwLock},
};

use bitwarden_error::bitwarden_error;
use thiserror::Error;

use crate::{
    repository::{Repository, RepositoryItem, RepositoryItemData, RepositoryMigrations},
    sdk_managed::{Database, DatabaseConfiguration, MemoryDatabase, SystemDatabase},
    settings::{Key, Setting, SettingItem},
};

/// A registry that contains repositories for different types of items.
/// These repositories can be either managed by the client or by the SDK itself.
pub struct StateRegistry {
    sdk_managed: RwLock<Vec<RepositoryItemData>>,
    client_managed: RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>,

    database: OnceLock<SystemDatabase>,
}

impl std::fmt::Debug for StateRegistry {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("StateRegistry").finish()
    }
}

#[allow(missing_docs)]
#[bitwarden_error(flat)]
#[derive(Debug, Error)]
pub enum StateRegistryError {
    #[error("Database is already initialized")]
    DatabaseAlreadyInitialized,
    #[error("Database is not initialized")]
    DatabaseNotInitialized,

    #[error(transparent)]
    Database(#[from] crate::sdk_managed::DatabaseError),
}

impl StateRegistry {
    /// Creates a new empty `StateRegistry`.
    #[allow(clippy::new_without_default)]
    pub fn new() -> Self {
        StateRegistry {
            client_managed: RwLock::new(HashMap::new()),
            database: OnceLock::new(),
            sdk_managed: RwLock::new(Vec::new()),
        }
    }

    /// Creates a new `StateRegistry` backed by an in-memory database.
    pub fn new_with_memory_db() -> Self {
        let registry = Self::new();
        // OnceLock::set returns Err only if already set.
        // new() guarantees the OnceLock is unset. We ignore the result
        // because there is no failure scenario here.
        let _ = registry
            .database
            .set(SystemDatabase::Memory(MemoryDatabase::new()));
        registry
    }

    // TODO: Ideally we'd do this in new, but that would mean making the client initialization
    // async.
    // TODO: This function needs to be provided some configuration to know where to open the
    // database. For Sqlite:
    // - A folder path where the files will be stored.
    // - A user ID to create a unique database file per user?
    //
    // For WASM indexedDB:
    // - A database name to use for the indexedDB (Some prefix to avoid conflicts + user ID?)

    /// Initializes the database used for sdk-managed repositories.
    pub async fn initialize_database(
        &self,
        configuration: DatabaseConfiguration,
        migrations: RepositoryMigrations,
    ) -> Result<(), StateRegistryError> {
        if self.database.get().is_some() {
            return Err(StateRegistryError::DatabaseAlreadyInitialized);
        }
        let _ = self
            .database
            .set(SystemDatabase::initialize(configuration, migrations.clone()).await?);

        *self
            .sdk_managed
            .write()
            .expect("RwLock should not be poisoned") = migrations.into_repository_items();

        Ok(())
    }

    /// Get a handle to a setting by its type-safe key.
    pub fn setting<T>(&self, key: Key<T>) -> Result<Setting<T>, StateRegistryError> {
        let repo = self.get::<SettingItem>()?;
        Ok(Setting::new(repo, key))
    }

    /// Registers a client-managed repository into the map, associating it with its type.
    pub fn register_client_managed<T: RepositoryItem>(&self, value: Arc<dyn Repository<T>>) {
        self.client_managed
            .write()
            .expect("RwLock should not be poisoned")
            .insert(TypeId::of::<T>(), Box::new(value));
    }

    /// Retrieves a client-managed repository from the map given its type.
    fn get_client_managed<T: RepositoryItem>(&self) -> Option<Arc<dyn Repository<T>>> {
        self.client_managed
            .read()
            .expect("RwLock should not be poisoned")
            .get(&TypeId::of::<T>())
            .and_then(|boxed| boxed.downcast_ref::<Arc<dyn Repository<T>>>())
            .map(Arc::clone)
    }

    /// Retrieves a SDK-managed repository from the database.
    fn get_sdk_managed<T: RepositoryItem>(
        &self,
    ) -> Result<Arc<dyn Repository<T>>, StateRegistryError> {
        self.database
            .get()
            .map(|db| db.get_repository::<T>())
            .ok_or(StateRegistryError::DatabaseNotInitialized)
    }

    /// Get a repository with fallback: prefer client-managed, fall back to SDK-managed.
    ///
    /// This method first attempts to retrieve a client-managed repository. If not found,
    /// it falls back to an SDK-managed repository. Both are returned as `Arc<dyn Repository<T>>`.
    ///
    /// # Type Requirements
    /// - `T` must implement `RepositoryItem` (for both types)
    ///
    /// # Errors
    /// Returns `StateRegistryError` when:
    /// - Client-managed repository is not registered, AND
    /// - SDK-managed repository cannot be retrieved (e.g., database not initialized)
    pub fn get<T>(&self) -> Result<Arc<dyn Repository<T>>, StateRegistryError>
    where
        T: RepositoryItem,
    {
        if let Some(repo) = self.get_client_managed::<T>() {
            return Ok(repo);
        }

        self.get_sdk_managed::<T>()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        register_repository_item,
        repository::{RepositoryError, RepositoryItem},
    };

    macro_rules! impl_repository {
        ($name:ident, $ty:ty) => {
            #[async_trait::async_trait]
            impl Repository<$ty> for $name {
                async fn get(&self, _key: String) -> Result<Option<$ty>, RepositoryError> {
                    Ok(Some(TestItem(self.0.clone())))
                }
                async fn list(&self) -> Result<Vec<$ty>, RepositoryError> {
                    unimplemented!()
                }
                async fn set(&self, _key: String, _value: $ty) -> Result<(), RepositoryError> {
                    unimplemented!()
                }
                async fn set_bulk(
                    &self,
                    _values: Vec<(String, $ty)>,
                ) -> Result<(), RepositoryError> {
                    unimplemented!()
                }
                async fn remove(&self, _key: String) -> Result<(), RepositoryError> {
                    unimplemented!()
                }
                async fn remove_bulk(&self, _keys: Vec<String>) -> Result<(), RepositoryError> {
                    unimplemented!()
                }
                async fn remove_all(&self) -> Result<(), RepositoryError> {
                    unimplemented!()
                }
            }
        };
    }

    use serde::{Deserialize, Serialize};

    #[derive(PartialEq, Eq, Debug)]
    struct TestA(usize);
    #[derive(PartialEq, Eq, Debug)]
    struct TestB(String);
    #[derive(PartialEq, Eq, Debug)]
    struct TestC(Vec<u8>);
    #[derive(PartialEq, Eq, Debug, Serialize, Deserialize)]
    struct TestItem<T>(T);

    register_repository_item!(String => TestItem<usize>, "TestItem_usize");
    register_repository_item!(String => TestItem<String>, "TestItem_String");
    register_repository_item!(String => TestItem<Vec<u8>>, "TestItem_Vec");

    impl_repository!(TestA, TestItem<usize>);
    impl_repository!(TestB, TestItem<String>);
    impl_repository!(TestC, TestItem<Vec<u8>>);

    #[tokio::test]
    async fn test_state_registry() {
        let a = Arc::new(TestA(145832));
        let b = Arc::new(TestB("test".to_string()));
        let c = Arc::new(TestC(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]));

        let map = StateRegistry::new();

        async fn get<T: RepositoryItem>(map: &StateRegistry) -> Option<T>
        where
            T::Key: Default,
        {
            map.get_client_managed::<T>()
                .unwrap()
                .get(Default::default())
                .await
                .unwrap()
        }

        assert!(map.get_client_managed::<TestItem<usize>>().is_none());
        assert!(map.get_client_managed::<TestItem<String>>().is_none());
        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());

        map.register_client_managed(a.clone());
        assert_eq!(get(&map).await, Some(TestItem(a.0)));
        assert!(map.get_client_managed::<TestItem<String>>().is_none());
        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());

        map.register_client_managed(b.clone());
        assert_eq!(get(&map).await, Some(TestItem(a.0)));
        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
        assert!(map.get_client_managed::<TestItem<Vec<u8>>>().is_none());

        map.register_client_managed(c.clone());
        assert_eq!(get(&map).await, Some(TestItem(a.0)));
        assert_eq!(get(&map).await, Some(TestItem(b.0.clone())));
        assert_eq!(get(&map).await, Some(TestItem(c.0.clone())));
    }

    #[tokio::test]
    async fn test_fallback_client_managed_found() {
        let registry = StateRegistry::new();
        let test_repo = Arc::new(TestA(12345));

        registry.register_client_managed(test_repo.clone());

        let repo = registry.get::<TestItem<usize>>().unwrap();
        let result = repo.get(String::new()).await.unwrap();

        assert_eq!(result, Some(TestItem(12345)));
    }

    #[tokio::test]
    async fn test_fallback_neither_available() {
        let registry = StateRegistry::new();
        // Don't register client-managed or initialize database

        let result = registry.get::<TestItem<usize>>();
        assert!(matches!(
            result,
            Err(StateRegistryError::DatabaseNotInitialized)
        ));
    }

    #[tokio::test]
    async fn test_new_with_memory_db_sync() {
        // Construct in sync context (no .await on the constructor itself)
        let registry = StateRegistry::new_with_memory_db();
        // Database must be accessible via async get after sync construction
        let repo = registry.get::<TestItem<usize>>().unwrap();
        let result = repo.get(String::new()).await;
        // Should return Ok(None) — key not found, not an error
        // (Note: TestItem<usize> is registered in this test module already)
        assert!(result.is_ok());
    }

    #[tokio::test]
    async fn test_setting_on_memory_db() {
        use crate::register_setting_key;
        register_setting_key!(const TEST_SETTING: String = "test_registry_setting_key");

        let registry = StateRegistry::new_with_memory_db();
        let setting = registry.setting(TEST_SETTING).unwrap();

        // Value must not exist initially
        assert_eq!(setting.get().await.unwrap(), None::<String>);

        // Update and read back
        setting.update("hello".to_string()).await.unwrap();
        assert_eq!(setting.get().await.unwrap(), Some("hello".to_string()));

        // Delete and confirm gone
        setting.delete().await.unwrap();
        assert_eq!(setting.get().await.unwrap(), None::<String>);
    }
}