use std::{
collections::HashMap,
hash::{DefaultHasher, Hasher},
};
use itertools::Itertools;
use primitives::{algebra::elliptic_curve::Curve, izip_eq};
use crate::circuit::{
errors::{CircuitError, InternalError},
Circuit,
Gate,
GateIndex,
Slice,
};
#[derive(Default)]
pub(super) struct CircuitBuilder<C: Curve> {
circuit: Circuit<C>,
gate_to_index: HashMap<Gate<C>, GateIndex>,
indices_id_to_slice: HashMap<u64, Slice>,
}
#[derive(Clone, Debug, Default)]
pub(super) struct Wire(Vec<(GateIndex, u32)>);
impl Wire {
pub fn range(gate_idx: GateIndex, start: u32, end: u32) -> Result<Self, InternalError> {
if start >= end {
return Err(InternalError(format!("Invalid range [{start}; {end})")));
}
let indices = (start..end).map(|idx| (gate_idx, idx)).collect_vec();
Ok(Self(indices))
}
pub fn merge(wires: Vec<Wire>) -> Self {
Self(wires.into_iter().flat_map(|wire| wire.0).collect())
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn extract_range(&self, new_start: u32, new_end: u32) -> Result<Self, InternalError> {
let (new_start, new_end) = (new_start as usize, new_end as usize);
if new_start < new_end && new_end <= self.0.len() {
Ok(Wire(self.0[new_start..new_end].to_vec()))
} else {
Err(InternalError(format!(
"Invalid or out-of-bounds extract range [{new_start}; {new_end}) for wire of length {:?}", self.0.len(),
)))
}
}
pub fn extract_slice(self, slice: Slice) -> Result<Self, InternalError> {
if slice.is_empty() {
return Err(InternalError(format!("Unexpected empty slice: {slice:?}.")));
}
let gate_index_pairs = self.0;
let indices = slice.get_indices();
if indices
.iter()
.any(|&idx| idx as usize >= gate_index_pairs.len())
{
return Err(InternalError(format!(
"Slice indices out of bounds: {slice:?} for wire of length {:?}.",
gate_index_pairs.len(),
)));
}
let output_gate_index_pairs = indices
.into_iter()
.map(|idx| gate_index_pairs[idx as usize])
.collect_vec();
Ok(Self(output_gate_index_pairs))
}
fn group_by_gate(self) -> Result<Vec<(GateIndex, Vec<u32>)>, InternalError> {
let gate_index_pairs = self.0;
if gate_index_pairs.is_empty() {
return Err(InternalError("Unexpected empty wire".to_string()));
}
let mut output_ranges = vec![];
let mut curr_gate_idx = gate_index_pairs[0].0;
let mut curr_group = vec![gate_index_pairs[0].1];
for (gate_idx, idx) in gate_index_pairs.into_iter().skip(1) {
if gate_idx == curr_gate_idx {
curr_group.push(idx);
} else {
output_ranges.push((curr_gate_idx, curr_group));
curr_gate_idx = gate_idx;
curr_group = vec![idx];
}
}
output_ranges.push((curr_gate_idx, curr_group));
Ok(output_ranges)
}
}
impl<C: Curve> CircuitBuilder<C> {
pub fn add_gate_new_inputs(
&mut self,
gate: Gate<C>,
inputs: Vec<Wire>,
) -> Result<Wire, CircuitError<C>> {
if let Gate::ExtractFromBatch { slice, .. } = gate {
if inputs.len() != 1 {
return Err(InternalError(format!(
"ExtractFromBatch gate should have exactly one input, got {}",
inputs.len()
))
.into());
}
let wire = inputs[0].clone().extract_slice(slice)?;
return Ok(wire);
}
if let Gate::CollectToBatch { wires } = gate {
if wires.len() != inputs.len() {
return Err(InternalError(format!(
"CollectToBatch gate should have exactly {} inputs, got {}",
wires.len(),
inputs.len()
))
.into());
}
let wire = Wire::merge(inputs);
return Ok(wire);
}
let inputs: Vec<_> = inputs
.into_iter()
.map(|wire| self.materialize_wire(wire))
.try_collect()?;
let new_gate = gate.try_replace_inputs(inputs)?;
let index = self.add_gate(new_gate)?;
Ok(Wire::range(
index,
0,
self.circuit.gate_output_unchecked(index).batch_size,
)?)
}
pub fn add_randomness_gate(&mut self, gate: Gate<C>) -> Result<Wire, CircuitError<C>> {
if !matches!(gate, Gate::Random { .. } | Gate::DaBit { .. }) {
return Err(
InternalError(format!("Expected gate with no inputs, found: {gate:?}")).into(),
);
}
let index = self.add_gate(gate)?;
Ok(Wire::range(
index,
0,
self.circuit.gate_output_unchecked(index).batch_size,
)?)
}
pub fn add_output(&mut self, wire: Wire) -> Result<(), CircuitError<C>> {
let index = self.materialize_wire(wire)?;
self.circuit.add_output(index)
}
pub fn into_circuit(self) -> Circuit<C> {
self.circuit
}
fn add_gate(&mut self, gate: Gate<C>) -> Result<GateIndex, CircuitError<C>> {
if let Gate::Input { .. } | Gate::Random { .. } | Gate::DaBit { .. } = gate {
return self.circuit.add_gate(gate);
}
if let Some(index) = self.gate_to_index.get(&gate) {
return Ok(*index);
}
let index = self.circuit.add_gate(gate.clone())?;
self.gate_to_index.insert(gate, index);
Ok(index)
}
fn materialize_wire(&mut self, wire: Wire) -> Result<GateIndex, CircuitError<C>> {
if wire.is_empty() {
return Err(InternalError("Unexpected empty wire".to_string()).into());
}
let groups = wire.group_by_gate()?;
let mut gates = Vec::new();
for (x, indices) in groups {
if indices.len() == self.circuit.gate_output_unchecked(x).batch_size as usize
&& indices.iter().enumerate().all(|(i, j)| i == *j as usize)
{
gates.push(x);
} else {
let slice = self.get_slice_for_indices(indices);
let x = self.add_gate(Gate::ExtractFromBatch { x, slice })?;
gates.push(x);
}
}
let idx = if gates.len() == 1 {
gates[0]
} else {
self.add_gate(Gate::CollectToBatch { wires: gates })?
};
Ok(idx)
}
fn get_slice_for_indices(&mut self, mut indices: Vec<u32>) -> Slice {
#[inline]
fn hash_indices(indices: &[u32]) -> u64 {
let mut hasher = DefaultHasher::new();
indices.iter().for_each(|idx| hasher.write_u32(*idx));
hasher.finish()
}
const THRESHOLD: usize = 100;
if indices.len() < THRESHOLD {
Slice::from_indices(indices)
} else {
let shift = *indices.iter().min().expect("unexpected empty indices");
if shift > 0 {
indices.iter_mut().for_each(|idx| {
*idx -= shift;
});
}
let id = hash_indices(&indices);
if let Some(slice) = self.indices_id_to_slice.get(&id) {
if izip_eq!(slice.get_indices(), &indices).all(|(x, y)| x == *y) {
let mut slice = slice.clone();
slice.shift_start(shift);
return slice;
}
}
let mut slice = Slice::from_indices(indices);
self.indices_id_to_slice.insert(id, slice.clone());
slice.shift_start(shift);
slice
}
}
}