1use crate::error::{Result, SdkError};
7use amaters_core::{CipherBlob, Key};
8use std::path::Path;
9
10#[derive(Clone)]
15pub struct FheKeys {
16 #[cfg(feature = "fhe")]
17 _keys: tfhe::ClientKey,
18 #[cfg(not(feature = "fhe"))]
19 _placeholder: (),
20}
21
22impl FheKeys {
23 pub fn generate() -> Result<Self> {
28 #[cfg(feature = "fhe")]
29 {
30 let config = tfhe::ConfigBuilder::default().build();
31 let client_key = tfhe::ClientKey::generate(config);
32 Ok(Self { _keys: client_key })
33 }
34 #[cfg(not(feature = "fhe"))]
35 {
36 Ok(Self { _placeholder: () })
37 }
38 }
39
40 pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self> {
45 #[cfg(feature = "fhe")]
46 {
47 let bytes = std::fs::read(path.as_ref())
48 .map_err(|e| SdkError::Fhe(format!("failed to read key file: {}", e)))?;
49 #[cfg(feature = "serialization")]
50 {
51 let client_key: tfhe::ClientKey = oxicode::serde::decode_serde(&bytes)
52 .map_err(|e| SdkError::Fhe(format!("failed to deserialize keys: {}", e)))?;
53 Ok(Self { _keys: client_key })
54 }
55 #[cfg(not(feature = "serialization"))]
56 {
57 let _ = bytes;
58 Err(SdkError::Fhe(
59 "serialization feature required for key file loading".to_string(),
60 ))
61 }
62 }
63 #[cfg(not(feature = "fhe"))]
64 {
65 let _ = path;
66 Ok(Self { _placeholder: () })
67 }
68 }
69
70 pub fn save_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
75 #[cfg(feature = "fhe")]
76 {
77 #[cfg(feature = "serialization")]
78 {
79 let bytes = oxicode::serde::encode_serde(&self._keys)
80 .map_err(|e| SdkError::Fhe(format!("failed to serialize keys: {}", e)))?;
81 std::fs::write(path.as_ref(), &bytes)
82 .map_err(|e| SdkError::Fhe(format!("failed to write key file: {}", e)))?;
83 Ok(())
84 }
85 #[cfg(not(feature = "serialization"))]
86 {
87 let _ = path;
88 Err(SdkError::Fhe(
89 "serialization feature required for key file saving".to_string(),
90 ))
91 }
92 }
93 #[cfg(not(feature = "fhe"))]
94 {
95 let _ = path;
96 Ok(())
97 }
98 }
99
100 #[cfg(feature = "serialization")]
102 pub fn to_bytes(&self) -> Result<Vec<u8>> {
103 #[cfg(feature = "fhe")]
104 {
105 oxicode::serde::encode_serde(&self._keys).map_err(|e| {
106 SdkError::Serialization(format!("failed to serialize FHE keys: {}", e))
107 })
108 }
109 #[cfg(not(feature = "fhe"))]
110 {
111 Ok(Vec::new())
112 }
113 }
114
115 #[cfg(feature = "serialization")]
117 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
118 #[cfg(feature = "fhe")]
119 {
120 let client_key: tfhe::ClientKey = oxicode::serde::decode_serde(bytes).map_err(|e| {
121 SdkError::Serialization(format!("failed to deserialize FHE keys: {}", e))
122 })?;
123 Ok(Self { _keys: client_key })
124 }
125 #[cfg(not(feature = "fhe"))]
126 {
127 let _ = bytes;
128 Ok(Self { _placeholder: () })
129 }
130 }
131}
132
133pub struct FheEncryptor {
135 keys: FheKeys,
136}
137
138impl FheEncryptor {
139 pub fn new() -> Result<Self> {
141 Ok(Self {
142 keys: FheKeys::generate()?,
143 })
144 }
145
146 pub fn with_keys(keys: FheKeys) -> Self {
148 Self { keys }
149 }
150
151 pub fn keys(&self) -> &FheKeys {
153 &self.keys
154 }
155
156 pub fn encrypt(&self, plaintext: &[u8]) -> Result<CipherBlob> {
162 #[cfg(feature = "fhe")]
163 {
164 use tfhe::prelude::FheTryEncrypt;
165
166 let mut encrypted_parts: Vec<Vec<u8>> = Vec::with_capacity(plaintext.len());
168 for &byte in plaintext {
169 let encrypted: tfhe::FheUint8 = tfhe::FheUint8::try_encrypt(byte, &self.keys._keys)
170 .map_err(|e| SdkError::Fhe(format!("failed to encrypt byte: {}", e)))?;
171 #[cfg(feature = "serialization")]
173 {
174 let serialized = oxicode::serde::encode_serde(&encrypted).map_err(|e| {
175 SdkError::Fhe(format!("failed to serialize encrypted byte: {}", e))
176 })?;
177 encrypted_parts.push(serialized);
178 }
179 #[cfg(not(feature = "serialization"))]
180 {
181 let _ = encrypted;
182 return Err(SdkError::Fhe(
183 "serialization feature required for FHE encryption".to_string(),
184 ));
185 }
186 }
187 let count = plaintext.len() as u64;
189 let total_size = 8 + encrypted_parts.iter().map(|p| 8 + p.len()).sum::<usize>();
190 let mut blob_data = Vec::with_capacity(total_size);
191 blob_data.extend_from_slice(&count.to_le_bytes());
192 for part in &encrypted_parts {
193 let len = part.len() as u64;
194 blob_data.extend_from_slice(&len.to_le_bytes());
195 blob_data.extend_from_slice(part);
196 }
197 Ok(CipherBlob::new(blob_data))
198 }
199 #[cfg(not(feature = "fhe"))]
200 {
201 Ok(CipherBlob::new(plaintext.to_vec()))
204 }
205 }
206
207 pub fn decrypt(&self, ciphertext: &CipherBlob) -> Result<Vec<u8>> {
212 #[cfg(feature = "fhe")]
213 {
214 use tfhe::prelude::FheDecrypt;
215
216 let data = ciphertext.to_vec();
217 if data.len() < 8 {
218 return Err(SdkError::Fhe("ciphertext too short".to_string()));
219 }
220 let count = u64::from_le_bytes(
221 data[..8]
222 .try_into()
223 .map_err(|_| SdkError::Fhe("invalid ciphertext header".to_string()))?,
224 ) as usize;
225
226 let mut offset = 8usize;
227 let mut plaintext = Vec::with_capacity(count);
228
229 for _ in 0..count {
230 if offset + 8 > data.len() {
231 return Err(SdkError::Fhe(
232 "ciphertext truncated: missing length field".to_string(),
233 ));
234 }
235 let part_len = u64::from_le_bytes(
236 data[offset..offset + 8]
237 .try_into()
238 .map_err(|_| SdkError::Fhe("invalid ciphertext part length".to_string()))?,
239 ) as usize;
240 offset += 8;
241
242 if offset + part_len > data.len() {
243 return Err(SdkError::Fhe(
244 "ciphertext truncated: insufficient data".to_string(),
245 ));
246 }
247
248 #[cfg(feature = "serialization")]
249 {
250 let encrypted: tfhe::FheUint8 = oxicode::serde::decode_serde(
251 &data[offset..offset + part_len],
252 )
253 .map_err(|e| {
254 SdkError::Fhe(format!("failed to deserialize encrypted byte: {}", e))
255 })?;
256 let byte: u8 = encrypted.decrypt(&self.keys._keys);
257 plaintext.push(byte);
258 }
259 #[cfg(not(feature = "serialization"))]
260 {
261 return Err(SdkError::Fhe(
262 "serialization feature required for FHE decryption".to_string(),
263 ));
264 }
265
266 offset += part_len;
267 }
268
269 Ok(plaintext)
270 }
271 #[cfg(not(feature = "fhe"))]
272 {
273 Ok(ciphertext.to_vec())
276 }
277 }
278
279 pub fn encrypt_key(&self, key: &Key) -> Result<CipherBlob> {
281 #[cfg(feature = "fhe")]
282 {
283 self.encrypt(key.as_bytes())
284 }
285 #[cfg(not(feature = "fhe"))]
286 {
287 Ok(CipherBlob::new(key.to_vec()))
289 }
290 }
291
292 pub fn encrypt_batch(&self, plaintexts: &[&[u8]]) -> Result<Vec<CipherBlob>> {
294 plaintexts.iter().map(|p| self.encrypt(p)).collect()
295 }
296}
297
298impl Default for FheEncryptor {
299 fn default() -> Self {
300 Self::new().expect("failed to create default encryptor")
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_fhe_keys_generate_no_fhe() {
310 #[cfg(not(feature = "fhe"))]
311 {
312 let keys = FheKeys::generate().expect("generate keys should succeed");
313 let dir = std::env::temp_dir();
315 let path = dir.join("test_fhe_keys_generate");
316 keys.save_to_file(&path)
317 .expect("save_to_file should succeed in stub mode");
318 }
319 }
320
321 #[test]
322 fn test_encrypt_decrypt_roundtrip_no_fhe() {
323 #[cfg(not(feature = "fhe"))]
324 {
325 let encryptor = FheEncryptor::new().expect("create encryptor");
326 let plaintext = b"hello world roundtrip test";
327 let ciphertext = encryptor.encrypt(plaintext).expect("encrypt");
328 let decrypted = encryptor.decrypt(&ciphertext).expect("decrypt");
329
330 assert_eq!(decrypted, plaintext);
332 }
333 }
334
335 #[test]
336 fn test_file_save_load_roundtrip_no_fhe() {
337 #[cfg(not(feature = "fhe"))]
338 {
339 let dir = std::env::temp_dir();
340 let path = dir.join("test_fhe_keys_save_load");
341
342 let keys = FheKeys::generate().expect("generate keys");
343 keys.save_to_file(&path).expect("save keys");
344
345 let _loaded = FheKeys::load_from_file(&path).expect("load keys");
346 }
348 }
349
350 #[cfg(feature = "serialization")]
351 #[test]
352 fn test_serialization_roundtrip_no_fhe() {
353 #[cfg(not(feature = "fhe"))]
354 {
355 let keys = FheKeys::generate().expect("generate keys");
356 let bytes = keys.to_bytes().expect("serialize keys");
357 let _restored = FheKeys::from_bytes(&bytes).expect("deserialize keys");
358 }
360 }
361
362 #[test]
363 fn test_batch_encrypt_no_fhe() {
364 #[cfg(not(feature = "fhe"))]
365 {
366 let encryptor = FheEncryptor::new().expect("create encryptor");
367 let data: Vec<&[u8]> = vec![b"one", b"two", b"three"];
368
369 let encrypted = encryptor.encrypt_batch(&data).expect("batch encrypt");
370 assert_eq!(encrypted.len(), 3);
371
372 for (i, ct) in encrypted.iter().enumerate() {
374 let decrypted = encryptor.decrypt(ct).expect("decrypt");
375 assert_eq!(decrypted, data[i]);
376 }
377 }
378 }
379
380 #[test]
381 fn test_encrypt_key_no_fhe() {
382 #[cfg(not(feature = "fhe"))]
383 {
384 let encryptor = FheEncryptor::new().expect("create encryptor");
385 let key = Key::new(b"test-key-data".to_vec());
386 let cipher = encryptor.encrypt_key(&key).expect("encrypt key");
387 let decrypted = encryptor.decrypt(&cipher).expect("decrypt");
388 assert_eq!(decrypted, key.as_bytes());
389 }
390 }
391
392 #[test]
393 fn test_encryptor_with_keys() {
394 #[cfg(not(feature = "fhe"))]
395 {
396 let keys = FheKeys::generate().expect("generate keys");
397 let encryptor = FheEncryptor::with_keys(keys);
398 let _keys_ref = encryptor.keys();
399 let plaintext = b"test with_keys";
400 let ciphertext = encryptor.encrypt(plaintext).expect("encrypt");
401 let decrypted = encryptor.decrypt(&ciphertext).expect("decrypt");
402 assert_eq!(decrypted, plaintext);
403 }
404 }
405
406 #[test]
407 fn test_empty_plaintext_no_fhe() {
408 #[cfg(not(feature = "fhe"))]
409 {
410 let encryptor = FheEncryptor::new().expect("create encryptor");
411 let plaintext = b"";
412 let ciphertext = encryptor.encrypt(plaintext).expect("encrypt empty");
413 let decrypted = encryptor.decrypt(&ciphertext).expect("decrypt empty");
414 assert_eq!(decrypted, plaintext);
415 }
416 }
417}