arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        bounds::FieldBounds,
        circuits::{
            boolean::{boolean_value::BooleanValue, utils::decoder_circuit},
            traits::arithmetic_circuit::ArithmeticCircuit,
        },
        expressions::{expr::EvalFailure, field_expr::FieldExpr},
        global_value::value::FieldValue,
    },
    traits::{GetBit, Select},
    utils::used_field::UsedField,
};
use num_traits::ToPrimitive;

/// One-hot-encodes the non-negative input x and returns a boolean vector of length len
/// with a single 1 at position x.
pub fn one_hot_encode<F: ActuallyUsedField>(x: FieldValue<F>, len: usize) -> Vec<BooleanValue> {
    let bounds = x.bounds();
    let (min, max) = bounds.min_and_max(true);
    if max.is_lt_zero() || (min.is_ge_zero() && F::from(len as u64) <= min) {
        return vec![BooleanValue::from(false); len];
    }
    let mut ohe = if min.eq(&max) {
        vec![BooleanValue::from(true)]
    } else if min.is_ge_zero() {
        // x ranges between min and len-1 (capped at max)
        // we first subtract min, i.e. x_sub ranges between 0 and len-1-min (capped at max-min)
        let x_sub = x - FieldValue::from(min);
        let gap = (len - 1 - min.to_unsigned_number().to_usize().unwrap()).min(
            (max - min)
                .to_unsigned_number()
                .to_usize()
                .unwrap_or(usize::MAX),
        );
        let bin_size = gap.max(1).ilog2() as usize + 1;
        let bits = (0..bin_size)
            .map(|i| x_sub.get_bit(i, false))
            .collect::<Vec<BooleanValue>>();
        // TODO: optimize the case where gap is a power of 2
        decoder_circuit(bits)
            .into_iter()
            .take(gap + 1)
            .collect::<Vec<BooleanValue>>()
    } else {
        let bin_size = bounds.bin_size(true);
        let bits = (0..bin_size)
            .map(|i| x.get_bit(i, true))
            .collect::<Vec<BooleanValue>>();
        let n_bits_to_keep = F::from(len as u64 - 1)
            .min(max, false)
            .to_unsigned_number()
            .to_usize()
            .unwrap()
            .max(1)
            .ilog2() as usize
            + 1;
        // TODO: optimize the case where len - 1 is a power of 2
        decoder_circuit(
            bits.into_iter()
                .take(n_bits_to_keep)
                .collect::<Vec<BooleanValue>>(),
        )
        .into_iter()
        .take(
            F::from(len as u64)
                .min(max + F::ONE, false)
                .to_unsigned_number()
                .to_usize()
                .unwrap(),
        )
        .collect::<Vec<BooleanValue>>()
    };
    let ohe_len = ohe.len();

    if min.is_ge_zero() {
        let mut res = vec![BooleanValue::from(false); min.to_unsigned_number().to_usize().unwrap()];
        res.append(&mut ohe);
        res.append(&mut vec![
            BooleanValue::from(false);
            len - min.to_unsigned_number().to_usize().unwrap()
                - ohe_len
        ]);
        res
    } else {
        let mut res = ohe;
        res.append(&mut vec![BooleanValue::from(false); (len - ohe_len).max(0)]);
        res
    }
}

/// Performs the indexing (`container[index]`) operation.
#[derive(Clone, Debug, Default)]
pub struct Index;

impl Index {
    pub fn index<F: ActuallyUsedField>(
        container: Vec<FieldValue<F>>,
        index: FieldValue<F>,
    ) -> FieldValue<F> {
        let bounds = index.bounds();
        let signed_min = bounds.signed_min();
        if bounds.signed_max().is_lt_zero()
            || (signed_min.is_ge_zero() && signed_min >= F::from(container.len() as u64))
        {
            FieldValue::from(F::ZERO)
        } else if bounds.signed_max().eq(&F::ZERO) {
            container[0]
        } else if bounds.signed_min().eq(&F::from(container.len() as u64 - 1)) {
            *container.last().unwrap()
        } else {
            let index_ohe = one_hot_encode(index, container.len());
            let zero = FieldValue::<F>::from(0);
            container
                .into_iter()
                .zip(index_ohe)
                .fold(zero, |acc, (x, b)| acc + b.select(x, zero))
        }
    }
}

