amaters_core/compute/
keys.rs1use crate::error::{AmateRSError, ErrorContext, Result};
7
8#[cfg(feature = "compute")]
9use tfhe::prelude::*;
10#[cfg(feature = "compute")]
11use tfhe::{ConfigBuilder, FheBool, FheUint8, FheUint16, FheUint32, generate_keys, set_server_key};
12
13#[cfg(feature = "compute")]
18#[derive(Clone)]
19pub struct FheKeyPair {
20 client_key: tfhe::ClientKey,
21 server_key: tfhe::ServerKey,
22}
23
24#[cfg(feature = "compute")]
25impl FheKeyPair {
26 pub fn generate() -> Result<Self> {
30 let config = ConfigBuilder::default().build();
31 let (client_key, server_key) = generate_keys(config);
32
33 Ok(Self {
34 client_key,
35 server_key,
36 })
37 }
38
39 pub fn generate_with_config(config: tfhe::Config) -> Result<Self> {
41 let (client_key, server_key) = generate_keys(config);
42
43 Ok(Self {
44 client_key,
45 server_key,
46 })
47 }
48
49 pub fn client_key(&self) -> &tfhe::ClientKey {
51 &self.client_key
52 }
53
54 pub fn server_key(&self) -> &tfhe::ServerKey {
56 &self.server_key
57 }
58
59 pub fn set_as_global_server_key(&self) {
63 set_server_key(self.server_key.clone());
64 }
65
66 pub fn serialize_client_key(&self) -> Result<Vec<u8>> {
68 bincode::serialize(&self.client_key).map_err(|e| {
69 AmateRSError::Serialization(ErrorContext::new(format!(
70 "Failed to serialize client key: {}",
71 e
72 )))
73 })
74 }
75
76 pub fn serialize_server_key(&self) -> Result<Vec<u8>> {
78 bincode::serialize(&self.server_key).map_err(|e| {
79 AmateRSError::Serialization(ErrorContext::new(format!(
80 "Failed to serialize server key: {}",
81 e
82 )))
83 })
84 }
85
86 pub fn deserialize_client_key(bytes: &[u8]) -> Result<tfhe::ClientKey> {
88 bincode::deserialize(bytes).map_err(|e| {
89 AmateRSError::Deserialization(ErrorContext::new(format!(
90 "Failed to deserialize client key: {}",
91 e
92 )))
93 })
94 }
95
96 pub fn deserialize_server_key(bytes: &[u8]) -> Result<tfhe::ServerKey> {
98 bincode::deserialize(bytes).map_err(|e| {
99 AmateRSError::Deserialization(ErrorContext::new(format!(
100 "Failed to deserialize server key: {}",
101 e
102 )))
103 })
104 }
105
106 pub fn from_serialized(client_key_bytes: &[u8], server_key_bytes: &[u8]) -> Result<Self> {
108 let client_key = Self::deserialize_client_key(client_key_bytes)?;
109 let server_key = Self::deserialize_server_key(server_key_bytes)?;
110
111 Ok(Self {
112 client_key,
113 server_key,
114 })
115 }
116}
117
118#[cfg(not(feature = "compute"))]
120#[derive(Clone, Debug)]
121pub struct FheKeyPair {
122 _phantom: std::marker::PhantomData<()>,
123}
124
125#[cfg(not(feature = "compute"))]
126impl FheKeyPair {
127 pub fn generate() -> Result<Self> {
128 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
129 "FHE compute feature is not enabled".to_string(),
130 )))
131 }
132
133 pub fn serialize_client_key(&self) -> Result<Vec<u8>> {
134 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
135 "FHE compute feature is not enabled".to_string(),
136 )))
137 }
138
139 pub fn serialize_server_key(&self) -> Result<Vec<u8>> {
140 Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
141 "FHE compute feature is not enabled".to_string(),
142 )))
143 }
144}
145
146pub trait KeyStorage: Send + Sync {
151 fn store_client_key(&self, key_id: &str, key: &[u8]) -> Result<()>;
153
154 fn store_server_key(&self, key_id: &str, key: &[u8]) -> Result<()>;
156
157 fn retrieve_client_key(&self, key_id: &str) -> Result<Vec<u8>>;
159
160 fn retrieve_server_key(&self, key_id: &str) -> Result<Vec<u8>>;
162
163 fn delete_keys(&self, key_id: &str) -> Result<()>;
165
166 fn list_key_ids(&self) -> Result<Vec<String>>;
168}
169
170#[derive(Default)]
172pub struct InMemoryKeyStorage {
173 client_keys: std::sync::Arc<dashmap::DashMap<String, Vec<u8>>>,
174 server_keys: std::sync::Arc<dashmap::DashMap<String, Vec<u8>>>,
175}
176
177impl InMemoryKeyStorage {
178 pub fn new() -> Self {
179 Self::default()
180 }
181}
182
183impl KeyStorage for InMemoryKeyStorage {
184 fn store_client_key(&self, key_id: &str, key: &[u8]) -> Result<()> {
185 self.client_keys.insert(key_id.to_string(), key.to_vec());
186 Ok(())
187 }
188
189 fn store_server_key(&self, key_id: &str, key: &[u8]) -> Result<()> {
190 self.server_keys.insert(key_id.to_string(), key.to_vec());
191 Ok(())
192 }
193
194 fn retrieve_client_key(&self, key_id: &str) -> Result<Vec<u8>> {
195 self.client_keys
196 .get(key_id)
197 .map(|entry| entry.value().clone())
198 .ok_or_else(|| {
199 AmateRSError::KeyNotFound(ErrorContext::new(format!(
200 "Client key not found: {}",
201 key_id
202 )))
203 })
204 }
205
206 fn retrieve_server_key(&self, key_id: &str) -> Result<Vec<u8>> {
207 self.server_keys
208 .get(key_id)
209 .map(|entry| entry.value().clone())
210 .ok_or_else(|| {
211 AmateRSError::KeyNotFound(ErrorContext::new(format!(
212 "Server key not found: {}",
213 key_id
214 )))
215 })
216 }
217
218 fn delete_keys(&self, key_id: &str) -> Result<()> {
219 self.client_keys.remove(key_id);
220 self.server_keys.remove(key_id);
221 Ok(())
222 }
223
224 fn list_key_ids(&self) -> Result<Vec<String>> {
225 Ok(self
226 .client_keys
227 .iter()
228 .map(|entry| entry.key().clone())
229 .collect())
230 }
231}
232
233#[cfg(all(test, feature = "compute"))]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_key_generation() -> Result<()> {
239 let _keypair = FheKeyPair::generate()?;
240 Ok(())
242 }
243
244 #[test]
245 fn test_key_serialization() -> Result<()> {
246 let keypair = FheKeyPair::generate()?;
247
248 let client_bytes = keypair.serialize_client_key()?;
249 let server_bytes = keypair.serialize_server_key()?;
250
251 assert!(!client_bytes.is_empty());
252 assert!(!server_bytes.is_empty());
253
254 Ok(())
255 }
256
257 #[test]
258 fn test_key_deserialization() -> Result<()> {
259 let keypair = FheKeyPair::generate()?;
260
261 let client_bytes = keypair.serialize_client_key()?;
262 let server_bytes = keypair.serialize_server_key()?;
263
264 let restored = FheKeyPair::from_serialized(&client_bytes, &server_bytes)?;
265
266 let value = true;
268 let encrypted = FheBool::encrypt(value, restored.client_key());
269 let decrypted: bool = encrypted.decrypt(restored.client_key());
270 assert_eq!(value, decrypted);
271
272 Ok(())
273 }
274
275 #[test]
276 fn test_key_storage() -> Result<()> {
277 let storage = InMemoryKeyStorage::new();
278 let keypair = FheKeyPair::generate()?;
279
280 let client_bytes = keypair.serialize_client_key()?;
281 let server_bytes = keypair.serialize_server_key()?;
282
283 storage.store_client_key("test_key", &client_bytes)?;
284 storage.store_server_key("test_key", &server_bytes)?;
285
286 let retrieved_client = storage.retrieve_client_key("test_key")?;
287 let retrieved_server = storage.retrieve_server_key("test_key")?;
288
289 assert_eq!(client_bytes, retrieved_client);
290 assert_eq!(server_bytes, retrieved_server);
291
292 let key_ids = storage.list_key_ids()?;
293 assert_eq!(key_ids.len(), 1);
294 assert_eq!(key_ids[0], "test_key");
295
296 storage.delete_keys("test_key")?;
297 assert!(storage.retrieve_client_key("test_key").is_err());
298
299 Ok(())
300 }
301}