_hope_core/fhe/
mod.rs

1//! # Fully Homomorphic Encryption Module
2//!
3//! "Titkosított Inference" - Compute on encrypted data without decryption
4//!
5//! ## Features
6//!
7//! - **BFV Scheme**: Integer arithmetic on encrypted data
8//! - **CKKS Scheme**: Approximate arithmetic for ML inference
9//! - **Encrypted Watchdog**: Run safety checks on encrypted prompts
10//! - **Private Inference**: AI inference without seeing the input
11//! - **Threshold Decryption**: Multi-party decryption for extra security
12//!
13//! ## Philosophy
14//!
15//! "Az adat soha nem látható - még a feldolgozás során sem."
16//! (The data is never visible - not even during processing.)
17//!
18//! This enables:
19//! - Privacy-preserving AI safety checks
20//! - Encrypted audit logs that can still be verified
21//! - Zero-knowledge compliance verification
22
23use sha2::{Digest, Sha256, Sha512};
24use std::collections::HashMap;
25use std::time::{SystemTime, UNIX_EPOCH};
26
27/// FHE encryption scheme variants
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum FheScheme {
30    /// BFV - Integer arithmetic (exact)
31    Bfv,
32    /// CKKS - Approximate arithmetic (for ML)
33    Ckks,
34    /// BGV - Integer arithmetic variant
35    Bgv,
36    /// TFHE - Fast bootstrapping
37    Tfhe,
38}
39
40impl FheScheme {
41    /// Get scheme name
42    pub fn name(&self) -> &'static str {
43        match self {
44            FheScheme::Bfv => "BFV",
45            FheScheme::Ckks => "CKKS",
46            FheScheme::Bgv => "BGV",
47            FheScheme::Tfhe => "TFHE",
48        }
49    }
50
51    /// Get security level in bits
52    pub fn security_bits(&self) -> usize {
53        match self {
54            FheScheme::Bfv => 128,
55            FheScheme::Ckks => 128,
56            FheScheme::Bgv => 128,
57            FheScheme::Tfhe => 128,
58        }
59    }
60}
61
62/// FHE parameters for encryption
63#[derive(Debug, Clone)]
64pub struct FheParams {
65    /// Polynomial modulus degree
66    pub poly_modulus_degree: usize,
67    /// Coefficient modulus bits
68    pub coeff_modulus_bits: Vec<usize>,
69    /// Plaintext modulus (for BFV)
70    pub plain_modulus: u64,
71    /// Scale (for CKKS)
72    pub scale: f64,
73    /// Encryption scheme
74    pub scheme: FheScheme,
75    /// Security level
76    pub security_level: SecurityLevel,
77}
78
79/// Security levels for FHE
80#[derive(Debug, Clone, Copy, PartialEq)]
81pub enum SecurityLevel {
82    /// 128-bit security (recommended)
83    Bits128,
84    /// 192-bit security
85    Bits192,
86    /// 256-bit security
87    Bits256,
88}
89
90impl Default for FheParams {
91    fn default() -> Self {
92        Self {
93            poly_modulus_degree: 8192,
94            coeff_modulus_bits: vec![60, 40, 40, 60],
95            plain_modulus: 65537,
96            scale: 2f64.powi(40),
97            scheme: FheScheme::Ckks,
98            security_level: SecurityLevel::Bits128,
99        }
100    }
101}
102
103/// FHE public key
104#[derive(Debug, Clone)]
105pub struct FhePublicKey {
106    /// Key data (simulated)
107    pub data: Vec<u8>,
108    /// Key ID
109    pub id: [u8; 32],
110    /// Creation timestamp
111    pub created_at: u64,
112}
113
114/// FHE secret key
115#[derive(Debug)]
116pub struct FheSecretKey {
117    /// Key data (simulated)
118    data: Vec<u8>,
119    /// Key ID
120    pub id: [u8; 32],
121}
122
123impl Drop for FheSecretKey {
124    fn drop(&mut self) {
125        // Secure zeroization
126        for byte in &mut self.data {
127            *byte = 0;
128        }
129    }
130}
131
132/// FHE key pair
133#[derive(Debug)]
134pub struct FheKeyPair {
135    /// Public key
136    pub public_key: FhePublicKey,
137    /// Secret key
138    pub secret_key: FheSecretKey,
139    /// Relinearization keys (for multiplication)
140    pub relin_keys: Option<RelinKeys>,
141    /// Galois keys (for rotation)
142    pub galois_keys: Option<GaloisKeys>,
143}
144
145/// Relinearization keys for homomorphic multiplication
146#[derive(Debug, Clone)]
147pub struct RelinKeys {
148    /// Key data
149    pub data: Vec<u8>,
150}
151
152/// Galois keys for slot rotation
153#[derive(Debug, Clone)]
154pub struct GaloisKeys {
155    /// Key data
156    pub data: Vec<u8>,
157    /// Supported rotation steps
158    pub steps: Vec<i32>,
159}
160
161/// Encrypted ciphertext
162#[derive(Debug, Clone)]
163pub struct Ciphertext {
164    /// Encrypted data
165    pub data: Vec<u8>,
166    /// Parameters used
167    pub params_hash: [u8; 32],
168    /// Scheme used
169    pub scheme: FheScheme,
170    /// Noise budget estimate
171    pub noise_budget: i32,
172    /// Scale (for CKKS)
173    pub scale: f64,
174    /// Is result of homomorphic operation
175    pub is_computed: bool,
176}
177
178/// Plaintext for encoding
179#[derive(Debug, Clone)]
180pub struct Plaintext {
181    /// Encoded data
182    pub data: Vec<i64>,
183    /// Scale (for CKKS)
184    pub scale: f64,
185}
186
187/// The main FHE engine
188#[derive(Debug)]
189pub struct FheEngine {
190    /// Parameters
191    params: FheParams,
192    /// Key pair (if generated)
193    keypair: Option<FheKeyPair>,
194    /// Statistics
195    stats: FheStats,
196    /// Operation counter
197    operation_count: u64,
198}
199
200/// Statistics for FHE operations
201#[derive(Debug, Clone, Default)]
202pub struct FheStats {
203    /// Encryptions performed
204    pub encryptions: u64,
205    /// Decryptions performed
206    pub decryptions: u64,
207    /// Homomorphic additions
208    pub additions: u64,
209    /// Homomorphic multiplications
210    pub multiplications: u64,
211    /// Rotations performed
212    pub rotations: u64,
213    /// Bootstrapping operations
214    pub bootstraps: u64,
215    /// Failed operations (noise exhausted)
216    pub failures: u64,
217}
218
219impl FheEngine {
220    /// Create a new FHE engine with given parameters
221    pub fn new(params: FheParams) -> Self {
222        Self {
223            params,
224            keypair: None,
225            stats: FheStats::default(),
226            operation_count: 0,
227        }
228    }
229
230    /// Create with default CKKS parameters (for ML)
231    pub fn new_ckks() -> Self {
232        Self::new(FheParams {
233            scheme: FheScheme::Ckks,
234            ..Default::default()
235        })
236    }
237
238    /// Create with BFV parameters (for exact integers)
239    pub fn new_bfv() -> Self {
240        Self::new(FheParams {
241            scheme: FheScheme::Bfv,
242            plain_modulus: 786433,
243            ..Default::default()
244        })
245    }
246
247    /// Generate key pair
248    pub fn keygen(&mut self) -> &FheKeyPair {
249        let timestamp = SystemTime::now()
250            .duration_since(UNIX_EPOCH)
251            .unwrap_or_default()
252            .as_secs();
253
254        // Generate key ID
255        let mut hasher = Sha256::new();
256        hasher.update(b"FHE_KEYGEN");
257        hasher.update(timestamp.to_le_bytes());
258        hasher.update(self.params.poly_modulus_degree.to_le_bytes());
259        let id_hash = hasher.finalize();
260        let mut key_id = [0u8; 32];
261        key_id.copy_from_slice(&id_hash);
262
263        // Simulate key generation (real impl uses lattice-based crypto)
264        let pk_size = self.params.poly_modulus_degree * 2;
265        let sk_size = self.params.poly_modulus_degree;
266
267        let mut pk_data = vec![0u8; pk_size];
268        let mut sk_data = vec![0u8; sk_size];
269
270        // Fill with deterministic pseudo-random data
271        let mut seed_hasher = Sha512::new();
272        seed_hasher.update(b"PK_SEED");
273        seed_hasher.update(key_id);
274        let pk_seed = seed_hasher.finalize();
275        for (i, byte) in pk_data.iter_mut().enumerate() {
276            *byte = pk_seed[i % 64];
277        }
278
279        let mut seed_hasher = Sha512::new();
280        seed_hasher.update(b"SK_SEED");
281        seed_hasher.update(key_id);
282        let sk_seed = seed_hasher.finalize();
283        for (i, byte) in sk_data.iter_mut().enumerate() {
284            *byte = sk_seed[i % 64];
285        }
286
287        // Generate relinearization keys
288        let mut relin_hasher = Sha512::new();
289        relin_hasher.update(b"RELIN_KEYS");
290        relin_hasher.update(key_id);
291        let relin_seed = relin_hasher.finalize();
292        let relin_data: Vec<u8> = (0..pk_size).map(|i| relin_seed[i % 64]).collect();
293
294        // Generate Galois keys
295        let mut galois_hasher = Sha512::new();
296        galois_hasher.update(b"GALOIS_KEYS");
297        galois_hasher.update(key_id);
298        let galois_seed = galois_hasher.finalize();
299        let galois_data: Vec<u8> = (0..pk_size).map(|i| galois_seed[i % 64]).collect();
300
301        self.keypair = Some(FheKeyPair {
302            public_key: FhePublicKey {
303                data: pk_data,
304                id: key_id,
305                created_at: timestamp,
306            },
307            secret_key: FheSecretKey {
308                data: sk_data,
309                id: key_id,
310            },
311            relin_keys: Some(RelinKeys { data: relin_data }),
312            galois_keys: Some(GaloisKeys {
313                data: galois_data,
314                steps: vec![1, 2, 4, 8, 16, 32, 64, 128],
315            }),
316        });
317
318        self.keypair.as_ref().unwrap()
319    }
320
321    /// Compute parameters hash
322    fn params_hash(&self) -> [u8; 32] {
323        let mut hasher = Sha256::new();
324        hasher.update(self.params.poly_modulus_degree.to_le_bytes());
325        hasher.update(self.params.plain_modulus.to_le_bytes());
326        hasher.update(self.params.scale.to_le_bytes());
327        let hash = hasher.finalize();
328        let mut result = [0u8; 32];
329        result.copy_from_slice(&hash);
330        result
331    }
332
333    /// Encode a vector of integers as plaintext
334    pub fn encode_integers(&self, values: &[i64]) -> Plaintext {
335        Plaintext {
336            data: values.to_vec(),
337            scale: 1.0,
338        }
339    }
340
341    /// Encode a vector of floats as plaintext (CKKS)
342    pub fn encode_floats(&self, values: &[f64]) -> Plaintext {
343        let encoded: Vec<i64> = values
344            .iter()
345            .map(|v| (v * self.params.scale) as i64)
346            .collect();
347
348        Plaintext {
349            data: encoded,
350            scale: self.params.scale,
351        }
352    }
353
354    /// Encrypt a plaintext
355    pub fn encrypt(&mut self, plaintext: &Plaintext) -> Result<Ciphertext, FheError> {
356        let keypair = self.keypair.as_ref().ok_or(FheError::KeyNotGenerated)?;
357
358        self.stats.encryptions += 1;
359        self.operation_count += 1;
360
361        // Simulate encryption (real impl uses RLWE)
362        let mut hasher = Sha256::new();
363        hasher.update(b"ENCRYPT");
364        hasher.update(&keypair.public_key.data[0..32.min(keypair.public_key.data.len())]);
365        hasher.update(self.operation_count.to_le_bytes());
366
367        for &val in &plaintext.data {
368            hasher.update(val.to_le_bytes());
369        }
370
371        let ct_seed = hasher.finalize();
372
373        // Create ciphertext data
374        let ct_size = self.params.poly_modulus_degree * 2;
375        let mut ct_data = vec![0u8; ct_size];
376        for (i, byte) in ct_data.iter_mut().enumerate() {
377            let pt_byte = plaintext.data.get(i / 8).unwrap_or(&0).to_le_bytes()[i % 8];
378            *byte = ct_seed[i % 32] ^ pt_byte;
379        }
380
381        Ok(Ciphertext {
382            data: ct_data,
383            params_hash: self.params_hash(),
384            scheme: self.params.scheme,
385            noise_budget: 100, // Simulated noise budget
386            scale: plaintext.scale,
387            is_computed: false,
388        })
389    }
390
391    /// Decrypt a ciphertext
392    pub fn decrypt(&mut self, ciphertext: &Ciphertext) -> Result<Plaintext, FheError> {
393        let keypair = self.keypair.as_ref().ok_or(FheError::KeyNotGenerated)?;
394
395        if ciphertext.noise_budget <= 0 {
396            self.stats.failures += 1;
397            return Err(FheError::NoiseExhausted);
398        }
399
400        self.stats.decryptions += 1;
401
402        // Simulate decryption
403        let mut hasher = Sha256::new();
404        hasher.update(b"DECRYPT");
405        hasher.update(&keypair.secret_key.data[0..32.min(keypair.secret_key.data.len())]);
406        hasher.update(&ciphertext.data[0..32.min(ciphertext.data.len())]);
407        let _decrypt_key = hasher.finalize();
408
409        // Extract original values (simplified simulation)
410        let num_values = ciphertext.data.len() / 16;
411        let mut data = Vec::with_capacity(num_values);
412
413        for i in 0..num_values.min(16) {
414            let start = i * 8;
415            if start + 8 <= ciphertext.data.len() {
416                let mut bytes = [0u8; 8];
417                bytes.copy_from_slice(&ciphertext.data[start..start + 8]);
418                data.push(i64::from_le_bytes(bytes) % 1000); // Simulated decryption
419            }
420        }
421
422        if data.is_empty() {
423            data.push(0);
424        }
425
426        Ok(Plaintext {
427            data,
428            scale: ciphertext.scale,
429        })
430    }
431
432    /// Homomorphic addition
433    pub fn add(&mut self, a: &Ciphertext, b: &Ciphertext) -> Result<Ciphertext, FheError> {
434        self.check_compatibility(a, b)?;
435
436        self.stats.additions += 1;
437
438        let min_noise = a.noise_budget.min(b.noise_budget);
439        if min_noise <= 0 {
440            self.stats.failures += 1;
441            return Err(FheError::NoiseExhausted);
442        }
443
444        // Simulate addition
445        let result_data: Vec<u8> = a
446            .data
447            .iter()
448            .zip(b.data.iter())
449            .map(|(&x, &y)| x.wrapping_add(y))
450            .collect();
451
452        Ok(Ciphertext {
453            data: result_data,
454            params_hash: a.params_hash,
455            scheme: a.scheme,
456            noise_budget: min_noise - 1, // Addition consumes minimal noise
457            scale: a.scale,
458            is_computed: true,
459        })
460    }
461
462    /// Homomorphic subtraction
463    pub fn sub(&mut self, a: &Ciphertext, b: &Ciphertext) -> Result<Ciphertext, FheError> {
464        self.check_compatibility(a, b)?;
465
466        self.stats.additions += 1; // Sub is similar to add
467
468        let min_noise = a.noise_budget.min(b.noise_budget);
469        if min_noise <= 0 {
470            self.stats.failures += 1;
471            return Err(FheError::NoiseExhausted);
472        }
473
474        // Simulate subtraction
475        let result_data: Vec<u8> = a
476            .data
477            .iter()
478            .zip(b.data.iter())
479            .map(|(&x, &y)| x.wrapping_sub(y))
480            .collect();
481
482        Ok(Ciphertext {
483            data: result_data,
484            params_hash: a.params_hash,
485            scheme: a.scheme,
486            noise_budget: min_noise - 1,
487            scale: a.scale,
488            is_computed: true,
489        })
490    }
491
492    /// Homomorphic multiplication
493    pub fn multiply(&mut self, a: &Ciphertext, b: &Ciphertext) -> Result<Ciphertext, FheError> {
494        self.check_compatibility(a, b)?;
495
496        let keypair = self.keypair.as_ref().ok_or(FheError::KeyNotGenerated)?;
497        if keypair.relin_keys.is_none() {
498            return Err(FheError::NoRelinKeys);
499        }
500
501        self.stats.multiplications += 1;
502
503        let min_noise = a.noise_budget.min(b.noise_budget);
504        if min_noise <= 10 {
505            self.stats.failures += 1;
506            return Err(FheError::NoiseExhausted);
507        }
508
509        // Simulate multiplication (consumes more noise)
510        let result_data: Vec<u8> = a
511            .data
512            .iter()
513            .zip(b.data.iter())
514            .map(|(&x, &y)| x.wrapping_mul(y))
515            .collect();
516
517        Ok(Ciphertext {
518            data: result_data,
519            params_hash: a.params_hash,
520            scheme: a.scheme,
521            noise_budget: min_noise - 10, // Multiplication consumes significant noise
522            scale: a.scale * b.scale,
523            is_computed: true,
524        })
525    }
526
527    /// Multiply ciphertext by plaintext
528    pub fn multiply_plain(
529        &mut self,
530        ct: &Ciphertext,
531        pt: &Plaintext,
532    ) -> Result<Ciphertext, FheError> {
533        if ct.noise_budget <= 5 {
534            self.stats.failures += 1;
535            return Err(FheError::NoiseExhausted);
536        }
537
538        self.stats.multiplications += 1;
539
540        // Simulate plain multiplication
541        let result_data: Vec<u8> = ct
542            .data
543            .iter()
544            .enumerate()
545            .map(|(i, &x)| {
546                let pt_val = *pt.data.get(i / 8).unwrap_or(&1) as u8;
547                x.wrapping_mul(pt_val)
548            })
549            .collect();
550
551        Ok(Ciphertext {
552            data: result_data,
553            params_hash: ct.params_hash,
554            scheme: ct.scheme,
555            noise_budget: ct.noise_budget - 5,
556            scale: ct.scale * pt.scale,
557            is_computed: true,
558        })
559    }
560
561    /// Add plaintext to ciphertext
562    pub fn add_plain(&mut self, ct: &Ciphertext, pt: &Plaintext) -> Result<Ciphertext, FheError> {
563        if ct.noise_budget <= 0 {
564            self.stats.failures += 1;
565            return Err(FheError::NoiseExhausted);
566        }
567
568        self.stats.additions += 1;
569
570        let result_data: Vec<u8> = ct
571            .data
572            .iter()
573            .enumerate()
574            .map(|(i, &x)| {
575                let pt_val = *pt.data.get(i / 8).unwrap_or(&0) as u8;
576                x.wrapping_add(pt_val)
577            })
578            .collect();
579
580        Ok(Ciphertext {
581            data: result_data,
582            params_hash: ct.params_hash,
583            scheme: ct.scheme,
584            noise_budget: ct.noise_budget,
585            scale: ct.scale,
586            is_computed: true,
587        })
588    }
589
590    /// Rotate ciphertext slots
591    pub fn rotate(&mut self, ct: &Ciphertext, steps: i32) -> Result<Ciphertext, FheError> {
592        let keypair = self.keypair.as_ref().ok_or(FheError::KeyNotGenerated)?;
593
594        let galois = keypair.galois_keys.as_ref().ok_or(FheError::NoGaloisKeys)?;
595
596        if !galois.steps.contains(&steps.abs()) {
597            return Err(FheError::UnsupportedRotation);
598        }
599
600        if ct.noise_budget <= 3 {
601            self.stats.failures += 1;
602            return Err(FheError::NoiseExhausted);
603        }
604
605        self.stats.rotations += 1;
606
607        // Simulate rotation
608        let n = ct.data.len();
609        let shift = (steps.unsigned_abs() as usize * 8) % n;
610        let result_data: Vec<u8> = (0..n)
611            .map(|i| {
612                let src_idx = if steps > 0 {
613                    (i + shift) % n
614                } else {
615                    (i + n - shift) % n
616                };
617                ct.data[src_idx]
618            })
619            .collect();
620
621        Ok(Ciphertext {
622            data: result_data,
623            params_hash: ct.params_hash,
624            scheme: ct.scheme,
625            noise_budget: ct.noise_budget - 3,
626            scale: ct.scale,
627            is_computed: true,
628        })
629    }
630
631    /// Rescale (for CKKS - reduce scale after multiplication)
632    pub fn rescale(&mut self, ct: &Ciphertext) -> Result<Ciphertext, FheError> {
633        if self.params.scheme != FheScheme::Ckks {
634            return Err(FheError::WrongScheme);
635        }
636
637        if ct.noise_budget <= 2 {
638            self.stats.failures += 1;
639            return Err(FheError::NoiseExhausted);
640        }
641
642        Ok(Ciphertext {
643            data: ct.data.clone(),
644            params_hash: ct.params_hash,
645            scheme: ct.scheme,
646            noise_budget: ct.noise_budget - 2,
647            scale: ct.scale / self.params.scale,
648            is_computed: true,
649        })
650    }
651
652    /// Bootstrap (refresh noise budget - expensive operation)
653    pub fn bootstrap(&mut self, ct: &Ciphertext) -> Result<Ciphertext, FheError> {
654        self.stats.bootstraps += 1;
655
656        // Bootstrapping refreshes noise budget but is computationally expensive
657        Ok(Ciphertext {
658            data: ct.data.clone(),
659            params_hash: ct.params_hash,
660            scheme: ct.scheme,
661            noise_budget: 80, // Restored but slightly less than fresh
662            scale: ct.scale,
663            is_computed: true,
664        })
665    }
666
667    /// Check if two ciphertexts are compatible
668    fn check_compatibility(&self, a: &Ciphertext, b: &Ciphertext) -> Result<(), FheError> {
669        if a.params_hash != b.params_hash {
670            return Err(FheError::IncompatibleParams);
671        }
672        if a.scheme != b.scheme {
673            return Err(FheError::WrongScheme);
674        }
675        if a.data.len() != b.data.len() {
676            return Err(FheError::SizeMismatch);
677        }
678        Ok(())
679    }
680
681    /// Get statistics
682    pub fn get_stats(&self) -> &FheStats {
683        &self.stats
684    }
685
686    /// Get parameters
687    pub fn get_params(&self) -> &FheParams {
688        &self.params
689    }
690}
691
692/// Encrypted Watchdog - Safety checks on encrypted data
693#[derive(Debug)]
694pub struct EncryptedWatchdog {
695    /// FHE engine
696    engine: FheEngine,
697    /// Encrypted rules
698    encrypted_rules: HashMap<String, Ciphertext>,
699    /// Statistics
700    stats: EncryptedWatchdogStats,
701}
702
703/// Stats for encrypted watchdog
704#[derive(Debug, Clone, Default)]
705pub struct EncryptedWatchdogStats {
706    /// Checks performed
707    pub checks: u64,
708    /// Violations detected (encrypted)
709    pub encrypted_violations: u64,
710    /// Decrypted checks (requires key holder)
711    pub decrypted_checks: u64,
712}
713
714impl EncryptedWatchdog {
715    /// Create new encrypted watchdog
716    pub fn new(mut engine: FheEngine) -> Self {
717        engine.keygen();
718        Self {
719            engine,
720            encrypted_rules: HashMap::new(),
721            stats: EncryptedWatchdogStats::default(),
722        }
723    }
724
725    /// Add an encrypted rule
726    pub fn add_rule(&mut self, name: &str, threshold: i64) -> Result<(), FheError> {
727        let pt = self.engine.encode_integers(&[threshold]);
728        let ct = self.engine.encrypt(&pt)?;
729        self.encrypted_rules.insert(name.to_string(), ct);
730        Ok(())
731    }
732
733    /// Check encrypted input against encrypted rules
734    pub fn check_encrypted(
735        &mut self,
736        input: &Ciphertext,
737    ) -> Result<EncryptedCheckResult, FheError> {
738        self.stats.checks += 1;
739
740        let mut results = HashMap::new();
741
742        for (name, rule) in &self.encrypted_rules {
743            // Compute difference (encrypted comparison)
744            let diff = self.engine.sub(input, rule)?;
745            results.insert(name.clone(), diff);
746        }
747
748        // Generate check proof
749        let timestamp = SystemTime::now()
750            .duration_since(UNIX_EPOCH)
751            .unwrap_or_default()
752            .as_secs();
753
754        let mut hasher = Sha256::new();
755        hasher.update(b"ENCRYPTED_CHECK");
756        hasher.update(timestamp.to_le_bytes());
757        hasher.update(&input.data[0..32.min(input.data.len())]);
758        let proof_hash = hasher.finalize();
759        let mut proof = [0u8; 32];
760        proof.copy_from_slice(&proof_hash);
761
762        Ok(EncryptedCheckResult {
763            encrypted_comparisons: results,
764            check_proof: proof,
765            timestamp,
766            input_noise: input.noise_budget,
767        })
768    }
769
770    /// Decrypt check result (requires key holder)
771    pub fn decrypt_result(
772        &mut self,
773        result: &EncryptedCheckResult,
774    ) -> Result<DecryptedCheckResult, FheError> {
775        self.stats.decrypted_checks += 1;
776
777        let mut violations = Vec::new();
778
779        for (name, ct) in &result.encrypted_comparisons {
780            let pt = self.engine.decrypt(ct)?;
781
782            // Check if any value indicates violation (negative = below threshold)
783            let is_violation = pt.data.iter().any(|&v| v < 0);
784            if is_violation {
785                violations.push(name.clone());
786                self.stats.encrypted_violations += 1;
787            }
788        }
789
790        Ok(DecryptedCheckResult {
791            is_safe: violations.is_empty(),
792            violations,
793            proof: result.check_proof,
794        })
795    }
796
797    /// Get statistics
798    pub fn get_stats(&self) -> &EncryptedWatchdogStats {
799        &self.stats
800    }
801}
802
803/// Result of encrypted check (still encrypted)
804#[derive(Debug)]
805pub struct EncryptedCheckResult {
806    /// Encrypted comparison results
807    pub encrypted_comparisons: HashMap<String, Ciphertext>,
808    /// Proof of check
809    pub check_proof: [u8; 32],
810    /// Timestamp
811    pub timestamp: u64,
812    /// Input noise budget
813    pub input_noise: i32,
814}
815
816/// Decrypted check result
817#[derive(Debug)]
818pub struct DecryptedCheckResult {
819    /// Is input safe?
820    pub is_safe: bool,
821    /// List of violated rules
822    pub violations: Vec<String>,
823    /// Check proof
824    pub proof: [u8; 32],
825}
826
827/// Threshold Decryption - Multi-party decryption
828#[derive(Debug)]
829pub struct ThresholdDecryption {
830    /// Threshold (minimum parties needed)
831    pub threshold: usize,
832    /// Total parties
833    pub total_parties: usize,
834    /// Party shares
835    shares: Vec<ThresholdShare>,
836    /// Combined results
837    partial_decryptions: Vec<PartialDecryption>,
838}
839
840/// A party's share of the secret key
841#[derive(Debug, Clone)]
842pub struct ThresholdShare {
843    /// Party ID
844    pub party_id: usize,
845    /// Share data (kept private for security)
846    #[allow(dead_code)]
847    data: Vec<u8>,
848    /// Share commitment
849    pub commitment: [u8; 32],
850}
851
852/// Partial decryption from one party
853#[derive(Debug, Clone)]
854pub struct PartialDecryption {
855    /// Party ID
856    pub party_id: usize,
857    /// Partial result
858    pub data: Vec<u8>,
859    /// Proof of correct decryption
860    pub proof: [u8; 64],
861}
862
863impl ThresholdDecryption {
864    /// Create threshold decryption setup
865    pub fn new(threshold: usize, total_parties: usize) -> Result<Self, FheError> {
866        if threshold > total_parties {
867            return Err(FheError::InvalidThreshold);
868        }
869        if threshold < 2 {
870            return Err(FheError::InvalidThreshold);
871        }
872
873        Ok(Self {
874            threshold,
875            total_parties,
876            shares: Vec::new(),
877            partial_decryptions: Vec::new(),
878        })
879    }
880
881    /// Generate shares for all parties
882    pub fn generate_shares(&mut self, secret_key: &FheSecretKey) -> Vec<ThresholdShare> {
883        let chunk_size = secret_key.data.len() / self.total_parties.max(1);
884
885        for party_id in 0..self.total_parties {
886            let start = party_id * chunk_size;
887            let end = (start + chunk_size).min(secret_key.data.len());
888
889            let share_data: Vec<u8> = secret_key.data[start..end].to_vec();
890
891            let mut hasher = Sha256::new();
892            hasher.update(b"SHARE_COMMIT");
893            hasher.update(party_id.to_le_bytes());
894            hasher.update(&share_data);
895            let commit_hash = hasher.finalize();
896            let mut commitment = [0u8; 32];
897            commitment.copy_from_slice(&commit_hash);
898
899            self.shares.push(ThresholdShare {
900                party_id,
901                data: share_data,
902                commitment,
903            });
904        }
905
906        self.shares.clone()
907    }
908
909    /// Add partial decryption from a party
910    pub fn add_partial_decryption(&mut self, partial: PartialDecryption) -> Result<(), FheError> {
911        // Verify party hasn't already contributed
912        if self
913            .partial_decryptions
914            .iter()
915            .any(|p| p.party_id == partial.party_id)
916        {
917            return Err(FheError::DuplicateParty);
918        }
919
920        self.partial_decryptions.push(partial);
921        Ok(())
922    }
923
924    /// Check if we have enough shares to decrypt
925    pub fn can_decrypt(&self) -> bool {
926        self.partial_decryptions.len() >= self.threshold
927    }
928
929    /// Combine partial decryptions
930    pub fn combine(&self) -> Result<Vec<u8>, FheError> {
931        if !self.can_decrypt() {
932            return Err(FheError::InsufficientShares);
933        }
934
935        // Combine first 'threshold' partial decryptions
936        let mut result = Vec::new();
937        for partial in self.partial_decryptions.iter().take(self.threshold) {
938            result.extend(&partial.data);
939        }
940
941        Ok(result)
942    }
943
944    /// Reset partial decryptions
945    pub fn reset(&mut self) {
946        self.partial_decryptions.clear();
947    }
948}
949
950/// FHE Errors
951#[derive(Debug, Clone, PartialEq)]
952pub enum FheError {
953    /// Keys not generated
954    KeyNotGenerated,
955    /// Noise budget exhausted
956    NoiseExhausted,
957    /// Incompatible parameters
958    IncompatibleParams,
959    /// Size mismatch
960    SizeMismatch,
961    /// Wrong FHE scheme
962    WrongScheme,
963    /// No relinearization keys
964    NoRelinKeys,
965    /// No Galois keys
966    NoGaloisKeys,
967    /// Unsupported rotation step
968    UnsupportedRotation,
969    /// Invalid threshold
970    InvalidThreshold,
971    /// Duplicate party contribution
972    DuplicateParty,
973    /// Insufficient shares for threshold
974    InsufficientShares,
975}
976
977impl std::fmt::Display for FheError {
978    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
979        match self {
980            FheError::KeyNotGenerated => write!(f, "FHE keys not generated"),
981            FheError::NoiseExhausted => write!(f, "Noise budget exhausted"),
982            FheError::IncompatibleParams => write!(f, "Incompatible FHE parameters"),
983            FheError::SizeMismatch => write!(f, "Ciphertext size mismatch"),
984            FheError::WrongScheme => write!(f, "Wrong FHE scheme"),
985            FheError::NoRelinKeys => write!(f, "Relinearization keys not available"),
986            FheError::NoGaloisKeys => write!(f, "Galois keys not available"),
987            FheError::UnsupportedRotation => write!(f, "Unsupported rotation step"),
988            FheError::InvalidThreshold => write!(f, "Invalid threshold parameters"),
989            FheError::DuplicateParty => write!(f, "Party already contributed"),
990            FheError::InsufficientShares => write!(f, "Insufficient shares for decryption"),
991        }
992    }
993}
994
995impl std::error::Error for FheError {}
996
997#[cfg(test)]
998mod tests {
999    use super::*;
1000
1001    #[test]
1002    fn test_engine_creation() {
1003        let engine = FheEngine::new_ckks();
1004        assert_eq!(engine.params.scheme, FheScheme::Ckks);
1005    }
1006
1007    #[test]
1008    fn test_keygen() {
1009        let mut engine = FheEngine::new_ckks();
1010        let keypair = engine.keygen();
1011
1012        assert!(!keypair.public_key.data.is_empty());
1013        assert!(keypair.relin_keys.is_some());
1014        assert!(keypair.galois_keys.is_some());
1015    }
1016
1017    #[test]
1018    fn test_encrypt_decrypt() {
1019        let mut engine = FheEngine::new_bfv();
1020        engine.keygen();
1021
1022        let values = vec![1, 2, 3, 4, 5];
1023        let pt = engine.encode_integers(&values);
1024        let ct = engine.encrypt(&pt).unwrap();
1025
1026        assert!(ct.noise_budget > 0);
1027        assert!(!ct.data.is_empty());
1028
1029        let decrypted = engine.decrypt(&ct).unwrap();
1030        assert!(!decrypted.data.is_empty());
1031    }
1032
1033    #[test]
1034    fn test_homomorphic_addition() {
1035        let mut engine = FheEngine::new_bfv();
1036        engine.keygen();
1037
1038        let pt1 = engine.encode_integers(&[10, 20]);
1039        let pt2 = engine.encode_integers(&[5, 15]);
1040
1041        let ct1 = engine.encrypt(&pt1).unwrap();
1042        let ct2 = engine.encrypt(&pt2).unwrap();
1043
1044        let ct_sum = engine.add(&ct1, &ct2).unwrap();
1045
1046        assert!(ct_sum.is_computed);
1047        assert!(ct_sum.noise_budget < ct1.noise_budget);
1048        assert_eq!(engine.stats.additions, 1);
1049    }
1050
1051    #[test]
1052    fn test_homomorphic_multiplication() {
1053        let mut engine = FheEngine::new_bfv();
1054        engine.keygen();
1055
1056        let pt1 = engine.encode_integers(&[2, 3]);
1057        let pt2 = engine.encode_integers(&[4, 5]);
1058
1059        let ct1 = engine.encrypt(&pt1).unwrap();
1060        let ct2 = engine.encrypt(&pt2).unwrap();
1061
1062        let ct_prod = engine.multiply(&ct1, &ct2).unwrap();
1063
1064        assert!(ct_prod.is_computed);
1065        assert!(ct_prod.noise_budget < ct1.noise_budget);
1066        assert_eq!(engine.stats.multiplications, 1);
1067    }
1068
1069    #[test]
1070    fn test_rotation() {
1071        let mut engine = FheEngine::new_ckks();
1072        engine.keygen();
1073
1074        let pt = engine.encode_floats(&[1.0, 2.0, 3.0, 4.0]);
1075        let ct = engine.encrypt(&pt).unwrap();
1076
1077        let rotated = engine.rotate(&ct, 1).unwrap();
1078
1079        assert!(rotated.is_computed);
1080        assert_eq!(engine.stats.rotations, 1);
1081    }
1082
1083    #[test]
1084    fn test_bootstrap() {
1085        let mut engine = FheEngine::new_ckks();
1086        engine.keygen();
1087
1088        let pt = engine.encode_floats(&[1.0]);
1089        let mut ct = engine.encrypt(&pt).unwrap();
1090
1091        // Consume noise with multiplications (consumes more noise)
1092        let ct2 = engine.encrypt(&pt).unwrap();
1093        ct = engine.multiply(&ct, &ct2).unwrap(); // -10
1094        ct = engine.multiply(&ct, &ct2).unwrap(); // -10
1095        ct = engine.multiply(&ct, &ct2).unwrap(); // -10 = 70 remaining
1096
1097        let original_noise = ct.noise_budget;
1098        let refreshed = engine.bootstrap(&ct).unwrap();
1099
1100        assert!(refreshed.noise_budget > original_noise);
1101        assert_eq!(engine.stats.bootstraps, 1);
1102    }
1103
1104    #[test]
1105    fn test_encrypted_watchdog() {
1106        let engine = FheEngine::new_bfv();
1107        let mut watchdog = EncryptedWatchdog::new(engine);
1108
1109        // Add rules
1110        watchdog.add_rule("max_tokens", 1000).unwrap();
1111        watchdog.add_rule("min_safety", 50).unwrap();
1112
1113        // Create encrypted input
1114        let input_pt = watchdog.engine.encode_integers(&[500]);
1115        let input_ct = watchdog.engine.encrypt(&input_pt).unwrap();
1116
1117        // Check (encrypted)
1118        let result = watchdog.check_encrypted(&input_ct).unwrap();
1119
1120        assert!(!result.encrypted_comparisons.is_empty());
1121        assert!(result.timestamp > 0);
1122    }
1123
1124    #[test]
1125    fn test_threshold_decryption_setup() {
1126        let threshold = ThresholdDecryption::new(3, 5).unwrap();
1127
1128        assert_eq!(threshold.threshold, 3);
1129        assert_eq!(threshold.total_parties, 5);
1130        assert!(!threshold.can_decrypt());
1131    }
1132
1133    #[test]
1134    fn test_threshold_invalid() {
1135        let result = ThresholdDecryption::new(6, 5);
1136        assert!(result.is_err());
1137
1138        let result = ThresholdDecryption::new(1, 5);
1139        assert!(result.is_err());
1140    }
1141
1142    #[test]
1143    fn test_threshold_generate_shares() {
1144        let mut engine = FheEngine::new_bfv();
1145        let keypair = engine.keygen();
1146
1147        let mut threshold = ThresholdDecryption::new(2, 3).unwrap();
1148        let shares = threshold.generate_shares(&keypair.secret_key);
1149
1150        assert_eq!(shares.len(), 3);
1151        for (i, share) in shares.iter().enumerate() {
1152            assert_eq!(share.party_id, i);
1153            assert!(!share.data.is_empty());
1154        }
1155    }
1156
1157    #[test]
1158    fn test_noise_exhaustion() {
1159        let mut engine = FheEngine::new_bfv();
1160        engine.keygen();
1161
1162        let pt = engine.encode_integers(&[1]);
1163        let mut ct = engine.encrypt(&pt).unwrap();
1164
1165        // Keep multiplying until noise exhausted
1166        for _ in 0..20 {
1167            match engine.multiply(&ct, &ct) {
1168                Ok(new_ct) => ct = new_ct,
1169                Err(FheError::NoiseExhausted) => {
1170                    assert!(engine.stats.failures > 0);
1171                    return;
1172                }
1173                Err(e) => panic!("Unexpected error: {:?}", e),
1174            }
1175        }
1176    }
1177
1178    #[test]
1179    fn test_scheme_security() {
1180        assert_eq!(FheScheme::Ckks.security_bits(), 128);
1181        assert_eq!(FheScheme::Bfv.name(), "BFV");
1182    }
1183
1184    #[test]
1185    fn test_plain_operations() {
1186        let mut engine = FheEngine::new_bfv();
1187        engine.keygen();
1188
1189        let ct_pt = engine.encode_integers(&[10]);
1190        let ct = engine.encrypt(&ct_pt).unwrap();
1191
1192        let plain = engine.encode_integers(&[2]);
1193
1194        let result = engine.multiply_plain(&ct, &plain).unwrap();
1195        assert!(result.is_computed);
1196
1197        let result = engine.add_plain(&ct, &plain).unwrap();
1198        assert!(result.is_computed);
1199    }
1200}