impl<F: UsedField> ArithmeticCircuit<F> for Index {
    fn eval(&self, mut x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
        let index = x.pop().unwrap();
        if index.is_lt_zero() || (index - F::from((x.len() - 1) as u64)).is_gt_zero() {
            return EvalFailure::err_ub("index out of range");
        }
        Ok(vec![x[index.to_unsigned_number().to_usize().unwrap()]])
    }

    fn bounds(&self, mut bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
        let bounds_index = bounds.pop().unwrap();
        let signed_min = bounds_index.signed_min();
        let signed_max = bounds_index.signed_max();
        let res_bounds = if signed_max.is_lt_zero()
            || (signed_min.is_ge_zero() && signed_min >= F::from(bounds.len() as u64))
        {
            FieldBounds::new(F::ZERO, F::ZERO)
        } else if signed_max == F::ZERO {
            bounds[0]
        } else if signed_min == F::from((bounds.len() - 1) as u64) {
            *bounds.last().unwrap()
        } else {
            let init_bounds =
                if signed_min.is_lt_zero() || signed_max > F::from((bounds.len() - 1) as u64) {
                    // in this case index_ohe might consist of 0's only
                    FieldBounds::new(F::ZERO, F::ZERO)
                } else {
                    FieldBounds::Empty
                };
            bounds
                .into_iter()
                .fold(init_bounds, |acc, bds| acc.union(bds))
        };
        vec![res_bounds]
    }

    fn run(&self, mut vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
    where
        F: ActuallyUsedField,
    {
        let index = vals.pop().unwrap();
        vec![Self::index(vals, index)]
    }
}

/// Performs the operation `container[index] = op(container[index], value)`,
/// where `op(a, b)` is one of:
/// - op(a, b) = b
/// - op(a, b) = a + b
/// - op(a, b) = a - b
/// - op(a, b) = a * b
/// - op(a, b) = a / b
/// - op(a, b) = a % b
#[derive(Clone, Debug)]
pub struct IndexOpAssign<F: ActuallyUsedField> {
    op: FieldExpr<F, bool>,
}

impl<F: ActuallyUsedField> IndexOpAssign<F> {
    #[allow(unused)]
    pub fn new(op: FieldExpr<F, bool>) -> Self {
        IndexOpAssign { op }
    }
}

impl<F: ActuallyUsedField> IndexOpAssign<F> {
    pub fn index_op_assign(
        &self,
        container: Vec<FieldValue<F>>,
        index: FieldValue<F>,
        value: FieldValue<F>,
    ) -> Vec<FieldValue<F>> {
        let bounds = index.bounds();
        let signed_min = bounds.signed_min();
        let signed_max = bounds.signed_max();
        if signed_max.is_lt_zero()
            || signed_min.is_ge_zero() && signed_min >= F::from(container.len() as u64)
        {
            container
        } else if signed_max == F::ZERO {
            let mut res = container;
            let arr = [res[0], value];
            res[0] = FieldValue::new(self.op.clone().apply(|i| arr[i as usize]));
            res
        } else if signed_min == F::from(container.len() as u64 - 1) {
            let len = container.len();
            let mut res = container;
            let arr = [res[len - 1], value];
            res[len - 1] = FieldValue::new(self.op.clone().apply(|i| arr[i as usize]));
            res
        } else {
            let index_ohe = one_hot_encode(index, container.len());
            match self.op {
                FieldExpr::Add(_, _) | FieldExpr::Sub(_, _) => container
                    .into_iter()
                    .zip(index_ohe)
                    .map(|(x, b)| {
                        let arr = [x, value];
                        let new_value = FieldValue::new(self.op.clone().apply(|i| arr[i as usize]));
                        b.select(new_value, x)
                    })
                    .collect::<Vec<FieldValue<F>>>(),
                _ => {
                    let arr = [Index::index(container.clone(), index), value];
                    let new_value = FieldValue::new(self.op.clone().apply(|i| arr[i as usize]));
                    container
                        .into_iter()
                        .zip(index_ohe)
                        .map(|(x, b)| b.select(new_value, x))
                        .collect::<Vec<FieldValue<F>>>()
                }
            }
        }
    }
}

impl<F: ActuallyUsedField> ArithmeticCircuit<F> for IndexOpAssign<F> {
    fn eval(&self, mut x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
        let index = x.pop().unwrap();
        let value = x.pop().unwrap();
        if index >= F::from(x.len() as u64) {
            return EvalFailure::err_ub("index out of range");
        }
        let index = index.to_unsigned_number().to_usize().unwrap();
        let arr = [x[index], value];
        x[index] = self.op.clone().apply(|i| arr[i as usize]).eval()?;
        Ok(x)
    }

