arcium-core-utils 0.4.1

Arcium core utils
Documentation
use std::{
    collections::HashMap,
    hash::{DefaultHasher, Hasher},
};

use itertools::Itertools;
use primitives::{algebra::elliptic_curve::Curve, izip_eq};

use crate::circuit::{
    errors::{CircuitError, InternalError},
    Circuit,
    Gate,
    GateIndex,
    Slice,
};

/// A wrapper for building circuits with lazy instantiation of batch/unbatch gates.
///
/// The circuit builder allows adding new gates and circuit outputs.
///
/// Instead of returning a `GateIndex` when adding a new gate, the builder returns a `Wire`. The
/// `Wire` structure denotes the gate output as a batch. One can extract a range from a wire (i.e. a
/// sequence of batch values) or merge wires. The wires are lazily materialized when used as inputs
/// to a new gate to add. The "materialization" is performed by adding the required
/// `ExtractFromBatch/CollectToBatch` gates to the circuit.
///
/// Additionally, the circuit builder tracks the gates added so far and reuses them if an already
/// existing gate is added to the circuit.
#[derive(Default)]
pub(super) struct CircuitBuilder<C: Curve> {
    circuit: Circuit<C>,
    // Gate to index mapping for fast lookup, used to avoid adding duplicate gates. Only non-input
    // gates should be added to this map.
    gate_to_index: HashMap<Gate<C>, GateIndex>,
    // Indices vector to optimized slice to avoid optimizing the same slice multiple times.
    indices_id_to_slice: HashMap<u64, Slice>,
}

/// Internal structure to represent a list of (gate index, output index) pairs.
#[derive(Clone, Debug, Default)]
pub(super) struct Wire(Vec<(GateIndex, u32)>);

impl Wire {
    pub fn range(gate_idx: GateIndex, start: u32, end: u32) -> Result<Self, InternalError> {
        if start >= end {
            return Err(InternalError(format!("Invalid range [{start}; {end})")));
        }
        let indices = (start..end).map(|idx| (gate_idx, idx)).collect_vec();
        Ok(Self(indices))
    }

    pub fn merge(wires: Vec<Wire>) -> Self {
        Self(wires.into_iter().flat_map(|wire| wire.0).collect())
    }

    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }

    /// Note: `new_start` and `new_end` are relative to the original range of the wire. E.g. if the
    /// wire is `Wire::Range { gate_idx, start: 1, end: 5 }`, then `new_start = 1` and `new_end = 3`
    /// will extract the second and third elements of the original range, i.e. `[2, 3]`
    pub fn extract_range(&self, new_start: u32, new_end: u32) -> Result<Self, InternalError> {
        let (new_start, new_end) = (new_start as usize, new_end as usize);
        if new_start < new_end && new_end <= self.0.len() {
            Ok(Wire(self.0[new_start..new_end].to_vec()))
        } else {
            Err(InternalError(format!(
                "Invalid or out-of-bounds extract range [{new_start}; {new_end}) for wire of length {:?}", self.0.len(),
            )))
        }
    }

    pub fn extract_slice(self, slice: Slice) -> Result<Self, InternalError> {
        if slice.is_empty() {
            return Err(InternalError(format!("Unexpected empty slice: {slice:?}.")));
        }

        let gate_index_pairs = self.0;
        let indices = slice.get_indices();
        if indices
            .iter()
            .any(|&idx| idx as usize >= gate_index_pairs.len())
        {
            return Err(InternalError(format!(
                "Slice indices out of bounds: {slice:?} for wire of length {:?}.",
                gate_index_pairs.len(),
            )));
        }

        let output_gate_index_pairs = indices
            .into_iter()
            .map(|idx| gate_index_pairs[idx as usize])
            .collect_vec();

        Ok(Self(output_gate_index_pairs))
    }

    fn group_by_gate(self) -> Result<Vec<(GateIndex, Vec<u32>)>, InternalError> {
        let gate_index_pairs = self.0;
        if gate_index_pairs.is_empty() {
            return Err(InternalError("Unexpected empty wire".to_string()));
        }

        // Merge consecutive range extracts from the same gate into a single range
        let mut output_ranges = vec![];
        let mut curr_gate_idx = gate_index_pairs[0].0;
        let mut curr_group = vec![gate_index_pairs[0].1];
        for (gate_idx, idx) in gate_index_pairs.into_iter().skip(1) {
            if gate_idx == curr_gate_idx {
                curr_group.push(idx);
            } else {
                output_ranges.push((curr_gate_idx, curr_group));
                curr_gate_idx = gate_idx;
                curr_group = vec![idx];
            }
        }
        output_ranges.push((curr_gate_idx, curr_group));

        Ok(output_ranges)
    }
}

