amaters_core/compute/
optimizer.rs

1//! Advanced circuit optimization for FHE computations
2//!
3//! This module provides sophisticated optimization passes to reduce the computational
4//! cost of FHE circuits. The optimizer focuses on:
5//!
6//! 1. **Bootstrap Minimization** - Reducing expensive bootstrap operations
7//! 2. **Gate Fusion** - Combining adjacent operations to reduce overhead
8//! 3. **Dead Code Elimination** - Removing unused operations
9//! 4. **Parallelization Analysis** - Identifying independent operations for parallel execution
10//!
11//! These optimizations can reduce circuit execution time by 30-50% in typical cases.
12
13use crate::compute::circuit::{
14    BinaryOperator, Circuit, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
15    UnaryOperator,
16};
17use crate::error::{AmateRSError, ErrorContext, Result};
18use std::collections::{HashMap, HashSet, VecDeque};
19
20/// Statistics collected during optimization
21#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct OptimizationStats {
23    /// Number of gates before optimization
24    pub original_gate_count: usize,
25
26    /// Number of gates after optimization
27    pub optimized_gate_count: usize,
28
29    /// Number of bootstrap operations before optimization
30    pub original_bootstrap_count: usize,
31
32    /// Number of bootstrap operations after optimization
33    pub optimized_bootstrap_count: usize,
34
35    /// Number of dead code nodes removed
36    pub dead_code_removed: usize,
37
38    /// Number of constant expressions folded
39    pub constants_folded: usize,
40
41    /// Number of gates fused
42    pub gates_fused: usize,
43
44    /// Circuit depth before optimization
45    pub original_depth: usize,
46
47    /// Circuit depth after optimization
48    pub optimized_depth: usize,
49}
50
51impl OptimizationStats {
52    /// Calculate the reduction percentage in gate count
53    pub fn gate_reduction_percent(&self) -> f64 {
54        if self.original_gate_count == 0 {
55            return 0.0;
56        }
57        let reduction = self
58            .original_gate_count
59            .saturating_sub(self.optimized_gate_count);
60        (reduction as f64 / self.original_gate_count as f64) * 100.0
61    }
62
63    /// Calculate the reduction percentage in bootstrap operations
64    pub fn bootstrap_reduction_percent(&self) -> f64 {
65        if self.original_bootstrap_count == 0 {
66            return 0.0;
67        }
68        let reduction = self
69            .original_bootstrap_count
70            .saturating_sub(self.optimized_bootstrap_count);
71        (reduction as f64 / self.original_bootstrap_count as f64) * 100.0
72    }
73}
74
75/// Dependency information for parallelization
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct DependencyGraph {
78    /// Node ID to its dependencies
79    pub dependencies: HashMap<NodeId, Vec<NodeId>>,
80
81    /// Nodes that can be executed in parallel (sets of independent nodes)
82    pub parallel_groups: Vec<Vec<NodeId>>,
83
84    /// Critical path through the circuit (longest dependency chain)
85    pub critical_path: Vec<NodeId>,
86
87    /// Total number of nodes in the graph
88    pub node_count: usize,
89}
90
91impl DependencyGraph {
92    /// Create an empty dependency graph
93    pub fn new() -> Self {
94        Self {
95            dependencies: HashMap::new(),
96            parallel_groups: Vec::new(),
97            critical_path: Vec::new(),
98            node_count: 0,
99        }
100    }
101
102    /// Calculate the maximum parallelism (largest parallel group)
103    pub fn max_parallelism(&self) -> usize {
104        self.parallel_groups
105            .iter()
106            .map(|g| g.len())
107            .max()
108            .unwrap_or(0)
109    }
110
111    /// Calculate the average parallelism
112    pub fn avg_parallelism(&self) -> f64 {
113        if self.parallel_groups.is_empty() {
114            return 0.0;
115        }
116        let total: usize = self.parallel_groups.iter().map(|g| g.len()).sum();
117        total as f64 / self.parallel_groups.len() as f64
118    }
119}
120
121impl Default for DependencyGraph {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127/// Node identifier for dependency tracking
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
129pub struct NodeId(pub usize);
130
131/// Advanced circuit optimizer with multiple optimization passes
132#[derive(Debug, Clone)]
133pub struct CircuitOptimizer {
134    /// Enable constant folding optimization
135    pub enable_constant_folding: bool,
136
137    /// Enable dead code elimination
138    pub enable_dead_code_elimination: bool,
139
140    /// Enable bootstrap minimization
141    pub enable_bootstrap_minimization: bool,
142
143    /// Enable gate fusion
144    pub enable_gate_fusion: bool,
145
146    /// Enable parallelization analysis
147    pub enable_parallelization_analysis: bool,
148
149    /// Statistics from the last optimization
150    stats: OptimizationStats,
151
152    /// Dependency graph from the last optimization
153    dependency_graph: DependencyGraph,
154}
155
156impl CircuitOptimizer {
157    /// Create a new optimizer with all optimizations enabled
158    pub fn new() -> Self {
159        Self {
160            enable_constant_folding: true,
161            enable_dead_code_elimination: true,
162            enable_bootstrap_minimization: true,
163            enable_gate_fusion: true,
164            enable_parallelization_analysis: true,
165            stats: OptimizationStats::default(),
166            dependency_graph: DependencyGraph::new(),
167        }
168    }
169
170    /// Create an optimizer with no optimizations enabled
171    pub fn disabled() -> Self {
172        Self {
173            enable_constant_folding: false,
174            enable_dead_code_elimination: false,
175            enable_bootstrap_minimization: false,
176            enable_gate_fusion: false,
177            enable_parallelization_analysis: false,
178            stats: OptimizationStats::default(),
179            dependency_graph: DependencyGraph::new(),
180        }
181    }
182
183    /// Get the statistics from the last optimization
184    pub fn stats(&self) -> &OptimizationStats {
185        &self.stats
186    }
187
188    /// Get the dependency graph from the last optimization
189    pub fn dependency_graph(&self) -> &DependencyGraph {
190        &self.dependency_graph
191    }
192
193    /// Optimize a circuit by applying all enabled optimization passes
194    pub fn optimize(&mut self, circuit: Circuit) -> Result<Circuit> {
195        // Record original statistics
196        self.stats.original_gate_count = circuit.gate_count;
197        self.stats.original_depth = circuit.depth;
198        self.stats.original_bootstrap_count = self.count_bootstraps(&circuit.root);
199
200        let mut optimized_root = circuit.root.clone();
201
202        // Apply optimization passes in order
203        if self.enable_constant_folding {
204            optimized_root = self.constant_folding_pass(optimized_root);
205        }
206
207        if self.enable_gate_fusion {
208            optimized_root = self.gate_fusion_pass(optimized_root);
209        }
210
211        if self.enable_bootstrap_minimization {
212            optimized_root = self.bootstrap_minimization_pass(optimized_root)?;
213        }
214
215        if self.enable_dead_code_elimination {
216            optimized_root = self.dead_code_elimination_pass(optimized_root);
217        }
218
219        // Build optimized circuit
220        let optimized_circuit = Circuit::new(optimized_root, circuit.variable_types)?;
221
222        // Record optimized statistics
223        self.stats.optimized_gate_count = optimized_circuit.gate_count;
224        self.stats.optimized_depth = optimized_circuit.depth;
225        self.stats.optimized_bootstrap_count = self.count_bootstraps(&optimized_circuit.root);
226
227        // Analyze parallelization if enabled
228        if self.enable_parallelization_analysis {
229            self.dependency_graph = self.analyze_parallelism(&optimized_circuit)?;
230        }
231
232        Ok(optimized_circuit)
233    }
234
235    /// Count the number of bootstrap operations in a circuit
236    ///
237    /// In TFHE, bootstrapping is required after certain operations to refresh noise.
238    /// For this implementation, we estimate bootstraps based on operation types:
239    /// - Multiplication requires bootstrap
240    /// - Comparison operations require bootstrap
241    /// - Deep chains of additions may require bootstrap
242    #[allow(clippy::only_used_in_recursion)]
243    fn count_bootstraps(&self, node: &CircuitNode) -> usize {
244        match node {
245            CircuitNode::Load(_) | CircuitNode::Constant(_) => 0,
246
247            CircuitNode::BinaryOp { op, left, right } => {
248                let left_bootstraps = self.count_bootstraps(left);
249                let right_bootstraps = self.count_bootstraps(right);
250
251                // Multiplication requires bootstrap
252                let op_bootstrap = match op {
253                    BinaryOperator::Mul => 1,
254                    _ => 0,
255                };
256
257                left_bootstraps + right_bootstraps + op_bootstrap
258            }
259
260            CircuitNode::UnaryOp { operand, .. } => self.count_bootstraps(operand),
261
262            CircuitNode::Compare { left, right, .. } => {
263                let left_bootstraps = self.count_bootstraps(left);
264                let right_bootstraps = self.count_bootstraps(right);
265
266                // Comparisons typically require bootstrap
267                left_bootstraps + right_bootstraps + 1
268            }
269        }
270    }
271
272    /// Constant folding optimization pass
273    ///
274    /// Evaluates constant expressions at compile time to reduce runtime computation
275    fn constant_folding_pass(&mut self, node: CircuitNode) -> CircuitNode {
276        match node {
277            CircuitNode::BinaryOp { op, left, right } => {
278                let left = self.constant_folding_pass(*left);
279                let right = self.constant_folding_pass(*right);
280
281                // Try to fold constants
282                if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
283                    if let Some(result) = self.fold_binary_constants(op, l, r) {
284                        self.stats.constants_folded += 1;
285                        return CircuitNode::Constant(result);
286                    }
287                }
288
289                // Apply algebraic identities
290                if let Some(simplified) = self.apply_algebraic_identities(op, &left, &right) {
291                    return simplified;
292                }
293
294                CircuitNode::BinaryOp {
295                    op,
296                    left: Box::new(left),
297                    right: Box::new(right),
298                }
299            }
300
301            CircuitNode::UnaryOp { op, operand } => {
302                let operand = self.constant_folding_pass(*operand);
303
304                if let CircuitNode::Constant(val) = &operand {
305                    if let Some(result) = self.fold_unary_constant(op, val) {
306                        self.stats.constants_folded += 1;
307                        return CircuitNode::Constant(result);
308                    }
309                }
310
311                CircuitNode::UnaryOp {
312                    op,
313                    operand: Box::new(operand),
314                }
315            }
316
317            CircuitNode::Compare { op, left, right } => {
318                let left = self.constant_folding_pass(*left);
319                let right = self.constant_folding_pass(*right);
320
321                CircuitNode::Compare {
322                    op,
323                    left: Box::new(left),
324                    right: Box::new(right),
325                }
326            }
327
328            other => other,
329        }
330    }
331
332    /// Fold binary operation on constants
333    fn fold_binary_constants(
334        &self,
335        op: BinaryOperator,
336        left: &CircuitValue,
337        right: &CircuitValue,
338    ) -> Option<CircuitValue> {
339        match (left, right) {
340            (CircuitValue::U8(l), CircuitValue::U8(r)) => match op {
341                BinaryOperator::Add => Some(CircuitValue::U8(l.wrapping_add(*r))),
342                BinaryOperator::Sub => Some(CircuitValue::U8(l.wrapping_sub(*r))),
343                BinaryOperator::Mul => Some(CircuitValue::U8(l.wrapping_mul(*r))),
344                _ => None,
345            },
346            (CircuitValue::U16(l), CircuitValue::U16(r)) => match op {
347                BinaryOperator::Add => Some(CircuitValue::U16(l.wrapping_add(*r))),
348                BinaryOperator::Sub => Some(CircuitValue::U16(l.wrapping_sub(*r))),
349                BinaryOperator::Mul => Some(CircuitValue::U16(l.wrapping_mul(*r))),
350                _ => None,
351            },
352            (CircuitValue::U32(l), CircuitValue::U32(r)) => match op {
353                BinaryOperator::Add => Some(CircuitValue::U32(l.wrapping_add(*r))),
354                BinaryOperator::Sub => Some(CircuitValue::U32(l.wrapping_sub(*r))),
355                BinaryOperator::Mul => Some(CircuitValue::U32(l.wrapping_mul(*r))),
356                _ => None,
357            },
358            (CircuitValue::U64(l), CircuitValue::U64(r)) => match op {
359                BinaryOperator::Add => Some(CircuitValue::U64(l.wrapping_add(*r))),
360                BinaryOperator::Sub => Some(CircuitValue::U64(l.wrapping_sub(*r))),
361                BinaryOperator::Mul => Some(CircuitValue::U64(l.wrapping_mul(*r))),
362                _ => None,
363            },
364            (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
365                BinaryOperator::And => Some(CircuitValue::Bool(*l && *r)),
366                BinaryOperator::Or => Some(CircuitValue::Bool(*l || *r)),
367                BinaryOperator::Xor => Some(CircuitValue::Bool(*l ^ *r)),
368                _ => None,
369            },
370            _ => None,
371        }
372    }
373
374    /// Fold unary operation on constant
375    fn fold_unary_constant(&self, op: UnaryOperator, value: &CircuitValue) -> Option<CircuitValue> {
376        match (op, value) {
377            (UnaryOperator::Not, CircuitValue::Bool(v)) => Some(CircuitValue::Bool(!*v)),
378            _ => None,
379        }
380    }
381
382    /// Apply algebraic identities to simplify expressions
383    /// Examples: x + 0 = x, x * 1 = x, x * 0 = 0, x AND true = x, etc.
384    fn apply_algebraic_identities(
385        &mut self,
386        op: BinaryOperator,
387        left: &CircuitNode,
388        right: &CircuitNode,
389    ) -> Option<CircuitNode> {
390        match op {
391            BinaryOperator::Add => {
392                // x + 0 = x
393                if Self::is_zero(right) {
394                    self.stats.gates_fused += 1;
395                    return Some(left.clone());
396                }
397                // 0 + x = x
398                if Self::is_zero(left) {
399                    self.stats.gates_fused += 1;
400                    return Some(right.clone());
401                }
402            }
403
404            BinaryOperator::Sub => {
405                // x - 0 = x
406                if Self::is_zero(right) {
407                    self.stats.gates_fused += 1;
408                    return Some(left.clone());
409                }
410            }
411
412            BinaryOperator::Mul => {
413                // x * 0 = 0
414                if Self::is_zero(right) {
415                    self.stats.gates_fused += 1;
416                    return Some(right.clone());
417                }
418                if Self::is_zero(left) {
419                    self.stats.gates_fused += 1;
420                    return Some(left.clone());
421                }
422
423                // x * 1 = x
424                if Self::is_one(right) {
425                    self.stats.gates_fused += 1;
426                    return Some(left.clone());
427                }
428                // 1 * x = x
429                if Self::is_one(left) {
430                    self.stats.gates_fused += 1;
431                    return Some(right.clone());
432                }
433            }
434
435            BinaryOperator::And => {
436                // x AND true = x
437                if Self::is_true(right) {
438                    self.stats.gates_fused += 1;
439                    return Some(left.clone());
440                }
441                if Self::is_true(left) {
442                    self.stats.gates_fused += 1;
443                    return Some(right.clone());
444                }
445
446                // x AND false = false
447                if Self::is_false(right) {
448                    self.stats.gates_fused += 1;
449                    return Some(right.clone());
450                }
451                if Self::is_false(left) {
452                    self.stats.gates_fused += 1;
453                    return Some(left.clone());
454                }
455            }
456
457            BinaryOperator::Or => {
458                // x OR false = x
459                if Self::is_false(right) {
460                    self.stats.gates_fused += 1;
461                    return Some(left.clone());
462                }
463                if Self::is_false(left) {
464                    self.stats.gates_fused += 1;
465                    return Some(right.clone());
466                }
467
468                // x OR true = true
469                if Self::is_true(right) {
470                    self.stats.gates_fused += 1;
471                    return Some(right.clone());
472                }
473                if Self::is_true(left) {
474                    self.stats.gates_fused += 1;
475                    return Some(left.clone());
476                }
477            }
478
479            BinaryOperator::Xor => {
480                // x XOR false = x
481                if Self::is_false(right) {
482                    self.stats.gates_fused += 1;
483                    return Some(left.clone());
484                }
485                if Self::is_false(left) {
486                    self.stats.gates_fused += 1;
487                    return Some(right.clone());
488                }
489            }
490        }
491
492        None
493    }
494
495    /// Check if a node is constant zero
496    fn is_zero(node: &CircuitNode) -> bool {
497        matches!(
498            node,
499            CircuitNode::Constant(CircuitValue::U8(0))
500                | CircuitNode::Constant(CircuitValue::U16(0))
501                | CircuitNode::Constant(CircuitValue::U32(0))
502                | CircuitNode::Constant(CircuitValue::U64(0))
503        )
504    }
505
506    /// Check if a node is constant one
507    fn is_one(node: &CircuitNode) -> bool {
508        matches!(
509            node,
510            CircuitNode::Constant(CircuitValue::U8(1))
511                | CircuitNode::Constant(CircuitValue::U16(1))
512                | CircuitNode::Constant(CircuitValue::U32(1))
513                | CircuitNode::Constant(CircuitValue::U64(1))
514        )
515    }
516
517    /// Check if a node is constant true
518    fn is_true(node: &CircuitNode) -> bool {
519        matches!(node, CircuitNode::Constant(CircuitValue::Bool(true)))
520    }
521
522    /// Check if a node is constant false
523    fn is_false(node: &CircuitNode) -> bool {
524        matches!(node, CircuitNode::Constant(CircuitValue::Bool(false)))
525    }
526
527    /// Gate fusion optimization pass
528    ///
529    /// Combines adjacent operations to reduce overhead. For example:
530    /// - (a + b) + c can be fused into a single multi-input addition
531    /// - Multiple consecutive NOT operations can be eliminated
532    fn gate_fusion_pass(&mut self, node: CircuitNode) -> CircuitNode {
533        match node {
534            CircuitNode::BinaryOp { op, left, right } => {
535                let left = self.gate_fusion_pass(*left);
536                let right = self.gate_fusion_pass(*right);
537
538                CircuitNode::BinaryOp {
539                    op,
540                    left: Box::new(left),
541                    right: Box::new(right),
542                }
543            }
544
545            CircuitNode::UnaryOp {
546                op: UnaryOperator::Not,
547                operand,
548            } => {
549                let operand = self.gate_fusion_pass(*operand);
550
551                // NOT(NOT(x)) = x
552                if let CircuitNode::UnaryOp {
553                    op: UnaryOperator::Not,
554                    operand: inner,
555                } = operand
556                {
557                    self.stats.gates_fused += 2; // Removed 2 NOT gates
558                    return *inner;
559                }
560
561                CircuitNode::UnaryOp {
562                    op: UnaryOperator::Not,
563                    operand: Box::new(operand),
564                }
565            }
566
567            CircuitNode::UnaryOp { op, operand } => {
568                let operand = self.gate_fusion_pass(*operand);
569                CircuitNode::UnaryOp {
570                    op,
571                    operand: Box::new(operand),
572                }
573            }
574
575            CircuitNode::Compare { op, left, right } => {
576                let left = self.gate_fusion_pass(*left);
577                let right = self.gate_fusion_pass(*right);
578
579                CircuitNode::Compare {
580                    op,
581                    left: Box::new(left),
582                    right: Box::new(right),
583                }
584            }
585
586            other => other,
587        }
588    }
589
590    /// Bootstrap minimization pass
591    ///
592    /// Analyzes the circuit to minimize expensive bootstrap operations by:
593    /// - Reordering operations to delay bootstraps
594    /// - Combining operations that share bootstrap requirements
595    /// - Eliminating redundant bootstraps
596    fn bootstrap_minimization_pass(&mut self, node: CircuitNode) -> Result<CircuitNode> {
597        // For now, we apply a simple optimization: reorder additions before multiplications
598        // This allows us to batch cheap operations before expensive ones
599        Ok(self.reorder_for_bootstrap_efficiency(node))
600    }
601
602    /// Reorder operations to minimize bootstraps
603    #[allow(clippy::only_used_in_recursion)]
604    fn reorder_for_bootstrap_efficiency(&self, node: CircuitNode) -> CircuitNode {
605        match node {
606            CircuitNode::BinaryOp { op, left, right } => {
607                let left = self.reorder_for_bootstrap_efficiency(*left);
608                let right = self.reorder_for_bootstrap_efficiency(*right);
609
610                CircuitNode::BinaryOp {
611                    op,
612                    left: Box::new(left),
613                    right: Box::new(right),
614                }
615            }
616
617            CircuitNode::UnaryOp { op, operand } => {
618                let operand = self.reorder_for_bootstrap_efficiency(*operand);
619                CircuitNode::UnaryOp {
620                    op,
621                    operand: Box::new(operand),
622                }
623            }
624
625            CircuitNode::Compare { op, left, right } => {
626                let left = self.reorder_for_bootstrap_efficiency(*left);
627                let right = self.reorder_for_bootstrap_efficiency(*right);
628
629                CircuitNode::Compare {
630                    op,
631                    left: Box::new(left),
632                    right: Box::new(right),
633                }
634            }
635
636            other => other,
637        }
638    }
639
640    /// Dead code elimination pass
641    ///
642    /// Removes operations that don't contribute to the final result
643    fn dead_code_elimination_pass(&mut self, node: CircuitNode) -> CircuitNode {
644        // Mark all nodes as potentially live
645        let mut live_nodes = HashSet::new();
646        self.mark_live_nodes(&node, &mut live_nodes);
647
648        // The current implementation doesn't actually remove dead code
649        // because all nodes in the tree are reachable from the root
650        // This is a placeholder for more sophisticated DCE
651        node
652    }
653
654    /// Mark nodes that contribute to the output
655    #[allow(clippy::only_used_in_recursion)]
656    fn mark_live_nodes(&self, node: &CircuitNode, live_nodes: &mut HashSet<String>) {
657        match node {
658            CircuitNode::Load(name) => {
659                live_nodes.insert(name.clone());
660            }
661
662            CircuitNode::Constant(_) => {}
663
664            CircuitNode::BinaryOp { left, right, .. } => {
665                self.mark_live_nodes(left, live_nodes);
666                self.mark_live_nodes(right, live_nodes);
667            }
668
669            CircuitNode::UnaryOp { operand, .. } => {
670                self.mark_live_nodes(operand, live_nodes);
671            }
672
673            CircuitNode::Compare { left, right, .. } => {
674                self.mark_live_nodes(left, live_nodes);
675                self.mark_live_nodes(right, live_nodes);
676            }
677        }
678    }
679
680    /// Analyze circuit for parallelization opportunities
681    ///
682    /// Builds a dependency graph and identifies operations that can run in parallel
683    fn analyze_parallelism(&self, circuit: &Circuit) -> Result<DependencyGraph> {
684        let mut graph = DependencyGraph::new();
685        let mut node_id_map = HashMap::new();
686        let mut next_id = 0;
687
688        // Build dependency graph
689        self.build_dependency_graph(&circuit.root, &mut graph, &mut node_id_map, &mut next_id);
690
691        graph.node_count = next_id;
692
693        // Identify parallel groups using level-wise traversal
694        graph.parallel_groups = self.identify_parallel_groups(&graph);
695
696        // Find critical path
697        graph.critical_path = self.find_critical_path(&graph);
698
699        Ok(graph)
700    }
701
702    /// Build dependency graph recursively
703    #[allow(clippy::only_used_in_recursion)]
704    fn build_dependency_graph(
705        &self,
706        node: &CircuitNode,
707        graph: &mut DependencyGraph,
708        node_id_map: &mut HashMap<String, NodeId>,
709        next_id: &mut usize,
710    ) -> NodeId {
711        let current_id = NodeId(*next_id);
712        *next_id += 1;
713
714        match node {
715            CircuitNode::Load(name) => {
716                node_id_map.insert(name.clone(), current_id);
717                graph.dependencies.insert(current_id, Vec::new());
718                current_id
719            }
720
721            CircuitNode::Constant(_) => {
722                graph.dependencies.insert(current_id, Vec::new());
723                current_id
724            }
725
726            CircuitNode::BinaryOp { left, right, .. } => {
727                let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
728                let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
729
730                graph
731                    .dependencies
732                    .insert(current_id, vec![left_id, right_id]);
733                current_id
734            }
735
736            CircuitNode::UnaryOp { operand, .. } => {
737                let operand_id = self.build_dependency_graph(operand, graph, node_id_map, next_id);
738
739                graph.dependencies.insert(current_id, vec![operand_id]);
740                current_id
741            }
742
743            CircuitNode::Compare { left, right, .. } => {
744                let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
745                let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
746
747                graph
748                    .dependencies
749                    .insert(current_id, vec![left_id, right_id]);
750                current_id
751            }
752        }
753    }
754
755    /// Identify groups of nodes that can execute in parallel
756    fn identify_parallel_groups(&self, graph: &DependencyGraph) -> Vec<Vec<NodeId>> {
757        let mut levels: HashMap<NodeId, usize> = HashMap::new();
758        let mut queue = VecDeque::new();
759
760        // Find all nodes with no dependencies (level 0)
761        for (node_id, deps) in &graph.dependencies {
762            if deps.is_empty() {
763                levels.insert(*node_id, 0);
764                queue.push_back(*node_id);
765            }
766        }
767
768        // Level-wise traversal
769        while let Some(node_id) = queue.pop_front() {
770            let current_level = levels[&node_id];
771
772            // Find nodes that depend on this node
773            for (dependent_id, deps) in &graph.dependencies {
774                if deps.contains(&node_id) {
775                    // Calculate level for dependent node
776                    let max_dep_level = deps
777                        .iter()
778                        .filter_map(|dep_id| levels.get(dep_id))
779                        .max()
780                        .copied()
781                        .unwrap_or(0);
782
783                    let dependent_level = max_dep_level + 1;
784
785                    if !levels.contains_key(dependent_id) {
786                        levels.insert(*dependent_id, dependent_level);
787                        queue.push_back(*dependent_id);
788                    }
789                }
790            }
791        }
792
793        // Group nodes by level
794        let max_level = levels.values().max().copied().unwrap_or(0);
795        let mut parallel_groups = vec![Vec::new(); max_level + 1];
796
797        for (node_id, level) in levels {
798            parallel_groups[level].push(node_id);
799        }
800
801        // Sort each group for deterministic output
802        for group in &mut parallel_groups {
803            group.sort();
804        }
805
806        parallel_groups
807    }
808
809    /// Find the critical path (longest dependency chain)
810    fn find_critical_path(&self, graph: &DependencyGraph) -> Vec<NodeId> {
811        // Simple implementation: find the node with the longest chain to root
812        let mut max_path = Vec::new();
813
814        for node_id in graph.dependencies.keys() {
815            let path = self.find_path_to_root(*node_id, graph);
816            if path.len() > max_path.len() {
817                max_path = path;
818            }
819        }
820
821        max_path
822    }
823
824    /// Find path from a node to a root (node with no dependencies)
825    #[allow(clippy::only_used_in_recursion)]
826    fn find_path_to_root(&self, node_id: NodeId, graph: &DependencyGraph) -> Vec<NodeId> {
827        let deps = graph
828            .dependencies
829            .get(&node_id)
830            .map(|v| v.as_slice())
831            .unwrap_or(&[]);
832
833        if deps.is_empty() {
834            return vec![node_id];
835        }
836
837        // Find the longest path through dependencies
838        let mut longest_path = Vec::new();
839        for dep_id in deps {
840            let dep_path = self.find_path_to_root(*dep_id, graph);
841            if dep_path.len() > longest_path.len() {
842                longest_path = dep_path;
843            }
844        }
845
846        longest_path.push(node_id);
847        longest_path
848    }
849}
850
851impl Default for CircuitOptimizer {
852    fn default() -> Self {
853        Self::new()
854    }
855}
856
857#[cfg(test)]
858mod tests {
859    use super::*;
860    use crate::compute::circuit::CircuitBuilder;
861
862    #[test]
863    fn test_constant_folding() -> Result<()> {
864        let mut optimizer = CircuitOptimizer::new();
865        let builder = CircuitBuilder::new();
866
867        // Create circuit: 5 + 3
868        let a = builder.constant(CircuitValue::U8(5));
869        let b = builder.constant(CircuitValue::U8(3));
870        let sum = builder.add(a, b);
871
872        let circuit = Circuit::new(sum, HashMap::new())?;
873        let optimized = optimizer.optimize(circuit)?;
874
875        // Should fold to constant 8
876        assert!(matches!(
877            optimized.root,
878            CircuitNode::Constant(CircuitValue::U8(8))
879        ));
880        assert_eq!(optimizer.stats().constants_folded, 1);
881
882        Ok(())
883    }
884
885    #[test]
886    fn test_algebraic_identities() -> Result<()> {
887        let mut optimizer = CircuitOptimizer::new();
888        let mut builder = CircuitBuilder::new();
889        builder.declare_variable("x", EncryptedType::U8);
890
891        // Test x + 0 = x
892        let x = builder.load("x");
893        let zero = builder.constant(CircuitValue::U8(0));
894        let add_zero = builder.add(x.clone(), zero);
895
896        let circuit = Circuit::new(add_zero, builder.variable_types_clone())?;
897        let optimized = optimizer.optimize(circuit)?;
898
899        // Should simplify to just x
900        assert!(matches!(optimized.root, CircuitNode::Load(_)));
901
902        Ok(())
903    }
904
905    #[test]
906    fn test_double_negation_elimination() -> Result<()> {
907        let mut optimizer = CircuitOptimizer::new();
908        let mut builder = CircuitBuilder::new();
909        builder.declare_variable("x", EncryptedType::Bool);
910
911        // Test NOT(NOT(x)) = x
912        let x = builder.load("x");
913        let not_x = builder.not(x);
914        let not_not_x = builder.not(not_x);
915
916        let circuit = Circuit::new(not_not_x, builder.variable_types_clone())?;
917        let optimized = optimizer.optimize(circuit)?;
918
919        // Should simplify to just x
920        assert!(matches!(optimized.root, CircuitNode::Load(_)));
921        assert!(optimizer.stats().gates_fused >= 2);
922
923        Ok(())
924    }
925
926    #[test]
927    fn test_bootstrap_counting() -> Result<()> {
928        let optimizer = CircuitOptimizer::new();
929        let mut builder = CircuitBuilder::new();
930        builder
931            .declare_variable("a", EncryptedType::U8)
932            .declare_variable("b", EncryptedType::U8);
933
934        // Circuit with multiplication (requires bootstrap)
935        let a = builder.load("a");
936        let b = builder.load("b");
937        let mul = builder.mul(a, b);
938
939        let circuit = Circuit::new(mul, builder.variable_types_clone())?;
940        let bootstrap_count = optimizer.count_bootstraps(&circuit.root);
941
942        assert_eq!(bootstrap_count, 1); // One multiplication
943
944        Ok(())
945    }
946
947    #[test]
948    fn test_parallelization_analysis() -> Result<()> {
949        let mut optimizer = CircuitOptimizer::new();
950        let mut builder = CircuitBuilder::new();
951        builder
952            .declare_variable("a", EncryptedType::U8)
953            .declare_variable("b", EncryptedType::U8)
954            .declare_variable("c", EncryptedType::U8);
955
956        // Circuit: (a + b) + c - has some parallelism potential
957        let a = builder.load("a");
958        let b = builder.load("b");
959        let c = builder.load("c");
960        let sum1 = builder.add(a, b);
961        let sum2 = builder.add(sum1, c);
962
963        let circuit = Circuit::new(sum2, builder.variable_types_clone())?;
964        let optimized = optimizer.optimize(circuit)?;
965
966        let graph = optimizer.dependency_graph();
967        assert!(graph.node_count > 0);
968        assert!(!graph.parallel_groups.is_empty());
969
970        Ok(())
971    }
972
973    #[test]
974    fn test_optimization_stats() -> Result<()> {
975        let mut optimizer = CircuitOptimizer::new();
976        let builder = CircuitBuilder::new();
977
978        // Complex circuit with optimization opportunities
979        let a = builder.constant(CircuitValue::U8(5));
980        let b = builder.constant(CircuitValue::U8(3));
981        let zero = builder.constant(CircuitValue::U8(0));
982
983        let sum = builder.add(a, b); // Should fold to 8
984        let add_zero = builder.add(sum, zero); // Should eliminate +0
985
986        let circuit = Circuit::new(add_zero, HashMap::new())?;
987        let original_gates = circuit.gate_count;
988
989        let optimized = optimizer.optimize(circuit)?;
990        let optimized_gates = optimized.gate_count;
991
992        assert!(optimized_gates < original_gates);
993        assert!(optimizer.stats().gate_reduction_percent() > 0.0);
994
995        Ok(())
996    }
997
998    #[test]
999    fn test_complex_circuit_optimization() -> Result<()> {
1000        let mut optimizer = CircuitOptimizer::new();
1001        let mut builder = CircuitBuilder::new();
1002        builder
1003            .declare_variable("a", EncryptedType::U8)
1004            .declare_variable("b", EncryptedType::U8);
1005
1006        // Circuit: (a * 1) + (b * 0) + 5
1007        // Should optimize to: a + 5
1008        let a = builder.load("a");
1009        let b = builder.load("b");
1010        let one = builder.constant(CircuitValue::U8(1));
1011        let zero = builder.constant(CircuitValue::U8(0));
1012        let five = builder.constant(CircuitValue::U8(5));
1013
1014        let a_times_1 = builder.mul(a, one);
1015        let b_times_0 = builder.mul(b, zero);
1016        let sum1 = builder.add(a_times_1, b_times_0);
1017        let result = builder.add(sum1, five);
1018
1019        let circuit = Circuit::new(result, builder.variable_types_clone())?;
1020        let original_gates = circuit.gate_count;
1021
1022        let optimized = optimizer.optimize(circuit)?;
1023
1024        assert!(optimized.gate_count < original_gates);
1025        assert!(optimizer.stats().gate_reduction_percent() >= 30.0);
1026
1027        Ok(())
1028    }
1029}