amaters_core/compute/
mod.rs

1//! Compute engine module (Yata - The Eight-Span Mirror)
2//!
3//! This module provides FHE circuit execution on encrypted data using TFHE.
4//!
5//! # Architecture
6//!
7//! The compute engine consists of four main components:
8//!
9//! - **Key Management** (`keys`): Client and server key generation, serialization
10//! - **FHE Operations** (`operations`): Encrypted boolean and integer operations
11//! - **Circuit Compilation** (`circuit`): AST representation and type inference
12//! - **FHE Executor** (`FheExecutor`): Circuit execution engine
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use amaters_core::compute::{FheKeyPair, CircuitBuilder, EncryptedType, FheExecutor};
18//!
19//! // Generate keys
20//! let keypair = FheKeyPair::generate()?;
21//! keypair.set_as_global_server_key();
22//!
23//! // Build a circuit: a + b
24//! let mut builder = CircuitBuilder::new();
25//! builder.declare_variable("a", EncryptedType::U8)
26//!        .declare_variable("b", EncryptedType::U8);
27//!
28//! let a = builder.load("a");
29//! let b = builder.load("b");
30//! let sum = builder.add(a, b);
31//! let circuit = builder.build(sum)?;
32//!
33//! // Execute circuit
34//! let executor = FheExecutor::new();
35//! let result = executor.execute(&circuit, &inputs)?;
36//! ```
37
38pub mod circuit;
39pub mod key_manager;
40pub mod keys;
41pub mod operations;
42pub mod optimizer;
43pub mod predicate;
44
45#[cfg(test)]
46mod filter_tests;
47
48// Re-export commonly used types
49pub use circuit::{
50    BinaryOperator, Circuit, CircuitBuilder, CircuitNode, CircuitValue, CompareOperator,
51    EncryptedType, UnaryOperator,
52};
53pub use key_manager::{ClientId, KeyManager};
54pub use keys::{FheKeyPair, InMemoryKeyStorage, KeyStorage};
55pub use operations::{EncryptedBool, EncryptedU8, EncryptedU16, EncryptedU32, EncryptedU64};
56pub use optimizer::{CircuitOptimizer, DependencyGraph, NodeId, OptimizationStats};
57pub use predicate::{PredicateCompiler, compile_predicate};
58
59use crate::error::{AmateRSError, ErrorContext, Result};
60use crate::types::CipherBlob;
61use std::collections::HashMap;
62
63/// FHE executor for circuit execution
64///
65/// This executor takes a compiled circuit and encrypted inputs,
66/// executes the circuit on the encrypted data, and returns encrypted results.
67#[derive(Debug, Clone)]
68pub struct FheExecutor {
69    optimizer: CircuitOptimizer,
70    optimization_enabled: bool,
71}
72
73impl FheExecutor {
74    /// Create a new FHE executor with optimization enabled
75    pub fn new() -> Self {
76        Self {
77            optimizer: CircuitOptimizer::new(),
78            optimization_enabled: true,
79        }
80    }
81
82    /// Create a new FHE executor with optimization control
83    pub fn with_optimization(enable: bool) -> Self {
84        Self {
85            optimizer: if enable {
86                CircuitOptimizer::new()
87            } else {
88                CircuitOptimizer::disabled()
89            },
90            optimization_enabled: enable,
91        }
92    }
93
94    /// Get the optimization statistics from the last execution
95    pub fn optimization_stats(&self) -> &OptimizationStats {
96        self.optimizer.stats()
97    }
98
99    /// Get the dependency graph from the last execution
100    pub fn dependency_graph(&self) -> &DependencyGraph {
101        self.optimizer.dependency_graph()
102    }
103
104    /// Execute FHE circuit on encrypted data
105    ///
106    /// # Arguments
107    ///
108    /// * `circuit` - The compiled circuit to execute
109    /// * `inputs` - Map of variable names to encrypted values (CipherBlob)
110    ///
111    /// # Returns
112    ///
113    /// The encrypted result as a CipherBlob
114    #[cfg(feature = "compute")]
115    pub fn execute(
116        &self,
117        circuit: &Circuit,
118        inputs: &HashMap<String, CipherBlob>,
119    ) -> Result<CipherBlob> {
120        // Validate circuit
121        circuit.validate()?;
122
123        // Check that all required inputs are provided
124        for var_name in circuit.variable_types.keys() {
125            if !inputs.contains_key(var_name) {
126                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
127                    "Missing input for variable: {}",
128                    var_name
129                ))));
130            }
131        }
132
133        // Optimize circuit if enabled
134        let optimized = if self.optimization_enabled {
135            // Need mutable reference for optimizer
136            let mut optimizer = self.optimizer.clone();
137            optimizer.optimize(circuit.clone())?
138        } else {
139            circuit.clone()
140        };
141
142        // Execute the circuit
143        let result_value = self.execute_node(&optimized.root, inputs, &optimized.variable_types)?;
144
145        // Serialize result to CipherBlob
146        match result_value {
147            EncryptedValue::Bool(v) => v.to_cipher_blob(),
148            EncryptedValue::U8(v) => v.to_cipher_blob(),
149            EncryptedValue::U16(v) => v.to_cipher_blob(),
150            EncryptedValue::U32(v) => v.to_cipher_blob(),
151            EncryptedValue::U64(v) => v.to_cipher_blob(),
152        }
153    }
154
155    /// Stub implementation when compute feature is disabled
156    #[cfg(not(feature = "compute"))]
157    pub fn execute(
158        &self,
159        _circuit: &Circuit,
160        _inputs: &HashMap<String, CipherBlob>,
161    ) -> Result<CipherBlob> {
162        Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
163            "FHE compute feature is not enabled".to_string(),
164        )))
165    }
166
167    /// Execute a single circuit node recursively
168    #[cfg(feature = "compute")]
169    #[allow(clippy::only_used_in_recursion)]
170    fn execute_node(
171        &self,
172        node: &CircuitNode,
173        inputs: &HashMap<String, CipherBlob>,
174        variable_types: &HashMap<String, EncryptedType>,
175    ) -> Result<EncryptedValue> {
176        match node {
177            CircuitNode::Load(name) => {
178                let blob = inputs.get(name).ok_or_else(|| {
179                    AmateRSError::FheComputation(ErrorContext::new(format!(
180                        "Missing input: {}",
181                        name
182                    )))
183                })?;
184
185                let var_type = variable_types.get(name).ok_or_else(|| {
186                    AmateRSError::FheComputation(ErrorContext::new(format!(
187                        "Unknown variable type: {}",
188                        name
189                    )))
190                })?;
191
192                match var_type {
193                    EncryptedType::Bool => {
194                        Ok(EncryptedValue::Bool(EncryptedBool::from_cipher_blob(blob)?))
195                    }
196                    EncryptedType::U8 => {
197                        Ok(EncryptedValue::U8(EncryptedU8::from_cipher_blob(blob)?))
198                    }
199                    EncryptedType::U16 => {
200                        Ok(EncryptedValue::U16(EncryptedU16::from_cipher_blob(blob)?))
201                    }
202                    EncryptedType::U32 => {
203                        Ok(EncryptedValue::U32(EncryptedU32::from_cipher_blob(blob)?))
204                    }
205                    EncryptedType::U64 => {
206                        Ok(EncryptedValue::U64(EncryptedU64::from_cipher_blob(blob)?))
207                    }
208                }
209            }
210
211            CircuitNode::Constant(_value) => {
212                // TODO: Support encrypted constants
213                Err(AmateRSError::FheComputation(ErrorContext::new(
214                    "Encrypted constants not yet supported".to_string(),
215                )))
216            }
217
218            CircuitNode::BinaryOp { op, left, right } => {
219                let left_val = self.execute_node(left, inputs, variable_types)?;
220                let right_val = self.execute_node(right, inputs, variable_types)?;
221
222                match (op, left_val, right_val) {
223                    // Boolean operations
224                    (BinaryOperator::And, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
225                        Ok(EncryptedValue::Bool(l.and(&r)))
226                    }
227                    (BinaryOperator::Or, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
228                        Ok(EncryptedValue::Bool(l.or(&r)))
229                    }
230                    (BinaryOperator::Xor, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
231                        Ok(EncryptedValue::Bool(l.xor(&r)))
232                    }
233
234                    // U8 arithmetic
235                    (BinaryOperator::Add, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
236                        Ok(EncryptedValue::U8(l.add(&r)))
237                    }
238                    (BinaryOperator::Sub, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
239                        Ok(EncryptedValue::U8(l.sub(&r)))
240                    }
241                    (BinaryOperator::Mul, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
242                        Ok(EncryptedValue::U8(l.mul(&r)))
243                    }
244
245                    // U16 arithmetic
246                    (BinaryOperator::Add, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
247                        Ok(EncryptedValue::U16(l.add(&r)))
248                    }
249                    (BinaryOperator::Sub, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
250                        Ok(EncryptedValue::U16(l.sub(&r)))
251                    }
252                    (BinaryOperator::Mul, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
253                        Ok(EncryptedValue::U16(l.mul(&r)))
254                    }
255
256                    // U32 arithmetic
257                    (BinaryOperator::Add, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
258                        Ok(EncryptedValue::U32(l.add(&r)))
259                    }
260                    (BinaryOperator::Sub, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
261                        Ok(EncryptedValue::U32(l.sub(&r)))
262                    }
263                    (BinaryOperator::Mul, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
264                        Ok(EncryptedValue::U32(l.mul(&r)))
265                    }
266
267                    // U64 arithmetic
268                    (BinaryOperator::Add, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
269                        Ok(EncryptedValue::U64(l.add(&r)))
270                    }
271                    (BinaryOperator::Sub, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
272                        Ok(EncryptedValue::U64(l.sub(&r)))
273                    }
274                    (BinaryOperator::Mul, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
275                        Ok(EncryptedValue::U64(l.mul(&r)))
276                    }
277
278                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
279                        "Type mismatch in binary operation".to_string(),
280                    ))),
281                }
282            }
283
284            CircuitNode::UnaryOp { op, operand } => {
285                let operand_val = self.execute_node(operand, inputs, variable_types)?;
286
287                match (op, operand_val) {
288                    (UnaryOperator::Not, EncryptedValue::Bool(v)) => {
289                        Ok(EncryptedValue::Bool(v.not()))
290                    }
291
292                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
293                        "Type mismatch in unary operation".to_string(),
294                    ))),
295                }
296            }
297
298            CircuitNode::Compare { op, left, right } => {
299                let left_val = self.execute_node(left, inputs, variable_types)?;
300                let right_val = self.execute_node(right, inputs, variable_types)?;
301
302                match (left_val, right_val) {
303                    (EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
304                        let result = match op {
305                            CompareOperator::Eq => l.eq(&r),
306                            CompareOperator::Ne => l.ne(&r),
307                            CompareOperator::Lt => l.lt(&r),
308                            CompareOperator::Le => l.le(&r),
309                            CompareOperator::Gt => l.gt(&r),
310                            CompareOperator::Ge => l.ge(&r),
311                        };
312                        Ok(EncryptedValue::Bool(result))
313                    }
314
315                    (EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
316                        let result = match op {
317                            CompareOperator::Eq => l.eq(&r),
318                            CompareOperator::Ne => l.ne(&r),
319                            CompareOperator::Lt => l.lt(&r),
320                            CompareOperator::Le => l.le(&r),
321                            CompareOperator::Gt => l.gt(&r),
322                            CompareOperator::Ge => l.ge(&r),
323                        };
324                        Ok(EncryptedValue::Bool(result))
325                    }
326
327                    (EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
328                        let result = match op {
329                            CompareOperator::Eq => l.eq(&r),
330                            CompareOperator::Ne => l.ne(&r),
331                            CompareOperator::Lt => l.lt(&r),
332                            CompareOperator::Le => l.le(&r),
333                            CompareOperator::Gt => l.gt(&r),
334                            CompareOperator::Ge => l.ge(&r),
335                        };
336                        Ok(EncryptedValue::Bool(result))
337                    }
338
339                    (EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
340                        let result = match op {
341                            CompareOperator::Eq => l.eq(&r),
342                            CompareOperator::Ne => l.ne(&r),
343                            CompareOperator::Lt => l.lt(&r),
344                            CompareOperator::Le => l.le(&r),
345                            CompareOperator::Gt => l.gt(&r),
346                            CompareOperator::Ge => l.ge(&r),
347                        };
348                        Ok(EncryptedValue::Bool(result))
349                    }
350
351                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
352                        "Type mismatch in comparison".to_string(),
353                    ))),
354                }
355            }
356        }
357    }
358}
359
360impl Default for FheExecutor {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366/// Internal enum for holding encrypted values during execution
367#[cfg(feature = "compute")]
368enum EncryptedValue {
369    Bool(EncryptedBool),
370    U8(EncryptedU8),
371    U16(EncryptedU16),
372    U32(EncryptedU32),
373    U64(EncryptedU64),
374}
375
376// Legacy types for backward compatibility (to be removed in future versions)
377
378/// Circuit gate (legacy - use CircuitNode instead)
379#[deprecated(since = "0.1.0", note = "Use CircuitNode instead")]
380#[derive(Debug, Clone)]
381pub enum Gate {
382    Add,
383    Mul,
384    Not,
385    Bootstrap,
386}
387
388#[cfg(all(test, feature = "compute"))]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_fhe_executor_basic() -> Result<()> {
394        // Generate keys
395        let keypair = FheKeyPair::generate()?;
396        keypair.set_as_global_server_key();
397
398        // Build circuit: a + b
399        let mut builder = CircuitBuilder::new();
400        builder
401            .declare_variable("a", EncryptedType::U8)
402            .declare_variable("b", EncryptedType::U8);
403
404        let a_node = builder.load("a");
405        let b_node = builder.load("b");
406        let sum_node = builder.add(a_node, b_node);
407
408        let circuit = builder.build(sum_node)?;
409
410        // Prepare inputs
411        let a = EncryptedU8::encrypt(5, keypair.client_key());
412        let b = EncryptedU8::encrypt(3, keypair.client_key());
413
414        let mut inputs = HashMap::new();
415        inputs.insert("a".to_string(), a.to_cipher_blob()?);
416        inputs.insert("b".to_string(), b.to_cipher_blob()?);
417
418        // Execute
419        let executor = FheExecutor::new();
420        let result_blob = executor.execute(&circuit, &inputs)?;
421
422        // Verify
423        let result = EncryptedU8::from_cipher_blob(&result_blob)?;
424        assert_eq!(result.decrypt(keypair.client_key()), 8);
425
426        Ok(())
427    }
428
429    #[test]
430    fn test_fhe_executor_boolean() -> Result<()> {
431        let keypair = FheKeyPair::generate()?;
432        keypair.set_as_global_server_key();
433
434        let mut builder = CircuitBuilder::new();
435        builder
436            .declare_variable("x", EncryptedType::Bool)
437            .declare_variable("y", EncryptedType::Bool);
438
439        let x_node = builder.load("x");
440        let y_node = builder.load("y");
441        let and_node = builder.and(x_node, y_node);
442
443        let circuit = builder.build(and_node)?;
444
445        let x = EncryptedBool::encrypt(true, keypair.client_key());
446        let y = EncryptedBool::encrypt(false, keypair.client_key());
447
448        let mut inputs = HashMap::new();
449        inputs.insert("x".to_string(), x.to_cipher_blob()?);
450        inputs.insert("y".to_string(), y.to_cipher_blob()?);
451
452        let executor = FheExecutor::new();
453        let result_blob = executor.execute(&circuit, &inputs)?;
454
455        let result = EncryptedBool::from_cipher_blob(&result_blob)?;
456        assert!(!result.decrypt(keypair.client_key()));
457
458        Ok(())
459    }
460
461    #[test]
462    fn test_fhe_executor_comparison() -> Result<()> {
463        let keypair = FheKeyPair::generate()?;
464        keypair.set_as_global_server_key();
465
466        let mut builder = CircuitBuilder::new();
467        builder
468            .declare_variable("a", EncryptedType::U8)
469            .declare_variable("b", EncryptedType::U8);
470
471        let a_node = builder.load("a");
472        let b_node = builder.load("b");
473        let gt_node = builder.gt(a_node, b_node);
474
475        let circuit = builder.build(gt_node)?;
476
477        let a = EncryptedU8::encrypt(10, keypair.client_key());
478        let b = EncryptedU8::encrypt(5, keypair.client_key());
479
480        let mut inputs = HashMap::new();
481        inputs.insert("a".to_string(), a.to_cipher_blob()?);
482        inputs.insert("b".to_string(), b.to_cipher_blob()?);
483
484        let executor = FheExecutor::new();
485        let result_blob = executor.execute(&circuit, &inputs)?;
486
487        let result = EncryptedBool::from_cipher_blob(&result_blob)?;
488        assert!(result.decrypt(keypair.client_key()));
489
490        Ok(())
491    }
492
493    #[test]
494    fn test_missing_input_error() -> Result<()> {
495        let keypair = FheKeyPair::generate()?;
496        keypair.set_as_global_server_key();
497
498        let mut builder = CircuitBuilder::new();
499        builder.declare_variable("a", EncryptedType::U8);
500
501        let a_node = builder.load("a");
502        let circuit = builder.build(a_node)?;
503
504        let inputs = HashMap::new(); // No inputs provided
505
506        let executor = FheExecutor::new();
507        let result = executor.execute(&circuit, &inputs);
508
509        assert!(result.is_err());
510
511        Ok(())
512    }
513}