Skip to main content

amaters_sdk_rust/
fhe.rs

1//! FHE (Fully Homomorphic Encryption) key management and operations
2//!
3//! This module provides client-side encryption and decryption capabilities.
4//! The actual FHE operations are feature-gated and require the `fhe` feature.
5
6use crate::error::{Result, SdkError};
7use amaters_core::{CipherBlob, Key};
8use std::path::Path;
9
10/// FHE client keys for encryption/decryption
11///
12/// When the `fhe` feature is enabled, this wraps TFHE client keys for real
13/// homomorphic encryption. Without the feature, it acts as a passthrough stub.
14#[derive(Clone)]
15pub struct FheKeys {
16    #[cfg(feature = "fhe")]
17    _keys: tfhe::ClientKey,
18    #[cfg(not(feature = "fhe"))]
19    _placeholder: (),
20}
21
22impl FheKeys {
23    /// Generate new FHE keys
24    ///
25    /// This is a computationally expensive operation (can take several seconds)
26    /// when the `fhe` feature is enabled.
27    pub fn generate() -> Result<Self> {
28        #[cfg(feature = "fhe")]
29        {
30            let config = tfhe::ConfigBuilder::default().build();
31            let client_key = tfhe::ClientKey::generate(config);
32            Ok(Self { _keys: client_key })
33        }
34        #[cfg(not(feature = "fhe"))]
35        {
36            Ok(Self { _placeholder: () })
37        }
38    }
39
40    /// Load keys from a file
41    ///
42    /// Reads serialized key data from the given path and deserializes
43    /// using oxicode (when `fhe` + `serialization` features are enabled).
44    pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self> {
45        #[cfg(feature = "fhe")]
46        {
47            let bytes = std::fs::read(path.as_ref())
48                .map_err(|e| SdkError::Fhe(format!("failed to read key file: {}", e)))?;
49            #[cfg(feature = "serialization")]
50            {
51                let client_key: tfhe::ClientKey = oxicode::serde::decode_serde(&bytes)
52                    .map_err(|e| SdkError::Fhe(format!("failed to deserialize keys: {}", e)))?;
53                Ok(Self { _keys: client_key })
54            }
55            #[cfg(not(feature = "serialization"))]
56            {
57                let _ = bytes;
58                Err(SdkError::Fhe(
59                    "serialization feature required for key file loading".to_string(),
60                ))
61            }
62        }
63        #[cfg(not(feature = "fhe"))]
64        {
65            let _ = path;
66            Ok(Self { _placeholder: () })
67        }
68    }
69
70    /// Save keys to a file
71    ///
72    /// Serializes keys using oxicode and writes to the given path
73    /// (when `fhe` + `serialization` features are enabled).
74    pub fn save_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
75        #[cfg(feature = "fhe")]
76        {
77            #[cfg(feature = "serialization")]
78            {
79                let bytes = oxicode::serde::encode_serde(&self._keys)
80                    .map_err(|e| SdkError::Fhe(format!("failed to serialize keys: {}", e)))?;
81                std::fs::write(path.as_ref(), &bytes)
82                    .map_err(|e| SdkError::Fhe(format!("failed to write key file: {}", e)))?;
83                Ok(())
84            }
85            #[cfg(not(feature = "serialization"))]
86            {
87                let _ = path;
88                Err(SdkError::Fhe(
89                    "serialization feature required for key file saving".to_string(),
90                ))
91            }
92        }
93        #[cfg(not(feature = "fhe"))]
94        {
95            let _ = path;
96            Ok(())
97        }
98    }
99
100    /// Serialize keys to bytes using oxicode
101    #[cfg(feature = "serialization")]
102    pub fn to_bytes(&self) -> Result<Vec<u8>> {
103        #[cfg(feature = "fhe")]
104        {
105            oxicode::serde::encode_serde(&self._keys).map_err(|e| {
106                SdkError::Serialization(format!("failed to serialize FHE keys: {}", e))
107            })
108        }
109        #[cfg(not(feature = "fhe"))]
110        {
111            Ok(Vec::new())
112        }
113    }
114
115    /// Deserialize keys from bytes using oxicode
116    #[cfg(feature = "serialization")]
117    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
118        #[cfg(feature = "fhe")]
119        {
120            let client_key: tfhe::ClientKey = oxicode::serde::decode_serde(bytes).map_err(|e| {
121                SdkError::Serialization(format!("failed to deserialize FHE keys: {}", e))
122            })?;
123            Ok(Self { _keys: client_key })
124        }
125        #[cfg(not(feature = "fhe"))]
126        {
127            let _ = bytes;
128            Ok(Self { _placeholder: () })
129        }
130    }
131}
132
133/// FHE encryptor for client-side encryption
134pub struct FheEncryptor {
135    keys: FheKeys,
136}
137
138impl FheEncryptor {
139    /// Create a new encryptor with generated keys
140    pub fn new() -> Result<Self> {
141        Ok(Self {
142            keys: FheKeys::generate()?,
143        })
144    }
145
146    /// Create an encryptor with existing keys
147    pub fn with_keys(keys: FheKeys) -> Self {
148        Self { keys }
149    }
150
151    /// Get a reference to the keys
152    pub fn keys(&self) -> &FheKeys {
153        &self.keys
154    }
155
156    /// Encrypt a value
157    ///
158    /// When `fhe` is enabled, encrypts each byte using TFHE FheUint8 and
159    /// serializes the resulting ciphertexts into a single CipherBlob.
160    /// Without `fhe`, this is a passthrough that wraps plaintext as-is (NOT secure).
161    pub fn encrypt(&self, plaintext: &[u8]) -> Result<CipherBlob> {
162        #[cfg(feature = "fhe")]
163        {
164            use tfhe::prelude::FheTryEncrypt;
165
166            // Encrypt each byte as an FheUint8 and collect serialized ciphertexts
167            let mut encrypted_parts: Vec<Vec<u8>> = Vec::with_capacity(plaintext.len());
168            for &byte in plaintext {
169                let encrypted: tfhe::FheUint8 = tfhe::FheUint8::try_encrypt(byte, &self.keys._keys)
170                    .map_err(|e| SdkError::Fhe(format!("failed to encrypt byte: {}", e)))?;
171                // Serialize each encrypted value
172                #[cfg(feature = "serialization")]
173                {
174                    let serialized = oxicode::serde::encode_serde(&encrypted).map_err(|e| {
175                        SdkError::Fhe(format!("failed to serialize encrypted byte: {}", e))
176                    })?;
177                    encrypted_parts.push(serialized);
178                }
179                #[cfg(not(feature = "serialization"))]
180                {
181                    let _ = encrypted;
182                    return Err(SdkError::Fhe(
183                        "serialization feature required for FHE encryption".to_string(),
184                    ));
185                }
186            }
187            // Pack: [count(u64)] [len1(u64)][data1] [len2(u64)][data2] ...
188            let count = plaintext.len() as u64;
189            let total_size = 8 + encrypted_parts.iter().map(|p| 8 + p.len()).sum::<usize>();
190            let mut blob_data = Vec::with_capacity(total_size);
191            blob_data.extend_from_slice(&count.to_le_bytes());
192            for part in &encrypted_parts {
193                let len = part.len() as u64;
194                blob_data.extend_from_slice(&len.to_le_bytes());
195                blob_data.extend_from_slice(part);
196            }
197            Ok(CipherBlob::new(blob_data))
198        }
199        #[cfg(not(feature = "fhe"))]
200        {
201            // For testing: just wrap the plaintext as-is
202            // WARNING: This is NOT secure - only for development
203            Ok(CipherBlob::new(plaintext.to_vec()))
204        }
205    }
206
207    /// Decrypt a ciphertext
208    ///
209    /// When `fhe` is enabled, deserializes FheUint8 ciphertexts from the blob
210    /// and decrypts each one. Without `fhe`, returns the raw blob data (NOT secure).
211    pub fn decrypt(&self, ciphertext: &CipherBlob) -> Result<Vec<u8>> {
212        #[cfg(feature = "fhe")]
213        {
214            use tfhe::prelude::FheDecrypt;
215
216            let data = ciphertext.to_vec();
217            if data.len() < 8 {
218                return Err(SdkError::Fhe("ciphertext too short".to_string()));
219            }
220            let count = u64::from_le_bytes(
221                data[..8]
222                    .try_into()
223                    .map_err(|_| SdkError::Fhe("invalid ciphertext header".to_string()))?,
224            ) as usize;
225
226            let mut offset = 8usize;
227            let mut plaintext = Vec::with_capacity(count);
228
229            for _ in 0..count {
230                if offset + 8 > data.len() {
231                    return Err(SdkError::Fhe(
232                        "ciphertext truncated: missing length field".to_string(),
233                    ));
234                }
235                let part_len = u64::from_le_bytes(
236                    data[offset..offset + 8]
237                        .try_into()
238                        .map_err(|_| SdkError::Fhe("invalid ciphertext part length".to_string()))?,
239                ) as usize;
240                offset += 8;
241
242                if offset + part_len > data.len() {
243                    return Err(SdkError::Fhe(
244                        "ciphertext truncated: insufficient data".to_string(),
245                    ));
246                }
247
248                #[cfg(feature = "serialization")]
249                {
250                    let encrypted: tfhe::FheUint8 = oxicode::serde::decode_serde(
251                        &data[offset..offset + part_len],
252                    )
253                    .map_err(|e| {
254                        SdkError::Fhe(format!("failed to deserialize encrypted byte: {}", e))
255                    })?;
256                    let byte: u8 = encrypted.decrypt(&self.keys._keys);
257                    plaintext.push(byte);
258                }
259                #[cfg(not(feature = "serialization"))]
260                {
261                    return Err(SdkError::Fhe(
262                        "serialization feature required for FHE decryption".to_string(),
263                    ));
264                }
265
266                offset += part_len;
267            }
268
269            Ok(plaintext)
270        }
271        #[cfg(not(feature = "fhe"))]
272        {
273            // For testing: just unwrap the data as-is
274            // WARNING: This is NOT secure - only for development
275            Ok(ciphertext.to_vec())
276        }
277    }
278
279    /// Encrypt a key
280    pub fn encrypt_key(&self, key: &Key) -> Result<CipherBlob> {
281        #[cfg(feature = "fhe")]
282        {
283            self.encrypt(key.as_bytes())
284        }
285        #[cfg(not(feature = "fhe"))]
286        {
287            // For testing: just wrap the key bytes
288            Ok(CipherBlob::new(key.to_vec()))
289        }
290    }
291
292    /// Batch encrypt multiple values
293    pub fn encrypt_batch(&self, plaintexts: &[&[u8]]) -> Result<Vec<CipherBlob>> {
294        plaintexts.iter().map(|p| self.encrypt(p)).collect()
295    }
296}
297
298impl Default for FheEncryptor {
299    fn default() -> Self {
300        Self::new().expect("failed to create default encryptor")
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_fhe_keys_generate_no_fhe() {
310        #[cfg(not(feature = "fhe"))]
311        {
312            let keys = FheKeys::generate().expect("generate keys should succeed");
313            // Verify save_to_file works (no-op in stub mode)
314            let dir = std::env::temp_dir();
315            let path = dir.join("test_fhe_keys_generate");
316            keys.save_to_file(&path)
317                .expect("save_to_file should succeed in stub mode");
318        }
319    }
320
321    #[test]
322    fn test_encrypt_decrypt_roundtrip_no_fhe() {
323        #[cfg(not(feature = "fhe"))]
324        {
325            let encryptor = FheEncryptor::new().expect("create encryptor");
326            let plaintext = b"hello world roundtrip test";
327            let ciphertext = encryptor.encrypt(plaintext).expect("encrypt");
328            let decrypted = encryptor.decrypt(&ciphertext).expect("decrypt");
329
330            // In stub mode, it should be identity
331            assert_eq!(decrypted, plaintext);
332        }
333    }
334
335    #[test]
336    fn test_file_save_load_roundtrip_no_fhe() {
337        #[cfg(not(feature = "fhe"))]
338        {
339            let dir = std::env::temp_dir();
340            let path = dir.join("test_fhe_keys_save_load");
341
342            let keys = FheKeys::generate().expect("generate keys");
343            keys.save_to_file(&path).expect("save keys");
344
345            let _loaded = FheKeys::load_from_file(&path).expect("load keys");
346            // In stub mode, both are placeholder values, so just verify no error
347        }
348    }
349
350    #[cfg(feature = "serialization")]
351    #[test]
352    fn test_serialization_roundtrip_no_fhe() {
353        #[cfg(not(feature = "fhe"))]
354        {
355            let keys = FheKeys::generate().expect("generate keys");
356            let bytes = keys.to_bytes().expect("serialize keys");
357            let _restored = FheKeys::from_bytes(&bytes).expect("deserialize keys");
358            // In stub mode, to_bytes returns empty vec, from_bytes accepts anything
359        }
360    }
361
362    #[test]
363    fn test_batch_encrypt_no_fhe() {
364        #[cfg(not(feature = "fhe"))]
365        {
366            let encryptor = FheEncryptor::new().expect("create encryptor");
367            let data: Vec<&[u8]> = vec![b"one", b"two", b"three"];
368
369            let encrypted = encryptor.encrypt_batch(&data).expect("batch encrypt");
370            assert_eq!(encrypted.len(), 3);
371
372            // Verify each can be decrypted back
373            for (i, ct) in encrypted.iter().enumerate() {
374                let decrypted = encryptor.decrypt(ct).expect("decrypt");
375                assert_eq!(decrypted, data[i]);
376            }
377        }
378    }
379
380    #[test]
381    fn test_encrypt_key_no_fhe() {
382        #[cfg(not(feature = "fhe"))]
383        {
384            let encryptor = FheEncryptor::new().expect("create encryptor");
385            let key = Key::new(b"test-key-data".to_vec());
386            let cipher = encryptor.encrypt_key(&key).expect("encrypt key");
387            let decrypted = encryptor.decrypt(&cipher).expect("decrypt");
388            assert_eq!(decrypted, key.as_bytes());
389        }
390    }
391
392    #[test]
393    fn test_encryptor_with_keys() {
394        #[cfg(not(feature = "fhe"))]
395        {
396            let keys = FheKeys::generate().expect("generate keys");
397            let encryptor = FheEncryptor::with_keys(keys);
398            let _keys_ref = encryptor.keys();
399            let plaintext = b"test with_keys";
400            let ciphertext = encryptor.encrypt(plaintext).expect("encrypt");
401            let decrypted = encryptor.decrypt(&ciphertext).expect("decrypt");
402            assert_eq!(decrypted, plaintext);
403        }
404    }
405
406    #[test]
407    fn test_empty_plaintext_no_fhe() {
408        #[cfg(not(feature = "fhe"))]
409        {
410            let encryptor = FheEncryptor::new().expect("create encryptor");
411            let plaintext = b"";
412            let ciphertext = encryptor.encrypt(plaintext).expect("encrypt empty");
413            let decrypted = encryptor.decrypt(&ciphertext).expect("decrypt empty");
414            assert_eq!(decrypted, plaintext);
415        }
416    }
417}