Skip to main content

amaters_core/compute/
optimizer.rs

1//! Advanced circuit optimization for FHE computations
2//!
3//! This module provides sophisticated optimization passes to reduce the computational
4//! cost of FHE circuits. The optimizer focuses on:
5//!
6//! 1. **Bootstrap Minimization** - Reducing expensive bootstrap operations
7//! 2. **Gate Fusion** - Combining adjacent operations to reduce overhead
8//! 3. **Dead Code Elimination** - Removing unused operations
9//! 4. **Parallelization Analysis** - Identifying independent operations for parallel execution
10//!
11//! These optimizations can reduce circuit execution time by 30-50% in typical cases.
12
13use crate::compute::circuit::{
14    BinaryOperator, Circuit, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
15    UnaryOperator,
16};
17use crate::error::{AmateRSError, ErrorContext, Result};
18use std::collections::{HashMap, HashSet, VecDeque};
19
20/// Statistics collected during optimization
21#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct OptimizationStats {
23    /// Number of gates before optimization
24    pub original_gate_count: usize,
25
26    /// Number of gates after optimization
27    pub optimized_gate_count: usize,
28
29    /// Number of bootstrap operations before optimization
30    pub original_bootstrap_count: usize,
31
32    /// Number of bootstrap operations after optimization
33    pub optimized_bootstrap_count: usize,
34
35    /// Number of dead code nodes removed
36    pub dead_code_removed: usize,
37
38    /// Number of nodes eliminated by DCE pass
39    pub nodes_eliminated: usize,
40
41    /// Number of algebraic simplifications applied
42    pub algebraic_simplifications: usize,
43
44    /// Number of constant expressions folded
45    pub constants_folded: usize,
46
47    /// Number of gates fused
48    pub gates_fused: usize,
49
50    /// Circuit depth before optimization
51    pub original_depth: usize,
52
53    /// Circuit depth after optimization
54    pub optimized_depth: usize,
55}
56
57impl OptimizationStats {
58    /// Calculate the reduction percentage in gate count
59    pub fn gate_reduction_percent(&self) -> f64 {
60        if self.original_gate_count == 0 {
61            return 0.0;
62        }
63        let reduction = self
64            .original_gate_count
65            .saturating_sub(self.optimized_gate_count);
66        (reduction as f64 / self.original_gate_count as f64) * 100.0
67    }
68
69    /// Calculate the reduction percentage in bootstrap operations
70    pub fn bootstrap_reduction_percent(&self) -> f64 {
71        if self.original_bootstrap_count == 0 {
72            return 0.0;
73        }
74        let reduction = self
75            .original_bootstrap_count
76            .saturating_sub(self.optimized_bootstrap_count);
77        (reduction as f64 / self.original_bootstrap_count as f64) * 100.0
78    }
79
80    /// Aggregate total statistics across all passes
81    pub fn total_stats(&self) -> (usize, usize, usize) {
82        (
83            self.nodes_eliminated + self.dead_code_removed,
84            self.algebraic_simplifications + self.gates_fused,
85            self.constants_folded,
86        )
87    }
88}
89
90/// Dependency information for parallelization
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub struct DependencyGraph {
93    /// Node ID to its dependencies
94    pub dependencies: HashMap<NodeId, Vec<NodeId>>,
95
96    /// Nodes that can be executed in parallel (sets of independent nodes)
97    pub parallel_groups: Vec<Vec<NodeId>>,
98
99    /// Critical path through the circuit (longest dependency chain)
100    pub critical_path: Vec<NodeId>,
101
102    /// Total number of nodes in the graph
103    pub node_count: usize,
104}
105
106impl DependencyGraph {
107    /// Create an empty dependency graph
108    pub fn new() -> Self {
109        Self {
110            dependencies: HashMap::new(),
111            parallel_groups: Vec::new(),
112            critical_path: Vec::new(),
113            node_count: 0,
114        }
115    }
116
117    /// Calculate the maximum parallelism (largest parallel group)
118    pub fn max_parallelism(&self) -> usize {
119        self.parallel_groups
120            .iter()
121            .map(|g| g.len())
122            .max()
123            .unwrap_or(0)
124    }
125
126    /// Calculate the average parallelism
127    pub fn avg_parallelism(&self) -> f64 {
128        if self.parallel_groups.is_empty() {
129            return 0.0;
130        }
131        let total: usize = self.parallel_groups.iter().map(|g| g.len()).sum();
132        total as f64 / self.parallel_groups.len() as f64
133    }
134}
135
136impl Default for DependencyGraph {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142/// Node identifier for dependency tracking
143#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
144pub struct NodeId(pub usize);
145
146/// Advanced circuit optimizer with multiple optimization passes
147#[derive(Debug, Clone)]
148pub struct CircuitOptimizer {
149    /// Enable constant folding optimization
150    pub enable_constant_folding: bool,
151
152    /// Enable dead code elimination
153    pub enable_dead_code_elimination: bool,
154
155    /// Enable bootstrap minimization
156    pub enable_bootstrap_minimization: bool,
157
158    /// Enable gate fusion
159    pub enable_gate_fusion: bool,
160
161    /// Enable parallelization analysis
162    pub enable_parallelization_analysis: bool,
163
164    /// Statistics from the last optimization
165    stats: OptimizationStats,
166
167    /// Dependency graph from the last optimization
168    dependency_graph: DependencyGraph,
169}
170
171impl CircuitOptimizer {
172    /// Create a new optimizer with all optimizations enabled
173    pub fn new() -> Self {
174        Self {
175            enable_constant_folding: true,
176            enable_dead_code_elimination: true,
177            enable_bootstrap_minimization: true,
178            enable_gate_fusion: true,
179            enable_parallelization_analysis: true,
180            stats: OptimizationStats::default(),
181            dependency_graph: DependencyGraph::new(),
182        }
183    }
184
185    /// Create an optimizer with no optimizations enabled
186    pub fn disabled() -> Self {
187        Self {
188            enable_constant_folding: false,
189            enable_dead_code_elimination: false,
190            enable_bootstrap_minimization: false,
191            enable_gate_fusion: false,
192            enable_parallelization_analysis: false,
193            stats: OptimizationStats::default(),
194            dependency_graph: DependencyGraph::new(),
195        }
196    }
197
198    /// Get the statistics from the last optimization
199    pub fn stats(&self) -> &OptimizationStats {
200        &self.stats
201    }
202
203    /// Get the dependency graph from the last optimization
204    pub fn dependency_graph(&self) -> &DependencyGraph {
205        &self.dependency_graph
206    }
207
208    /// Get aggregated totals: (nodes_eliminated, algebraic_simplifications, constant_folds)
209    pub fn total_stats(&self) -> (usize, usize, usize) {
210        self.stats.total_stats()
211    }
212
213    /// Optimize a circuit by applying all enabled optimization passes
214    pub fn optimize(&mut self, circuit: Circuit) -> Result<Circuit> {
215        // Record original statistics
216        self.stats.original_gate_count = circuit.gate_count;
217        self.stats.original_depth = circuit.depth;
218        self.stats.original_bootstrap_count = self.count_bootstraps(&circuit.root);
219
220        let mut optimized_root = circuit.root.clone();
221
222        // Apply optimization passes in order
223        if self.enable_constant_folding {
224            optimized_root = self.constant_folding_pass(optimized_root);
225        }
226
227        if self.enable_gate_fusion {
228            optimized_root = self.gate_fusion_pass(optimized_root);
229        }
230
231        if self.enable_bootstrap_minimization {
232            optimized_root = self.bootstrap_minimization_pass(optimized_root)?;
233        }
234
235        if self.enable_dead_code_elimination {
236            optimized_root = self.dead_code_elimination_pass(optimized_root);
237        }
238
239        // Build optimized circuit
240        let optimized_circuit = Circuit::new(optimized_root, circuit.variable_types)?;
241
242        // Record optimized statistics
243        self.stats.optimized_gate_count = optimized_circuit.gate_count;
244        self.stats.optimized_depth = optimized_circuit.depth;
245        self.stats.optimized_bootstrap_count = self.count_bootstraps(&optimized_circuit.root);
246
247        // Analyze parallelization if enabled
248        if self.enable_parallelization_analysis {
249            self.dependency_graph = self.analyze_parallelism(&optimized_circuit)?;
250        }
251
252        Ok(optimized_circuit)
253    }
254
255    /// Count the number of bootstrap operations in a circuit
256    ///
257    /// In TFHE, bootstrapping is required after certain operations to refresh noise.
258    /// For this implementation, we estimate bootstraps based on operation types:
259    /// - Multiplication requires bootstrap
260    /// - Comparison operations require bootstrap
261    /// - Deep chains of additions may require bootstrap
262    #[allow(clippy::only_used_in_recursion)]
263    fn count_bootstraps(&self, node: &CircuitNode) -> usize {
264        match node {
265            CircuitNode::Load(_)
266            | CircuitNode::Constant(_)
267            | CircuitNode::EncryptedConstant { .. } => 0,
268
269            CircuitNode::BinaryOp { op, left, right } => {
270                let left_bootstraps = self.count_bootstraps(left);
271                let right_bootstraps = self.count_bootstraps(right);
272
273                // Multiplication requires bootstrap
274                let op_bootstrap = match op {
275                    BinaryOperator::Mul => 1,
276                    _ => 0,
277                };
278
279                left_bootstraps + right_bootstraps + op_bootstrap
280            }
281
282            CircuitNode::UnaryOp { operand, .. } => self.count_bootstraps(operand),
283
284            CircuitNode::Compare { left, right, .. } => {
285                let left_bootstraps = self.count_bootstraps(left);
286                let right_bootstraps = self.count_bootstraps(right);
287
288                // Comparisons typically require bootstrap
289                left_bootstraps + right_bootstraps + 1
290            }
291        }
292    }
293
294    /// Constant folding optimization pass
295    ///
296    /// Evaluates constant expressions at compile time to reduce runtime computation
297    fn constant_folding_pass(&mut self, node: CircuitNode) -> CircuitNode {
298        match node {
299            CircuitNode::BinaryOp { op, left, right } => {
300                let left = self.constant_folding_pass(*left);
301                let right = self.constant_folding_pass(*right);
302
303                // Try to fold constants
304                if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
305                    if let Some(result) = self.fold_binary_constants(op, l, r) {
306                        self.stats.constants_folded += 1;
307                        return CircuitNode::Constant(result);
308                    }
309                }
310
311                // Apply algebraic identities
312                if let Some(simplified) = self.apply_algebraic_identities(op, &left, &right) {
313                    return simplified;
314                }
315
316                CircuitNode::BinaryOp {
317                    op,
318                    left: Box::new(left),
319                    right: Box::new(right),
320                }
321            }
322
323            CircuitNode::UnaryOp { op, operand } => {
324                let operand = self.constant_folding_pass(*operand);
325
326                if let CircuitNode::Constant(val) = &operand {
327                    if let Some(result) = self.fold_unary_constant(op, val) {
328                        self.stats.constants_folded += 1;
329                        return CircuitNode::Constant(result);
330                    }
331                }
332
333                CircuitNode::UnaryOp {
334                    op,
335                    operand: Box::new(operand),
336                }
337            }
338
339            CircuitNode::Compare { op, left, right } => {
340                let left = self.constant_folding_pass(*left);
341                let right = self.constant_folding_pass(*right);
342
343                CircuitNode::Compare {
344                    op,
345                    left: Box::new(left),
346                    right: Box::new(right),
347                }
348            }
349
350            other => other,
351        }
352    }
353
354    /// Fold binary operation on constants
355    fn fold_binary_constants(
356        &self,
357        op: BinaryOperator,
358        left: &CircuitValue,
359        right: &CircuitValue,
360    ) -> Option<CircuitValue> {
361        match (left, right) {
362            (CircuitValue::U8(l), CircuitValue::U8(r)) => match op {
363                BinaryOperator::Add => Some(CircuitValue::U8(l.wrapping_add(*r))),
364                BinaryOperator::Sub => Some(CircuitValue::U8(l.wrapping_sub(*r))),
365                BinaryOperator::Mul => Some(CircuitValue::U8(l.wrapping_mul(*r))),
366                _ => None,
367            },
368            (CircuitValue::U16(l), CircuitValue::U16(r)) => match op {
369                BinaryOperator::Add => Some(CircuitValue::U16(l.wrapping_add(*r))),
370                BinaryOperator::Sub => Some(CircuitValue::U16(l.wrapping_sub(*r))),
371                BinaryOperator::Mul => Some(CircuitValue::U16(l.wrapping_mul(*r))),
372                _ => None,
373            },
374            (CircuitValue::U32(l), CircuitValue::U32(r)) => match op {
375                BinaryOperator::Add => Some(CircuitValue::U32(l.wrapping_add(*r))),
376                BinaryOperator::Sub => Some(CircuitValue::U32(l.wrapping_sub(*r))),
377                BinaryOperator::Mul => Some(CircuitValue::U32(l.wrapping_mul(*r))),
378                _ => None,
379            },
380            (CircuitValue::U64(l), CircuitValue::U64(r)) => match op {
381                BinaryOperator::Add => Some(CircuitValue::U64(l.wrapping_add(*r))),
382                BinaryOperator::Sub => Some(CircuitValue::U64(l.wrapping_sub(*r))),
383                BinaryOperator::Mul => Some(CircuitValue::U64(l.wrapping_mul(*r))),
384                _ => None,
385            },
386            (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
387                BinaryOperator::And => Some(CircuitValue::Bool(*l && *r)),
388                BinaryOperator::Or => Some(CircuitValue::Bool(*l || *r)),
389                BinaryOperator::Xor => Some(CircuitValue::Bool(*l ^ *r)),
390                _ => None,
391            },
392            _ => None,
393        }
394    }
395
396    /// Fold unary operation on constant
397    fn fold_unary_constant(&self, op: UnaryOperator, value: &CircuitValue) -> Option<CircuitValue> {
398        match (op, value) {
399            (UnaryOperator::Not, CircuitValue::Bool(v)) => Some(CircuitValue::Bool(!*v)),
400            _ => None,
401        }
402    }
403
404    /// Apply algebraic identities to simplify expressions
405    /// Examples: x + 0 = x, x * 1 = x, x * 0 = 0, x AND true = x, etc.
406    fn apply_algebraic_identities(
407        &mut self,
408        op: BinaryOperator,
409        left: &CircuitNode,
410        right: &CircuitNode,
411    ) -> Option<CircuitNode> {
412        match op {
413            BinaryOperator::Add => {
414                // x + 0 = x
415                if Self::is_zero(right) {
416                    self.stats.gates_fused += 1;
417                    return Some(left.clone());
418                }
419                // 0 + x = x
420                if Self::is_zero(left) {
421                    self.stats.gates_fused += 1;
422                    return Some(right.clone());
423                }
424            }
425
426            BinaryOperator::Sub => {
427                // x - 0 = x
428                if Self::is_zero(right) {
429                    self.stats.gates_fused += 1;
430                    return Some(left.clone());
431                }
432            }
433
434            BinaryOperator::Mul => {
435                // x * 0 = 0
436                if Self::is_zero(right) {
437                    self.stats.gates_fused += 1;
438                    return Some(right.clone());
439                }
440                if Self::is_zero(left) {
441                    self.stats.gates_fused += 1;
442                    return Some(left.clone());
443                }
444
445                // x * 1 = x
446                if Self::is_one(right) {
447                    self.stats.gates_fused += 1;
448                    return Some(left.clone());
449                }
450                // 1 * x = x
451                if Self::is_one(left) {
452                    self.stats.gates_fused += 1;
453                    return Some(right.clone());
454                }
455            }
456
457            BinaryOperator::And => {
458                // x AND true = x
459                if Self::is_true(right) {
460                    self.stats.gates_fused += 1;
461                    return Some(left.clone());
462                }
463                if Self::is_true(left) {
464                    self.stats.gates_fused += 1;
465                    return Some(right.clone());
466                }
467
468                // x AND false = false
469                if Self::is_false(right) {
470                    self.stats.gates_fused += 1;
471                    return Some(right.clone());
472                }
473                if Self::is_false(left) {
474                    self.stats.gates_fused += 1;
475                    return Some(left.clone());
476                }
477            }
478
479            BinaryOperator::Or => {
480                // x OR false = x
481                if Self::is_false(right) {
482                    self.stats.gates_fused += 1;
483                    return Some(left.clone());
484                }
485                if Self::is_false(left) {
486                    self.stats.gates_fused += 1;
487                    return Some(right.clone());
488                }
489
490                // x OR true = true
491                if Self::is_true(right) {
492                    self.stats.gates_fused += 1;
493                    return Some(right.clone());
494                }
495                if Self::is_true(left) {
496                    self.stats.gates_fused += 1;
497                    return Some(left.clone());
498                }
499            }
500
501            BinaryOperator::Xor => {
502                // x XOR false = x
503                if Self::is_false(right) {
504                    self.stats.gates_fused += 1;
505                    return Some(left.clone());
506                }
507                if Self::is_false(left) {
508                    self.stats.gates_fused += 1;
509                    return Some(right.clone());
510                }
511            }
512        }
513
514        None
515    }
516
517    /// Check if a node is constant zero
518    fn is_zero(node: &CircuitNode) -> bool {
519        matches!(
520            node,
521            CircuitNode::Constant(CircuitValue::U8(0))
522                | CircuitNode::Constant(CircuitValue::U16(0))
523                | CircuitNode::Constant(CircuitValue::U32(0))
524                | CircuitNode::Constant(CircuitValue::U64(0))
525        )
526    }
527
528    /// Check if a node is constant one
529    fn is_one(node: &CircuitNode) -> bool {
530        matches!(
531            node,
532            CircuitNode::Constant(CircuitValue::U8(1))
533                | CircuitNode::Constant(CircuitValue::U16(1))
534                | CircuitNode::Constant(CircuitValue::U32(1))
535                | CircuitNode::Constant(CircuitValue::U64(1))
536        )
537    }
538
539    /// Check if a node is constant true
540    fn is_true(node: &CircuitNode) -> bool {
541        matches!(node, CircuitNode::Constant(CircuitValue::Bool(true)))
542    }
543
544    /// Check if a node is constant false
545    fn is_false(node: &CircuitNode) -> bool {
546        matches!(node, CircuitNode::Constant(CircuitValue::Bool(false)))
547    }
548
549    /// Gate fusion optimization pass
550    ///
551    /// Combines adjacent operations to reduce overhead. For example:
552    /// - (a + b) + c can be fused into a single multi-input addition
553    /// - Multiple consecutive NOT operations can be eliminated
554    fn gate_fusion_pass(&mut self, node: CircuitNode) -> CircuitNode {
555        match node {
556            CircuitNode::BinaryOp { op, left, right } => {
557                let left = self.gate_fusion_pass(*left);
558                let right = self.gate_fusion_pass(*right);
559
560                CircuitNode::BinaryOp {
561                    op,
562                    left: Box::new(left),
563                    right: Box::new(right),
564                }
565            }
566
567            CircuitNode::UnaryOp {
568                op: UnaryOperator::Not,
569                operand,
570            } => {
571                let operand = self.gate_fusion_pass(*operand);
572
573                // NOT(NOT(x)) = x
574                if let CircuitNode::UnaryOp {
575                    op: UnaryOperator::Not,
576                    operand: inner,
577                } = operand
578                {
579                    self.stats.gates_fused += 2; // Removed 2 NOT gates
580                    return *inner;
581                }
582
583                CircuitNode::UnaryOp {
584                    op: UnaryOperator::Not,
585                    operand: Box::new(operand),
586                }
587            }
588
589            CircuitNode::UnaryOp { op, operand } => {
590                let operand = self.gate_fusion_pass(*operand);
591                CircuitNode::UnaryOp {
592                    op,
593                    operand: Box::new(operand),
594                }
595            }
596
597            CircuitNode::Compare { op, left, right } => {
598                let left = self.gate_fusion_pass(*left);
599                let right = self.gate_fusion_pass(*right);
600
601                CircuitNode::Compare {
602                    op,
603                    left: Box::new(left),
604                    right: Box::new(right),
605                }
606            }
607
608            other => other,
609        }
610    }
611
612    /// Bootstrap minimization pass
613    ///
614    /// Analyzes the circuit to minimize expensive bootstrap operations by:
615    /// - Reordering operations to delay bootstraps
616    /// - Combining operations that share bootstrap requirements
617    /// - Eliminating redundant bootstraps
618    fn bootstrap_minimization_pass(&mut self, node: CircuitNode) -> Result<CircuitNode> {
619        // For now, we apply a simple optimization: reorder additions before multiplications
620        // This allows us to batch cheap operations before expensive ones
621        Ok(self.reorder_for_bootstrap_efficiency(node))
622    }
623
624    /// Reorder operations to minimize bootstraps
625    #[allow(clippy::only_used_in_recursion)]
626    fn reorder_for_bootstrap_efficiency(&self, node: CircuitNode) -> CircuitNode {
627        match node {
628            CircuitNode::BinaryOp { op, left, right } => {
629                let left = self.reorder_for_bootstrap_efficiency(*left);
630                let right = self.reorder_for_bootstrap_efficiency(*right);
631
632                CircuitNode::BinaryOp {
633                    op,
634                    left: Box::new(left),
635                    right: Box::new(right),
636                }
637            }
638
639            CircuitNode::UnaryOp { op, operand } => {
640                let operand = self.reorder_for_bootstrap_efficiency(*operand);
641                CircuitNode::UnaryOp {
642                    op,
643                    operand: Box::new(operand),
644                }
645            }
646
647            CircuitNode::Compare { op, left, right } => {
648                let left = self.reorder_for_bootstrap_efficiency(*left);
649                let right = self.reorder_for_bootstrap_efficiency(*right);
650
651                CircuitNode::Compare {
652                    op,
653                    left: Box::new(left),
654                    right: Box::new(right),
655                }
656            }
657
658            other => other,
659        }
660    }
661
662    /// Dead code elimination pass
663    ///
664    /// Performs real DCE by:
665    /// 1. Applying algebraic simplifications that eliminate redundant operations
666    ///    (e.g., `x - x` -> `0`, `x + 0` -> `x`, double negation)
667    /// 2. Constant folding any newly-exposed constant sub-expressions
668    /// 3. Iterating until a fixed point is reached (no further changes)
669    ///
670    /// For single-output tree-structured circuits every reachable node is live,
671    /// so classical "unused result" DCE is a no-op on the tree. Instead we focus
672    /// on strength-reducing and identity-collapsing operations that produce
673    /// effectively dead work (operations whose result equals an operand or a
674    /// constant).
675    fn dead_code_elimination_pass(&mut self, node: CircuitNode) -> CircuitNode {
676        let mut current = node;
677        // Iterate to a fixed point so nested simplifications cascade
678        loop {
679            let simplified = self.dce_simplify(current.clone());
680            if simplified == current {
681                break;
682            }
683            current = simplified;
684        }
685        current
686    }
687
688    /// Single pass of DCE simplification applied bottom-up
689    fn dce_simplify(&mut self, node: CircuitNode) -> CircuitNode {
690        match node {
691            CircuitNode::BinaryOp { op, left, right } => {
692                // Recurse first (bottom-up)
693                let left = self.dce_simplify(*left);
694                let right = self.dce_simplify(*right);
695
696                // Constant folding on newly-exposed constants
697                if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
698                    if let Some(result) = self.fold_binary_constants(op, l, r) {
699                        self.stats.nodes_eliminated += 1;
700                        self.stats.constants_folded += 1;
701                        return CircuitNode::Constant(result);
702                    }
703                }
704
705                // x - x = 0 (same subtree detection)
706                if op == BinaryOperator::Sub && left == right {
707                    self.stats.nodes_eliminated += 1;
708                    self.stats.algebraic_simplifications += 1;
709                    // Produce a zero of the appropriate type based on left subtree
710                    return self.zero_like(&left);
711                }
712
713                // x XOR x = false
714                if op == BinaryOperator::Xor && left == right {
715                    self.stats.nodes_eliminated += 1;
716                    self.stats.algebraic_simplifications += 1;
717                    return CircuitNode::Constant(CircuitValue::Bool(false));
718                }
719
720                // Algebraic identities: x+0, 0+x, x-0, x*1, 1*x, x*0, 0*x
721                match op {
722                    BinaryOperator::Add => {
723                        if Self::is_zero(&right) {
724                            self.stats.nodes_eliminated += 1;
725                            self.stats.algebraic_simplifications += 1;
726                            return left;
727                        }
728                        if Self::is_zero(&left) {
729                            self.stats.nodes_eliminated += 1;
730                            self.stats.algebraic_simplifications += 1;
731                            return right;
732                        }
733                    }
734                    BinaryOperator::Sub => {
735                        if Self::is_zero(&right) {
736                            self.stats.nodes_eliminated += 1;
737                            self.stats.algebraic_simplifications += 1;
738                            return left;
739                        }
740                    }
741                    BinaryOperator::Mul => {
742                        if Self::is_zero(&right) {
743                            self.stats.nodes_eliminated += 1;
744                            self.stats.algebraic_simplifications += 1;
745                            return right;
746                        }
747                        if Self::is_zero(&left) {
748                            self.stats.nodes_eliminated += 1;
749                            self.stats.algebraic_simplifications += 1;
750                            return left;
751                        }
752                        if Self::is_one(&right) {
753                            self.stats.nodes_eliminated += 1;
754                            self.stats.algebraic_simplifications += 1;
755                            return left;
756                        }
757                        if Self::is_one(&left) {
758                            self.stats.nodes_eliminated += 1;
759                            self.stats.algebraic_simplifications += 1;
760                            return right;
761                        }
762                    }
763                    BinaryOperator::And => {
764                        // x AND x = x
765                        if left == right {
766                            self.stats.nodes_eliminated += 1;
767                            self.stats.algebraic_simplifications += 1;
768                            return left;
769                        }
770                        if Self::is_true(&right) {
771                            self.stats.nodes_eliminated += 1;
772                            self.stats.algebraic_simplifications += 1;
773                            return left;
774                        }
775                        if Self::is_true(&left) {
776                            self.stats.nodes_eliminated += 1;
777                            self.stats.algebraic_simplifications += 1;
778                            return right;
779                        }
780                        if Self::is_false(&right) {
781                            self.stats.nodes_eliminated += 1;
782                            self.stats.algebraic_simplifications += 1;
783                            return right;
784                        }
785                        if Self::is_false(&left) {
786                            self.stats.nodes_eliminated += 1;
787                            self.stats.algebraic_simplifications += 1;
788                            return left;
789                        }
790                    }
791                    BinaryOperator::Or => {
792                        // x OR x = x
793                        if left == right {
794                            self.stats.nodes_eliminated += 1;
795                            self.stats.algebraic_simplifications += 1;
796                            return left;
797                        }
798                        if Self::is_false(&right) {
799                            self.stats.nodes_eliminated += 1;
800                            self.stats.algebraic_simplifications += 1;
801                            return left;
802                        }
803                        if Self::is_false(&left) {
804                            self.stats.nodes_eliminated += 1;
805                            self.stats.algebraic_simplifications += 1;
806                            return right;
807                        }
808                        if Self::is_true(&right) {
809                            self.stats.nodes_eliminated += 1;
810                            self.stats.algebraic_simplifications += 1;
811                            return right;
812                        }
813                        if Self::is_true(&left) {
814                            self.stats.nodes_eliminated += 1;
815                            self.stats.algebraic_simplifications += 1;
816                            return left;
817                        }
818                    }
819                    BinaryOperator::Xor => {
820                        if Self::is_false(&right) {
821                            self.stats.nodes_eliminated += 1;
822                            self.stats.algebraic_simplifications += 1;
823                            return left;
824                        }
825                        if Self::is_false(&left) {
826                            self.stats.nodes_eliminated += 1;
827                            self.stats.algebraic_simplifications += 1;
828                            return right;
829                        }
830                    }
831                }
832
833                CircuitNode::BinaryOp {
834                    op,
835                    left: Box::new(left),
836                    right: Box::new(right),
837                }
838            }
839
840            CircuitNode::UnaryOp { op, operand } => {
841                let operand = self.dce_simplify(*operand);
842
843                // Constant folding
844                if let CircuitNode::Constant(val) = &operand {
845                    if let Some(result) = self.fold_unary_constant(op, val) {
846                        self.stats.nodes_eliminated += 1;
847                        self.stats.constants_folded += 1;
848                        return CircuitNode::Constant(result);
849                    }
850                }
851
852                // Double negation: NOT(NOT(x)) = x
853                if op == UnaryOperator::Not {
854                    if let CircuitNode::UnaryOp {
855                        op: UnaryOperator::Not,
856                        operand: inner,
857                    } = operand
858                    {
859                        self.stats.nodes_eliminated += 2;
860                        self.stats.algebraic_simplifications += 1;
861                        return *inner;
862                    }
863                }
864
865                // Double negation for Neg: Neg(Neg(x)) = x
866                if op == UnaryOperator::Neg {
867                    if let CircuitNode::UnaryOp {
868                        op: UnaryOperator::Neg,
869                        operand: inner,
870                    } = operand
871                    {
872                        self.stats.nodes_eliminated += 2;
873                        self.stats.algebraic_simplifications += 1;
874                        return *inner;
875                    }
876                }
877
878                CircuitNode::UnaryOp {
879                    op,
880                    operand: Box::new(operand),
881                }
882            }
883
884            CircuitNode::Compare { op, left, right } => {
885                let left = self.dce_simplify(*left);
886                let right = self.dce_simplify(*right);
887
888                // Constant fold comparisons
889                if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
890                    if let Some(result) = self.fold_comparison(op, l, r) {
891                        self.stats.nodes_eliminated += 1;
892                        self.stats.constants_folded += 1;
893                        return CircuitNode::Constant(CircuitValue::Bool(result));
894                    }
895                }
896
897                CircuitNode::Compare {
898                    op,
899                    left: Box::new(left),
900                    right: Box::new(right),
901                }
902            }
903
904            other => other,
905        }
906    }
907
908    /// Produce a zero constant matching the type inferred from a subtree
909    fn zero_like(&self, node: &CircuitNode) -> CircuitNode {
910        match node {
911            CircuitNode::Constant(CircuitValue::U8(_)) => {
912                CircuitNode::Constant(CircuitValue::U8(0))
913            }
914            CircuitNode::Constant(CircuitValue::U16(_)) => {
915                CircuitNode::Constant(CircuitValue::U16(0))
916            }
917            CircuitNode::Constant(CircuitValue::U32(_)) => {
918                CircuitNode::Constant(CircuitValue::U32(0))
919            }
920            CircuitNode::Constant(CircuitValue::U64(_)) => {
921                CircuitNode::Constant(CircuitValue::U64(0))
922            }
923            // Default to U8(0) for non-constant nodes where type is unknown
924            _ => CircuitNode::Constant(CircuitValue::U8(0)),
925        }
926    }
927
928    /// Fold comparison of two constants into a boolean result
929    fn fold_comparison(
930        &self,
931        op: CompareOperator,
932        left: &CircuitValue,
933        right: &CircuitValue,
934    ) -> Option<bool> {
935        match (left, right) {
936            (CircuitValue::U8(l), CircuitValue::U8(r)) => Some(self.compare_values(op, *l, *r)),
937            (CircuitValue::U16(l), CircuitValue::U16(r)) => Some(self.compare_values(op, *l, *r)),
938            (CircuitValue::U32(l), CircuitValue::U32(r)) => Some(self.compare_values(op, *l, *r)),
939            (CircuitValue::U64(l), CircuitValue::U64(r)) => Some(self.compare_values(op, *l, *r)),
940            (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
941                CompareOperator::Eq => Some(l == r),
942                CompareOperator::Ne => Some(l != r),
943                _ => None,
944            },
945            _ => None,
946        }
947    }
948
949    /// Compare two ordered values with a comparison operator
950    fn compare_values<T: PartialOrd + PartialEq>(&self, op: CompareOperator, l: T, r: T) -> bool {
951        match op {
952            CompareOperator::Eq => l == r,
953            CompareOperator::Ne => l != r,
954            CompareOperator::Lt => l < r,
955            CompareOperator::Le => l <= r,
956            CompareOperator::Gt => l > r,
957            CompareOperator::Ge => l >= r,
958        }
959    }
960
961    /// Collect the set of variable names that are actually used in the circuit tree
962    pub fn collect_live_variables(&self, node: &CircuitNode) -> HashSet<String> {
963        let mut live = HashSet::new();
964        self.mark_live_nodes(node, &mut live);
965        live
966    }
967
968    /// Mark nodes that contribute to the output
969    #[allow(clippy::only_used_in_recursion)]
970    fn mark_live_nodes(&self, node: &CircuitNode, live_nodes: &mut HashSet<String>) {
971        match node {
972            CircuitNode::Load(name) => {
973                live_nodes.insert(name.clone());
974            }
975
976            CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {}
977
978            CircuitNode::BinaryOp { left, right, .. } => {
979                self.mark_live_nodes(left, live_nodes);
980                self.mark_live_nodes(right, live_nodes);
981            }
982
983            CircuitNode::UnaryOp { operand, .. } => {
984                self.mark_live_nodes(operand, live_nodes);
985            }
986
987            CircuitNode::Compare { left, right, .. } => {
988                self.mark_live_nodes(left, live_nodes);
989                self.mark_live_nodes(right, live_nodes);
990            }
991        }
992    }
993
994    /// Analyze circuit for parallelization opportunities
995    ///
996    /// Builds a dependency graph and identifies operations that can run in parallel
997    fn analyze_parallelism(&self, circuit: &Circuit) -> Result<DependencyGraph> {
998        let mut graph = DependencyGraph::new();
999        let mut node_id_map = HashMap::new();
1000        let mut next_id = 0;
1001
1002        // Build dependency graph
1003        self.build_dependency_graph(&circuit.root, &mut graph, &mut node_id_map, &mut next_id);
1004
1005        graph.node_count = next_id;
1006
1007        // Identify parallel groups using level-wise traversal
1008        graph.parallel_groups = self.identify_parallel_groups(&graph);
1009
1010        // Find critical path
1011        graph.critical_path = self.find_critical_path(&graph);
1012
1013        Ok(graph)
1014    }
1015
1016    /// Build dependency graph recursively
1017    #[allow(clippy::only_used_in_recursion)]
1018    fn build_dependency_graph(
1019        &self,
1020        node: &CircuitNode,
1021        graph: &mut DependencyGraph,
1022        node_id_map: &mut HashMap<String, NodeId>,
1023        next_id: &mut usize,
1024    ) -> NodeId {
1025        let current_id = NodeId(*next_id);
1026        *next_id += 1;
1027
1028        match node {
1029            CircuitNode::Load(name) => {
1030                node_id_map.insert(name.clone(), current_id);
1031                graph.dependencies.insert(current_id, Vec::new());
1032                current_id
1033            }
1034
1035            CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {
1036                graph.dependencies.insert(current_id, Vec::new());
1037                current_id
1038            }
1039
1040            CircuitNode::BinaryOp { left, right, .. } => {
1041                let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
1042                let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
1043
1044                graph
1045                    .dependencies
1046                    .insert(current_id, vec![left_id, right_id]);
1047                current_id
1048            }
1049
1050            CircuitNode::UnaryOp { operand, .. } => {
1051                let operand_id = self.build_dependency_graph(operand, graph, node_id_map, next_id);
1052
1053                graph.dependencies.insert(current_id, vec![operand_id]);
1054                current_id
1055            }
1056
1057            CircuitNode::Compare { left, right, .. } => {
1058                let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
1059                let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
1060
1061                graph
1062                    .dependencies
1063                    .insert(current_id, vec![left_id, right_id]);
1064                current_id
1065            }
1066        }
1067    }
1068
1069    /// Identify groups of nodes that can execute in parallel
1070    fn identify_parallel_groups(&self, graph: &DependencyGraph) -> Vec<Vec<NodeId>> {
1071        let mut levels: HashMap<NodeId, usize> = HashMap::new();
1072        let mut queue = VecDeque::new();
1073
1074        // Find all nodes with no dependencies (level 0)
1075        for (node_id, deps) in &graph.dependencies {
1076            if deps.is_empty() {
1077                levels.insert(*node_id, 0);
1078                queue.push_back(*node_id);
1079            }
1080        }
1081
1082        // Level-wise traversal
1083        while let Some(node_id) = queue.pop_front() {
1084            let current_level = levels[&node_id];
1085
1086            // Find nodes that depend on this node
1087            for (dependent_id, deps) in &graph.dependencies {
1088                if deps.contains(&node_id) {
1089                    // Calculate level for dependent node
1090                    let max_dep_level = deps
1091                        .iter()
1092                        .filter_map(|dep_id| levels.get(dep_id))
1093                        .max()
1094                        .copied()
1095                        .unwrap_or(0);
1096
1097                    let dependent_level = max_dep_level + 1;
1098
1099                    if !levels.contains_key(dependent_id) {
1100                        levels.insert(*dependent_id, dependent_level);
1101                        queue.push_back(*dependent_id);
1102                    }
1103                }
1104            }
1105        }
1106
1107        // Group nodes by level
1108        let max_level = levels.values().max().copied().unwrap_or(0);
1109        let mut parallel_groups = vec![Vec::new(); max_level + 1];
1110
1111        for (node_id, level) in levels {
1112            parallel_groups[level].push(node_id);
1113        }
1114
1115        // Sort each group for deterministic output
1116        for group in &mut parallel_groups {
1117            group.sort();
1118        }
1119
1120        parallel_groups
1121    }
1122
1123    /// Find the critical path (longest dependency chain)
1124    fn find_critical_path(&self, graph: &DependencyGraph) -> Vec<NodeId> {
1125        // Simple implementation: find the node with the longest chain to root
1126        let mut max_path = Vec::new();
1127
1128        for node_id in graph.dependencies.keys() {
1129            let path = self.find_path_to_root(*node_id, graph);
1130            if path.len() > max_path.len() {
1131                max_path = path;
1132            }
1133        }
1134
1135        max_path
1136    }
1137
1138    /// Find path from a node to a root (node with no dependencies)
1139    #[allow(clippy::only_used_in_recursion)]
1140    fn find_path_to_root(&self, node_id: NodeId, graph: &DependencyGraph) -> Vec<NodeId> {
1141        let deps = graph
1142            .dependencies
1143            .get(&node_id)
1144            .map(|v| v.as_slice())
1145            .unwrap_or(&[]);
1146
1147        if deps.is_empty() {
1148            return vec![node_id];
1149        }
1150
1151        // Find the longest path through dependencies
1152        let mut longest_path = Vec::new();
1153        for dep_id in deps {
1154            let dep_path = self.find_path_to_root(*dep_id, graph);
1155            if dep_path.len() > longest_path.len() {
1156                longest_path = dep_path;
1157            }
1158        }
1159
1160        longest_path.push(node_id);
1161        longest_path
1162    }
1163}
1164
1165impl Default for CircuitOptimizer {
1166    fn default() -> Self {
1167        Self::new()
1168    }
1169}
1170
1171#[cfg(test)]
1172mod tests {
1173    use super::*;
1174    use crate::compute::circuit::CircuitBuilder;
1175
1176    // ── Constant folding tests ─────────────────────────────────────────
1177
1178    #[test]
1179    fn test_constant_folding() -> Result<()> {
1180        let mut optimizer = CircuitOptimizer::new();
1181        let builder = CircuitBuilder::new();
1182
1183        // Create circuit: 5 + 3
1184        let a = builder.constant(CircuitValue::U8(5));
1185        let b = builder.constant(CircuitValue::U8(3));
1186        let sum = builder.add(a, b);
1187
1188        let circuit = Circuit::new(sum, HashMap::new())?;
1189        let optimized = optimizer.optimize(circuit)?;
1190
1191        // Should fold to constant 8
1192        assert!(matches!(
1193            optimized.root,
1194            CircuitNode::Constant(CircuitValue::U8(8))
1195        ));
1196        assert!(optimizer.stats().constants_folded >= 1);
1197
1198        Ok(())
1199    }
1200
1201    #[test]
1202    fn test_constant_folding_sub() -> Result<()> {
1203        let mut optimizer = CircuitOptimizer::new();
1204        let builder = CircuitBuilder::new();
1205
1206        let a = builder.constant(CircuitValue::U16(100));
1207        let b = builder.constant(CircuitValue::U16(30));
1208        let result = builder.sub(a, b);
1209
1210        let circuit = Circuit::new(result, HashMap::new())?;
1211        let optimized = optimizer.optimize(circuit)?;
1212
1213        assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U16(70)));
1214        Ok(())
1215    }
1216
1217    #[test]
1218    fn test_constant_folding_mul() -> Result<()> {
1219        let mut optimizer = CircuitOptimizer::new();
1220        let builder = CircuitBuilder::new();
1221
1222        let a = builder.constant(CircuitValue::U32(7));
1223        let b = builder.constant(CircuitValue::U32(6));
1224        let result = builder.mul(a, b);
1225
1226        let circuit = Circuit::new(result, HashMap::new())?;
1227        let optimized = optimizer.optimize(circuit)?;
1228
1229        assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U32(42)));
1230        Ok(())
1231    }
1232
1233    #[test]
1234    fn test_constant_folding_bool_and() -> Result<()> {
1235        let mut optimizer = CircuitOptimizer::new();
1236        let builder = CircuitBuilder::new();
1237
1238        let t = builder.constant(CircuitValue::Bool(true));
1239        let f = builder.constant(CircuitValue::Bool(false));
1240        let result = builder.and(t, f);
1241
1242        let circuit = Circuit::new(result, HashMap::new())?;
1243        let optimized = optimizer.optimize(circuit)?;
1244
1245        assert_eq!(
1246            optimized.root,
1247            CircuitNode::Constant(CircuitValue::Bool(false))
1248        );
1249        Ok(())
1250    }
1251
1252    #[test]
1253    fn test_constant_folding_unary_not() -> Result<()> {
1254        let mut optimizer = CircuitOptimizer::new();
1255        let builder = CircuitBuilder::new();
1256
1257        let t = builder.constant(CircuitValue::Bool(true));
1258        let result = builder.not(t);
1259
1260        let circuit = Circuit::new(result, HashMap::new())?;
1261        let optimized = optimizer.optimize(circuit)?;
1262
1263        assert_eq!(
1264            optimized.root,
1265            CircuitNode::Constant(CircuitValue::Bool(false))
1266        );
1267        Ok(())
1268    }
1269
1270    // ── Algebraic identity tests ───────────────────────────────────────
1271
1272    #[test]
1273    fn test_algebraic_x_plus_zero() -> Result<()> {
1274        let mut optimizer = CircuitOptimizer::new();
1275        let mut builder = CircuitBuilder::new();
1276        builder.declare_variable("x", EncryptedType::U8);
1277
1278        let x = builder.load("x");
1279        let zero = builder.constant(CircuitValue::U8(0));
1280        let add_zero = builder.add(x, zero);
1281
1282        let circuit = Circuit::new(add_zero, builder.variable_types_clone())?;
1283        let optimized = optimizer.optimize(circuit)?;
1284
1285        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1286        Ok(())
1287    }
1288
1289    #[test]
1290    fn test_algebraic_zero_plus_x() -> Result<()> {
1291        let mut optimizer = CircuitOptimizer::new();
1292        let mut builder = CircuitBuilder::new();
1293        builder.declare_variable("x", EncryptedType::U8);
1294
1295        let x = builder.load("x");
1296        let zero = builder.constant(CircuitValue::U8(0));
1297        let result = builder.add(zero, x);
1298
1299        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1300        let optimized = optimizer.optimize(circuit)?;
1301
1302        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1303        Ok(())
1304    }
1305
1306    #[test]
1307    fn test_algebraic_x_mul_one() -> Result<()> {
1308        let mut optimizer = CircuitOptimizer::new();
1309        let mut builder = CircuitBuilder::new();
1310        builder.declare_variable("x", EncryptedType::U8);
1311
1312        let x = builder.load("x");
1313        let one = builder.constant(CircuitValue::U8(1));
1314        let result = builder.mul(x, one);
1315
1316        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1317        let optimized = optimizer.optimize(circuit)?;
1318
1319        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1320        Ok(())
1321    }
1322
1323    #[test]
1324    fn test_algebraic_one_mul_x() -> Result<()> {
1325        let mut optimizer = CircuitOptimizer::new();
1326        let mut builder = CircuitBuilder::new();
1327        builder.declare_variable("x", EncryptedType::U8);
1328
1329        let x = builder.load("x");
1330        let one = builder.constant(CircuitValue::U8(1));
1331        let result = builder.mul(one, x);
1332
1333        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1334        let optimized = optimizer.optimize(circuit)?;
1335
1336        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1337        Ok(())
1338    }
1339
1340    #[test]
1341    fn test_algebraic_x_mul_zero() -> Result<()> {
1342        let mut optimizer = CircuitOptimizer::new();
1343        let mut builder = CircuitBuilder::new();
1344        builder.declare_variable("x", EncryptedType::U8);
1345
1346        let x = builder.load("x");
1347        let zero = builder.constant(CircuitValue::U8(0));
1348        let result = builder.mul(x, zero);
1349
1350        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1351        let optimized = optimizer.optimize(circuit)?;
1352
1353        assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(0)));
1354        Ok(())
1355    }
1356
1357    #[test]
1358    fn test_algebraic_zero_mul_x() -> Result<()> {
1359        let mut optimizer = CircuitOptimizer::new();
1360        let mut builder = CircuitBuilder::new();
1361        builder.declare_variable("x", EncryptedType::U8);
1362
1363        let x = builder.load("x");
1364        let zero = builder.constant(CircuitValue::U8(0));
1365        let result = builder.mul(zero, x);
1366
1367        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1368        let optimized = optimizer.optimize(circuit)?;
1369
1370        assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(0)));
1371        Ok(())
1372    }
1373
1374    #[test]
1375    fn test_algebraic_x_sub_zero() -> Result<()> {
1376        let mut optimizer = CircuitOptimizer::new();
1377        let mut builder = CircuitBuilder::new();
1378        builder.declare_variable("x", EncryptedType::U8);
1379
1380        let x = builder.load("x");
1381        let zero = builder.constant(CircuitValue::U8(0));
1382        let result = builder.sub(x, zero);
1383
1384        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1385        let optimized = optimizer.optimize(circuit)?;
1386
1387        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1388        Ok(())
1389    }
1390
1391    #[test]
1392    fn test_algebraic_x_sub_x() -> Result<()> {
1393        let mut optimizer = CircuitOptimizer::new();
1394        let mut builder = CircuitBuilder::new();
1395        builder.declare_variable("x", EncryptedType::U8);
1396
1397        let x1 = builder.load("x");
1398        let x2 = builder.load("x");
1399        let result = builder.sub(x1, x2);
1400
1401        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1402        let optimized = optimizer.optimize(circuit)?;
1403
1404        // x - x should be 0
1405        assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(0)));
1406        assert!(optimizer.stats().algebraic_simplifications >= 1);
1407        Ok(())
1408    }
1409
1410    // ── Double negation tests ──────────────────────────────────────────
1411
1412    #[test]
1413    fn test_double_negation_elimination() -> Result<()> {
1414        let mut optimizer = CircuitOptimizer::new();
1415        let mut builder = CircuitBuilder::new();
1416        builder.declare_variable("x", EncryptedType::Bool);
1417
1418        let x = builder.load("x");
1419        let not_x = builder.not(x);
1420        let not_not_x = builder.not(not_x);
1421
1422        let circuit = Circuit::new(not_not_x, builder.variable_types_clone())?;
1423        let optimized = optimizer.optimize(circuit)?;
1424
1425        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1426        Ok(())
1427    }
1428
1429    #[test]
1430    fn test_quadruple_negation_elimination() -> Result<()> {
1431        let mut optimizer = CircuitOptimizer::new();
1432        let mut builder = CircuitBuilder::new();
1433        builder.declare_variable("x", EncryptedType::Bool);
1434
1435        let x = builder.load("x");
1436        let n1 = builder.not(x);
1437        let n2 = builder.not(n1);
1438        let n3 = builder.not(n2);
1439        let n4 = builder.not(n3);
1440
1441        let circuit = Circuit::new(n4, builder.variable_types_clone())?;
1442        let optimized = optimizer.optimize(circuit)?;
1443
1444        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1445        Ok(())
1446    }
1447
1448    // ── Nested simplification tests ────────────────────────────────────
1449
1450    #[test]
1451    fn test_nested_x_plus_0_times_1() -> Result<()> {
1452        let mut optimizer = CircuitOptimizer::new();
1453        let mut builder = CircuitBuilder::new();
1454        builder.declare_variable("x", EncryptedType::U8);
1455
1456        // (x + 0) * 1 -> x
1457        let x = builder.load("x");
1458        let zero = builder.constant(CircuitValue::U8(0));
1459        let one = builder.constant(CircuitValue::U8(1));
1460        let add_zero = builder.add(x, zero);
1461        let times_one = builder.mul(add_zero, one);
1462
1463        let circuit = Circuit::new(times_one, builder.variable_types_clone())?;
1464        let optimized = optimizer.optimize(circuit)?;
1465
1466        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1467        Ok(())
1468    }
1469
1470    #[test]
1471    fn test_nested_complex_optimization() -> Result<()> {
1472        let mut optimizer = CircuitOptimizer::new();
1473        let mut builder = CircuitBuilder::new();
1474        builder
1475            .declare_variable("a", EncryptedType::U8)
1476            .declare_variable("b", EncryptedType::U8);
1477
1478        // (a * 1) + (b * 0) + 5  ->  a + 5
1479        let a = builder.load("a");
1480        let b = builder.load("b");
1481        let one = builder.constant(CircuitValue::U8(1));
1482        let zero = builder.constant(CircuitValue::U8(0));
1483        let five = builder.constant(CircuitValue::U8(5));
1484
1485        let a_times_1 = builder.mul(a, one);
1486        let b_times_0 = builder.mul(b, zero);
1487        let sum1 = builder.add(a_times_1, b_times_0);
1488        let result = builder.add(sum1, five);
1489
1490        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1491        let original_gates = circuit.gate_count;
1492
1493        let optimized = optimizer.optimize(circuit)?;
1494
1495        assert!(optimized.gate_count < original_gates);
1496        assert!(optimizer.stats().gate_reduction_percent() >= 30.0);
1497
1498        Ok(())
1499    }
1500
1501    // ── No-op on already optimal circuits ──────────────────────────────
1502
1503    #[test]
1504    fn test_noop_on_optimal_circuit() -> Result<()> {
1505        let mut optimizer = CircuitOptimizer::new();
1506        let mut builder = CircuitBuilder::new();
1507        builder
1508            .declare_variable("a", EncryptedType::U8)
1509            .declare_variable("b", EncryptedType::U8);
1510
1511        // a + b is already optimal
1512        let a = builder.load("a");
1513        let b = builder.load("b");
1514        let result = builder.add(a, b);
1515
1516        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1517        let original_gates = circuit.gate_count;
1518
1519        let optimized = optimizer.optimize(circuit)?;
1520
1521        assert_eq!(optimized.gate_count, original_gates);
1522        assert_eq!(
1523            optimized.root,
1524            CircuitNode::BinaryOp {
1525                op: BinaryOperator::Add,
1526                left: Box::new(CircuitNode::Load("a".to_string())),
1527                right: Box::new(CircuitNode::Load("b".to_string())),
1528            }
1529        );
1530        Ok(())
1531    }
1532
1533    #[test]
1534    fn test_noop_single_load() -> Result<()> {
1535        let mut optimizer = CircuitOptimizer::new();
1536        let mut builder = CircuitBuilder::new();
1537        builder.declare_variable("x", EncryptedType::U8);
1538
1539        let x = builder.load("x");
1540        let circuit = Circuit::new(x, builder.variable_types_clone())?;
1541        let optimized = optimizer.optimize(circuit)?;
1542
1543        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1544        Ok(())
1545    }
1546
1547    // ── Statistics accuracy tests ──────────────────────────────────────
1548
1549    #[test]
1550    fn test_stats_accuracy_constant_folding() -> Result<()> {
1551        let mut optimizer = CircuitOptimizer::new();
1552        let builder = CircuitBuilder::new();
1553
1554        // 5 + 3 -> 8, then 8 * 2 -> 16  (two folds)
1555        let a = builder.constant(CircuitValue::U8(5));
1556        let b = builder.constant(CircuitValue::U8(3));
1557        let two = builder.constant(CircuitValue::U8(2));
1558        let sum = builder.add(a, b);
1559        let result = builder.mul(sum, two);
1560
1561        let circuit = Circuit::new(result, HashMap::new())?;
1562        let optimized = optimizer.optimize(circuit)?;
1563
1564        assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(16)));
1565        // At least 2 constant folds happened (possibly more from DCE re-fold)
1566        assert!(optimizer.stats().constants_folded >= 2);
1567        Ok(())
1568    }
1569
1570    #[test]
1571    fn test_stats_accuracy_algebraic() -> Result<()> {
1572        let mut optimizer = CircuitOptimizer::new();
1573        let mut builder = CircuitBuilder::new();
1574        builder.declare_variable("x", EncryptedType::U8);
1575
1576        // x - x -> 0
1577        let x1 = builder.load("x");
1578        let x2 = builder.load("x");
1579        let result = builder.sub(x1, x2);
1580
1581        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1582        let _optimized = optimizer.optimize(circuit)?;
1583
1584        let (total_eliminated, total_algebraic, _total_folds) = optimizer.total_stats();
1585        assert!(total_eliminated >= 1);
1586        assert!(total_algebraic >= 1);
1587        Ok(())
1588    }
1589
1590    #[test]
1591    fn test_optimization_stats() -> Result<()> {
1592        let mut optimizer = CircuitOptimizer::new();
1593        let builder = CircuitBuilder::new();
1594
1595        let a = builder.constant(CircuitValue::U8(5));
1596        let b = builder.constant(CircuitValue::U8(3));
1597        let zero = builder.constant(CircuitValue::U8(0));
1598
1599        let sum = builder.add(a, b);
1600        let add_zero = builder.add(sum, zero);
1601
1602        let circuit = Circuit::new(add_zero, HashMap::new())?;
1603        let original_gates = circuit.gate_count;
1604
1605        let optimized = optimizer.optimize(circuit)?;
1606        let optimized_gates = optimized.gate_count;
1607
1608        assert!(optimized_gates < original_gates);
1609        assert!(optimizer.stats().gate_reduction_percent() > 0.0);
1610
1611        Ok(())
1612    }
1613
1614    #[test]
1615    fn test_total_stats_method() -> Result<()> {
1616        let mut optimizer = CircuitOptimizer::new();
1617        let mut builder = CircuitBuilder::new();
1618        builder.declare_variable("x", EncryptedType::U8);
1619
1620        // (x + 0) * 1 -> x (algebraic simplifications)
1621        // plus: 5 + 3 constant fold somewhere
1622        let x = builder.load("x");
1623        let zero = builder.constant(CircuitValue::U8(0));
1624        let one = builder.constant(CircuitValue::U8(1));
1625        let add_zero = builder.add(x, zero);
1626        let times_one = builder.mul(add_zero, one);
1627
1628        let circuit = Circuit::new(times_one, builder.variable_types_clone())?;
1629        let _optimized = optimizer.optimize(circuit)?;
1630
1631        let (eliminated, algebraic, _folds) = optimizer.total_stats();
1632        // Both x+0 and *1 should be simplified
1633        assert!(eliminated + algebraic >= 2);
1634        Ok(())
1635    }
1636
1637    // ── Bootstrap counting test ────────────────────────────────────────
1638
1639    #[test]
1640    fn test_bootstrap_counting() -> Result<()> {
1641        let optimizer = CircuitOptimizer::new();
1642        let mut builder = CircuitBuilder::new();
1643        builder
1644            .declare_variable("a", EncryptedType::U8)
1645            .declare_variable("b", EncryptedType::U8);
1646
1647        let a = builder.load("a");
1648        let b = builder.load("b");
1649        let mul = builder.mul(a, b);
1650
1651        let circuit = Circuit::new(mul, builder.variable_types_clone())?;
1652        let bootstrap_count = optimizer.count_bootstraps(&circuit.root);
1653
1654        assert_eq!(bootstrap_count, 1);
1655        Ok(())
1656    }
1657
1658    // ── Parallelization analysis test ──────────────────────────────────
1659
1660    #[test]
1661    fn test_parallelization_analysis() -> Result<()> {
1662        let mut optimizer = CircuitOptimizer::new();
1663        let mut builder = CircuitBuilder::new();
1664        builder
1665            .declare_variable("a", EncryptedType::U8)
1666            .declare_variable("b", EncryptedType::U8)
1667            .declare_variable("c", EncryptedType::U8);
1668
1669        let a = builder.load("a");
1670        let b = builder.load("b");
1671        let c = builder.load("c");
1672        let sum1 = builder.add(a, b);
1673        let sum2 = builder.add(sum1, c);
1674
1675        let circuit = Circuit::new(sum2, builder.variable_types_clone())?;
1676        let _optimized = optimizer.optimize(circuit)?;
1677
1678        let graph = optimizer.dependency_graph();
1679        assert!(graph.node_count > 0);
1680        assert!(!graph.parallel_groups.is_empty());
1681
1682        Ok(())
1683    }
1684
1685    // ── Live variable collection test ──────────────────────────────────
1686
1687    #[test]
1688    fn test_collect_live_variables() -> Result<()> {
1689        let optimizer = CircuitOptimizer::new();
1690        let mut builder = CircuitBuilder::new();
1691        builder
1692            .declare_variable("a", EncryptedType::U8)
1693            .declare_variable("b", EncryptedType::U8);
1694
1695        let a = builder.load("a");
1696        let b = builder.load("b");
1697        let result = builder.add(a, b);
1698
1699        let live = optimizer.collect_live_variables(&result);
1700        assert!(live.contains("a"));
1701        assert!(live.contains("b"));
1702        assert_eq!(live.len(), 2);
1703        Ok(())
1704    }
1705
1706    #[test]
1707    fn test_collect_live_variables_after_dce() -> Result<()> {
1708        let mut optimizer = CircuitOptimizer::new();
1709        let mut builder = CircuitBuilder::new();
1710        builder
1711            .declare_variable("a", EncryptedType::U8)
1712            .declare_variable("b", EncryptedType::U8);
1713
1714        // (a * 1) + (b * 0) => a + 0 => a
1715        // After optimization, b should be eliminated
1716        let a = builder.load("a");
1717        let b = builder.load("b");
1718        let one = builder.constant(CircuitValue::U8(1));
1719        let zero = builder.constant(CircuitValue::U8(0));
1720        let a1 = builder.mul(a, one);
1721        let b0 = builder.mul(b, zero);
1722        let result = builder.add(a1, b0);
1723
1724        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1725        let optimized = optimizer.optimize(circuit)?;
1726
1727        let live = optimizer.collect_live_variables(&optimized.root);
1728        assert!(live.contains("a"));
1729        // b was multiplied by 0, so entire branch collapses to 0, and then a + 0 => a
1730        assert!(!live.contains("b"), "b should be eliminated by DCE");
1731        Ok(())
1732    }
1733
1734    // ── Comparison constant folding test ───────────────────────────────
1735
1736    #[test]
1737    fn test_comparison_constant_fold() -> Result<()> {
1738        let mut optimizer = CircuitOptimizer::new();
1739        let builder = CircuitBuilder::new();
1740
1741        let a = builder.constant(CircuitValue::U8(10));
1742        let b = builder.constant(CircuitValue::U8(5));
1743        let result = builder.gt(a, b);
1744
1745        let circuit = Circuit::new(result, HashMap::new())?;
1746        let optimized = optimizer.optimize(circuit)?;
1747
1748        assert_eq!(
1749            optimized.root,
1750            CircuitNode::Constant(CircuitValue::Bool(true))
1751        );
1752        Ok(())
1753    }
1754
1755    #[test]
1756    fn test_comparison_constant_fold_eq() -> Result<()> {
1757        let mut optimizer = CircuitOptimizer::new();
1758        let builder = CircuitBuilder::new();
1759
1760        let a = builder.constant(CircuitValue::U8(5));
1761        let b = builder.constant(CircuitValue::U8(5));
1762        let result = builder.eq(a, b);
1763
1764        let circuit = Circuit::new(result, HashMap::new())?;
1765        let optimized = optimizer.optimize(circuit)?;
1766
1767        assert_eq!(
1768            optimized.root,
1769            CircuitNode::Constant(CircuitValue::Bool(true))
1770        );
1771        Ok(())
1772    }
1773
1774    // ── XOR self-elimination test ──────────────────────────────────────
1775
1776    #[test]
1777    fn test_xor_self_elimination() -> Result<()> {
1778        let mut optimizer = CircuitOptimizer::new();
1779        let mut builder = CircuitBuilder::new();
1780        builder.declare_variable("x", EncryptedType::Bool);
1781
1782        let x1 = builder.load("x");
1783        let x2 = builder.load("x");
1784        let result = builder.xor(x1, x2);
1785
1786        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1787        let optimized = optimizer.optimize(circuit)?;
1788
1789        assert_eq!(
1790            optimized.root,
1791            CircuitNode::Constant(CircuitValue::Bool(false))
1792        );
1793        Ok(())
1794    }
1795
1796    // ── AND/OR idempotent test ─────────────────────────────────────────
1797
1798    #[test]
1799    fn test_and_idempotent() -> Result<()> {
1800        let mut optimizer = CircuitOptimizer::new();
1801        let mut builder = CircuitBuilder::new();
1802        builder.declare_variable("x", EncryptedType::Bool);
1803
1804        let x1 = builder.load("x");
1805        let x2 = builder.load("x");
1806        let result = builder.and(x1, x2);
1807
1808        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1809        let optimized = optimizer.optimize(circuit)?;
1810
1811        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1812        Ok(())
1813    }
1814
1815    #[test]
1816    fn test_or_idempotent() -> Result<()> {
1817        let mut optimizer = CircuitOptimizer::new();
1818        let mut builder = CircuitBuilder::new();
1819        builder.declare_variable("x", EncryptedType::Bool);
1820
1821        let x1 = builder.load("x");
1822        let x2 = builder.load("x");
1823        let result = builder.or(x1, x2);
1824
1825        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1826        let optimized = optimizer.optimize(circuit)?;
1827
1828        assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1829        Ok(())
1830    }
1831
1832    // ── Encrypted constant optimizer tests ────────────────────────────
1833
1834    #[test]
1835    fn test_optimizer_does_not_fold_encrypted_constants() -> Result<()> {
1836        use crate::compute::circuit::ConstantType;
1837
1838        let mut optimizer = CircuitOptimizer::new();
1839        let builder = CircuitBuilder::new();
1840
1841        // Build: EncryptedConstant + EncryptedConstant
1842        // The optimizer must NOT try to constant-fold these because their
1843        // plaintext values are unknown.
1844        let enc_a = builder.encrypted_constant(vec![0x01, 0x05], ConstantType::Integer);
1845        let enc_b = builder.encrypted_constant(vec![0x01, 0x03], ConstantType::Integer);
1846        let sum = builder.add(enc_a.clone(), enc_b.clone());
1847
1848        let circuit = Circuit::new(sum, HashMap::new())?;
1849        let optimized = optimizer.optimize(circuit)?;
1850
1851        // The root should still be a BinaryOp Add, not a folded constant
1852        match &optimized.root {
1853            CircuitNode::BinaryOp { op, left, right } => {
1854                assert_eq!(*op, BinaryOperator::Add);
1855                assert!(matches!(**left, CircuitNode::EncryptedConstant { .. }));
1856                assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1857            }
1858            _ => {
1859                return Err(AmateRSError::FheComputation(ErrorContext::new(
1860                    "Optimizer incorrectly folded encrypted constants".to_string(),
1861                )));
1862            }
1863        }
1864
1865        // No constants should have been folded
1866        assert_eq!(optimizer.stats().constants_folded, 0);
1867
1868        Ok(())
1869    }
1870
1871    #[test]
1872    fn test_optimizer_dce_treats_encrypted_constant_as_opaque() -> Result<()> {
1873        use crate::compute::circuit::ConstantType;
1874
1875        let mut optimizer = CircuitOptimizer::new();
1876
1877        // Build a circuit: EncryptedConstant (standalone, as root)
1878        // DCE should leave it alone (it is the output)
1879        let enc = CircuitNode::EncryptedConstant {
1880            data: vec![0x04, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11],
1881            original_type: ConstantType::Integer,
1882        };
1883
1884        let circuit = Circuit::new(enc.clone(), HashMap::new())?;
1885        let optimized = optimizer.optimize(circuit)?;
1886
1887        // The root should remain an EncryptedConstant, untouched
1888        assert_eq!(optimized.root, enc);
1889
1890        Ok(())
1891    }
1892
1893    #[test]
1894    fn test_optimizer_mixed_plain_and_encrypted_constants() -> Result<()> {
1895        use crate::compute::circuit::ConstantType;
1896
1897        let mut optimizer = CircuitOptimizer::new();
1898        let builder = CircuitBuilder::new();
1899
1900        // Build: Constant(5u8) + Constant(3u8) -- these CAN be folded
1901        let plain_a = builder.constant(CircuitValue::U8(5));
1902        let plain_b = builder.constant(CircuitValue::U8(3));
1903        let plain_sum = builder.add(plain_a, plain_b);
1904
1905        let circuit = Circuit::new(plain_sum, HashMap::new())?;
1906        let optimized = optimizer.optimize(circuit)?;
1907
1908        // Should fold to 8
1909        assert!(matches!(
1910            optimized.root,
1911            CircuitNode::Constant(CircuitValue::U8(8))
1912        ));
1913
1914        // Now with encrypted: EncryptedConst + EncryptedConst -- must NOT fold
1915        let mut optimizer2 = CircuitOptimizer::new();
1916        let enc_a = builder.encrypted_constant(vec![0x01, 0xAA], ConstantType::Integer);
1917        let enc_b = builder.encrypted_constant(vec![0x01, 0xBB], ConstantType::Integer);
1918        let enc_sum = builder.add(enc_a, enc_b);
1919
1920        let circuit2 = Circuit::new(enc_sum, HashMap::new())?;
1921        let optimized2 = optimizer2.optimize(circuit2)?;
1922
1923        assert!(matches!(optimized2.root, CircuitNode::BinaryOp { .. }));
1924
1925        Ok(())
1926    }
1927
1928    #[test]
1929    fn test_optimizer_algebraic_identity_with_encrypted_constant() -> Result<()> {
1930        use crate::compute::circuit::ConstantType;
1931
1932        let mut optimizer = CircuitOptimizer::new();
1933        let builder = CircuitBuilder::new();
1934
1935        // Build: EncryptedConstant + Constant(0u64)
1936        // EncryptedConstant with ConstantType::Integer infers to U64,
1937        // so the zero constant must also be U64 for type compatibility.
1938        // The algebraic identity x + 0 = x should simplify this to just
1939        // the EncryptedConstant.
1940        let enc = builder.encrypted_constant(
1941            vec![0x04, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
1942            ConstantType::Integer,
1943        );
1944        let zero = builder.constant(CircuitValue::U64(0));
1945        let sum = builder.add(enc.clone(), zero);
1946
1947        let circuit = Circuit::new(sum, HashMap::new())?;
1948        let optimized = optimizer.optimize(circuit)?;
1949
1950        // Should simplify to just the encrypted constant
1951        assert_eq!(optimized.root, enc);
1952
1953        Ok(())
1954    }
1955
1956    #[test]
1957    fn test_optimizer_live_variables_with_encrypted_constants() -> Result<()> {
1958        use crate::compute::circuit::ConstantType;
1959
1960        let optimizer = CircuitOptimizer::new();
1961        let mut builder = CircuitBuilder::new();
1962        builder.declare_variable("x", EncryptedType::U8);
1963
1964        // Build: Load("x") + EncryptedConstant
1965        let x = builder.load("x");
1966        let enc = builder.encrypted_constant(vec![0x01, 0x10], ConstantType::Integer);
1967        let sum = builder.add(x, enc);
1968
1969        let live = optimizer.collect_live_variables(&sum);
1970
1971        // "x" is live, encrypted constant contributes nothing to variables
1972        assert!(live.contains("x"));
1973        assert_eq!(live.len(), 1);
1974
1975        Ok(())
1976    }
1977}