use crate::error::{AmateRSError, ErrorContext, Result};
use dashmap::DashMap;
use parking_lot::RwLock;
use std::sync::Arc;
#[cfg(feature = "compute")]
use tfhe::ServerKey;
pub type ClientId = String;
#[derive(Default)]
pub struct KeyManager {
server_keys: DashMap<ClientId, Arc<ServerKey>>,
global_key: RwLock<Option<Arc<ServerKey>>>,
}
impl KeyManager {
pub fn new() -> Self {
Self {
server_keys: DashMap::new(),
global_key: RwLock::new(None),
}
}
#[cfg(feature = "compute")]
pub fn register_key(&self, client_id: ClientId, key: ServerKey) {
self.server_keys.insert(client_id, Arc::new(key));
}
#[cfg(not(feature = "compute"))]
pub fn register_key(&self, _client_id: ClientId, _key: ()) {
}
#[cfg(feature = "compute")]
pub fn get_key(&self, client_id: &str) -> Option<Arc<ServerKey>> {
self.server_keys
.get(client_id)
.map(|entry| entry.value().clone())
}
#[cfg(not(feature = "compute"))]
pub fn get_key(&self, _client_id: &str) -> Option<()> {
None
}
pub fn set_global(&self, client_id: &str) -> Result<()> {
#[cfg(feature = "compute")]
{
let key = self.get_key(client_id).ok_or_else(|| {
AmateRSError::FheComputation(ErrorContext::new(format!(
"No server key found for client: {}",
client_id
)))
})?;
let mut global = self.global_key.write();
*global = Some(key);
Ok(())
}
#[cfg(not(feature = "compute"))]
{
let _ = client_id;
Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
"FHE compute feature is not enabled".to_string(),
)))
}
}
#[cfg(feature = "compute")]
pub fn get_global(&self) -> Option<Arc<ServerKey>> {
self.global_key.read().clone()
}
#[cfg(not(feature = "compute"))]
pub fn get_global(&self) -> Option<()> {
None
}
pub fn remove_key(&self, client_id: &str) -> bool {
self.server_keys.remove(client_id).is_some()
}
pub fn key_count(&self) -> usize {
self.server_keys.len()
}
pub fn clear(&self) {
self.server_keys.clear();
let mut global = self.global_key.write();
*global = None;
}
}
#[cfg(all(test, feature = "compute"))]
mod tests {
use super::*;
use crate::compute::FheKeyPair;
#[test]
fn test_key_manager_new() {
let manager = KeyManager::new();
assert_eq!(manager.key_count(), 0);
assert!(manager.get_global().is_none());
}
#[test]
fn test_register_and_get_key() -> Result<()> {
let manager = KeyManager::new();
let keypair = FheKeyPair::generate()?;
manager.register_key("client_1".to_string(), keypair.server_key().clone());
let retrieved = manager.get_key("client_1");
assert!(retrieved.is_some());
assert_eq!(manager.key_count(), 1);
Ok(())
}
#[test]
fn test_get_nonexistent_key() {
let manager = KeyManager::new();
let result = manager.get_key("nonexistent");
assert!(result.is_none());
}
#[test]
fn test_set_and_get_global() -> Result<()> {
let manager = KeyManager::new();
let keypair = FheKeyPair::generate()?;
manager.register_key("default".to_string(), keypair.server_key().clone());
manager.set_global("default")?;
let global = manager.get_global();
assert!(global.is_some());
Ok(())
}
#[test]
fn test_set_global_nonexistent_client() {
let manager = KeyManager::new();
let result = manager.set_global("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_remove_key() -> Result<()> {
let manager = KeyManager::new();
let keypair = FheKeyPair::generate()?;
manager.register_key("client_1".to_string(), keypair.server_key().clone());
assert_eq!(manager.key_count(), 1);
let removed = manager.remove_key("client_1");
assert!(removed);
assert_eq!(manager.key_count(), 0);
let removed_again = manager.remove_key("client_1");
assert!(!removed_again);
Ok(())
}
#[test]
fn test_key_count() -> Result<()> {
let manager = KeyManager::new();
assert_eq!(manager.key_count(), 0);
let keypair1 = FheKeyPair::generate()?;
let keypair2 = FheKeyPair::generate()?;
manager.register_key("client_1".to_string(), keypair1.server_key().clone());
assert_eq!(manager.key_count(), 1);
manager.register_key("client_2".to_string(), keypair2.server_key().clone());
assert_eq!(manager.key_count(), 2);
Ok(())
}
#[test]
fn test_replace_existing_key() -> Result<()> {
let manager = KeyManager::new();
let keypair1 = FheKeyPair::generate()?;
let keypair2 = FheKeyPair::generate()?;
manager.register_key("client_1".to_string(), keypair1.server_key().clone());
manager.register_key("client_1".to_string(), keypair2.server_key().clone());
assert_eq!(manager.key_count(), 1);
Ok(())
}
#[test]
fn test_clear() -> Result<()> {
let manager = KeyManager::new();
let keypair1 = FheKeyPair::generate()?;
let keypair2 = FheKeyPair::generate()?;
manager.register_key("client_1".to_string(), keypair1.server_key().clone());
manager.register_key("client_2".to_string(), keypair2.server_key().clone());
manager.set_global("client_1")?;
assert_eq!(manager.key_count(), 2);
assert!(manager.get_global().is_some());
manager.clear();
assert_eq!(manager.key_count(), 0);
assert!(manager.get_global().is_none());
Ok(())
}
#[test]
fn test_concurrent_access() -> Result<()> {
use std::thread;
let manager = Arc::new(KeyManager::new());
let mut handles = vec![];
for i in 0..10 {
let manager_clone = Arc::clone(&manager);
let handle = thread::spawn(move || -> Result<()> {
let keypair = FheKeyPair::generate()?;
let client_id = format!("client_{}", i);
manager_clone.register_key(client_id.clone(), keypair.server_key().clone());
let retrieved = manager_clone.get_key(&client_id);
assert!(retrieved.is_some());
Ok(())
});
handles.push(handle);
}
for handle in handles {
handle
.join()
.expect("Thread panicked")
.expect("Thread failed");
}
assert_eq!(manager.key_count(), 10);
Ok(())
}
}