arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        actually_used_field::ActuallyUsedField,
        bounds::FieldBounds,
        circuits::traits::general_circuit::GeneralCircuit,
        expressions::expr::EvalFailure,
        global_value::value::FieldValue,
    },
    traits::{GetBit, Select},
    utils::field::{BaseField, ScalarField},
};

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct ConversionCircuit;

impl Default for ConversionCircuit {
    fn default() -> Self {
        Self
    }
}

impl ConversionCircuit {
    fn signed(&self) -> bool {
        false
    }
    fn convert_plaintext<F: ActuallyUsedField, G: ActuallyUsedField>(&self, f_val: F) -> G {
        f_val.to_unsigned_number().into()
    }
    fn convert_bounds<F: ActuallyUsedField, G: ActuallyUsedField>(
        &self,
        f_val: FieldBounds<F>,
    ) -> FieldBounds<G> {
        let (min_abs, max_abs) = f_val.to_unsigned_number_pair();
        if &max_abs - &min_abs < G::modulus() {
            FieldBounds::new(min_abs.into(), max_abs.into())
        } else {
            FieldBounds::All
        }
    }
    fn convert_mpc<F: ActuallyUsedField, G: ActuallyUsedField>(
        &self,
        f_val: FieldValue<F>,
    ) -> FieldValue<G> {
        let mut result: FieldValue<G> = 0.into();
        for idx in 0usize..(F::NUM_BITS as usize) {
            let bit_val = f_val.get_bit(idx, self.signed());
            result += bit_val.select(
                FieldValue::from(G::power_of_two(idx)),
                FieldValue::<G>::from(0),
            );
        }
        result
    }
}

impl GeneralCircuit for ConversionCircuit {
    fn eval(
        &self,
        scalars: Vec<ScalarField>,
        bases: Vec<BaseField>,
    ) -> Result<(Vec<ScalarField>, Vec<BaseField>), EvalFailure> {
        let new_bases = scalars
            .into_iter()
            .map(|scalar| self.convert_plaintext(scalar))
            .collect();
        let new_scalars = bases
            .into_iter()
            .map(|base| self.convert_plaintext(base))
            .collect();
        Ok((new_bases, new_scalars))
    }

    fn bounds(
        &self,
        scalar_bounds: Vec<FieldBounds<ScalarField>>,
        base_bounds: Vec<FieldBounds<BaseField>>,
    ) -> (Vec<FieldBounds<ScalarField>>, Vec<FieldBounds<BaseField>>) {
        let new_base_bounds = scalar_bounds
            .into_iter()
            .map(|bounds| self.convert_bounds(bounds))
            .collect();
        let new_scalar_bounds = base_bounds
            .into_iter()
            .map(|bounds| self.convert_bounds(bounds))
            .collect();
        (new_base_bounds, new_scalar_bounds)
    }

    fn run(
        &self,
        scalar_vals: Vec<FieldValue<ScalarField>>,
        base_vals: Vec<FieldValue<BaseField>>,
    ) -> (Vec<FieldValue<ScalarField>>, Vec<FieldValue<BaseField>>) {
        let new_base_vals = scalar_vals
            .into_iter()
            .map(|val| self.convert_mpc(val))
            .collect();
        let new_scalar_vals = base_vals
            .into_iter()
            .map(|val| self.convert_mpc(val))
            .collect();
        (new_base_vals, new_scalar_vals)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::core::circuits::traits::general_circuit::tests::TestedGeneralCircuit;
    use rand::Rng;

    fn gen_vec_len<R: Rng + ?Sized>(rng: &mut R) -> usize {
        let mut result = rng.gen_bool(0.75) as usize;
        while rng.gen_bool(0.25) {
            result += 1;
        }
        result
    }

    impl TestedGeneralCircuit for ConversionCircuit {
        fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
            ConversionCircuit
        }

        fn gen_n_scalars<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
            gen_vec_len(rng)
        }

        fn gen_n_bases<R: Rng + ?Sized>(&self, rng: &mut R, _n_scalars: usize) -> usize {
            gen_vec_len(rng)
        }
    }

    #[test]
    fn tested() {
        ConversionCircuit::test(1, 16)
    }
}