1use crate::compute::circuit::{
14 BinaryOperator, Circuit, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
15 UnaryOperator,
16};
17use crate::error::{AmateRSError, ErrorContext, Result};
18use std::collections::{HashMap, HashSet, VecDeque};
19
20#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct OptimizationStats {
23 pub original_gate_count: usize,
25
26 pub optimized_gate_count: usize,
28
29 pub original_bootstrap_count: usize,
31
32 pub optimized_bootstrap_count: usize,
34
35 pub dead_code_removed: usize,
37
38 pub nodes_eliminated: usize,
40
41 pub algebraic_simplifications: usize,
43
44 pub constants_folded: usize,
46
47 pub gates_fused: usize,
49
50 pub original_depth: usize,
52
53 pub optimized_depth: usize,
55}
56
57impl OptimizationStats {
58 pub fn gate_reduction_percent(&self) -> f64 {
60 if self.original_gate_count == 0 {
61 return 0.0;
62 }
63 let reduction = self
64 .original_gate_count
65 .saturating_sub(self.optimized_gate_count);
66 (reduction as f64 / self.original_gate_count as f64) * 100.0
67 }
68
69 pub fn bootstrap_reduction_percent(&self) -> f64 {
71 if self.original_bootstrap_count == 0 {
72 return 0.0;
73 }
74 let reduction = self
75 .original_bootstrap_count
76 .saturating_sub(self.optimized_bootstrap_count);
77 (reduction as f64 / self.original_bootstrap_count as f64) * 100.0
78 }
79
80 pub fn total_stats(&self) -> (usize, usize, usize) {
82 (
83 self.nodes_eliminated + self.dead_code_removed,
84 self.algebraic_simplifications + self.gates_fused,
85 self.constants_folded,
86 )
87 }
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
92pub struct DependencyGraph {
93 pub dependencies: HashMap<NodeId, Vec<NodeId>>,
95
96 pub parallel_groups: Vec<Vec<NodeId>>,
98
99 pub critical_path: Vec<NodeId>,
101
102 pub node_count: usize,
104}
105
106impl DependencyGraph {
107 pub fn new() -> Self {
109 Self {
110 dependencies: HashMap::new(),
111 parallel_groups: Vec::new(),
112 critical_path: Vec::new(),
113 node_count: 0,
114 }
115 }
116
117 pub fn max_parallelism(&self) -> usize {
119 self.parallel_groups
120 .iter()
121 .map(|g| g.len())
122 .max()
123 .unwrap_or(0)
124 }
125
126 pub fn avg_parallelism(&self) -> f64 {
128 if self.parallel_groups.is_empty() {
129 return 0.0;
130 }
131 let total: usize = self.parallel_groups.iter().map(|g| g.len()).sum();
132 total as f64 / self.parallel_groups.len() as f64
133 }
134}
135
136impl Default for DependencyGraph {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
144pub struct NodeId(pub usize);
145
146#[derive(Debug, Clone)]
148pub struct CircuitOptimizer {
149 pub enable_constant_folding: bool,
151
152 pub enable_dead_code_elimination: bool,
154
155 pub enable_bootstrap_minimization: bool,
157
158 pub enable_gate_fusion: bool,
160
161 pub enable_parallelization_analysis: bool,
163
164 stats: OptimizationStats,
166
167 dependency_graph: DependencyGraph,
169}
170
171impl CircuitOptimizer {
172 pub fn new() -> Self {
174 Self {
175 enable_constant_folding: true,
176 enable_dead_code_elimination: true,
177 enable_bootstrap_minimization: true,
178 enable_gate_fusion: true,
179 enable_parallelization_analysis: true,
180 stats: OptimizationStats::default(),
181 dependency_graph: DependencyGraph::new(),
182 }
183 }
184
185 pub fn disabled() -> Self {
187 Self {
188 enable_constant_folding: false,
189 enable_dead_code_elimination: false,
190 enable_bootstrap_minimization: false,
191 enable_gate_fusion: false,
192 enable_parallelization_analysis: false,
193 stats: OptimizationStats::default(),
194 dependency_graph: DependencyGraph::new(),
195 }
196 }
197
198 pub fn stats(&self) -> &OptimizationStats {
200 &self.stats
201 }
202
203 pub fn dependency_graph(&self) -> &DependencyGraph {
205 &self.dependency_graph
206 }
207
208 pub fn total_stats(&self) -> (usize, usize, usize) {
210 self.stats.total_stats()
211 }
212
213 pub fn optimize(&mut self, circuit: Circuit) -> Result<Circuit> {
215 self.stats.original_gate_count = circuit.gate_count;
217 self.stats.original_depth = circuit.depth;
218 self.stats.original_bootstrap_count = self.count_bootstraps(&circuit.root);
219
220 let mut optimized_root = circuit.root.clone();
221
222 if self.enable_constant_folding {
224 optimized_root = self.constant_folding_pass(optimized_root);
225 }
226
227 if self.enable_gate_fusion {
228 optimized_root = self.gate_fusion_pass(optimized_root);
229 }
230
231 if self.enable_bootstrap_minimization {
232 optimized_root = self.bootstrap_minimization_pass(optimized_root)?;
233 }
234
235 if self.enable_dead_code_elimination {
236 optimized_root = self.dead_code_elimination_pass(optimized_root);
237 }
238
239 let optimized_circuit = Circuit::new(optimized_root, circuit.variable_types)?;
241
242 self.stats.optimized_gate_count = optimized_circuit.gate_count;
244 self.stats.optimized_depth = optimized_circuit.depth;
245 self.stats.optimized_bootstrap_count = self.count_bootstraps(&optimized_circuit.root);
246
247 if self.enable_parallelization_analysis {
249 self.dependency_graph = self.analyze_parallelism(&optimized_circuit)?;
250 }
251
252 Ok(optimized_circuit)
253 }
254
255 #[allow(clippy::only_used_in_recursion)]
263 fn count_bootstraps(&self, node: &CircuitNode) -> usize {
264 match node {
265 CircuitNode::Load(_)
266 | CircuitNode::Constant(_)
267 | CircuitNode::EncryptedConstant { .. } => 0,
268
269 CircuitNode::BinaryOp { op, left, right } => {
270 let left_bootstraps = self.count_bootstraps(left);
271 let right_bootstraps = self.count_bootstraps(right);
272
273 let op_bootstrap = match op {
275 BinaryOperator::Mul => 1,
276 _ => 0,
277 };
278
279 left_bootstraps + right_bootstraps + op_bootstrap
280 }
281
282 CircuitNode::UnaryOp { operand, .. } => self.count_bootstraps(operand),
283
284 CircuitNode::Compare { left, right, .. } => {
285 let left_bootstraps = self.count_bootstraps(left);
286 let right_bootstraps = self.count_bootstraps(right);
287
288 left_bootstraps + right_bootstraps + 1
290 }
291 }
292 }
293
294 fn constant_folding_pass(&mut self, node: CircuitNode) -> CircuitNode {
298 match node {
299 CircuitNode::BinaryOp { op, left, right } => {
300 let left = self.constant_folding_pass(*left);
301 let right = self.constant_folding_pass(*right);
302
303 if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
305 if let Some(result) = self.fold_binary_constants(op, l, r) {
306 self.stats.constants_folded += 1;
307 return CircuitNode::Constant(result);
308 }
309 }
310
311 if let Some(simplified) = self.apply_algebraic_identities(op, &left, &right) {
313 return simplified;
314 }
315
316 CircuitNode::BinaryOp {
317 op,
318 left: Box::new(left),
319 right: Box::new(right),
320 }
321 }
322
323 CircuitNode::UnaryOp { op, operand } => {
324 let operand = self.constant_folding_pass(*operand);
325
326 if let CircuitNode::Constant(val) = &operand {
327 if let Some(result) = self.fold_unary_constant(op, val) {
328 self.stats.constants_folded += 1;
329 return CircuitNode::Constant(result);
330 }
331 }
332
333 CircuitNode::UnaryOp {
334 op,
335 operand: Box::new(operand),
336 }
337 }
338
339 CircuitNode::Compare { op, left, right } => {
340 let left = self.constant_folding_pass(*left);
341 let right = self.constant_folding_pass(*right);
342
343 CircuitNode::Compare {
344 op,
345 left: Box::new(left),
346 right: Box::new(right),
347 }
348 }
349
350 other => other,
351 }
352 }
353
354 fn fold_binary_constants(
356 &self,
357 op: BinaryOperator,
358 left: &CircuitValue,
359 right: &CircuitValue,
360 ) -> Option<CircuitValue> {
361 match (left, right) {
362 (CircuitValue::U8(l), CircuitValue::U8(r)) => match op {
363 BinaryOperator::Add => Some(CircuitValue::U8(l.wrapping_add(*r))),
364 BinaryOperator::Sub => Some(CircuitValue::U8(l.wrapping_sub(*r))),
365 BinaryOperator::Mul => Some(CircuitValue::U8(l.wrapping_mul(*r))),
366 _ => None,
367 },
368 (CircuitValue::U16(l), CircuitValue::U16(r)) => match op {
369 BinaryOperator::Add => Some(CircuitValue::U16(l.wrapping_add(*r))),
370 BinaryOperator::Sub => Some(CircuitValue::U16(l.wrapping_sub(*r))),
371 BinaryOperator::Mul => Some(CircuitValue::U16(l.wrapping_mul(*r))),
372 _ => None,
373 },
374 (CircuitValue::U32(l), CircuitValue::U32(r)) => match op {
375 BinaryOperator::Add => Some(CircuitValue::U32(l.wrapping_add(*r))),
376 BinaryOperator::Sub => Some(CircuitValue::U32(l.wrapping_sub(*r))),
377 BinaryOperator::Mul => Some(CircuitValue::U32(l.wrapping_mul(*r))),
378 _ => None,
379 },
380 (CircuitValue::U64(l), CircuitValue::U64(r)) => match op {
381 BinaryOperator::Add => Some(CircuitValue::U64(l.wrapping_add(*r))),
382 BinaryOperator::Sub => Some(CircuitValue::U64(l.wrapping_sub(*r))),
383 BinaryOperator::Mul => Some(CircuitValue::U64(l.wrapping_mul(*r))),
384 _ => None,
385 },
386 (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
387 BinaryOperator::And => Some(CircuitValue::Bool(*l && *r)),
388 BinaryOperator::Or => Some(CircuitValue::Bool(*l || *r)),
389 BinaryOperator::Xor => Some(CircuitValue::Bool(*l ^ *r)),
390 _ => None,
391 },
392 _ => None,
393 }
394 }
395
396 fn fold_unary_constant(&self, op: UnaryOperator, value: &CircuitValue) -> Option<CircuitValue> {
398 match (op, value) {
399 (UnaryOperator::Not, CircuitValue::Bool(v)) => Some(CircuitValue::Bool(!*v)),
400 _ => None,
401 }
402 }
403
404 fn apply_algebraic_identities(
407 &mut self,
408 op: BinaryOperator,
409 left: &CircuitNode,
410 right: &CircuitNode,
411 ) -> Option<CircuitNode> {
412 match op {
413 BinaryOperator::Add => {
414 if Self::is_zero(right) {
416 self.stats.gates_fused += 1;
417 return Some(left.clone());
418 }
419 if Self::is_zero(left) {
421 self.stats.gates_fused += 1;
422 return Some(right.clone());
423 }
424 }
425
426 BinaryOperator::Sub => {
427 if Self::is_zero(right) {
429 self.stats.gates_fused += 1;
430 return Some(left.clone());
431 }
432 }
433
434 BinaryOperator::Mul => {
435 if Self::is_zero(right) {
437 self.stats.gates_fused += 1;
438 return Some(right.clone());
439 }
440 if Self::is_zero(left) {
441 self.stats.gates_fused += 1;
442 return Some(left.clone());
443 }
444
445 if Self::is_one(right) {
447 self.stats.gates_fused += 1;
448 return Some(left.clone());
449 }
450 if Self::is_one(left) {
452 self.stats.gates_fused += 1;
453 return Some(right.clone());
454 }
455 }
456
457 BinaryOperator::And => {
458 if Self::is_true(right) {
460 self.stats.gates_fused += 1;
461 return Some(left.clone());
462 }
463 if Self::is_true(left) {
464 self.stats.gates_fused += 1;
465 return Some(right.clone());
466 }
467
468 if Self::is_false(right) {
470 self.stats.gates_fused += 1;
471 return Some(right.clone());
472 }
473 if Self::is_false(left) {
474 self.stats.gates_fused += 1;
475 return Some(left.clone());
476 }
477 }
478
479 BinaryOperator::Or => {
480 if Self::is_false(right) {
482 self.stats.gates_fused += 1;
483 return Some(left.clone());
484 }
485 if Self::is_false(left) {
486 self.stats.gates_fused += 1;
487 return Some(right.clone());
488 }
489
490 if Self::is_true(right) {
492 self.stats.gates_fused += 1;
493 return Some(right.clone());
494 }
495 if Self::is_true(left) {
496 self.stats.gates_fused += 1;
497 return Some(left.clone());
498 }
499 }
500
501 BinaryOperator::Xor => {
502 if Self::is_false(right) {
504 self.stats.gates_fused += 1;
505 return Some(left.clone());
506 }
507 if Self::is_false(left) {
508 self.stats.gates_fused += 1;
509 return Some(right.clone());
510 }
511 }
512 }
513
514 None
515 }
516
517 fn is_zero(node: &CircuitNode) -> bool {
519 matches!(
520 node,
521 CircuitNode::Constant(CircuitValue::U8(0))
522 | CircuitNode::Constant(CircuitValue::U16(0))
523 | CircuitNode::Constant(CircuitValue::U32(0))
524 | CircuitNode::Constant(CircuitValue::U64(0))
525 )
526 }
527
528 fn is_one(node: &CircuitNode) -> bool {
530 matches!(
531 node,
532 CircuitNode::Constant(CircuitValue::U8(1))
533 | CircuitNode::Constant(CircuitValue::U16(1))
534 | CircuitNode::Constant(CircuitValue::U32(1))
535 | CircuitNode::Constant(CircuitValue::U64(1))
536 )
537 }
538
539 fn is_true(node: &CircuitNode) -> bool {
541 matches!(node, CircuitNode::Constant(CircuitValue::Bool(true)))
542 }
543
544 fn is_false(node: &CircuitNode) -> bool {
546 matches!(node, CircuitNode::Constant(CircuitValue::Bool(false)))
547 }
548
549 fn gate_fusion_pass(&mut self, node: CircuitNode) -> CircuitNode {
555 match node {
556 CircuitNode::BinaryOp { op, left, right } => {
557 let left = self.gate_fusion_pass(*left);
558 let right = self.gate_fusion_pass(*right);
559
560 CircuitNode::BinaryOp {
561 op,
562 left: Box::new(left),
563 right: Box::new(right),
564 }
565 }
566
567 CircuitNode::UnaryOp {
568 op: UnaryOperator::Not,
569 operand,
570 } => {
571 let operand = self.gate_fusion_pass(*operand);
572
573 if let CircuitNode::UnaryOp {
575 op: UnaryOperator::Not,
576 operand: inner,
577 } = operand
578 {
579 self.stats.gates_fused += 2; return *inner;
581 }
582
583 CircuitNode::UnaryOp {
584 op: UnaryOperator::Not,
585 operand: Box::new(operand),
586 }
587 }
588
589 CircuitNode::UnaryOp { op, operand } => {
590 let operand = self.gate_fusion_pass(*operand);
591 CircuitNode::UnaryOp {
592 op,
593 operand: Box::new(operand),
594 }
595 }
596
597 CircuitNode::Compare { op, left, right } => {
598 let left = self.gate_fusion_pass(*left);
599 let right = self.gate_fusion_pass(*right);
600
601 CircuitNode::Compare {
602 op,
603 left: Box::new(left),
604 right: Box::new(right),
605 }
606 }
607
608 other => other,
609 }
610 }
611
612 fn bootstrap_minimization_pass(&mut self, node: CircuitNode) -> Result<CircuitNode> {
619 Ok(self.reorder_for_bootstrap_efficiency(node))
622 }
623
624 #[allow(clippy::only_used_in_recursion)]
626 fn reorder_for_bootstrap_efficiency(&self, node: CircuitNode) -> CircuitNode {
627 match node {
628 CircuitNode::BinaryOp { op, left, right } => {
629 let left = self.reorder_for_bootstrap_efficiency(*left);
630 let right = self.reorder_for_bootstrap_efficiency(*right);
631
632 CircuitNode::BinaryOp {
633 op,
634 left: Box::new(left),
635 right: Box::new(right),
636 }
637 }
638
639 CircuitNode::UnaryOp { op, operand } => {
640 let operand = self.reorder_for_bootstrap_efficiency(*operand);
641 CircuitNode::UnaryOp {
642 op,
643 operand: Box::new(operand),
644 }
645 }
646
647 CircuitNode::Compare { op, left, right } => {
648 let left = self.reorder_for_bootstrap_efficiency(*left);
649 let right = self.reorder_for_bootstrap_efficiency(*right);
650
651 CircuitNode::Compare {
652 op,
653 left: Box::new(left),
654 right: Box::new(right),
655 }
656 }
657
658 other => other,
659 }
660 }
661
662 fn dead_code_elimination_pass(&mut self, node: CircuitNode) -> CircuitNode {
676 let mut current = node;
677 loop {
679 let simplified = self.dce_simplify(current.clone());
680 if simplified == current {
681 break;
682 }
683 current = simplified;
684 }
685 current
686 }
687
688 fn dce_simplify(&mut self, node: CircuitNode) -> CircuitNode {
690 match node {
691 CircuitNode::BinaryOp { op, left, right } => {
692 let left = self.dce_simplify(*left);
694 let right = self.dce_simplify(*right);
695
696 if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
698 if let Some(result) = self.fold_binary_constants(op, l, r) {
699 self.stats.nodes_eliminated += 1;
700 self.stats.constants_folded += 1;
701 return CircuitNode::Constant(result);
702 }
703 }
704
705 if op == BinaryOperator::Sub && left == right {
707 self.stats.nodes_eliminated += 1;
708 self.stats.algebraic_simplifications += 1;
709 return self.zero_like(&left);
711 }
712
713 if op == BinaryOperator::Xor && left == right {
715 self.stats.nodes_eliminated += 1;
716 self.stats.algebraic_simplifications += 1;
717 return CircuitNode::Constant(CircuitValue::Bool(false));
718 }
719
720 match op {
722 BinaryOperator::Add => {
723 if Self::is_zero(&right) {
724 self.stats.nodes_eliminated += 1;
725 self.stats.algebraic_simplifications += 1;
726 return left;
727 }
728 if Self::is_zero(&left) {
729 self.stats.nodes_eliminated += 1;
730 self.stats.algebraic_simplifications += 1;
731 return right;
732 }
733 }
734 BinaryOperator::Sub => {
735 if Self::is_zero(&right) {
736 self.stats.nodes_eliminated += 1;
737 self.stats.algebraic_simplifications += 1;
738 return left;
739 }
740 }
741 BinaryOperator::Mul => {
742 if Self::is_zero(&right) {
743 self.stats.nodes_eliminated += 1;
744 self.stats.algebraic_simplifications += 1;
745 return right;
746 }
747 if Self::is_zero(&left) {
748 self.stats.nodes_eliminated += 1;
749 self.stats.algebraic_simplifications += 1;
750 return left;
751 }
752 if Self::is_one(&right) {
753 self.stats.nodes_eliminated += 1;
754 self.stats.algebraic_simplifications += 1;
755 return left;
756 }
757 if Self::is_one(&left) {
758 self.stats.nodes_eliminated += 1;
759 self.stats.algebraic_simplifications += 1;
760 return right;
761 }
762 }
763 BinaryOperator::And => {
764 if left == right {
766 self.stats.nodes_eliminated += 1;
767 self.stats.algebraic_simplifications += 1;
768 return left;
769 }
770 if Self::is_true(&right) {
771 self.stats.nodes_eliminated += 1;
772 self.stats.algebraic_simplifications += 1;
773 return left;
774 }
775 if Self::is_true(&left) {
776 self.stats.nodes_eliminated += 1;
777 self.stats.algebraic_simplifications += 1;
778 return right;
779 }
780 if Self::is_false(&right) {
781 self.stats.nodes_eliminated += 1;
782 self.stats.algebraic_simplifications += 1;
783 return right;
784 }
785 if Self::is_false(&left) {
786 self.stats.nodes_eliminated += 1;
787 self.stats.algebraic_simplifications += 1;
788 return left;
789 }
790 }
791 BinaryOperator::Or => {
792 if left == right {
794 self.stats.nodes_eliminated += 1;
795 self.stats.algebraic_simplifications += 1;
796 return left;
797 }
798 if Self::is_false(&right) {
799 self.stats.nodes_eliminated += 1;
800 self.stats.algebraic_simplifications += 1;
801 return left;
802 }
803 if Self::is_false(&left) {
804 self.stats.nodes_eliminated += 1;
805 self.stats.algebraic_simplifications += 1;
806 return right;
807 }
808 if Self::is_true(&right) {
809 self.stats.nodes_eliminated += 1;
810 self.stats.algebraic_simplifications += 1;
811 return right;
812 }
813 if Self::is_true(&left) {
814 self.stats.nodes_eliminated += 1;
815 self.stats.algebraic_simplifications += 1;
816 return left;
817 }
818 }
819 BinaryOperator::Xor => {
820 if Self::is_false(&right) {
821 self.stats.nodes_eliminated += 1;
822 self.stats.algebraic_simplifications += 1;
823 return left;
824 }
825 if Self::is_false(&left) {
826 self.stats.nodes_eliminated += 1;
827 self.stats.algebraic_simplifications += 1;
828 return right;
829 }
830 }
831 }
832
833 CircuitNode::BinaryOp {
834 op,
835 left: Box::new(left),
836 right: Box::new(right),
837 }
838 }
839
840 CircuitNode::UnaryOp { op, operand } => {
841 let operand = self.dce_simplify(*operand);
842
843 if let CircuitNode::Constant(val) = &operand {
845 if let Some(result) = self.fold_unary_constant(op, val) {
846 self.stats.nodes_eliminated += 1;
847 self.stats.constants_folded += 1;
848 return CircuitNode::Constant(result);
849 }
850 }
851
852 if op == UnaryOperator::Not {
854 if let CircuitNode::UnaryOp {
855 op: UnaryOperator::Not,
856 operand: inner,
857 } = operand
858 {
859 self.stats.nodes_eliminated += 2;
860 self.stats.algebraic_simplifications += 1;
861 return *inner;
862 }
863 }
864
865 if op == UnaryOperator::Neg {
867 if let CircuitNode::UnaryOp {
868 op: UnaryOperator::Neg,
869 operand: inner,
870 } = operand
871 {
872 self.stats.nodes_eliminated += 2;
873 self.stats.algebraic_simplifications += 1;
874 return *inner;
875 }
876 }
877
878 CircuitNode::UnaryOp {
879 op,
880 operand: Box::new(operand),
881 }
882 }
883
884 CircuitNode::Compare { op, left, right } => {
885 let left = self.dce_simplify(*left);
886 let right = self.dce_simplify(*right);
887
888 if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
890 if let Some(result) = self.fold_comparison(op, l, r) {
891 self.stats.nodes_eliminated += 1;
892 self.stats.constants_folded += 1;
893 return CircuitNode::Constant(CircuitValue::Bool(result));
894 }
895 }
896
897 CircuitNode::Compare {
898 op,
899 left: Box::new(left),
900 right: Box::new(right),
901 }
902 }
903
904 other => other,
905 }
906 }
907
908 fn zero_like(&self, node: &CircuitNode) -> CircuitNode {
910 match node {
911 CircuitNode::Constant(CircuitValue::U8(_)) => {
912 CircuitNode::Constant(CircuitValue::U8(0))
913 }
914 CircuitNode::Constant(CircuitValue::U16(_)) => {
915 CircuitNode::Constant(CircuitValue::U16(0))
916 }
917 CircuitNode::Constant(CircuitValue::U32(_)) => {
918 CircuitNode::Constant(CircuitValue::U32(0))
919 }
920 CircuitNode::Constant(CircuitValue::U64(_)) => {
921 CircuitNode::Constant(CircuitValue::U64(0))
922 }
923 _ => CircuitNode::Constant(CircuitValue::U8(0)),
925 }
926 }
927
928 fn fold_comparison(
930 &self,
931 op: CompareOperator,
932 left: &CircuitValue,
933 right: &CircuitValue,
934 ) -> Option<bool> {
935 match (left, right) {
936 (CircuitValue::U8(l), CircuitValue::U8(r)) => Some(self.compare_values(op, *l, *r)),
937 (CircuitValue::U16(l), CircuitValue::U16(r)) => Some(self.compare_values(op, *l, *r)),
938 (CircuitValue::U32(l), CircuitValue::U32(r)) => Some(self.compare_values(op, *l, *r)),
939 (CircuitValue::U64(l), CircuitValue::U64(r)) => Some(self.compare_values(op, *l, *r)),
940 (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
941 CompareOperator::Eq => Some(l == r),
942 CompareOperator::Ne => Some(l != r),
943 _ => None,
944 },
945 _ => None,
946 }
947 }
948
949 fn compare_values<T: PartialOrd + PartialEq>(&self, op: CompareOperator, l: T, r: T) -> bool {
951 match op {
952 CompareOperator::Eq => l == r,
953 CompareOperator::Ne => l != r,
954 CompareOperator::Lt => l < r,
955 CompareOperator::Le => l <= r,
956 CompareOperator::Gt => l > r,
957 CompareOperator::Ge => l >= r,
958 }
959 }
960
961 pub fn collect_live_variables(&self, node: &CircuitNode) -> HashSet<String> {
963 let mut live = HashSet::new();
964 self.mark_live_nodes(node, &mut live);
965 live
966 }
967
968 #[allow(clippy::only_used_in_recursion)]
970 fn mark_live_nodes(&self, node: &CircuitNode, live_nodes: &mut HashSet<String>) {
971 match node {
972 CircuitNode::Load(name) => {
973 live_nodes.insert(name.clone());
974 }
975
976 CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {}
977
978 CircuitNode::BinaryOp { left, right, .. } => {
979 self.mark_live_nodes(left, live_nodes);
980 self.mark_live_nodes(right, live_nodes);
981 }
982
983 CircuitNode::UnaryOp { operand, .. } => {
984 self.mark_live_nodes(operand, live_nodes);
985 }
986
987 CircuitNode::Compare { left, right, .. } => {
988 self.mark_live_nodes(left, live_nodes);
989 self.mark_live_nodes(right, live_nodes);
990 }
991 }
992 }
993
994 fn analyze_parallelism(&self, circuit: &Circuit) -> Result<DependencyGraph> {
998 let mut graph = DependencyGraph::new();
999 let mut node_id_map = HashMap::new();
1000 let mut next_id = 0;
1001
1002 self.build_dependency_graph(&circuit.root, &mut graph, &mut node_id_map, &mut next_id);
1004
1005 graph.node_count = next_id;
1006
1007 graph.parallel_groups = self.identify_parallel_groups(&graph);
1009
1010 graph.critical_path = self.find_critical_path(&graph);
1012
1013 Ok(graph)
1014 }
1015
1016 #[allow(clippy::only_used_in_recursion)]
1018 fn build_dependency_graph(
1019 &self,
1020 node: &CircuitNode,
1021 graph: &mut DependencyGraph,
1022 node_id_map: &mut HashMap<String, NodeId>,
1023 next_id: &mut usize,
1024 ) -> NodeId {
1025 let current_id = NodeId(*next_id);
1026 *next_id += 1;
1027
1028 match node {
1029 CircuitNode::Load(name) => {
1030 node_id_map.insert(name.clone(), current_id);
1031 graph.dependencies.insert(current_id, Vec::new());
1032 current_id
1033 }
1034
1035 CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {
1036 graph.dependencies.insert(current_id, Vec::new());
1037 current_id
1038 }
1039
1040 CircuitNode::BinaryOp { left, right, .. } => {
1041 let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
1042 let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
1043
1044 graph
1045 .dependencies
1046 .insert(current_id, vec![left_id, right_id]);
1047 current_id
1048 }
1049
1050 CircuitNode::UnaryOp { operand, .. } => {
1051 let operand_id = self.build_dependency_graph(operand, graph, node_id_map, next_id);
1052
1053 graph.dependencies.insert(current_id, vec![operand_id]);
1054 current_id
1055 }
1056
1057 CircuitNode::Compare { left, right, .. } => {
1058 let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
1059 let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
1060
1061 graph
1062 .dependencies
1063 .insert(current_id, vec![left_id, right_id]);
1064 current_id
1065 }
1066 }
1067 }
1068
1069 fn identify_parallel_groups(&self, graph: &DependencyGraph) -> Vec<Vec<NodeId>> {
1071 let mut levels: HashMap<NodeId, usize> = HashMap::new();
1072 let mut queue = VecDeque::new();
1073
1074 for (node_id, deps) in &graph.dependencies {
1076 if deps.is_empty() {
1077 levels.insert(*node_id, 0);
1078 queue.push_back(*node_id);
1079 }
1080 }
1081
1082 while let Some(node_id) = queue.pop_front() {
1084 let current_level = levels[&node_id];
1085
1086 for (dependent_id, deps) in &graph.dependencies {
1088 if deps.contains(&node_id) {
1089 let max_dep_level = deps
1091 .iter()
1092 .filter_map(|dep_id| levels.get(dep_id))
1093 .max()
1094 .copied()
1095 .unwrap_or(0);
1096
1097 let dependent_level = max_dep_level + 1;
1098
1099 if !levels.contains_key(dependent_id) {
1100 levels.insert(*dependent_id, dependent_level);
1101 queue.push_back(*dependent_id);
1102 }
1103 }
1104 }
1105 }
1106
1107 let max_level = levels.values().max().copied().unwrap_or(0);
1109 let mut parallel_groups = vec![Vec::new(); max_level + 1];
1110
1111 for (node_id, level) in levels {
1112 parallel_groups[level].push(node_id);
1113 }
1114
1115 for group in &mut parallel_groups {
1117 group.sort();
1118 }
1119
1120 parallel_groups
1121 }
1122
1123 fn find_critical_path(&self, graph: &DependencyGraph) -> Vec<NodeId> {
1125 let mut max_path = Vec::new();
1127
1128 for node_id in graph.dependencies.keys() {
1129 let path = self.find_path_to_root(*node_id, graph);
1130 if path.len() > max_path.len() {
1131 max_path = path;
1132 }
1133 }
1134
1135 max_path
1136 }
1137
1138 #[allow(clippy::only_used_in_recursion)]
1140 fn find_path_to_root(&self, node_id: NodeId, graph: &DependencyGraph) -> Vec<NodeId> {
1141 let deps = graph
1142 .dependencies
1143 .get(&node_id)
1144 .map(|v| v.as_slice())
1145 .unwrap_or(&[]);
1146
1147 if deps.is_empty() {
1148 return vec![node_id];
1149 }
1150
1151 let mut longest_path = Vec::new();
1153 for dep_id in deps {
1154 let dep_path = self.find_path_to_root(*dep_id, graph);
1155 if dep_path.len() > longest_path.len() {
1156 longest_path = dep_path;
1157 }
1158 }
1159
1160 longest_path.push(node_id);
1161 longest_path
1162 }
1163}
1164
1165impl Default for CircuitOptimizer {
1166 fn default() -> Self {
1167 Self::new()
1168 }
1169}
1170
1171#[cfg(test)]
1172mod tests {
1173 use super::*;
1174 use crate::compute::circuit::CircuitBuilder;
1175
1176 #[test]
1179 fn test_constant_folding() -> Result<()> {
1180 let mut optimizer = CircuitOptimizer::new();
1181 let builder = CircuitBuilder::new();
1182
1183 let a = builder.constant(CircuitValue::U8(5));
1185 let b = builder.constant(CircuitValue::U8(3));
1186 let sum = builder.add(a, b);
1187
1188 let circuit = Circuit::new(sum, HashMap::new())?;
1189 let optimized = optimizer.optimize(circuit)?;
1190
1191 assert!(matches!(
1193 optimized.root,
1194 CircuitNode::Constant(CircuitValue::U8(8))
1195 ));
1196 assert!(optimizer.stats().constants_folded >= 1);
1197
1198 Ok(())
1199 }
1200
1201 #[test]
1202 fn test_constant_folding_sub() -> Result<()> {
1203 let mut optimizer = CircuitOptimizer::new();
1204 let builder = CircuitBuilder::new();
1205
1206 let a = builder.constant(CircuitValue::U16(100));
1207 let b = builder.constant(CircuitValue::U16(30));
1208 let result = builder.sub(a, b);
1209
1210 let circuit = Circuit::new(result, HashMap::new())?;
1211 let optimized = optimizer.optimize(circuit)?;
1212
1213 assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U16(70)));
1214 Ok(())
1215 }
1216
1217 #[test]
1218 fn test_constant_folding_mul() -> Result<()> {
1219 let mut optimizer = CircuitOptimizer::new();
1220 let builder = CircuitBuilder::new();
1221
1222 let a = builder.constant(CircuitValue::U32(7));
1223 let b = builder.constant(CircuitValue::U32(6));
1224 let result = builder.mul(a, b);
1225
1226 let circuit = Circuit::new(result, HashMap::new())?;
1227 let optimized = optimizer.optimize(circuit)?;
1228
1229 assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U32(42)));
1230 Ok(())
1231 }
1232
1233 #[test]
1234 fn test_constant_folding_bool_and() -> Result<()> {
1235 let mut optimizer = CircuitOptimizer::new();
1236 let builder = CircuitBuilder::new();
1237
1238 let t = builder.constant(CircuitValue::Bool(true));
1239 let f = builder.constant(CircuitValue::Bool(false));
1240 let result = builder.and(t, f);
1241
1242 let circuit = Circuit::new(result, HashMap::new())?;
1243 let optimized = optimizer.optimize(circuit)?;
1244
1245 assert_eq!(
1246 optimized.root,
1247 CircuitNode::Constant(CircuitValue::Bool(false))
1248 );
1249 Ok(())
1250 }
1251
1252 #[test]
1253 fn test_constant_folding_unary_not() -> Result<()> {
1254 let mut optimizer = CircuitOptimizer::new();
1255 let builder = CircuitBuilder::new();
1256
1257 let t = builder.constant(CircuitValue::Bool(true));
1258 let result = builder.not(t);
1259
1260 let circuit = Circuit::new(result, HashMap::new())?;
1261 let optimized = optimizer.optimize(circuit)?;
1262
1263 assert_eq!(
1264 optimized.root,
1265 CircuitNode::Constant(CircuitValue::Bool(false))
1266 );
1267 Ok(())
1268 }
1269
1270 #[test]
1273 fn test_algebraic_x_plus_zero() -> Result<()> {
1274 let mut optimizer = CircuitOptimizer::new();
1275 let mut builder = CircuitBuilder::new();
1276 builder.declare_variable("x", EncryptedType::U8);
1277
1278 let x = builder.load("x");
1279 let zero = builder.constant(CircuitValue::U8(0));
1280 let add_zero = builder.add(x, zero);
1281
1282 let circuit = Circuit::new(add_zero, builder.variable_types_clone())?;
1283 let optimized = optimizer.optimize(circuit)?;
1284
1285 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1286 Ok(())
1287 }
1288
1289 #[test]
1290 fn test_algebraic_zero_plus_x() -> Result<()> {
1291 let mut optimizer = CircuitOptimizer::new();
1292 let mut builder = CircuitBuilder::new();
1293 builder.declare_variable("x", EncryptedType::U8);
1294
1295 let x = builder.load("x");
1296 let zero = builder.constant(CircuitValue::U8(0));
1297 let result = builder.add(zero, x);
1298
1299 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1300 let optimized = optimizer.optimize(circuit)?;
1301
1302 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1303 Ok(())
1304 }
1305
1306 #[test]
1307 fn test_algebraic_x_mul_one() -> Result<()> {
1308 let mut optimizer = CircuitOptimizer::new();
1309 let mut builder = CircuitBuilder::new();
1310 builder.declare_variable("x", EncryptedType::U8);
1311
1312 let x = builder.load("x");
1313 let one = builder.constant(CircuitValue::U8(1));
1314 let result = builder.mul(x, one);
1315
1316 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1317 let optimized = optimizer.optimize(circuit)?;
1318
1319 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1320 Ok(())
1321 }
1322
1323 #[test]
1324 fn test_algebraic_one_mul_x() -> Result<()> {
1325 let mut optimizer = CircuitOptimizer::new();
1326 let mut builder = CircuitBuilder::new();
1327 builder.declare_variable("x", EncryptedType::U8);
1328
1329 let x = builder.load("x");
1330 let one = builder.constant(CircuitValue::U8(1));
1331 let result = builder.mul(one, x);
1332
1333 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1334 let optimized = optimizer.optimize(circuit)?;
1335
1336 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1337 Ok(())
1338 }
1339
1340 #[test]
1341 fn test_algebraic_x_mul_zero() -> Result<()> {
1342 let mut optimizer = CircuitOptimizer::new();
1343 let mut builder = CircuitBuilder::new();
1344 builder.declare_variable("x", EncryptedType::U8);
1345
1346 let x = builder.load("x");
1347 let zero = builder.constant(CircuitValue::U8(0));
1348 let result = builder.mul(x, zero);
1349
1350 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1351 let optimized = optimizer.optimize(circuit)?;
1352
1353 assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(0)));
1354 Ok(())
1355 }
1356
1357 #[test]
1358 fn test_algebraic_zero_mul_x() -> Result<()> {
1359 let mut optimizer = CircuitOptimizer::new();
1360 let mut builder = CircuitBuilder::new();
1361 builder.declare_variable("x", EncryptedType::U8);
1362
1363 let x = builder.load("x");
1364 let zero = builder.constant(CircuitValue::U8(0));
1365 let result = builder.mul(zero, x);
1366
1367 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1368 let optimized = optimizer.optimize(circuit)?;
1369
1370 assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(0)));
1371 Ok(())
1372 }
1373
1374 #[test]
1375 fn test_algebraic_x_sub_zero() -> Result<()> {
1376 let mut optimizer = CircuitOptimizer::new();
1377 let mut builder = CircuitBuilder::new();
1378 builder.declare_variable("x", EncryptedType::U8);
1379
1380 let x = builder.load("x");
1381 let zero = builder.constant(CircuitValue::U8(0));
1382 let result = builder.sub(x, zero);
1383
1384 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1385 let optimized = optimizer.optimize(circuit)?;
1386
1387 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1388 Ok(())
1389 }
1390
1391 #[test]
1392 fn test_algebraic_x_sub_x() -> Result<()> {
1393 let mut optimizer = CircuitOptimizer::new();
1394 let mut builder = CircuitBuilder::new();
1395 builder.declare_variable("x", EncryptedType::U8);
1396
1397 let x1 = builder.load("x");
1398 let x2 = builder.load("x");
1399 let result = builder.sub(x1, x2);
1400
1401 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1402 let optimized = optimizer.optimize(circuit)?;
1403
1404 assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(0)));
1406 assert!(optimizer.stats().algebraic_simplifications >= 1);
1407 Ok(())
1408 }
1409
1410 #[test]
1413 fn test_double_negation_elimination() -> Result<()> {
1414 let mut optimizer = CircuitOptimizer::new();
1415 let mut builder = CircuitBuilder::new();
1416 builder.declare_variable("x", EncryptedType::Bool);
1417
1418 let x = builder.load("x");
1419 let not_x = builder.not(x);
1420 let not_not_x = builder.not(not_x);
1421
1422 let circuit = Circuit::new(not_not_x, builder.variable_types_clone())?;
1423 let optimized = optimizer.optimize(circuit)?;
1424
1425 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1426 Ok(())
1427 }
1428
1429 #[test]
1430 fn test_quadruple_negation_elimination() -> Result<()> {
1431 let mut optimizer = CircuitOptimizer::new();
1432 let mut builder = CircuitBuilder::new();
1433 builder.declare_variable("x", EncryptedType::Bool);
1434
1435 let x = builder.load("x");
1436 let n1 = builder.not(x);
1437 let n2 = builder.not(n1);
1438 let n3 = builder.not(n2);
1439 let n4 = builder.not(n3);
1440
1441 let circuit = Circuit::new(n4, builder.variable_types_clone())?;
1442 let optimized = optimizer.optimize(circuit)?;
1443
1444 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1445 Ok(())
1446 }
1447
1448 #[test]
1451 fn test_nested_x_plus_0_times_1() -> Result<()> {
1452 let mut optimizer = CircuitOptimizer::new();
1453 let mut builder = CircuitBuilder::new();
1454 builder.declare_variable("x", EncryptedType::U8);
1455
1456 let x = builder.load("x");
1458 let zero = builder.constant(CircuitValue::U8(0));
1459 let one = builder.constant(CircuitValue::U8(1));
1460 let add_zero = builder.add(x, zero);
1461 let times_one = builder.mul(add_zero, one);
1462
1463 let circuit = Circuit::new(times_one, builder.variable_types_clone())?;
1464 let optimized = optimizer.optimize(circuit)?;
1465
1466 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1467 Ok(())
1468 }
1469
1470 #[test]
1471 fn test_nested_complex_optimization() -> Result<()> {
1472 let mut optimizer = CircuitOptimizer::new();
1473 let mut builder = CircuitBuilder::new();
1474 builder
1475 .declare_variable("a", EncryptedType::U8)
1476 .declare_variable("b", EncryptedType::U8);
1477
1478 let a = builder.load("a");
1480 let b = builder.load("b");
1481 let one = builder.constant(CircuitValue::U8(1));
1482 let zero = builder.constant(CircuitValue::U8(0));
1483 let five = builder.constant(CircuitValue::U8(5));
1484
1485 let a_times_1 = builder.mul(a, one);
1486 let b_times_0 = builder.mul(b, zero);
1487 let sum1 = builder.add(a_times_1, b_times_0);
1488 let result = builder.add(sum1, five);
1489
1490 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1491 let original_gates = circuit.gate_count;
1492
1493 let optimized = optimizer.optimize(circuit)?;
1494
1495 assert!(optimized.gate_count < original_gates);
1496 assert!(optimizer.stats().gate_reduction_percent() >= 30.0);
1497
1498 Ok(())
1499 }
1500
1501 #[test]
1504 fn test_noop_on_optimal_circuit() -> Result<()> {
1505 let mut optimizer = CircuitOptimizer::new();
1506 let mut builder = CircuitBuilder::new();
1507 builder
1508 .declare_variable("a", EncryptedType::U8)
1509 .declare_variable("b", EncryptedType::U8);
1510
1511 let a = builder.load("a");
1513 let b = builder.load("b");
1514 let result = builder.add(a, b);
1515
1516 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1517 let original_gates = circuit.gate_count;
1518
1519 let optimized = optimizer.optimize(circuit)?;
1520
1521 assert_eq!(optimized.gate_count, original_gates);
1522 assert_eq!(
1523 optimized.root,
1524 CircuitNode::BinaryOp {
1525 op: BinaryOperator::Add,
1526 left: Box::new(CircuitNode::Load("a".to_string())),
1527 right: Box::new(CircuitNode::Load("b".to_string())),
1528 }
1529 );
1530 Ok(())
1531 }
1532
1533 #[test]
1534 fn test_noop_single_load() -> Result<()> {
1535 let mut optimizer = CircuitOptimizer::new();
1536 let mut builder = CircuitBuilder::new();
1537 builder.declare_variable("x", EncryptedType::U8);
1538
1539 let x = builder.load("x");
1540 let circuit = Circuit::new(x, builder.variable_types_clone())?;
1541 let optimized = optimizer.optimize(circuit)?;
1542
1543 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1544 Ok(())
1545 }
1546
1547 #[test]
1550 fn test_stats_accuracy_constant_folding() -> Result<()> {
1551 let mut optimizer = CircuitOptimizer::new();
1552 let builder = CircuitBuilder::new();
1553
1554 let a = builder.constant(CircuitValue::U8(5));
1556 let b = builder.constant(CircuitValue::U8(3));
1557 let two = builder.constant(CircuitValue::U8(2));
1558 let sum = builder.add(a, b);
1559 let result = builder.mul(sum, two);
1560
1561 let circuit = Circuit::new(result, HashMap::new())?;
1562 let optimized = optimizer.optimize(circuit)?;
1563
1564 assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(16)));
1565 assert!(optimizer.stats().constants_folded >= 2);
1567 Ok(())
1568 }
1569
1570 #[test]
1571 fn test_stats_accuracy_algebraic() -> Result<()> {
1572 let mut optimizer = CircuitOptimizer::new();
1573 let mut builder = CircuitBuilder::new();
1574 builder.declare_variable("x", EncryptedType::U8);
1575
1576 let x1 = builder.load("x");
1578 let x2 = builder.load("x");
1579 let result = builder.sub(x1, x2);
1580
1581 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1582 let _optimized = optimizer.optimize(circuit)?;
1583
1584 let (total_eliminated, total_algebraic, _total_folds) = optimizer.total_stats();
1585 assert!(total_eliminated >= 1);
1586 assert!(total_algebraic >= 1);
1587 Ok(())
1588 }
1589
1590 #[test]
1591 fn test_optimization_stats() -> Result<()> {
1592 let mut optimizer = CircuitOptimizer::new();
1593 let builder = CircuitBuilder::new();
1594
1595 let a = builder.constant(CircuitValue::U8(5));
1596 let b = builder.constant(CircuitValue::U8(3));
1597 let zero = builder.constant(CircuitValue::U8(0));
1598
1599 let sum = builder.add(a, b);
1600 let add_zero = builder.add(sum, zero);
1601
1602 let circuit = Circuit::new(add_zero, HashMap::new())?;
1603 let original_gates = circuit.gate_count;
1604
1605 let optimized = optimizer.optimize(circuit)?;
1606 let optimized_gates = optimized.gate_count;
1607
1608 assert!(optimized_gates < original_gates);
1609 assert!(optimizer.stats().gate_reduction_percent() > 0.0);
1610
1611 Ok(())
1612 }
1613
1614 #[test]
1615 fn test_total_stats_method() -> Result<()> {
1616 let mut optimizer = CircuitOptimizer::new();
1617 let mut builder = CircuitBuilder::new();
1618 builder.declare_variable("x", EncryptedType::U8);
1619
1620 let x = builder.load("x");
1623 let zero = builder.constant(CircuitValue::U8(0));
1624 let one = builder.constant(CircuitValue::U8(1));
1625 let add_zero = builder.add(x, zero);
1626 let times_one = builder.mul(add_zero, one);
1627
1628 let circuit = Circuit::new(times_one, builder.variable_types_clone())?;
1629 let _optimized = optimizer.optimize(circuit)?;
1630
1631 let (eliminated, algebraic, _folds) = optimizer.total_stats();
1632 assert!(eliminated + algebraic >= 2);
1634 Ok(())
1635 }
1636
1637 #[test]
1640 fn test_bootstrap_counting() -> Result<()> {
1641 let optimizer = CircuitOptimizer::new();
1642 let mut builder = CircuitBuilder::new();
1643 builder
1644 .declare_variable("a", EncryptedType::U8)
1645 .declare_variable("b", EncryptedType::U8);
1646
1647 let a = builder.load("a");
1648 let b = builder.load("b");
1649 let mul = builder.mul(a, b);
1650
1651 let circuit = Circuit::new(mul, builder.variable_types_clone())?;
1652 let bootstrap_count = optimizer.count_bootstraps(&circuit.root);
1653
1654 assert_eq!(bootstrap_count, 1);
1655 Ok(())
1656 }
1657
1658 #[test]
1661 fn test_parallelization_analysis() -> Result<()> {
1662 let mut optimizer = CircuitOptimizer::new();
1663 let mut builder = CircuitBuilder::new();
1664 builder
1665 .declare_variable("a", EncryptedType::U8)
1666 .declare_variable("b", EncryptedType::U8)
1667 .declare_variable("c", EncryptedType::U8);
1668
1669 let a = builder.load("a");
1670 let b = builder.load("b");
1671 let c = builder.load("c");
1672 let sum1 = builder.add(a, b);
1673 let sum2 = builder.add(sum1, c);
1674
1675 let circuit = Circuit::new(sum2, builder.variable_types_clone())?;
1676 let _optimized = optimizer.optimize(circuit)?;
1677
1678 let graph = optimizer.dependency_graph();
1679 assert!(graph.node_count > 0);
1680 assert!(!graph.parallel_groups.is_empty());
1681
1682 Ok(())
1683 }
1684
1685 #[test]
1688 fn test_collect_live_variables() -> Result<()> {
1689 let optimizer = CircuitOptimizer::new();
1690 let mut builder = CircuitBuilder::new();
1691 builder
1692 .declare_variable("a", EncryptedType::U8)
1693 .declare_variable("b", EncryptedType::U8);
1694
1695 let a = builder.load("a");
1696 let b = builder.load("b");
1697 let result = builder.add(a, b);
1698
1699 let live = optimizer.collect_live_variables(&result);
1700 assert!(live.contains("a"));
1701 assert!(live.contains("b"));
1702 assert_eq!(live.len(), 2);
1703 Ok(())
1704 }
1705
1706 #[test]
1707 fn test_collect_live_variables_after_dce() -> Result<()> {
1708 let mut optimizer = CircuitOptimizer::new();
1709 let mut builder = CircuitBuilder::new();
1710 builder
1711 .declare_variable("a", EncryptedType::U8)
1712 .declare_variable("b", EncryptedType::U8);
1713
1714 let a = builder.load("a");
1717 let b = builder.load("b");
1718 let one = builder.constant(CircuitValue::U8(1));
1719 let zero = builder.constant(CircuitValue::U8(0));
1720 let a1 = builder.mul(a, one);
1721 let b0 = builder.mul(b, zero);
1722 let result = builder.add(a1, b0);
1723
1724 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1725 let optimized = optimizer.optimize(circuit)?;
1726
1727 let live = optimizer.collect_live_variables(&optimized.root);
1728 assert!(live.contains("a"));
1729 assert!(!live.contains("b"), "b should be eliminated by DCE");
1731 Ok(())
1732 }
1733
1734 #[test]
1737 fn test_comparison_constant_fold() -> Result<()> {
1738 let mut optimizer = CircuitOptimizer::new();
1739 let builder = CircuitBuilder::new();
1740
1741 let a = builder.constant(CircuitValue::U8(10));
1742 let b = builder.constant(CircuitValue::U8(5));
1743 let result = builder.gt(a, b);
1744
1745 let circuit = Circuit::new(result, HashMap::new())?;
1746 let optimized = optimizer.optimize(circuit)?;
1747
1748 assert_eq!(
1749 optimized.root,
1750 CircuitNode::Constant(CircuitValue::Bool(true))
1751 );
1752 Ok(())
1753 }
1754
1755 #[test]
1756 fn test_comparison_constant_fold_eq() -> Result<()> {
1757 let mut optimizer = CircuitOptimizer::new();
1758 let builder = CircuitBuilder::new();
1759
1760 let a = builder.constant(CircuitValue::U8(5));
1761 let b = builder.constant(CircuitValue::U8(5));
1762 let result = builder.eq(a, b);
1763
1764 let circuit = Circuit::new(result, HashMap::new())?;
1765 let optimized = optimizer.optimize(circuit)?;
1766
1767 assert_eq!(
1768 optimized.root,
1769 CircuitNode::Constant(CircuitValue::Bool(true))
1770 );
1771 Ok(())
1772 }
1773
1774 #[test]
1777 fn test_xor_self_elimination() -> Result<()> {
1778 let mut optimizer = CircuitOptimizer::new();
1779 let mut builder = CircuitBuilder::new();
1780 builder.declare_variable("x", EncryptedType::Bool);
1781
1782 let x1 = builder.load("x");
1783 let x2 = builder.load("x");
1784 let result = builder.xor(x1, x2);
1785
1786 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1787 let optimized = optimizer.optimize(circuit)?;
1788
1789 assert_eq!(
1790 optimized.root,
1791 CircuitNode::Constant(CircuitValue::Bool(false))
1792 );
1793 Ok(())
1794 }
1795
1796 #[test]
1799 fn test_and_idempotent() -> Result<()> {
1800 let mut optimizer = CircuitOptimizer::new();
1801 let mut builder = CircuitBuilder::new();
1802 builder.declare_variable("x", EncryptedType::Bool);
1803
1804 let x1 = builder.load("x");
1805 let x2 = builder.load("x");
1806 let result = builder.and(x1, x2);
1807
1808 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1809 let optimized = optimizer.optimize(circuit)?;
1810
1811 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1812 Ok(())
1813 }
1814
1815 #[test]
1816 fn test_or_idempotent() -> Result<()> {
1817 let mut optimizer = CircuitOptimizer::new();
1818 let mut builder = CircuitBuilder::new();
1819 builder.declare_variable("x", EncryptedType::Bool);
1820
1821 let x1 = builder.load("x");
1822 let x2 = builder.load("x");
1823 let result = builder.or(x1, x2);
1824
1825 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1826 let optimized = optimizer.optimize(circuit)?;
1827
1828 assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
1829 Ok(())
1830 }
1831
1832 #[test]
1835 fn test_optimizer_does_not_fold_encrypted_constants() -> Result<()> {
1836 use crate::compute::circuit::ConstantType;
1837
1838 let mut optimizer = CircuitOptimizer::new();
1839 let builder = CircuitBuilder::new();
1840
1841 let enc_a = builder.encrypted_constant(vec![0x01, 0x05], ConstantType::Integer);
1845 let enc_b = builder.encrypted_constant(vec![0x01, 0x03], ConstantType::Integer);
1846 let sum = builder.add(enc_a.clone(), enc_b.clone());
1847
1848 let circuit = Circuit::new(sum, HashMap::new())?;
1849 let optimized = optimizer.optimize(circuit)?;
1850
1851 match &optimized.root {
1853 CircuitNode::BinaryOp { op, left, right } => {
1854 assert_eq!(*op, BinaryOperator::Add);
1855 assert!(matches!(**left, CircuitNode::EncryptedConstant { .. }));
1856 assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
1857 }
1858 _ => {
1859 return Err(AmateRSError::FheComputation(ErrorContext::new(
1860 "Optimizer incorrectly folded encrypted constants".to_string(),
1861 )));
1862 }
1863 }
1864
1865 assert_eq!(optimizer.stats().constants_folded, 0);
1867
1868 Ok(())
1869 }
1870
1871 #[test]
1872 fn test_optimizer_dce_treats_encrypted_constant_as_opaque() -> Result<()> {
1873 use crate::compute::circuit::ConstantType;
1874
1875 let mut optimizer = CircuitOptimizer::new();
1876
1877 let enc = CircuitNode::EncryptedConstant {
1880 data: vec![0x04, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11],
1881 original_type: ConstantType::Integer,
1882 };
1883
1884 let circuit = Circuit::new(enc.clone(), HashMap::new())?;
1885 let optimized = optimizer.optimize(circuit)?;
1886
1887 assert_eq!(optimized.root, enc);
1889
1890 Ok(())
1891 }
1892
1893 #[test]
1894 fn test_optimizer_mixed_plain_and_encrypted_constants() -> Result<()> {
1895 use crate::compute::circuit::ConstantType;
1896
1897 let mut optimizer = CircuitOptimizer::new();
1898 let builder = CircuitBuilder::new();
1899
1900 let plain_a = builder.constant(CircuitValue::U8(5));
1902 let plain_b = builder.constant(CircuitValue::U8(3));
1903 let plain_sum = builder.add(plain_a, plain_b);
1904
1905 let circuit = Circuit::new(plain_sum, HashMap::new())?;
1906 let optimized = optimizer.optimize(circuit)?;
1907
1908 assert!(matches!(
1910 optimized.root,
1911 CircuitNode::Constant(CircuitValue::U8(8))
1912 ));
1913
1914 let mut optimizer2 = CircuitOptimizer::new();
1916 let enc_a = builder.encrypted_constant(vec![0x01, 0xAA], ConstantType::Integer);
1917 let enc_b = builder.encrypted_constant(vec![0x01, 0xBB], ConstantType::Integer);
1918 let enc_sum = builder.add(enc_a, enc_b);
1919
1920 let circuit2 = Circuit::new(enc_sum, HashMap::new())?;
1921 let optimized2 = optimizer2.optimize(circuit2)?;
1922
1923 assert!(matches!(optimized2.root, CircuitNode::BinaryOp { .. }));
1924
1925 Ok(())
1926 }
1927
1928 #[test]
1929 fn test_optimizer_algebraic_identity_with_encrypted_constant() -> Result<()> {
1930 use crate::compute::circuit::ConstantType;
1931
1932 let mut optimizer = CircuitOptimizer::new();
1933 let builder = CircuitBuilder::new();
1934
1935 let enc = builder.encrypted_constant(
1941 vec![0x04, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
1942 ConstantType::Integer,
1943 );
1944 let zero = builder.constant(CircuitValue::U64(0));
1945 let sum = builder.add(enc.clone(), zero);
1946
1947 let circuit = Circuit::new(sum, HashMap::new())?;
1948 let optimized = optimizer.optimize(circuit)?;
1949
1950 assert_eq!(optimized.root, enc);
1952
1953 Ok(())
1954 }
1955
1956 #[test]
1957 fn test_optimizer_live_variables_with_encrypted_constants() -> Result<()> {
1958 use crate::compute::circuit::ConstantType;
1959
1960 let optimizer = CircuitOptimizer::new();
1961 let mut builder = CircuitBuilder::new();
1962 builder.declare_variable("x", EncryptedType::U8);
1963
1964 let x = builder.load("x");
1966 let enc = builder.encrypted_constant(vec![0x01, 0x10], ConstantType::Integer);
1967 let sum = builder.add(x, enc);
1968
1969 let live = optimizer.collect_live_variables(&sum);
1970
1971 assert!(live.contains("x"));
1973 assert_eq!(live.len(), 1);
1974
1975 Ok(())
1976 }
1977}