Skip to main content

amaters_core/compute/
circuit.rs

1//! Circuit compilation and optimization
2//!
3//! This module provides circuit AST representation, type inference,
4//! and basic optimization for FHE operations.
5
6use crate::error::{AmateRSError, ErrorContext, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Type tag for encrypted constants, indicating the original plaintext type
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum ConstantType {
13    /// Original value was an integer (u8, u16, u32, u64)
14    Integer,
15    /// Original value was a boolean
16    Boolean,
17    /// Original value was a floating-point number
18    Float,
19    /// Original value was raw bytes
20    Bytes,
21}
22
23impl std::fmt::Display for ConstantType {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            ConstantType::Integer => write!(f, "integer"),
27            ConstantType::Boolean => write!(f, "boolean"),
28            ConstantType::Float => write!(f, "float"),
29            ConstantType::Bytes => write!(f, "bytes"),
30        }
31    }
32}
33
34/// Circuit AST node representing FHE operations
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum CircuitNode {
37    /// Load a variable by name
38    Load(String),
39
40    /// Constant value (plaintext)
41    Constant(CircuitValue),
42
43    /// Encrypted constant value (ciphertext form, opaque to the optimizer)
44    ///
45    /// This variant holds a constant that has already been encrypted for use
46    /// in FHE evaluation. The optimizer must NOT attempt to constant-fold or
47    /// simplify encrypted constants since their plaintext values are unknown.
48    EncryptedConstant {
49        /// Encrypted ciphertext data
50        data: Vec<u8>,
51        /// The type of the original plaintext value before encryption
52        original_type: ConstantType,
53    },
54
55    /// Binary operation
56    BinaryOp {
57        op: BinaryOperator,
58        left: Box<CircuitNode>,
59        right: Box<CircuitNode>,
60    },
61
62    /// Unary operation
63    UnaryOp {
64        op: UnaryOperator,
65        operand: Box<CircuitNode>,
66    },
67
68    /// Comparison operation
69    Compare {
70        op: CompareOperator,
71        left: Box<CircuitNode>,
72        right: Box<CircuitNode>,
73    },
74
75    /// N-ary operation for associative+commutative ops (optimizer-only IR)
76    ///
77    /// Valid only for associative+commutative operators: Add, Mul, And, Or, Xor.
78    /// Created by the optimizer's gate fusion pass; BinaryOp is canonical.
79    /// Invariant: `operands.len() >= 2`.
80    NaryOp {
81        op: BinaryOperator,
82        operands: Vec<CircuitNode>,
83    },
84}
85
86/// Binary operators
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum BinaryOperator {
89    Add,
90    Sub,
91    Mul,
92    And,
93    Or,
94    Xor,
95}
96
97impl BinaryOperator {
98    /// Get the string representation
99    pub fn as_str(&self) -> &str {
100        match self {
101            BinaryOperator::Add => "+",
102            BinaryOperator::Sub => "-",
103            BinaryOperator::Mul => "*",
104            BinaryOperator::And => "AND",
105            BinaryOperator::Or => "OR",
106            BinaryOperator::Xor => "XOR",
107        }
108    }
109}
110
111/// Unary operators
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum UnaryOperator {
114    Not,
115    Neg,
116}
117
118impl UnaryOperator {
119    /// Get the string representation
120    pub fn as_str(&self) -> &str {
121        match self {
122            UnaryOperator::Not => "NOT",
123            UnaryOperator::Neg => "-",
124        }
125    }
126}
127
128/// Comparison operators
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum CompareOperator {
131    Eq,
132    Ne,
133    Lt,
134    Le,
135    Gt,
136    Ge,
137}
138
139impl CompareOperator {
140    /// Get the string representation
141    pub fn as_str(&self) -> &str {
142        match self {
143            CompareOperator::Eq => "=",
144            CompareOperator::Ne => "!=",
145            CompareOperator::Lt => "<",
146            CompareOperator::Le => "<=",
147            CompareOperator::Gt => ">",
148            CompareOperator::Ge => ">=",
149        }
150    }
151}
152
153/// Circuit value types
154#[derive(Debug, Clone, PartialEq, Eq)]
155pub enum CircuitValue {
156    Bool(bool),
157    U8(u8),
158    U16(u16),
159    U32(u32),
160    U64(u64),
161}
162
163impl std::fmt::Display for CircuitNode {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        match self {
166            CircuitNode::Load(name) => write!(f, "Load({})", name),
167            CircuitNode::Constant(value) => match value {
168                CircuitValue::Bool(v) => write!(f, "Const({})", v),
169                CircuitValue::U8(v) => write!(f, "Const({}u8)", v),
170                CircuitValue::U16(v) => write!(f, "Const({}u16)", v),
171                CircuitValue::U32(v) => write!(f, "Const({}u32)", v),
172                CircuitValue::U64(v) => write!(f, "Const({}u64)", v),
173            },
174            CircuitNode::EncryptedConstant {
175                data,
176                original_type,
177            } => {
178                write!(f, "EncryptedConst({}, {} bytes)", original_type, data.len())
179            }
180            CircuitNode::BinaryOp { op, left, right } => {
181                write!(f, "({} {} {})", left, op.as_str(), right)
182            }
183            CircuitNode::UnaryOp { op, operand } => {
184                write!(f, "{}({})", op.as_str(), operand)
185            }
186            CircuitNode::Compare { op, left, right } => {
187                write!(f, "({} {} {})", left, op.as_str(), right)
188            }
189            CircuitNode::NaryOp { op, operands } => {
190                write!(f, "{}(", op.as_str())?;
191                for (i, operand) in operands.iter().enumerate() {
192                    if i > 0 {
193                        write!(f, ", ")?;
194                    }
195                    write!(f, "{}", operand)?;
196                }
197                write!(f, ")")
198            }
199        }
200    }
201}
202
203/// Encrypted type information
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
205pub enum EncryptedType {
206    Bool,
207    U8,
208    U16,
209    U32,
210    U64,
211}
212
213impl EncryptedType {
214    /// Get the bit width of the type
215    pub fn bit_width(&self) -> usize {
216        match self {
217            EncryptedType::Bool => 1,
218            EncryptedType::U8 => 8,
219            EncryptedType::U16 => 16,
220            EncryptedType::U32 => 32,
221            EncryptedType::U64 => 64,
222        }
223    }
224
225    /// Check if this type is numeric
226    pub fn is_numeric(&self) -> bool {
227        !matches!(self, EncryptedType::Bool)
228    }
229
230    /// Check if this type is boolean
231    pub fn is_boolean(&self) -> bool {
232        matches!(self, EncryptedType::Bool)
233    }
234}
235
236impl std::fmt::Display for EncryptedType {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        match self {
239            EncryptedType::Bool => write!(f, "bool"),
240            EncryptedType::U8 => write!(f, "u8"),
241            EncryptedType::U16 => write!(f, "u16"),
242            EncryptedType::U32 => write!(f, "u32"),
243            EncryptedType::U64 => write!(f, "u64"),
244        }
245    }
246}
247
248/// Circuit representation with metadata
249#[derive(Debug, Clone)]
250pub struct Circuit {
251    /// Root node of the circuit
252    pub root: CircuitNode,
253
254    /// Type information for variables
255    pub variable_types: HashMap<String, EncryptedType>,
256
257    /// Inferred result type
258    pub result_type: EncryptedType,
259
260    /// Circuit depth (for complexity estimation)
261    pub depth: usize,
262
263    /// Number of gates (for complexity estimation)
264    pub gate_count: usize,
265}
266
267impl Circuit {
268    /// Create a new circuit from a root node
269    pub fn new(root: CircuitNode, variable_types: HashMap<String, EncryptedType>) -> Result<Self> {
270        let result_type = Self::infer_type(&root, &variable_types)?;
271        let depth = Self::compute_depth(&root);
272        let gate_count = Self::count_gates(&root);
273
274        Ok(Self {
275            root,
276            variable_types,
277            result_type,
278            depth,
279            gate_count,
280        })
281    }
282
283    /// Infer the type of a circuit node
284    fn infer_type(
285        node: &CircuitNode,
286        variable_types: &HashMap<String, EncryptedType>,
287    ) -> Result<EncryptedType> {
288        match node {
289            CircuitNode::Load(name) => variable_types.get(name).copied().ok_or_else(|| {
290                AmateRSError::FheComputation(ErrorContext::new(format!(
291                    "Undefined variable: {}",
292                    name
293                )))
294            }),
295
296            CircuitNode::Constant(value) => Ok(match value {
297                CircuitValue::Bool(_) => EncryptedType::Bool,
298                CircuitValue::U8(_) => EncryptedType::U8,
299                CircuitValue::U16(_) => EncryptedType::U16,
300                CircuitValue::U32(_) => EncryptedType::U32,
301                CircuitValue::U64(_) => EncryptedType::U64,
302            }),
303
304            CircuitNode::EncryptedConstant { original_type, .. } => {
305                Ok(match original_type {
306                    ConstantType::Boolean => EncryptedType::Bool,
307                    // For non-boolean encrypted constants, we default to U64
308                    // since the exact width is not recoverable from the encrypted data.
309                    // In practice, users should ensure type consistency.
310                    ConstantType::Integer | ConstantType::Float | ConstantType::Bytes => {
311                        EncryptedType::U64
312                    }
313                })
314            }
315
316            CircuitNode::BinaryOp { op, left, right } => {
317                let left_type = Self::infer_type(left, variable_types)?;
318                let right_type = Self::infer_type(right, variable_types)?;
319
320                match op {
321                    BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
322                        if left_type != EncryptedType::Bool || right_type != EncryptedType::Bool {
323                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
324                                "Logical operation requires boolean operands, got {} and {}",
325                                left_type, right_type
326                            ))));
327                        }
328                        Ok(EncryptedType::Bool)
329                    }
330
331                    BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
332                        if !left_type.is_numeric() || !right_type.is_numeric() {
333                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
334                                "Arithmetic operation requires numeric operands, got {} and {}",
335                                left_type, right_type
336                            ))));
337                        }
338
339                        if left_type != right_type {
340                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
341                                "Arithmetic operation requires matching types, got {} and {}",
342                                left_type, right_type
343                            ))));
344                        }
345
346                        Ok(left_type)
347                    }
348                }
349            }
350
351            CircuitNode::UnaryOp { op, operand } => {
352                let operand_type = Self::infer_type(operand, variable_types)?;
353
354                match op {
355                    UnaryOperator::Not => {
356                        if operand_type != EncryptedType::Bool {
357                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
358                                "NOT operation requires boolean operand, got {}",
359                                operand_type
360                            ))));
361                        }
362                        Ok(EncryptedType::Bool)
363                    }
364
365                    UnaryOperator::Neg => {
366                        if !operand_type.is_numeric() {
367                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
368                                "Negation operation requires numeric operand, got {}",
369                                operand_type
370                            ))));
371                        }
372                        Ok(operand_type)
373                    }
374                }
375            }
376
377            CircuitNode::Compare { left, right, .. } => {
378                let left_type = Self::infer_type(left, variable_types)?;
379                let right_type = Self::infer_type(right, variable_types)?;
380
381                if left_type != right_type {
382                    return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
383                        "Comparison requires matching types, got {} and {}",
384                        left_type, right_type
385                    ))));
386                }
387
388                Ok(EncryptedType::Bool)
389            }
390            CircuitNode::NaryOp { op, operands } => {
391                if operands.len() < 2 {
392                    return Err(AmateRSError::FheComputation(ErrorContext::new(
393                        "NaryOp requires at least 2 operands".to_string(),
394                    )));
395                }
396                let first_type = Self::infer_type(&operands[0], variable_types)?;
397                for operand in &operands[1..] {
398                    let t = Self::infer_type(operand, variable_types)?;
399                    if t != first_type {
400                        return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
401                            "NaryOp operands have mismatched types: {} and {}",
402                            first_type, t
403                        ))));
404                    }
405                }
406                match op {
407                    BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
408                        if first_type != EncryptedType::Bool {
409                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
410                                "Logical NaryOp requires boolean operands, got {}",
411                                first_type
412                            ))));
413                        }
414                        Ok(EncryptedType::Bool)
415                    }
416                    BinaryOperator::Add | BinaryOperator::Mul => {
417                        if !first_type.is_numeric() {
418                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
419                                "Arithmetic NaryOp requires numeric operands, got {}",
420                                first_type
421                            ))));
422                        }
423                        Ok(first_type)
424                    }
425                    BinaryOperator::Sub => Err(AmateRSError::FheComputation(ErrorContext::new(
426                        "NaryOp is not valid for Sub (non-associative)".to_string(),
427                    ))),
428                }
429            }
430        }
431    }
432
433    /// Compute the depth of the circuit
434    fn compute_depth(node: &CircuitNode) -> usize {
435        match node {
436            CircuitNode::Load(_)
437            | CircuitNode::Constant(_)
438            | CircuitNode::EncryptedConstant { .. } => 1,
439
440            CircuitNode::BinaryOp { left, right, .. }
441            | CircuitNode::Compare { left, right, .. } => {
442                1 + Self::compute_depth(left).max(Self::compute_depth(right))
443            }
444
445            CircuitNode::UnaryOp { operand, .. } => 1 + Self::compute_depth(operand),
446            CircuitNode::NaryOp { operands, .. } => {
447                let max_operand_depth = operands.iter().map(Self::compute_depth).max().unwrap_or(1);
448                let n = operands.len();
449                let log2_n = (n as f64).log2().ceil() as usize;
450                log2_n.max(1) + max_operand_depth
451            }
452        }
453    }
454
455    /// Count the number of gates in the circuit
456    fn count_gates(node: &CircuitNode) -> usize {
457        match node {
458            CircuitNode::Load(_)
459            | CircuitNode::Constant(_)
460            | CircuitNode::EncryptedConstant { .. } => 0,
461
462            CircuitNode::BinaryOp { left, right, .. }
463            | CircuitNode::Compare { left, right, .. } => {
464                1 + Self::count_gates(left) + Self::count_gates(right)
465            }
466
467            CircuitNode::UnaryOp { operand, .. } => 1 + Self::count_gates(operand),
468            CircuitNode::NaryOp { operands, .. } => {
469                let inner_gates: usize = operands.iter().map(Self::count_gates).sum();
470                inner_gates + operands.len().saturating_sub(1)
471            }
472        }
473    }
474
475    /// Validate the circuit for correctness
476    pub fn validate(&self) -> Result<()> {
477        Self::validate_node(&self.root, &self.variable_types)?;
478        Ok(())
479    }
480
481    fn validate_node(
482        node: &CircuitNode,
483        variable_types: &HashMap<String, EncryptedType>,
484    ) -> Result<()> {
485        match node {
486            CircuitNode::Load(name) => {
487                if !variable_types.contains_key(name) {
488                    return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
489                        "Undefined variable: {}",
490                        name
491                    ))));
492                }
493                Ok(())
494            }
495
496            CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => Ok(()),
497
498            CircuitNode::BinaryOp { left, right, .. }
499            | CircuitNode::Compare { left, right, .. } => {
500                Self::validate_node(left, variable_types)?;
501                Self::validate_node(right, variable_types)?;
502                Ok(())
503            }
504
505            CircuitNode::UnaryOp { operand, .. } => Self::validate_node(operand, variable_types),
506            CircuitNode::NaryOp { op, operands } => {
507                if operands.len() < 2 {
508                    return Err(AmateRSError::FheComputation(ErrorContext::new(
509                        "NaryOp requires at least 2 operands".to_string(),
510                    )));
511                }
512                if op == &BinaryOperator::Sub {
513                    return Err(AmateRSError::FheComputation(ErrorContext::new(
514                        "NaryOp is not valid for Sub (non-associative)".to_string(),
515                    )));
516                }
517                for operand in operands {
518                    Self::validate_node(operand, variable_types)?;
519                }
520                Ok(())
521            }
522        }
523    }
524}
525
526/// Circuit builder for constructing circuits programmatically
527#[derive(Default)]
528pub struct CircuitBuilder {
529    variable_types: HashMap<String, EncryptedType>,
530}
531
532impl CircuitBuilder {
533    pub fn new() -> Self {
534        Self::default()
535    }
536
537    /// Get the variable types map
538    pub fn variable_types(&self) -> &HashMap<String, EncryptedType> {
539        &self.variable_types
540    }
541
542    /// Clone the variable types map
543    pub fn variable_types_clone(&self) -> HashMap<String, EncryptedType> {
544        self.variable_types.clone()
545    }
546
547    /// Declare a variable with its type
548    pub fn declare_variable(&mut self, name: impl Into<String>, ty: EncryptedType) -> &mut Self {
549        self.variable_types.insert(name.into(), ty);
550        self
551    }
552
553    /// Build the circuit from a root node
554    pub fn build(&self, root: CircuitNode) -> Result<Circuit> {
555        Circuit::new(root, self.variable_types.clone())
556    }
557
558    /// Create a load node
559    pub fn load(&self, name: impl Into<String>) -> CircuitNode {
560        CircuitNode::Load(name.into())
561    }
562
563    /// Create a constant node
564    pub fn constant(&self, value: CircuitValue) -> CircuitNode {
565        CircuitNode::Constant(value)
566    }
567
568    /// Create an addition node
569    pub fn add(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
570        CircuitNode::BinaryOp {
571            op: BinaryOperator::Add,
572            left: Box::new(left),
573            right: Box::new(right),
574        }
575    }
576
577    /// Create a subtraction node
578    pub fn sub(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
579        CircuitNode::BinaryOp {
580            op: BinaryOperator::Sub,
581            left: Box::new(left),
582            right: Box::new(right),
583        }
584    }
585
586    /// Create a multiplication node
587    pub fn mul(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
588        CircuitNode::BinaryOp {
589            op: BinaryOperator::Mul,
590            left: Box::new(left),
591            right: Box::new(right),
592        }
593    }
594
595    /// Create an AND node
596    pub fn and(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
597        CircuitNode::BinaryOp {
598            op: BinaryOperator::And,
599            left: Box::new(left),
600            right: Box::new(right),
601        }
602    }
603
604    /// Create an OR node
605    pub fn or(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
606        CircuitNode::BinaryOp {
607            op: BinaryOperator::Or,
608            left: Box::new(left),
609            right: Box::new(right),
610        }
611    }
612
613    /// Create an XOR node
614    pub fn xor(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
615        CircuitNode::BinaryOp {
616            op: BinaryOperator::Xor,
617            left: Box::new(left),
618            right: Box::new(right),
619        }
620    }
621
622    /// Create a NOT node
623    pub fn not(&self, operand: CircuitNode) -> CircuitNode {
624        CircuitNode::UnaryOp {
625            op: UnaryOperator::Not,
626            operand: Box::new(operand),
627        }
628    }
629
630    /// Create an equality comparison node
631    pub fn eq(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
632        CircuitNode::Compare {
633            op: CompareOperator::Eq,
634            left: Box::new(left),
635            right: Box::new(right),
636        }
637    }
638
639    /// Create a less-than comparison node
640    pub fn lt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
641        CircuitNode::Compare {
642            op: CompareOperator::Lt,
643            left: Box::new(left),
644            right: Box::new(right),
645        }
646    }
647
648    /// Create a greater-than comparison node
649    pub fn gt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
650        CircuitNode::Compare {
651            op: CompareOperator::Gt,
652            left: Box::new(left),
653            right: Box::new(right),
654        }
655    }
656
657    /// Create an encrypted constant node directly
658    pub fn encrypted_constant(&self, data: Vec<u8>, original_type: ConstantType) -> CircuitNode {
659        CircuitNode::EncryptedConstant {
660            data,
661            original_type,
662        }
663    }
664}
665
666// ---------------------------------------------------------------------------
667// Encrypted constant helpers
668// ---------------------------------------------------------------------------
669
670/// Encrypt a plaintext circuit constant value using a symmetric key.
671///
672/// This uses a simple XOR-based stream cipher derived from the key for
673/// demonstration purposes. In production, this would delegate to the
674/// actual FHE encryption backend (e.g., TFHE key-switch + bootstrap).
675///
676/// The output ciphertext contains a 1-byte type tag followed by the
677/// XOR-encrypted payload so that it can be correctly interpreted during
678/// evaluation.
679pub fn encrypt_constant(value: &CircuitValue, key: &[u8]) -> Result<Vec<u8>> {
680    if key.is_empty() {
681        return Err(AmateRSError::FheComputation(ErrorContext::new(
682            "Encryption key must not be empty".to_string(),
683        )));
684    }
685
686    // Serialize the plaintext value to bytes
687    let (type_tag, plaintext): (u8, Vec<u8>) = match value {
688        CircuitValue::Bool(v) => (0x00, vec![if *v { 1 } else { 0 }]),
689        CircuitValue::U8(v) => (0x01, v.to_le_bytes().to_vec()),
690        CircuitValue::U16(v) => (0x02, v.to_le_bytes().to_vec()),
691        CircuitValue::U32(v) => (0x03, v.to_le_bytes().to_vec()),
692        CircuitValue::U64(v) => (0x04, v.to_le_bytes().to_vec()),
693    };
694
695    // Generate a keystream by repeating and hashing the key material
696    let keystream = derive_keystream(key, plaintext.len());
697
698    // XOR plaintext with keystream
699    let ciphertext: Vec<u8> = plaintext
700        .iter()
701        .zip(keystream.iter())
702        .map(|(p, k)| p ^ k)
703        .collect();
704
705    // Prepend type tag (unencrypted, needed for type inference)
706    let mut output = Vec::with_capacity(1 + ciphertext.len());
707    output.push(type_tag);
708    output.extend_from_slice(&ciphertext);
709
710    Ok(output)
711}
712
713/// Decrypt an encrypted constant back to its plaintext value.
714///
715/// Inverse of [`encrypt_constant`]. Returns an error if the data is
716/// malformed or the key does not match.
717pub fn decrypt_constant(data: &[u8], key: &[u8]) -> Result<CircuitValue> {
718    if key.is_empty() {
719        return Err(AmateRSError::FheComputation(ErrorContext::new(
720            "Decryption key must not be empty".to_string(),
721        )));
722    }
723    if data.is_empty() {
724        return Err(AmateRSError::FheComputation(ErrorContext::new(
725            "Encrypted constant data is empty".to_string(),
726        )));
727    }
728
729    let type_tag = data[0];
730    let ciphertext = &data[1..];
731
732    let keystream = derive_keystream(key, ciphertext.len());
733    let plaintext: Vec<u8> = ciphertext
734        .iter()
735        .zip(keystream.iter())
736        .map(|(c, k)| c ^ k)
737        .collect();
738
739    match type_tag {
740        0x00 => {
741            if plaintext.is_empty() {
742                return Err(AmateRSError::FheComputation(ErrorContext::new(
743                    "Encrypted boolean constant has no payload".to_string(),
744                )));
745            }
746            Ok(CircuitValue::Bool(plaintext[0] != 0))
747        }
748        0x01 => {
749            let arr: [u8; 1] = plaintext.as_slice().try_into().map_err(|_| {
750                AmateRSError::FheComputation(ErrorContext::new(
751                    "Invalid encrypted u8 constant length".to_string(),
752                ))
753            })?;
754            Ok(CircuitValue::U8(u8::from_le_bytes(arr)))
755        }
756        0x02 => {
757            let arr: [u8; 2] = plaintext.as_slice().try_into().map_err(|_| {
758                AmateRSError::FheComputation(ErrorContext::new(
759                    "Invalid encrypted u16 constant length".to_string(),
760                ))
761            })?;
762            Ok(CircuitValue::U16(u16::from_le_bytes(arr)))
763        }
764        0x03 => {
765            let arr: [u8; 4] = plaintext.as_slice().try_into().map_err(|_| {
766                AmateRSError::FheComputation(ErrorContext::new(
767                    "Invalid encrypted u32 constant length".to_string(),
768                ))
769            })?;
770            Ok(CircuitValue::U32(u32::from_le_bytes(arr)))
771        }
772        0x04 => {
773            let arr: [u8; 8] = plaintext.as_slice().try_into().map_err(|_| {
774                AmateRSError::FheComputation(ErrorContext::new(
775                    "Invalid encrypted u64 constant length".to_string(),
776                ))
777            })?;
778            Ok(CircuitValue::U64(u64::from_le_bytes(arr)))
779        }
780        _ => Err(AmateRSError::FheComputation(ErrorContext::new(format!(
781            "Unknown encrypted constant type tag: 0x{:02x}",
782            type_tag
783        )))),
784    }
785}
786
787/// Recursively walk a circuit tree and encrypt all `Constant` nodes into
788/// `EncryptedConstant` nodes. This is a pre-processing step before FHE
789/// evaluation to ensure no plaintext constants leak into the encrypted
790/// computation.
791pub fn encrypt_circuit_constants(node: &CircuitNode, key: &[u8]) -> Result<CircuitNode> {
792    match node {
793        CircuitNode::Load(name) => Ok(CircuitNode::Load(name.clone())),
794
795        CircuitNode::Constant(value) => {
796            let data = encrypt_constant(value, key)?;
797            let original_type = match value {
798                CircuitValue::Bool(_) => ConstantType::Boolean,
799                CircuitValue::U8(_)
800                | CircuitValue::U16(_)
801                | CircuitValue::U32(_)
802                | CircuitValue::U64(_) => ConstantType::Integer,
803            };
804            Ok(CircuitNode::EncryptedConstant {
805                data,
806                original_type,
807            })
808        }
809
810        // Already encrypted — pass through
811        CircuitNode::EncryptedConstant {
812            data,
813            original_type,
814        } => Ok(CircuitNode::EncryptedConstant {
815            data: data.clone(),
816            original_type: *original_type,
817        }),
818
819        CircuitNode::BinaryOp { op, left, right } => {
820            let left = encrypt_circuit_constants(left, key)?;
821            let right = encrypt_circuit_constants(right, key)?;
822            Ok(CircuitNode::BinaryOp {
823                op: *op,
824                left: Box::new(left),
825                right: Box::new(right),
826            })
827        }
828
829        CircuitNode::UnaryOp { op, operand } => {
830            let operand = encrypt_circuit_constants(operand, key)?;
831            Ok(CircuitNode::UnaryOp {
832                op: *op,
833                operand: Box::new(operand),
834            })
835        }
836
837        CircuitNode::Compare { op, left, right } => {
838            let left = encrypt_circuit_constants(left, key)?;
839            let right = encrypt_circuit_constants(right, key)?;
840            Ok(CircuitNode::Compare {
841                op: *op,
842                left: Box::new(left),
843                right: Box::new(right),
844            })
845        }
846        CircuitNode::NaryOp { op, operands } => {
847            let new_operands: Result<Vec<CircuitNode>> = operands
848                .iter()
849                .map(|o| encrypt_circuit_constants(o, key))
850                .collect();
851            Ok(CircuitNode::NaryOp {
852                op: *op,
853                operands: new_operands?,
854            })
855        }
856    }
857}
858
859/// Derive a deterministic keystream from a key for XOR encryption.
860///
861/// Uses a simple hash-chain approach: each block of the keystream is
862/// derived by hashing (key || block_index) using a lightweight mixing
863/// function. This is NOT cryptographically secure — real FHE would use
864/// the actual FHE encryption scheme.
865fn derive_keystream(key: &[u8], length: usize) -> Vec<u8> {
866    let mut keystream = Vec::with_capacity(length);
867    let mut block_index: u64 = 0;
868
869    while keystream.len() < length {
870        // Simple deterministic mixing: FNV-1a-inspired hash of key + block index
871        let mut hash: u64 = 0xcbf29ce484222325;
872        for &byte in key {
873            hash ^= byte as u64;
874            hash = hash.wrapping_mul(0x100000001b3);
875        }
876        for &byte in &block_index.to_le_bytes() {
877            hash ^= byte as u64;
878            hash = hash.wrapping_mul(0x100000001b3);
879        }
880
881        // Extract bytes from the hash
882        for &byte in &hash.to_le_bytes() {
883            if keystream.len() < length {
884                keystream.push(byte);
885            }
886        }
887
888        block_index += 1;
889    }
890
891    keystream
892}
893
894/// Check whether a circuit node is an encrypted constant
895pub fn is_encrypted_constant(node: &CircuitNode) -> bool {
896    matches!(node, CircuitNode::EncryptedConstant { .. })
897}
898
899/// Count the number of plaintext constants in a circuit tree
900pub fn count_plaintext_constants(node: &CircuitNode) -> usize {
901    match node {
902        CircuitNode::Constant(_) => 1,
903        CircuitNode::EncryptedConstant { .. } | CircuitNode::Load(_) => 0,
904        CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
905            count_plaintext_constants(left) + count_plaintext_constants(right)
906        }
907        CircuitNode::UnaryOp { operand, .. } => count_plaintext_constants(operand),
908        CircuitNode::NaryOp { operands, .. } => {
909            operands.iter().map(count_plaintext_constants).sum()
910        }
911    }
912}
913
914/// Count the number of encrypted constants in a circuit tree
915pub fn count_encrypted_constants(node: &CircuitNode) -> usize {
916    match node {
917        CircuitNode::EncryptedConstant { .. } => 1,
918        CircuitNode::Constant(_) | CircuitNode::Load(_) => 0,
919        CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
920            count_encrypted_constants(left) + count_encrypted_constants(right)
921        }
922        CircuitNode::UnaryOp { operand, .. } => count_encrypted_constants(operand),
923        CircuitNode::NaryOp { operands, .. } => {
924            operands.iter().map(count_encrypted_constants).sum()
925        }
926    }
927}
928
929/// Basic circuit optimizer for backward compatibility
930///
931/// This is a legacy optimizer kept for backward compatibility.
932/// For advanced optimizations, use the `optimizer` module instead.
933#[derive(Debug, Clone, Default)]
934#[deprecated(
935    since = "0.1.0",
936    note = "Use CircuitOptimizer from optimizer module instead"
937)]
938pub struct CircuitOptimizer;
939
940#[allow(deprecated)]
941impl CircuitOptimizer {
942    pub fn new() -> Self {
943        Self
944    }
945
946    /// Optimize circuit by applying basic optimization passes
947    ///
948    /// For advanced optimizations including bootstrap minimization,
949    /// gate fusion, and parallelization analysis, use the optimizer module.
950    pub fn optimize(&self, circuit: Circuit) -> Result<Circuit> {
951        // Delegate to the advanced optimizer
952        let mut advanced_optimizer = crate::compute::optimizer::CircuitOptimizer::new();
953        advanced_optimizer.optimize(circuit)
954    }
955}
956
957#[cfg(test)]
958mod tests {
959    use super::*;
960
961    #[test]
962    fn test_circuit_builder() -> Result<()> {
963        let mut builder = CircuitBuilder::new();
964        builder
965            .declare_variable("a", EncryptedType::U8)
966            .declare_variable("b", EncryptedType::U8);
967
968        let a = builder.load("a");
969        let b = builder.load("b");
970        let sum = builder.add(a, b);
971
972        let circuit = builder.build(sum)?;
973        assert_eq!(circuit.result_type, EncryptedType::U8);
974        assert_eq!(circuit.gate_count, 1);
975
976        Ok(())
977    }
978
979    #[test]
980    fn test_type_inference() -> Result<()> {
981        let mut builder = CircuitBuilder::new();
982        builder
983            .declare_variable("x", EncryptedType::Bool)
984            .declare_variable("y", EncryptedType::Bool);
985
986        let x = builder.load("x");
987        let y = builder.load("y");
988        let result = builder.and(x, y);
989
990        let circuit = builder.build(result)?;
991        assert_eq!(circuit.result_type, EncryptedType::Bool);
992
993        Ok(())
994    }
995
996    #[test]
997    fn test_type_mismatch_error() {
998        let mut builder = CircuitBuilder::new();
999        builder
1000            .declare_variable("a", EncryptedType::U8)
1001            .declare_variable("b", EncryptedType::Bool);
1002
1003        let a = builder.load("a");
1004        let b = builder.load("b");
1005        let invalid = builder.add(a, b);
1006
1007        let result = builder.build(invalid);
1008        assert!(result.is_err());
1009    }
1010
1011    #[test]
1012    #[allow(deprecated)]
1013    fn test_constant_folding() -> Result<()> {
1014        let optimizer = CircuitOptimizer::new();
1015        let builder = CircuitBuilder::new();
1016
1017        let a = builder.constant(CircuitValue::U8(5));
1018        let b = builder.constant(CircuitValue::U8(3));
1019        let sum = builder.add(a, b);
1020
1021        let circuit = Circuit::new(sum, HashMap::new())?;
1022        let optimized = optimizer.optimize(circuit)?;
1023
1024        // Should fold to constant 8
1025        match optimized.root {
1026            CircuitNode::Constant(CircuitValue::U8(8)) => Ok(()),
1027            _ => Err(AmateRSError::FheComputation(ErrorContext::new(
1028                "Constant folding failed".to_string(),
1029            ))),
1030        }
1031    }
1032
1033    #[test]
1034    fn test_circuit_depth() -> Result<()> {
1035        let mut builder = CircuitBuilder::new();
1036        builder
1037            .declare_variable("a", EncryptedType::U8)
1038            .declare_variable("b", EncryptedType::U8)
1039            .declare_variable("c", EncryptedType::U8);
1040
1041        let a = builder.load("a");
1042        let b = builder.load("b");
1043        let c = builder.load("c");
1044
1045        // (a + b) + c
1046        let sum1 = builder.add(a, b);
1047        let sum2 = builder.add(sum1, c);
1048
1049        let circuit = builder.build(sum2)?;
1050        assert_eq!(circuit.depth, 3); // Load -> Add -> Add
1051
1052        Ok(())
1053    }
1054
1055    // ── Encrypted constant tests ──────────────────────────────────────
1056
1057    #[test]
1058    fn test_encrypted_constant_creation() {
1059        let builder = CircuitBuilder::new();
1060        let enc = builder.encrypted_constant(vec![0xAA, 0xBB], ConstantType::Integer);
1061
1062        match enc {
1063            CircuitNode::EncryptedConstant {
1064                data,
1065                original_type,
1066            } => {
1067                assert_eq!(data, vec![0xAA, 0xBB]);
1068                assert_eq!(original_type, ConstantType::Integer);
1069            }
1070            _ => panic!("Expected EncryptedConstant"),
1071        }
1072    }
1073
1074    #[test]
1075    fn test_encrypt_constant_bool() -> Result<()> {
1076        let key = b"test-encryption-key";
1077        let value = CircuitValue::Bool(true);
1078
1079        let encrypted = encrypt_constant(&value, key)?;
1080        assert!(!encrypted.is_empty());
1081        // Type tag should be 0x00 for Bool
1082        assert_eq!(encrypted[0], 0x00);
1083
1084        let decrypted = decrypt_constant(&encrypted, key)?;
1085        assert_eq!(decrypted, value);
1086
1087        Ok(())
1088    }
1089
1090    #[test]
1091    fn test_encrypt_constant_u8() -> Result<()> {
1092        let key = b"test-key-u8";
1093        let value = CircuitValue::U8(42);
1094
1095        let encrypted = encrypt_constant(&value, key)?;
1096        assert_eq!(encrypted[0], 0x01); // Type tag for U8
1097
1098        let decrypted = decrypt_constant(&encrypted, key)?;
1099        assert_eq!(decrypted, value);
1100
1101        Ok(())
1102    }
1103
1104    #[test]
1105    fn test_encrypt_constant_u16() -> Result<()> {
1106        let key = b"test-key-u16";
1107        let value = CircuitValue::U16(12345);
1108
1109        let encrypted = encrypt_constant(&value, key)?;
1110        assert_eq!(encrypted[0], 0x02);
1111
1112        let decrypted = decrypt_constant(&encrypted, key)?;
1113        assert_eq!(decrypted, value);
1114
1115        Ok(())
1116    }
1117
1118    #[test]
1119    fn test_encrypt_constant_u32() -> Result<()> {
1120        let key = b"test-key-u32";
1121        let value = CircuitValue::U32(1_000_000);
1122
1123        let encrypted = encrypt_constant(&value, key)?;
1124        assert_eq!(encrypted[0], 0x03);
1125
1126        let decrypted = decrypt_constant(&encrypted, key)?;
1127        assert_eq!(decrypted, value);
1128
1129        Ok(())
1130    }
1131
1132    #[test]
1133    fn test_encrypt_constant_u64() -> Result<()> {
1134        let key = b"test-key-u64";
1135        let value = CircuitValue::U64(0xDEAD_BEEF_CAFE_BABE);
1136
1137        let encrypted = encrypt_constant(&value, key)?;
1138        assert_eq!(encrypted[0], 0x04);
1139
1140        let decrypted = decrypt_constant(&encrypted, key)?;
1141        assert_eq!(decrypted, value);
1142
1143        Ok(())
1144    }
1145
1146    #[test]
1147    fn test_encrypt_constant_wrong_key_produces_wrong_value() -> Result<()> {
1148        let key1 = b"correct-key";
1149        let key2 = b"wrong-key!!";
1150        let value = CircuitValue::U8(42);
1151
1152        let encrypted = encrypt_constant(&value, key1)?;
1153        let decrypted = decrypt_constant(&encrypted, key2)?;
1154        // With the wrong key, we get a different value (XOR-based, so it decrypts but wrongly)
1155        assert_ne!(decrypted, value);
1156
1157        Ok(())
1158    }
1159
1160    #[test]
1161    fn test_encrypt_constant_empty_key_error() {
1162        let key: &[u8] = &[];
1163        let value = CircuitValue::U8(1);
1164
1165        let result = encrypt_constant(&value, key);
1166        assert!(result.is_err());
1167    }
1168
1169    #[test]
1170    fn test_decrypt_constant_empty_data_error() {
1171        let key = b"some-key";
1172        let result = decrypt_constant(&[], key);
1173        assert!(result.is_err());
1174    }
1175
1176    #[test]
1177    fn test_encrypt_circuit_constants_transforms_all() -> Result<()> {
1178        let builder = CircuitBuilder::new();
1179        let key = b"circuit-encryption-key";
1180
1181        // Build: Constant(5u8) + Constant(3u8)
1182        let a = builder.constant(CircuitValue::U8(5));
1183        let b = builder.constant(CircuitValue::U8(3));
1184        let sum = builder.add(a, b);
1185
1186        // Before encryption: 2 plaintext constants, 0 encrypted
1187        assert_eq!(count_plaintext_constants(&sum), 2);
1188        assert_eq!(count_encrypted_constants(&sum), 0);
1189
1190        // Encrypt
1191        let encrypted = encrypt_circuit_constants(&sum, key)?;
1192
1193        // After encryption: 0 plaintext constants, 2 encrypted
1194        assert_eq!(count_plaintext_constants(&encrypted), 0);
1195        assert_eq!(count_encrypted_constants(&encrypted), 2);
1196
1197        // Verify structure is preserved (BinaryOp Add with two EncryptedConstant children)
1198        match &encrypted {
1199            CircuitNode::BinaryOp { op, left, right } => {
1200                assert_eq!(*op, BinaryOperator::Add);
1201                assert!(matches!(**left, CircuitNode::EncryptedConstant { .. }));
1202                assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1203            }
1204            _ => panic!("Expected BinaryOp after encryption"),
1205        }
1206
1207        Ok(())
1208    }
1209
1210    #[test]
1211    fn test_encrypt_circuit_constants_preserves_loads() -> Result<()> {
1212        let mut builder = CircuitBuilder::new();
1213        builder.declare_variable("x", EncryptedType::U8);
1214        let key = b"key-for-loads-test";
1215
1216        // Build: Load("x") + Constant(10u8)
1217        let x = builder.load("x");
1218        let c = builder.constant(CircuitValue::U8(10));
1219        let sum = builder.add(x, c);
1220
1221        let encrypted = encrypt_circuit_constants(&sum, key)?;
1222
1223        // Load should be preserved, Constant should become EncryptedConstant
1224        match &encrypted {
1225            CircuitNode::BinaryOp { left, right, .. } => {
1226                assert!(matches!(**left, CircuitNode::Load(ref name) if name == "x"));
1227                assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1228            }
1229            _ => panic!("Expected BinaryOp"),
1230        }
1231
1232        Ok(())
1233    }
1234
1235    #[test]
1236    fn test_encrypt_circuit_constants_already_encrypted_pass_through() -> Result<()> {
1237        let builder = CircuitBuilder::new();
1238        let key = b"key-pass-through";
1239
1240        // Create an already-encrypted constant
1241        let enc = builder.encrypted_constant(vec![0x01, 0x02, 0x03], ConstantType::Integer);
1242        let original_data = vec![0x01, 0x02, 0x03];
1243
1244        let result = encrypt_circuit_constants(&enc, key)?;
1245
1246        // The encrypted constant should pass through unchanged
1247        match result {
1248            CircuitNode::EncryptedConstant {
1249                data,
1250                original_type,
1251            } => {
1252                assert_eq!(data, original_data);
1253                assert_eq!(original_type, ConstantType::Integer);
1254            }
1255            _ => panic!("Expected EncryptedConstant pass-through"),
1256        }
1257
1258        Ok(())
1259    }
1260
1261    #[test]
1262    fn test_encrypted_constant_display() {
1263        let node = CircuitNode::EncryptedConstant {
1264            data: vec![0xAA, 0xBB, 0xCC],
1265            original_type: ConstantType::Boolean,
1266        };
1267        let display = format!("{}", node);
1268        assert!(display.contains("EncryptedConst"));
1269        assert!(display.contains("boolean"));
1270        assert!(display.contains("3 bytes"));
1271    }
1272
1273    #[test]
1274    fn test_circuit_node_display_variants() {
1275        // Load
1276        let load = CircuitNode::Load("x".to_string());
1277        assert_eq!(format!("{}", load), "Load(x)");
1278
1279        // Constant
1280        let constant = CircuitNode::Constant(CircuitValue::U8(42));
1281        assert_eq!(format!("{}", constant), "Const(42u8)");
1282
1283        // Bool constant
1284        let bool_const = CircuitNode::Constant(CircuitValue::Bool(true));
1285        assert_eq!(format!("{}", bool_const), "Const(true)");
1286    }
1287
1288    #[test]
1289    fn test_constant_type_display() {
1290        assert_eq!(format!("{}", ConstantType::Integer), "integer");
1291        assert_eq!(format!("{}", ConstantType::Boolean), "boolean");
1292        assert_eq!(format!("{}", ConstantType::Float), "float");
1293        assert_eq!(format!("{}", ConstantType::Bytes), "bytes");
1294    }
1295
1296    #[test]
1297    fn test_constant_type_variants() {
1298        // Verify all variants exist and are distinct
1299        let variants = [
1300            ConstantType::Integer,
1301            ConstantType::Boolean,
1302            ConstantType::Float,
1303            ConstantType::Bytes,
1304        ];
1305        for (i, a) in variants.iter().enumerate() {
1306            for (j, b) in variants.iter().enumerate() {
1307                if i == j {
1308                    assert_eq!(a, b);
1309                } else {
1310                    assert_ne!(a, b);
1311                }
1312            }
1313        }
1314    }
1315
1316    #[test]
1317    fn test_constant_type_serialization_roundtrip() -> Result<()> {
1318        let types = [
1319            ConstantType::Integer,
1320            ConstantType::Boolean,
1321            ConstantType::Float,
1322            ConstantType::Bytes,
1323        ];
1324
1325        for ct in &types {
1326            let json = serde_json::to_string(ct).map_err(|e| {
1327                AmateRSError::FheComputation(ErrorContext::new(format!(
1328                    "Serialization failed: {}",
1329                    e
1330                )))
1331            })?;
1332            let deserialized: ConstantType = serde_json::from_str(&json).map_err(|e| {
1333                AmateRSError::FheComputation(ErrorContext::new(format!(
1334                    "Deserialization failed: {}",
1335                    e
1336                )))
1337            })?;
1338            assert_eq!(*ct, deserialized);
1339        }
1340
1341        Ok(())
1342    }
1343
1344    #[test]
1345    fn test_is_encrypted_constant() {
1346        let enc = CircuitNode::EncryptedConstant {
1347            data: vec![1, 2, 3],
1348            original_type: ConstantType::Integer,
1349        };
1350        assert!(is_encrypted_constant(&enc));
1351
1352        let plain = CircuitNode::Constant(CircuitValue::U8(5));
1353        assert!(!is_encrypted_constant(&plain));
1354
1355        let load = CircuitNode::Load("x".to_string());
1356        assert!(!is_encrypted_constant(&load));
1357    }
1358
1359    #[test]
1360    fn test_encrypted_constant_in_circuit_validation() -> Result<()> {
1361        // EncryptedConstant should pass validation
1362        let enc = CircuitNode::EncryptedConstant {
1363            data: vec![0x00, 0x01],
1364            original_type: ConstantType::Boolean,
1365        };
1366        let circuit = Circuit::new(enc, HashMap::new())?;
1367        circuit.validate()?;
1368        // EncryptedConstant with Boolean type infers to Bool
1369        assert_eq!(circuit.result_type, EncryptedType::Bool);
1370        Ok(())
1371    }
1372
1373    #[test]
1374    fn test_encrypted_constant_depth_and_gate_count() -> Result<()> {
1375        let builder = CircuitBuilder::new();
1376
1377        // EncryptedConstant has depth 1 and gate count 0 (same as Constant/Load)
1378        let enc = builder.encrypted_constant(vec![0x01, 0x42], ConstantType::Integer);
1379        let circuit = Circuit::new(enc, HashMap::new())?;
1380        assert_eq!(circuit.depth, 1);
1381        assert_eq!(circuit.gate_count, 0);
1382
1383        Ok(())
1384    }
1385
1386    #[test]
1387    fn test_mixed_plain_and_encrypted_constants() -> Result<()> {
1388        let builder = CircuitBuilder::new();
1389        let key = b"mixed-circuit-key";
1390
1391        // Build a circuit with both plaintext and encrypted constants
1392        let plain = builder.constant(CircuitValue::U8(10));
1393        let encrypted_data = encrypt_constant(&CircuitValue::U8(20), key)?;
1394        let enc = builder.encrypted_constant(encrypted_data, ConstantType::Integer);
1395
1396        // In a real circuit, these would need compatible types. Here we just
1397        // verify counting works on a mixed tree.
1398        let not_node = CircuitNode::UnaryOp {
1399            op: UnaryOperator::Not,
1400            operand: Box::new(CircuitNode::Constant(CircuitValue::Bool(true))),
1401        };
1402
1403        // Build a dummy tree with both types
1404        // (plain + enc) is not directly buildable due to type mismatch in
1405        // a strict sense, so let's just check counting on a flat structure.
1406        assert_eq!(count_plaintext_constants(&plain), 1);
1407        assert_eq!(count_encrypted_constants(&plain), 0);
1408        assert_eq!(count_plaintext_constants(&enc), 0);
1409        assert_eq!(count_encrypted_constants(&enc), 1);
1410        assert_eq!(count_plaintext_constants(&not_node), 1);
1411        assert_eq!(count_encrypted_constants(&not_node), 0);
1412
1413        Ok(())
1414    }
1415
1416    #[test]
1417    fn test_encrypt_constant_deterministic() -> Result<()> {
1418        let key = b"deterministic-test-key";
1419        let value = CircuitValue::U32(999);
1420
1421        let enc1 = encrypt_constant(&value, key)?;
1422        let enc2 = encrypt_constant(&value, key)?;
1423        // Same key + same value => same ciphertext (deterministic)
1424        assert_eq!(enc1, enc2);
1425
1426        Ok(())
1427    }
1428
1429    #[test]
1430    fn test_encrypt_constant_different_keys_differ() -> Result<()> {
1431        let key1 = b"key-alpha";
1432        let key2 = b"key-bravo";
1433        let value = CircuitValue::U64(123456789);
1434
1435        let enc1 = encrypt_constant(&value, key1)?;
1436        let enc2 = encrypt_constant(&value, key2)?;
1437        // Different keys should produce different ciphertext (with high probability)
1438        // Type tag (first byte) is the same, but payload differs
1439        assert_eq!(enc1[0], enc2[0]); // Same type tag
1440        assert_ne!(enc1[1..], enc2[1..]); // Different payload
1441
1442        Ok(())
1443    }
1444
1445    #[test]
1446    fn test_encrypt_decrypt_roundtrip_all_types() -> Result<()> {
1447        let key = b"roundtrip-all-types";
1448
1449        let values = vec![
1450            CircuitValue::Bool(false),
1451            CircuitValue::Bool(true),
1452            CircuitValue::U8(0),
1453            CircuitValue::U8(255),
1454            CircuitValue::U16(0),
1455            CircuitValue::U16(65535),
1456            CircuitValue::U32(0),
1457            CircuitValue::U32(u32::MAX),
1458            CircuitValue::U64(0),
1459            CircuitValue::U64(u64::MAX),
1460        ];
1461
1462        for value in &values {
1463            let encrypted = encrypt_constant(value, key)?;
1464            let decrypted = decrypt_constant(&encrypted, key)?;
1465            assert_eq!(*value, decrypted, "Roundtrip failed for {:?}", value);
1466        }
1467
1468        Ok(())
1469    }
1470
1471    #[test]
1472    fn test_encrypt_circuit_constants_nested() -> Result<()> {
1473        let builder = CircuitBuilder::new();
1474        let key = b"nested-circuit-key";
1475
1476        // Build: NOT(Constant(true) AND Constant(false))
1477        let t = builder.constant(CircuitValue::Bool(true));
1478        let f = builder.constant(CircuitValue::Bool(false));
1479        let and_node = builder.and(t, f);
1480        let not_node = builder.not(and_node);
1481
1482        assert_eq!(count_plaintext_constants(&not_node), 2);
1483        assert_eq!(count_encrypted_constants(&not_node), 0);
1484
1485        let encrypted = encrypt_circuit_constants(&not_node, key)?;
1486
1487        assert_eq!(count_plaintext_constants(&encrypted), 0);
1488        assert_eq!(count_encrypted_constants(&encrypted), 2);
1489
1490        // Verify structure: NOT(AND(EncryptedConstant, EncryptedConstant))
1491        match &encrypted {
1492            CircuitNode::UnaryOp { op, operand } => {
1493                assert_eq!(*op, UnaryOperator::Not);
1494                match operand.as_ref() {
1495                    CircuitNode::BinaryOp { op, left, right } => {
1496                        assert_eq!(*op, BinaryOperator::And);
1497                        assert!(is_encrypted_constant(left));
1498                        assert!(is_encrypted_constant(right));
1499                    }
1500                    _ => panic!("Expected BinaryOp inside UnaryOp"),
1501                }
1502            }
1503            _ => panic!("Expected UnaryOp at root"),
1504        }
1505
1506        Ok(())
1507    }
1508}