Skip to main content

amaters_core/compute/
optimizer.rs

1//! Advanced circuit optimization for FHE computations
2//!
3//! This module provides sophisticated optimization passes to reduce the computational
4//! cost of FHE circuits. The optimizer focuses on:
5//!
6//! 1. **Bootstrap Minimization** - Reducing expensive bootstrap operations
7//! 2. **Gate Fusion** - Combining adjacent operations to reduce overhead
8//! 3. **Dead Code Elimination** - Removing unused operations
9//! 4. **Parallelization Analysis** - Identifying independent operations for parallel execution
10//!
11//! These optimizations can reduce circuit execution time by 30-50% in typical cases.
12
13use crate::compute::circuit::{
14    BinaryOperator, Circuit, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
15    UnaryOperator,
16};
17use crate::error::{AmateRSError, ErrorContext, Result};
18use std::collections::{HashMap, HashSet, VecDeque};
19
20/// Statistics collected during optimization
21#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct OptimizationStats {
23    /// Number of gates before optimization
24    pub original_gate_count: usize,
25
26    /// Number of gates after optimization
27    pub optimized_gate_count: usize,
28
29    /// Number of bootstrap operations before optimization
30    pub original_bootstrap_count: usize,
31
32    /// Number of bootstrap operations after optimization
33    pub optimized_bootstrap_count: usize,
34
35    /// Number of dead code nodes removed
36    pub dead_code_removed: usize,
37
38    /// Number of nodes eliminated by DCE pass
39    pub nodes_eliminated: usize,
40
41    /// Number of algebraic simplifications applied
42    pub algebraic_simplifications: usize,
43
44    /// Number of constant expressions folded
45    pub constants_folded: usize,
46
47    /// Number of gates fused
48    pub gates_fused: usize,
49
50    /// Circuit depth before optimization
51    pub original_depth: usize,
52
53    /// Circuit depth after optimization
54    pub optimized_depth: usize,
55}
56
57impl OptimizationStats {
58    /// Calculate the reduction percentage in gate count
59    pub fn gate_reduction_percent(&self) -> f64 {
60        if self.original_gate_count == 0 {
61            return 0.0;
62        }
63        let reduction = self
64            .original_gate_count
65            .saturating_sub(self.optimized_gate_count);
66        (reduction as f64 / self.original_gate_count as f64) * 100.0
67    }
68
69    /// Calculate the reduction percentage in bootstrap operations
70    pub fn bootstrap_reduction_percent(&self) -> f64 {
71        if self.original_bootstrap_count == 0 {
72            return 0.0;
73        }
74        let reduction = self
75            .original_bootstrap_count
76            .saturating_sub(self.optimized_bootstrap_count);
77        (reduction as f64 / self.original_bootstrap_count as f64) * 100.0
78    }
79
80    /// Aggregate total statistics across all passes
81    pub fn total_stats(&self) -> (usize, usize, usize) {
82        (
83            self.nodes_eliminated + self.dead_code_removed,
84            self.algebraic_simplifications + self.gates_fused,
85            self.constants_folded,
86        )
87    }
88}
89
90/// Dependency information for parallelization
91#[derive(Debug, Clone, PartialEq, Eq)]
92pub struct DependencyGraph {
93    /// Node ID to its dependencies
94    pub dependencies: HashMap<NodeId, Vec<NodeId>>,
95
96    /// Nodes that can be executed in parallel (sets of independent nodes)
97    pub parallel_groups: Vec<Vec<NodeId>>,
98
99    /// Critical path through the circuit (longest dependency chain)
100    pub critical_path: Vec<NodeId>,
101
102    /// Total number of nodes in the graph
103    pub node_count: usize,
104}
105
106impl DependencyGraph {
107    /// Create an empty dependency graph
108    pub fn new() -> Self {
109        Self {
110            dependencies: HashMap::new(),
111            parallel_groups: Vec::new(),
112            critical_path: Vec::new(),
113            node_count: 0,
114        }
115    }
116
117    /// Calculate the maximum parallelism (largest parallel group)
118    pub fn max_parallelism(&self) -> usize {
119        self.parallel_groups
120            .iter()
121            .map(|g| g.len())
122            .max()
123            .unwrap_or(0)
124    }
125
126    /// Calculate the average parallelism
127    pub fn avg_parallelism(&self) -> f64 {
128        if self.parallel_groups.is_empty() {
129            return 0.0;
130        }
131        let total: usize = self.parallel_groups.iter().map(|g| g.len()).sum();
132        total as f64 / self.parallel_groups.len() as f64
133    }
134
135    /// Returns nodes in topological order (dependencies before dependents)
136    pub fn topological_order(&self) -> Vec<NodeId> {
137        self.compute_topological_order()
138    }
139
140    fn compute_topological_order(&self) -> Vec<NodeId> {
141        // Kahn's algorithm: in_degree[node] = number of prerequisites (dependencies)
142        let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
143
144        // Initialize all nodes to in-degree = number of their dependencies
145        for (node_id, deps) in &self.dependencies {
146            *in_degree.entry(*node_id).or_insert(0) = deps.len();
147            // Ensure deps are also in the map
148            for dep_id in deps {
149                in_degree.entry(*dep_id).or_insert(0);
150            }
151        }
152
153        // Start with nodes that have no dependencies (leaves)
154        let mut queue: std::collections::BTreeSet<NodeId> = in_degree
155            .iter()
156            .filter(|&(_, deg)| *deg == 0)
157            .map(|(&id, _)| id)
158            .collect();
159
160        let mut result = Vec::new();
161
162        // Build reverse edges: for each dep, who depends on it?
163        let mut dependents: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
164        for (node_id, deps) in &self.dependencies {
165            for dep_id in deps {
166                dependents.entry(*dep_id).or_default().push(*node_id);
167            }
168        }
169
170        while let Some(&node_id) = queue.iter().next() {
171            queue.remove(&node_id);
172            result.push(node_id);
173
174            if let Some(dep_nodes) = dependents.get(&node_id) {
175                for &dependent_id in dep_nodes {
176                    if let Some(deg) = in_degree.get_mut(&dependent_id) {
177                        if *deg > 0 {
178                            *deg -= 1;
179                            if *deg == 0 {
180                                queue.insert(dependent_id);
181                            }
182                        }
183                    }
184                }
185            }
186        }
187
188        result
189    }
190}
191
192impl Default for DependencyGraph {
193    fn default() -> Self {
194        Self::new()
195    }
196}
197
198/// Node identifier for dependency tracking
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
200pub struct NodeId(pub usize);
201
202/// Advanced circuit optimizer with multiple optimization passes
203#[derive(Debug, Clone)]
204pub struct CircuitOptimizer {
205    /// Enable constant folding optimization
206    pub enable_constant_folding: bool,
207
208    /// Enable dead code elimination
209    pub enable_dead_code_elimination: bool,
210
211    /// Enable bootstrap minimization
212    pub enable_bootstrap_minimization: bool,
213
214    /// Enable gate fusion
215    pub enable_gate_fusion: bool,
216
217    /// Enable parallelization analysis
218    pub enable_parallelization_analysis: bool,
219
220    /// Statistics from the last optimization
221    stats: OptimizationStats,
222
223    /// Dependency graph from the last optimization
224    dependency_graph: DependencyGraph,
225}
226
227impl CircuitOptimizer {
228    /// Create a new optimizer with all optimizations enabled
229    pub fn new() -> Self {
230        Self {
231            enable_constant_folding: true,
232            enable_dead_code_elimination: true,
233            enable_bootstrap_minimization: true,
234            enable_gate_fusion: true,
235            enable_parallelization_analysis: true,
236            stats: OptimizationStats::default(),
237            dependency_graph: DependencyGraph::new(),
238        }
239    }
240
241    /// Create an optimizer with no optimizations enabled
242    pub fn disabled() -> Self {
243        Self {
244            enable_constant_folding: false,
245            enable_dead_code_elimination: false,
246            enable_bootstrap_minimization: false,
247            enable_gate_fusion: false,
248            enable_parallelization_analysis: false,
249            stats: OptimizationStats::default(),
250            dependency_graph: DependencyGraph::new(),
251        }
252    }
253
254    /// Get the statistics from the last optimization
255    pub fn stats(&self) -> &OptimizationStats {
256        &self.stats
257    }
258
259    /// Get the dependency graph from the last optimization
260    pub fn dependency_graph(&self) -> &DependencyGraph {
261        &self.dependency_graph
262    }
263
264    /// Get aggregated totals: (nodes_eliminated, algebraic_simplifications, constant_folds)
265    pub fn total_stats(&self) -> (usize, usize, usize) {
266        self.stats.total_stats()
267    }
268
269    /// Optimize a circuit by applying all enabled optimization passes
270    pub fn optimize(&mut self, circuit: Circuit) -> Result<Circuit> {
271        // Record original statistics
272        self.stats.original_gate_count = circuit.gate_count;
273        self.stats.original_depth = circuit.depth;
274        self.stats.original_bootstrap_count = self.count_bootstraps(&circuit.root);
275
276        let mut optimized_root = circuit.root.clone();
277
278        // Apply optimization passes in order
279        if self.enable_constant_folding {
280            optimized_root = self.constant_folding_pass(optimized_root);
281        }
282
283        if self.enable_gate_fusion {
284            optimized_root = self.gate_fusion_pass(optimized_root);
285        }
286
287        if self.enable_bootstrap_minimization {
288            optimized_root = self.bootstrap_minimization_pass(optimized_root)?;
289        }
290
291        if self.enable_dead_code_elimination {
292            optimized_root = self.dead_code_elimination_pass(optimized_root);
293        }
294
295        // Build optimized circuit
296        let optimized_circuit = Circuit::new(optimized_root, circuit.variable_types)?;
297
298        // Record optimized statistics
299        self.stats.optimized_gate_count = optimized_circuit.gate_count;
300        self.stats.optimized_depth = optimized_circuit.depth;
301        self.stats.optimized_bootstrap_count = self.count_bootstraps(&optimized_circuit.root);
302
303        // Analyze parallelization if enabled
304        if self.enable_parallelization_analysis {
305            self.dependency_graph = self.analyze_parallelism(&optimized_circuit)?;
306        }
307
308        Ok(optimized_circuit)
309    }
310
311    /// Count the number of bootstrap operations in a circuit
312    ///
313    /// In TFHE, bootstrapping is required after certain operations to refresh noise.
314    /// For this implementation, we estimate bootstraps based on operation types:
315    /// - Multiplication requires bootstrap
316    /// - Comparison operations require bootstrap
317    /// - Deep chains of additions may require bootstrap
318    #[allow(clippy::only_used_in_recursion)]
319    fn count_bootstraps(&self, node: &CircuitNode) -> usize {
320        match node {
321            CircuitNode::Load(_)
322            | CircuitNode::Constant(_)
323            | CircuitNode::EncryptedConstant { .. } => 0,
324
325            CircuitNode::BinaryOp { op, left, right } => {
326                let left_bootstraps = self.count_bootstraps(left);
327                let right_bootstraps = self.count_bootstraps(right);
328
329                // Multiplication requires bootstrap
330                let op_bootstrap = match op {
331                    BinaryOperator::Mul => 1,
332                    _ => 0,
333                };
334
335                left_bootstraps + right_bootstraps + op_bootstrap
336            }
337
338            CircuitNode::UnaryOp { operand, .. } => self.count_bootstraps(operand),
339
340            CircuitNode::Compare { left, right, .. } => {
341                let left_bootstraps = self.count_bootstraps(left);
342                let right_bootstraps = self.count_bootstraps(right);
343
344                // Comparisons typically require bootstrap
345                left_bootstraps + right_bootstraps + 1
346            }
347            CircuitNode::NaryOp { op, operands } => {
348                let operand_bootstraps: usize =
349                    operands.iter().map(|o| self.count_bootstraps(o)).sum();
350                let op_bootstraps = match op {
351                    BinaryOperator::Mul => operands.len().saturating_sub(1),
352                    _ => 0,
353                };
354                operand_bootstraps + op_bootstraps
355            }
356        }
357    }
358
359    /// Constant folding optimization pass
360    ///
361    /// Evaluates constant expressions at compile time to reduce runtime computation
362    fn constant_folding_pass(&mut self, node: CircuitNode) -> CircuitNode {
363        match node {
364            CircuitNode::BinaryOp { op, left, right } => {
365                let left = self.constant_folding_pass(*left);
366                let right = self.constant_folding_pass(*right);
367
368                // Try to fold constants
369                if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
370                    if let Some(result) = self.fold_binary_constants(op, l, r) {
371                        self.stats.constants_folded += 1;
372                        return CircuitNode::Constant(result);
373                    }
374                }
375
376                // Apply algebraic identities
377                if let Some(simplified) = self.apply_algebraic_identities(op, &left, &right) {
378                    return simplified;
379                }
380
381                CircuitNode::BinaryOp {
382                    op,
383                    left: Box::new(left),
384                    right: Box::new(right),
385                }
386            }
387
388            CircuitNode::UnaryOp { op, operand } => {
389                let operand = self.constant_folding_pass(*operand);
390
391                if let CircuitNode::Constant(val) = &operand {
392                    if let Some(result) = self.fold_unary_constant(op, val) {
393                        self.stats.constants_folded += 1;
394                        return CircuitNode::Constant(result);
395                    }
396                }
397
398                CircuitNode::UnaryOp {
399                    op,
400                    operand: Box::new(operand),
401                }
402            }
403
404            CircuitNode::Compare { op, left, right } => {
405                let left = self.constant_folding_pass(*left);
406                let right = self.constant_folding_pass(*right);
407
408                CircuitNode::Compare {
409                    op,
410                    left: Box::new(left),
411                    right: Box::new(right),
412                }
413            }
414
415            CircuitNode::NaryOp { op, operands } => {
416                let new_operands: Vec<CircuitNode> = operands
417                    .into_iter()
418                    .map(|o| self.constant_folding_pass(o))
419                    .collect();
420                CircuitNode::NaryOp {
421                    op,
422                    operands: new_operands,
423                }
424            }
425
426            other => other,
427        }
428    }
429
430    /// Fold binary operation on constants
431    fn fold_binary_constants(
432        &self,
433        op: BinaryOperator,
434        left: &CircuitValue,
435        right: &CircuitValue,
436    ) -> Option<CircuitValue> {
437        match (left, right) {
438            (CircuitValue::U8(l), CircuitValue::U8(r)) => match op {
439                BinaryOperator::Add => Some(CircuitValue::U8(l.wrapping_add(*r))),
440                BinaryOperator::Sub => Some(CircuitValue::U8(l.wrapping_sub(*r))),
441                BinaryOperator::Mul => Some(CircuitValue::U8(l.wrapping_mul(*r))),
442                _ => None,
443            },
444            (CircuitValue::U16(l), CircuitValue::U16(r)) => match op {
445                BinaryOperator::Add => Some(CircuitValue::U16(l.wrapping_add(*r))),
446                BinaryOperator::Sub => Some(CircuitValue::U16(l.wrapping_sub(*r))),
447                BinaryOperator::Mul => Some(CircuitValue::U16(l.wrapping_mul(*r))),
448                _ => None,
449            },
450            (CircuitValue::U32(l), CircuitValue::U32(r)) => match op {
451                BinaryOperator::Add => Some(CircuitValue::U32(l.wrapping_add(*r))),
452                BinaryOperator::Sub => Some(CircuitValue::U32(l.wrapping_sub(*r))),
453                BinaryOperator::Mul => Some(CircuitValue::U32(l.wrapping_mul(*r))),
454                _ => None,
455            },
456            (CircuitValue::U64(l), CircuitValue::U64(r)) => match op {
457                BinaryOperator::Add => Some(CircuitValue::U64(l.wrapping_add(*r))),
458                BinaryOperator::Sub => Some(CircuitValue::U64(l.wrapping_sub(*r))),
459                BinaryOperator::Mul => Some(CircuitValue::U64(l.wrapping_mul(*r))),
460                _ => None,
461            },
462            (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
463                BinaryOperator::And => Some(CircuitValue::Bool(*l && *r)),
464                BinaryOperator::Or => Some(CircuitValue::Bool(*l || *r)),
465                BinaryOperator::Xor => Some(CircuitValue::Bool(*l ^ *r)),
466                _ => None,
467            },
468            _ => None,
469        }
470    }
471
472    /// Fold unary operation on constant
473    fn fold_unary_constant(&self, op: UnaryOperator, value: &CircuitValue) -> Option<CircuitValue> {
474        match (op, value) {
475            (UnaryOperator::Not, CircuitValue::Bool(v)) => Some(CircuitValue::Bool(!*v)),
476            _ => None,
477        }
478    }
479
480    /// Apply algebraic identities to simplify expressions
481    /// Examples: x + 0 = x, x * 1 = x, x * 0 = 0, x AND true = x, etc.
482    fn apply_algebraic_identities(
483        &mut self,
484        op: BinaryOperator,
485        left: &CircuitNode,
486        right: &CircuitNode,
487    ) -> Option<CircuitNode> {
488        match op {
489            BinaryOperator::Add => {
490                // x + 0 = x
491                if Self::is_zero(right) {
492                    self.stats.gates_fused += 1;
493                    return Some(left.clone());
494                }
495                // 0 + x = x
496                if Self::is_zero(left) {
497                    self.stats.gates_fused += 1;
498                    return Some(right.clone());
499                }
500            }
501
502            BinaryOperator::Sub => {
503                // x - 0 = x
504                if Self::is_zero(right) {
505                    self.stats.gates_fused += 1;
506                    return Some(left.clone());
507                }
508            }
509
510            BinaryOperator::Mul => {
511                // x * 0 = 0
512                if Self::is_zero(right) {
513                    self.stats.gates_fused += 1;
514                    return Some(right.clone());
515                }
516                if Self::is_zero(left) {
517                    self.stats.gates_fused += 1;
518                    return Some(left.clone());
519                }
520
521                // x * 1 = x
522                if Self::is_one(right) {
523                    self.stats.gates_fused += 1;
524                    return Some(left.clone());
525                }
526                // 1 * x = x
527                if Self::is_one(left) {
528                    self.stats.gates_fused += 1;
529                    return Some(right.clone());
530                }
531            }
532
533            BinaryOperator::And => {
534                // x AND true = x
535                if Self::is_true(right) {
536                    self.stats.gates_fused += 1;
537                    return Some(left.clone());
538                }
539                if Self::is_true(left) {
540                    self.stats.gates_fused += 1;
541                    return Some(right.clone());
542                }
543
544                // x AND false = false
545                if Self::is_false(right) {
546                    self.stats.gates_fused += 1;
547                    return Some(right.clone());
548                }
549                if Self::is_false(left) {
550                    self.stats.gates_fused += 1;
551                    return Some(left.clone());
552                }
553            }
554
555            BinaryOperator::Or => {
556                // x OR false = x
557                if Self::is_false(right) {
558                    self.stats.gates_fused += 1;
559                    return Some(left.clone());
560                }
561                if Self::is_false(left) {
562                    self.stats.gates_fused += 1;
563                    return Some(right.clone());
564                }
565
566                // x OR true = true
567                if Self::is_true(right) {
568                    self.stats.gates_fused += 1;
569                    return Some(right.clone());
570                }
571                if Self::is_true(left) {
572                    self.stats.gates_fused += 1;
573                    return Some(left.clone());
574                }
575            }
576
577            BinaryOperator::Xor => {
578                // x XOR false = x
579                if Self::is_false(right) {
580                    self.stats.gates_fused += 1;
581                    return Some(left.clone());
582                }
583                if Self::is_false(left) {
584                    self.stats.gates_fused += 1;
585                    return Some(right.clone());
586                }
587            }
588        }
589
590        None
591    }
592
593    /// Check if a node is constant zero
594    fn is_zero(node: &CircuitNode) -> bool {
595        matches!(
596            node,
597            CircuitNode::Constant(CircuitValue::U8(0))
598                | CircuitNode::Constant(CircuitValue::U16(0))
599                | CircuitNode::Constant(CircuitValue::U32(0))
600                | CircuitNode::Constant(CircuitValue::U64(0))
601        )
602    }
603
604    /// Check if a node is constant one
605    fn is_one(node: &CircuitNode) -> bool {
606        matches!(
607            node,
608            CircuitNode::Constant(CircuitValue::U8(1))
609                | CircuitNode::Constant(CircuitValue::U16(1))
610                | CircuitNode::Constant(CircuitValue::U32(1))
611                | CircuitNode::Constant(CircuitValue::U64(1))
612        )
613    }
614
615    /// Check if a node is constant true
616    fn is_true(node: &CircuitNode) -> bool {
617        matches!(node, CircuitNode::Constant(CircuitValue::Bool(true)))
618    }
619
620    /// Check if a node is constant false
621    fn is_false(node: &CircuitNode) -> bool {
622        matches!(node, CircuitNode::Constant(CircuitValue::Bool(false)))
623    }
624
625    /// Gate fusion optimization pass
626    ///
627    /// Combines adjacent operations to reduce overhead:
628    /// - Associative+commutative same-op chains are flattened into NaryOp nodes
629    /// - Multiple consecutive NOT operations are eliminated
630    fn gate_fusion_pass(&mut self, node: CircuitNode) -> CircuitNode {
631        match node {
632            CircuitNode::BinaryOp { op, left, right } => {
633                let left = self.gate_fusion_pass(*left);
634                let right = self.gate_fusion_pass(*right);
635
636                match op {
637                    BinaryOperator::Add
638                    | BinaryOperator::Mul
639                    | BinaryOperator::And
640                    | BinaryOperator::Or
641                    | BinaryOperator::Xor => {
642                        // Collect flat operand list by flattening same-op children.
643                        // After these two calls, left/right are consumed into operands.
644                        let mut operands: Vec<CircuitNode> = Vec::new();
645                        Self::collect_nary_operands(op, left, &mut operands);
646                        Self::collect_nary_operands(op, right, &mut operands);
647                        // Invariant: operands.len() >= 2 (each of left/right contributes >= 1)
648
649                        if operands.len() >= 3 {
650                            self.stats.gates_fused += operands.len().saturating_sub(2);
651                            CircuitNode::NaryOp { op, operands }
652                        } else {
653                            // Exactly 2 operands — BinaryOp is the canonical form
654                            Self::build_balanced_reduction(op, operands)
655                        }
656                    }
657                    _ => CircuitNode::BinaryOp {
658                        op,
659                        left: Box::new(left),
660                        right: Box::new(right),
661                    },
662                }
663            }
664
665            CircuitNode::NaryOp { op, operands } => {
666                // Recurse into operands and potentially absorb more
667                let new_operands: Vec<CircuitNode> = operands
668                    .into_iter()
669                    .map(|o| self.gate_fusion_pass(o))
670                    .collect();
671                // Re-flatten after recursion
672                let mut flat_operands = Vec::new();
673                for operand in new_operands {
674                    Self::collect_nary_operands(op, operand, &mut flat_operands);
675                }
676                if flat_operands.len() >= 2 {
677                    CircuitNode::NaryOp {
678                        op,
679                        operands: flat_operands,
680                    }
681                } else if flat_operands.len() == 1 {
682                    flat_operands.remove(0)
683                } else {
684                    CircuitNode::NaryOp {
685                        op,
686                        operands: flat_operands,
687                    }
688                }
689            }
690
691            CircuitNode::UnaryOp {
692                op: UnaryOperator::Not,
693                operand,
694            } => {
695                let operand = self.gate_fusion_pass(*operand);
696
697                // NOT(NOT(x)) = x
698                if let CircuitNode::UnaryOp {
699                    op: UnaryOperator::Not,
700                    operand: inner,
701                } = operand
702                {
703                    self.stats.gates_fused += 2;
704                    return *inner;
705                }
706
707                CircuitNode::UnaryOp {
708                    op: UnaryOperator::Not,
709                    operand: Box::new(operand),
710                }
711            }
712
713            CircuitNode::UnaryOp { op, operand } => {
714                let operand = self.gate_fusion_pass(*operand);
715                CircuitNode::UnaryOp {
716                    op,
717                    operand: Box::new(operand),
718                }
719            }
720
721            CircuitNode::Compare { op, left, right } => {
722                let left = self.gate_fusion_pass(*left);
723                let right = self.gate_fusion_pass(*right);
724                CircuitNode::Compare {
725                    op,
726                    left: Box::new(left),
727                    right: Box::new(right),
728                }
729            }
730
731            other => other,
732        }
733    }
734
735    /// Helper to collect operands for N-ary fusion by flattening same-op chains
736    fn collect_nary_operands(op: BinaryOperator, node: CircuitNode, out: &mut Vec<CircuitNode>) {
737        match node {
738            CircuitNode::BinaryOp {
739                op: child_op,
740                left,
741                right,
742            } if child_op == op => {
743                Self::collect_nary_operands(op, *left, out);
744                Self::collect_nary_operands(op, *right, out);
745            }
746            CircuitNode::NaryOp {
747                op: child_op,
748                operands,
749            } if child_op == op => {
750                for operand in operands {
751                    Self::collect_nary_operands(op, operand, out);
752                }
753            }
754            other => out.push(other),
755        }
756    }
757
758    /// Bootstrap minimization pass
759    ///
760    /// Analyzes the circuit to minimize expensive bootstrap operations by:
761    /// - Reordering operations to delay bootstraps
762    /// - Combining operations that share bootstrap requirements
763    /// - Eliminating redundant bootstraps
764    fn bootstrap_minimization_pass(&mut self, node: CircuitNode) -> Result<CircuitNode> {
765        Ok(self.reorder_for_bootstrap_efficiency(node))
766    }
767
768    /// Reorder operations to minimize bootstraps
769    ///
770    /// For commutative operators, places the higher-bootstrap-cost subtree
771    /// first (left), which improves scheduling locality. For NaryOp Mul
772    /// (bootstrap-heavy), builds a balanced binary reduction tree.
773    fn reorder_for_bootstrap_efficiency(&mut self, node: CircuitNode) -> CircuitNode {
774        match node {
775            CircuitNode::BinaryOp { op, left, right } => {
776                let left = self.reorder_for_bootstrap_efficiency(*left);
777                let right = self.reorder_for_bootstrap_efficiency(*right);
778
779                let is_commutative = matches!(
780                    op,
781                    BinaryOperator::Add
782                        | BinaryOperator::Mul
783                        | BinaryOperator::And
784                        | BinaryOperator::Or
785                        | BinaryOperator::Xor
786                );
787
788                if is_commutative {
789                    let left_cost = self.count_bootstraps(&left);
790                    let right_cost = self.count_bootstraps(&right);
791                    if right_cost > left_cost {
792                        return CircuitNode::BinaryOp {
793                            op,
794                            left: Box::new(right),
795                            right: Box::new(left),
796                        };
797                    }
798                }
799
800                CircuitNode::BinaryOp {
801                    op,
802                    left: Box::new(left),
803                    right: Box::new(right),
804                }
805            }
806
807            CircuitNode::NaryOp { op, operands } => {
808                // Recurse into operands
809                let processed_operands: Vec<CircuitNode> = operands
810                    .into_iter()
811                    .map(|o| self.reorder_for_bootstrap_efficiency(o))
812                    .collect();
813
814                // For Mul (bootstrap-heavy), build balanced binary reduction tree
815                if matches!(op, BinaryOperator::Mul) && processed_operands.len() >= 2 {
816                    return Self::build_balanced_reduction(op, processed_operands);
817                }
818
819                // For Add and logical ops (no bootstrap cost), sort by cost descending
820                let mut with_costs: Vec<(usize, CircuitNode)> = processed_operands
821                    .into_iter()
822                    .map(|o| {
823                        let cost = self.count_bootstraps(&o);
824                        (cost, o)
825                    })
826                    .collect();
827                with_costs.sort_by_key(|b| std::cmp::Reverse(b.0));
828                let sorted_operands: Vec<CircuitNode> =
829                    with_costs.into_iter().map(|(_, o)| o).collect();
830
831                CircuitNode::NaryOp {
832                    op,
833                    operands: sorted_operands,
834                }
835            }
836
837            CircuitNode::UnaryOp { op, operand } => {
838                let operand = self.reorder_for_bootstrap_efficiency(*operand);
839                CircuitNode::UnaryOp {
840                    op,
841                    operand: Box::new(operand),
842                }
843            }
844
845            CircuitNode::Compare { op, left, right } => {
846                let left = self.reorder_for_bootstrap_efficiency(*left);
847                let right = self.reorder_for_bootstrap_efficiency(*right);
848                CircuitNode::Compare {
849                    op,
850                    left: Box::new(left),
851                    right: Box::new(right),
852                }
853            }
854
855            other => other,
856        }
857    }
858
859    /// Build a balanced binary reduction tree for N operands
860    fn build_balanced_reduction(op: BinaryOperator, operands: Vec<CircuitNode>) -> CircuitNode {
861        if operands.is_empty() {
862            // Degenerate case: return a zero-value placeholder
863            return CircuitNode::Constant(crate::compute::circuit::CircuitValue::U8(0));
864        }
865        if operands.len() == 1 {
866            // unwrap is safe here: len == 1, so next() always returns Some
867            return operands.into_iter().next().unwrap_or(CircuitNode::Constant(
868                crate::compute::circuit::CircuitValue::U8(0),
869            ));
870        }
871        if operands.len() == 2 {
872            let mut it = operands.into_iter();
873            // Both next() calls succeed because len == 2
874            let left = it.next().unwrap_or(CircuitNode::Constant(
875                crate::compute::circuit::CircuitValue::U8(0),
876            ));
877            let right = it.next().unwrap_or(CircuitNode::Constant(
878                crate::compute::circuit::CircuitValue::U8(0),
879            ));
880            return CircuitNode::BinaryOp {
881                op,
882                left: Box::new(left),
883                right: Box::new(right),
884            };
885        }
886
887        let mid = operands.len() / 2;
888        let (left_operands, right_operands) = operands.into_iter().enumerate().fold(
889            (Vec::new(), Vec::new()),
890            |(mut l, mut r), (i, node)| {
891                if i < mid {
892                    l.push(node);
893                } else {
894                    r.push(node);
895                }
896                (l, r)
897            },
898        );
899
900        let left_node = Self::build_balanced_reduction(op, left_operands);
901        let right_node = Self::build_balanced_reduction(op, right_operands);
902
903        CircuitNode::BinaryOp {
904            op,
905            left: Box::new(left_node),
906            right: Box::new(right_node),
907        }
908    }
909
910    /// Dead code elimination pass
911    ///
912    /// Performs real DCE by:
913    /// 1. Applying algebraic simplifications that eliminate redundant operations
914    ///    (e.g., `x - x` -> `0`, `x + 0` -> `x`, double negation)
915    /// 2. Constant folding any newly-exposed constant sub-expressions
916    /// 3. Iterating until a fixed point is reached (no further changes)
917    ///
918    /// For single-output tree-structured circuits every reachable node is live,
919    /// so classical "unused result" DCE is a no-op on the tree. Instead we focus
920    /// on strength-reducing and identity-collapsing operations that produce
921    /// effectively dead work (operations whose result equals an operand or a
922    /// constant).
923    fn dead_code_elimination_pass(&mut self, node: CircuitNode) -> CircuitNode {
924        let mut current = node;
925        // Iterate to a fixed point so nested simplifications cascade
926        loop {
927            let simplified = self.dce_simplify(current.clone());
928            if simplified == current {
929                break;
930            }
931            current = simplified;
932        }
933        current
934    }
935
936    /// Single pass of DCE simplification applied bottom-up
937    fn dce_simplify(&mut self, node: CircuitNode) -> CircuitNode {
938        match node {
939            CircuitNode::BinaryOp { op, left, right } => {
940                // Recurse first (bottom-up)
941                let left = self.dce_simplify(*left);
942                let right = self.dce_simplify(*right);
943
944                // Constant folding on newly-exposed constants
945                if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
946                    if let Some(result) = self.fold_binary_constants(op, l, r) {
947                        self.stats.nodes_eliminated += 1;
948                        self.stats.constants_folded += 1;
949                        return CircuitNode::Constant(result);
950                    }
951                }
952
953                // x - x = 0 (same subtree detection)
954                if op == BinaryOperator::Sub && left == right {
955                    self.stats.nodes_eliminated += 1;
956                    self.stats.algebraic_simplifications += 1;
957                    // Produce a zero of the appropriate type based on left subtree
958                    return self.zero_like(&left);
959                }
960
961                // x XOR x = false
962                if op == BinaryOperator::Xor && left == right {
963                    self.stats.nodes_eliminated += 1;
964                    self.stats.algebraic_simplifications += 1;
965                    return CircuitNode::Constant(CircuitValue::Bool(false));
966                }
967
968                // Algebraic identities: x+0, 0+x, x-0, x*1, 1*x, x*0, 0*x
969                match op {
970                    BinaryOperator::Add => {
971                        if Self::is_zero(&right) {
972                            self.stats.nodes_eliminated += 1;
973                            self.stats.algebraic_simplifications += 1;
974                            return left;
975                        }
976                        if Self::is_zero(&left) {
977                            self.stats.nodes_eliminated += 1;
978                            self.stats.algebraic_simplifications += 1;
979                            return right;
980                        }
981                    }
982                    BinaryOperator::Sub => {
983                        if Self::is_zero(&right) {
984                            self.stats.nodes_eliminated += 1;
985                            self.stats.algebraic_simplifications += 1;
986                            return left;
987                        }
988                    }
989                    BinaryOperator::Mul => {
990                        if Self::is_zero(&right) {
991                            self.stats.nodes_eliminated += 1;
992                            self.stats.algebraic_simplifications += 1;
993                            return right;
994                        }
995                        if Self::is_zero(&left) {
996                            self.stats.nodes_eliminated += 1;
997                            self.stats.algebraic_simplifications += 1;
998                            return left;
999                        }
1000                        if Self::is_one(&right) {
1001                            self.stats.nodes_eliminated += 1;
1002                            self.stats.algebraic_simplifications += 1;
1003                            return left;
1004                        }
1005                        if Self::is_one(&left) {
1006                            self.stats.nodes_eliminated += 1;
1007                            self.stats.algebraic_simplifications += 1;
1008                            return right;
1009                        }
1010                    }
1011                    BinaryOperator::And => {
1012                        // x AND x = x
1013                        if left == right {
1014                            self.stats.nodes_eliminated += 1;
1015                            self.stats.algebraic_simplifications += 1;
1016                            return left;
1017                        }
1018                        if Self::is_true(&right) {
1019                            self.stats.nodes_eliminated += 1;
1020                            self.stats.algebraic_simplifications += 1;
1021                            return left;
1022                        }
1023                        if Self::is_true(&left) {
1024                            self.stats.nodes_eliminated += 1;
1025                            self.stats.algebraic_simplifications += 1;
1026                            return right;
1027                        }
1028                        if Self::is_false(&right) {
1029                            self.stats.nodes_eliminated += 1;
1030                            self.stats.algebraic_simplifications += 1;
1031                            return right;
1032                        }
1033                        if Self::is_false(&left) {
1034                            self.stats.nodes_eliminated += 1;
1035                            self.stats.algebraic_simplifications += 1;
1036                            return left;
1037                        }
1038                    }
1039                    BinaryOperator::Or => {
1040                        // x OR x = x
1041                        if left == right {
1042                            self.stats.nodes_eliminated += 1;
1043                            self.stats.algebraic_simplifications += 1;
1044                            return left;
1045                        }
1046                        if Self::is_false(&right) {
1047                            self.stats.nodes_eliminated += 1;
1048                            self.stats.algebraic_simplifications += 1;
1049                            return left;
1050                        }
1051                        if Self::is_false(&left) {
1052                            self.stats.nodes_eliminated += 1;
1053                            self.stats.algebraic_simplifications += 1;
1054                            return right;
1055                        }
1056                        if Self::is_true(&right) {
1057                            self.stats.nodes_eliminated += 1;
1058                            self.stats.algebraic_simplifications += 1;
1059                            return right;
1060                        }
1061                        if Self::is_true(&left) {
1062                            self.stats.nodes_eliminated += 1;
1063                            self.stats.algebraic_simplifications += 1;
1064                            return left;
1065                        }
1066                    }
1067                    BinaryOperator::Xor => {
1068                        if Self::is_false(&right) {
1069                            self.stats.nodes_eliminated += 1;
1070                            self.stats.algebraic_simplifications += 1;
1071                            return left;
1072                        }
1073                        if Self::is_false(&left) {
1074                            self.stats.nodes_eliminated += 1;
1075                            self.stats.algebraic_simplifications += 1;
1076                            return right;
1077                        }
1078                    }
1079                }
1080
1081                CircuitNode::BinaryOp {
1082                    op,
1083                    left: Box::new(left),
1084                    right: Box::new(right),
1085                }
1086            }
1087
1088            CircuitNode::UnaryOp { op, operand } => {
1089                let operand = self.dce_simplify(*operand);
1090
1091                // Constant folding
1092                if let CircuitNode::Constant(val) = &operand {
1093                    if let Some(result) = self.fold_unary_constant(op, val) {
1094                        self.stats.nodes_eliminated += 1;
1095                        self.stats.constants_folded += 1;
1096                        return CircuitNode::Constant(result);
1097                    }
1098                }
1099
1100                // Double negation: NOT(NOT(x)) = x
1101                if op == UnaryOperator::Not {
1102                    if let CircuitNode::UnaryOp {
1103                        op: UnaryOperator::Not,
1104                        operand: inner,
1105                    } = operand
1106                    {
1107                        self.stats.nodes_eliminated += 2;
1108                        self.stats.algebraic_simplifications += 1;
1109                        return *inner;
1110                    }
1111                }
1112
1113                // Double negation for Neg: Neg(Neg(x)) = x
1114                if op == UnaryOperator::Neg {
1115                    if let CircuitNode::UnaryOp {
1116                        op: UnaryOperator::Neg,
1117                        operand: inner,
1118                    } = operand
1119                    {
1120                        self.stats.nodes_eliminated += 2;
1121                        self.stats.algebraic_simplifications += 1;
1122                        return *inner;
1123                    }
1124                }
1125
1126                CircuitNode::UnaryOp {
1127                    op,
1128                    operand: Box::new(operand),
1129                }
1130            }
1131
1132            CircuitNode::Compare { op, left, right } => {
1133                let left = self.dce_simplify(*left);
1134                let right = self.dce_simplify(*right);
1135
1136                // Constant fold comparisons
1137                if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
1138                    if let Some(result) = self.fold_comparison(op, l, r) {
1139                        self.stats.nodes_eliminated += 1;
1140                        self.stats.constants_folded += 1;
1141                        return CircuitNode::Constant(CircuitValue::Bool(result));
1142                    }
1143                }
1144
1145                CircuitNode::Compare {
1146                    op,
1147                    left: Box::new(left),
1148                    right: Box::new(right),
1149                }
1150            }
1151
1152            CircuitNode::NaryOp { op, operands } => {
1153                let new_operands: Vec<CircuitNode> =
1154                    operands.into_iter().map(|o| self.dce_simplify(o)).collect();
1155                CircuitNode::NaryOp {
1156                    op,
1157                    operands: new_operands,
1158                }
1159            }
1160
1161            other => other,
1162        }
1163    }
1164
1165    /// Produce a zero constant matching the type inferred from a subtree
1166    fn zero_like(&self, node: &CircuitNode) -> CircuitNode {
1167        match node {
1168            CircuitNode::Constant(CircuitValue::U8(_)) => {
1169                CircuitNode::Constant(CircuitValue::U8(0))
1170            }
1171            CircuitNode::Constant(CircuitValue::U16(_)) => {
1172                CircuitNode::Constant(CircuitValue::U16(0))
1173            }
1174            CircuitNode::Constant(CircuitValue::U32(_)) => {
1175                CircuitNode::Constant(CircuitValue::U32(0))
1176            }
1177            CircuitNode::Constant(CircuitValue::U64(_)) => {
1178                CircuitNode::Constant(CircuitValue::U64(0))
1179            }
1180            // Default to U8(0) for non-constant nodes where type is unknown
1181            _ => CircuitNode::Constant(CircuitValue::U8(0)),
1182        }
1183    }
1184
1185    /// Fold comparison of two constants into a boolean result
1186    fn fold_comparison(
1187        &self,
1188        op: CompareOperator,
1189        left: &CircuitValue,
1190        right: &CircuitValue,
1191    ) -> Option<bool> {
1192        match (left, right) {
1193            (CircuitValue::U8(l), CircuitValue::U8(r)) => Some(self.compare_values(op, *l, *r)),
1194            (CircuitValue::U16(l), CircuitValue::U16(r)) => Some(self.compare_values(op, *l, *r)),
1195            (CircuitValue::U32(l), CircuitValue::U32(r)) => Some(self.compare_values(op, *l, *r)),
1196            (CircuitValue::U64(l), CircuitValue::U64(r)) => Some(self.compare_values(op, *l, *r)),
1197            (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
1198                CompareOperator::Eq => Some(l == r),
1199                CompareOperator::Ne => Some(l != r),
1200                _ => None,
1201            },
1202            _ => None,
1203        }
1204    }
1205
1206    /// Compare two ordered values with a comparison operator
1207    fn compare_values<T: PartialOrd + PartialEq>(&self, op: CompareOperator, l: T, r: T) -> bool {
1208        match op {
1209            CompareOperator::Eq => l == r,
1210            CompareOperator::Ne => l != r,
1211            CompareOperator::Lt => l < r,
1212            CompareOperator::Le => l <= r,
1213            CompareOperator::Gt => l > r,
1214            CompareOperator::Ge => l >= r,
1215        }
1216    }
1217
1218    /// Collect the set of variable names that are actually used in the circuit tree
1219    pub fn collect_live_variables(&self, node: &CircuitNode) -> HashSet<String> {
1220        let mut live = HashSet::new();
1221        self.mark_live_nodes(node, &mut live);
1222        live
1223    }
1224
1225    /// Mark nodes that contribute to the output
1226    #[allow(clippy::only_used_in_recursion)]
1227    fn mark_live_nodes(&self, node: &CircuitNode, live_nodes: &mut HashSet<String>) {
1228        match node {
1229            CircuitNode::Load(name) => {
1230                live_nodes.insert(name.clone());
1231            }
1232
1233            CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {}
1234
1235            CircuitNode::BinaryOp { left, right, .. } => {
1236                self.mark_live_nodes(left, live_nodes);
1237                self.mark_live_nodes(right, live_nodes);
1238            }
1239
1240            CircuitNode::UnaryOp { operand, .. } => {
1241                self.mark_live_nodes(operand, live_nodes);
1242            }
1243
1244            CircuitNode::Compare { left, right, .. } => {
1245                self.mark_live_nodes(left, live_nodes);
1246                self.mark_live_nodes(right, live_nodes);
1247            }
1248            CircuitNode::NaryOp { operands, .. } => {
1249                for operand in operands {
1250                    self.mark_live_nodes(operand, live_nodes);
1251                }
1252            }
1253        }
1254    }
1255
1256    /// Analyze circuit for parallelization opportunities
1257    ///
1258    /// Builds a dependency graph and identifies operations that can run in parallel
1259    fn analyze_parallelism(&self, circuit: &Circuit) -> Result<DependencyGraph> {
1260        let mut graph = DependencyGraph::new();
1261        let mut node_id_map = HashMap::new();
1262        let mut cse_map = HashMap::new();
1263        let mut next_id = 0;
1264
1265        // Build dependency graph with CSE deduplication
1266        self.build_dependency_graph(
1267            &circuit.root,
1268            &mut graph,
1269            &mut node_id_map,
1270            &mut cse_map,
1271            &mut next_id,
1272        );
1273
1274        graph.node_count = next_id;
1275
1276        // Identify parallel groups using level-wise traversal
1277        graph.parallel_groups = self.identify_parallel_groups(&graph);
1278
1279        // Find critical path using memoized algorithm
1280        graph.critical_path = self.find_critical_path(&graph);
1281
1282        Ok(graph)
1283    }
1284
1285    /// Build dependency graph recursively, using CSE map to deduplicate identical subtrees
1286    #[allow(clippy::only_used_in_recursion)]
1287    fn build_dependency_graph(
1288        &self,
1289        node: &CircuitNode,
1290        graph: &mut DependencyGraph,
1291        node_id_map: &mut HashMap<String, NodeId>,
1292        cse_map: &mut HashMap<u64, NodeId>,
1293        next_id: &mut usize,
1294    ) -> NodeId {
1295        // Check for structural CSE deduplication
1296        let node_hash = Self::structural_hash(node);
1297        if let Some(&existing_id) = cse_map.get(&node_hash) {
1298            return existing_id;
1299        }
1300
1301        let current_id = NodeId(*next_id);
1302        *next_id += 1;
1303        cse_map.insert(node_hash, current_id);
1304
1305        match node {
1306            CircuitNode::Load(name) => {
1307                node_id_map.insert(name.clone(), current_id);
1308                graph.dependencies.insert(current_id, Vec::new());
1309                current_id
1310            }
1311
1312            CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {
1313                graph.dependencies.insert(current_id, Vec::new());
1314                current_id
1315            }
1316
1317            CircuitNode::BinaryOp { left, right, .. } => {
1318                let left_id =
1319                    self.build_dependency_graph(left, graph, node_id_map, cse_map, next_id);
1320                let right_id =
1321                    self.build_dependency_graph(right, graph, node_id_map, cse_map, next_id);
1322                graph
1323                    .dependencies
1324                    .insert(current_id, vec![left_id, right_id]);
1325                current_id
1326            }
1327
1328            CircuitNode::UnaryOp { operand, .. } => {
1329                let operand_id =
1330                    self.build_dependency_graph(operand, graph, node_id_map, cse_map, next_id);
1331                graph.dependencies.insert(current_id, vec![operand_id]);
1332                current_id
1333            }
1334
1335            CircuitNode::Compare { left, right, .. } => {
1336                let left_id =
1337                    self.build_dependency_graph(left, graph, node_id_map, cse_map, next_id);
1338                let right_id =
1339                    self.build_dependency_graph(right, graph, node_id_map, cse_map, next_id);
1340                graph
1341                    .dependencies
1342                    .insert(current_id, vec![left_id, right_id]);
1343                current_id
1344            }
1345
1346            CircuitNode::NaryOp { operands, .. } => {
1347                let dep_ids: Vec<NodeId> = operands
1348                    .iter()
1349                    .map(|o| self.build_dependency_graph(o, graph, node_id_map, cse_map, next_id))
1350                    .collect();
1351                graph.dependencies.insert(current_id, dep_ids);
1352                current_id
1353            }
1354        }
1355    }
1356
1357    /// Compute a structural hash for a circuit node (for CSE deduplication)
1358    fn structural_hash(node: &CircuitNode) -> u64 {
1359        use std::collections::hash_map::DefaultHasher;
1360        use std::hash::Hasher;
1361
1362        let mut hasher = DefaultHasher::new();
1363        Self::hash_node(node, &mut hasher);
1364        hasher.finish()
1365    }
1366
1367    fn hash_node(node: &CircuitNode, hasher: &mut impl std::hash::Hasher) {
1368        use std::hash::Hash;
1369        match node {
1370            CircuitNode::Load(name) => {
1371                0u8.hash(hasher);
1372                name.hash(hasher);
1373            }
1374            CircuitNode::Constant(value) => {
1375                1u8.hash(hasher);
1376                match value {
1377                    crate::compute::circuit::CircuitValue::Bool(v) => {
1378                        0u8.hash(hasher);
1379                        v.hash(hasher);
1380                    }
1381                    crate::compute::circuit::CircuitValue::U8(v) => {
1382                        1u8.hash(hasher);
1383                        v.hash(hasher);
1384                    }
1385                    crate::compute::circuit::CircuitValue::U16(v) => {
1386                        2u8.hash(hasher);
1387                        v.hash(hasher);
1388                    }
1389                    crate::compute::circuit::CircuitValue::U32(v) => {
1390                        3u8.hash(hasher);
1391                        v.hash(hasher);
1392                    }
1393                    crate::compute::circuit::CircuitValue::U64(v) => {
1394                        4u8.hash(hasher);
1395                        v.hash(hasher);
1396                    }
1397                }
1398            }
1399            CircuitNode::EncryptedConstant {
1400                data,
1401                original_type,
1402            } => {
1403                2u8.hash(hasher);
1404                data.hash(hasher);
1405                match original_type {
1406                    crate::compute::circuit::ConstantType::Integer => 0u8.hash(hasher),
1407                    crate::compute::circuit::ConstantType::Boolean => 1u8.hash(hasher),
1408                    crate::compute::circuit::ConstantType::Float => 2u8.hash(hasher),
1409                    crate::compute::circuit::ConstantType::Bytes => 3u8.hash(hasher),
1410                }
1411            }
1412            CircuitNode::BinaryOp { op, left, right } => {
1413                3u8.hash(hasher);
1414                Self::hash_binary_op(*op, hasher);
1415                Self::hash_node(left, hasher);
1416                Self::hash_node(right, hasher);
1417            }
1418            CircuitNode::UnaryOp { op, operand } => {
1419                4u8.hash(hasher);
1420                match op {
1421                    UnaryOperator::Not => 0u8.hash(hasher),
1422                    UnaryOperator::Neg => 1u8.hash(hasher),
1423                }
1424                Self::hash_node(operand, hasher);
1425            }
1426            CircuitNode::Compare { op, left, right } => {
1427                5u8.hash(hasher);
1428                match op {
1429                    CompareOperator::Eq => 0u8.hash(hasher),
1430                    CompareOperator::Ne => 1u8.hash(hasher),
1431                    CompareOperator::Lt => 2u8.hash(hasher),
1432                    CompareOperator::Le => 3u8.hash(hasher),
1433                    CompareOperator::Gt => 4u8.hash(hasher),
1434                    CompareOperator::Ge => 5u8.hash(hasher),
1435                }
1436                Self::hash_node(left, hasher);
1437                Self::hash_node(right, hasher);
1438            }
1439            CircuitNode::NaryOp { op, operands } => {
1440                6u8.hash(hasher);
1441                Self::hash_binary_op(*op, hasher);
1442                operands.len().hash(hasher);
1443                for o in operands {
1444                    Self::hash_node(o, hasher);
1445                }
1446            }
1447        }
1448    }
1449
1450    fn hash_binary_op(op: BinaryOperator, hasher: &mut impl std::hash::Hasher) {
1451        use std::hash::Hash;
1452        match op {
1453            BinaryOperator::Add => 0u8.hash(hasher),
1454            BinaryOperator::Sub => 1u8.hash(hasher),
1455            BinaryOperator::Mul => 2u8.hash(hasher),
1456            BinaryOperator::And => 3u8.hash(hasher),
1457            BinaryOperator::Or => 4u8.hash(hasher),
1458            BinaryOperator::Xor => 5u8.hash(hasher),
1459        }
1460    }
1461
1462    /// Identify groups of nodes that can execute in parallel
1463    fn identify_parallel_groups(&self, graph: &DependencyGraph) -> Vec<Vec<NodeId>> {
1464        let mut levels: HashMap<NodeId, usize> = HashMap::new();
1465        let mut queue = VecDeque::new();
1466
1467        // Find all nodes with no dependencies (level 0)
1468        for (node_id, deps) in &graph.dependencies {
1469            if deps.is_empty() {
1470                levels.insert(*node_id, 0);
1471                queue.push_back(*node_id);
1472            }
1473        }
1474
1475        // Level-wise traversal
1476        while let Some(node_id) = queue.pop_front() {
1477            let current_level = levels[&node_id];
1478
1479            // Find nodes that depend on this node
1480            for (dependent_id, deps) in &graph.dependencies {
1481                if deps.contains(&node_id) {
1482                    // Calculate level for dependent node
1483                    let max_dep_level = deps
1484                        .iter()
1485                        .filter_map(|dep_id| levels.get(dep_id))
1486                        .max()
1487                        .copied()
1488                        .unwrap_or(0);
1489
1490                    let dependent_level = max_dep_level + 1;
1491
1492                    if !levels.contains_key(dependent_id) {
1493                        levels.insert(*dependent_id, dependent_level);
1494                        queue.push_back(*dependent_id);
1495                    }
1496                }
1497            }
1498        }
1499
1500        // Group nodes by level
1501        let max_level = levels.values().max().copied().unwrap_or(0);
1502        let mut parallel_groups = vec![Vec::new(); max_level + 1];
1503
1504        for (node_id, level) in levels {
1505            parallel_groups[level].push(node_id);
1506        }
1507
1508        // Sort each group for deterministic output
1509        for group in &mut parallel_groups {
1510            group.sort();
1511        }
1512
1513        parallel_groups
1514    }
1515
1516    /// Find the critical path (longest dependency chain) using memoization
1517    fn find_critical_path(&self, graph: &DependencyGraph) -> Vec<NodeId> {
1518        let mut memo = HashMap::new();
1519
1520        // Compute longest path length to each node
1521        for &node_id in graph.dependencies.keys() {
1522            self.longest_path_to(node_id, graph, &mut memo);
1523        }
1524
1525        // Find the node with the maximum path length
1526        let max_node = graph
1527            .dependencies
1528            .keys()
1529            .max_by_key(|&&id| memo.get(&id).copied().unwrap_or(0));
1530
1531        let Some(&end_node) = max_node else {
1532            return Vec::new();
1533        };
1534
1535        // Reconstruct path from end_node following max-cost dependencies
1536        let mut path = Vec::new();
1537        let mut current = end_node;
1538        path.push(current);
1539
1540        loop {
1541            let deps = match graph.dependencies.get(&current) {
1542                Some(d) if !d.is_empty() => d,
1543                _ => break,
1544            };
1545            let next = deps
1546                .iter()
1547                .max_by_key(|&&dep_id| memo.get(&dep_id).copied().unwrap_or(0))
1548                .copied();
1549            match next {
1550                Some(next_id) if next_id != current => {
1551                    path.push(next_id);
1552                    current = next_id;
1553                }
1554                _ => break,
1555            }
1556        }
1557
1558        path.reverse();
1559        path
1560    }
1561
1562    /// Memoized computation of longest path from a leaf to `node_id`
1563    fn longest_path_to(
1564        &self,
1565        node_id: NodeId,
1566        graph: &DependencyGraph,
1567        memo: &mut HashMap<NodeId, usize>,
1568    ) -> usize {
1569        if let Some(&cached) = memo.get(&node_id) {
1570            return cached;
1571        }
1572
1573        let deps = graph
1574            .dependencies
1575            .get(&node_id)
1576            .map(|v| v.as_slice())
1577            .unwrap_or(&[]);
1578
1579        let result = if deps.is_empty() {
1580            1
1581        } else {
1582            let max_dep = deps
1583                .iter()
1584                .map(|&dep_id| self.longest_path_to(dep_id, graph, memo))
1585                .max()
1586                .unwrap_or(0);
1587            max_dep + 1
1588        };
1589
1590        memo.insert(node_id, result);
1591        result
1592    }
1593}
1594
1595impl Default for CircuitOptimizer {
1596    fn default() -> Self {
1597        Self::new()
1598    }
1599}
1600
1601#[cfg(test)]
1602#[path = "optimizer_tests.rs"]
1603mod tests;