use crate::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct VaultInfo {
pub name: String,
pub path: std::path::PathBuf,
pub is_default: bool,
pub config: VaultConfig,
}
pub struct MultiVaultManager {
vaults: Arc<RwLock<HashMap<String, VaultConfig>>>,
default_vault: Arc<RwLock<String>>,
config: ServerConfig,
}
impl MultiVaultManager {
pub fn new(config: ServerConfig) -> Result<Self> {
let vaults = Arc::new(RwLock::new(
config
.vaults
.iter()
.map(|v| (v.name.clone(), v.clone()))
.collect(),
));
let default_name = if config.vaults.is_empty() {
String::new() } else {
config
.vaults
.iter()
.find(|v| v.is_default)
.map(|v| v.name.clone())
.or_else(|| config.vaults.first().map(|v| v.name.clone()))
.unwrap_or_default()
};
Ok(Self {
vaults,
default_vault: Arc::new(RwLock::new(default_name)),
config,
})
}
pub fn empty(config: ServerConfig) -> Result<Self> {
Ok(Self {
vaults: Arc::new(RwLock::new(HashMap::new())),
default_vault: Arc::new(RwLock::new(String::new())),
config,
})
}
const MAX_VAULTS: usize = 50;
pub async fn add_vault(&self, vault_config: VaultConfig) -> Result<()> {
let mut vaults = self.vaults.write().await;
if vaults.len() >= Self::MAX_VAULTS {
return Err(Error::config_error(format!(
"Maximum vault limit ({}) reached. Remove a vault before adding a new one.",
Self::MAX_VAULTS
)));
}
if vaults.contains_key(&vault_config.name) {
return Err(Error::invalid_path(format!(
"Vault '{}' already exists",
vault_config.name
)));
}
let is_first_vault = vaults.is_empty();
vaults.insert(vault_config.name.clone(), vault_config.clone());
if is_first_vault {
drop(vaults); *self.default_vault.write().await = vault_config.name;
}
Ok(())
}
pub async fn remove_vault(&self, name: &str) -> Result<()> {
let mut vaults = self.vaults.write().await;
if !vaults.contains_key(name) {
return Err(Error::not_found(format!("Vault '{}' not found", name)));
}
let current_default = self.default_vault.read().await;
if *current_default == name {
drop(current_default); vaults.remove(name);
if let Some((first_name, _)) = vaults.iter().next() {
*self.default_vault.write().await = first_name.clone();
} else {
*self.default_vault.write().await = String::new();
}
} else {
vaults.remove(name);
}
Ok(())
}
pub async fn get_vault_config(&self, name: &str) -> Result<VaultConfig> {
let vaults = self.vaults.read().await;
vaults
.get(name)
.cloned()
.ok_or_else(|| Error::not_found(format!("Vault '{}' not found", name)))
}
pub async fn get_active_vault(&self) -> String {
self.default_vault.read().await.clone()
}
pub async fn set_active_vault(&self, name: &str) -> Result<()> {
let vaults = self.vaults.read().await;
if !vaults.contains_key(name) {
return Err(Error::not_found(format!("Vault '{}' not found", name)));
}
*self.default_vault.write().await = name.to_string();
Ok(())
}
pub async fn list_vaults(&self) -> Result<Vec<VaultInfo>> {
let vaults = self.vaults.read().await;
let default = self.default_vault.read().await.clone();
let infos = vaults
.iter()
.map(|(name, config)| VaultInfo {
name: name.clone(),
path: config.path.clone(),
is_default: name == &default,
config: config.clone(),
})
.collect();
Ok(infos)
}
pub async fn get_effective_vault_settings(&self, vault_name: &str) -> Result<VaultConfig> {
let vault_config = self.get_vault_config(vault_name).await?;
let effective = vault_config.clone();
Ok(effective)
}
pub async fn vault_count(&self) -> usize {
self.vaults.read().await.len()
}
pub async fn vault_exists(&self, name: &str) -> bool {
self.vaults.read().await.contains_key(name)
}
pub async fn get_active_vault_config(&self) -> Result<VaultConfig> {
let active_name = self.default_vault.read().await.clone();
if active_name.is_empty() {
return Err(Error::not_found(
"No vault is currently active. Please add a vault using add_vault tool."
.to_string(),
));
}
self.get_vault_config(&active_name).await
}
}
impl Clone for MultiVaultManager {
fn clone(&self) -> Self {
Self {
vaults: self.vaults.clone(),
default_vault: self.default_vault.clone(),
config: self.config.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_vault(name: &str, is_default: bool) -> VaultConfig {
VaultConfig {
name: name.to_string(),
path: std::path::PathBuf::from(format!("/tmp/{}", name)),
is_default,
watch_for_changes: None,
max_file_size: None,
allowed_extensions: None,
excluded_paths: None,
enable_caching: None,
cache_ttl: None,
template_dirs: None,
allowed_operations: None,
}
}
fn create_test_config() -> ServerConfig {
let mut config = ServerConfig::new();
config.vaults = vec![
create_test_vault("vault1", true),
create_test_vault("vault2", false),
];
config
}
#[test]
fn test_multi_vault_manager_creation() {
let config = create_test_config();
let manager = MultiVaultManager::new(config);
assert!(manager.is_ok());
}
#[test]
fn test_can_create_empty_vaults() {
let config = ServerConfig::new(); let manager = MultiVaultManager::new(config);
assert!(manager.is_ok());
let mgr = manager.unwrap();
let rt = tokio::runtime::Runtime::new().unwrap();
let default = rt.block_on(async { mgr.get_active_vault().await });
assert!(default.is_empty());
}
#[tokio::test]
async fn test_get_active_vault() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
let active = manager.get_active_vault().await;
assert_eq!(active, "vault1");
}
#[tokio::test]
async fn test_set_active_vault() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
manager.set_active_vault("vault2").await.unwrap();
let active = manager.get_active_vault().await;
assert_eq!(active, "vault2");
}
#[tokio::test]
async fn test_set_invalid_active_vault() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
let result = manager.set_active_vault("nonexistent").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_add_vault() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
let new_vault = create_test_vault("vault3", false);
manager.add_vault(new_vault).await.unwrap();
assert!(manager.vault_exists("vault3").await);
}
#[tokio::test]
async fn test_add_duplicate_vault_fails() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
let dup_vault = create_test_vault("vault1", false);
let result = manager.add_vault(dup_vault).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_remove_vault() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
manager.remove_vault("vault2").await.unwrap();
assert!(!manager.vault_exists("vault2").await);
}
#[tokio::test]
async fn test_remove_default_vault_reassigns() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
assert_eq!(manager.get_active_vault().await, "vault1");
let result = manager.remove_vault("vault1").await;
assert!(result.is_ok());
assert!(!manager.vault_exists("vault1").await);
let new_default = manager.get_active_vault().await;
assert_eq!(new_default, "vault2");
}
#[tokio::test]
async fn test_list_vaults() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
let vaults = manager.list_vaults().await.unwrap();
assert_eq!(vaults.len(), 2);
assert!(vaults.iter().any(|v| v.name == "vault1" && v.is_default));
assert!(vaults.iter().any(|v| v.name == "vault2" && !v.is_default));
}
#[tokio::test]
async fn test_vault_count() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
assert_eq!(manager.vault_count().await, 2);
manager
.add_vault(create_test_vault("vault3", false))
.await
.ok();
assert_eq!(manager.vault_count().await, 3);
}
#[tokio::test]
async fn test_get_effective_vault_settings() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
let settings = manager
.get_effective_vault_settings("vault1")
.await
.unwrap();
assert_eq!(settings.name, "vault1");
}
#[tokio::test]
async fn test_clone() {
let config = create_test_config();
let manager = MultiVaultManager::new(config).unwrap();
let manager2 = manager.clone();
assert_eq!(manager.vault_count().await, manager2.vault_count().await);
}
#[tokio::test]
async fn test_remove_active_vault_reassigns() {
let config = create_test_config(); let manager = MultiVaultManager::new(config).unwrap();
manager.set_active_vault("vault2").await.unwrap();
assert_eq!(manager.get_active_vault().await, "vault2");
manager.remove_vault("vault1").await.unwrap();
assert!(!manager.vault_exists("vault1").await);
assert_eq!(
manager.get_active_vault().await,
"vault2",
"active vault should remain vault2 after vault1 is removed"
);
manager.remove_vault("vault2").await.unwrap();
assert!(!manager.vault_exists("vault2").await);
let active_after = manager.get_active_vault().await;
assert!(
active_after.is_empty(),
"active vault should be empty when last vault is removed, got: {:?}",
active_after
);
}
#[tokio::test]
async fn test_add_vault_at_max_limit() {
let config = ServerConfig::new(); let manager = MultiVaultManager::new(config).unwrap();
for i in 0..MultiVaultManager::MAX_VAULTS {
let vault = create_test_vault(&format!("vault{}", i), i == 0);
manager.add_vault(vault).await.unwrap();
}
assert_eq!(manager.vault_count().await, MultiVaultManager::MAX_VAULTS);
let overflow = create_test_vault("overflow", false);
let result = manager.add_vault(overflow).await;
assert!(
result.is_err(),
"adding the {}th vault should fail",
MultiVaultManager::MAX_VAULTS + 1
);
}
}