impl<C: Curve> CircuitBuilder<C> {
    pub fn add_gate_new_inputs(
        &mut self,
        gate: Gate<C>,
        inputs: Vec<Wire>,
    ) -> Result<Wire, CircuitError<C>> {
        // If the gate is an extract from a batch, we can directly return a Wire without adding
        // a new gate to the circuit.
        if let Gate::ExtractFromBatch { slice, .. } = gate {
            if inputs.len() != 1 {
                return Err(InternalError(format!(
                    "ExtractFromBatch gate should have exactly one input, got {}",
                    inputs.len()
                ))
                .into());
            }
            let wire = inputs[0].clone().extract_slice(slice)?;
            return Ok(wire);
        }

        // If the gate is a collect to batch, we can directly return a Wire without adding a new
        // gate to the circuit.
        if let Gate::CollectToBatch { wires } = gate {
            if wires.len() != inputs.len() {
                return Err(InternalError(format!(
                    "CollectToBatch gate should have exactly {} inputs, got {}",
                    wires.len(),
                    inputs.len()
                ))
                .into());
            }

            let wire = Wire::merge(inputs);
            return Ok(wire);
        }

        let inputs: Vec<_> = inputs
            .into_iter()
            .map(|wire| self.materialize_wire(wire))
            .try_collect()?;

        let new_gate = gate.try_replace_inputs(inputs)?;

        let index = self.add_gate(new_gate)?;

        Ok(Wire::range(
            index,
            0,
            self.circuit.gate_output_unchecked(index).batch_size,
        )?)
    }

    pub fn add_randomness_gate(&mut self, gate: Gate<C>) -> Result<Wire, CircuitError<C>> {
        if !matches!(gate, Gate::Random { .. } | Gate::DaBit { .. }) {
            return Err(
                InternalError(format!("Expected gate with no inputs, found: {gate:?}")).into(),
            );
        }

        let index = self.add_gate(gate)?;

        Ok(Wire::range(
            index,
            0,
            self.circuit.gate_output_unchecked(index).batch_size,
        )?)
    }

    pub fn add_output(&mut self, wire: Wire) -> Result<(), CircuitError<C>> {
        let index = self.materialize_wire(wire)?;
        self.circuit.add_output(index)
    }

    pub fn into_circuit(self) -> Circuit<C> {
        self.circuit
    }

    fn add_gate(&mut self, gate: Gate<C>) -> Result<GateIndex, CircuitError<C>> {
        // Gates with no inputs are always added to the circuit.
        if let Gate::Input { .. } | Gate::Random { .. } | Gate::DaBit { .. } = gate {
            return self.circuit.add_gate(gate);
        }

        // Check if the gate is already present in the circuit.
        if let Some(index) = self.gate_to_index.get(&gate) {
            return Ok(*index);
        }

        // Otherwise, add the gate to the circuit.
        let index = self.circuit.add_gate(gate.clone())?;
        self.gate_to_index.insert(gate, index);

        Ok(index)
    }

    fn materialize_wire(&mut self, wire: Wire) -> Result<GateIndex, CircuitError<C>> {
        if wire.is_empty() {
            return Err(InternalError("Unexpected empty wire".to_string()).into());
        }

        let groups = wire.group_by_gate()?;
        let mut gates = Vec::new();
        for (x, indices) in groups {
            // If we extract all gate outputs, we can directly return the gate index. Otherwise, add
            // an extract gate.
            if indices.len() == self.circuit.gate_output_unchecked(x).batch_size as usize
                && indices.iter().enumerate().all(|(i, j)| i == *j as usize)
            {
                gates.push(x);
            } else {
                let slice = self.get_slice_for_indices(indices);
                let x = self.add_gate(Gate::ExtractFromBatch { x, slice })?;
                gates.push(x);
            }
        }

        // If there is only a single extract gate, return its index directly. Otherwise, add a batch
        // collect gate.
        let idx = if gates.len() == 1 {
            gates[0]
        } else {
            self.add_gate(Gate::CollectToBatch { wires: gates })?
        };

        Ok(idx)
    }

    fn get_slice_for_indices(&mut self, mut indices: Vec<u32>) -> Slice {
        #[inline]
        fn hash_indices(indices: &[u32]) -> u64 {
            let mut hasher = DefaultHasher::new();
            indices.iter().for_each(|idx| hasher.write_u32(*idx));
            hasher.finish()
        }

        // If the slice is small, optimize it directly and do not cache it. Otherwise, caching the
        // optimized slice is empirically more efficient.
        const THRESHOLD: usize = 100;
        if indices.len() < THRESHOLD {
            Slice::from_indices(indices)
        } else {
            // Shift the indices by the smallest index because shifted slices are equivalent
            let shift = *indices.iter().min().expect("unexpected empty indices");
            if shift > 0 {
                indices.iter_mut().for_each(|idx| {
                    *idx -= shift;
                });
            }

            // Try to find in the cache the optimized slice for the shifted indices. If the slice is
            // not found, optimize it and cache it.
            let id = hash_indices(&indices);
            if let Some(slice) = self.indices_id_to_slice.get(&id) {
                if izip_eq!(slice.get_indices(), &indices).all(|(x, y)| x == *y) {
                    let mut slice = slice.clone();
                    slice.shift_start(shift);
                    return slice;
                }
            }

            let mut slice = Slice::from_indices(indices);
            self.indices_id_to_slice.insert(id, slice.clone());
            slice.shift_start(shift);
            slice
        }
    }
}