use crate::compute::circuit::{
BinaryOperator, Circuit, CircuitNode, CircuitValue, CompareOperator, EncryptedType,
UnaryOperator,
};
use crate::error::{AmateRSError, ErrorContext, Result};
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct OptimizationStats {
pub original_gate_count: usize,
pub optimized_gate_count: usize,
pub original_bootstrap_count: usize,
pub optimized_bootstrap_count: usize,
pub dead_code_removed: usize,
pub nodes_eliminated: usize,
pub algebraic_simplifications: usize,
pub constants_folded: usize,
pub gates_fused: usize,
pub original_depth: usize,
pub optimized_depth: usize,
}
impl OptimizationStats {
pub fn gate_reduction_percent(&self) -> f64 {
if self.original_gate_count == 0 {
return 0.0;
}
let reduction = self
.original_gate_count
.saturating_sub(self.optimized_gate_count);
(reduction as f64 / self.original_gate_count as f64) * 100.0
}
pub fn bootstrap_reduction_percent(&self) -> f64 {
if self.original_bootstrap_count == 0 {
return 0.0;
}
let reduction = self
.original_bootstrap_count
.saturating_sub(self.optimized_bootstrap_count);
(reduction as f64 / self.original_bootstrap_count as f64) * 100.0
}
pub fn total_stats(&self) -> (usize, usize, usize) {
(
self.nodes_eliminated + self.dead_code_removed,
self.algebraic_simplifications + self.gates_fused,
self.constants_folded,
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DependencyGraph {
pub dependencies: HashMap<NodeId, Vec<NodeId>>,
pub parallel_groups: Vec<Vec<NodeId>>,
pub critical_path: Vec<NodeId>,
pub node_count: usize,
}
impl DependencyGraph {
pub fn new() -> Self {
Self {
dependencies: HashMap::new(),
parallel_groups: Vec::new(),
critical_path: Vec::new(),
node_count: 0,
}
}
pub fn max_parallelism(&self) -> usize {
self.parallel_groups
.iter()
.map(|g| g.len())
.max()
.unwrap_or(0)
}
pub fn avg_parallelism(&self) -> f64 {
if self.parallel_groups.is_empty() {
return 0.0;
}
let total: usize = self.parallel_groups.iter().map(|g| g.len()).sum();
total as f64 / self.parallel_groups.len() as f64
}
}
impl Default for DependencyGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NodeId(pub usize);
#[derive(Debug, Clone)]
pub struct CircuitOptimizer {
pub enable_constant_folding: bool,
pub enable_dead_code_elimination: bool,
pub enable_bootstrap_minimization: bool,
pub enable_gate_fusion: bool,
pub enable_parallelization_analysis: bool,
stats: OptimizationStats,
dependency_graph: DependencyGraph,
}
impl CircuitOptimizer {
pub fn new() -> Self {
Self {
enable_constant_folding: true,
enable_dead_code_elimination: true,
enable_bootstrap_minimization: true,
enable_gate_fusion: true,
enable_parallelization_analysis: true,
stats: OptimizationStats::default(),
dependency_graph: DependencyGraph::new(),
}
}
pub fn disabled() -> Self {
Self {
enable_constant_folding: false,
enable_dead_code_elimination: false,
enable_bootstrap_minimization: false,
enable_gate_fusion: false,
enable_parallelization_analysis: false,
stats: OptimizationStats::default(),
dependency_graph: DependencyGraph::new(),
}
}
pub fn stats(&self) -> &OptimizationStats {
&self.stats
}
pub fn dependency_graph(&self) -> &DependencyGraph {
&self.dependency_graph
}
pub fn total_stats(&self) -> (usize, usize, usize) {
self.stats.total_stats()
}
pub fn optimize(&mut self, circuit: Circuit) -> Result<Circuit> {
self.stats.original_gate_count = circuit.gate_count;
self.stats.original_depth = circuit.depth;
self.stats.original_bootstrap_count = self.count_bootstraps(&circuit.root);
let mut optimized_root = circuit.root.clone();
if self.enable_constant_folding {
optimized_root = self.constant_folding_pass(optimized_root);
}
if self.enable_gate_fusion {
optimized_root = self.gate_fusion_pass(optimized_root);
}
if self.enable_bootstrap_minimization {
optimized_root = self.bootstrap_minimization_pass(optimized_root)?;
}
if self.enable_dead_code_elimination {
optimized_root = self.dead_code_elimination_pass(optimized_root);
}
let optimized_circuit = Circuit::new(optimized_root, circuit.variable_types)?;
self.stats.optimized_gate_count = optimized_circuit.gate_count;
self.stats.optimized_depth = optimized_circuit.depth;
self.stats.optimized_bootstrap_count = self.count_bootstraps(&optimized_circuit.root);
if self.enable_parallelization_analysis {
self.dependency_graph = self.analyze_parallelism(&optimized_circuit)?;
}
Ok(optimized_circuit)
}
#[allow(clippy::only_used_in_recursion)]
fn count_bootstraps(&self, node: &CircuitNode) -> usize {
match node {
CircuitNode::Load(_)
| CircuitNode::Constant(_)
| CircuitNode::EncryptedConstant { .. } => 0,
CircuitNode::BinaryOp { op, left, right } => {
let left_bootstraps = self.count_bootstraps(left);
let right_bootstraps = self.count_bootstraps(right);
let op_bootstrap = match op {
BinaryOperator::Mul => 1,
_ => 0,
};
left_bootstraps + right_bootstraps + op_bootstrap
}
CircuitNode::UnaryOp { operand, .. } => self.count_bootstraps(operand),
CircuitNode::Compare { left, right, .. } => {
let left_bootstraps = self.count_bootstraps(left);
let right_bootstraps = self.count_bootstraps(right);
left_bootstraps + right_bootstraps + 1
}
}
}
fn constant_folding_pass(&mut self, node: CircuitNode) -> CircuitNode {
match node {
CircuitNode::BinaryOp { op, left, right } => {
let left = self.constant_folding_pass(*left);
let right = self.constant_folding_pass(*right);
if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
if let Some(result) = self.fold_binary_constants(op, l, r) {
self.stats.constants_folded += 1;
return CircuitNode::Constant(result);
}
}
if let Some(simplified) = self.apply_algebraic_identities(op, &left, &right) {
return simplified;
}
CircuitNode::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
}
}
CircuitNode::UnaryOp { op, operand } => {
let operand = self.constant_folding_pass(*operand);
if let CircuitNode::Constant(val) = &operand {
if let Some(result) = self.fold_unary_constant(op, val) {
self.stats.constants_folded += 1;
return CircuitNode::Constant(result);
}
}
CircuitNode::UnaryOp {
op,
operand: Box::new(operand),
}
}
CircuitNode::Compare { op, left, right } => {
let left = self.constant_folding_pass(*left);
let right = self.constant_folding_pass(*right);
CircuitNode::Compare {
op,
left: Box::new(left),
right: Box::new(right),
}
}
other => other,
}
}
fn fold_binary_constants(
&self,
op: BinaryOperator,
left: &CircuitValue,
right: &CircuitValue,
) -> Option<CircuitValue> {
match (left, right) {
(CircuitValue::U8(l), CircuitValue::U8(r)) => match op {
BinaryOperator::Add => Some(CircuitValue::U8(l.wrapping_add(*r))),
BinaryOperator::Sub => Some(CircuitValue::U8(l.wrapping_sub(*r))),
BinaryOperator::Mul => Some(CircuitValue::U8(l.wrapping_mul(*r))),
_ => None,
},
(CircuitValue::U16(l), CircuitValue::U16(r)) => match op {
BinaryOperator::Add => Some(CircuitValue::U16(l.wrapping_add(*r))),
BinaryOperator::Sub => Some(CircuitValue::U16(l.wrapping_sub(*r))),
BinaryOperator::Mul => Some(CircuitValue::U16(l.wrapping_mul(*r))),
_ => None,
},
(CircuitValue::U32(l), CircuitValue::U32(r)) => match op {
BinaryOperator::Add => Some(CircuitValue::U32(l.wrapping_add(*r))),
BinaryOperator::Sub => Some(CircuitValue::U32(l.wrapping_sub(*r))),
BinaryOperator::Mul => Some(CircuitValue::U32(l.wrapping_mul(*r))),
_ => None,
},
(CircuitValue::U64(l), CircuitValue::U64(r)) => match op {
BinaryOperator::Add => Some(CircuitValue::U64(l.wrapping_add(*r))),
BinaryOperator::Sub => Some(CircuitValue::U64(l.wrapping_sub(*r))),
BinaryOperator::Mul => Some(CircuitValue::U64(l.wrapping_mul(*r))),
_ => None,
},
(CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
BinaryOperator::And => Some(CircuitValue::Bool(*l && *r)),
BinaryOperator::Or => Some(CircuitValue::Bool(*l || *r)),
BinaryOperator::Xor => Some(CircuitValue::Bool(*l ^ *r)),
_ => None,
},
_ => None,
}
}
fn fold_unary_constant(&self, op: UnaryOperator, value: &CircuitValue) -> Option<CircuitValue> {
match (op, value) {
(UnaryOperator::Not, CircuitValue::Bool(v)) => Some(CircuitValue::Bool(!*v)),
_ => None,
}
}
fn apply_algebraic_identities(
&mut self,
op: BinaryOperator,
left: &CircuitNode,
right: &CircuitNode,
) -> Option<CircuitNode> {
match op {
BinaryOperator::Add => {
if Self::is_zero(right) {
self.stats.gates_fused += 1;
return Some(left.clone());
}
if Self::is_zero(left) {
self.stats.gates_fused += 1;
return Some(right.clone());
}
}
BinaryOperator::Sub => {
if Self::is_zero(right) {
self.stats.gates_fused += 1;
return Some(left.clone());
}
}
BinaryOperator::Mul => {
if Self::is_zero(right) {
self.stats.gates_fused += 1;
return Some(right.clone());
}
if Self::is_zero(left) {
self.stats.gates_fused += 1;
return Some(left.clone());
}
if Self::is_one(right) {
self.stats.gates_fused += 1;
return Some(left.clone());
}
if Self::is_one(left) {
self.stats.gates_fused += 1;
return Some(right.clone());
}
}
BinaryOperator::And => {
if Self::is_true(right) {
self.stats.gates_fused += 1;
return Some(left.clone());
}
if Self::is_true(left) {
self.stats.gates_fused += 1;
return Some(right.clone());
}
if Self::is_false(right) {
self.stats.gates_fused += 1;
return Some(right.clone());
}
if Self::is_false(left) {
self.stats.gates_fused += 1;
return Some(left.clone());
}
}
BinaryOperator::Or => {
if Self::is_false(right) {
self.stats.gates_fused += 1;
return Some(left.clone());
}
if Self::is_false(left) {
self.stats.gates_fused += 1;
return Some(right.clone());
}
if Self::is_true(right) {
self.stats.gates_fused += 1;
return Some(right.clone());
}
if Self::is_true(left) {
self.stats.gates_fused += 1;
return Some(left.clone());
}
}
BinaryOperator::Xor => {
if Self::is_false(right) {
self.stats.gates_fused += 1;
return Some(left.clone());
}
if Self::is_false(left) {
self.stats.gates_fused += 1;
return Some(right.clone());
}
}
}
None
}
fn is_zero(node: &CircuitNode) -> bool {
matches!(
node,
CircuitNode::Constant(CircuitValue::U8(0))
| CircuitNode::Constant(CircuitValue::U16(0))
| CircuitNode::Constant(CircuitValue::U32(0))
| CircuitNode::Constant(CircuitValue::U64(0))
)
}
fn is_one(node: &CircuitNode) -> bool {
matches!(
node,
CircuitNode::Constant(CircuitValue::U8(1))
| CircuitNode::Constant(CircuitValue::U16(1))
| CircuitNode::Constant(CircuitValue::U32(1))
| CircuitNode::Constant(CircuitValue::U64(1))
)
}
fn is_true(node: &CircuitNode) -> bool {
matches!(node, CircuitNode::Constant(CircuitValue::Bool(true)))
}
fn is_false(node: &CircuitNode) -> bool {
matches!(node, CircuitNode::Constant(CircuitValue::Bool(false)))
}
fn gate_fusion_pass(&mut self, node: CircuitNode) -> CircuitNode {
match node {
CircuitNode::BinaryOp { op, left, right } => {
let left = self.gate_fusion_pass(*left);
let right = self.gate_fusion_pass(*right);
CircuitNode::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
}
}
CircuitNode::UnaryOp {
op: UnaryOperator::Not,
operand,
} => {
let operand = self.gate_fusion_pass(*operand);
if let CircuitNode::UnaryOp {
op: UnaryOperator::Not,
operand: inner,
} = operand
{
self.stats.gates_fused += 2; return *inner;
}
CircuitNode::UnaryOp {
op: UnaryOperator::Not,
operand: Box::new(operand),
}
}
CircuitNode::UnaryOp { op, operand } => {
let operand = self.gate_fusion_pass(*operand);
CircuitNode::UnaryOp {
op,
operand: Box::new(operand),
}
}
CircuitNode::Compare { op, left, right } => {
let left = self.gate_fusion_pass(*left);
let right = self.gate_fusion_pass(*right);
CircuitNode::Compare {
op,
left: Box::new(left),
right: Box::new(right),
}
}
other => other,
}
}
fn bootstrap_minimization_pass(&mut self, node: CircuitNode) -> Result<CircuitNode> {
Ok(self.reorder_for_bootstrap_efficiency(node))
}
#[allow(clippy::only_used_in_recursion)]
fn reorder_for_bootstrap_efficiency(&self, node: CircuitNode) -> CircuitNode {
match node {
CircuitNode::BinaryOp { op, left, right } => {
let left = self.reorder_for_bootstrap_efficiency(*left);
let right = self.reorder_for_bootstrap_efficiency(*right);
CircuitNode::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
}
}
CircuitNode::UnaryOp { op, operand } => {
let operand = self.reorder_for_bootstrap_efficiency(*operand);
CircuitNode::UnaryOp {
op,
operand: Box::new(operand),
}
}
CircuitNode::Compare { op, left, right } => {
let left = self.reorder_for_bootstrap_efficiency(*left);
let right = self.reorder_for_bootstrap_efficiency(*right);
CircuitNode::Compare {
op,
left: Box::new(left),
right: Box::new(right),
}
}
other => other,
}
}
fn dead_code_elimination_pass(&mut self, node: CircuitNode) -> CircuitNode {
let mut current = node;
loop {
let simplified = self.dce_simplify(current.clone());
if simplified == current {
break;
}
current = simplified;
}
current
}
fn dce_simplify(&mut self, node: CircuitNode) -> CircuitNode {
match node {
CircuitNode::BinaryOp { op, left, right } => {
let left = self.dce_simplify(*left);
let right = self.dce_simplify(*right);
if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
if let Some(result) = self.fold_binary_constants(op, l, r) {
self.stats.nodes_eliminated += 1;
self.stats.constants_folded += 1;
return CircuitNode::Constant(result);
}
}
if op == BinaryOperator::Sub && left == right {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return self.zero_like(&left);
}
if op == BinaryOperator::Xor && left == right {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return CircuitNode::Constant(CircuitValue::Bool(false));
}
match op {
BinaryOperator::Add => {
if Self::is_zero(&right) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
if Self::is_zero(&left) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return right;
}
}
BinaryOperator::Sub => {
if Self::is_zero(&right) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
}
BinaryOperator::Mul => {
if Self::is_zero(&right) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return right;
}
if Self::is_zero(&left) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
if Self::is_one(&right) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
if Self::is_one(&left) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return right;
}
}
BinaryOperator::And => {
if left == right {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
if Self::is_true(&right) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
if Self::is_true(&left) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return right;
}
if Self::is_false(&right) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return right;
}
if Self::is_false(&left) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
}
BinaryOperator::Or => {
if left == right {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
if Self::is_false(&right) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
if Self::is_false(&left) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return right;
}
if Self::is_true(&right) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return right;
}
if Self::is_true(&left) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
}
BinaryOperator::Xor => {
if Self::is_false(&right) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return left;
}
if Self::is_false(&left) {
self.stats.nodes_eliminated += 1;
self.stats.algebraic_simplifications += 1;
return right;
}
}
}
CircuitNode::BinaryOp {
op,
left: Box::new(left),
right: Box::new(right),
}
}
CircuitNode::UnaryOp { op, operand } => {
let operand = self.dce_simplify(*operand);
if let CircuitNode::Constant(val) = &operand {
if let Some(result) = self.fold_unary_constant(op, val) {
self.stats.nodes_eliminated += 1;
self.stats.constants_folded += 1;
return CircuitNode::Constant(result);
}
}
if op == UnaryOperator::Not {
if let CircuitNode::UnaryOp {
op: UnaryOperator::Not,
operand: inner,
} = operand
{
self.stats.nodes_eliminated += 2;
self.stats.algebraic_simplifications += 1;
return *inner;
}
}
if op == UnaryOperator::Neg {
if let CircuitNode::UnaryOp {
op: UnaryOperator::Neg,
operand: inner,
} = operand
{
self.stats.nodes_eliminated += 2;
self.stats.algebraic_simplifications += 1;
return *inner;
}
}
CircuitNode::UnaryOp {
op,
operand: Box::new(operand),
}
}
CircuitNode::Compare { op, left, right } => {
let left = self.dce_simplify(*left);
let right = self.dce_simplify(*right);
if let (CircuitNode::Constant(l), CircuitNode::Constant(r)) = (&left, &right) {
if let Some(result) = self.fold_comparison(op, l, r) {
self.stats.nodes_eliminated += 1;
self.stats.constants_folded += 1;
return CircuitNode::Constant(CircuitValue::Bool(result));
}
}
CircuitNode::Compare {
op,
left: Box::new(left),
right: Box::new(right),
}
}
other => other,
}
}
fn zero_like(&self, node: &CircuitNode) -> CircuitNode {
match node {
CircuitNode::Constant(CircuitValue::U8(_)) => {
CircuitNode::Constant(CircuitValue::U8(0))
}
CircuitNode::Constant(CircuitValue::U16(_)) => {
CircuitNode::Constant(CircuitValue::U16(0))
}
CircuitNode::Constant(CircuitValue::U32(_)) => {
CircuitNode::Constant(CircuitValue::U32(0))
}
CircuitNode::Constant(CircuitValue::U64(_)) => {
CircuitNode::Constant(CircuitValue::U64(0))
}
_ => CircuitNode::Constant(CircuitValue::U8(0)),
}
}
fn fold_comparison(
&self,
op: CompareOperator,
left: &CircuitValue,
right: &CircuitValue,
) -> Option<bool> {
match (left, right) {
(CircuitValue::U8(l), CircuitValue::U8(r)) => Some(self.compare_values(op, *l, *r)),
(CircuitValue::U16(l), CircuitValue::U16(r)) => Some(self.compare_values(op, *l, *r)),
(CircuitValue::U32(l), CircuitValue::U32(r)) => Some(self.compare_values(op, *l, *r)),
(CircuitValue::U64(l), CircuitValue::U64(r)) => Some(self.compare_values(op, *l, *r)),
(CircuitValue::Bool(l), CircuitValue::Bool(r)) => match op {
CompareOperator::Eq => Some(l == r),
CompareOperator::Ne => Some(l != r),
_ => None,
},
_ => None,
}
}
fn compare_values<T: PartialOrd + PartialEq>(&self, op: CompareOperator, l: T, r: T) -> bool {
match op {
CompareOperator::Eq => l == r,
CompareOperator::Ne => l != r,
CompareOperator::Lt => l < r,
CompareOperator::Le => l <= r,
CompareOperator::Gt => l > r,
CompareOperator::Ge => l >= r,
}
}
pub fn collect_live_variables(&self, node: &CircuitNode) -> HashSet<String> {
let mut live = HashSet::new();
self.mark_live_nodes(node, &mut live);
live
}
#[allow(clippy::only_used_in_recursion)]
fn mark_live_nodes(&self, node: &CircuitNode, live_nodes: &mut HashSet<String>) {
match node {
CircuitNode::Load(name) => {
live_nodes.insert(name.clone());
}
CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {}
CircuitNode::BinaryOp { left, right, .. } => {
self.mark_live_nodes(left, live_nodes);
self.mark_live_nodes(right, live_nodes);
}
CircuitNode::UnaryOp { operand, .. } => {
self.mark_live_nodes(operand, live_nodes);
}
CircuitNode::Compare { left, right, .. } => {
self.mark_live_nodes(left, live_nodes);
self.mark_live_nodes(right, live_nodes);
}
}
}
fn analyze_parallelism(&self, circuit: &Circuit) -> Result<DependencyGraph> {
let mut graph = DependencyGraph::new();
let mut node_id_map = HashMap::new();
let mut next_id = 0;
self.build_dependency_graph(&circuit.root, &mut graph, &mut node_id_map, &mut next_id);
graph.node_count = next_id;
graph.parallel_groups = self.identify_parallel_groups(&graph);
graph.critical_path = self.find_critical_path(&graph);
Ok(graph)
}
#[allow(clippy::only_used_in_recursion)]
fn build_dependency_graph(
&self,
node: &CircuitNode,
graph: &mut DependencyGraph,
node_id_map: &mut HashMap<String, NodeId>,
next_id: &mut usize,
) -> NodeId {
let current_id = NodeId(*next_id);
*next_id += 1;
match node {
CircuitNode::Load(name) => {
node_id_map.insert(name.clone(), current_id);
graph.dependencies.insert(current_id, Vec::new());
current_id
}
CircuitNode::Constant(_) | CircuitNode::EncryptedConstant { .. } => {
graph.dependencies.insert(current_id, Vec::new());
current_id
}
CircuitNode::BinaryOp { left, right, .. } => {
let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
graph
.dependencies
.insert(current_id, vec![left_id, right_id]);
current_id
}
CircuitNode::UnaryOp { operand, .. } => {
let operand_id = self.build_dependency_graph(operand, graph, node_id_map, next_id);
graph.dependencies.insert(current_id, vec![operand_id]);
current_id
}
CircuitNode::Compare { left, right, .. } => {
let left_id = self.build_dependency_graph(left, graph, node_id_map, next_id);
let right_id = self.build_dependency_graph(right, graph, node_id_map, next_id);
graph
.dependencies
.insert(current_id, vec![left_id, right_id]);
current_id
}
}
}
fn identify_parallel_groups(&self, graph: &DependencyGraph) -> Vec<Vec<NodeId>> {
let mut levels: HashMap<NodeId, usize> = HashMap::new();
let mut queue = VecDeque::new();
for (node_id, deps) in &graph.dependencies {
if deps.is_empty() {
levels.insert(*node_id, 0);
queue.push_back(*node_id);
}
}
while let Some(node_id) = queue.pop_front() {
let current_level = levels[&node_id];
for (dependent_id, deps) in &graph.dependencies {
if deps.contains(&node_id) {
let max_dep_level = deps
.iter()
.filter_map(|dep_id| levels.get(dep_id))
.max()
.copied()
.unwrap_or(0);
let dependent_level = max_dep_level + 1;
if !levels.contains_key(dependent_id) {
levels.insert(*dependent_id, dependent_level);
queue.push_back(*dependent_id);
}
}
}
}
let max_level = levels.values().max().copied().unwrap_or(0);
let mut parallel_groups = vec![Vec::new(); max_level + 1];
for (node_id, level) in levels {
parallel_groups[level].push(node_id);
}
for group in &mut parallel_groups {
group.sort();
}
parallel_groups
}
fn find_critical_path(&self, graph: &DependencyGraph) -> Vec<NodeId> {
let mut max_path = Vec::new();
for node_id in graph.dependencies.keys() {
let path = self.find_path_to_root(*node_id, graph);
if path.len() > max_path.len() {
max_path = path;
}
}
max_path
}
#[allow(clippy::only_used_in_recursion)]
fn find_path_to_root(&self, node_id: NodeId, graph: &DependencyGraph) -> Vec<NodeId> {
let deps = graph
.dependencies
.get(&node_id)
.map(|v| v.as_slice())
.unwrap_or(&[]);
if deps.is_empty() {
return vec![node_id];
}
let mut longest_path = Vec::new();
for dep_id in deps {
let dep_path = self.find_path_to_root(*dep_id, graph);
if dep_path.len() > longest_path.len() {
longest_path = dep_path;
}
}
longest_path.push(node_id);
longest_path
}
}
impl Default for CircuitOptimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compute::circuit::CircuitBuilder;
#[test]
fn test_constant_folding() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let a = builder.constant(CircuitValue::U8(5));
let b = builder.constant(CircuitValue::U8(3));
let sum = builder.add(a, b);
let circuit = Circuit::new(sum, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert!(matches!(
optimized.root,
CircuitNode::Constant(CircuitValue::U8(8))
));
assert!(optimizer.stats().constants_folded >= 1);
Ok(())
}
#[test]
fn test_constant_folding_sub() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let a = builder.constant(CircuitValue::U16(100));
let b = builder.constant(CircuitValue::U16(30));
let result = builder.sub(a, b);
let circuit = Circuit::new(result, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U16(70)));
Ok(())
}
#[test]
fn test_constant_folding_mul() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let a = builder.constant(CircuitValue::U32(7));
let b = builder.constant(CircuitValue::U32(6));
let result = builder.mul(a, b);
let circuit = Circuit::new(result, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U32(42)));
Ok(())
}
#[test]
fn test_constant_folding_bool_and() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let t = builder.constant(CircuitValue::Bool(true));
let f = builder.constant(CircuitValue::Bool(false));
let result = builder.and(t, f);
let circuit = Circuit::new(result, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(
optimized.root,
CircuitNode::Constant(CircuitValue::Bool(false))
);
Ok(())
}
#[test]
fn test_constant_folding_unary_not() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let t = builder.constant(CircuitValue::Bool(true));
let result = builder.not(t);
let circuit = Circuit::new(result, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(
optimized.root,
CircuitNode::Constant(CircuitValue::Bool(false))
);
Ok(())
}
#[test]
fn test_algebraic_x_plus_zero() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let zero = builder.constant(CircuitValue::U8(0));
let add_zero = builder.add(x, zero);
let circuit = Circuit::new(add_zero, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_algebraic_zero_plus_x() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let zero = builder.constant(CircuitValue::U8(0));
let result = builder.add(zero, x);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_algebraic_x_mul_one() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let one = builder.constant(CircuitValue::U8(1));
let result = builder.mul(x, one);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_algebraic_one_mul_x() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let one = builder.constant(CircuitValue::U8(1));
let result = builder.mul(one, x);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_algebraic_x_mul_zero() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let zero = builder.constant(CircuitValue::U8(0));
let result = builder.mul(x, zero);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(0)));
Ok(())
}
#[test]
fn test_algebraic_zero_mul_x() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let zero = builder.constant(CircuitValue::U8(0));
let result = builder.mul(zero, x);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(0)));
Ok(())
}
#[test]
fn test_algebraic_x_sub_zero() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let zero = builder.constant(CircuitValue::U8(0));
let result = builder.sub(x, zero);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_algebraic_x_sub_x() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x1 = builder.load("x");
let x2 = builder.load("x");
let result = builder.sub(x1, x2);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(0)));
assert!(optimizer.stats().algebraic_simplifications >= 1);
Ok(())
}
#[test]
fn test_double_negation_elimination() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::Bool);
let x = builder.load("x");
let not_x = builder.not(x);
let not_not_x = builder.not(not_x);
let circuit = Circuit::new(not_not_x, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_quadruple_negation_elimination() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::Bool);
let x = builder.load("x");
let n1 = builder.not(x);
let n2 = builder.not(n1);
let n3 = builder.not(n2);
let n4 = builder.not(n3);
let circuit = Circuit::new(n4, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_nested_x_plus_0_times_1() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let zero = builder.constant(CircuitValue::U8(0));
let one = builder.constant(CircuitValue::U8(1));
let add_zero = builder.add(x, zero);
let times_one = builder.mul(add_zero, one);
let circuit = Circuit::new(times_one, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_nested_complex_optimization() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder
.declare_variable("a", EncryptedType::U8)
.declare_variable("b", EncryptedType::U8);
let a = builder.load("a");
let b = builder.load("b");
let one = builder.constant(CircuitValue::U8(1));
let zero = builder.constant(CircuitValue::U8(0));
let five = builder.constant(CircuitValue::U8(5));
let a_times_1 = builder.mul(a, one);
let b_times_0 = builder.mul(b, zero);
let sum1 = builder.add(a_times_1, b_times_0);
let result = builder.add(sum1, five);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let original_gates = circuit.gate_count;
let optimized = optimizer.optimize(circuit)?;
assert!(optimized.gate_count < original_gates);
assert!(optimizer.stats().gate_reduction_percent() >= 30.0);
Ok(())
}
#[test]
fn test_noop_on_optimal_circuit() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder
.declare_variable("a", EncryptedType::U8)
.declare_variable("b", EncryptedType::U8);
let a = builder.load("a");
let b = builder.load("b");
let result = builder.add(a, b);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let original_gates = circuit.gate_count;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.gate_count, original_gates);
assert_eq!(
optimized.root,
CircuitNode::BinaryOp {
op: BinaryOperator::Add,
left: Box::new(CircuitNode::Load("a".to_string())),
right: Box::new(CircuitNode::Load("b".to_string())),
}
);
Ok(())
}
#[test]
fn test_noop_single_load() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let circuit = Circuit::new(x, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_stats_accuracy_constant_folding() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let a = builder.constant(CircuitValue::U8(5));
let b = builder.constant(CircuitValue::U8(3));
let two = builder.constant(CircuitValue::U8(2));
let sum = builder.add(a, b);
let result = builder.mul(sum, two);
let circuit = Circuit::new(result, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Constant(CircuitValue::U8(16)));
assert!(optimizer.stats().constants_folded >= 2);
Ok(())
}
#[test]
fn test_stats_accuracy_algebraic() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x1 = builder.load("x");
let x2 = builder.load("x");
let result = builder.sub(x1, x2);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let _optimized = optimizer.optimize(circuit)?;
let (total_eliminated, total_algebraic, _total_folds) = optimizer.total_stats();
assert!(total_eliminated >= 1);
assert!(total_algebraic >= 1);
Ok(())
}
#[test]
fn test_optimization_stats() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let a = builder.constant(CircuitValue::U8(5));
let b = builder.constant(CircuitValue::U8(3));
let zero = builder.constant(CircuitValue::U8(0));
let sum = builder.add(a, b);
let add_zero = builder.add(sum, zero);
let circuit = Circuit::new(add_zero, HashMap::new())?;
let original_gates = circuit.gate_count;
let optimized = optimizer.optimize(circuit)?;
let optimized_gates = optimized.gate_count;
assert!(optimized_gates < original_gates);
assert!(optimizer.stats().gate_reduction_percent() > 0.0);
Ok(())
}
#[test]
fn test_total_stats_method() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let zero = builder.constant(CircuitValue::U8(0));
let one = builder.constant(CircuitValue::U8(1));
let add_zero = builder.add(x, zero);
let times_one = builder.mul(add_zero, one);
let circuit = Circuit::new(times_one, builder.variable_types_clone())?;
let _optimized = optimizer.optimize(circuit)?;
let (eliminated, algebraic, _folds) = optimizer.total_stats();
assert!(eliminated + algebraic >= 2);
Ok(())
}
#[test]
fn test_bootstrap_counting() -> Result<()> {
let optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder
.declare_variable("a", EncryptedType::U8)
.declare_variable("b", EncryptedType::U8);
let a = builder.load("a");
let b = builder.load("b");
let mul = builder.mul(a, b);
let circuit = Circuit::new(mul, builder.variable_types_clone())?;
let bootstrap_count = optimizer.count_bootstraps(&circuit.root);
assert_eq!(bootstrap_count, 1);
Ok(())
}
#[test]
fn test_parallelization_analysis() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder
.declare_variable("a", EncryptedType::U8)
.declare_variable("b", EncryptedType::U8)
.declare_variable("c", EncryptedType::U8);
let a = builder.load("a");
let b = builder.load("b");
let c = builder.load("c");
let sum1 = builder.add(a, b);
let sum2 = builder.add(sum1, c);
let circuit = Circuit::new(sum2, builder.variable_types_clone())?;
let _optimized = optimizer.optimize(circuit)?;
let graph = optimizer.dependency_graph();
assert!(graph.node_count > 0);
assert!(!graph.parallel_groups.is_empty());
Ok(())
}
#[test]
fn test_collect_live_variables() -> Result<()> {
let optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder
.declare_variable("a", EncryptedType::U8)
.declare_variable("b", EncryptedType::U8);
let a = builder.load("a");
let b = builder.load("b");
let result = builder.add(a, b);
let live = optimizer.collect_live_variables(&result);
assert!(live.contains("a"));
assert!(live.contains("b"));
assert_eq!(live.len(), 2);
Ok(())
}
#[test]
fn test_collect_live_variables_after_dce() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder
.declare_variable("a", EncryptedType::U8)
.declare_variable("b", EncryptedType::U8);
let a = builder.load("a");
let b = builder.load("b");
let one = builder.constant(CircuitValue::U8(1));
let zero = builder.constant(CircuitValue::U8(0));
let a1 = builder.mul(a, one);
let b0 = builder.mul(b, zero);
let result = builder.add(a1, b0);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
let live = optimizer.collect_live_variables(&optimized.root);
assert!(live.contains("a"));
assert!(!live.contains("b"), "b should be eliminated by DCE");
Ok(())
}
#[test]
fn test_comparison_constant_fold() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let a = builder.constant(CircuitValue::U8(10));
let b = builder.constant(CircuitValue::U8(5));
let result = builder.gt(a, b);
let circuit = Circuit::new(result, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(
optimized.root,
CircuitNode::Constant(CircuitValue::Bool(true))
);
Ok(())
}
#[test]
fn test_comparison_constant_fold_eq() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let a = builder.constant(CircuitValue::U8(5));
let b = builder.constant(CircuitValue::U8(5));
let result = builder.eq(a, b);
let circuit = Circuit::new(result, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(
optimized.root,
CircuitNode::Constant(CircuitValue::Bool(true))
);
Ok(())
}
#[test]
fn test_xor_self_elimination() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::Bool);
let x1 = builder.load("x");
let x2 = builder.load("x");
let result = builder.xor(x1, x2);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(
optimized.root,
CircuitNode::Constant(CircuitValue::Bool(false))
);
Ok(())
}
#[test]
fn test_and_idempotent() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::Bool);
let x1 = builder.load("x");
let x2 = builder.load("x");
let result = builder.and(x1, x2);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_or_idempotent() -> Result<()> {
let mut optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::Bool);
let x1 = builder.load("x");
let x2 = builder.load("x");
let result = builder.or(x1, x2);
let circuit = Circuit::new(result, builder.variable_types_clone())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, CircuitNode::Load("x".to_string()));
Ok(())
}
#[test]
fn test_optimizer_does_not_fold_encrypted_constants() -> Result<()> {
use crate::compute::circuit::ConstantType;
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let enc_a = builder.encrypted_constant(vec![0x01, 0x05], ConstantType::Integer);
let enc_b = builder.encrypted_constant(vec![0x01, 0x03], ConstantType::Integer);
let sum = builder.add(enc_a.clone(), enc_b.clone());
let circuit = Circuit::new(sum, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
match &optimized.root {
CircuitNode::BinaryOp { op, left, right } => {
assert_eq!(*op, BinaryOperator::Add);
assert!(matches!(**left, CircuitNode::EncryptedConstant { .. }));
assert!(matches!(**right, CircuitNode::EncryptedConstant { .. }));
}
_ => {
return Err(AmateRSError::FheComputation(ErrorContext::new(
"Optimizer incorrectly folded encrypted constants".to_string(),
)));
}
}
assert_eq!(optimizer.stats().constants_folded, 0);
Ok(())
}
#[test]
fn test_optimizer_dce_treats_encrypted_constant_as_opaque() -> Result<()> {
use crate::compute::circuit::ConstantType;
let mut optimizer = CircuitOptimizer::new();
let enc = CircuitNode::EncryptedConstant {
data: vec![0x04, 0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x00, 0x11],
original_type: ConstantType::Integer,
};
let circuit = Circuit::new(enc.clone(), HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, enc);
Ok(())
}
#[test]
fn test_optimizer_mixed_plain_and_encrypted_constants() -> Result<()> {
use crate::compute::circuit::ConstantType;
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let plain_a = builder.constant(CircuitValue::U8(5));
let plain_b = builder.constant(CircuitValue::U8(3));
let plain_sum = builder.add(plain_a, plain_b);
let circuit = Circuit::new(plain_sum, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert!(matches!(
optimized.root,
CircuitNode::Constant(CircuitValue::U8(8))
));
let mut optimizer2 = CircuitOptimizer::new();
let enc_a = builder.encrypted_constant(vec![0x01, 0xAA], ConstantType::Integer);
let enc_b = builder.encrypted_constant(vec![0x01, 0xBB], ConstantType::Integer);
let enc_sum = builder.add(enc_a, enc_b);
let circuit2 = Circuit::new(enc_sum, HashMap::new())?;
let optimized2 = optimizer2.optimize(circuit2)?;
assert!(matches!(optimized2.root, CircuitNode::BinaryOp { .. }));
Ok(())
}
#[test]
fn test_optimizer_algebraic_identity_with_encrypted_constant() -> Result<()> {
use crate::compute::circuit::ConstantType;
let mut optimizer = CircuitOptimizer::new();
let builder = CircuitBuilder::new();
let enc = builder.encrypted_constant(
vec![0x04, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
ConstantType::Integer,
);
let zero = builder.constant(CircuitValue::U64(0));
let sum = builder.add(enc.clone(), zero);
let circuit = Circuit::new(sum, HashMap::new())?;
let optimized = optimizer.optimize(circuit)?;
assert_eq!(optimized.root, enc);
Ok(())
}
#[test]
fn test_optimizer_live_variables_with_encrypted_constants() -> Result<()> {
use crate::compute::circuit::ConstantType;
let optimizer = CircuitOptimizer::new();
let mut builder = CircuitBuilder::new();
builder.declare_variable("x", EncryptedType::U8);
let x = builder.load("x");
let enc = builder.encrypted_constant(vec![0x01, 0x10], ConstantType::Integer);
let sum = builder.add(x, enc);
let live = optimizer.collect_live_variables(&sum);
assert!(live.contains("x"));
assert_eq!(live.len(), 1);
Ok(())
}
}