amaters_core/compute/
keys.rs

1//! Key management for FHE operations
2//!
3//! This module handles client and server key generation, serialization,
4//! and management for TFHE operations.
5
6use crate::error::{AmateRSError, ErrorContext, Result};
7
8#[cfg(feature = "compute")]
9use tfhe::prelude::*;
10#[cfg(feature = "compute")]
11use tfhe::{ConfigBuilder, FheBool, FheUint8, FheUint16, FheUint32, generate_keys, set_server_key};
12
13/// Key pair for FHE operations
14///
15/// Contains both client key (for encryption/decryption) and server key (for FHE operations).
16/// The keys are generated together and must be used as a pair.
17#[cfg(feature = "compute")]
18#[derive(Clone)]
19pub struct FheKeyPair {
20    client_key: tfhe::ClientKey,
21    server_key: tfhe::ServerKey,
22}
23
24#[cfg(feature = "compute")]
25impl FheKeyPair {
26    /// Generate a new key pair with default parameters
27    ///
28    /// This uses TFHE's default configuration which provides good security/performance balance.
29    pub fn generate() -> Result<Self> {
30        let config = ConfigBuilder::default().build();
31        let (client_key, server_key) = generate_keys(config);
32
33        Ok(Self {
34            client_key,
35            server_key,
36        })
37    }
38
39    /// Generate a new key pair with custom configuration
40    pub fn generate_with_config(config: tfhe::Config) -> Result<Self> {
41        let (client_key, server_key) = generate_keys(config);
42
43        Ok(Self {
44            client_key,
45            server_key,
46        })
47    }
48
49    /// Get reference to client key
50    pub fn client_key(&self) -> &tfhe::ClientKey {
51        &self.client_key
52    }
53
54    /// Get reference to server key
55    pub fn server_key(&self) -> &tfhe::ServerKey {
56        &self.server_key
57    }
58
59    /// Set this key pair's server key as the global server key
60    ///
61    /// TFHE operations require a server key to be set globally.
62    pub fn set_as_global_server_key(&self) {
63        set_server_key(self.server_key.clone());
64    }
65
66    /// Serialize client key to bytes
67    pub fn serialize_client_key(&self) -> Result<Vec<u8>> {
68        bincode::serialize(&self.client_key).map_err(|e| {
69            AmateRSError::Serialization(ErrorContext::new(format!(
70                "Failed to serialize client key: {}",
71                e
72            )))
73        })
74    }
75
76    /// Serialize server key to bytes
77    pub fn serialize_server_key(&self) -> Result<Vec<u8>> {
78        bincode::serialize(&self.server_key).map_err(|e| {
79            AmateRSError::Serialization(ErrorContext::new(format!(
80                "Failed to serialize server key: {}",
81                e
82            )))
83        })
84    }
85
86    /// Deserialize client key from bytes
87    pub fn deserialize_client_key(bytes: &[u8]) -> Result<tfhe::ClientKey> {
88        bincode::deserialize(bytes).map_err(|e| {
89            AmateRSError::Deserialization(ErrorContext::new(format!(
90                "Failed to deserialize client key: {}",
91                e
92            )))
93        })
94    }
95
96    /// Deserialize server key from bytes
97    pub fn deserialize_server_key(bytes: &[u8]) -> Result<tfhe::ServerKey> {
98        bincode::deserialize(bytes).map_err(|e| {
99            AmateRSError::Deserialization(ErrorContext::new(format!(
100                "Failed to deserialize server key: {}",
101                e
102            )))
103        })
104    }
105
106    /// Create key pair from serialized keys
107    pub fn from_serialized(client_key_bytes: &[u8], server_key_bytes: &[u8]) -> Result<Self> {
108        let client_key = Self::deserialize_client_key(client_key_bytes)?;
109        let server_key = Self::deserialize_server_key(server_key_bytes)?;
110
111        Ok(Self {
112            client_key,
113            server_key,
114        })
115    }
116}
117
118/// Stub implementation when compute feature is disabled
119#[cfg(not(feature = "compute"))]
120#[derive(Clone, Debug)]
121pub struct FheKeyPair {
122    _phantom: std::marker::PhantomData<()>,
123}
124
125#[cfg(not(feature = "compute"))]
126impl FheKeyPair {
127    pub fn generate() -> Result<Self> {
128        Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
129            "FHE compute feature is not enabled".to_string(),
130        )))
131    }
132
133    pub fn serialize_client_key(&self) -> Result<Vec<u8>> {
134        Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
135            "FHE compute feature is not enabled".to_string(),
136        )))
137    }
138
139    pub fn serialize_server_key(&self) -> Result<Vec<u8>> {
140        Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
141            "FHE compute feature is not enabled".to_string(),
142        )))
143    }
144}
145
146/// Key storage interface for managing FHE keys
147///
148/// This trait defines how keys are stored and retrieved.
149/// Implementations can use file system, memory, or remote key management services.
150pub trait KeyStorage: Send + Sync {
151    /// Store client key
152    fn store_client_key(&self, key_id: &str, key: &[u8]) -> Result<()>;
153
154    /// Store server key
155    fn store_server_key(&self, key_id: &str, key: &[u8]) -> Result<()>;
156
157    /// Retrieve client key
158    fn retrieve_client_key(&self, key_id: &str) -> Result<Vec<u8>>;
159
160    /// Retrieve server key
161    fn retrieve_server_key(&self, key_id: &str) -> Result<Vec<u8>>;
162
163    /// Delete keys
164    fn delete_keys(&self, key_id: &str) -> Result<()>;
165
166    /// List all key IDs
167    fn list_key_ids(&self) -> Result<Vec<String>>;
168}
169
170/// In-memory key storage for testing and development
171#[derive(Default)]
172pub struct InMemoryKeyStorage {
173    client_keys: std::sync::Arc<dashmap::DashMap<String, Vec<u8>>>,
174    server_keys: std::sync::Arc<dashmap::DashMap<String, Vec<u8>>>,
175}
176
177impl InMemoryKeyStorage {
178    pub fn new() -> Self {
179        Self::default()
180    }
181}
182
183impl KeyStorage for InMemoryKeyStorage {
184    fn store_client_key(&self, key_id: &str, key: &[u8]) -> Result<()> {
185        self.client_keys.insert(key_id.to_string(), key.to_vec());
186        Ok(())
187    }
188
189    fn store_server_key(&self, key_id: &str, key: &[u8]) -> Result<()> {
190        self.server_keys.insert(key_id.to_string(), key.to_vec());
191        Ok(())
192    }
193
194    fn retrieve_client_key(&self, key_id: &str) -> Result<Vec<u8>> {
195        self.client_keys
196            .get(key_id)
197            .map(|entry| entry.value().clone())
198            .ok_or_else(|| {
199                AmateRSError::KeyNotFound(ErrorContext::new(format!(
200                    "Client key not found: {}",
201                    key_id
202                )))
203            })
204    }
205
206    fn retrieve_server_key(&self, key_id: &str) -> Result<Vec<u8>> {
207        self.server_keys
208            .get(key_id)
209            .map(|entry| entry.value().clone())
210            .ok_or_else(|| {
211                AmateRSError::KeyNotFound(ErrorContext::new(format!(
212                    "Server key not found: {}",
213                    key_id
214                )))
215            })
216    }
217
218    fn delete_keys(&self, key_id: &str) -> Result<()> {
219        self.client_keys.remove(key_id);
220        self.server_keys.remove(key_id);
221        Ok(())
222    }
223
224    fn list_key_ids(&self) -> Result<Vec<String>> {
225        Ok(self
226            .client_keys
227            .iter()
228            .map(|entry| entry.key().clone())
229            .collect())
230    }
231}
232
233#[cfg(all(test, feature = "compute"))]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_key_generation() -> Result<()> {
239        let _keypair = FheKeyPair::generate()?;
240        // Successfully generated keys - test passes
241        Ok(())
242    }
243
244    #[test]
245    fn test_key_serialization() -> Result<()> {
246        let keypair = FheKeyPair::generate()?;
247
248        let client_bytes = keypair.serialize_client_key()?;
249        let server_bytes = keypair.serialize_server_key()?;
250
251        assert!(!client_bytes.is_empty());
252        assert!(!server_bytes.is_empty());
253
254        Ok(())
255    }
256
257    #[test]
258    fn test_key_deserialization() -> Result<()> {
259        let keypair = FheKeyPair::generate()?;
260
261        let client_bytes = keypair.serialize_client_key()?;
262        let server_bytes = keypair.serialize_server_key()?;
263
264        let restored = FheKeyPair::from_serialized(&client_bytes, &server_bytes)?;
265
266        // Verify keys work by doing a simple encryption/decryption
267        let value = true;
268        let encrypted = FheBool::encrypt(value, restored.client_key());
269        let decrypted: bool = encrypted.decrypt(restored.client_key());
270        assert_eq!(value, decrypted);
271
272        Ok(())
273    }
274
275    #[test]
276    fn test_key_storage() -> Result<()> {
277        let storage = InMemoryKeyStorage::new();
278        let keypair = FheKeyPair::generate()?;
279
280        let client_bytes = keypair.serialize_client_key()?;
281        let server_bytes = keypair.serialize_server_key()?;
282
283        storage.store_client_key("test_key", &client_bytes)?;
284        storage.store_server_key("test_key", &server_bytes)?;
285
286        let retrieved_client = storage.retrieve_client_key("test_key")?;
287        let retrieved_server = storage.retrieve_server_key("test_key")?;
288
289        assert_eq!(client_bytes, retrieved_client);
290        assert_eq!(server_bytes, retrieved_server);
291
292        let key_ids = storage.list_key_ids()?;
293        assert_eq!(key_ids.len(), 1);
294        assert_eq!(key_ids[0], "test_key");
295
296        storage.delete_keys("test_key")?;
297        assert!(storage.retrieve_client_key("test_key").is_err());
298
299        Ok(())
300    }
301}