arcium-primitives 0.4.1

Arcium primitives
Documentation
use core::iter::{Product, Sum};
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};

use ff::Field;
use hybrid_array::Array;
use serde::{Deserialize, Serialize};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
use typenum::U1;
use wincode::{SchemaRead, SchemaWrite};

use crate::algebra::{
    field::FieldExtension,
    ops::{AccReduce, DefaultDotProduct, IntoWide, MulAccReduce, ReduceWide},
    uniform_bytes::FromUniformBytes,
};

// TODO: see if changing u8 to Choice makes more sense
#[derive(
    Clone,
    Copy,
    Debug,
    Eq,
    PartialEq,
    Hash,
    Deserialize,
    Serialize,
    SchemaWrite,
    SchemaRead,
    Ord,
    PartialOrd,
)]
#[repr(transparent)]
pub struct Gf2(pub(super) u8);

impl Default for Gf2 {
    fn default() -> Self {
        Gf2::ZERO
    }
}

impl IntoWide for Gf2 {
    #[inline]
    fn to_wide(&self) -> Gf2 {
        *self
    }

    #[inline]
    fn zero_wide() -> Gf2 {
        Gf2::ZERO
    }
}

impl ReduceWide for Gf2 {
    fn reduce_mod_order(a: Self) -> Self {
        a
    }
}

// Dot product: Gf2 x Gf2
impl MulAccReduce for Gf2 {
    type WideType = Self;

    fn mul_acc(acc: &mut Self::WideType, a: Self, b: Self) {
        acc.0 ^= a.0 & b.0;
    }
}

impl DefaultDotProduct for Gf2 {}

// Dot product: Gf2 x &Gf2
impl<'a> MulAccReduce<Self, &'a Self> for Gf2 {
    type WideType = Self;

    fn mul_acc(acc: &mut Self::WideType, a: Self, b: &'a Self) {
        acc.0 ^= a.0 & b.0;
    }
}

impl DefaultDotProduct<Self, &Self> for Gf2 {}

// Dot product: &Gf2 x Gf2
impl<'a> MulAccReduce<&'a Self, Self> for Gf2 {
    type WideType = Self;

    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: Self) {
        acc.0 ^= a.0 & b.0;
    }
}

impl DefaultDotProduct<&Self, Self> for Gf2 {}

// Dot product: &Gf2 x &Gf2
impl<'a, 'b> MulAccReduce<&'a Self, &'b Self> for Gf2 {
    type WideType = Self;

    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: &'b Self) {
        acc.0 ^= a.0 & b.0;
    }
}

impl DefaultDotProduct<&Self, &Self> for Gf2 {}

impl AccReduce for Gf2 {
    type WideType = Gf2;

    #[inline]
    fn acc(acc: &mut Self::WideType, a: Self) {
        acc.0 ^= a.0;
    }
}

impl AccReduce<&Self> for Gf2 {
    type WideType = Gf2;

    #[inline]
    fn acc(acc: &mut Self::WideType, a: &Self) {
        acc.0 ^= a.0;
    }
}

impl ff::Field for Gf2 {
    const ZERO: Self = Self(0);
    const ONE: Self = Self(1);

    fn random(mut rng: impl rand::RngCore) -> Self {
        let mut tmp = [0u8; 1];
        rng.fill_bytes(&mut tmp);
        Self(tmp[0] & 1)
    }

    fn square(&self) -> Self {
        Self(self.0)
    }

    fn double(&self) -> Self {
        Self::ZERO
    }

    fn invert(&self) -> CtOption<Self> {
        CtOption::new(*self, self.ct_eq(&Self::ONE))
    }

    fn sqrt_ratio(_num: &Self, _div: &Self) -> (Choice, Self) {
        unimplemented!()
    }
}

impl FieldExtension for Gf2 {
    type Subfield = Self;

    type Degree = U1;
    type FieldBitSize = U1;
    type FieldBytesSize = U1;

    fn to_subfield_elements(&self) -> impl ExactSizeIterator<Item = Self::Subfield> {
        std::iter::once(*self)
    }

    fn from_subfield_elements(elems: &[Self::Subfield]) -> Option<Self> {
        if elems.len() == 1 {
            elems.first().copied()
        } else {
            None
        }
    }

    fn to_le_bytes(&self) -> Array<u8, Self::FieldBytesSize> {
        [self.0].into()
    }

    fn from_le_bytes(bytes: &[u8]) -> Option<Self> {
        bytes
            .first()
            .and_then(|&byte| if byte <= 1 { Some(Self(byte)) } else { None })
    }

    fn mul_by_subfield(&self, other: &Self::Subfield) -> Self {
        self * other
    }

    fn generator() -> Self {
        Self::ONE
    }
}

// === Field traits implementations === //

