Skip to main content

amaters_core/compute/
circuit.rs

1//! Circuit compilation and optimization
2//!
3//! This module provides circuit AST representation, type inference,
4//! and basic optimization for FHE operations.
5
6use crate::error::{AmateRSError, ErrorContext, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Type tag for encrypted constants, indicating the original plaintext type
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum ConstantType {
13    /// Original value was an integer (u8, u16, u32, u64)
14    Integer,
15    /// Original value was a boolean
16    Boolean,
17    /// Original value was a floating-point number
18    Float,
19    /// Original value was raw bytes
20    Bytes,
21}
22
23impl std::fmt::Display for ConstantType {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            ConstantType::Integer => write!(f, "integer"),
27            ConstantType::Boolean => write!(f, "boolean"),
28            ConstantType::Float => write!(f, "float"),
29            ConstantType::Bytes => write!(f, "bytes"),
30        }
31    }
32}
33
34/// Circuit AST node representing FHE operations
35#[derive(Debug, Clone, PartialEq, Eq)]
36pub enum CircuitNode {
37    /// Load a variable by name
38    Load(String),
39
40    /// Constant value (plaintext)
41    Constant(CircuitValue),
42
43    /// Encrypted constant value (ciphertext form, opaque to the optimizer)
44    ///
45    /// This variant holds a constant that has already been encrypted for use
46    /// in FHE evaluation. The optimizer must NOT attempt to constant-fold or
47    /// simplify encrypted constants since their plaintext values are unknown.
48    EncryptedConstant {
49        /// Encrypted ciphertext data
50        data: Vec<u8>,
51        /// The type of the original plaintext value before encryption
52        original_type: ConstantType,
53    },
54
55    /// Binary operation
56    BinaryOp {
57        op: BinaryOperator,
58        left: Box<CircuitNode>,
59        right: Box<CircuitNode>,
60    },
61
62    /// Unary operation
63    UnaryOp {
64        op: UnaryOperator,
65        operand: Box<CircuitNode>,
66    },
67
68    /// Comparison operation
69    Compare {
70        op: CompareOperator,
71        left: Box<CircuitNode>,
72        right: Box<CircuitNode>,
73    },
74}
75
76/// Binary operators
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum BinaryOperator {
79    Add,
80    Sub,
81    Mul,
82    And,
83    Or,
84    Xor,
85}
86
87impl BinaryOperator {
88    /// Get the string representation
89    pub fn as_str(&self) -> &str {
90        match self {
91            BinaryOperator::Add => "+",
92            BinaryOperator::Sub => "-",
93            BinaryOperator::Mul => "*",
94            BinaryOperator::And => "AND",
95            BinaryOperator::Or => "OR",
96            BinaryOperator::Xor => "XOR",
97        }
98    }
99}
100
101/// Unary operators
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub enum UnaryOperator {
104    Not,
105    Neg,
106}
107
108impl UnaryOperator {
109    /// Get the string representation
110    pub fn as_str(&self) -> &str {
111        match self {
112            UnaryOperator::Not => "NOT",
113            UnaryOperator::Neg => "-",
114        }
115    }
116}
117
118/// Comparison operators
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum CompareOperator {
121    Eq,
122    Ne,
123    Lt,
124    Le,
125    Gt,
126    Ge,
127}
128
129impl CompareOperator {
130    /// Get the string representation
131    pub fn as_str(&self) -> &str {
132        match self {
133            CompareOperator::Eq => "=",
134            CompareOperator::Ne => "!=",
135            CompareOperator::Lt => "<",
136            CompareOperator::Le => "<=",
137            CompareOperator::Gt => ">",
138            CompareOperator::Ge => ">=",
139        }
140    }
141}
142
143/// Circuit value types
144#[derive(Debug, Clone, PartialEq, Eq)]
145pub enum CircuitValue {
146    Bool(bool),
147    U8(u8),
148    U16(u16),
149    U32(u32),
150    U64(u64),
151}
152
153impl std::fmt::Display for CircuitNode {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        match self {
156            CircuitNode::Load(name) => write!(f, "Load({})", name),
157            CircuitNode::Constant(value) => match value {
158                CircuitValue::Bool(v) => write!(f, "Const({})", v),
159                CircuitValue::U8(v) => write!(f, "Const({}u8)", v),
160                CircuitValue::U16(v) => write!(f, "Const({}u16)", v),
161                CircuitValue::U32(v) => write!(f, "Const({}u32)", v),
162                CircuitValue::U64(v) => write!(f, "Const({}u64)", v),
163            },
164            CircuitNode::EncryptedConstant {
165                data,
166                original_type,
167            } => {
168                write!(f, "EncryptedConst({}, {} bytes)", original_type, data.len())
169            }
170            CircuitNode::BinaryOp { op, left, right } => {
171                write!(f, "({} {} {})", left, op.as_str(), right)
172            }
173            CircuitNode::UnaryOp { op, operand } => {
174                write!(f, "{}({})", op.as_str(), operand)
175            }
176            CircuitNode::Compare { op, left, right } => {
177                write!(f, "({} {} {})", left, op.as_str(), right)
178            }
179        }
180    }
181}
182
183/// Encrypted type information
184#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
185pub enum EncryptedType {
186    Bool,
187    U8,
188    U16,
189    U32,
190    U64,
191}
192
193impl EncryptedType {
194    /// Get the bit width of the type
195    pub fn bit_width(&self) -> usize {
196        match self {
197            EncryptedType::Bool => 1,
198            EncryptedType::U8 => 8,
199            EncryptedType::U16 => 16,
200            EncryptedType::U32 => 32,
201            EncryptedType::U64 => 64,
202        }
203    }
204
205    /// Check if this type is numeric
206    pub fn is_numeric(&self) -> bool {
207        !matches!(self, EncryptedType::Bool)
208    }
209
210    /// Check if this type is boolean
211    pub fn is_boolean(&self) -> bool {
212        matches!(self, EncryptedType::Bool)
213    }
214}
215
216impl std::fmt::Display for EncryptedType {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        match self {
219            EncryptedType::Bool => write!(f, "bool"),
220            EncryptedType::U8 => write!(f, "u8"),
221            EncryptedType::U16 => write!(f, "u16"),
222            EncryptedType::U32 => write!(f, "u32"),
223            EncryptedType::U64 => write!(f, "u64"),
224        }
225    }
226}
227
228/// Circuit representation with metadata
229#[derive(Debug, Clone)]
230pub struct Circuit {
231    /// Root node of the circuit
232    pub root: CircuitNode,
233
234    /// Type information for variables
235    pub variable_types: HashMap<String, EncryptedType>,
236
237    /// Inferred result type
238    pub result_type: EncryptedType,
239
240    /// Circuit depth (for complexity estimation)
241    pub depth: usize,
242
243    /// Number of gates (for complexity estimation)
244    pub gate_count: usize,
245}
246
247impl Circuit {
248    /// Create a new circuit from a root node
249    pub fn new(root: CircuitNode, variable_types: HashMap<String, EncryptedType>) -> Result<Self> {
250        let result_type = Self::infer_type(&root, &variable_types)?;
251        let depth = Self::compute_depth(&root);
252        let gate_count = Self::count_gates(&root);
253
254        Ok(Self {
255            root,
256            variable_types,
257            result_type,
258            depth,
259            gate_count,
260        })
261    }
262
263    /// Infer the type of a circuit node
264    fn infer_type(
265        node: &CircuitNode,
266        variable_types: &HashMap<String, EncryptedType>,
267    ) -> Result<EncryptedType> {
268        match node {
269            CircuitNode::Load(name) => variable_types.get(name).copied().ok_or_else(|| {
270                AmateRSError::FheComputation(ErrorContext::new(format!(
271                    "Undefined variable: {}",
272                    name
273                )))
274            }),
275
276            CircuitNode::Constant(value) => Ok(match value {
277                CircuitValue::Bool(_) => EncryptedType::Bool,
278                CircuitValue::U8(_) => EncryptedType::U8,
279                CircuitValue::U16(_) => EncryptedType::U16,
280                CircuitValue::U32(_) => EncryptedType::U32,
281                CircuitValue::U64(_) => EncryptedType::U64,
282            }),
283
284            CircuitNode::EncryptedConstant { original_type, .. } => {
285                Ok(match original_type {
286                    ConstantType::Boolean => EncryptedType::Bool,
287                    // For non-boolean encrypted constants, we default to U64
288                    // since the exact width is not recoverable from the encrypted data.
289                    // In practice, users should ensure type consistency.
290                    ConstantType::Integer | ConstantType::Float | ConstantType::Bytes => {
291                        EncryptedType::U64
292                    }
293                })
294            }
295
296            CircuitNode::BinaryOp { op, left, right } => {
297                let left_type = Self::infer_type(left, variable_types)?;
298                let right_type = Self::infer_type(right, variable_types)?;
299
300                match op {
301                    BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
302                        if left_type != EncryptedType::Bool || right_type != EncryptedType::Bool {
303                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
304                                "Logical operation requires boolean operands, got {} and {}",
305                                left_type, right_type
306                            ))));
307                        }
308                        Ok(EncryptedType::Bool)
309                    }
310
311                    BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
312                        if !left_type.is_numeric() || !right_type.is_numeric() {
313                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
314                                "Arithmetic operation requires numeric operands, got {} and {}",
315                                left_type, right_type
316                            ))));
317                        }
318
319                        if left_type != right_type {
320                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
321                                "Arithmetic operation requires matching types, got {} and {}",
322                                left_type, right_type
323                            ))));
324                        }
325
326                        Ok(left_type)
327                    }
328                }
329            }
330
331            CircuitNode::UnaryOp { op, operand } => {
332                let operand_type = Self::infer_type(operand, variable_types)?;
333
334                match op {
335                    UnaryOperator::Not => {
336                        if operand_type != EncryptedType::Bool {
337                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
338                                "NOT operation requires boolean operand, got {}",
339                                operand_type
340                            ))));
341                        }
342                        Ok(EncryptedType::Bool)
343                    }
344
345                    UnaryOperator::Neg => {
346                        if !operand_type.is_numeric() {
347                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
348                                "Negation operation requires numeric operand, got {}",
349                                operand_type
350                            ))));
351                        }
352                        Ok(operand_type)
353                    }
354                }
355            }
356
357            CircuitNode::Compare { left, right, .. } => {
358                let left_type = Self::infer_type(left, variable_types)?;
359                let right_type = Self::infer_type(right, variable_types)?;
360
361                if left_type != right_type {
362                    return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
363                        "Comparison requires matching types, got {} and {}",
364                        left_type, right_type
365                    ))));
366                }
367
368                Ok(EncryptedType::Bool)
369            }
370        }
371    }
372
373    /// Compute the depth of the circuit
374    fn compute_depth(node: &CircuitNode) -> usize {
375        match node {
376            CircuitNode::Load(_)
377            | CircuitNode::Constant(_)
378            | CircuitNode::EncryptedConstant { .. } => 1,
379
380            CircuitNode::BinaryOp { left, right, .. }
381            | CircuitNode::Compare { left, right, .. } => {
382                1 + Self::compute_depth(left).max(Self::compute_depth(right))
383            }
384
385            CircuitNode::UnaryOp { operand, .. } => 1 + Self::compute_depth(operand),
386        }
387    }
388
389    /// Count the number of gates in the circuit
390    fn count_gates(node: &CircuitNode) -> usize {
391        match node {
392            CircuitNode::Load(_)
393            | CircuitNode::Constant(_)
394            | CircuitNode::EncryptedConstant { .. } => 0,
395
396            CircuitNode::BinaryOp { left, right, .. }
397            | CircuitNode::Compare { left, right, .. } => {
398                1 + Self::count_gates(left) + Self::count_gates(right)
399            }
400
401            CircuitNode::UnaryOp { operand, .. } => 1 + Self::count_gates(operand),
402        }
403    }
404
405    /// Validate the circuit for correctness
406    pub fn validate(&self) -> Result<()> {
407        Self::validate_node(&self.root, &self.variable_types)?;
408        Ok(())
409    }
410
411    fn validate_node(
412        node: &CircuitNode,
413        variable_types: &HashMap<String, EncryptedType>,
414    ) -> Result<()> {
415        match node {
416            CircuitNode::Load(name) => {
417                if !variable_types.contains_key(name) {
418                    return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
419                        "Undefined variable: {}",
420                        name
421                    ))));
422                }
423                Ok(())
424            }
425
426            CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => Ok(()),
427
428            CircuitNode::BinaryOp { left, right, .. }
429            | CircuitNode::Compare { left, right, .. } => {
430                Self::validate_node(left, variable_types)?;
431                Self::validate_node(right, variable_types)?;
432                Ok(())
433            }
434
435            CircuitNode::UnaryOp { operand, .. } => Self::validate_node(operand, variable_types),
436        }
437    }
438}
439
440/// Circuit builder for constructing circuits programmatically
441#[derive(Default)]
442pub struct CircuitBuilder {
443    variable_types: HashMap<String, EncryptedType>,
444}
445
446impl CircuitBuilder {
447    pub fn new() -> Self {
448        Self::default()
449    }
450
451    /// Get the variable types map
452    pub fn variable_types(&self) -> &HashMap<String, EncryptedType> {
453        &self.variable_types
454    }
455
456    /// Clone the variable types map
457    pub fn variable_types_clone(&self) -> HashMap<String, EncryptedType> {
458        self.variable_types.clone()
459    }
460
461    /// Declare a variable with its type
462    pub fn declare_variable(&mut self, name: impl Into<String>, ty: EncryptedType) -> &mut Self {
463        self.variable_types.insert(name.into(), ty);
464        self
465    }
466
467    /// Build the circuit from a root node
468    pub fn build(&self, root: CircuitNode) -> Result<Circuit> {
469        Circuit::new(root, self.variable_types.clone())
470    }
471
472    /// Create a load node
473    pub fn load(&self, name: impl Into<String>) -> CircuitNode {
474        CircuitNode::Load(name.into())
475    }
476
477    /// Create a constant node
478    pub fn constant(&self, value: CircuitValue) -> CircuitNode {
479        CircuitNode::Constant(value)
480    }
481
482    /// Create an addition node
483    pub fn add(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
484        CircuitNode::BinaryOp {
485            op: BinaryOperator::Add,
486            left: Box::new(left),
487            right: Box::new(right),
488        }
489    }
490
491    /// Create a subtraction node
492    pub fn sub(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
493        CircuitNode::BinaryOp {
494            op: BinaryOperator::Sub,
495            left: Box::new(left),
496            right: Box::new(right),
497        }
498    }
499
500    /// Create a multiplication node
501    pub fn mul(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
502        CircuitNode::BinaryOp {
503            op: BinaryOperator::Mul,
504            left: Box::new(left),
505            right: Box::new(right),
506        }
507    }
508
509    /// Create an AND node
510    pub fn and(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
511        CircuitNode::BinaryOp {
512            op: BinaryOperator::And,
513            left: Box::new(left),
514            right: Box::new(right),
515        }
516    }
517
518    /// Create an OR node
519    pub fn or(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
520        CircuitNode::BinaryOp {
521            op: BinaryOperator::Or,
522            left: Box::new(left),
523            right: Box::new(right),
524        }
525    }
526
527    /// Create an XOR node
528    pub fn xor(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
529        CircuitNode::BinaryOp {
530            op: BinaryOperator::Xor,
531            left: Box::new(left),
532            right: Box::new(right),
533        }
534    }
535
536    /// Create a NOT node
537    pub fn not(&self, operand: CircuitNode) -> CircuitNode {
538        CircuitNode::UnaryOp {
539            op: UnaryOperator::Not,
540            operand: Box::new(operand),
541        }
542    }
543
544    /// Create an equality comparison node
545    pub fn eq(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
546        CircuitNode::Compare {
547            op: CompareOperator::Eq,
548            left: Box::new(left),
549            right: Box::new(right),
550        }
551    }
552
553    /// Create a less-than comparison node
554    pub fn lt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
555        CircuitNode::Compare {
556            op: CompareOperator::Lt,
557            left: Box::new(left),
558            right: Box::new(right),
559        }
560    }
561
562    /// Create a greater-than comparison node
563    pub fn gt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
564        CircuitNode::Compare {
565            op: CompareOperator::Gt,
566            left: Box::new(left),
567            right: Box::new(right),
568        }
569    }
570
571    /// Create an encrypted constant node directly
572    pub fn encrypted_constant(&self, data: Vec<u8>, original_type: ConstantType) -> CircuitNode {
573        CircuitNode::EncryptedConstant {
574            data,
575            original_type,
576        }
577    }
578}
579
580// ---------------------------------------------------------------------------
581// Encrypted constant helpers
582// ---------------------------------------------------------------------------
583
584/// Encrypt a plaintext circuit constant value using a symmetric key.
585///
586/// This uses a simple XOR-based stream cipher derived from the key for
587/// demonstration purposes. In production, this would delegate to the
588/// actual FHE encryption backend (e.g., TFHE key-switch + bootstrap).
589///
590/// The output ciphertext contains a 1-byte type tag followed by the
591/// XOR-encrypted payload so that it can be correctly interpreted during
592/// evaluation.
593pub fn encrypt_constant(value: &CircuitValue, key: &[u8]) -> Result<Vec<u8>> {
594    if key.is_empty() {
595        return Err(AmateRSError::FheComputation(ErrorContext::new(
596            "Encryption key must not be empty".to_string(),
597        )));
598    }
599
600    // Serialize the plaintext value to bytes
601    let (type_tag, plaintext): (u8, Vec<u8>) = match value {
602        CircuitValue::Bool(v) => (0x00, vec![if *v { 1 } else { 0 }]),
603        CircuitValue::U8(v) => (0x01, v.to_le_bytes().to_vec()),
604        CircuitValue::U16(v) => (0x02, v.to_le_bytes().to_vec()),
605        CircuitValue::U32(v) => (0x03, v.to_le_bytes().to_vec()),
606        CircuitValue::U64(v) => (0x04, v.to_le_bytes().to_vec()),
607    };
608
609    // Generate a keystream by repeating and hashing the key material
610    let keystream = derive_keystream(key, plaintext.len());
611
612    // XOR plaintext with keystream
613    let ciphertext: Vec<u8> = plaintext
614        .iter()
615        .zip(keystream.iter())
616        .map(|(p, k)| p ^ k)
617        .collect();
618
619    // Prepend type tag (unencrypted, needed for type inference)
620    let mut output = Vec::with_capacity(1 + ciphertext.len());
621    output.push(type_tag);
622    output.extend_from_slice(&ciphertext);
623
624    Ok(output)
625}
626
627/// Decrypt an encrypted constant back to its plaintext value.
628///
629/// Inverse of [`encrypt_constant`]. Returns an error if the data is
630/// malformed or the key does not match.
631pub fn decrypt_constant(data: &[u8], key: &[u8]) -> Result<CircuitValue> {
632    if key.is_empty() {
633        return Err(AmateRSError::FheComputation(ErrorContext::new(
634            "Decryption key must not be empty".to_string(),
635        )));
636    }
637    if data.is_empty() {
638        return Err(AmateRSError::FheComputation(ErrorContext::new(
639            "Encrypted constant data is empty".to_string(),
640        )));
641    }
642
643    let type_tag = data[0];
644    let ciphertext = &data[1..];
645
646    let keystream = derive_keystream(key, ciphertext.len());
647    let plaintext: Vec<u8> = ciphertext
648        .iter()
649        .zip(keystream.iter())
650        .map(|(c, k)| c ^ k)
651        .collect();
652
653    match type_tag {
654        0x00 => {
655            if plaintext.is_empty() {
656                return Err(AmateRSError::FheComputation(ErrorContext::new(
657                    "Encrypted boolean constant has no payload".to_string(),
658                )));
659            }
660            Ok(CircuitValue::Bool(plaintext[0] != 0))
661        }
662        0x01 => {
663            let arr: [u8; 1] = plaintext.as_slice().try_into().map_err(|_| {
664                AmateRSError::FheComputation(ErrorContext::new(
665                    "Invalid encrypted u8 constant length".to_string(),
666                ))
667            })?;
668            Ok(CircuitValue::U8(u8::from_le_bytes(arr)))
669        }
670        0x02 => {
671            let arr: [u8; 2] = plaintext.as_slice().try_into().map_err(|_| {
672                AmateRSError::FheComputation(ErrorContext::new(
673                    "Invalid encrypted u16 constant length".to_string(),
674                ))
675            })?;
676            Ok(CircuitValue::U16(u16::from_le_bytes(arr)))
677        }
678        0x03 => {
679            let arr: [u8; 4] = plaintext.as_slice().try_into().map_err(|_| {
680                AmateRSError::FheComputation(ErrorContext::new(
681                    "Invalid encrypted u32 constant length".to_string(),
682                ))
683            })?;
684            Ok(CircuitValue::U32(u32::from_le_bytes(arr)))
685        }
686        0x04 => {
687            let arr: [u8; 8] = plaintext.as_slice().try_into().map_err(|_| {
688                AmateRSError::FheComputation(ErrorContext::new(
689                    "Invalid encrypted u64 constant length".to_string(),
690                ))
691            })?;
692            Ok(CircuitValue::U64(u64::from_le_bytes(arr)))
693        }
694        _ => Err(AmateRSError::FheComputation(ErrorContext::new(format!(
695            "Unknown encrypted constant type tag: 0x{:02x}",
696            type_tag
697        )))),
698    }
699}
700
701/// Recursively walk a circuit tree and encrypt all `Constant` nodes into
702/// `EncryptedConstant` nodes. This is a pre-processing step before FHE
703/// evaluation to ensure no plaintext constants leak into the encrypted
704/// computation.
705pub fn encrypt_circuit_constants(node: &CircuitNode, key: &[u8]) -> Result<CircuitNode> {
706    match node {
707        CircuitNode::Load(name) => Ok(CircuitNode::Load(name.clone())),
708
709        CircuitNode::Constant(value) => {
710            let data = encrypt_constant(value, key)?;
711            let original_type = match value {
712                CircuitValue::Bool(_) => ConstantType::Boolean,
713                CircuitValue::U8(_)
714                | CircuitValue::U16(_)
715                | CircuitValue::U32(_)
716                | CircuitValue::U64(_) => ConstantType::Integer,
717            };
718            Ok(CircuitNode::EncryptedConstant {
719                data,
720                original_type,
721            })
722        }
723
724        // Already encrypted — pass through
725        CircuitNode::EncryptedConstant {
726            data,
727            original_type,
728        } => Ok(CircuitNode::EncryptedConstant {
729            data: data.clone(),
730            original_type: *original_type,
731        }),
732
733        CircuitNode::BinaryOp { op, left, right } => {
734            let left = encrypt_circuit_constants(left, key)?;
735            let right = encrypt_circuit_constants(right, key)?;
736            Ok(CircuitNode::BinaryOp {
737                op: *op,
738                left: Box::new(left),
739                right: Box::new(right),
740            })
741        }
742
743        CircuitNode::UnaryOp { op, operand } => {
744            let operand = encrypt_circuit_constants(operand, key)?;
745            Ok(CircuitNode::UnaryOp {
746                op: *op,
747                operand: Box::new(operand),
748            })
749        }
750
751        CircuitNode::Compare { op, left, right } => {
752            let left = encrypt_circuit_constants(left, key)?;
753            let right = encrypt_circuit_constants(right, key)?;
754            Ok(CircuitNode::Compare {
755                op: *op,
756                left: Box::new(left),
757                right: Box::new(right),
758            })
759        }
760    }
761}
762
763/// Derive a deterministic keystream from a key for XOR encryption.
764///
765/// Uses a simple hash-chain approach: each block of the keystream is
766/// derived by hashing (key || block_index) using a lightweight mixing
767/// function. This is NOT cryptographically secure — real FHE would use
768/// the actual FHE encryption scheme.
769fn derive_keystream(key: &[u8], length: usize) -> Vec<u8> {
770    let mut keystream = Vec::with_capacity(length);
771    let mut block_index: u64 = 0;
772
773    while keystream.len() < length {
774        // Simple deterministic mixing: FNV-1a-inspired hash of key + block index
775        let mut hash: u64 = 0xcbf29ce484222325;
776        for &byte in key {
777            hash ^= byte as u64;
778            hash = hash.wrapping_mul(0x100000001b3);
779        }
780        for &byte in &block_index.to_le_bytes() {
781            hash ^= byte as u64;
782            hash = hash.wrapping_mul(0x100000001b3);
783        }
784
785        // Extract bytes from the hash
786        for &byte in &hash.to_le_bytes() {
787            if keystream.len() < length {
788                keystream.push(byte);
789            }
790        }
791
792        block_index += 1;
793    }
794
795    keystream
796}
797
798/// Check whether a circuit node is an encrypted constant
799pub fn is_encrypted_constant(node: &CircuitNode) -> bool {
800    matches!(node, CircuitNode::EncryptedConstant { .. })
801}
802
803/// Count the number of plaintext constants in a circuit tree
804pub fn count_plaintext_constants(node: &CircuitNode) -> usize {
805    match node {
806        CircuitNode::Constant(_) => 1,
807        CircuitNode::EncryptedConstant { .. } | CircuitNode::Load(_) => 0,
808        CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
809            count_plaintext_constants(left) + count_plaintext_constants(right)
810        }
811        CircuitNode::UnaryOp { operand, .. } => count_plaintext_constants(operand),
812    }
813}
814
815/// Count the number of encrypted constants in a circuit tree
816pub fn count_encrypted_constants(node: &CircuitNode) -> usize {
817    match node {
818        CircuitNode::EncryptedConstant { .. } => 1,
819        CircuitNode::Constant(_) | CircuitNode::Load(_) => 0,
820        CircuitNode::BinaryOp { left, right, .. } | CircuitNode::Compare { left, right, .. } => {
821            count_encrypted_constants(left) + count_encrypted_constants(right)
822        }
823        CircuitNode::UnaryOp { operand, .. } => count_encrypted_constants(operand),
824    }
825}
826
827/// Basic circuit optimizer for backward compatibility
828///
829/// This is a legacy optimizer kept for backward compatibility.
830/// For advanced optimizations, use the `optimizer` module instead.
831#[derive(Debug, Clone, Default)]
832#[deprecated(
833    since = "0.1.0",
834    note = "Use CircuitOptimizer from optimizer module instead"
835)]
836pub struct CircuitOptimizer;
837
838#[allow(deprecated)]
839impl CircuitOptimizer {
840    pub fn new() -> Self {
841        Self
842    }
843
844    /// Optimize circuit by applying basic optimization passes
845    ///
846    /// For advanced optimizations including bootstrap minimization,
847    /// gate fusion, and parallelization analysis, use the optimizer module.
848    pub fn optimize(&self, circuit: Circuit) -> Result<Circuit> {
849        // Delegate to the advanced optimizer
850        let mut advanced_optimizer = crate::compute::optimizer::CircuitOptimizer::new();
851        advanced_optimizer.optimize(circuit)
852    }
853}
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858
859    #[test]
860    fn test_circuit_builder() -> Result<()> {
861        let mut builder = CircuitBuilder::new();
862        builder
863            .declare_variable("a", EncryptedType::U8)
864            .declare_variable("b", EncryptedType::U8);
865
866        let a = builder.load("a");
867        let b = builder.load("b");
868        let sum = builder.add(a, b);
869
870        let circuit = builder.build(sum)?;
871        assert_eq!(circuit.result_type, EncryptedType::U8);
872        assert_eq!(circuit.gate_count, 1);
873
874        Ok(())
875    }
876
877    #[test]
878    fn test_type_inference() -> Result<()> {
879        let mut builder = CircuitBuilder::new();
880        builder
881            .declare_variable("x", EncryptedType::Bool)
882            .declare_variable("y", EncryptedType::Bool);
883
884        let x = builder.load("x");
885        let y = builder.load("y");
886        let result = builder.and(x, y);
887
888        let circuit = builder.build(result)?;
889        assert_eq!(circuit.result_type, EncryptedType::Bool);
890
891        Ok(())
892    }
893
894    #[test]
895    fn test_type_mismatch_error() {
896        let mut builder = CircuitBuilder::new();
897        builder
898            .declare_variable("a", EncryptedType::U8)
899            .declare_variable("b", EncryptedType::Bool);
900
901        let a = builder.load("a");
902        let b = builder.load("b");
903        let invalid = builder.add(a, b);
904
905        let result = builder.build(invalid);
906        assert!(result.is_err());
907    }
908
909    #[test]
910    #[allow(deprecated)]
911    fn test_constant_folding() -> Result<()> {
912        let optimizer = CircuitOptimizer::new();
913        let builder = CircuitBuilder::new();
914
915        let a = builder.constant(CircuitValue::U8(5));
916        let b = builder.constant(CircuitValue::U8(3));
917        let sum = builder.add(a, b);
918
919        let circuit = Circuit::new(sum, HashMap::new())?;
920        let optimized = optimizer.optimize(circuit)?;
921
922        // Should fold to constant 8
923        match optimized.root {
924            CircuitNode::Constant(CircuitValue::U8(8)) => Ok(()),
925            _ => Err(AmateRSError::FheComputation(ErrorContext::new(
926                "Constant folding failed".to_string(),
927            ))),
928        }
929    }
930
931    #[test]
932    fn test_circuit_depth() -> Result<()> {
933        let mut builder = CircuitBuilder::new();
934        builder
935            .declare_variable("a", EncryptedType::U8)
936            .declare_variable("b", EncryptedType::U8)
937            .declare_variable("c", EncryptedType::U8);
938
939        let a = builder.load("a");
940        let b = builder.load("b");
941        let c = builder.load("c");
942
943        // (a + b) + c
944        let sum1 = builder.add(a, b);
945        let sum2 = builder.add(sum1, c);
946
947        let circuit = builder.build(sum2)?;
948        assert_eq!(circuit.depth, 3); // Load -> Add -> Add
949
950        Ok(())
951    }
952
953    // ── Encrypted constant tests ──────────────────────────────────────
954
955    #[test]
956    fn test_encrypted_constant_creation() {
957        let builder = CircuitBuilder::new();
958        let enc = builder.encrypted_constant(vec![0xAA, 0xBB], ConstantType::Integer);
959
960        match enc {
961            CircuitNode::EncryptedConstant {
962                data,
963                original_type,
964            } => {
965                assert_eq!(data, vec![0xAA, 0xBB]);
966                assert_eq!(original_type, ConstantType::Integer);
967            }
968            _ => panic!("Expected EncryptedConstant"),
969        }
970    }
971
972    #[test]
973    fn test_encrypt_constant_bool() -> Result<()> {
974        let key = b"test-encryption-key";
975        let value = CircuitValue::Bool(true);
976
977        let encrypted = encrypt_constant(&value, key)?;
978        assert!(!encrypted.is_empty());
979        // Type tag should be 0x00 for Bool
980        assert_eq!(encrypted[0], 0x00);
981
982        let decrypted = decrypt_constant(&encrypted, key)?;
983        assert_eq!(decrypted, value);
984
985        Ok(())
986    }
987
988    #[test]
989    fn test_encrypt_constant_u8() -> Result<()> {
990        let key = b"test-key-u8";
991        let value = CircuitValue::U8(42);
992
993        let encrypted = encrypt_constant(&value, key)?;
994        assert_eq!(encrypted[0], 0x01); // Type tag for U8
995
996        let decrypted = decrypt_constant(&encrypted, key)?;
997        assert_eq!(decrypted, value);
998
999        Ok(())
1000    }
1001
1002    #[test]
1003    fn test_encrypt_constant_u16() -> Result<()> {
1004        let key = b"test-key-u16";
1005        let value = CircuitValue::U16(12345);
1006
1007        let encrypted = encrypt_constant(&value, key)?;
1008        assert_eq!(encrypted[0], 0x02);
1009
1010        let decrypted = decrypt_constant(&encrypted, key)?;
1011        assert_eq!(decrypted, value);
1012
1013        Ok(())
1014    }
1015
1016    #[test]
1017    fn test_encrypt_constant_u32() -> Result<()> {
1018        let key = b"test-key-u32";
1019        let value = CircuitValue::U32(1_000_000);
1020
1021        let encrypted = encrypt_constant(&value, key)?;
1022        assert_eq!(encrypted[0], 0x03);
1023
1024        let decrypted = decrypt_constant(&encrypted, key)?;
1025        assert_eq!(decrypted, value);
1026
1027        Ok(())
1028    }
1029
1030    #[test]
1031    fn test_encrypt_constant_u64() -> Result<()> {
1032        let key = b"test-key-u64";
1033        let value = CircuitValue::U64(0xDEAD_BEEF_CAFE_BABE);
1034
1035        let encrypted = encrypt_constant(&value, key)?;
1036        assert_eq!(encrypted[0], 0x04);
1037
1038        let decrypted = decrypt_constant(&encrypted, key)?;
1039        assert_eq!(decrypted, value);
1040
1041        Ok(())
1042    }
1043
1044    #[test]
1045    fn test_encrypt_constant_wrong_key_produces_wrong_value() -> Result<()> {
1046        let key1 = b"correct-key";
1047        let key2 = b"wrong-key!!";
1048        let value = CircuitValue::U8(42);
1049
1050        let encrypted = encrypt_constant(&value, key1)?;
1051        let decrypted = decrypt_constant(&encrypted, key2)?;
1052        // With the wrong key, we get a different value (XOR-based, so it decrypts but wrongly)
1053        assert_ne!(decrypted, value);
1054
1055        Ok(())
1056    }
1057
1058    #[test]
1059    fn test_encrypt_constant_empty_key_error() {
1060        let key: &[u8] = &[];
1061        let value = CircuitValue::U8(1);
1062
1063        let result = encrypt_constant(&value, key);
1064        assert!(result.is_err());
1065    }
1066
1067    #[test]
1068    fn test_decrypt_constant_empty_data_error() {
1069        let key = b"some-key";
1070        let result = decrypt_constant(&[], key);
1071        assert!(result.is_err());
1072    }
1073
1074    #[test]
1075    fn test_encrypt_circuit_constants_transforms_all() -> Result<()> {
1076        let builder = CircuitBuilder::new();
1077        let key = b"circuit-encryption-key";
1078
1079        // Build: Constant(5u8) + Constant(3u8)
1080        let a = builder.constant(CircuitValue::U8(5));
1081        let b = builder.constant(CircuitValue::U8(3));
1082        let sum = builder.add(a, b);
1083
1084        // Before encryption: 2 plaintext constants, 0 encrypted
1085        assert_eq!(count_plaintext_constants(&sum), 2);
1086        assert_eq!(count_encrypted_constants(&sum), 0);
1087
1088        // Encrypt
1089        let encrypted = encrypt_circuit_constants(&sum, key)?;
1090
1091        // After encryption: 0 plaintext constants, 2 encrypted
1092        assert_eq!(count_plaintext_constants(&encrypted), 0);
1093        assert_eq!(count_encrypted_constants(&encrypted), 2);
1094
1095        // Verify structure is preserved (BinaryOp Add with two EncryptedConstant children)
1096        match &encrypted {
1097            CircuitNode::BinaryOp { op, left, right } => {
1098                assert_eq!(*op, BinaryOperator::Add);
1099                assert!(matches!(**left, CircuitNode::EncryptedConstant { .. }));
1100                assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1101            }
1102            _ => panic!("Expected BinaryOp after encryption"),
1103        }
1104
1105        Ok(())
1106    }
1107
1108    #[test]
1109    fn test_encrypt_circuit_constants_preserves_loads() -> Result<()> {
1110        let mut builder = CircuitBuilder::new();
1111        builder.declare_variable("x", EncryptedType::U8);
1112        let key = b"key-for-loads-test";
1113
1114        // Build: Load("x") + Constant(10u8)
1115        let x = builder.load("x");
1116        let c = builder.constant(CircuitValue::U8(10));
1117        let sum = builder.add(x, c);
1118
1119        let encrypted = encrypt_circuit_constants(&sum, key)?;
1120
1121        // Load should be preserved, Constant should become EncryptedConstant
1122        match &encrypted {
1123            CircuitNode::BinaryOp { left, right, .. } => {
1124                assert!(matches!(**left, CircuitNode::Load(ref name) if name == "x"));
1125                assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1126            }
1127            _ => panic!("Expected BinaryOp"),
1128        }
1129
1130        Ok(())
1131    }
1132
1133    #[test]
1134    fn test_encrypt_circuit_constants_already_encrypted_pass_through() -> Result<()> {
1135        let builder = CircuitBuilder::new();
1136        let key = b"key-pass-through";
1137
1138        // Create an already-encrypted constant
1139        let enc = builder.encrypted_constant(vec![0x01, 0x02, 0x03], ConstantType::Integer);
1140        let original_data = vec![0x01, 0x02, 0x03];
1141
1142        let result = encrypt_circuit_constants(&enc, key)?;
1143
1144        // The encrypted constant should pass through unchanged
1145        match result {
1146            CircuitNode::EncryptedConstant {
1147                data,
1148                original_type,
1149            } => {
1150                assert_eq!(data, original_data);
1151                assert_eq!(original_type, ConstantType::Integer);
1152            }
1153            _ => panic!("Expected EncryptedConstant pass-through"),
1154        }
1155
1156        Ok(())
1157    }
1158
1159    #[test]
1160    fn test_encrypted_constant_display() {
1161        let node = CircuitNode::EncryptedConstant {
1162            data: vec![0xAA, 0xBB, 0xCC],
1163            original_type: ConstantType::Boolean,
1164        };
1165        let display = format!("{}", node);
1166        assert!(display.contains("EncryptedConst"));
1167        assert!(display.contains("boolean"));
1168        assert!(display.contains("3 bytes"));
1169    }
1170
1171    #[test]
1172    fn test_circuit_node_display_variants() {
1173        // Load
1174        let load = CircuitNode::Load("x".to_string());
1175        assert_eq!(format!("{}", load), "Load(x)");
1176
1177        // Constant
1178        let constant = CircuitNode::Constant(CircuitValue::U8(42));
1179        assert_eq!(format!("{}", constant), "Const(42u8)");
1180
1181        // Bool constant
1182        let bool_const = CircuitNode::Constant(CircuitValue::Bool(true));
1183        assert_eq!(format!("{}", bool_const), "Const(true)");
1184    }
1185
1186    #[test]
1187    fn test_constant_type_display() {
1188        assert_eq!(format!("{}", ConstantType::Integer), "integer");
1189        assert_eq!(format!("{}", ConstantType::Boolean), "boolean");
1190        assert_eq!(format!("{}", ConstantType::Float), "float");
1191        assert_eq!(format!("{}", ConstantType::Bytes), "bytes");
1192    }
1193
1194    #[test]
1195    fn test_constant_type_variants() {
1196        // Verify all variants exist and are distinct
1197        let variants = [
1198            ConstantType::Integer,
1199            ConstantType::Boolean,
1200            ConstantType::Float,
1201            ConstantType::Bytes,
1202        ];
1203        for (i, a) in variants.iter().enumerate() {
1204            for (j, b) in variants.iter().enumerate() {
1205                if i == j {
1206                    assert_eq!(a, b);
1207                } else {
1208                    assert_ne!(a, b);
1209                }
1210            }
1211        }
1212    }
1213
1214    #[test]
1215    fn test_constant_type_serialization_roundtrip() -> Result<()> {
1216        let types = [
1217            ConstantType::Integer,
1218            ConstantType::Boolean,
1219            ConstantType::Float,
1220            ConstantType::Bytes,
1221        ];
1222
1223        for ct in &types {
1224            let json = serde_json::to_string(ct).map_err(|e| {
1225                AmateRSError::FheComputation(ErrorContext::new(format!(
1226                    "Serialization failed: {}",
1227                    e
1228                )))
1229            })?;
1230            let deserialized: ConstantType = serde_json::from_str(&json).map_err(|e| {
1231                AmateRSError::FheComputation(ErrorContext::new(format!(
1232                    "Deserialization failed: {}",
1233                    e
1234                )))
1235            })?;
1236            assert_eq!(*ct, deserialized);
1237        }
1238
1239        Ok(())
1240    }
1241
1242    #[test]
1243    fn test_is_encrypted_constant() {
1244        let enc = CircuitNode::EncryptedConstant {
1245            data: vec![1, 2, 3],
1246            original_type: ConstantType::Integer,
1247        };
1248        assert!(is_encrypted_constant(&enc));
1249
1250        let plain = CircuitNode::Constant(CircuitValue::U8(5));
1251        assert!(!is_encrypted_constant(&plain));
1252
1253        let load = CircuitNode::Load("x".to_string());
1254        assert!(!is_encrypted_constant(&load));
1255    }
1256
1257    #[test]
1258    fn test_encrypted_constant_in_circuit_validation() -> Result<()> {
1259        // EncryptedConstant should pass validation
1260        let enc = CircuitNode::EncryptedConstant {
1261            data: vec![0x00, 0x01],
1262            original_type: ConstantType::Boolean,
1263        };
1264        let circuit = Circuit::new(enc, HashMap::new())?;
1265        circuit.validate()?;
1266        // EncryptedConstant with Boolean type infers to Bool
1267        assert_eq!(circuit.result_type, EncryptedType::Bool);
1268        Ok(())
1269    }
1270
1271    #[test]
1272    fn test_encrypted_constant_depth_and_gate_count() -> Result<()> {
1273        let builder = CircuitBuilder::new();
1274
1275        // EncryptedConstant has depth 1 and gate count 0 (same as Constant/Load)
1276        let enc = builder.encrypted_constant(vec![0x01, 0x42], ConstantType::Integer);
1277        let circuit = Circuit::new(enc, HashMap::new())?;
1278        assert_eq!(circuit.depth, 1);
1279        assert_eq!(circuit.gate_count, 0);
1280
1281        Ok(())
1282    }
1283
1284    #[test]
1285    fn test_mixed_plain_and_encrypted_constants() -> Result<()> {
1286        let builder = CircuitBuilder::new();
1287        let key = b"mixed-circuit-key";
1288
1289        // Build a circuit with both plaintext and encrypted constants
1290        let plain = builder.constant(CircuitValue::U8(10));
1291        let encrypted_data = encrypt_constant(&CircuitValue::U8(20), key)?;
1292        let enc = builder.encrypted_constant(encrypted_data, ConstantType::Integer);
1293
1294        // In a real circuit, these would need compatible types. Here we just
1295        // verify counting works on a mixed tree.
1296        let not_node = CircuitNode::UnaryOp {
1297            op: UnaryOperator::Not,
1298            operand: Box::new(CircuitNode::Constant(CircuitValue::Bool(true))),
1299        };
1300
1301        // Build a dummy tree with both types
1302        // (plain + enc) is not directly buildable due to type mismatch in
1303        // a strict sense, so let's just check counting on a flat structure.
1304        assert_eq!(count_plaintext_constants(&plain), 1);
1305        assert_eq!(count_encrypted_constants(&plain), 0);
1306        assert_eq!(count_plaintext_constants(&enc), 0);
1307        assert_eq!(count_encrypted_constants(&enc), 1);
1308        assert_eq!(count_plaintext_constants(&not_node), 1);
1309        assert_eq!(count_encrypted_constants(&not_node), 0);
1310
1311        Ok(())
1312    }
1313
1314    #[test]
1315    fn test_encrypt_constant_deterministic() -> Result<()> {
1316        let key = b"deterministic-test-key";
1317        let value = CircuitValue::U32(999);
1318
1319        let enc1 = encrypt_constant(&value, key)?;
1320        let enc2 = encrypt_constant(&value, key)?;
1321        // Same key + same value => same ciphertext (deterministic)
1322        assert_eq!(enc1, enc2);
1323
1324        Ok(())
1325    }
1326
1327    #[test]
1328    fn test_encrypt_constant_different_keys_differ() -> Result<()> {
1329        let key1 = b"key-alpha";
1330        let key2 = b"key-bravo";
1331        let value = CircuitValue::U64(123456789);
1332
1333        let enc1 = encrypt_constant(&value, key1)?;
1334        let enc2 = encrypt_constant(&value, key2)?;
1335        // Different keys should produce different ciphertext (with high probability)
1336        // Type tag (first byte) is the same, but payload differs
1337        assert_eq!(enc1[0], enc2[0]); // Same type tag
1338        assert_ne!(enc1[1..], enc2[1..]); // Different payload
1339
1340        Ok(())
1341    }
1342
1343    #[test]
1344    fn test_encrypt_decrypt_roundtrip_all_types() -> Result<()> {
1345        let key = b"roundtrip-all-types";
1346
1347        let values = vec![
1348            CircuitValue::Bool(false),
1349            CircuitValue::Bool(true),
1350            CircuitValue::U8(0),
1351            CircuitValue::U8(255),
1352            CircuitValue::U16(0),
1353            CircuitValue::U16(65535),
1354            CircuitValue::U32(0),
1355            CircuitValue::U32(u32::MAX),
1356            CircuitValue::U64(0),
1357            CircuitValue::U64(u64::MAX),
1358        ];
1359
1360        for value in &values {
1361            let encrypted = encrypt_constant(value, key)?;
1362            let decrypted = decrypt_constant(&encrypted, key)?;
1363            assert_eq!(*value, decrypted, "Roundtrip failed for {:?}", value);
1364        }
1365
1366        Ok(())
1367    }
1368
1369    #[test]
1370    fn test_encrypt_circuit_constants_nested() -> Result<()> {
1371        let builder = CircuitBuilder::new();
1372        let key = b"nested-circuit-key";
1373
1374        // Build: NOT(Constant(true) AND Constant(false))
1375        let t = builder.constant(CircuitValue::Bool(true));
1376        let f = builder.constant(CircuitValue::Bool(false));
1377        let and_node = builder.and(t, f);
1378        let not_node = builder.not(and_node);
1379
1380        assert_eq!(count_plaintext_constants(&not_node), 2);
1381        assert_eq!(count_encrypted_constants(&not_node), 0);
1382
1383        let encrypted = encrypt_circuit_constants(&not_node, key)?;
1384
1385        assert_eq!(count_plaintext_constants(&encrypted), 0);
1386        assert_eq!(count_encrypted_constants(&encrypted), 2);
1387
1388        // Verify structure: NOT(AND(EncryptedConstant, EncryptedConstant))
1389        match &encrypted {
1390            CircuitNode::UnaryOp { op, operand } => {
1391                assert_eq!(*op, UnaryOperator::Not);
1392                match operand.as_ref() {
1393                    CircuitNode::BinaryOp { op, left, right } => {
1394                        assert_eq!(*op, BinaryOperator::And);
1395                        assert!(is_encrypted_constant(left));
1396                        assert!(is_encrypted_constant(right));
1397                    }
1398                    _ => panic!("Expected BinaryOp inside UnaryOp"),
1399                }
1400            }
1401            _ => panic!("Expected UnaryOp at root"),
1402        }
1403
1404        Ok(())
1405    }
1406}