1use crate::compute::circuit::{
14 BinaryOperator, Circuit, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
15 UnaryOperator,
16};
17use crate::error::{AmateRSError, ErrorContext, Result};
18use std::collections::{HashMap, HashSet, VecDeque};
19
20#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct OptimizationStats {
23 pub original_gate_count: usize,
25
26 pub optimized_gate_count: usize,
28
29 pub original_bootstrap_count: usize,
31
32 pub optimized_bootstrap_count: usize,
34
35 pub dead_code_removed: usize,
37
38 pub nodes_eliminated: usize,
40
41 pub algebraic_simplifications: usize,
43
44 pub constants_folded: usize,
46
47 pub gates_fused: usize,
49
50 pub original_depth: usize,
52
53 pub optimized_depth: usize,
55}
56
57impl OptimizationStats {
58 pub fn gate_reduction_percent(&self) -> f64 {
60 if self.original_gate_count == 0 {
61 return 0.0;
62 }
63 let reduction = self
64 .original_gate_count
65 .saturating_sub(self.optimized_gate_count);
66 (reduction as f64 / self.original_gate_count as f64) * 100.0
67 }
68
69 pub fn bootstrap_reduction_percent(&self) -> f64 {
71 if self.original_bootstrap_count == 0 {
72 return 0.0;
73 }
74 let reduction = self
75 .original_bootstrap_count
76 .saturating_sub(self.optimized_bootstrap_count);
77 (reduction as f64 / self.original_bootstrap_count as f64) * 100.0
78 }
79
80 pub fn total_stats(&self) -> (usize, usize, usize) {
82 (
83 self.nodes_eliminated + self.dead_code_removed,
84 self.algebraic_simplifications + self.gates_fused,
85 self.constants_folded,
86 )
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
92pub struct DependencyGraph {
93 pub dependencies: HashMap<NodeId, Vec<NodeId>>,
95
96 pub parallel_groups: Vec<Vec<NodeId>>,
98
99 pub critical_path: Vec<NodeId>,
101
102 pub node_count: usize,
104}
105
106impl DependencyGraph {
107 pub fn new() -> Self {
109 Self {
110 dependencies: HashMap::new(),
111 parallel_groups: Vec::new(),
112 critical_path: Vec::new(),
113 node_count: 0,
114 }
115 }
116
117 pub fn max_parallelism(&self) -> usize {
119 self.parallel_groups
120 .iter()
121 .map(|g| g.len())
122 .max()
123 .unwrap_or(0)
124 }
125
126 pub fn avg_parallelism(&self) -> f64 {
128 if self.parallel_groups.is_empty() {
129 return 0.0;
130 }
131 let total: usize = self.parallel_groups.iter().map(|g| g.len()).sum();
132 total as f64 / self.parallel_groups.len() as f64
133 }
134
135 pub fn topological_order(&self) -> Vec<NodeId> {
137 self.compute_topological_order()
138 }
139
140 fn compute_topological_order(&self) -> Vec<NodeId> {
141 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
143
144 for (node_id, deps) in &self.dependencies {
146 *in_degree.entry(*node_id).or_insert(0) = deps.len();
147 for dep_id in deps {
149 in_degree.entry(*dep_id).or_insert(0);
150 }
151 }
152
153 let mut queue: std::collections::BTreeSet<NodeId> = in_degree
155 .iter()
156 .filter(|&(_, deg)| *deg == 0)
157 .map(|(&id, _)| id)
158 .collect();
159
160 let mut result = Vec::new();
161
162 let mut dependents: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
164 for (node_id, deps) in &self.dependencies {
165 for dep_id in deps {
166 dependents.entry(*dep_id).or_default().push(*node_id);
167 }
168 }
169
170 while let Some(&node_id) = queue.iter().next() {
171 queue.remove(&node_id);
172 result.push(node_id);
173
174 if let Some(dep_nodes) = dependents.get(&node_id) {
175 for &dependent_id in dep_nodes {
176 if let Some(deg) = in_degree.get_mut(&dependent_id) {
177 if *deg > 0 {
178 *deg -= 1;
179 if *deg == 0 {
180 queue.insert(dependent_id);
181 }
182 }
183 }
184 }
185 }
186 }
187
188 result
189 }
190}
191
192impl Default for DependencyGraph {
193 fn default() -> Self {
194 Self::new()
195 }
196}
197
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
200pub struct NodeId(pub usize);
201
202#[derive(Debug, Clone)]
204pub struct CircuitOptimizer {
205 pub enable_constant_folding: bool,
207
208 pub enable_dead_code_elimination: bool,
210
211 pub enable_bootstrap_minimization: bool,
213
214 pub enable_gate_fusion: bool,
216
217 pub enable_parallelization_analysis: bool,
219
220 stats: OptimizationStats,
222
223 dependency_graph: DependencyGraph,
225}
226
227impl CircuitOptimizer {
228 pub fn new() -> Self {
230 Self {
231 enable_constant_folding: true,
232 enable_dead_code_elimination: true,
233 enable_bootstrap_minimization: true,
234 enable_gate_fusion: true,
235 enable_parallelization_analysis: true,
236 stats: OptimizationStats::default(),
237 dependency_graph: DependencyGraph::new(),
238 }
239 }
240
241 pub fn disabled() -> Self {
243 Self {
244 enable_constant_folding: false,
245 enable_dead_code_elimination: false,
246 enable_bootstrap_minimization: false,
247 enable_gate_fusion: false,
248 enable_parallelization_analysis: false,
249 stats: OptimizationStats::default(),
250 dependency_graph: DependencyGraph::new(),
251 }
252 }
253
254 pub fn stats(&self) -> &OptimizationStats {
256 &self.stats
257 }
258
259 pub fn dependency_graph(&self) -> &DependencyGraph {
261 &self.dependency_graph
262 }
263
264 pub fn total_stats(&self) -> (usize, usize, usize) {
266 self.stats.total_stats()
267 }
268
269 pub fn optimize(&mut self, circuit: Circuit) -> Result<Circuit> {
271 self.stats.original_gate_count = circuit.gate_count;
273 self.stats.original_depth = circuit.depth;
274 self.stats.original_bootstrap_count = self.count_bootstraps(&circuit.root);
275
276 let mut optimized_root = circuit.root.clone();
277
278 if self.enable_constant_folding {
280 optimized_root = self.constant_folding_pass(optimized_root);
281 }
282
283 if self.enable_gate_fusion {
284 optimized_root = self.gate_fusion_pass(optimized_root);
285 }
286
287 if self.enable_bootstrap_minimization {
288 optimized_root = self.bootstrap_minimization_pass(optimized_root)?;
289 }
290
291 if self.enable_dead_code_elimination {
292 optimized_root = self.dead_code_elimination_pass(optimized_root);
293 }
294
295 let optimized_circuit = Circuit::new(optimized_root, circuit.variable_types)?;
297
298 self.stats.optimized_gate_count = optimized_circuit.gate_count;
300 self.stats.optimized_depth = optimized_circuit.depth;
301 self.stats.optimized_bootstrap_count = self.count_bootstraps(&optimized_circuit.root);
302
303 if self.enable_parallelization_analysis {
305 self.dependency_graph = self.analyze_parallelism(&optimized_circuit)?;
306 }
307
308 Ok(optimized_circuit)
309 }
310
311 #[allow(clippy::only_used_in_recursion)]
319 fn count_bootstraps(&self, node: &CircuitNode) -> usize {
320 match node {
321 CircuitNode::Load(_)
322 | CircuitNode::Constant(_)
323 | CircuitNode::EncryptedConstant { .. } => 0,
324
325 CircuitNode::BinaryOp { op, left, right } => {
326 let left_bootstraps = self.count_bootstraps(left);
327 let right_bootstraps = self.count_bootstraps(right);
328
329 let op_bootstrap = match op {
331 BinaryOperator::Mul => 1,
332 _ => 0,
333 };
334
335 left_bootstraps + right_bootstraps + op_bootstrap
336 }
337
338 CircuitNode::UnaryOp { operand, .. } => self.count_bootstraps(operand),
339
340 CircuitNode::Compare { left, right, .. } => {
341 let left_bootstraps = self.count_bootstraps(left);
342 let right_bootstraps = self.count_bootstraps(right);
343
344 left_bootstraps + right_bootstraps + 1
346 }
347 CircuitNode::NaryOp { op, operands } => {
348 let operand_bootstraps: usize =
349 operands.iter().map(|o| self.count_bootstraps(o)).sum();
350 let op_bootstraps = match op {
351 BinaryOperator::Mul => operands.len().saturating_sub(1),
352 _ => 0,
353 };
354 operand_bootstraps + op_bootstraps
355 }
356 }
357 }
358
359 fn constant_folding_pass(&mut self, node: CircuitNode) -> CircuitNode {
363 match node {
364 CircuitNode::BinaryOp { op, left, right } => {
365 let left = self.constant_folding_pass(*left);
366 let right = self.constant_folding_pass(*right);
367
368 if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
370 if let Some(result) = self.fold_binary_constants(op, l, r) {
371 self.stats.constants_folded += 1;
372 return CircuitNode::Constant(result);
373 }
374 }
375
376 if let Some(simplified) = self.apply_algebraic_identities(op, &left, &right) {
378 return simplified;
379 }
380
381 CircuitNode::BinaryOp {
382 op,
383 left: Box::new(left),
384 right: Box::new(right),
385 }
386 }
387
388 CircuitNode::UnaryOp { op, operand } => {
389 let operand = self.constant_folding_pass(*operand);
390
391 if let CircuitNode::Constant(val) = &operand {
392 if let Some(result) = self.fold_unary_constant(op, val) {
393 self.stats.constants_folded += 1;
394 return CircuitNode::Constant(result);
395 }
396 }
397
398 CircuitNode::UnaryOp {
399 op,
400 operand: Box::new(operand),
401 }
402 }
403
404 CircuitNode::Compare { op, left, right } => {
405 let left = self.constant_folding_pass(*left);
406 let right = self.constant_folding_pass(*right);
407
408 CircuitNode::Compare {
409 op,
410 left: Box::new(left),
411 right: Box::new(right),
412 }
413 }
414
415 CircuitNode::NaryOp { op, operands } => {
416 let new_operands: Vec<CircuitNode> = operands
417 .into_iter()
418 .map(|o| self.constant_folding_pass(o))
419 .collect();
420 CircuitNode::NaryOp {
421 op,
422 operands: new_operands,
423 }
424 }
425
426 other => other,
427 }
428 }
429
430 fn fold_binary_constants(
432 &self,
433 op: BinaryOperator,
434 left: &CircuitValue,
435 right: &CircuitValue,
436 ) -> Option<CircuitValue> {
437 match (left, right) {
438 (CircuitValue::U8(l), CircuitValue::U8(r)) => match op {
439 BinaryOperator::Add => Some(CircuitValue::U8(l.wrapping_add(*r))),
440 BinaryOperator::Sub => Some(CircuitValue::U8(l.wrapping_sub(*r))),
441 BinaryOperator::Mul => Some(CircuitValue::U8(l.wrapping_mul(*r))),
442 _ => None,
443 },
444 (CircuitValue::U16(l), CircuitValue::U16(r)) => match op {
445 BinaryOperator::Add => Some(CircuitValue::U16(l.wrapping_add(*r))),
446 BinaryOperator::Sub => Some(CircuitValue::U16(l.wrapping_sub(*r))),
447 BinaryOperator::Mul => Some(CircuitValue::U16(l.wrapping_mul(*r))),
448 _ => None,
449 },
450 (CircuitValue::U32(l), CircuitValue::U32(r)) => match op {
451 BinaryOperator::Add => Some(CircuitValue::U32(l.wrapping_add(*r))),
452 BinaryOperator::Sub => Some(CircuitValue::U32(l.wrapping_sub(*r))),
453 BinaryOperator::Mul => Some(CircuitValue::U32(l.wrapping_mul(*r))),
454 _ => None,
455 },
456 (CircuitValue::U64(l), CircuitValue::U64(r)) => match op {
457 BinaryOperator::Add => Some(CircuitValue::U64(l.wrapping_add(*r))),
458 BinaryOperator::Sub => Some(CircuitValue::U64(l.wrapping_sub(*r))),
459 BinaryOperator::Mul => Some(CircuitValue::U64(l.wrapping_mul(*r))),
460 _ => None,
461 },
462 (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
463 BinaryOperator::And => Some(CircuitValue::Bool(*l && *r)),
464 BinaryOperator::Or => Some(CircuitValue::Bool(*l || *r)),
465 BinaryOperator::Xor => Some(CircuitValue::Bool(*l ^ *r)),
466 _ => None,
467 },
468 _ => None,
469 }
470 }
471
472 fn fold_unary_constant(&self, op: UnaryOperator, value: &CircuitValue) -> Option<CircuitValue> {
474 match (op, value) {
475 (UnaryOperator::Not, CircuitValue::Bool(v)) => Some(CircuitValue::Bool(!*v)),
476 _ => None,
477 }
478 }
479
480 fn apply_algebraic_identities(
483 &mut self,
484 op: BinaryOperator,
485 left: &CircuitNode,
486 right: &CircuitNode,
487 ) -> Option<CircuitNode> {
488 match op {
489 BinaryOperator::Add => {
490 if Self::is_zero(right) {
492 self.stats.gates_fused += 1;
493 return Some(left.clone());
494 }
495 if Self::is_zero(left) {
497 self.stats.gates_fused += 1;
498 return Some(right.clone());
499 }
500 }
501
502 BinaryOperator::Sub => {
503 if Self::is_zero(right) {
505 self.stats.gates_fused += 1;
506 return Some(left.clone());
507 }
508 }
509
510 BinaryOperator::Mul => {
511 if Self::is_zero(right) {
513 self.stats.gates_fused += 1;
514 return Some(right.clone());
515 }
516 if Self::is_zero(left) {
517 self.stats.gates_fused += 1;
518 return Some(left.clone());
519 }
520
521 if Self::is_one(right) {
523 self.stats.gates_fused += 1;
524 return Some(left.clone());
525 }
526 if Self::is_one(left) {
528 self.stats.gates_fused += 1;
529 return Some(right.clone());
530 }
531 }
532
533 BinaryOperator::And => {
534 if Self::is_true(right) {
536 self.stats.gates_fused += 1;
537 return Some(left.clone());
538 }
539 if Self::is_true(left) {
540 self.stats.gates_fused += 1;
541 return Some(right.clone());
542 }
543
544 if Self::is_false(right) {
546 self.stats.gates_fused += 1;
547 return Some(right.clone());
548 }
549 if Self::is_false(left) {
550 self.stats.gates_fused += 1;
551 return Some(left.clone());
552 }
553 }
554
555 BinaryOperator::Or => {
556 if Self::is_false(right) {
558 self.stats.gates_fused += 1;
559 return Some(left.clone());
560 }
561 if Self::is_false(left) {
562 self.stats.gates_fused += 1;
563 return Some(right.clone());
564 }
565
566 if Self::is_true(right) {
568 self.stats.gates_fused += 1;
569 return Some(right.clone());
570 }
571 if Self::is_true(left) {
572 self.stats.gates_fused += 1;
573 return Some(left.clone());
574 }
575 }
576
577 BinaryOperator::Xor => {
578 if Self::is_false(right) {
580 self.stats.gates_fused += 1;
581 return Some(left.clone());
582 }
583 if Self::is_false(left) {
584 self.stats.gates_fused += 1;
585 return Some(right.clone());
586 }
587 }
588 }
589
590 None
591 }
592
593 fn is_zero(node: &CircuitNode) -> bool {
595 matches!(
596 node,
597 CircuitNode::Constant(CircuitValue::U8(0))
598 | CircuitNode::Constant(CircuitValue::U16(0))
599 | CircuitNode::Constant(CircuitValue::U32(0))
600 | CircuitNode::Constant(CircuitValue::U64(0))
601 )
602 }
603
604 fn is_one(node: &CircuitNode) -> bool {
606 matches!(
607 node,
608 CircuitNode::Constant(CircuitValue::U8(1))
609 | CircuitNode::Constant(CircuitValue::U16(1))
610 | CircuitNode::Constant(CircuitValue::U32(1))
611 | CircuitNode::Constant(CircuitValue::U64(1))
612 )
613 }
614
615 fn is_true(node: &CircuitNode) -> bool {
617 matches!(node, CircuitNode::Constant(CircuitValue::Bool(true)))
618 }
619
620 fn is_false(node: &CircuitNode) -> bool {
622 matches!(node, CircuitNode::Constant(CircuitValue::Bool(false)))
623 }
624
625 fn gate_fusion_pass(&mut self, node: CircuitNode) -> CircuitNode {
631 match node {
632 CircuitNode::BinaryOp { op, left, right } => {
633 let left = self.gate_fusion_pass(*left);
634 let right = self.gate_fusion_pass(*right);
635
636 match op {
637 BinaryOperator::Add
638 | BinaryOperator::Mul
639 | BinaryOperator::And
640 | BinaryOperator::Or
641 | BinaryOperator::Xor => {
642 let mut operands: Vec<CircuitNode> = Vec::new();
645 Self::collect_nary_operands(op, left, &mut operands);
646 Self::collect_nary_operands(op, right, &mut operands);
647 if operands.len() >= 3 {
650 self.stats.gates_fused += operands.len().saturating_sub(2);
651 CircuitNode::NaryOp { op, operands }
652 } else {
653 Self::build_balanced_reduction(op, operands)
655 }
656 }
657 _ => CircuitNode::BinaryOp {
658 op,
659 left: Box::new(left),
660 right: Box::new(right),
661 },
662 }
663 }
664
665 CircuitNode::NaryOp { op, operands } => {
666 let new_operands: Vec<CircuitNode> = operands
668 .into_iter()
669 .map(|o| self.gate_fusion_pass(o))
670 .collect();
671 let mut flat_operands = Vec::new();
673 for operand in new_operands {
674 Self::collect_nary_operands(op, operand, &mut flat_operands);
675 }
676 if flat_operands.len() >= 2 {
677 CircuitNode::NaryOp {
678 op,
679 operands: flat_operands,
680 }
681 } else if flat_operands.len() == 1 {
682 flat_operands.remove(0)
683 } else {
684 CircuitNode::NaryOp {
685 op,
686 operands: flat_operands,
687 }
688 }
689 }
690
691 CircuitNode::UnaryOp {
692 op: UnaryOperator::Not,
693 operand,
694 } => {
695 let operand = self.gate_fusion_pass(*operand);
696
697 if let CircuitNode::UnaryOp {
699 op: UnaryOperator::Not,
700 operand: inner,
701 } = operand
702 {
703 self.stats.gates_fused += 2;
704 return *inner;
705 }
706
707 CircuitNode::UnaryOp {
708 op: UnaryOperator::Not,
709 operand: Box::new(operand),
710 }
711 }
712
713 CircuitNode::UnaryOp { op, operand } => {
714 let operand = self.gate_fusion_pass(*operand);
715 CircuitNode::UnaryOp {
716 op,
717 operand: Box::new(operand),
718 }
719 }
720
721 CircuitNode::Compare { op, left, right } => {
722 let left = self.gate_fusion_pass(*left);
723 let right = self.gate_fusion_pass(*right);
724 CircuitNode::Compare {
725 op,
726 left: Box::new(left),
727 right: Box::new(right),
728 }
729 }
730
731 other => other,
732 }
733 }
734
735 fn collect_nary_operands(op: BinaryOperator, node: CircuitNode, out: &mut Vec<CircuitNode>) {
737 match node {
738 CircuitNode::BinaryOp {
739 op: child_op,
740 left,
741 right,
742 } if child_op == op => {
743 Self::collect_nary_operands(op, *left, out);
744 Self::collect_nary_operands(op, *right, out);
745 }
746 CircuitNode::NaryOp {
747 op: child_op,
748 operands,
749 } if child_op == op => {
750 for operand in operands {
751 Self::collect_nary_operands(op, operand, out);
752 }
753 }
754 other => out.push(other),
755 }
756 }
757
758 fn bootstrap_minimization_pass(&mut self, node: CircuitNode) -> Result<CircuitNode> {
765 Ok(self.reorder_for_bootstrap_efficiency(node))
766 }
767
768 fn reorder_for_bootstrap_efficiency(&mut self, node: CircuitNode) -> CircuitNode {
774 match node {
775 CircuitNode::BinaryOp { op, left, right } => {
776 let left = self.reorder_for_bootstrap_efficiency(*left);
777 let right = self.reorder_for_bootstrap_efficiency(*right);
778
779 let is_commutative = matches!(
780 op,
781 BinaryOperator::Add
782 | BinaryOperator::Mul
783 | BinaryOperator::And
784 | BinaryOperator::Or
785 | BinaryOperator::Xor
786 );
787
788 if is_commutative {
789 let left_cost = self.count_bootstraps(&left);
790 let right_cost = self.count_bootstraps(&right);
791 if right_cost > left_cost {
792 return CircuitNode::BinaryOp {
793 op,
794 left: Box::new(right),
795 right: Box::new(left),
796 };
797 }
798 }
799
800 CircuitNode::BinaryOp {
801 op,
802 left: Box::new(left),
803 right: Box::new(right),
804 }
805 }
806
807 CircuitNode::NaryOp { op, operands } => {
808 let processed_operands: Vec<CircuitNode> = operands
810 .into_iter()
811 .map(|o| self.reorder_for_bootstrap_efficiency(o))
812 .collect();
813
814 if matches!(op, BinaryOperator::Mul) && processed_operands.len() >= 2 {
816 return Self::build_balanced_reduction(op, processed_operands);
817 }
818
819 let mut with_costs: Vec<(usize, CircuitNode)> = processed_operands
821 .into_iter()
822 .map(|o| {
823 let cost = self.count_bootstraps(&o);
824 (cost, o)
825 })
826 .collect();
827 with_costs.sort_by_key(|b| std::cmp::Reverse(b.0));
828 let sorted_operands: Vec<CircuitNode> =
829 with_costs.into_iter().map(|(_, o)| o).collect();
830
831 CircuitNode::NaryOp {
832 op,
833 operands: sorted_operands,
834 }
835 }
836
837 CircuitNode::UnaryOp { op, operand } => {
838 let operand = self.reorder_for_bootstrap_efficiency(*operand);
839 CircuitNode::UnaryOp {
840 op,
841 operand: Box::new(operand),
842 }
843 }
844
845 CircuitNode::Compare { op, left, right } => {
846 let left = self.reorder_for_bootstrap_efficiency(*left);
847 let right = self.reorder_for_bootstrap_efficiency(*right);
848 CircuitNode::Compare {
849 op,
850 left: Box::new(left),
851 right: Box::new(right),
852 }
853 }
854
855 other => other,
856 }
857 }
858
859 fn build_balanced_reduction(op: BinaryOperator, operands: Vec<CircuitNode>) -> CircuitNode {
861 if operands.is_empty() {
862 return CircuitNode::Constant(crate::compute::circuit::CircuitValue::U8(0));
864 }
865 if operands.len() == 1 {
866 return operands.into_iter().next().unwrap_or(CircuitNode::Constant(
868 crate::compute::circuit::CircuitValue::U8(0),
869 ));
870 }
871 if operands.len() == 2 {
872 let mut it = operands.into_iter();
873 let left = it.next().unwrap_or(CircuitNode::Constant(
875 crate::compute::circuit::CircuitValue::U8(0),
876 ));
877 let right = it.next().unwrap_or(CircuitNode::Constant(
878 crate::compute::circuit::CircuitValue::U8(0),
879 ));
880 return CircuitNode::BinaryOp {
881 op,
882 left: Box::new(left),
883 right: Box::new(right),
884 };
885 }
886
887 let mid = operands.len() / 2;
888 let (left_operands, right_operands) = operands.into_iter().enumerate().fold(
889 (Vec::new(), Vec::new()),
890 |(mut l, mut r), (i, node)| {
891 if i < mid {
892 l.push(node);
893 } else {
894 r.push(node);
895 }
896 (l, r)
897 },
898 );
899
900 let left_node = Self::build_balanced_reduction(op, left_operands);
901 let right_node = Self::build_balanced_reduction(op, right_operands);
902
903 CircuitNode::BinaryOp {
904 op,
905 left: Box::new(left_node),
906 right: Box::new(right_node),
907 }
908 }
909
910 fn dead_code_elimination_pass(&mut self, node: CircuitNode) -> CircuitNode {
924 let mut current = node;
925 loop {
927 let simplified = self.dce_simplify(current.clone());
928 if simplified == current {
929 break;
930 }
931 current = simplified;
932 }
933 current
934 }
935
936 fn dce_simplify(&mut self, node: CircuitNode) -> CircuitNode {
938 match node {
939 CircuitNode::BinaryOp { op, left, right } => {
940 let left = self.dce_simplify(*left);
942 let right = self.dce_simplify(*right);
943
944 if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
946 if let Some(result) = self.fold_binary_constants(op, l, r) {
947 self.stats.nodes_eliminated += 1;
948 self.stats.constants_folded += 1;
949 return CircuitNode::Constant(result);
950 }
951 }
952
953 if op == BinaryOperator::Sub && left == right {
955 self.stats.nodes_eliminated += 1;
956 self.stats.algebraic_simplifications += 1;
957 return self.zero_like(&left);
959 }
960
961 if op == BinaryOperator::Xor && left == right {
963 self.stats.nodes_eliminated += 1;
964 self.stats.algebraic_simplifications += 1;
965 return CircuitNode::Constant(CircuitValue::Bool(false));
966 }
967
968 match op {
970 BinaryOperator::Add => {
971 if Self::is_zero(&right) {
972 self.stats.nodes_eliminated += 1;
973 self.stats.algebraic_simplifications += 1;
974 return left;
975 }
976 if Self::is_zero(&left) {
977 self.stats.nodes_eliminated += 1;
978 self.stats.algebraic_simplifications += 1;
979 return right;
980 }
981 }
982 BinaryOperator::Sub => {
983 if Self::is_zero(&right) {
984 self.stats.nodes_eliminated += 1;
985 self.stats.algebraic_simplifications += 1;
986 return left;
987 }
988 }
989 BinaryOperator::Mul => {
990 if Self::is_zero(&right) {
991 self.stats.nodes_eliminated += 1;
992 self.stats.algebraic_simplifications += 1;
993 return right;
994 }
995 if Self::is_zero(&left) {
996 self.stats.nodes_eliminated += 1;
997 self.stats.algebraic_simplifications += 1;
998 return left;
999 }
1000 if Self::is_one(&right) {
1001 self.stats.nodes_eliminated += 1;
1002 self.stats.algebraic_simplifications += 1;
1003 return left;
1004 }
1005 if Self::is_one(&left) {
1006 self.stats.nodes_eliminated += 1;
1007 self.stats.algebraic_simplifications += 1;
1008 return right;
1009 }
1010 }
1011 BinaryOperator::And => {
1012 if left == right {
1014 self.stats.nodes_eliminated += 1;
1015 self.stats.algebraic_simplifications += 1;
1016 return left;
1017 }
1018 if Self::is_true(&right) {
1019 self.stats.nodes_eliminated += 1;
1020 self.stats.algebraic_simplifications += 1;
1021 return left;
1022 }
1023 if Self::is_true(&left) {
1024 self.stats.nodes_eliminated += 1;
1025 self.stats.algebraic_simplifications += 1;
1026 return right;
1027 }
1028 if Self::is_false(&right) {
1029 self.stats.nodes_eliminated += 1;
1030 self.stats.algebraic_simplifications += 1;
1031 return right;
1032 }
1033 if Self::is_false(&left) {
1034 self.stats.nodes_eliminated += 1;
1035 self.stats.algebraic_simplifications += 1;
1036 return left;
1037 }
1038 }
1039 BinaryOperator::Or => {
1040 if left == right {
1042 self.stats.nodes_eliminated += 1;
1043 self.stats.algebraic_simplifications += 1;
1044 return left;
1045 }
1046 if Self::is_false(&right) {
1047 self.stats.nodes_eliminated += 1;
1048 self.stats.algebraic_simplifications += 1;
1049 return left;
1050 }
1051 if Self::is_false(&left) {
1052 self.stats.nodes_eliminated += 1;
1053 self.stats.algebraic_simplifications += 1;
1054 return right;
1055 }
1056 if Self::is_true(&right) {
1057 self.stats.nodes_eliminated += 1;
1058 self.stats.algebraic_simplifications += 1;
1059 return right;
1060 }
1061 if Self::is_true(&left) {
1062 self.stats.nodes_eliminated += 1;
1063 self.stats.algebraic_simplifications += 1;
1064 return left;
1065 }
1066 }
1067 BinaryOperator::Xor => {
1068 if Self::is_false(&right) {
1069 self.stats.nodes_eliminated += 1;
1070 self.stats.algebraic_simplifications += 1;
1071 return left;
1072 }
1073 if Self::is_false(&left) {
1074 self.stats.nodes_eliminated += 1;
1075 self.stats.algebraic_simplifications += 1;
1076 return right;
1077 }
1078 }
1079 }
1080
1081 CircuitNode::BinaryOp {
1082 op,
1083 left: Box::new(left),
1084 right: Box::new(right),
1085 }
1086 }
1087
1088 CircuitNode::UnaryOp { op, operand } => {
1089 let operand = self.dce_simplify(*operand);
1090
1091 if let CircuitNode::Constant(val) = &operand {
1093 if let Some(result) = self.fold_unary_constant(op, val) {
1094 self.stats.nodes_eliminated += 1;
1095 self.stats.constants_folded += 1;
1096 return CircuitNode::Constant(result);
1097 }
1098 }
1099
1100 if op == UnaryOperator::Not {
1102 if let CircuitNode::UnaryOp {
1103 op: UnaryOperator::Not,
1104 operand: inner,
1105 } = operand
1106 {
1107 self.stats.nodes_eliminated += 2;
1108 self.stats.algebraic_simplifications += 1;
1109 return *inner;
1110 }
1111 }
1112
1113 if op == UnaryOperator::Neg {
1115 if let CircuitNode::UnaryOp {
1116 op: UnaryOperator::Neg,
1117 operand: inner,
1118 } = operand
1119 {
1120 self.stats.nodes_eliminated += 2;
1121 self.stats.algebraic_simplifications += 1;
1122 return *inner;
1123 }
1124 }
1125
1126 CircuitNode::UnaryOp {
1127 op,
1128 operand: Box::new(operand),
1129 }
1130 }
1131
1132 CircuitNode::Compare { op, left, right } => {
1133 let left = self.dce_simplify(*left);
1134 let right = self.dce_simplify(*right);
1135
1136 if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
1138 if let Some(result) = self.fold_comparison(op, l, r) {
1139 self.stats.nodes_eliminated += 1;
1140 self.stats.constants_folded += 1;
1141 return CircuitNode::Constant(CircuitValue::Bool(result));
1142 }
1143 }
1144
1145 CircuitNode::Compare {
1146 op,
1147 left: Box::new(left),
1148 right: Box::new(right),
1149 }
1150 }
1151
1152 CircuitNode::NaryOp { op, operands } => {
1153 let new_operands: Vec<CircuitNode> =
1154 operands.into_iter().map(|o| self.dce_simplify(o)).collect();
1155 CircuitNode::NaryOp {
1156 op,
1157 operands: new_operands,
1158 }
1159 }
1160
1161 other => other,
1162 }
1163 }
1164
1165 fn zero_like(&self, node: &CircuitNode) -> CircuitNode {
1167 match node {
1168 CircuitNode::Constant(CircuitValue::U8(_)) => {
1169 CircuitNode::Constant(CircuitValue::U8(0))
1170 }
1171 CircuitNode::Constant(CircuitValue::U16(_)) => {
1172 CircuitNode::Constant(CircuitValue::U16(0))
1173 }
1174 CircuitNode::Constant(CircuitValue::U32(_)) => {
1175 CircuitNode::Constant(CircuitValue::U32(0))
1176 }
1177 CircuitNode::Constant(CircuitValue::U64(_)) => {
1178 CircuitNode::Constant(CircuitValue::U64(0))
1179 }
1180 _ => CircuitNode::Constant(CircuitValue::U8(0)),
1182 }
1183 }
1184
1185 fn fold_comparison(
1187 &self,
1188 op: CompareOperator,
1189 left: &CircuitValue,
1190 right: &CircuitValue,
1191 ) -> Option<bool> {
1192 match (left, right) {
1193 (CircuitValue::U8(l), CircuitValue::U8(r)) => Some(self.compare_values(op, *l, *r)),
1194 (CircuitValue::U16(l), CircuitValue::U16(r)) => Some(self.compare_values(op, *l, *r)),
1195 (CircuitValue::U32(l), CircuitValue::U32(r)) => Some(self.compare_values(op, *l, *r)),
1196 (CircuitValue::U64(l), CircuitValue::U64(r)) => Some(self.compare_values(op, *l, *r)),
1197 (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
1198 CompareOperator::Eq => Some(l == r),
1199 CompareOperator::Ne => Some(l != r),
1200 _ => None,
1201 },
1202 _ => None,
1203 }
1204 }
1205
1206 fn compare_values<T: PartialOrd + PartialEq>(&self, op: CompareOperator, l: T, r: T) -> bool {
1208 match op {
1209 CompareOperator::Eq => l == r,
1210 CompareOperator::Ne => l != r,
1211 CompareOperator::Lt => l < r,
1212 CompareOperator::Le => l <= r,
1213 CompareOperator::Gt => l > r,
1214 CompareOperator::Ge => l >= r,
1215 }
1216 }
1217
1218 pub fn collect_live_variables(&self, node: &CircuitNode) -> HashSet<String> {
1220 let mut live = HashSet::new();
1221 self.mark_live_nodes(node, &mut live);
1222 live
1223 }
1224
1225 #[allow(clippy::only_used_in_recursion)]
1227 fn mark_live_nodes(&self, node: &CircuitNode, live_nodes: &mut HashSet<String>) {
1228 match node {
1229 CircuitNode::Load(name) => {
1230 live_nodes.insert(name.clone());
1231 }
1232
1233 CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {}
1234
1235 CircuitNode::BinaryOp { left, right, .. } => {
1236 self.mark_live_nodes(left, live_nodes);
1237 self.mark_live_nodes(right, live_nodes);
1238 }
1239
1240 CircuitNode::UnaryOp { operand, .. } => {
1241 self.mark_live_nodes(operand, live_nodes);
1242 }
1243
1244 CircuitNode::Compare { left, right, .. } => {
1245 self.mark_live_nodes(left, live_nodes);
1246 self.mark_live_nodes(right, live_nodes);
1247 }
1248 CircuitNode::NaryOp { operands, .. } => {
1249 for operand in operands {
1250 self.mark_live_nodes(operand, live_nodes);
1251 }
1252 }
1253 }
1254 }
1255
1256 fn analyze_parallelism(&self, circuit: &Circuit) -> Result<DependencyGraph> {
1260 let mut graph = DependencyGraph::new();
1261 let mut node_id_map = HashMap::new();
1262 let mut cse_map = HashMap::new();
1263 let mut next_id = 0;
1264
1265 self.build_dependency_graph(
1267 &circuit.root,
1268 &mut graph,
1269 &mut node_id_map,
1270 &mut cse_map,
1271 &mut next_id,
1272 );
1273
1274 graph.node_count = next_id;
1275
1276 graph.parallel_groups = self.identify_parallel_groups(&graph);
1278
1279 graph.critical_path = self.find_critical_path(&graph);
1281
1282 Ok(graph)
1283 }
1284
1285 #[allow(clippy::only_used_in_recursion)]
1287 fn build_dependency_graph(
1288 &self,
1289 node: &CircuitNode,
1290 graph: &mut DependencyGraph,
1291 node_id_map: &mut HashMap<String, NodeId>,
1292 cse_map: &mut HashMap<u64, NodeId>,
1293 next_id: &mut usize,
1294 ) -> NodeId {
1295 let node_hash = Self::structural_hash(node);
1297 if let Some(&existing_id) = cse_map.get(&node_hash) {
1298 return existing_id;
1299 }
1300
1301 let current_id = NodeId(*next_id);
1302 *next_id += 1;
1303 cse_map.insert(node_hash, current_id);
1304
1305 match node {
1306 CircuitNode::Load(name) => {
1307 node_id_map.insert(name.clone(), current_id);
1308 graph.dependencies.insert(current_id, Vec::new());
1309 current_id
1310 }
1311
1312 CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {
1313 graph.dependencies.insert(current_id, Vec::new());
1314 current_id
1315 }
1316
1317 CircuitNode::BinaryOp { left, right, .. } => {
1318 let left_id =
1319 self.build_dependency_graph(left, graph, node_id_map, cse_map, next_id);
1320 let right_id =
1321 self.build_dependency_graph(right, graph, node_id_map, cse_map, next_id);
1322 graph
1323 .dependencies
1324 .insert(current_id, vec![left_id, right_id]);
1325 current_id
1326 }
1327
1328 CircuitNode::UnaryOp { operand, .. } => {
1329 let operand_id =
1330 self.build_dependency_graph(operand, graph, node_id_map, cse_map, next_id);
1331 graph.dependencies.insert(current_id, vec![operand_id]);
1332 current_id
1333 }
1334
1335 CircuitNode::Compare { left, right, .. } => {
1336 let left_id =
1337 self.build_dependency_graph(left, graph, node_id_map, cse_map, next_id);
1338 let right_id =
1339 self.build_dependency_graph(right, graph, node_id_map, cse_map, next_id);
1340 graph
1341 .dependencies
1342 .insert(current_id, vec![left_id, right_id]);
1343 current_id
1344 }
1345
1346 CircuitNode::NaryOp { operands, .. } => {
1347 let dep_ids: Vec<NodeId> = operands
1348 .iter()
1349 .map(|o| self.build_dependency_graph(o, graph, node_id_map, cse_map, next_id))
1350 .collect();
1351 graph.dependencies.insert(current_id, dep_ids);
1352 current_id
1353 }
1354 }
1355 }
1356
1357 fn structural_hash(node: &CircuitNode) -> u64 {
1359 use std::collections::hash_map::DefaultHasher;
1360 use std::hash::Hasher;
1361
1362 let mut hasher = DefaultHasher::new();
1363 Self::hash_node(node, &mut hasher);
1364 hasher.finish()
1365 }
1366
1367 fn hash_node(node: &CircuitNode, hasher: &mut impl std::hash::Hasher) {
1368 use std::hash::Hash;
1369 match node {
1370 CircuitNode::Load(name) => {
1371 0u8.hash(hasher);
1372 name.hash(hasher);
1373 }
1374 CircuitNode::Constant(value) => {
1375 1u8.hash(hasher);
1376 match value {
1377 crate::compute::circuit::CircuitValue::Bool(v) => {
1378 0u8.hash(hasher);
1379 v.hash(hasher);
1380 }
1381 crate::compute::circuit::CircuitValue::U8(v) => {
1382 1u8.hash(hasher);
1383 v.hash(hasher);
1384 }
1385 crate::compute::circuit::CircuitValue::U16(v) => {
1386 2u8.hash(hasher);
1387 v.hash(hasher);
1388 }
1389 crate::compute::circuit::CircuitValue::U32(v) => {
1390 3u8.hash(hasher);
1391 v.hash(hasher);
1392 }
1393 crate::compute::circuit::CircuitValue::U64(v) => {
1394 4u8.hash(hasher);
1395 v.hash(hasher);
1396 }
1397 }
1398 }
1399 CircuitNode::EncryptedConstant {
1400 data,
1401 original_type,
1402 } => {
1403 2u8.hash(hasher);
1404 data.hash(hasher);
1405 match original_type {
1406 crate::compute::circuit::ConstantType::Integer => 0u8.hash(hasher),
1407 crate::compute::circuit::ConstantType::Boolean => 1u8.hash(hasher),
1408 crate::compute::circuit::ConstantType::Float => 2u8.hash(hasher),
1409 crate::compute::circuit::ConstantType::Bytes => 3u8.hash(hasher),
1410 }
1411 }
1412 CircuitNode::BinaryOp { op, left, right } => {
1413 3u8.hash(hasher);
1414 Self::hash_binary_op(*op, hasher);
1415 Self::hash_node(left, hasher);
1416 Self::hash_node(right, hasher);
1417 }
1418 CircuitNode::UnaryOp { op, operand } => {
1419 4u8.hash(hasher);
1420 match op {
1421 UnaryOperator::Not => 0u8.hash(hasher),
1422 UnaryOperator::Neg => 1u8.hash(hasher),
1423 }
1424 Self::hash_node(operand, hasher);
1425 }
1426 CircuitNode::Compare { op, left, right } => {
1427 5u8.hash(hasher);
1428 match op {
1429 CompareOperator::Eq => 0u8.hash(hasher),
1430 CompareOperator::Ne => 1u8.hash(hasher),
1431 CompareOperator::Lt => 2u8.hash(hasher),
1432 CompareOperator::Le => 3u8.hash(hasher),
1433 CompareOperator::Gt => 4u8.hash(hasher),
1434 CompareOperator::Ge => 5u8.hash(hasher),
1435 }
1436 Self::hash_node(left, hasher);
1437 Self::hash_node(right, hasher);
1438 }
1439 CircuitNode::NaryOp { op, operands } => {
1440 6u8.hash(hasher);
1441 Self::hash_binary_op(*op, hasher);
1442 operands.len().hash(hasher);
1443 for o in operands {
1444 Self::hash_node(o, hasher);
1445 }
1446 }
1447 }
1448 }
1449
1450 fn hash_binary_op(op: BinaryOperator, hasher: &mut impl std::hash::Hasher) {
1451 use std::hash::Hash;
1452 match op {
1453 BinaryOperator::Add => 0u8.hash(hasher),
1454 BinaryOperator::Sub => 1u8.hash(hasher),
1455 BinaryOperator::Mul => 2u8.hash(hasher),
1456 BinaryOperator::And => 3u8.hash(hasher),
1457 BinaryOperator::Or => 4u8.hash(hasher),
1458 BinaryOperator::Xor => 5u8.hash(hasher),
1459 }
1460 }
1461
1462 fn identify_parallel_groups(&self, graph: &DependencyGraph) -> Vec<Vec<NodeId>> {
1464 let mut levels: HashMap<NodeId, usize> = HashMap::new();
1465 let mut queue = VecDeque::new();
1466
1467 for (node_id, deps) in &graph.dependencies {
1469 if deps.is_empty() {
1470 levels.insert(*node_id, 0);
1471 queue.push_back(*node_id);
1472 }
1473 }
1474
1475 while let Some(node_id) = queue.pop_front() {
1477 let current_level = levels[&node_id];
1478
1479 for (dependent_id, deps) in &graph.dependencies {
1481 if deps.contains(&node_id) {
1482 let max_dep_level = deps
1484 .iter()
1485 .filter_map(|dep_id| levels.get(dep_id))
1486 .max()
1487 .copied()
1488 .unwrap_or(0);
1489
1490 let dependent_level = max_dep_level + 1;
1491
1492 if !levels.contains_key(dependent_id) {
1493 levels.insert(*dependent_id, dependent_level);
1494 queue.push_back(*dependent_id);
1495 }
1496 }
1497 }
1498 }
1499
1500 let max_level = levels.values().max().copied().unwrap_or(0);
1502 let mut parallel_groups = vec![Vec::new(); max_level + 1];
1503
1504 for (node_id, level) in levels {
1505 parallel_groups[level].push(node_id);
1506 }
1507
1508 for group in &mut parallel_groups {
1510 group.sort();
1511 }
1512
1513 parallel_groups
1514 }
1515
1516 fn find_critical_path(&self, graph: &DependencyGraph) -> Vec<NodeId> {
1518 let mut memo = HashMap::new();
1519
1520 for &node_id in graph.dependencies.keys() {
1522 self.longest_path_to(node_id, graph, &mut memo);
1523 }
1524
1525 let max_node = graph
1527 .dependencies
1528 .keys()
1529 .max_by_key(|&&id| memo.get(&id).copied().unwrap_or(0));
1530
1531 let Some(&end_node) = max_node else {
1532 return Vec::new();
1533 };
1534
1535 let mut path = Vec::new();
1537 let mut current = end_node;
1538 path.push(current);
1539
1540 loop {
1541 let deps = match graph.dependencies.get(¤t) {
1542 Some(d) if !d.is_empty() => d,
1543 _ => break,
1544 };
1545 let next = deps
1546 .iter()
1547 .max_by_key(|&&dep_id| memo.get(&dep_id).copied().unwrap_or(0))
1548 .copied();
1549 match next {
1550 Some(next_id) if next_id != current => {
1551 path.push(next_id);
1552 current = next_id;
1553 }
1554 _ => break,
1555 }
1556 }
1557
1558 path.reverse();
1559 path
1560 }
1561
1562 fn longest_path_to(
1564 &self,
1565 node_id: NodeId,
1566 graph: &DependencyGraph,
1567 memo: &mut HashMap<NodeId, usize>,
1568 ) -> usize {
1569 if let Some(&cached) = memo.get(&node_id) {
1570 return cached;
1571 }
1572
1573 let deps = graph
1574 .dependencies
1575 .get(&node_id)
1576 .map(|v| v.as_slice())
1577 .unwrap_or(&[]);
1578
1579 let result = if deps.is_empty() {
1580 1
1581 } else {
1582 let max_dep = deps
1583 .iter()
1584 .map(|&dep_id| self.longest_path_to(dep_id, graph, memo))
1585 .max()
1586 .unwrap_or(0);
1587 max_dep + 1
1588 };
1589
1590 memo.insert(node_id, result);
1591 result
1592 }
1593}
1594
1595impl Default for CircuitOptimizer {
1596 fn default() -> Self {
1597 Self::new()
1598 }
1599}
1600
1601#[cfg(test)]
1602#[path = "optimizer_tests.rs"]
1603mod tests;