arcis-compiler 0.9.7

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    traits::{FromLeBytes, Invert, Pow},
    utils::{matrix::Matrix, number::Number},
};
use ff::{
    derive::bitvec::{order::Lsb0, view::AsBits},
    PrimeField,
};
use num_bigint::{BigInt, BigUint};
use num_traits::Zero;
use rand::Rng;
use std::{cmp::Ordering, hash::Hash};

/// A collection of functions on prime fields.
/// A lot of these functions are linked to the different orders one can define over a cyclic group.
pub trait UsedField:
    PrimeField
    + Hash
    + PartialOrd
    + From<Number>
    + From<i32>
    + From<bool>
    + From<f64>
    + Zero
    + std::ops::Shr<usize, Output = Self>
    + FromLeBytes
{
    /// The prime number p such that the field is F_p.
    fn modulus() -> Number;

    /// The smallest prime such that it does not divide p-1.
    fn get_alpha() -> Number;

    /// The smallest positive integer n such that alpha*n = 1 mod (p-1).
    fn get_alpha_inverse() -> Number;

    /// An MDS matrix and its inverse.
    fn mds_matrix_and_inverse(width: usize) -> (Matrix<Self>, Matrix<Self>);

    /// Returns two^exponent.
    fn power_of_two(exponent: usize) -> Self;

    /// Returns -two^exponent.
    fn negative_power_of_two(exponent: usize) -> Self {
        Self::ZERO - Self::power_of_two(exponent)
    }

    fn to_unsigned_number(self) -> Number {
        BigInt::from(BigUint::from_bytes_le(self.to_repr().as_ref())).into()
    }

    fn to_signed_number(self) -> Number {
        if self.is_ge_zero() {
            self.to_unsigned_number()
        } else {
            -(Self::ZERO - self).to_unsigned_number()
        }
    }

    /// Whether a number is binary or not.
    fn is_binary(self) -> bool {
        self <= Self::ONE
    }

    /// Whether a number is greater or equal to zero according to the signed order.
    #[inline(always)]
    fn is_ge_zero(self) -> bool {
        // should be equivalent to self <= Self::ZERO - self (see test)
        self < Self::TWO_INV
    }

    /// Whether a number is less than or equal to zero according to the signed order.
    fn is_le_zero(self) -> bool {
        self >= Self::ZERO - self
    }

    /// Whether a number is greater than zero according to the signed order.
    #[inline(always)]
    fn is_gt_zero(self) -> bool {
        !self.is_le_zero()
    }

    /// Whether a number is less than zero according to the signed order.
    #[inline(always)]
    fn is_lt_zero(self) -> bool {
        !self.is_ge_zero()
    }

    /// Max according to the cyclic order on the smaller interval between the two field elements.
    fn max_cyclic(self, other: Self) -> (Self, bool) {
        if (other - self).is_ge_zero() {
            (other, true)
        } else {
            (self, false)
        }
    }
    /// Min according to the cyclic order on the smaller interval between the two field elements.
    fn min_cyclic(self, other: Self) -> (Self, bool) {
        if (other - self).is_ge_zero() {
            (self, false)
        } else {
            (other, true)
        }
    }
    /// Max according to (un)signed order.
    fn max(self, other: Self, signed: bool) -> Self {
        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
        if self - offset < other - offset {
            other
        } else {
            self
        }
    }
    /// Min according to (un)signed order.
    fn min(self, other: Self, signed: bool) -> Self {
        let offset = if signed { Self::TWO_INV } else { Self::ZERO };
        if self - offset > other - offset {
            other
        } else {
            self
        }
    }
    /// Sort according to the cyclic order on the smaller interval between the two field elements.
    fn sort_pair(self, other: Self) -> (Self, Self) {
        if (other - self).is_ge_zero() {
            (self, other)
        } else {
            (other, self)
        }
    }
    /// Abs according to the signed order.
    fn abs(self) -> Self {
        if self.is_ge_zero() {
            self
        } else {
            Self::ZERO - self
        }
    }
    // assuming self and other >= 0
    fn does_mul_overflow(self, other: Self) -> bool {
        if self.is_zero_vartime() || other.is_zero_vartime() {
            return false;
        }
        let prod = self.to_unsigned_number() * other.to_unsigned_number();
        prod >= Self::modulus()
    }
    fn does_add_signed_overflow(self, other: Self) -> bool {
        let sum = self + other;
        match (self.is_ge_zero(), other.is_ge_zero()) {
            (true, true) => sum.is_lt_zero(),
            (true, false) => false,
            (false, true) => false,
            (false, false) => sum.is_ge_zero(),
        }
    }
    fn does_add_unsigned_overflow(self, other: Self) -> bool {
        if self == Self::ZERO || other == Self::ZERO {
            false
        } else {
            self >= -other
        }
    }
    /// The number of bits of self in unsigned notation.
    fn unsigned_bits(self) -> usize {
        let binding = self.to_repr();
        let bits = binding.as_bits::<Lsb0>();
        bits.len() - bits.trailing_zeros()
    }
    /// The number of bits of self in signed notation.
    fn signed_bits(self) -> usize {
        self.abs().unsigned_bits()
    }
    /// The idx bit of self in unsigned notation.
    fn unsigned_bit(&self, idx: usize) -> bool {
        let repr = self.to_repr();
        let bits = repr.as_bits::<Lsb0>();
        if idx < bits.len() {
            bits[idx]
        } else {
            false
        }
    }
    /// The idx bit of self in signed notation.
    fn signed_bit(&self, idx: usize) -> bool {
        if self.is_ge_zero() {
            self.unsigned_bit(idx)
        } else {
            !(self.abs() - Self::ONE).unsigned_bit(idx)
        }
    }
    /// The unsigned Euclidean division. Returns 0 if the divisor is 0.
    fn unsigned_euclidean_division(self, other: Self) -> Self {
        if other == Self::ZERO {
            Self::ZERO
        } else {
            (self.to_unsigned_number() / other.to_unsigned_number()).into()
        }
    }
    /// The signed Euclidean division. Returns 0 if the divisor is 0.
    fn signed_euclidean_division(self, other: Self) -> Self {
        if other == Self::ZERO {
            Self::ZERO
        } else {
            (self.to_signed_number() / other.to_signed_number()).into()
        }
    }
    /// Generates a field element between min and max, included.
    fn gen_inclusive_range<R: Rng + ?Sized>(rng: &mut R, min: Self, max: Self) -> Self {
        min + Self::from(Number::gen_range(
            rng,
            &0.into(),
            &((max - min).to_unsigned_number() + 1),
        ))
    }

    /// Converts a number in lsb-to-msb binary expansion to the corresponding element in Self.
    fn from_bin(bin: &str) -> Self {
        Self::from(
            bin.chars()
                .enumerate()
                .fold(Number::from(0), |acc, (i, c)| {
                    if c == '1' {
                        acc + Number::power_of_two(i)
                    } else {
                        acc
                    }
                }),
        )
    }

    /// Converts self to its lsb-to-msb binary expansion.
    fn to_bin(&self) -> String {
        (0..Self::modulus().bits()).fold(String::new(), |mut acc, i| {
            if self.unsigned_bit(i) {
                acc.push('1');
            } else {
                acc.push('0');
            }
            acc
        })
    }
    fn as_power_of_two(self) -> Option<usize> {
        if self == Self::ZERO {
            return None;
        }
        let mut min_possible_exponent = 0usize;
        let mut max_possible_exponent = Self::CAPACITY as usize;
        while max_possible_exponent >= min_possible_exponent {
            let mid = (min_possible_exponent + max_possible_exponent) / 2;
            match self.partial_cmp(&Self::power_of_two(mid)) {
                None => panic!("order should be total"),
                Some(Ordering::Less) => {
                    max_possible_exponent = mid - 1;
                }
                Some(Ordering::Equal) => return Some(mid),
                Some(Ordering::Greater) => {
                    min_possible_exponent = mid + 1;
                }
            }
        }
        None
    }
    fn signed_gt(self, other: Self) -> bool {
        self.max(other, true) != other
    }
    fn signed_ge(self, other: Self) -> bool {
        self.max(other, true) == self
    }
    fn signed_lt(self, other: Self) -> bool {
        self.min(other, true) != other
    }
    fn signed_le(self, other: Self) -> bool {
        self.min(other, true) == self
    }
}

impl<F: UsedField> Invert for F {
    fn invert(self, _is_expected_non_zero: bool) -> Self {
        F::invert(&self).unwrap_or(F::ZERO)
    }
}

impl<F: UsedField> Pow for F {
    fn pow(self, e: &Number, _is_expected_non_zero: bool) -> Self {
        let e = e % (F::modulus() - 1);
        let mut e_u64 = [0u64; 4];
        let bytes: [u8; 32] = e.into();
        for (i, chunk) in bytes.chunks_exact(8).enumerate() {
            e_u64[i] = u64::from_le_bytes(chunk.try_into().unwrap());
        }

        F::pow(&self, e_u64)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::utils::field::ScalarField;
    use ff::Field;
    #[test]
    fn is_ge_zero() {
        for n in [
            ScalarField::ZERO,
            ScalarField::ONE,
            ScalarField::TWO_INV - ScalarField::ONE,
            ScalarField::TWO_INV,
            ScalarField::ZERO - ScalarField::ONE,
        ] {
            assert_eq!(n.is_ge_zero(), n <= ScalarField::ZERO - n)
        }
    }
}