1use crate::compute::circuit::{
14 BinaryOperator, Circuit, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
15 UnaryOperator,
16};
17use crate::error::{AmateRSError, ErrorContext, Result};
18use std::collections::{HashMap, HashSet, VecDeque};
19
20#[derive(Debug, Clone, Default, PartialEq, Eq)]
22pub struct OptimizationStats {
23 pub original_gate_count: usize,
25
26 pub optimized_gate_count: usize,
28
29 pub original_bootstrap_count: usize,
31
32 pub optimized_bootstrap_count: usize,
34
35 pub dead_code_removed: usize,
37
38 pub constants_folded: usize,
40
41 pub gates_fused: usize,
43
44 pub original_depth: usize,
46
47 pub optimized_depth: usize,
49}
50
51impl OptimizationStats {
52 pub fn gate_reduction_percent(&self) -> f64 {
54 if self.original_gate_count == 0 {
55 return 0.0;
56 }
57 let reduction = self
58 .original_gate_count
59 .saturating_sub(self.optimized_gate_count);
60 (reduction as f64 / self.original_gate_count as f64) * 100.0
61 }
62
63 pub fn bootstrap_reduction_percent(&self) -> f64 {
65 if self.original_bootstrap_count == 0 {
66 return 0.0;
67 }
68 let reduction = self
69 .original_bootstrap_count
70 .saturating_sub(self.optimized_bootstrap_count);
71 (reduction as f64 / self.original_bootstrap_count as f64) * 100.0
72 }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct DependencyGraph {
78 pub dependencies: HashMap<NodeId, Vec<NodeId>>,
80
81 pub parallel_groups: Vec<Vec<NodeId>>,
83
84 pub critical_path: Vec<NodeId>,
86
87 pub node_count: usize,
89}
90
91impl DependencyGraph {
92 pub fn new() -> Self {
94 Self {
95 dependencies: HashMap::new(),
96 parallel_groups: Vec::new(),
97 critical_path: Vec::new(),
98 node_count: 0,
99 }
100 }
101
102 pub fn max_parallelism(&self) -> usize {
104 self.parallel_groups
105 .iter()
106 .map(|g| g.len())
107 .max()
108 .unwrap_or(0)
109 }
110
111 pub fn avg_parallelism(&self) -> f64 {
113 if self.parallel_groups.is_empty() {
114 return 0.0;
115 }
116 let total: usize = self.parallel_groups.iter().map(|g| g.len()).sum();
117 total as f64 / self.parallel_groups.len() as f64
118 }
119}
120
121impl Default for DependencyGraph {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
129pub struct NodeId(pub usize);
130
131#[derive(Debug, Clone)]
133pub struct CircuitOptimizer {
134 pub enable_constant_folding: bool,
136
137 pub enable_dead_code_elimination: bool,
139
140 pub enable_bootstrap_minimization: bool,
142
143 pub enable_gate_fusion: bool,
145
146 pub enable_parallelization_analysis: bool,
148
149 stats: OptimizationStats,
151
152 dependency_graph: DependencyGraph,
154}
155
156impl CircuitOptimizer {
157 pub fn new() -> Self {
159 Self {
160 enable_constant_folding: true,
161 enable_dead_code_elimination: true,
162 enable_bootstrap_minimization: true,
163 enable_gate_fusion: true,
164 enable_parallelization_analysis: true,
165 stats: OptimizationStats::default(),
166 dependency_graph: DependencyGraph::new(),
167 }
168 }
169
170 pub fn disabled() -> Self {
172 Self {
173 enable_constant_folding: false,
174 enable_dead_code_elimination: false,
175 enable_bootstrap_minimization: false,
176 enable_gate_fusion: false,
177 enable_parallelization_analysis: false,
178 stats: OptimizationStats::default(),
179 dependency_graph: DependencyGraph::new(),
180 }
181 }
182
183 pub fn stats(&self) -> &OptimizationStats {
185 &self.stats
186 }
187
188 pub fn dependency_graph(&self) -> &DependencyGraph {
190 &self.dependency_graph
191 }
192
193 pub fn optimize(&mut self, circuit: Circuit) -> Result<Circuit> {
195 self.stats.original_gate_count = circuit.gate_count;
197 self.stats.original_depth = circuit.depth;
198 self.stats.original_bootstrap_count = self.count_bootstraps(&circuit.root);
199
200 let mut optimized_root = circuit.root.clone();
201
202 if self.enable_constant_folding {
204 optimized_root = self.constant_folding_pass(optimized_root);
205 }
206
207 if self.enable_gate_fusion {
208 optimized_root = self.gate_fusion_pass(optimized_root);
209 }
210
211 if self.enable_bootstrap_minimization {
212 optimized_root = self.bootstrap_minimization_pass(optimized_root)?;
213 }
214
215 if self.enable_dead_code_elimination {
216 optimized_root = self.dead_code_elimination_pass(optimized_root);
217 }
218
219 let optimized_circuit = Circuit::new(optimized_root, circuit.variable_types)?;
221
222 self.stats.optimized_gate_count = optimized_circuit.gate_count;
224 self.stats.optimized_depth = optimized_circuit.depth;
225 self.stats.optimized_bootstrap_count = self.count_bootstraps(&optimized_circuit.root);
226
227 if self.enable_parallelization_analysis {
229 self.dependency_graph = self.analyze_parallelism(&optimized_circuit)?;
230 }
231
232 Ok(optimized_circuit)
233 }
234
235 #[allow(clippy::only_used_in_recursion)]
243 fn count_bootstraps(&self, node: &CircuitNode) -> usize {
244 match node {
245 CircuitNode::Load(_) | CircuitNode::Constant(_) => 0,
246
247 CircuitNode::BinaryOp { op, left, right } => {
248 let left_bootstraps = self.count_bootstraps(left);
249 let right_bootstraps = self.count_bootstraps(right);
250
251 let op_bootstrap = match op {
253 BinaryOperator::Mul => 1,
254 _ => 0,
255 };
256
257 left_bootstraps + right_bootstraps + op_bootstrap
258 }
259
260 CircuitNode::UnaryOp { operand, .. } => self.count_bootstraps(operand),
261
262 CircuitNode::Compare { left, right, .. } => {
263 let left_bootstraps = self.count_bootstraps(left);
264 let right_bootstraps = self.count_bootstraps(right);
265
266 left_bootstraps + right_bootstraps + 1
268 }
269 }
270 }
271
272 fn constant_folding_pass(&mut self, node: CircuitNode) -> CircuitNode {
276 match node {
277 CircuitNode::BinaryOp { op, left, right } => {
278 let left = self.constant_folding_pass(*left);
279 let right = self.constant_folding_pass(*right);
280
281 if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
283 if let Some(result) = self.fold_binary_constants(op, l, r) {
284 self.stats.constants_folded += 1;
285 return CircuitNode::Constant(result);
286 }
287 }
288
289 if let Some(simplified) = self.apply_algebraic_identities(op, &left, &right) {
291 return simplified;
292 }
293
294 CircuitNode::BinaryOp {
295 op,
296 left: Box::new(left),
297 right: Box::new(right),
298 }
299 }
300
301 CircuitNode::UnaryOp { op, operand } => {
302 let operand = self.constant_folding_pass(*operand);
303
304 if let CircuitNode::Constant(val) = &operand {
305 if let Some(result) = self.fold_unary_constant(op, val) {
306 self.stats.constants_folded += 1;
307 return CircuitNode::Constant(result);
308 }
309 }
310
311 CircuitNode::UnaryOp {
312 op,
313 operand: Box::new(operand),
314 }
315 }
316
317 CircuitNode::Compare { op, left, right } => {
318 let left = self.constant_folding_pass(*left);
319 let right = self.constant_folding_pass(*right);
320
321 CircuitNode::Compare {
322 op,
323 left: Box::new(left),
324 right: Box::new(right),
325 }
326 }
327
328 other => other,
329 }
330 }
331
332 fn fold_binary_constants(
334 &self,
335 op: BinaryOperator,
336 left: &CircuitValue,
337 right: &CircuitValue,
338 ) -> Option<CircuitValue> {
339 match (left, right) {
340 (CircuitValue::U8(l), CircuitValue::U8(r)) => match op {
341 BinaryOperator::Add => Some(CircuitValue::U8(l.wrapping_add(*r))),
342 BinaryOperator::Sub => Some(CircuitValue::U8(l.wrapping_sub(*r))),
343 BinaryOperator::Mul => Some(CircuitValue::U8(l.wrapping_mul(*r))),
344 _ => None,
345 },
346 (CircuitValue::U16(l), CircuitValue::U16(r)) => match op {
347 BinaryOperator::Add => Some(CircuitValue::U16(l.wrapping_add(*r))),
348 BinaryOperator::Sub => Some(CircuitValue::U16(l.wrapping_sub(*r))),
349 BinaryOperator::Mul => Some(CircuitValue::U16(l.wrapping_mul(*r))),
350 _ => None,
351 },
352 (CircuitValue::U32(l), CircuitValue::U32(r)) => match op {
353 BinaryOperator::Add => Some(CircuitValue::U32(l.wrapping_add(*r))),
354 BinaryOperator::Sub => Some(CircuitValue::U32(l.wrapping_sub(*r))),
355 BinaryOperator::Mul => Some(CircuitValue::U32(l.wrapping_mul(*r))),
356 _ => None,
357 },
358 (CircuitValue::U64(l), CircuitValue::U64(r)) => match op {
359 BinaryOperator::Add => Some(CircuitValue::U64(l.wrapping_add(*r))),
360 BinaryOperator::Sub => Some(CircuitValue::U64(l.wrapping_sub(*r))),
361 BinaryOperator::Mul => Some(CircuitValue::U64(l.wrapping_mul(*r))),
362 _ => None,
363 },
364 (CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
365 BinaryOperator::And => Some(CircuitValue::Bool(*l && *r)),
366 BinaryOperator::Or => Some(CircuitValue::Bool(*l || *r)),
367 BinaryOperator::Xor => Some(CircuitValue::Bool(*l ^ *r)),
368 _ => None,
369 },
370 _ => None,
371 }
372 }
373
374 fn fold_unary_constant(&self, op: UnaryOperator, value: &CircuitValue) -> Option<CircuitValue> {
376 match (op, value) {
377 (UnaryOperator::Not, CircuitValue::Bool(v)) => Some(CircuitValue::Bool(!*v)),
378 _ => None,
379 }
380 }
381
382 fn apply_algebraic_identities(
385 &mut self,
386 op: BinaryOperator,
387 left: &CircuitNode,
388 right: &CircuitNode,
389 ) -> Option<CircuitNode> {
390 match op {
391 BinaryOperator::Add => {
392 if Self::is_zero(right) {
394 self.stats.gates_fused += 1;
395 return Some(left.clone());
396 }
397 if Self::is_zero(left) {
399 self.stats.gates_fused += 1;
400 return Some(right.clone());
401 }
402 }
403
404 BinaryOperator::Sub => {
405 if Self::is_zero(right) {
407 self.stats.gates_fused += 1;
408 return Some(left.clone());
409 }
410 }
411
412 BinaryOperator::Mul => {
413 if Self::is_zero(right) {
415 self.stats.gates_fused += 1;
416 return Some(right.clone());
417 }
418 if Self::is_zero(left) {
419 self.stats.gates_fused += 1;
420 return Some(left.clone());
421 }
422
423 if Self::is_one(right) {
425 self.stats.gates_fused += 1;
426 return Some(left.clone());
427 }
428 if Self::is_one(left) {
430 self.stats.gates_fused += 1;
431 return Some(right.clone());
432 }
433 }
434
435 BinaryOperator::And => {
436 if Self::is_true(right) {
438 self.stats.gates_fused += 1;
439 return Some(left.clone());
440 }
441 if Self::is_true(left) {
442 self.stats.gates_fused += 1;
443 return Some(right.clone());
444 }
445
446 if Self::is_false(right) {
448 self.stats.gates_fused += 1;
449 return Some(right.clone());
450 }
451 if Self::is_false(left) {
452 self.stats.gates_fused += 1;
453 return Some(left.clone());
454 }
455 }
456
457 BinaryOperator::Or => {
458 if Self::is_false(right) {
460 self.stats.gates_fused += 1;
461 return Some(left.clone());
462 }
463 if Self::is_false(left) {
464 self.stats.gates_fused += 1;
465 return Some(right.clone());
466 }
467
468 if Self::is_true(right) {
470 self.stats.gates_fused += 1;
471 return Some(right.clone());
472 }
473 if Self::is_true(left) {
474 self.stats.gates_fused += 1;
475 return Some(left.clone());
476 }
477 }
478
479 BinaryOperator::Xor => {
480 if Self::is_false(right) {
482 self.stats.gates_fused += 1;
483 return Some(left.clone());
484 }
485 if Self::is_false(left) {
486 self.stats.gates_fused += 1;
487 return Some(right.clone());
488 }
489 }
490 }
491
492 None
493 }
494
495 fn is_zero(node: &CircuitNode) -> bool {
497 matches!(
498 node,
499 CircuitNode::Constant(CircuitValue::U8(0))
500 | CircuitNode::Constant(CircuitValue::U16(0))
501 | CircuitNode::Constant(CircuitValue::U32(0))
502 | CircuitNode::Constant(CircuitValue::U64(0))
503 )
504 }
505
506 fn is_one(node: &CircuitNode) -> bool {
508 matches!(
509 node,
510 CircuitNode::Constant(CircuitValue::U8(1))
511 | CircuitNode::Constant(CircuitValue::U16(1))
512 | CircuitNode::Constant(CircuitValue::U32(1))
513 | CircuitNode::Constant(CircuitValue::U64(1))
514 )
515 }
516
517 fn is_true(node: &CircuitNode) -> bool {
519 matches!(node, CircuitNode::Constant(CircuitValue::Bool(true)))
520 }
521
522 fn is_false(node: &CircuitNode) -> bool {
524 matches!(node, CircuitNode::Constant(CircuitValue::Bool(false)))
525 }
526
527 fn gate_fusion_pass(&mut self, node: CircuitNode) -> CircuitNode {
533 match node {
534 CircuitNode::BinaryOp { op, left, right } => {
535 let left = self.gate_fusion_pass(*left);
536 let right = self.gate_fusion_pass(*right);
537
538 CircuitNode::BinaryOp {
539 op,
540 left: Box::new(left),
541 right: Box::new(right),
542 }
543 }
544
545 CircuitNode::UnaryOp {
546 op: UnaryOperator::Not,
547 operand,
548 } => {
549 let operand = self.gate_fusion_pass(*operand);
550
551 if let CircuitNode::UnaryOp {
553 op: UnaryOperator::Not,
554 operand: inner,
555 } = operand
556 {
557 self.stats.gates_fused += 2; return *inner;
559 }
560
561 CircuitNode::UnaryOp {
562 op: UnaryOperator::Not,
563 operand: Box::new(operand),
564 }
565 }
566
567 CircuitNode::UnaryOp { op, operand } => {
568 let operand = self.gate_fusion_pass(*operand);
569 CircuitNode::UnaryOp {
570 op,
571 operand: Box::new(operand),
572 }
573 }
574
575 CircuitNode::Compare { op, left, right } => {
576 let left = self.gate_fusion_pass(*left);
577 let right = self.gate_fusion_pass(*right);
578
579 CircuitNode::Compare {
580 op,
581 left: Box::new(left),
582 right: Box::new(right),
583 }
584 }
585
586 other => other,
587 }
588 }
589
590 fn bootstrap_minimization_pass(&mut self, node: CircuitNode) -> Result<CircuitNode> {
597 Ok(self.reorder_for_bootstrap_efficiency(node))
600 }
601
602 #[allow(clippy::only_used_in_recursion)]
604 fn reorder_for_bootstrap_efficiency(&self, node: CircuitNode) -> CircuitNode {
605 match node {
606 CircuitNode::BinaryOp { op, left, right } => {
607 let left = self.reorder_for_bootstrap_efficiency(*left);
608 let right = self.reorder_for_bootstrap_efficiency(*right);
609
610 CircuitNode::BinaryOp {
611 op,
612 left: Box::new(left),
613 right: Box::new(right),
614 }
615 }
616
617 CircuitNode::UnaryOp { op, operand } => {
618 let operand = self.reorder_for_bootstrap_efficiency(*operand);
619 CircuitNode::UnaryOp {
620 op,
621 operand: Box::new(operand),
622 }
623 }
624
625 CircuitNode::Compare { op, left, right } => {
626 let left = self.reorder_for_bootstrap_efficiency(*left);
627 let right = self.reorder_for_bootstrap_efficiency(*right);
628
629 CircuitNode::Compare {
630 op,
631 left: Box::new(left),
632 right: Box::new(right),
633 }
634 }
635
636 other => other,
637 }
638 }
639
640 fn dead_code_elimination_pass(&mut self, node: CircuitNode) -> CircuitNode {
644 let mut live_nodes = HashSet::new();
646 self.mark_live_nodes(&node, &mut live_nodes);
647
648 node
652 }
653
654 #[allow(clippy::only_used_in_recursion)]
656 fn mark_live_nodes(&self, node: &CircuitNode, live_nodes: &mut HashSet<String>) {
657 match node {
658 CircuitNode::Load(name) => {
659 live_nodes.insert(name.clone());
660 }
661
662 CircuitNode::Constant(_) => {}
663
664 CircuitNode::BinaryOp { left, right, .. } => {
665 self.mark_live_nodes(left, live_nodes);
666 self.mark_live_nodes(right, live_nodes);
667 }
668
669 CircuitNode::UnaryOp { operand, .. } => {
670 self.mark_live_nodes(operand, live_nodes);
671 }
672
673 CircuitNode::Compare { left, right, .. } => {
674 self.mark_live_nodes(left, live_nodes);
675 self.mark_live_nodes(right, live_nodes);
676 }
677 }
678 }
679
680 fn analyze_parallelism(&self, circuit: &Circuit) -> Result<DependencyGraph> {
684 let mut graph = DependencyGraph::new();
685 let mut node_id_map = HashMap::new();
686 let mut next_id = 0;
687
688 self.build_dependency_graph(&circuit.root, &mut graph, &mut node_id_map, &mut next_id);
690
691 graph.node_count = next_id;
692
693 graph.parallel_groups = self.identify_parallel_groups(&graph);
695
696 graph.critical_path = self.find_critical_path(&graph);
698
699 Ok(graph)
700 }
701
702 #[allow(clippy::only_used_in_recursion)]
704 fn build_dependency_graph(
705 &self,
706 node: &CircuitNode,
707 graph: &mut DependencyGraph,
708 node_id_map: &mut HashMap<String, NodeId>,
709 next_id: &mut usize,
710 ) -> NodeId {
711 let current_id = NodeId(*next_id);
712 *next_id += 1;
713
714 match node {
715 CircuitNode::Load(name) => {
716 node_id_map.insert(name.clone(), current_id);
717 graph.dependencies.insert(current_id, Vec::new());
718 current_id
719 }
720
721 CircuitNode::Constant(_) => {
722 graph.dependencies.insert(current_id, Vec::new());
723 current_id
724 }
725
726 CircuitNode::BinaryOp { left, right, .. } => {
727 let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
728 let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
729
730 graph
731 .dependencies
732 .insert(current_id, vec![left_id, right_id]);
733 current_id
734 }
735
736 CircuitNode::UnaryOp { operand, .. } => {
737 let operand_id = self.build_dependency_graph(operand, graph, node_id_map, next_id);
738
739 graph.dependencies.insert(current_id, vec![operand_id]);
740 current_id
741 }
742
743 CircuitNode::Compare { left, right, .. } => {
744 let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
745 let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
746
747 graph
748 .dependencies
749 .insert(current_id, vec![left_id, right_id]);
750 current_id
751 }
752 }
753 }
754
755 fn identify_parallel_groups(&self, graph: &DependencyGraph) -> Vec<Vec<NodeId>> {
757 let mut levels: HashMap<NodeId, usize> = HashMap::new();
758 let mut queue = VecDeque::new();
759
760 for (node_id, deps) in &graph.dependencies {
762 if deps.is_empty() {
763 levels.insert(*node_id, 0);
764 queue.push_back(*node_id);
765 }
766 }
767
768 while let Some(node_id) = queue.pop_front() {
770 let current_level = levels[&node_id];
771
772 for (dependent_id, deps) in &graph.dependencies {
774 if deps.contains(&node_id) {
775 let max_dep_level = deps
777 .iter()
778 .filter_map(|dep_id| levels.get(dep_id))
779 .max()
780 .copied()
781 .unwrap_or(0);
782
783 let dependent_level = max_dep_level + 1;
784
785 if !levels.contains_key(dependent_id) {
786 levels.insert(*dependent_id, dependent_level);
787 queue.push_back(*dependent_id);
788 }
789 }
790 }
791 }
792
793 let max_level = levels.values().max().copied().unwrap_or(0);
795 let mut parallel_groups = vec![Vec::new(); max_level + 1];
796
797 for (node_id, level) in levels {
798 parallel_groups[level].push(node_id);
799 }
800
801 for group in &mut parallel_groups {
803 group.sort();
804 }
805
806 parallel_groups
807 }
808
809 fn find_critical_path(&self, graph: &DependencyGraph) -> Vec<NodeId> {
811 let mut max_path = Vec::new();
813
814 for node_id in graph.dependencies.keys() {
815 let path = self.find_path_to_root(*node_id, graph);
816 if path.len() > max_path.len() {
817 max_path = path;
818 }
819 }
820
821 max_path
822 }
823
824 #[allow(clippy::only_used_in_recursion)]
826 fn find_path_to_root(&self, node_id: NodeId, graph: &DependencyGraph) -> Vec<NodeId> {
827 let deps = graph
828 .dependencies
829 .get(&node_id)
830 .map(|v| v.as_slice())
831 .unwrap_or(&[]);
832
833 if deps.is_empty() {
834 return vec![node_id];
835 }
836
837 let mut longest_path = Vec::new();
839 for dep_id in deps {
840 let dep_path = self.find_path_to_root(*dep_id, graph);
841 if dep_path.len() > longest_path.len() {
842 longest_path = dep_path;
843 }
844 }
845
846 longest_path.push(node_id);
847 longest_path
848 }
849}
850
851impl Default for CircuitOptimizer {
852 fn default() -> Self {
853 Self::new()
854 }
855}
856
857#[cfg(test)]
858mod tests {
859 use super::*;
860 use crate::compute::circuit::CircuitBuilder;
861
862 #[test]
863 fn test_constant_folding() -> Result<()> {
864 let mut optimizer = CircuitOptimizer::new();
865 let builder = CircuitBuilder::new();
866
867 let a = builder.constant(CircuitValue::U8(5));
869 let b = builder.constant(CircuitValue::U8(3));
870 let sum = builder.add(a, b);
871
872 let circuit = Circuit::new(sum, HashMap::new())?;
873 let optimized = optimizer.optimize(circuit)?;
874
875 assert!(matches!(
877 optimized.root,
878 CircuitNode::Constant(CircuitValue::U8(8))
879 ));
880 assert_eq!(optimizer.stats().constants_folded, 1);
881
882 Ok(())
883 }
884
885 #[test]
886 fn test_algebraic_identities() -> Result<()> {
887 let mut optimizer = CircuitOptimizer::new();
888 let mut builder = CircuitBuilder::new();
889 builder.declare_variable("x", EncryptedType::U8);
890
891 let x = builder.load("x");
893 let zero = builder.constant(CircuitValue::U8(0));
894 let add_zero = builder.add(x.clone(), zero);
895
896 let circuit = Circuit::new(add_zero, builder.variable_types_clone())?;
897 let optimized = optimizer.optimize(circuit)?;
898
899 assert!(matches!(optimized.root, CircuitNode::Load(_)));
901
902 Ok(())
903 }
904
905 #[test]
906 fn test_double_negation_elimination() -> Result<()> {
907 let mut optimizer = CircuitOptimizer::new();
908 let mut builder = CircuitBuilder::new();
909 builder.declare_variable("x", EncryptedType::Bool);
910
911 let x = builder.load("x");
913 let not_x = builder.not(x);
914 let not_not_x = builder.not(not_x);
915
916 let circuit = Circuit::new(not_not_x, builder.variable_types_clone())?;
917 let optimized = optimizer.optimize(circuit)?;
918
919 assert!(matches!(optimized.root, CircuitNode::Load(_)));
921 assert!(optimizer.stats().gates_fused >= 2);
922
923 Ok(())
924 }
925
926 #[test]
927 fn test_bootstrap_counting() -> Result<()> {
928 let optimizer = CircuitOptimizer::new();
929 let mut builder = CircuitBuilder::new();
930 builder
931 .declare_variable("a", EncryptedType::U8)
932 .declare_variable("b", EncryptedType::U8);
933
934 let a = builder.load("a");
936 let b = builder.load("b");
937 let mul = builder.mul(a, b);
938
939 let circuit = Circuit::new(mul, builder.variable_types_clone())?;
940 let bootstrap_count = optimizer.count_bootstraps(&circuit.root);
941
942 assert_eq!(bootstrap_count, 1); Ok(())
945 }
946
947 #[test]
948 fn test_parallelization_analysis() -> Result<()> {
949 let mut optimizer = CircuitOptimizer::new();
950 let mut builder = CircuitBuilder::new();
951 builder
952 .declare_variable("a", EncryptedType::U8)
953 .declare_variable("b", EncryptedType::U8)
954 .declare_variable("c", EncryptedType::U8);
955
956 let a = builder.load("a");
958 let b = builder.load("b");
959 let c = builder.load("c");
960 let sum1 = builder.add(a, b);
961 let sum2 = builder.add(sum1, c);
962
963 let circuit = Circuit::new(sum2, builder.variable_types_clone())?;
964 let optimized = optimizer.optimize(circuit)?;
965
966 let graph = optimizer.dependency_graph();
967 assert!(graph.node_count > 0);
968 assert!(!graph.parallel_groups.is_empty());
969
970 Ok(())
971 }
972
973 #[test]
974 fn test_optimization_stats() -> Result<()> {
975 let mut optimizer = CircuitOptimizer::new();
976 let builder = CircuitBuilder::new();
977
978 let a = builder.constant(CircuitValue::U8(5));
980 let b = builder.constant(CircuitValue::U8(3));
981 let zero = builder.constant(CircuitValue::U8(0));
982
983 let sum = builder.add(a, b); let add_zero = builder.add(sum, zero); let circuit = Circuit::new(add_zero, HashMap::new())?;
987 let original_gates = circuit.gate_count;
988
989 let optimized = optimizer.optimize(circuit)?;
990 let optimized_gates = optimized.gate_count;
991
992 assert!(optimized_gates < original_gates);
993 assert!(optimizer.stats().gate_reduction_percent() > 0.0);
994
995 Ok(())
996 }
997
998 #[test]
999 fn test_complex_circuit_optimization() -> Result<()> {
1000 let mut optimizer = CircuitOptimizer::new();
1001 let mut builder = CircuitBuilder::new();
1002 builder
1003 .declare_variable("a", EncryptedType::U8)
1004 .declare_variable("b", EncryptedType::U8);
1005
1006 let a = builder.load("a");
1009 let b = builder.load("b");
1010 let one = builder.constant(CircuitValue::U8(1));
1011 let zero = builder.constant(CircuitValue::U8(0));
1012 let five = builder.constant(CircuitValue::U8(5));
1013
1014 let a_times_1 = builder.mul(a, one);
1015 let b_times_0 = builder.mul(b, zero);
1016 let sum1 = builder.add(a_times_1, b_times_0);
1017 let result = builder.add(sum1, five);
1018
1019 let circuit = Circuit::new(result, builder.variable_types_clone())?;
1020 let original_gates = circuit.gate_count;
1021
1022 let optimized = optimizer.optimize(circuit)?;
1023
1024 assert!(optimized.gate_count < original_gates);
1025 assert!(optimizer.stats().gate_reduction_percent() >= 30.0);
1026
1027 Ok(())
1028 }
1029}