arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        bounds::FieldBounds,
        circuits::{
            boolean::{
                boolean_value::BooleanValue,
                byte::Byte,
                utils::{addition_circuit, CircuitType},
            },
            f64::utils::F64,
            traits::f64_circuit::F64Circuit,
        },
        expressions::expr::EvalFailure,
        global_value::value::FieldValue,
    },
    traits::{FromLeBits, GetBit},
    utils::{field::BaseField, used_field::UsedField},
};
use core::panic;
use ff::Field;

// 12-bit signed binary expansion of -1023.
#[allow(dead_code)]
const EXPONENT_OFFSET: [bool; 12] = [
    true, false, false, false, false, false, false, false, false, false, true, true,
];

#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct F64Mul;

impl F64Mul {
    #[allow(dead_code)]
    pub fn mul(lhs: F64, rhs: F64) -> F64 {
        let sign_res = lhs.sign ^ rhs.sign;

        let mantissa_lhs = FieldValue::from(BaseField::power_of_two(52)) + lhs.mantissa;
        let mantissa_rhs = FieldValue::from(BaseField::power_of_two(52)) + rhs.mantissa;
        // the product of two 53-bit numbers is either 105 or 106 bits long
        let mantissa_prod = mantissa_lhs * mantissa_rhs;
        let mantissa_prod_bits = (0..106)
            .map(|i| mantissa_prod.get_bit(i, false))
            .collect::<Vec<BooleanValue>>();
        // is_full_length if mantissa_prod is 106 bits long
        let is_full_length = *mantissa_prod_bits.last().unwrap();
        // is_not_full_length if mantissa_prod is 105 bits long
        let is_not_full_length = !is_full_length;
        // this is the least significant bit in case is_not_full_length
        let lsb_bit = mantissa_prod_bits[52];
        let lsb_offset = is_not_full_length & lsb_bit;
        let mantissa_before_correction = FieldValue::<BaseField>::from_le_bits(
            mantissa_prod_bits
                .into_iter()
                .skip(53)
                .collect::<Vec<BooleanValue>>(),
            false,
        );
        // correction_term = 1 if is_full_length and correction_term = 2 if is_not_full_length
        let correction_term =
            FieldValue::<BaseField>::from_le_bits(vec![is_full_length, is_not_full_length], false);
        // mantissa is 53 bits long with a leading 1; it correspons to mantissa_prod >> 53 if
        // is_full_length and mantissa_prod >> 52 if is_not_full_length
        let mantissa = mantissa_before_correction * correction_term
            + FieldValue::<BaseField>::from(lsb_offset);
        let mantissa_res = (mantissa + FieldValue::from(BaseField::negative_power_of_two(52)))
            .with_bounds(FieldBounds::new(
                BaseField::ZERO,
                BaseField::power_of_two(52) - BaseField::ONE,
            ));

        // The exponent is a unsigned 11-bit integer, which has to be seen with a fixed
        // offset of -1023. Since we add two 'raw' exponents and then subtract
        // the offset we can consider the sum as a 12-bit signed integer.
        let mut exponent_lhs = lhs.exponent.to_vec();
        exponent_lhs.push(BooleanValue::from(false));
        let mut exponent_rhs = rhs.exponent.to_vec();
        exponent_rhs.push(BooleanValue::from(false));
        // TODO: expose add_bitnums_circuit(a, b, c)
        let tmp = addition_circuit(
            exponent_lhs,
            EXPONENT_OFFSET
                .into_iter()
                .map(BooleanValue::from)
                .collect::<Vec<BooleanValue>>(),
            // if is_full_length we have been right-shifting mantissa_prod by 53 and not 52
            // positions, hence we need to increase EXPONENT_OFFSET by 1 to compensate
            is_full_length,
            CircuitType::default(),
        );
        let mut exponent_res = addition_circuit(
            tmp,
            exponent_rhs,
            BooleanValue::from(false),
            CircuitType::default(),
        );
        let _is_nan = exponent_res.pop();

        let exponent_res = exponent_res
            .try_into()
            .unwrap_or_else(|v: Vec<BooleanValue>| {
                panic!("Expected a Vec of length 11 (found {})", v.len())
            });

        F64::new(sign_res, exponent_res, mantissa_res)
    }
}

impl F64Circuit for F64Mul {
    fn eval(&self, x: Vec<f64>) -> Result<Vec<f64>, EvalFailure> {
        if x.len() != 2 {
            panic!("F64Mul expects input Vec of length 2");
        }
        let lhs = x[0];
        let rhs = x[1];
        let lhs_bits = lhs
            .to_le_bytes()
            .into_iter()
            .flat_map(|byte| Byte::from(byte).to_vec())
            .collect::<Vec<bool>>();
        let rhs_bits = rhs
            .to_le_bytes()
            .into_iter()
            .flat_map(|byte| Byte::from(byte).to_vec())
            .collect::<Vec<bool>>();
        let exponent_lhs = lhs_bits[52..63]
            .iter()
            .enumerate()
            .fold(0i16, |acc, (i, b)| if *b { acc | (1 << i) } else { acc });
        let exponent_rhs = rhs_bits[52..63]
            .iter()
            .enumerate()
            .fold(0i16, |acc, (i, b)| if *b { acc | (1 << i) } else { acc });
        let exponent_res = exponent_lhs + exponent_rhs - 1023;
        // TODO: verify this claim
        // exponent = 0, 2046 and 2047 are reserved for NAN, inf, etc.
        if !(1..2046).contains(&exponent_lhs)
            || !(1..2046).contains(&exponent_rhs)
            || !(1..2046).contains(&exponent_res)
        {
            return EvalFailure::err_ub("inputs or product out of range");
        }
        Ok(vec![lhs * rhs])
    }

    fn rtol(&self) -> f64 {
        2f64.powi(-52)
    }

    fn run(&self, vals: Vec<F64>) -> Vec<F64> {
        if vals.len() != 2 {
            panic!("F64Mul expects input Vec of length 2");
        }
        let lhs = vals[0].clone();
        let rhs = vals[1].clone();
        vec![F64Mul::mul(lhs, rhs)]
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::core::circuits::traits::f64_circuit::tests::TestedF64Circuit;
    use rand::Rng;

    impl TestedF64Circuit for F64Mul {
        fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
            Self
        }

        fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
            2
        }
    }

    #[test]
    fn tested_f64_mul() {
        F64Mul::test(4, 16)
    }
}