Skip to main content

amaters_core/compute/
mod.rs

1//! Compute engine module (Yata - The Eight-Span Mirror)
2//!
3//! This module provides FHE circuit execution on encrypted data using TFHE.
4//!
5//! # Architecture
6//!
7//! The compute engine consists of four main components:
8//!
9//! - **Key Management** (`keys`): Client and server key generation, serialization
10//! - **FHE Operations** (`operations`): Encrypted boolean and integer operations
11//! - **Circuit Compilation** (`circuit`): AST representation and type inference
12//! - **FHE Executor** (`FheExecutor`): Circuit execution engine
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! use amaters_core::compute::{FheKeyPair, CircuitBuilder, EncryptedType, FheExecutor};
18//!
19//! // Generate keys
20//! let keypair = FheKeyPair::generate()?;
21//! keypair.set_as_global_server_key();
22//!
23//! // Build a circuit: a + b
24//! let mut builder = CircuitBuilder::new();
25//! builder.declare_variable("a", EncryptedType::U8)
26//!        .declare_variable("b", EncryptedType::U8);
27//!
28//! let a = builder.load("a");
29//! let b = builder.load("b");
30//! let sum = builder.add(a, b);
31//! let circuit = builder.build(sum)?;
32//!
33//! // Execute circuit
34//! let executor = FheExecutor::new();
35//! let result = executor.execute(&circuit, &inputs)?;
36//! ```
37//!
38//! ## Compute Pipeline
39//!
40//! ```text
41//!  Query → Predicate → CircuitNode → CircuitOptimizer → FheExecutor
42//!                          │                │
43//!                   (BinaryOp,       (bootstrap-min,
44//!                    UnaryOp,         gate fusion,
45//!                    Compare)         parallelism)
46//! ```
47
48pub mod circuit;
49pub mod gpu;
50pub mod key_manager;
51pub mod keys;
52pub mod operations;
53pub mod optimizer;
54pub mod plan_cache;
55pub mod planner;
56pub mod predicate;
57
58#[cfg(test)]
59mod filter_tests;
60
61// Re-export commonly used types
62pub use circuit::{
63    BinaryOperator, Circuit, CircuitBuilder, CircuitNode, CircuitValue, CompareOperator,
64    ConstantType, EncryptedType, UnaryOperator, count_encrypted_constants,
65    count_plaintext_constants, decrypt_constant, encrypt_circuit_constants, encrypt_constant,
66    is_encrypted_constant,
67};
68pub use key_manager::{ClientId, KeyManager};
69pub use keys::{FheKeyPair, InMemoryKeyStorage, KeyStorage};
70pub use operations::{EncryptedBool, EncryptedU8, EncryptedU16, EncryptedU32, EncryptedU64};
71pub use optimizer::{CircuitOptimizer, DependencyGraph, NodeId, OptimizationStats};
72pub use planner::{LogicalPlan, PhysicalPlan, PlanCost, PlannerStats, QueryPlanner};
73pub use predicate::{PredicateCompiler, compile_predicate};
74
75use crate::error::{AmateRSError, ErrorContext, Result};
76use crate::types::CipherBlob;
77use std::collections::HashMap;
78#[cfg(feature = "compute")]
79use tfhe::prelude::*;
80#[cfg(feature = "compute")]
81use tfhe::{FheBool, FheUint8, FheUint16, FheUint32, FheUint64};
82
83/// FHE executor for circuit execution
84///
85/// This executor takes a compiled circuit and encrypted inputs,
86/// executes the circuit on the encrypted data, and returns encrypted results.
87#[derive(Debug, Clone)]
88pub struct FheExecutor {
89    optimizer: CircuitOptimizer,
90    optimization_enabled: bool,
91}
92
93impl FheExecutor {
94    /// Create a new FHE executor with optimization enabled
95    pub fn new() -> Self {
96        Self {
97            optimizer: CircuitOptimizer::new(),
98            optimization_enabled: true,
99        }
100    }
101
102    /// Create a new FHE executor with optimization control
103    pub fn with_optimization(enable: bool) -> Self {
104        Self {
105            optimizer: if enable {
106                CircuitOptimizer::new()
107            } else {
108                CircuitOptimizer::disabled()
109            },
110            optimization_enabled: enable,
111        }
112    }
113
114    /// Get the optimization statistics from the last execution
115    pub fn optimization_stats(&self) -> &OptimizationStats {
116        self.optimizer.stats()
117    }
118
119    /// Get the dependency graph from the last execution
120    pub fn dependency_graph(&self) -> &DependencyGraph {
121        self.optimizer.dependency_graph()
122    }
123
124    /// Execute FHE circuit on encrypted data
125    ///
126    /// # Arguments
127    ///
128    /// * `circuit` - The compiled circuit to execute
129    /// * `inputs` - Map of variable names to encrypted values (CipherBlob)
130    ///
131    /// # Returns
132    ///
133    /// The encrypted result as a CipherBlob
134    #[cfg(feature = "compute")]
135    pub fn execute(
136        &self,
137        circuit: &Circuit,
138        inputs: &HashMap<String, CipherBlob>,
139    ) -> Result<CipherBlob> {
140        // Validate circuit
141        circuit.validate()?;
142
143        // Check that all required inputs are provided
144        for var_name in circuit.variable_types.keys() {
145            if !inputs.contains_key(var_name) {
146                return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
147                    "Missing input for variable: {}",
148                    var_name
149                ))));
150            }
151        }
152
153        // Optimize circuit if enabled
154        let optimized = if self.optimization_enabled {
155            // Need mutable reference for optimizer
156            let mut optimizer = self.optimizer.clone();
157            optimizer.optimize(circuit.clone())?
158        } else {
159            circuit.clone()
160        };
161
162        // Execute the circuit
163        let result_value = self.execute_node(&optimized.root, inputs, &optimized.variable_types)?;
164
165        // Serialize result to CipherBlob
166        match result_value {
167            EncryptedValue::Bool(v) => v.to_cipher_blob(),
168            EncryptedValue::U8(v) => v.to_cipher_blob(),
169            EncryptedValue::U16(v) => v.to_cipher_blob(),
170            EncryptedValue::U32(v) => v.to_cipher_blob(),
171            EncryptedValue::U64(v) => v.to_cipher_blob(),
172        }
173    }
174
175    /// Stub implementation when compute feature is disabled
176    #[cfg(not(feature = "compute"))]
177    pub fn execute(
178        &self,
179        _circuit: &Circuit,
180        _inputs: &HashMap<String, CipherBlob>,
181    ) -> Result<CipherBlob> {
182        Err(AmateRSError::FeatureNotEnabled(ErrorContext::new(
183            "FHE compute feature is not enabled".to_string(),
184        )))
185    }
186
187    /// Execute a single circuit node recursively
188    #[cfg(feature = "compute")]
189    #[allow(clippy::only_used_in_recursion)]
190    fn execute_node(
191        &self,
192        node: &CircuitNode,
193        inputs: &HashMap<String, CipherBlob>,
194        variable_types: &HashMap<String, EncryptedType>,
195    ) -> Result<EncryptedValue> {
196        let node_type_str = match node {
197            CircuitNode::Load(_) => "load",
198            CircuitNode::Constant(_) => "constant",
199            CircuitNode::EncryptedConstant { .. } => "encrypted_constant",
200            CircuitNode::BinaryOp { op, .. } => match op {
201                BinaryOperator::Add => "binary_add",
202                BinaryOperator::Sub => "binary_sub",
203                BinaryOperator::Mul => "binary_mul",
204                BinaryOperator::And => "binary_and",
205                BinaryOperator::Or => "binary_or",
206                BinaryOperator::Xor => "binary_xor",
207            },
208            CircuitNode::UnaryOp { op, .. } => match op {
209                UnaryOperator::Not => "unary_not",
210                UnaryOperator::Neg => "unary_neg",
211            },
212            CircuitNode::Compare { op, .. } => match op {
213                CompareOperator::Eq => "compare_eq",
214                CompareOperator::Ne => "compare_ne",
215                CompareOperator::Lt => "compare_lt",
216                CompareOperator::Le => "compare_le",
217                CompareOperator::Gt => "compare_gt",
218                CompareOperator::Ge => "compare_ge",
219            },
220            CircuitNode::NaryOp { op, .. } => match op {
221                BinaryOperator::And => "nary_and",
222                BinaryOperator::Or => "nary_or",
223                BinaryOperator::Add => "nary_add",
224                BinaryOperator::Mul => "nary_mul",
225                BinaryOperator::Sub => "nary_sub",
226                BinaryOperator::Xor => "nary_xor",
227            },
228        };
229        let _span = tracing::debug_span!("amaters.fhe.gate", "amaters.gate.type" = node_type_str,)
230            .entered();
231        match node {
232            CircuitNode::Load(name) => {
233                let blob = inputs.get(name).ok_or_else(|| {
234                    AmateRSError::FheComputation(ErrorContext::new(format!(
235                        "Missing input: {}",
236                        name
237                    )))
238                })?;
239
240                let var_type = variable_types.get(name).ok_or_else(|| {
241                    AmateRSError::FheComputation(ErrorContext::new(format!(
242                        "Unknown variable type: {}",
243                        name
244                    )))
245                })?;
246
247                match var_type {
248                    EncryptedType::Bool => {
249                        Ok(EncryptedValue::Bool(EncryptedBool::from_cipher_blob(blob)?))
250                    }
251                    EncryptedType::U8 => {
252                        Ok(EncryptedValue::U8(EncryptedU8::from_cipher_blob(blob)?))
253                    }
254                    EncryptedType::U16 => {
255                        Ok(EncryptedValue::U16(EncryptedU16::from_cipher_blob(blob)?))
256                    }
257                    EncryptedType::U32 => {
258                        Ok(EncryptedValue::U32(EncryptedU32::from_cipher_blob(blob)?))
259                    }
260                    EncryptedType::U64 => {
261                        Ok(EncryptedValue::U64(EncryptedU64::from_cipher_blob(blob)?))
262                    }
263                }
264            }
265
266            CircuitNode::Constant(value) => {
267                // Use TFHE trivial encryption to create a public constant ciphertext.
268                // Trivially-encrypted values are public (no confidentiality) and
269                // can be used in FHE computations as circuit constants.
270                match value {
271                    CircuitValue::Bool(b) => {
272                        let fhe_bool = FheBool::try_encrypt_trivial(*b).map_err(|e| {
273                            AmateRSError::FheComputation(ErrorContext::new(format!(
274                                "Failed to trivially encrypt bool constant: {}",
275                                e
276                            )))
277                        })?;
278                        Ok(EncryptedValue::Bool(EncryptedBool::from_fhe(fhe_bool)))
279                    }
280                    CircuitValue::U8(v) => {
281                        let fhe_val = FheUint8::try_encrypt_trivial(*v).map_err(|e| {
282                            AmateRSError::FheComputation(ErrorContext::new(format!(
283                                "Failed to trivially encrypt u8 constant: {}",
284                                e
285                            )))
286                        })?;
287                        Ok(EncryptedValue::U8(EncryptedU8::from_fhe(fhe_val)))
288                    }
289                    CircuitValue::U16(v) => {
290                        let fhe_val = FheUint16::try_encrypt_trivial(*v).map_err(|e| {
291                            AmateRSError::FheComputation(ErrorContext::new(format!(
292                                "Failed to trivially encrypt u16 constant: {}",
293                                e
294                            )))
295                        })?;
296                        Ok(EncryptedValue::U16(EncryptedU16::from_fhe(fhe_val)))
297                    }
298                    CircuitValue::U32(v) => {
299                        let fhe_val = FheUint32::try_encrypt_trivial(*v).map_err(|e| {
300                            AmateRSError::FheComputation(ErrorContext::new(format!(
301                                "Failed to trivially encrypt u32 constant: {}",
302                                e
303                            )))
304                        })?;
305                        Ok(EncryptedValue::U32(EncryptedU32::from_fhe(fhe_val)))
306                    }
307                    CircuitValue::U64(v) => {
308                        let fhe_val = FheUint64::try_encrypt_trivial(*v).map_err(|e| {
309                            AmateRSError::FheComputation(ErrorContext::new(format!(
310                                "Failed to trivially encrypt u64 constant: {}",
311                                e
312                            )))
313                        })?;
314                        Ok(EncryptedValue::U64(EncryptedU64::from_fhe(fhe_val)))
315                    }
316                }
317            }
318
319            CircuitNode::EncryptedConstant {
320                data,
321                original_type,
322            } => {
323                // Encrypted constants are already in ciphertext form.
324                // Deserialize the CipherBlob from the encrypted data and
325                // convert to the appropriate EncryptedValue based on original_type.
326                let blob = CipherBlob::new(data.clone());
327                match original_type {
328                    ConstantType::Boolean => Ok(EncryptedValue::Bool(
329                        EncryptedBool::from_cipher_blob(&blob)?,
330                    )),
331                    ConstantType::Integer => {
332                        // Try to deserialize as the most common integer type (U64)
333                        // In practice, the caller should ensure the encrypted data
334                        // matches the expected type from the circuit context.
335                        Ok(EncryptedValue::U64(EncryptedU64::from_cipher_blob(&blob)?))
336                    }
337                    ConstantType::Float | ConstantType::Bytes => {
338                        Err(AmateRSError::FheComputation(ErrorContext::new(format!(
339                            "EncryptedConstant of type {} is not directly evaluable in FHE circuits",
340                            original_type
341                        ))))
342                    }
343                }
344            }
345
346            CircuitNode::BinaryOp { op, left, right } => {
347                let left_val = self.execute_node(left, inputs, variable_types)?;
348                let right_val = self.execute_node(right, inputs, variable_types)?;
349
350                match (op, left_val, right_val) {
351                    // Boolean operations
352                    (BinaryOperator::And, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
353                        Ok(EncryptedValue::Bool(l.and(&r)))
354                    }
355                    (BinaryOperator::Or, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
356                        Ok(EncryptedValue::Bool(l.or(&r)))
357                    }
358                    (BinaryOperator::Xor, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
359                        Ok(EncryptedValue::Bool(l.xor(&r)))
360                    }
361
362                    // U8 arithmetic
363                    (BinaryOperator::Add, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
364                        Ok(EncryptedValue::U8(l.add(&r)))
365                    }
366                    (BinaryOperator::Sub, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
367                        Ok(EncryptedValue::U8(l.sub(&r)))
368                    }
369                    (BinaryOperator::Mul, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
370                        Ok(EncryptedValue::U8(l.mul(&r)))
371                    }
372
373                    // U16 arithmetic
374                    (BinaryOperator::Add, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
375                        Ok(EncryptedValue::U16(l.add(&r)))
376                    }
377                    (BinaryOperator::Sub, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
378                        Ok(EncryptedValue::U16(l.sub(&r)))
379                    }
380                    (BinaryOperator::Mul, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
381                        Ok(EncryptedValue::U16(l.mul(&r)))
382                    }
383
384                    // U32 arithmetic
385                    (BinaryOperator::Add, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
386                        Ok(EncryptedValue::U32(l.add(&r)))
387                    }
388                    (BinaryOperator::Sub, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
389                        Ok(EncryptedValue::U32(l.sub(&r)))
390                    }
391                    (BinaryOperator::Mul, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
392                        Ok(EncryptedValue::U32(l.mul(&r)))
393                    }
394
395                    // U64 arithmetic
396                    (BinaryOperator::Add, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
397                        Ok(EncryptedValue::U64(l.add(&r)))
398                    }
399                    (BinaryOperator::Sub, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
400                        Ok(EncryptedValue::U64(l.sub(&r)))
401                    }
402                    (BinaryOperator::Mul, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
403                        Ok(EncryptedValue::U64(l.mul(&r)))
404                    }
405
406                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
407                        "Type mismatch in binary operation".to_string(),
408                    ))),
409                }
410            }
411
412            CircuitNode::UnaryOp { op, operand } => {
413                let operand_val = self.execute_node(operand, inputs, variable_types)?;
414
415                match (op, operand_val) {
416                    (UnaryOperator::Not, EncryptedValue::Bool(v)) => {
417                        Ok(EncryptedValue::Bool(v.not()))
418                    }
419
420                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
421                        "Type mismatch in unary operation".to_string(),
422                    ))),
423                }
424            }
425
426            CircuitNode::Compare { op, left, right } => {
427                let left_val = self.execute_node(left, inputs, variable_types)?;
428                let right_val = self.execute_node(right, inputs, variable_types)?;
429
430                match (left_val, right_val) {
431                    (EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
432                        let result = match op {
433                            CompareOperator::Eq => l.eq(&r),
434                            CompareOperator::Ne => l.ne(&r),
435                            CompareOperator::Lt => l.lt(&r),
436                            CompareOperator::Le => l.le(&r),
437                            CompareOperator::Gt => l.gt(&r),
438                            CompareOperator::Ge => l.ge(&r),
439                        };
440                        Ok(EncryptedValue::Bool(result))
441                    }
442
443                    (EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
444                        let result = match op {
445                            CompareOperator::Eq => l.eq(&r),
446                            CompareOperator::Ne => l.ne(&r),
447                            CompareOperator::Lt => l.lt(&r),
448                            CompareOperator::Le => l.le(&r),
449                            CompareOperator::Gt => l.gt(&r),
450                            CompareOperator::Ge => l.ge(&r),
451                        };
452                        Ok(EncryptedValue::Bool(result))
453                    }
454
455                    (EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
456                        let result = match op {
457                            CompareOperator::Eq => l.eq(&r),
458                            CompareOperator::Ne => l.ne(&r),
459                            CompareOperator::Lt => l.lt(&r),
460                            CompareOperator::Le => l.le(&r),
461                            CompareOperator::Gt => l.gt(&r),
462                            CompareOperator::Ge => l.ge(&r),
463                        };
464                        Ok(EncryptedValue::Bool(result))
465                    }
466
467                    (EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
468                        let result = match op {
469                            CompareOperator::Eq => l.eq(&r),
470                            CompareOperator::Ne => l.ne(&r),
471                            CompareOperator::Lt => l.lt(&r),
472                            CompareOperator::Le => l.le(&r),
473                            CompareOperator::Gt => l.gt(&r),
474                            CompareOperator::Ge => l.ge(&r),
475                        };
476                        Ok(EncryptedValue::Bool(result))
477                    }
478
479                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
480                        "Type mismatch in comparison".to_string(),
481                    ))),
482                }
483            }
484            CircuitNode::NaryOp { op, operands } => {
485                if operands.is_empty() {
486                    return Err(AmateRSError::FheComputation(ErrorContext::new(
487                        "NaryOp has no operands".to_string(),
488                    )));
489                }
490                // Evaluate all operands
491                let mut values: Vec<EncryptedValue> = Vec::with_capacity(operands.len());
492                for operand in operands {
493                    values.push(self.execute_node(operand, inputs, variable_types)?);
494                }
495                // Fold using binary op
496                let mut iter = values.into_iter();
497                let first = iter.next().ok_or_else(|| {
498                    AmateRSError::FheComputation(ErrorContext::new(
499                        "NaryOp has no operands after collection".to_string(),
500                    ))
501                })?;
502                iter.try_fold(first, |acc, next| match (op, acc, next) {
503                    (&BinaryOperator::And, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
504                        Ok(EncryptedValue::Bool(l.and(&r)))
505                    }
506                    (&BinaryOperator::Or, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
507                        Ok(EncryptedValue::Bool(l.or(&r)))
508                    }
509                    (&BinaryOperator::Xor, EncryptedValue::Bool(l), EncryptedValue::Bool(r)) => {
510                        Ok(EncryptedValue::Bool(l.xor(&r)))
511                    }
512                    (&BinaryOperator::Add, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
513                        Ok(EncryptedValue::U8(l.add(&r)))
514                    }
515                    (&BinaryOperator::Mul, EncryptedValue::U8(l), EncryptedValue::U8(r)) => {
516                        Ok(EncryptedValue::U8(l.mul(&r)))
517                    }
518                    (&BinaryOperator::Add, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
519                        Ok(EncryptedValue::U16(l.add(&r)))
520                    }
521                    (&BinaryOperator::Mul, EncryptedValue::U16(l), EncryptedValue::U16(r)) => {
522                        Ok(EncryptedValue::U16(l.mul(&r)))
523                    }
524                    (&BinaryOperator::Add, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
525                        Ok(EncryptedValue::U32(l.add(&r)))
526                    }
527                    (&BinaryOperator::Mul, EncryptedValue::U32(l), EncryptedValue::U32(r)) => {
528                        Ok(EncryptedValue::U32(l.mul(&r)))
529                    }
530                    (&BinaryOperator::Add, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
531                        Ok(EncryptedValue::U64(l.add(&r)))
532                    }
533                    (&BinaryOperator::Mul, EncryptedValue::U64(l), EncryptedValue::U64(r)) => {
534                        Ok(EncryptedValue::U64(l.mul(&r)))
535                    }
536                    _ => Err(AmateRSError::FheComputation(ErrorContext::new(
537                        "Type mismatch in NaryOp".to_string(),
538                    ))),
539                })
540            }
541        }
542    }
543}
544
545impl Default for FheExecutor {
546    fn default() -> Self {
547        Self::new()
548    }
549}
550
551/// Internal enum for holding encrypted values during execution
552#[cfg(feature = "compute")]
553enum EncryptedValue {
554    Bool(EncryptedBool),
555    U8(EncryptedU8),
556    U16(EncryptedU16),
557    U32(EncryptedU32),
558    U64(EncryptedU64),
559}
560
561// Legacy types for backward compatibility (to be removed in future versions)
562
563/// Circuit gate (legacy - use CircuitNode instead)
564#[deprecated(since = "0.1.0", note = "Use CircuitNode instead")]
565#[derive(Debug, Clone)]
566pub enum Gate {
567    Add,
568    Mul,
569    Not,
570    Bootstrap,
571}
572
573#[cfg(all(test, feature = "compute"))]
574mod tests {
575    use super::*;
576
577    #[test]
578    fn test_fhe_executor_basic() -> Result<()> {
579        // Generate keys
580        let keypair = FheKeyPair::generate()?;
581        keypair.set_as_global_server_key();
582
583        // Build circuit: a + b
584        let mut builder = CircuitBuilder::new();
585        builder
586            .declare_variable("a", EncryptedType::U8)
587            .declare_variable("b", EncryptedType::U8);
588
589        let a_node = builder.load("a");
590        let b_node = builder.load("b");
591        let sum_node = builder.add(a_node, b_node);
592
593        let circuit = builder.build(sum_node)?;
594
595        // Prepare inputs
596        let a = EncryptedU8::encrypt(5, keypair.client_key());
597        let b = EncryptedU8::encrypt(3, keypair.client_key());
598
599        let mut inputs = HashMap::new();
600        inputs.insert("a".to_string(), a.to_cipher_blob()?);
601        inputs.insert("b".to_string(), b.to_cipher_blob()?);
602
603        // Execute
604        let executor = FheExecutor::new();
605        let result_blob = executor.execute(&circuit, &inputs)?;
606
607        // Verify
608        let result = EncryptedU8::from_cipher_blob(&result_blob)?;
609        assert_eq!(result.decrypt(keypair.client_key()), 8);
610
611        Ok(())
612    }
613
614    #[test]
615    fn test_fhe_executor_boolean() -> Result<()> {
616        let keypair = FheKeyPair::generate()?;
617        keypair.set_as_global_server_key();
618
619        let mut builder = CircuitBuilder::new();
620        builder
621            .declare_variable("x", EncryptedType::Bool)
622            .declare_variable("y", EncryptedType::Bool);
623
624        let x_node = builder.load("x");
625        let y_node = builder.load("y");
626        let and_node = builder.and(x_node, y_node);
627
628        let circuit = builder.build(and_node)?;
629
630        let x = EncryptedBool::encrypt(true, keypair.client_key());
631        let y = EncryptedBool::encrypt(false, keypair.client_key());
632
633        let mut inputs = HashMap::new();
634        inputs.insert("x".to_string(), x.to_cipher_blob()?);
635        inputs.insert("y".to_string(), y.to_cipher_blob()?);
636
637        let executor = FheExecutor::new();
638        let result_blob = executor.execute(&circuit, &inputs)?;
639
640        let result = EncryptedBool::from_cipher_blob(&result_blob)?;
641        assert!(!result.decrypt(keypair.client_key()));
642
643        Ok(())
644    }
645
646    #[test]
647    fn test_fhe_executor_comparison() -> Result<()> {
648        let keypair = FheKeyPair::generate()?;
649        keypair.set_as_global_server_key();
650
651        let mut builder = CircuitBuilder::new();
652        builder
653            .declare_variable("a", EncryptedType::U8)
654            .declare_variable("b", EncryptedType::U8);
655
656        let a_node = builder.load("a");
657        let b_node = builder.load("b");
658        let gt_node = builder.gt(a_node, b_node);
659
660        let circuit = builder.build(gt_node)?;
661
662        let a = EncryptedU8::encrypt(10, keypair.client_key());
663        let b = EncryptedU8::encrypt(5, keypair.client_key());
664
665        let mut inputs = HashMap::new();
666        inputs.insert("a".to_string(), a.to_cipher_blob()?);
667        inputs.insert("b".to_string(), b.to_cipher_blob()?);
668
669        let executor = FheExecutor::new();
670        let result_blob = executor.execute(&circuit, &inputs)?;
671
672        let result = EncryptedBool::from_cipher_blob(&result_blob)?;
673        assert!(result.decrypt(keypair.client_key()));
674
675        Ok(())
676    }
677
678    #[test]
679    fn test_missing_input_error() -> Result<()> {
680        let keypair = FheKeyPair::generate()?;
681        keypair.set_as_global_server_key();
682
683        let mut builder = CircuitBuilder::new();
684        builder.declare_variable("a", EncryptedType::U8);
685
686        let a_node = builder.load("a");
687        let circuit = builder.build(a_node)?;
688
689        let inputs = HashMap::new(); // No inputs provided
690
691        let executor = FheExecutor::new();
692        let result = executor.execute(&circuit, &inputs);
693
694        assert!(result.is_err());
695
696        Ok(())
697    }
698
699    #[test]
700    fn test_trivial_constant_in_circuit() -> Result<()> {
701        // Trivially-encrypted constants (via CircuitNode::Constant) should work
702        // in FHE circuits via try_encrypt_trivial.
703        let keypair = FheKeyPair::generate()?;
704        keypair.set_as_global_server_key();
705
706        // Build circuit: a + Constant(5u8)
707        let mut builder = CircuitBuilder::new();
708        builder.declare_variable("a", EncryptedType::U8);
709
710        let a_node = builder.load("a");
711        let const_node = builder.constant(CircuitValue::U8(5));
712        let sum_node = builder.add(a_node, const_node);
713
714        let circuit = builder.build(sum_node)?;
715
716        let a = EncryptedU8::encrypt(3, keypair.client_key());
717
718        let mut inputs = HashMap::new();
719        inputs.insert("a".to_string(), a.to_cipher_blob()?);
720
721        let executor = FheExecutor::new();
722        let result_blob = executor.execute(&circuit, &inputs)?;
723
724        let result = EncryptedU8::from_cipher_blob(&result_blob)?;
725        assert_eq!(result.decrypt(keypair.client_key()), 8);
726
727        Ok(())
728    }
729}
730
731#[cfg(test)]
732mod trace_tests {
733    #[test]
734    fn test_execute_node_emits_trace() {
735        // Smoke test: span creation must not panic (no-op without subscriber).
736        let _span =
737            tracing::debug_span!("amaters.fhe.gate", "amaters.gate.type" = "test").entered();
738    }
739}