amaters_core/compute/key_manager.rs
1//! FHE Server Key Management
2//!
3//! This module provides centralized management of TFHE server keys for multiple clients.
4//! Each client can register their server key, and the system can execute FHE operations
5//! using the appropriate key.
6
7use crate::error::{AmateRSError, ErrorContext, Result};
8use dashmap::DashMap;
9use parking_lot::RwLock;
10use std::sync::Arc;
11
12#[cfg(feature = "compute")]
13use tfhe::ServerKey;
14
15/// Type alias for client identifiers
16pub type ClientId = String;
17
18/// Manages FHE server keys for multiple clients
19///
20/// This structure provides thread-safe storage and retrieval of TFHE server keys.
21/// It supports both multi-client scenarios (where each client has their own key)
22/// and single-client scenarios (where a global key is used).
23///
24/// # Example
25///
26/// ```rust,ignore
27/// use amaters_core::compute::KeyManager;
28/// use tfhe::ServerKey;
29///
30/// let manager = KeyManager::new();
31///
32/// // Register a key for a client
33/// let server_key = ServerKey::new(&client_key);
34/// manager.register_key("client_1".to_string(), server_key);
35///
36/// // Retrieve the key
37/// let key = manager.get_key("client_1").expect("Client key should exist");
38///
39/// // Set as global for convenience
40/// manager.set_global("client_1")?;
41/// let global_key = manager.get_global().expect("Global key should be set");
42/// ```
43#[derive(Default)]
44pub struct KeyManager {
45 /// Map of client IDs to their server keys
46 server_keys: DashMap<ClientId, Arc<ServerKey>>,
47
48 /// Optional global server key (for single-client scenarios)
49 global_key: RwLock<Option<Arc<ServerKey>>>,
50}
51
52impl KeyManager {
53 /// Create a new empty key manager
54 ///
55 /// # Example
56 ///
57 /// ```rust
58 /// use amaters_core::compute::KeyManager;
59 ///
60 /// let manager = KeyManager::new();
61 /// assert_eq!(manager.key_count(), 0);
62 /// ```
63 pub fn new() -> Self {
64 Self {
65 server_keys: DashMap::new(),
66 global_key: RwLock::new(None),
67 }
68 }
69
70 /// Register a server key for a specific client
71 ///
72 /// If a key already exists for this client, it will be replaced.
73 ///
74 /// # Arguments
75 ///
76 /// * `client_id` - Unique identifier for the client
77 /// * `key` - The TFHE server key to register
78 ///
79 /// # Example
80 ///
81 /// ```rust,ignore
82 /// use amaters_core::compute::KeyManager;
83 ///
84 /// let manager = KeyManager::new();
85 /// let server_key = ServerKey::new(&client_key);
86 /// manager.register_key("client_1".to_string(), server_key);
87 /// ```
88 #[cfg(feature = "compute")]
89 pub fn register_key(&self, client_id: ClientId, key: ServerKey) {
90 self.server_keys.insert(client_id, Arc::new(key));
91 }
92
93 /// Stub for register_key when compute feature is disabled
94 #[cfg(not(feature = "compute"))]
95 pub fn register_key(&self, _client_id: ClientId, _key: ()) {
96 // No-op when compute is disabled
97 }
98
99 /// Get server key for a specific client
100 ///
101 /// Returns `None` if no key is registered for the given client.
102 ///
103 /// # Arguments
104 ///
105 /// * `client_id` - The client identifier
106 ///
107 /// # Returns
108 ///
109 /// An `Arc<ServerKey>` if found, `None` otherwise
110 ///
111 /// # Example
112 ///
113 /// ```rust,ignore
114 /// use amaters_core::compute::KeyManager;
115 ///
116 /// let manager = KeyManager::new();
117 /// manager.register_key("client_1".to_string(), server_key);
118 ///
119 /// let key = manager.get_key("client_1").expect("Key should exist");
120 /// ```
121 #[cfg(feature = "compute")]
122 pub fn get_key(&self, client_id: &str) -> Option<Arc<ServerKey>> {
123 self.server_keys
124 .get(client_id)
125 .map(|entry| entry.value().clone())
126 }
127
128 /// Stub for get_key when compute feature is disabled
129 #[cfg(not(feature = "compute"))]
130 pub fn get_key(&self, _client_id: &str) -> Option<()> {
131 None
132 }
133
134 /// Set a global server key from a registered client key
135 ///
136 /// This is a convenience method for single-client scenarios where you want
137 /// to use one client's key as the default for all operations.
138 ///
139 /// # Arguments
140 ///
141 /// * `client_id` - The client whose key should become the global key
142 ///
143 /// # Errors
144 ///
145 /// Returns an error if no key is registered for the specified client.
146 ///
147 /// # Example
148 ///
149 /// ```rust,ignore
150 /// use amaters_core::compute::KeyManager;
151 ///
152 /// let manager = KeyManager::new();
153 /// manager.register_key("default".to_string(), server_key);
154 /// manager.set_global("default")?;
155 ///
156 /// let global = manager.get_global().expect("Global key should be set");
157 /// ```
158 pub fn set_global(&self, client_id: &str) -> Result<()> {
159 #[cfg(feature = "compute")]
160 {
161 let key = self.get_key(client_id).ok_or_else(|| {
162 AmateRSError::FheComputation(ErrorContext::new(format!(
163 "No server key found for client: {}",
164 client_id
165 )))
166 })?;
167
168 let mut global = self.global_key.write();
169 *global = Some(key);
170 Ok(())
171 }
172
173 #[cfg(not(feature = "compute"))]
174 {
175 let _ = client_id;
176 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
177 "FHE compute feature is not enabled".to_string(),
178 )))
179 }
180 }
181
182 /// Get the global server key
183 ///
184 /// Returns `None` if no global key has been set.
185 ///
186 /// # Returns
187 ///
188 /// An `Arc<ServerKey>` if a global key is set, `None` otherwise
189 ///
190 /// # Example
191 ///
192 /// ```rust,ignore
193 /// use amaters_core::compute::KeyManager;
194 ///
195 /// let manager = KeyManager::new();
196 /// assert!(manager.get_global().is_none());
197 ///
198 /// manager.register_key("default".to_string(), server_key);
199 /// manager.set_global("default")?;
200 /// assert!(manager.get_global().is_some());
201 /// ```
202 #[cfg(feature = "compute")]
203 pub fn get_global(&self) -> Option<Arc<ServerKey>> {
204 self.global_key.read().clone()
205 }
206
207 /// Stub for get_global when compute feature is disabled
208 #[cfg(not(feature = "compute"))]
209 pub fn get_global(&self) -> Option<()> {
210 None
211 }
212
213 /// Remove a client's key
214 ///
215 /// Returns `true` if the key was found and removed, `false` otherwise.
216 ///
217 /// # Arguments
218 ///
219 /// * `client_id` - The client identifier
220 ///
221 /// # Returns
222 ///
223 /// `true` if a key was removed, `false` if no key was found
224 ///
225 /// # Example
226 ///
227 /// ```rust,ignore
228 /// use amaters_core::compute::KeyManager;
229 ///
230 /// let manager = KeyManager::new();
231 /// manager.register_key("client_1".to_string(), server_key);
232 ///
233 /// assert!(manager.remove_key("client_1"));
234 /// assert!(!manager.remove_key("client_1")); // Already removed
235 /// ```
236 pub fn remove_key(&self, client_id: &str) -> bool {
237 self.server_keys.remove(client_id).is_some()
238 }
239
240 /// Get the number of registered keys
241 ///
242 /// # Returns
243 ///
244 /// The count of currently registered client keys
245 ///
246 /// # Example
247 ///
248 /// ```rust,ignore
249 /// use amaters_core::compute::KeyManager;
250 ///
251 /// let manager = KeyManager::new();
252 /// assert_eq!(manager.key_count(), 0);
253 ///
254 /// manager.register_key("client_1".to_string(), server_key);
255 /// assert_eq!(manager.key_count(), 1);
256 /// ```
257 pub fn key_count(&self) -> usize {
258 self.server_keys.len()
259 }
260
261 /// Clear all registered keys (including global)
262 ///
263 /// This removes all client keys and clears the global key.
264 ///
265 /// # Example
266 ///
267 /// ```rust,ignore
268 /// use amaters_core::compute::KeyManager;
269 ///
270 /// let manager = KeyManager::new();
271 /// manager.register_key("client_1".to_string(), server_key);
272 /// manager.set_global("client_1")?;
273 ///
274 /// manager.clear();
275 /// assert_eq!(manager.key_count(), 0);
276 /// assert!(manager.get_global().is_none());
277 /// ```
278 pub fn clear(&self) {
279 self.server_keys.clear();
280 let mut global = self.global_key.write();
281 *global = None;
282 }
283}
284
285#[cfg(all(test, feature = "compute"))]
286mod tests {
287 use super::*;
288 use crate::compute::FheKeyPair;
289
290 #[test]
291 fn test_key_manager_new() {
292 let manager = KeyManager::new();
293 assert_eq!(manager.key_count(), 0);
294 assert!(manager.get_global().is_none());
295 }
296
297 #[test]
298 fn test_register_and_get_key() -> Result<()> {
299 let manager = KeyManager::new();
300 let keypair = FheKeyPair::generate()?;
301
302 manager.register_key("client_1".to_string(), keypair.server_key().clone());
303
304 let retrieved = manager.get_key("client_1");
305 assert!(retrieved.is_some());
306 assert_eq!(manager.key_count(), 1);
307
308 Ok(())
309 }
310
311 #[test]
312 fn test_get_nonexistent_key() {
313 let manager = KeyManager::new();
314 let result = manager.get_key("nonexistent");
315 assert!(result.is_none());
316 }
317
318 #[test]
319 fn test_set_and_get_global() -> Result<()> {
320 let manager = KeyManager::new();
321 let keypair = FheKeyPair::generate()?;
322
323 manager.register_key("default".to_string(), keypair.server_key().clone());
324 manager.set_global("default")?;
325
326 let global = manager.get_global();
327 assert!(global.is_some());
328
329 Ok(())
330 }
331
332 #[test]
333 fn test_set_global_nonexistent_client() {
334 let manager = KeyManager::new();
335 let result = manager.set_global("nonexistent");
336 assert!(result.is_err());
337 }
338
339 #[test]
340 fn test_remove_key() -> Result<()> {
341 let manager = KeyManager::new();
342 let keypair = FheKeyPair::generate()?;
343
344 manager.register_key("client_1".to_string(), keypair.server_key().clone());
345 assert_eq!(manager.key_count(), 1);
346
347 let removed = manager.remove_key("client_1");
348 assert!(removed);
349 assert_eq!(manager.key_count(), 0);
350
351 // Try to remove again
352 let removed_again = manager.remove_key("client_1");
353 assert!(!removed_again);
354
355 Ok(())
356 }
357
358 #[test]
359 fn test_key_count() -> Result<()> {
360 let manager = KeyManager::new();
361 assert_eq!(manager.key_count(), 0);
362
363 let keypair1 = FheKeyPair::generate()?;
364 let keypair2 = FheKeyPair::generate()?;
365
366 manager.register_key("client_1".to_string(), keypair1.server_key().clone());
367 assert_eq!(manager.key_count(), 1);
368
369 manager.register_key("client_2".to_string(), keypair2.server_key().clone());
370 assert_eq!(manager.key_count(), 2);
371
372 Ok(())
373 }
374
375 #[test]
376 fn test_replace_existing_key() -> Result<()> {
377 let manager = KeyManager::new();
378 let keypair1 = FheKeyPair::generate()?;
379 let keypair2 = FheKeyPair::generate()?;
380
381 manager.register_key("client_1".to_string(), keypair1.server_key().clone());
382 manager.register_key("client_1".to_string(), keypair2.server_key().clone());
383
384 // Should still have only 1 key (replaced)
385 assert_eq!(manager.key_count(), 1);
386
387 Ok(())
388 }
389
390 #[test]
391 fn test_clear() -> Result<()> {
392 let manager = KeyManager::new();
393 let keypair1 = FheKeyPair::generate()?;
394 let keypair2 = FheKeyPair::generate()?;
395
396 manager.register_key("client_1".to_string(), keypair1.server_key().clone());
397 manager.register_key("client_2".to_string(), keypair2.server_key().clone());
398 manager.set_global("client_1")?;
399
400 assert_eq!(manager.key_count(), 2);
401 assert!(manager.get_global().is_some());
402
403 manager.clear();
404
405 assert_eq!(manager.key_count(), 0);
406 assert!(manager.get_global().is_none());
407
408 Ok(())
409 }
410
411 #[test]
412 fn test_concurrent_access() -> Result<()> {
413 use std::thread;
414
415 let manager = Arc::new(KeyManager::new());
416 let mut handles = vec![];
417
418 // Spawn multiple threads that register keys
419 for i in 0..10 {
420 let manager_clone = Arc::clone(&manager);
421 let handle = thread::spawn(move || -> Result<()> {
422 let keypair = FheKeyPair::generate()?;
423 let client_id = format!("client_{}", i);
424 manager_clone.register_key(client_id.clone(), keypair.server_key().clone());
425
426 // Try to retrieve it
427 let retrieved = manager_clone.get_key(&client_id);
428 assert!(retrieved.is_some());
429
430 Ok(())
431 });
432 handles.push(handle);
433 }
434
435 // Wait for all threads
436 for handle in handles {
437 handle
438 .join()
439 .expect("Thread panicked")
440 .expect("Thread failed");
441 }
442
443 assert_eq!(manager.key_count(), 10);
444
445 Ok(())
446 }
447}