    fn bounds(&self, mut bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
        let bounds_index = bounds.pop().unwrap();
        let bounds_value = bounds.pop().unwrap();
        let signed_min = bounds_index.signed_min();
        let signed_max = bounds_index.signed_max();
        if signed_max.is_lt_zero()
            || signed_min.is_ge_zero() && signed_min >= F::from(bounds.len() as u64)
        {
            bounds
        } else if signed_max == F::ZERO {
            let mut res = bounds;
            let arr = [res[0], bounds_value];
            res[0] = self.op.clone().apply(|i| arr[i as usize]).bounds();
            res
        } else if signed_min == F::from(bounds.len() as u64 - 1) {
            let len = bounds.len();
            let mut res = bounds;
            let arr = [res[len - 1], bounds_value];
            res[len - 1] = self.op.clone().apply(|i| arr[i as usize]).bounds();
            res
        } else {
            let len: usize = bounds.len();
            let container_bounds = bounds;
            let bit_bounds = FieldBounds::from((F::ZERO, F::ONE));
            let mut res_bounds = container_bounds
                .into_iter()
                .map(|x_bounds| {
                    let arr = [x_bounds, bounds_value];
                    let new_bounds = self.op.clone().apply(|i| arr[i as usize]).bounds();
                    FieldExpr::bounds(FieldExpr::Where(bit_bounds, new_bounds, x_bounds))
                })
                .collect::<Vec<FieldBounds<F>>>();
            if (bounds_index.signed_min() - F::ZERO).is_lt_zero()
                || (bounds_index.signed_max() - F::from((len - 1) as u64)).is_gt_zero()
            {
                // in this case index_ohe might consist of 0's only
                res_bounds = res_bounds
                    .into_iter()
                    .map(|bounds| bounds.union(FieldBounds::new(F::ZERO, F::ZERO)))
                    .collect::<Vec<FieldBounds<F>>>()
            }
            res_bounds
        }
    }

    fn run(&self, mut vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
    where
        F: ActuallyUsedField,
    {
        let index = vals.pop().unwrap();
        let value = vals.pop().unwrap();
        self.index_op_assign(vals, index, value)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        core::{
            circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
            expressions::field_expr::expr_lincomb,
        },
        utils::field::ScalarField,
    };
    use rand::Rng;

    impl TestedArithmeticCircuit<ScalarField> for Index {
        fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
            Self
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
            let mut result = if rng.gen_bool(0.125) { 300 } else { 9 };
            while rng.gen_bool(0.875) {
                result += 3;
            }
            result
        }
    }

    impl TestedArithmeticCircuit<ScalarField> for IndexOpAssign<ScalarField> {
        fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
            // TODO: improve so that all different functions are generated.
            let variant = rng.gen_range(0..6);
            match variant {
                0 => Self::new(expr_lincomb!((true, 1);0)), // container[index] = val
                1 => Self::new(FieldExpr::Add(false, true)), // container[index] += val
                2 => Self::new(FieldExpr::Sub(false, true)), // container[index] -= val
                3 => Self::new(FieldExpr::Mul(false, true)), // container[index] *= val
                4 => Self::new(FieldExpr::Div(false, true)), // container[index] /= val
                5 => Self::new(FieldExpr::Rem(false, true)), // container[index] %= val
                _ => unreachable!("variant more than 5 should not be possible"),
            }
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
            let mut result = if rng.gen_bool(0.125) { 300 } else { 9 };
            while rng.gen_bool(0.875) {
                result += 3;
            }
            result
        }
    }

    #[test]
    fn tested_index() {
        Index::test(64, 4)
    }

    #[test]
    fn tested_index_op_assign() {
        IndexOpAssign::test(64, 4)
    }
}