use std::collections::BTreeMap;
use itertools::izip;
use primitives::{algebra::elliptic_curve::Curve, izip_eq};
use crate::circuit::{
batcher::builder::{CircuitBuilder, Wire},
errors::CircuitError,
AlgebraicType,
BitPlaintextBinaryOp,
BitPlaintextUnaryOp,
BitShareBinaryOp,
BitShareUnaryOp,
Circuit,
FieldPlaintextBinaryOp,
FieldPlaintextUnaryOp,
FieldShareBinaryOp,
FieldShareUnaryOp,
Gate,
GateExt,
GateIndex,
GateLevel,
PointPlaintextBinaryOp,
PointPlaintextUnaryOp,
PointShareBinaryOp,
PointShareUnaryOp,
ShareOrPlaintext,
};
pub fn batch_circuit<C: Curve>(circuit: &Circuit<C>) -> Result<Circuit<C>, CircuitError<C>> {
let gate_groups = group_into_batches(circuit);
let mut circuit_builder = CircuitBuilder::<C>::default();
let mut old_to_new_idx = vec![Wire::default(); circuit.nb_gates() as usize];
for ((_, gate_batch_type), gates_to_batch) in gate_groups {
if let GateBatchType::Unbatched = gate_batch_type {
for gate_old_idx in gates_to_batch {
let gate = circuit.gate_ext_unchecked(gate_old_idx).gate.clone();
let gate_inps_new = gate
.get_inputs()
.into_iter()
.map(|old_idx| old_to_new_idx[old_idx as usize].clone())
.collect();
old_to_new_idx[gate_old_idx as usize] =
circuit_builder.add_gate_new_inputs(gate, gate_inps_new)?;
}
} else {
let mut gate = circuit.gate_unchecked(gates_to_batch[0]).clone();
let batched_output = if let Gate::Random { batch_size, .. }
| Gate::DaBit { batch_size, .. } = &mut gate
{
*batch_size = gates_to_batch
.iter()
.map(|gate_idx| circuit.gate_output_unchecked(*gate_idx).batch_size)
.sum::<u32>();
circuit_builder.add_randomness_gate(gate)?
} else {
let mut inputs_to_batch = vec![vec![]; gate.get_inputs().len()];
gates_to_batch.iter().for_each(|gate_idx_old| {
let gate_input_indices_old =
circuit.gate_ext_unchecked(*gate_idx_old).gate.get_inputs();
izip!(&mut inputs_to_batch, gate_input_indices_old).for_each(
|(gate_inputs_new, input_idx_old)| {
gate_inputs_new.push(old_to_new_idx[input_idx_old as usize].clone())
},
)
});
let batched_inputs = inputs_to_batch.into_iter().map(Wire::merge).collect();
circuit_builder.add_gate_new_inputs(gate, batched_inputs)?
};
let mut start = 0u32;
for output_idx_old in gates_to_batch {
let batch_size = circuit.gate_output_unchecked(output_idx_old).batch_size;
let end = start + batch_size;
old_to_new_idx[output_idx_old as usize] =
batched_output.extract_range(start, end)?;
start = end;
}
}
}
circuit.iter_output_indices().try_for_each(|output| {
circuit_builder.add_output(old_to_new_idx[*output as usize].clone())
})?;
Ok(circuit_builder.into_circuit())
}
fn group_into_batches<C: Curve>(
circuit: &Circuit<C>,
) -> BTreeMap<(GateLevel, GateBatchType), Vec<GateIndex>> {
let mut groups = BTreeMap::<(GateLevel, GateBatchType), Vec<GateIndex>>::new(); for (index, gate_ext) in izip_eq!(0..circuit.nb_gates(), circuit.iter_gates_ext()) {
let batch_type = GateBatchType::new(gate_ext.clone(), circuit);
groups
.entry((gate_ext.level, batch_type))
.or_default()
.push(index);
}
let mut groups_filtered = BTreeMap::<(GateLevel, GateBatchType), Vec<GateIndex>>::new(); for ((level, batch_type), group) in groups {
let batch_type = if group.len() == 1 {
GateBatchType::Unbatched
} else {
batch_type
};
groups_filtered
.entry((level, batch_type))
.or_default()
.extend(group);
}
groups_filtered
}
#[derive(Debug, PartialOrd, Ord, Eq, PartialEq)]
enum GateBatchType {
Random {
algebraic_type: AlgebraicType,
},
FieldShareUnaryOp {
op: FieldShareUnaryOp,
algebraic_type: AlgebraicType,
},
FieldShareBinaryOp {
op: FieldShareBinaryOp,
algebraic_type: AlgebraicType,
y_form: ShareOrPlaintext,
},
BitShareUnaryOp {
op: BitShareUnaryOp,
},
BitShareBinaryOp {
op: BitShareBinaryOp,
y_form: ShareOrPlaintext,
},
PointShareUnaryOp {
op: PointShareUnaryOp,
},
PointShareBinaryOp {
op: PointShareBinaryOp,
p_form: ShareOrPlaintext,
y_form: ShareOrPlaintext,
},
FieldPlaintextUnaryOp {
op: FieldPlaintextUnaryOp,
algebraic_type: AlgebraicType,
},
FieldPlaintextBinaryOp {
op: FieldPlaintextBinaryOp,
algebraic_type: AlgebraicType,
},
BitPlaintextUnaryOp {
op: BitPlaintextUnaryOp,
},
BitPlaintextBinaryOp {
op: BitPlaintextBinaryOp,
},
PointPlaintextUnaryOp {
op: PointPlaintextUnaryOp,
},
PointPlaintextBinaryOp {
op: PointPlaintextBinaryOp,
},
DaBit {
algebraic_type: AlgebraicType,
},
GetDaBitFieldShare {
algebraic_type: AlgebraicType,
},
GetDaBitSharedBit {
algebraic_type: AlgebraicType,
},
BitPlaintextToField {
algebraic_type: AlgebraicType,
},
FieldPlaintextToBit {
algebraic_type: AlgebraicType,
},
Unbatched,
}
impl GateBatchType {
fn new<C: Curve>(GateExt { gate, output, .. }: GateExt<C>, circuit: &Circuit<C>) -> Self {
match gate {
Gate::Input(_) => GateBatchType::Unbatched,
Gate::Constant(_) => GateBatchType::Unbatched,
Gate::Random { algebraic_type, .. } => GateBatchType::Random { algebraic_type },
Gate::FieldShareUnaryOp { x, op } => GateBatchType::FieldShareUnaryOp {
op,
algebraic_type: circuit.gate_output_unchecked(x).algebraic_type,
},
Gate::FieldShareBinaryOp { x, y, op } => GateBatchType::FieldShareBinaryOp {
op,
algebraic_type: circuit.gate_output_unchecked(x).algebraic_type,
y_form: circuit.gate_output_unchecked(y).form,
},
Gate::BatchSummation { .. } => GateBatchType::Unbatched,
Gate::BitShareUnaryOp { op, .. } => GateBatchType::BitShareUnaryOp { op },
Gate::BitShareBinaryOp { y, op, .. } => GateBatchType::BitShareBinaryOp {
op,
y_form: circuit.gate_output_unchecked(y).form,
},
Gate::PointShareUnaryOp { op, .. } => GateBatchType::PointShareUnaryOp { op },
Gate::PointShareBinaryOp { p: x, y, op, .. } => GateBatchType::PointShareBinaryOp {
op,
p_form: circuit.gate_output_unchecked(x).form,
y_form: circuit.gate_output_unchecked(y).form,
},
Gate::FieldPlaintextUnaryOp { x, op, .. } => GateBatchType::FieldPlaintextUnaryOp {
op,
algebraic_type: circuit.gate_output_unchecked(x).algebraic_type,
},
Gate::FieldPlaintextBinaryOp { x, op, .. } => GateBatchType::FieldPlaintextBinaryOp {
op,
algebraic_type: circuit.gate_output_unchecked(x).algebraic_type,
},
Gate::BitPlaintextUnaryOp { op, .. } => GateBatchType::BitPlaintextUnaryOp { op },
Gate::BitPlaintextBinaryOp { op, .. } => GateBatchType::BitPlaintextBinaryOp { op },
Gate::PointPlaintextUnaryOp { op, .. } => GateBatchType::PointPlaintextUnaryOp { op },
Gate::PointPlaintextBinaryOp { op, .. } => GateBatchType::PointPlaintextBinaryOp { op },
Gate::DaBit { field_type, .. } => GateBatchType::DaBit {
algebraic_type: field_type.into(),
},
Gate::GetDaBitFieldShare { x } => GateBatchType::GetDaBitFieldShare {
algebraic_type: circuit.gate_output_unchecked(x).algebraic_type,
},
Gate::GetDaBitSharedBit { x } => GateBatchType::GetDaBitSharedBit {
algebraic_type: circuit.gate_output_unchecked(x).algebraic_type,
},
Gate::BaseFieldPow { .. } => GateBatchType::Unbatched,
Gate::BitPlaintextToField { .. } => GateBatchType::BitPlaintextToField {
algebraic_type: output.algebraic_type,
},
Gate::FieldPlaintextToBit { x, .. } => GateBatchType::FieldPlaintextToBit {
algebraic_type: circuit.gate_output_unchecked(x).algebraic_type,
},
Gate::ExtractFromBatch { .. } => GateBatchType::Unbatched,
Gate::CollectToBatch { .. } => GateBatchType::Unbatched,
Gate::PointFromPlaintextExtendedEdwards { .. } => GateBatchType::Unbatched,
Gate::PlaintextPointToExtendedEdwards { .. } => GateBatchType::Unbatched,
Gate::PlaintextKeccakF1600 { .. } => GateBatchType::Unbatched,
Gate::CompressPlaintextPoint { .. } => GateBatchType::Unbatched,
Gate::KeyRecoveryPlaintextComputeErrors { .. } => GateBatchType::Unbatched,
}
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use primitives::algebra::elliptic_curve::Curve25519Ristretto as C;
use crate::circuit::{
batcher::batch_circuit,
tests::{create_add_tree_circuit, create_mul_tree_circuit},
AlgebraicType,
Circuit,
FieldShareBinaryOp,
Gate,
Input,
};
fn create_mixed_circuit() -> Circuit<C> {
let mut circuit = Circuit::<C>::new();
let inputs = (0..8)
.map(|_| {
circuit
.add_gate(Gate::Input(Input::SecretPlaintext {
inputer: 0,
algebraic_type: AlgebraicType::Mersenne107,
batch_size: 1,
}))
.unwrap()
})
.collect_vec();
let (a, b, c) = inputs[0..6]
.chunks(2)
.map(|chunk| {
assert_eq!(chunk.len(), 2);
circuit
.add_gate(Gate::FieldShareBinaryOp {
x: chunk[0],
y: chunk[1],
op: FieldShareBinaryOp::Add,
})
.unwrap()
})
.collect_tuple()
.unwrap();
let d = circuit
.add_gate(Gate::FieldShareBinaryOp {
x: inputs[6],
y: inputs[7],
op: FieldShareBinaryOp::Mul,
})
.unwrap();
let e = circuit
.add_gate(Gate::FieldShareBinaryOp {
x: a,
y: b,
op: FieldShareBinaryOp::Add,
})
.unwrap();
let f = circuit
.add_gate(Gate::FieldShareBinaryOp {
x: e,
y: c,
op: FieldShareBinaryOp::Mul,
})
.unwrap();
let h = circuit
.add_gate(Gate::FieldShareBinaryOp {
x: f,
y: d,
op: FieldShareBinaryOp::Add,
})
.unwrap();
circuit.add_output(h).unwrap();
circuit
}
#[test]
fn test_batcher_mul_tree() {
let depth = 4;
let circuit = create_mul_tree_circuit::<C>(depth);
assert_eq!(circuit.nb_inputs(), 1 << depth);
assert_eq!(circuit.nb_gates(), (1 << (depth + 1)) - 1);
assert_eq!(circuit.nb_outputs(), 1);
let _circuit = batch_circuit(&circuit).unwrap();
}
#[test]
fn test_batcher_add_tree() {
let depth = 4;
let circuit = create_add_tree_circuit::<C>(depth);
assert_eq!(circuit.nb_inputs(), 1 << depth);
assert_eq!(circuit.nb_gates(), (1 << (depth + 1)) - 1);
assert_eq!(circuit.nb_outputs(), 1);
let _circuit = batch_circuit(&circuit).unwrap();
}
#[test]
fn test_batcher_mixed_circuit() {
let circuit = create_mixed_circuit();
assert_eq!(circuit.nb_inputs(), 8);
assert_eq!(circuit.nb_gates(), 15);
assert_eq!(circuit.nb_outputs(), 1);
let _circuit = batch_circuit(&circuit).unwrap();
}
}