amaters_core/compute/
circuit.rs

1//! Circuit compilation and optimization
2//!
3//! This module provides circuit AST representation, type inference,
4//! and basic optimization for FHE operations.
5
6use crate::error::{AmateRSError, ErrorContext, Result};
7use std::collections::HashMap;
8
9/// Circuit AST node representing FHE operations
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum CircuitNode {
12    /// Load a variable by name
13    Load(String),
14
15    /// Constant value
16    Constant(CircuitValue),
17
18    /// Binary operation
19    BinaryOp {
20        op: BinaryOperator,
21        left: Box<CircuitNode>,
22        right: Box<CircuitNode>,
23    },
24
25    /// Unary operation
26    UnaryOp {
27        op: UnaryOperator,
28        operand: Box<CircuitNode>,
29    },
30
31    /// Comparison operation
32    Compare {
33        op: CompareOperator,
34        left: Box<CircuitNode>,
35        right: Box<CircuitNode>,
36    },
37}
38
39/// Binary operators
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum BinaryOperator {
42    Add,
43    Sub,
44    Mul,
45    And,
46    Or,
47    Xor,
48}
49
50impl BinaryOperator {
51    /// Get the string representation
52    pub fn as_str(&self) -> &str {
53        match self {
54            BinaryOperator::Add => "+",
55            BinaryOperator::Sub => "-",
56            BinaryOperator::Mul => "*",
57            BinaryOperator::And => "AND",
58            BinaryOperator::Or => "OR",
59            BinaryOperator::Xor => "XOR",
60        }
61    }
62}
63
64/// Unary operators
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum UnaryOperator {
67    Not,
68    Neg,
69}
70
71impl UnaryOperator {
72    /// Get the string representation
73    pub fn as_str(&self) -> &str {
74        match self {
75            UnaryOperator::Not => "NOT",
76            UnaryOperator::Neg => "-",
77        }
78    }
79}
80
81/// Comparison operators
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
83pub enum CompareOperator {
84    Eq,
85    Ne,
86    Lt,
87    Le,
88    Gt,
89    Ge,
90}
91
92impl CompareOperator {
93    /// Get the string representation
94    pub fn as_str(&self) -> &str {
95        match self {
96            CompareOperator::Eq => "=",
97            CompareOperator::Ne => "!=",
98            CompareOperator::Lt => "<",
99            CompareOperator::Le => "<=",
100            CompareOperator::Gt => ">",
101            CompareOperator::Ge => ">=",
102        }
103    }
104}
105
106/// Circuit value types
107#[derive(Debug, Clone, PartialEq, Eq)]
108pub enum CircuitValue {
109    Bool(bool),
110    U8(u8),
111    U16(u16),
112    U32(u32),
113    U64(u64),
114}
115
116/// Encrypted type information
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
118pub enum EncryptedType {
119    Bool,
120    U8,
121    U16,
122    U32,
123    U64,
124}
125
126impl EncryptedType {
127    /// Get the bit width of the type
128    pub fn bit_width(&self) -> usize {
129        match self {
130            EncryptedType::Bool => 1,
131            EncryptedType::U8 => 8,
132            EncryptedType::U16 => 16,
133            EncryptedType::U32 => 32,
134            EncryptedType::U64 => 64,
135        }
136    }
137
138    /// Check if this type is numeric
139    pub fn is_numeric(&self) -> bool {
140        !matches!(self, EncryptedType::Bool)
141    }
142
143    /// Check if this type is boolean
144    pub fn is_boolean(&self) -> bool {
145        matches!(self, EncryptedType::Bool)
146    }
147}
148
149impl std::fmt::Display for EncryptedType {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        match self {
152            EncryptedType::Bool => write!(f, "bool"),
153            EncryptedType::U8 => write!(f, "u8"),
154            EncryptedType::U16 => write!(f, "u16"),
155            EncryptedType::U32 => write!(f, "u32"),
156            EncryptedType::U64 => write!(f, "u64"),
157        }
158    }
159}
160
161/// Circuit representation with metadata
162#[derive(Debug, Clone)]
163pub struct Circuit {
164    /// Root node of the circuit
165    pub root: CircuitNode,
166
167    /// Type information for variables
168    pub variable_types: HashMap<String, EncryptedType>,
169
170    /// Inferred result type
171    pub result_type: EncryptedType,
172
173    /// Circuit depth (for complexity estimation)
174    pub depth: usize,
175
176    /// Number of gates (for complexity estimation)
177    pub gate_count: usize,
178}
179
180impl Circuit {
181    /// Create a new circuit from a root node
182    pub fn new(root: CircuitNode, variable_types: HashMap<String, EncryptedType>) -> Result<Self> {
183        let result_type = Self::infer_type(&root, &variable_types)?;
184        let depth = Self::compute_depth(&root);
185        let gate_count = Self::count_gates(&root);
186
187        Ok(Self {
188            root,
189            variable_types,
190            result_type,
191            depth,
192            gate_count,
193        })
194    }
195
196    /// Infer the type of a circuit node
197    fn infer_type(
198        node: &CircuitNode,
199        variable_types: &HashMap<String, EncryptedType>,
200    ) -> Result<EncryptedType> {
201        match node {
202            CircuitNode::Load(name) => variable_types.get(name).copied().ok_or_else(|| {
203                AmateRSError::FheComputation(ErrorContext::new(format!(
204                    "Undefined variable: {}",
205                    name
206                )))
207            }),
208
209            CircuitNode::Constant(value) => Ok(match value {
210                CircuitValue::Bool(_) => EncryptedType::Bool,
211                CircuitValue::U8(_) => EncryptedType::U8,
212                CircuitValue::U16(_) => EncryptedType::U16,
213                CircuitValue::U32(_) => EncryptedType::U32,
214                CircuitValue::U64(_) => EncryptedType::U64,
215            }),
216
217            CircuitNode::BinaryOp { op, left, right } => {
218                let left_type = Self::infer_type(left, variable_types)?;
219                let right_type = Self::infer_type(right, variable_types)?;
220
221                match op {
222                    BinaryOperator::And | BinaryOperator::Or | BinaryOperator::Xor => {
223                        if left_type != EncryptedType::Bool || right_type != EncryptedType::Bool {
224                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
225                                "Logical operation requires boolean operands, got {} and {}",
226                                left_type, right_type
227                            ))));
228                        }
229                        Ok(EncryptedType::Bool)
230                    }
231
232                    BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
233                        if !left_type.is_numeric() || !right_type.is_numeric() {
234                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
235                                "Arithmetic operation requires numeric operands, got {} and {}",
236                                left_type, right_type
237                            ))));
238                        }
239
240                        if left_type != right_type {
241                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
242                                "Arithmetic operation requires matching types, got {} and {}",
243                                left_type, right_type
244                            ))));
245                        }
246
247                        Ok(left_type)
248                    }
249                }
250            }
251
252            CircuitNode::UnaryOp { op, operand } => {
253                let operand_type = Self::infer_type(operand, variable_types)?;
254
255                match op {
256                    UnaryOperator::Not => {
257                        if operand_type != EncryptedType::Bool {
258                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
259                                "NOT operation requires boolean operand, got {}",
260                                operand_type
261                            ))));
262                        }
263                        Ok(EncryptedType::Bool)
264                    }
265
266                    UnaryOperator::Neg => {
267                        if !operand_type.is_numeric() {
268                            return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
269                                "Negation operation requires numeric operand, got {}",
270                                operand_type
271                            ))));
272                        }
273                        Ok(operand_type)
274                    }
275                }
276            }
277
278            CircuitNode::Compare { left, right, .. } => {
279                let left_type = Self::infer_type(left, variable_types)?;
280                let right_type = Self::infer_type(right, variable_types)?;
281
282                if left_type != right_type {
283                    return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
284                        "Comparison requires matching types, got {} and {}",
285                        left_type, right_type
286                    ))));
287                }
288
289                Ok(EncryptedType::Bool)
290            }
291        }
292    }
293
294    /// Compute the depth of the circuit
295    fn compute_depth(node: &CircuitNode) -> usize {
296        match node {
297            CircuitNode::Load(_) | CircuitNode::Constant(_) => 1,
298
299            CircuitNode::BinaryOp { left, right, .. }
300            | CircuitNode::Compare { left, right, .. } => {
301                1 + Self::compute_depth(left).max(Self::compute_depth(right))
302            }
303
304            CircuitNode::UnaryOp { operand, .. } => 1 + Self::compute_depth(operand),
305        }
306    }
307
308    /// Count the number of gates in the circuit
309    fn count_gates(node: &CircuitNode) -> usize {
310        match node {
311            CircuitNode::Load(_) | CircuitNode::Constant(_) => 0,
312
313            CircuitNode::BinaryOp { left, right, .. }
314            | CircuitNode::Compare { left, right, .. } => {
315                1 + Self::count_gates(left) + Self::count_gates(right)
316            }
317
318            CircuitNode::UnaryOp { operand, .. } => 1 + Self::count_gates(operand),
319        }
320    }
321
322    /// Validate the circuit for correctness
323    pub fn validate(&self) -> Result<()> {
324        Self::validate_node(&self.root, &self.variable_types)?;
325        Ok(())
326    }
327
328    fn validate_node(
329        node: &CircuitNode,
330        variable_types: &HashMap<String, EncryptedType>,
331    ) -> Result<()> {
332        match node {
333            CircuitNode::Load(name) => {
334                if !variable_types.contains_key(name) {
335                    return Err(AmateRSError::FheComputation(ErrorContext::new(format!(
336                        "Undefined variable: {}",
337                        name
338                    ))));
339                }
340                Ok(())
341            }
342
343            CircuitNode::Constant(_) => Ok(()),
344
345            CircuitNode::BinaryOp { left, right, .. }
346            | CircuitNode::Compare { left, right, .. } => {
347                Self::validate_node(left, variable_types)?;
348                Self::validate_node(right, variable_types)?;
349                Ok(())
350            }
351
352            CircuitNode::UnaryOp { operand, .. } => Self::validate_node(operand, variable_types),
353        }
354    }
355}
356
357/// Circuit builder for constructing circuits programmatically
358#[derive(Default)]
359pub struct CircuitBuilder {
360    variable_types: HashMap<String, EncryptedType>,
361}
362
363impl CircuitBuilder {
364    pub fn new() -> Self {
365        Self::default()
366    }
367
368    /// Get the variable types map
369    pub fn variable_types(&self) -> &HashMap<String, EncryptedType> {
370        &self.variable_types
371    }
372
373    /// Clone the variable types map
374    pub fn variable_types_clone(&self) -> HashMap<String, EncryptedType> {
375        self.variable_types.clone()
376    }
377
378    /// Declare a variable with its type
379    pub fn declare_variable(&mut self, name: impl Into<String>, ty: EncryptedType) -> &mut Self {
380        self.variable_types.insert(name.into(), ty);
381        self
382    }
383
384    /// Build the circuit from a root node
385    pub fn build(&self, root: CircuitNode) -> Result<Circuit> {
386        Circuit::new(root, self.variable_types.clone())
387    }
388
389    /// Create a load node
390    pub fn load(&self, name: impl Into<String>) -> CircuitNode {
391        CircuitNode::Load(name.into())
392    }
393
394    /// Create a constant node
395    pub fn constant(&self, value: CircuitValue) -> CircuitNode {
396        CircuitNode::Constant(value)
397    }
398
399    /// Create an addition node
400    pub fn add(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
401        CircuitNode::BinaryOp {
402            op: BinaryOperator::Add,
403            left: Box::new(left),
404            right: Box::new(right),
405        }
406    }
407
408    /// Create a subtraction node
409    pub fn sub(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
410        CircuitNode::BinaryOp {
411            op: BinaryOperator::Sub,
412            left: Box::new(left),
413            right: Box::new(right),
414        }
415    }
416
417    /// Create a multiplication node
418    pub fn mul(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
419        CircuitNode::BinaryOp {
420            op: BinaryOperator::Mul,
421            left: Box::new(left),
422            right: Box::new(right),
423        }
424    }
425
426    /// Create an AND node
427    pub fn and(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
428        CircuitNode::BinaryOp {
429            op: BinaryOperator::And,
430            left: Box::new(left),
431            right: Box::new(right),
432        }
433    }
434
435    /// Create an OR node
436    pub fn or(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
437        CircuitNode::BinaryOp {
438            op: BinaryOperator::Or,
439            left: Box::new(left),
440            right: Box::new(right),
441        }
442    }
443
444    /// Create an XOR node
445    pub fn xor(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
446        CircuitNode::BinaryOp {
447            op: BinaryOperator::Xor,
448            left: Box::new(left),
449            right: Box::new(right),
450        }
451    }
452
453    /// Create a NOT node
454    pub fn not(&self, operand: CircuitNode) -> CircuitNode {
455        CircuitNode::UnaryOp {
456            op: UnaryOperator::Not,
457            operand: Box::new(operand),
458        }
459    }
460
461    /// Create an equality comparison node
462    pub fn eq(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
463        CircuitNode::Compare {
464            op: CompareOperator::Eq,
465            left: Box::new(left),
466            right: Box::new(right),
467        }
468    }
469
470    /// Create a less-than comparison node
471    pub fn lt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
472        CircuitNode::Compare {
473            op: CompareOperator::Lt,
474            left: Box::new(left),
475            right: Box::new(right),
476        }
477    }
478
479    /// Create a greater-than comparison node
480    pub fn gt(&self, left: CircuitNode, right: CircuitNode) -> CircuitNode {
481        CircuitNode::Compare {
482            op: CompareOperator::Gt,
483            left: Box::new(left),
484            right: Box::new(right),
485        }
486    }
487}
488
489/// Basic circuit optimizer for backward compatibility
490///
491/// This is a legacy optimizer kept for backward compatibility.
492/// For advanced optimizations, use the `optimizer` module instead.
493#[derive(Debug, Clone, Default)]
494#[deprecated(
495    since = "0.1.0",
496    note = "Use CircuitOptimizer from optimizer module instead"
497)]
498pub struct CircuitOptimizer;
499
500#[allow(deprecated)]
501impl CircuitOptimizer {
502    pub fn new() -> Self {
503        Self
504    }
505
506    /// Optimize circuit by applying basic optimization passes
507    ///
508    /// For advanced optimizations including bootstrap minimization,
509    /// gate fusion, and parallelization analysis, use the optimizer module.
510    pub fn optimize(&self, circuit: Circuit) -> Result<Circuit> {
511        // Delegate to the advanced optimizer
512        let mut advanced_optimizer = crate::compute::optimizer::CircuitOptimizer::new();
513        advanced_optimizer.optimize(circuit)
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_circuit_builder() -> Result<()> {
523        let mut builder = CircuitBuilder::new();
524        builder
525            .declare_variable("a", EncryptedType::U8)
526            .declare_variable("b", EncryptedType::U8);
527
528        let a = builder.load("a");
529        let b = builder.load("b");
530        let sum = builder.add(a, b);
531
532        let circuit = builder.build(sum)?;
533        assert_eq!(circuit.result_type, EncryptedType::U8);
534        assert_eq!(circuit.gate_count, 1);
535
536        Ok(())
537    }
538
539    #[test]
540    fn test_type_inference() -> Result<()> {
541        let mut builder = CircuitBuilder::new();
542        builder
543            .declare_variable("x", EncryptedType::Bool)
544            .declare_variable("y", EncryptedType::Bool);
545
546        let x = builder.load("x");
547        let y = builder.load("y");
548        let result = builder.and(x, y);
549
550        let circuit = builder.build(result)?;
551        assert_eq!(circuit.result_type, EncryptedType::Bool);
552
553        Ok(())
554    }
555
556    #[test]
557    fn test_type_mismatch_error() {
558        let mut builder = CircuitBuilder::new();
559        builder
560            .declare_variable("a", EncryptedType::U8)
561            .declare_variable("b", EncryptedType::Bool);
562
563        let a = builder.load("a");
564        let b = builder.load("b");
565        let invalid = builder.add(a, b);
566
567        let result = builder.build(invalid);
568        assert!(result.is_err());
569    }
570
571    #[test]
572    #[allow(deprecated)]
573    fn test_constant_folding() -> Result<()> {
574        let optimizer = CircuitOptimizer::new();
575        let builder = CircuitBuilder::new();
576
577        let a = builder.constant(CircuitValue::U8(5));
578        let b = builder.constant(CircuitValue::U8(3));
579        let sum = builder.add(a, b);
580
581        let circuit = Circuit::new(sum, HashMap::new())?;
582        let optimized = optimizer.optimize(circuit)?;
583
584        // Should fold to constant 8
585        match optimized.root {
586            CircuitNode::Constant(CircuitValue::U8(8)) => Ok(()),
587            _ => Err(AmateRSError::FheComputation(ErrorContext::new(
588                "Constant folding failed".to_string(),
589            ))),
590        }
591    }
592
593    #[test]
594    fn test_circuit_depth() -> Result<()> {
595        let mut builder = CircuitBuilder::new();
596        builder
597            .declare_variable("a", EncryptedType::U8)
598            .declare_variable("b", EncryptedType::U8)
599            .declare_variable("c", EncryptedType::U8);
600
601        let a = builder.load("a");
602        let b = builder.load("b");
603        let c = builder.load("c");
604
605        // (a + b) + c
606        let sum1 = builder.add(a, b);
607        let sum2 = builder.add(sum1, c);
608
609        let circuit = builder.build(sum2)?;
610        assert_eq!(circuit.depth, 3); // Load -> Add -> Add
611
612        Ok(())
613    }
614}