Skip to main content

amaters_core/compute/
mod.rs

1//! Compute engine module (Yata - The Eight-Span Mirror)
2//!
3//! This module provides FHE circuit execution on encrypted data using TFHE.
4//!
5//! # Architecture
6//!
7//! The compute engine consists of four main components:
8//!
9//! - **Key Management** (`keys`): Client and server key generation, serialization
10//! - **FHE Operations** (`operations`): Encrypted boolean and integer operations
11//! - **Circuit Compilation** (`circuit`): AST representation and type inference
12//! - **FHE Executor** (`FheExecutor`): Circuit execution engine
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use amaters_core::compute::{FheKeyPair, CircuitBuilder, EncryptedType, FheExecutor};
18//!
19//! // Generate keys
20//! let keypair = FheKeyPair::generate()?;
21//! keypair.set_as_global_server_key();
22//!
23//! // Build a circuit: a + b
24//! let mut builder = CircuitBuilder::new();
25//! builder.declare_variable("a", EncryptedType::U8)
26//!        .declare_variable("b", EncryptedType::U8);
27//!
28//! let a = builder.load("a");
29//! let b = builder.load("b");
30//! let sum = builder.add(a, b);
31//! let circuit = builder.build(sum)?;
32//!
33//! // Execute circuit
34//! let executor = FheExecutor::new();
35//! let result = executor.execute(&circuit, &inputs)?;
36//! ```
37
38pub mod circuit;
39pub mod gpu;
40pub mod key_manager;
41pub mod keys;
42pub mod operations;
43pub mod optimizer;
44pub mod plan_cache;
45pub mod planner;
46pub mod predicate;
47
48#[cfg(test)]
49mod filter_tests;
50
51// Re-export commonly used types
52pub use circuit::{
53    BinaryOperator, Circuit, CircuitBuilder, CircuitNode, CircuitValue, CompareOperator,
54    ConstantType, EncryptedType, UnaryOperator, count_encrypted_constants,
55    count_plaintext_constants, decrypt_constant, encrypt_circuit_constants, encrypt_constant,
56    is_encrypted_constant,
57};
58pub use key_manager::{ClientId, KeyManager};
59pub use keys::{FheKeyPair, InMemoryKeyStorage, KeyStorage};
60pub use operations::{EncryptedBool, EncryptedU8, EncryptedU16, EncryptedU32, EncryptedU64};
61pub use optimizer::{CircuitOptimizer, DependencyGraph, NodeId, OptimizationStats};
62pub use planner::{LogicalPlan, PhysicalPlan, PlanCost, PlannerStats, QueryPlanner};
63pub use predicate::{PredicateCompiler, compile_predicate};
64
65use crate::error::{AmateRSError, ErrorContext, Result};
66use crate::types::CipherBlob;
67use std::collections::HashMap;
68
69/// FHE executor for circuit execution
70///
71/// This executor takes a compiled circuit and encrypted inputs,
72/// executes the circuit on the encrypted data, and returns encrypted results.
73#[derive(Debug, Clone)]
74pub struct FheExecutor {
75    optimizer: CircuitOptimizer,
76    optimization_enabled: bool,
77}
78
79impl FheExecutor {
80    /// Create a new FHE executor with optimization enabled
81    pub fn new() -> Self {
82        Self {
83            optimizer: CircuitOptimizer::new(),
84            optimization_enabled: true,
85        }
86    }
87
88    /// Create a new FHE executor with optimization control
89    pub fn with_optimization(enable: bool) -> Self {
90        Self {
91            optimizer: if enable {
92                CircuitOptimizer::new()
93            } else {
94                CircuitOptimizer::disabled()
95            },
96            optimization_enabled: enable,
97        }
98    }
99
100    /// Get the optimization statistics from the last execution
101    pub fn optimization_stats(&self) -> &OptimizationStats {
102        self.optimizer.stats()
103    }
104
105    /// Get the dependency graph from the last execution
106    pub fn dependency_graph(&self) -> &DependencyGraph {
107        self.optimizer.dependency_graph()
108    }
109
110    /// Execute FHE circuit on encrypted data
111    ///
112    /// # Arguments
113    ///
114    /// * `circuit` - The compiled circuit to execute
115    /// * `inputs` - Map of variable names to encrypted values (CipherBlob)
116    ///
117    /// # Returns
118    ///
119    /// The encrypted result as a CipherBlob
120    #[cfg(feature = "compute")]
121    pub fn execute(
122        &self,
123        circuit: &Circuit,
124        inputs: &HashMap<String, CipherBlob>,
125    ) -> Result<CipherBlob> {
126        // Validate circuit
127        circuit.validate()?;
128
129        // Check that all required inputs are provided
130        for var_name in circuit.variable_types.keys() {
131            if !inputs.contains_key(var_name) {
132                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
133                    "Missing input for variable: {}",
134                    var_name
135                ))));
136            }
137        }
138
139        // Optimize circuit if enabled
140        let optimized = if self.optimization_enabled {
141            // Need mutable reference for optimizer
142            let mut optimizer = self.optimizer.clone();
143            optimizer.optimize(circuit.clone())?
144        } else {
145            circuit.clone()
146        };
147
148        // Execute the circuit
149        let result_value = self.execute_node(&optimized.root, inputs, &optimized.variable_types)?;
150
151        // Serialize result to CipherBlob
152        match result_value {
153            EncryptedValue::Bool(v) => v.to_cipher_blob(),
154            EncryptedValue::U8(v) => v.to_cipher_blob(),
155            EncryptedValue::U16(v) => v.to_cipher_blob(),
156            EncryptedValue::U32(v) => v.to_cipher_blob(),
157            EncryptedValue::U64(v) => v.to_cipher_blob(),
158        }
159    }
160
161    /// Stub implementation when compute feature is disabled
162    #[cfg(not(feature = "compute"))]
163    pub fn execute(
164        &self,
165        _circuit: &Circuit,
166        _inputs: &HashMap<String, CipherBlob>,
167    ) -> Result<CipherBlob> {
168        Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
169            "FHE compute feature is not enabled".to_string(),
170        )))
171    }
172
173    /// Execute a single circuit node recursively
174    #[cfg(feature = "compute")]
175    #[allow(clippy::only_used_in_recursion)]
176    fn execute_node(
177        &self,
178        node: &CircuitNode,
179        inputs: &HashMap<String, CipherBlob>,
180        variable_types: &HashMap<String, EncryptedType>,
181    ) -> Result<EncryptedValue> {
182        match node {
183            CircuitNode::Load(name) => {
184                let blob = inputs.get(name).ok_or_else(|| {
185                    AmateRSError::FheComputation(ErrorContext::new(format!(
186                        "Missing input: {}",
187                        name
188                    )))
189                })?;
190
191                let var_type = variable_types.get(name).ok_or_else(|| {
192                    AmateRSError::FheComputation(ErrorContext::new(format!(
193                        "Unknown variable type: {}",
194                        name
195                    )))
196                })?;
197
198                match var_type {
199                    EncryptedType::Bool => {
200                        Ok(EncryptedValue::Bool(EncryptedBool::from_cipher_blob(blob)?))
201                    }
202                    EncryptedType::U8 => {
203                        Ok(EncryptedValue::U8(EncryptedU8::from_cipher_blob(blob)?))
204                    }
205                    EncryptedType::U16 => {
206                        Ok(EncryptedValue::U16(EncryptedU16::from_cipher_blob(blob)?))
207                    }
208                    EncryptedType::U32 => {
209                        Ok(EncryptedValue::U32(EncryptedU32::from_cipher_blob(blob)?))
210                    }
211                    EncryptedType::U64 => {
212                        Ok(EncryptedValue::U64(EncryptedU64::from_cipher_blob(blob)?))
213                    }
214                }
215            }
216
217            CircuitNode::Constant(_value) => {
218                // Plaintext constants in FHE context are not directly supported.
219                // Use encrypt_circuit_constants() to pre-process the circuit before
220                // execution, converting all Constant nodes to EncryptedConstant.
221                Err(AmateRSError::FheComputation(ErrorContext::new(
222                    "Plaintext constants cannot be used in FHE execution. \
223                     Use encrypt_circuit_constants() to encrypt constants before evaluation."
224                        .to_string(),
225                )))
226            }
227
228            CircuitNode::EncryptedConstant {
229                data,
230                original_type,
231            } => {
232                // Encrypted constants are already in ciphertext form.
233                // Deserialize the CipherBlob from the encrypted data and
234                // convert to the appropriate EncryptedValue based on original_type.
235                let blob = CipherBlob::new(data.clone());
236                match original_type {
237                    ConstantType::Boolean => Ok(EncryptedValue::Bool(
238                        EncryptedBool::from_cipher_blob(&blob)?,
239                    )),
240                    ConstantType::Integer => {
241                        // Try to deserialize as the most common integer type (U64)
242                        // In practice, the caller should ensure the encrypted data
243                        // matches the expected type from the circuit context.
244                        Ok(EncryptedValue::U64(EncryptedU64::from_cipher_blob(&blob)?))
245                    }
246                    ConstantType::Float | ConstantType::Bytes => {
247                        Err(AmateRSError::FheComputation(ErrorContext::new(format!(
248                            "EncryptedConstant of type {} is not directly evaluable in FHE circuits",
249                            original_type
250                        ))))
251                    }
252                }
253            }
254
255            CircuitNode::BinaryOp { op, left, right } => {
256                let left_val = self.execute_node(left, inputs, variable_types)?;
257                let right_val = self.execute_node(right, inputs, variable_types)?;
258
259                match (op, left_val, right_val) {
260                    // Boolean operations
261                    (BinaryOperator::And, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
262                        Ok(EncryptedValue::Bool(l.and(&r)))
263                    }
264                    (BinaryOperator::Or, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
265                        Ok(EncryptedValue::Bool(l.or(&r)))
266                    }
267                    (BinaryOperator::Xor, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
268                        Ok(EncryptedValue::Bool(l.xor(&r)))
269                    }
270
271                    // U8 arithmetic
272                    (BinaryOperator::Add, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
273                        Ok(EncryptedValue::U8(l.add(&r)))
274                    }
275                    (BinaryOperator::Sub, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
276                        Ok(EncryptedValue::U8(l.sub(&r)))
277                    }
278                    (BinaryOperator::Mul, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
279                        Ok(EncryptedValue::U8(l.mul(&r)))
280                    }
281
282                    // U16 arithmetic
283                    (BinaryOperator::Add, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
284                        Ok(EncryptedValue::U16(l.add(&r)))
285                    }
286                    (BinaryOperator::Sub, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
287                        Ok(EncryptedValue::U16(l.sub(&r)))
288                    }
289                    (BinaryOperator::Mul, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
290                        Ok(EncryptedValue::U16(l.mul(&r)))
291                    }
292
293                    // U32 arithmetic
294                    (BinaryOperator::Add, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
295                        Ok(EncryptedValue::U32(l.add(&r)))
296                    }
297                    (BinaryOperator::Sub, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
298                        Ok(EncryptedValue::U32(l.sub(&r)))
299                    }
300                    (BinaryOperator::Mul, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
301                        Ok(EncryptedValue::U32(l.mul(&r)))
302                    }
303
304                    // U64 arithmetic
305                    (BinaryOperator::Add, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
306                        Ok(EncryptedValue::U64(l.add(&r)))
307                    }
308                    (BinaryOperator::Sub, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
309                        Ok(EncryptedValue::U64(l.sub(&r)))
310                    }
311                    (BinaryOperator::Mul, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
312                        Ok(EncryptedValue::U64(l.mul(&r)))
313                    }
314
315                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
316                        "Type mismatch in binary operation".to_string(),
317                    ))),
318                }
319            }
320
321            CircuitNode::UnaryOp { op, operand } => {
322                let operand_val = self.execute_node(operand, inputs, variable_types)?;
323
324                match (op, operand_val) {
325                    (UnaryOperator::Not, EncryptedValue::Bool(v)) => {
326                        Ok(EncryptedValue::Bool(v.not()))
327                    }
328
329                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
330                        "Type mismatch in unary operation".to_string(),
331                    ))),
332                }
333            }
334
335            CircuitNode::Compare { op, left, right } => {
336                let left_val = self.execute_node(left, inputs, variable_types)?;
337                let right_val = self.execute_node(right, inputs, variable_types)?;
338
339                match (left_val, right_val) {
340                    (EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
341                        let result = match op {
342                            CompareOperator::Eq => l.eq(&r),
343                            CompareOperator::Ne => l.ne(&r),
344                            CompareOperator::Lt => l.lt(&r),
345                            CompareOperator::Le => l.le(&r),
346                            CompareOperator::Gt => l.gt(&r),
347                            CompareOperator::Ge => l.ge(&r),
348                        };
349                        Ok(EncryptedValue::Bool(result))
350                    }
351
352                    (EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
353                        let result = match op {
354                            CompareOperator::Eq => l.eq(&r),
355                            CompareOperator::Ne => l.ne(&r),
356                            CompareOperator::Lt => l.lt(&r),
357                            CompareOperator::Le => l.le(&r),
358                            CompareOperator::Gt => l.gt(&r),
359                            CompareOperator::Ge => l.ge(&r),
360                        };
361                        Ok(EncryptedValue::Bool(result))
362                    }
363
364                    (EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
365                        let result = match op {
366                            CompareOperator::Eq => l.eq(&r),
367                            CompareOperator::Ne => l.ne(&r),
368                            CompareOperator::Lt => l.lt(&r),
369                            CompareOperator::Le => l.le(&r),
370                            CompareOperator::Gt => l.gt(&r),
371                            CompareOperator::Ge => l.ge(&r),
372                        };
373                        Ok(EncryptedValue::Bool(result))
374                    }
375
376                    (EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
377                        let result = match op {
378                            CompareOperator::Eq => l.eq(&r),
379                            CompareOperator::Ne => l.ne(&r),
380                            CompareOperator::Lt => l.lt(&r),
381                            CompareOperator::Le => l.le(&r),
382                            CompareOperator::Gt => l.gt(&r),
383                            CompareOperator::Ge => l.ge(&r),
384                        };
385                        Ok(EncryptedValue::Bool(result))
386                    }
387
388                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
389                        "Type mismatch in comparison".to_string(),
390                    ))),
391                }
392            }
393        }
394    }
395}
396
397impl Default for FheExecutor {
398    fn default() -> Self {
399        Self::new()
400    }
401}
402
403/// Internal enum for holding encrypted values during execution
404#[cfg(feature = "compute")]
405enum EncryptedValue {
406    Bool(EncryptedBool),
407    U8(EncryptedU8),
408    U16(EncryptedU16),
409    U32(EncryptedU32),
410    U64(EncryptedU64),
411}
412
413// Legacy types for backward compatibility (to be removed in future versions)
414
415/// Circuit gate (legacy - use CircuitNode instead)
416#[deprecated(since = "0.1.0", note = "Use CircuitNode instead")]
417#[derive(Debug, Clone)]
418pub enum Gate {
419    Add,
420    Mul,
421    Not,
422    Bootstrap,
423}
424
425#[cfg(all(test, feature = "compute"))]
426mod tests {
427    use super::*;
428
429    #[test]
430    fn test_fhe_executor_basic() -> Result<()> {
431        // Generate keys
432        let keypair = FheKeyPair::generate()?;
433        keypair.set_as_global_server_key();
434
435        // Build circuit: a + b
436        let mut builder = CircuitBuilder::new();
437        builder
438            .declare_variable("a", EncryptedType::U8)
439            .declare_variable("b", EncryptedType::U8);
440
441        let a_node = builder.load("a");
442        let b_node = builder.load("b");
443        let sum_node = builder.add(a_node, b_node);
444
445        let circuit = builder.build(sum_node)?;
446
447        // Prepare inputs
448        let a = EncryptedU8::encrypt(5, keypair.client_key());
449        let b = EncryptedU8::encrypt(3, keypair.client_key());
450
451        let mut inputs = HashMap::new();
452        inputs.insert("a".to_string(), a.to_cipher_blob()?);
453        inputs.insert("b".to_string(), b.to_cipher_blob()?);
454
455        // Execute
456        let executor = FheExecutor::new();
457        let result_blob = executor.execute(&circuit, &inputs)?;
458
459        // Verify
460        let result = EncryptedU8::from_cipher_blob(&result_blob)?;
461        assert_eq!(result.decrypt(keypair.client_key()), 8);
462
463        Ok(())
464    }
465
466    #[test]
467    fn test_fhe_executor_boolean() -> Result<()> {
468        let keypair = FheKeyPair::generate()?;
469        keypair.set_as_global_server_key();
470
471        let mut builder = CircuitBuilder::new();
472        builder
473            .declare_variable("x", EncryptedType::Bool)
474            .declare_variable("y", EncryptedType::Bool);
475
476        let x_node = builder.load("x");
477        let y_node = builder.load("y");
478        let and_node = builder.and(x_node, y_node);
479
480        let circuit = builder.build(and_node)?;
481
482        let x = EncryptedBool::encrypt(true, keypair.client_key());
483        let y = EncryptedBool::encrypt(false, keypair.client_key());
484
485        let mut inputs = HashMap::new();
486        inputs.insert("x".to_string(), x.to_cipher_blob()?);
487        inputs.insert("y".to_string(), y.to_cipher_blob()?);
488
489        let executor = FheExecutor::new();
490        let result_blob = executor.execute(&circuit, &inputs)?;
491
492        let result = EncryptedBool::from_cipher_blob(&result_blob)?;
493        assert!(!result.decrypt(keypair.client_key()));
494
495        Ok(())
496    }
497
498    #[test]
499    fn test_fhe_executor_comparison() -> Result<()> {
500        let keypair = FheKeyPair::generate()?;
501        keypair.set_as_global_server_key();
502
503        let mut builder = CircuitBuilder::new();
504        builder
505            .declare_variable("a", EncryptedType::U8)
506            .declare_variable("b", EncryptedType::U8);
507
508        let a_node = builder.load("a");
509        let b_node = builder.load("b");
510        let gt_node = builder.gt(a_node, b_node);
511
512        let circuit = builder.build(gt_node)?;
513
514        let a = EncryptedU8::encrypt(10, keypair.client_key());
515        let b = EncryptedU8::encrypt(5, keypair.client_key());
516
517        let mut inputs = HashMap::new();
518        inputs.insert("a".to_string(), a.to_cipher_blob()?);
519        inputs.insert("b".to_string(), b.to_cipher_blob()?);
520
521        let executor = FheExecutor::new();
522        let result_blob = executor.execute(&circuit, &inputs)?;
523
524        let result = EncryptedBool::from_cipher_blob(&result_blob)?;
525        assert!(result.decrypt(keypair.client_key()));
526
527        Ok(())
528    }
529
530    #[test]
531    fn test_missing_input_error() -> Result<()> {
532        let keypair = FheKeyPair::generate()?;
533        keypair.set_as_global_server_key();
534
535        let mut builder = CircuitBuilder::new();
536        builder.declare_variable("a", EncryptedType::U8);
537
538        let a_node = builder.load("a");
539        let circuit = builder.build(a_node)?;
540
541        let inputs = HashMap::new(); // No inputs provided
542
543        let executor = FheExecutor::new();
544        let result = executor.execute(&circuit, &inputs);
545
546        assert!(result.is_err());
547
548        Ok(())
549    }
550}