amaters_core/compute/
key_manager.rs

1//! FHE Server Key Management
2//!
3//! This module provides centralized management of TFHE server keys for multiple clients.
4//! Each client can register their server key, and the system can execute FHE operations
5//! using the appropriate key.
6
7use crate::error::{AmateRSError, ErrorContext, Result};
8use dashmap::DashMap;
9use parking_lot::RwLock;
10use std::sync::Arc;
11
12#[cfg(feature = "compute")]
13use tfhe::ServerKey;
14
15/// Type alias for client identifiers
16pub type ClientId = String;
17
18/// Manages FHE server keys for multiple clients
19///
20/// This structure provides thread-safe storage and retrieval of TFHE server keys.
21/// It supports both multi-client scenarios (where each client has their own key)
22/// and single-client scenarios (where a global key is used).
23///
24/// # Example
25///
26/// ```rust,ignore
27/// use amaters_core::compute::KeyManager;
28/// use tfhe::ServerKey;
29///
30/// let manager = KeyManager::new();
31///
32/// // Register a key for a client
33/// let server_key = ServerKey::new(&client_key);
34/// manager.register_key("client_1".to_string(), server_key);
35///
36/// // Retrieve the key
37/// let key = manager.get_key("client_1").expect("Client key should exist");
38///
39/// // Set as global for convenience
40/// manager.set_global("client_1")?;
41/// let global_key = manager.get_global().expect("Global key should be set");
42/// ```
43#[derive(Default)]
44pub struct KeyManager {
45    /// Map of client IDs to their server keys
46    server_keys: DashMap<ClientId, Arc<ServerKey>>,
47
48    /// Optional global server key (for single-client scenarios)
49    global_key: RwLock<Option<Arc<ServerKey>>>,
50}
51
52impl KeyManager {
53    /// Create a new empty key manager
54    ///
55    /// # Example
56    ///
57    /// ```rust
58    /// use amaters_core::compute::KeyManager;
59    ///
60    /// let manager = KeyManager::new();
61    /// assert_eq!(manager.key_count(), 0);
62    /// ```
63    pub fn new() -> Self {
64        Self {
65            server_keys: DashMap::new(),
66            global_key: RwLock::new(None),
67        }
68    }
69
70    /// Register a server key for a specific client
71    ///
72    /// If a key already exists for this client, it will be replaced.
73    ///
74    /// # Arguments
75    ///
76    /// * `client_id` - Unique identifier for the client
77    /// * `key` - The TFHE server key to register
78    ///
79    /// # Example
80    ///
81    /// ```rust,ignore
82    /// use amaters_core::compute::KeyManager;
83    ///
84    /// let manager = KeyManager::new();
85    /// let server_key = ServerKey::new(&client_key);
86    /// manager.register_key("client_1".to_string(), server_key);
87    /// ```
88    #[cfg(feature = "compute")]
89    pub fn register_key(&self, client_id: ClientId, key: ServerKey) {
90        self.server_keys.insert(client_id, Arc::new(key));
91    }
92
93    /// Stub for register_key when compute feature is disabled
94    #[cfg(not(feature = "compute"))]
95    pub fn register_key(&self, _client_id: ClientId, _key: ()) {
96        // No-op when compute is disabled
97    }
98
99    /// Get server key for a specific client
100    ///
101    /// Returns `None` if no key is registered for the given client.
102    ///
103    /// # Arguments
104    ///
105    /// * `client_id` - The client identifier
106    ///
107    /// # Returns
108    ///
109    /// An `Arc<ServerKey>` if found, `None` otherwise
110    ///
111    /// # Example
112    ///
113    /// ```rust,ignore
114    /// use amaters_core::compute::KeyManager;
115    ///
116    /// let manager = KeyManager::new();
117    /// manager.register_key("client_1".to_string(), server_key);
118    ///
119    /// let key = manager.get_key("client_1").expect("Key should exist");
120    /// ```
121    #[cfg(feature = "compute")]
122    pub fn get_key(&self, client_id: &str) -> Option<Arc<ServerKey>> {
123        self.server_keys
124            .get(client_id)
125            .map(|entry| entry.value().clone())
126    }
127
128    /// Stub for get_key when compute feature is disabled
129    #[cfg(not(feature = "compute"))]
130    pub fn get_key(&self, _client_id: &str) -> Option<()> {
131        None
132    }
133
134    /// Set a global server key from a registered client key
135    ///
136    /// This is a convenience method for single-client scenarios where you want
137    /// to use one client's key as the default for all operations.
138    ///
139    /// # Arguments
140    ///
141    /// * `client_id` - The client whose key should become the global key
142    ///
143    /// # Errors
144    ///
145    /// Returns an error if no key is registered for the specified client.
146    ///
147    /// # Example
148    ///
149    /// ```rust,ignore
150    /// use amaters_core::compute::KeyManager;
151    ///
152    /// let manager = KeyManager::new();
153    /// manager.register_key("default".to_string(), server_key);
154    /// manager.set_global("default")?;
155    ///
156    /// let global = manager.get_global().expect("Global key should be set");
157    /// ```
158    pub fn set_global(&self, client_id: &str) -> Result<()> {
159        #[cfg(feature = "compute")]
160        {
161            let key = self.get_key(client_id).ok_or_else(|| {
162                AmateRSError::FheComputation(ErrorContext::new(format!(
163                    "No server key found for client: {}",
164                    client_id
165                )))
166            })?;
167
168            let mut global = self.global_key.write();
169            *global = Some(key);
170            Ok(())
171        }
172
173        #[cfg(not(feature = "compute"))]
174        {
175            let _ = client_id;
176            Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
177                "FHE compute feature is not enabled".to_string(),
178            )))
179        }
180    }
181
182    /// Get the global server key
183    ///
184    /// Returns `None` if no global key has been set.
185    ///
186    /// # Returns
187    ///
188    /// An `Arc<ServerKey>` if a global key is set, `None` otherwise
189    ///
190    /// # Example
191    ///
192    /// ```rust,ignore
193    /// use amaters_core::compute::KeyManager;
194    ///
195    /// let manager = KeyManager::new();
196    /// assert!(manager.get_global().is_none());
197    ///
198    /// manager.register_key("default".to_string(), server_key);
199    /// manager.set_global("default")?;
200    /// assert!(manager.get_global().is_some());
201    /// ```
202    #[cfg(feature = "compute")]
203    pub fn get_global(&self) -> Option<Arc<ServerKey>> {
204        self.global_key.read().clone()
205    }
206
207    /// Stub for get_global when compute feature is disabled
208    #[cfg(not(feature = "compute"))]
209    pub fn get_global(&self) -> Option<()> {
210        None
211    }
212
213    /// Remove a client's key
214    ///
215    /// Returns `true` if the key was found and removed, `false` otherwise.
216    ///
217    /// # Arguments
218    ///
219    /// * `client_id` - The client identifier
220    ///
221    /// # Returns
222    ///
223    /// `true` if a key was removed, `false` if no key was found
224    ///
225    /// # Example
226    ///
227    /// ```rust,ignore
228    /// use amaters_core::compute::KeyManager;
229    ///
230    /// let manager = KeyManager::new();
231    /// manager.register_key("client_1".to_string(), server_key);
232    ///
233    /// assert!(manager.remove_key("client_1"));
234    /// assert!(!manager.remove_key("client_1")); // Already removed
235    /// ```
236    pub fn remove_key(&self, client_id: &str) -> bool {
237        self.server_keys.remove(client_id).is_some()
238    }
239
240    /// Get the number of registered keys
241    ///
242    /// # Returns
243    ///
244    /// The count of currently registered client keys
245    ///
246    /// # Example
247    ///
248    /// ```rust,ignore
249    /// use amaters_core::compute::KeyManager;
250    ///
251    /// let manager = KeyManager::new();
252    /// assert_eq!(manager.key_count(), 0);
253    ///
254    /// manager.register_key("client_1".to_string(), server_key);
255    /// assert_eq!(manager.key_count(), 1);
256    /// ```
257    pub fn key_count(&self) -> usize {
258        self.server_keys.len()
259    }
260
261    /// Clear all registered keys (including global)
262    ///
263    /// This removes all client keys and clears the global key.
264    ///
265    /// # Example
266    ///
267    /// ```rust,ignore
268    /// use amaters_core::compute::KeyManager;
269    ///
270    /// let manager = KeyManager::new();
271    /// manager.register_key("client_1".to_string(), server_key);
272    /// manager.set_global("client_1")?;
273    ///
274    /// manager.clear();
275    /// assert_eq!(manager.key_count(), 0);
276    /// assert!(manager.get_global().is_none());
277    /// ```
278    pub fn clear(&self) {
279        self.server_keys.clear();
280        let mut global = self.global_key.write();
281        *global = None;
282    }
283}
284
285#[cfg(all(test, feature = "compute"))]
286mod tests {
287    use super::*;
288    use crate::compute::FheKeyPair;
289
290    #[test]
291    fn test_key_manager_new() {
292        let manager = KeyManager::new();
293        assert_eq!(manager.key_count(), 0);
294        assert!(manager.get_global().is_none());
295    }
296
297    #[test]
298    fn test_register_and_get_key() -> Result<()> {
299        let manager = KeyManager::new();
300        let keypair = FheKeyPair::generate()?;
301
302        manager.register_key("client_1".to_string(), keypair.server_key().clone());
303
304        let retrieved = manager.get_key("client_1");
305        assert!(retrieved.is_some());
306        assert_eq!(manager.key_count(), 1);
307
308        Ok(())
309    }
310
311    #[test]
312    fn test_get_nonexistent_key() {
313        let manager = KeyManager::new();
314        let result = manager.get_key("nonexistent");
315        assert!(result.is_none());
316    }
317
318    #[test]
319    fn test_set_and_get_global() -> Result<()> {
320        let manager = KeyManager::new();
321        let keypair = FheKeyPair::generate()?;
322
323        manager.register_key("default".to_string(), keypair.server_key().clone());
324        manager.set_global("default")?;
325
326        let global = manager.get_global();
327        assert!(global.is_some());
328
329        Ok(())
330    }
331
332    #[test]
333    fn test_set_global_nonexistent_client() {
334        let manager = KeyManager::new();
335        let result = manager.set_global("nonexistent");
336        assert!(result.is_err());
337    }
338
339    #[test]
340    fn test_remove_key() -> Result<()> {
341        let manager = KeyManager::new();
342        let keypair = FheKeyPair::generate()?;
343
344        manager.register_key("client_1".to_string(), keypair.server_key().clone());
345        assert_eq!(manager.key_count(), 1);
346
347        let removed = manager.remove_key("client_1");
348        assert!(removed);
349        assert_eq!(manager.key_count(), 0);
350
351        // Try to remove again
352        let removed_again = manager.remove_key("client_1");
353        assert!(!removed_again);
354
355        Ok(())
356    }
357
358    #[test]
359    fn test_key_count() -> Result<()> {
360        let manager = KeyManager::new();
361        assert_eq!(manager.key_count(), 0);
362
363        let keypair1 = FheKeyPair::generate()?;
364        let keypair2 = FheKeyPair::generate()?;
365
366        manager.register_key("client_1".to_string(), keypair1.server_key().clone());
367        assert_eq!(manager.key_count(), 1);
368
369        manager.register_key("client_2".to_string(), keypair2.server_key().clone());
370        assert_eq!(manager.key_count(), 2);
371
372        Ok(())
373    }
374
375    #[test]
376    fn test_replace_existing_key() -> Result<()> {
377        let manager = KeyManager::new();
378        let keypair1 = FheKeyPair::generate()?;
379        let keypair2 = FheKeyPair::generate()?;
380
381        manager.register_key("client_1".to_string(), keypair1.server_key().clone());
382        manager.register_key("client_1".to_string(), keypair2.server_key().clone());
383
384        // Should still have only 1 key (replaced)
385        assert_eq!(manager.key_count(), 1);
386
387        Ok(())
388    }
389
390    #[test]
391    fn test_clear() -> Result<()> {
392        let manager = KeyManager::new();
393        let keypair1 = FheKeyPair::generate()?;
394        let keypair2 = FheKeyPair::generate()?;
395
396        manager.register_key("client_1".to_string(), keypair1.server_key().clone());
397        manager.register_key("client_2".to_string(), keypair2.server_key().clone());
398        manager.set_global("client_1")?;
399
400        assert_eq!(manager.key_count(), 2);
401        assert!(manager.get_global().is_some());
402
403        manager.clear();
404
405        assert_eq!(manager.key_count(), 0);
406        assert!(manager.get_global().is_none());
407
408        Ok(())
409    }
410
411    #[test]
412    fn test_concurrent_access() -> Result<()> {
413        use std::thread;
414
415        let manager = Arc::new(KeyManager::new());
416        let mut handles = vec![];
417
418        // Spawn multiple threads that register keys
419        for i in 0..10 {
420            let manager_clone = Arc::clone(&manager);
421            let handle = thread::spawn(move || -> Result<()> {
422                let keypair = FheKeyPair::generate()?;
423                let client_id = format!("client_{}", i);
424                manager_clone.register_key(client_id.clone(), keypair.server_key().clone());
425
426                // Try to retrieve it
427                let retrieved = manager_clone.get_key(&client_id);
428                assert!(retrieved.is_some());
429
430                Ok(())
431            });
432            handles.push(handle);
433        }
434
435        // Wait for all threads
436        for handle in handles {
437            handle
438                .join()
439                .expect("Thread panicked")
440                .expect("Thread failed");
441        }
442
443        assert_eq!(manager.key_count(), 10);
444
445        Ok(())
446    }
447}