impl ConditionallySelectable for Gf2 {
    #[inline]
    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
        Gf2(u8::conditional_select(&a.0, &b.0, choice))
    }
}

impl ConstantTimeEq for Gf2 {
    #[inline]
    fn ct_eq(&self, other: &Self) -> Choice {
        self.0.ct_eq(&other.0)
    }
}

#[macros::op_variants(borrowed)]
impl Neg for Gf2 {
    type Output = Gf2;

    #[inline]
    fn neg(self) -> Self::Output {
        self
    }
}

#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl Add<&Gf2> for Gf2 {
    type Output = Gf2;

    #[inline]
    #[allow(clippy::suspicious_arithmetic_impl)]
    fn add(self, rhs: &Gf2) -> Self::Output {
        Gf2(self.0 ^ rhs.0)
    }
}

#[macros::op_variants(owned)]
impl AddAssign<&Gf2> for Gf2 {
    #[inline]
    fn add_assign(&mut self, rhs: &Gf2) {
        *self = *self + rhs;
    }
}

#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl Sub<&Gf2> for Gf2 {
    type Output = Gf2;

    #[inline]
    #[allow(clippy::suspicious_arithmetic_impl)]
    fn sub(self, rhs: &Gf2) -> Self::Output {
        self + rhs
    }
}

#[macros::op_variants(owned)]
impl SubAssign<&Gf2> for Gf2 {
    #[inline]
    fn sub_assign(&mut self, rhs: &Gf2) {
        *self = *self - rhs;
    }
}

#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl Mul<&Gf2> for Gf2 {
    type Output = Gf2;

    #[inline]
    #[allow(clippy::suspicious_arithmetic_impl)]
    fn mul(self, rhs: &Gf2) -> Self::Output {
        Gf2(self.0 & rhs.0)
    }
}
#[macros::op_variants(owned)]
impl<'a> MulAssign<&'a Gf2> for Gf2 {
    #[inline]
    fn mul_assign(&mut self, rhs: &'a Gf2) {
        *self = *self * rhs;
    }
}

impl Sum for Gf2 {
    #[inline]
    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
        iter.fold(Gf2::ZERO, |a, b| a + b)
    }
}

impl<'a> Sum<&'a Gf2> for Gf2 {
    #[inline]
    fn sum<I: Iterator<Item = &'a Gf2>>(iter: I) -> Self {
        iter.fold(Gf2::ZERO, |a, b| a + b)
    }
}

impl Product for Gf2 {
    #[inline]
    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
        iter.fold(Gf2::ONE, |a, b| a * b)
    }
}

impl<'a> Product<&'a Gf2> for Gf2 {
    #[inline]
    fn product<I: Iterator<Item = &'a Gf2>>(iter: I) -> Self {
        iter.fold(Gf2::ONE, |a, b| a * b)
    }
}

impl From<Gf2> for bool {
    fn from(value: Gf2) -> Self {
        value.0 == 1
    }
}

impl From<&Gf2> for bool {
    fn from(value: &Gf2) -> Self {
        value.0 == 1
    }
}

impl From<bool> for Gf2 {
    fn from(value: bool) -> Self {
        Gf2(value.into())
    }
}

impl From<&bool> for Gf2 {
    fn from(value: &bool) -> Self {
        (*value).into()
    }
}

impl From<u8> for Gf2 {
    fn from(val: u8) -> Self {
        Gf2(val & 1)
    }
}

impl From<Gf2> for u8 {
    fn from(value: Gf2) -> Self {
        value.0
    }
}

impl From<&Gf2> for u64 {
    fn from(value: &Gf2) -> Self {
        value.0 as u64
    }
}

impl From<u64> for Gf2 {
    fn from(val: u64) -> Self {
        Gf2((val & 1) as u8)
    }
}

impl From<u128> for Gf2 {
    fn from(val: u128) -> Self {
        Gf2((val & 1) as u8)
    }
}

impl From<Choice> for Gf2 {
    fn from(value: Choice) -> Self {
        Gf2(value.unwrap_u8())
    }
}

impl From<&Choice> for Gf2 {
    fn from(value: &Choice) -> Self {
        (*value).into()
    }
}

impl From<Gf2> for Choice {
    fn from(value: Gf2) -> Self {
        value.0.into()
    }
}

impl From<&Gf2> for Choice {
    fn from(value: &Gf2) -> Self {
        value.0.into()
    }
}

impl FromUniformBytes for Gf2 {
    type UniformBytes = U1;

    fn from_uniform_bytes(bytes: &Array<u8, Self::UniformBytes>) -> Self {
        Gf2(bytes[0] & 1)
    }
}

impl AsRef<[u8]> for Gf2 {
    fn as_ref(&self) -> &[u8] {
        unsafe {
            std::slice::from_raw_parts(self as *const Gf2 as *const u8, std::mem::size_of::<Gf2>())
        }
    }
}