arcis-compiler 0.9.4

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        circuits::boolean::{
            boolean_value::{Boolean, BooleanValue},
            utils::decoder_circuit,
        },
        global_value::value::FieldValue,
    },
    traits::{GetBit, Random, Reveal},
    utils::field::BaseField,
};
use core::panic;
use std::ops::{BitAnd, BitAndAssign, BitXor, BitXorAssign, Not};

/// Lsb-to-msb representation of a u8.
#[derive(Debug, Clone, Copy)]
pub struct Byte<B: Boolean>([B; 8]);

impl<B: Boolean> Byte<B> {
    pub fn new(bits: [B; 8]) -> Self {
        Self(bits)
    }

    pub fn get_bits(&self) -> [B; 8] {
        self.0
    }

    pub fn to_vec(self) -> Vec<B> {
        self.0.to_vec()
    }

    pub fn one_hot_encode(&self) -> [B; 256] {
        decoder_circuit(self.to_vec())
            .try_into()
            .unwrap_or_else(|v: Vec<B>| panic!("Expected a Vec of length 256 (found {})", v.len()))
    }

    pub fn from_u8_and_bool_val(value: u8, b: B) -> Self {
        let bit_true = b;
        let bit_false = B::from(false);
        Self(
            (0u8..8)
                .map(|i| {
                    if (value >> i) & 1u8 == 1u8 {
                        bit_true
                    } else {
                        bit_false
                    }
                })
                .collect::<Vec<B>>()
                .try_into()
                .unwrap_or_else(|v: Vec<B>| {
                    panic!("Expected a Vec of length 8 (found {})", v.len())
                }),
        )
    }
}

impl<B: Boolean> BitAnd for Byte<B> {
    type Output = Byte<B>;

    fn bitand(self, rhs: Self) -> Self::Output {
        Self(
            self.get_bits()
                .into_iter()
                .zip(rhs.get_bits())
                .map(|(bit_self, bit_other)| bit_self & bit_other)
                .collect::<Vec<B>>()
                .try_into()
                .unwrap_or_else(|v: Vec<B>| {
                    panic!("Expected a Vec of length 8 (found {})", v.len())
                }),
        )
    }
}

impl<B: Boolean> BitAndAssign for Byte<B> {
    fn bitand_assign(&mut self, rhs: Self) {
        *self = *self & rhs;
    }
}

impl<B: Boolean> BitXor for Byte<B> {
    type Output = Byte<B>;

    fn bitxor(self, rhs: Self) -> Self::Output {
        Self(
            self.get_bits()
                .into_iter()
                .zip(rhs.get_bits())
                .map(|(bit_self, bit_other)| bit_self ^ bit_other)
                .collect::<Vec<B>>()
                .try_into()
                .unwrap_or_else(|v: Vec<B>| {
                    panic!("Expected a Vec of length 8 (found {})", v.len())
                }),
        )
    }
}

impl<B: Boolean> BitXorAssign for Byte<B> {
    fn bitxor_assign(&mut self, rhs: Self) {
        *self = *self ^ rhs;
    }
}

impl<B: Boolean> Not for Byte<B> {
    type Output = Byte<B>;

    fn not(self) -> Self::Output {
        Self(
            self.get_bits()
                .into_iter()
                .map(|bit| !bit)
                .collect::<Vec<B>>()
                .try_into()
                .unwrap_or_else(|v: Vec<B>| {
                    panic!("Expected a Vec of length 8 (found {})", v.len())
                }),
        )
    }
}

impl<B: Boolean> Random for Byte<B> {
    fn random() -> Self {
        Self(
            (0..8)
                .map(|_| B::random())
                .collect::<Vec<B>>()
                .try_into()
                .unwrap_or_else(|v: Vec<B>| {
                    panic!("Expected a Vec of length 8 (found {})", v.len())
                }),
        )
    }
}

impl<B: Boolean> Reveal for Byte<B> {
    fn reveal(self) -> Self {
        Self(
            self.0
                .into_iter()
                .map(|bit| bit.reveal())
                .collect::<Vec<B>>()
                .try_into()
                .unwrap_or_else(|v: Vec<B>| {
                    panic!("Expected a Vec of length 8 (found {})", v.len())
                }),
        )
    }
}

impl<B: Boolean> From<u8> for Byte<B> {
    fn from(value: u8) -> Self {
        Self::from_u8_and_bool_val(value, B::from(true))
    }
}

impl From<FieldValue<BaseField>> for Byte<BooleanValue> {
    fn from(val: FieldValue<BaseField>) -> Self {
        let bounds = val.bounds();
        let max = bounds.unsigned_max();
        if max.gt(&BaseField::from(255)) {
            panic!("bounds must be in [0, 255] (found {:?})", bounds);
        }
        Self(
            (0..8)
                .map(|i| val.get_bit(i, false))
                .collect::<Vec<BooleanValue>>()
                .try_into()
                .unwrap_or_else(|v: Vec<BooleanValue>| {
                    panic!("Expected a Vec of length 8 (found {})", v.len())
                }),
        )
    }
}

impl From<Byte<bool>> for u8 {
    fn from(value: Byte<bool>) -> Self {
        value
            .get_bits()
            .into_iter()
            .enumerate()
            .fold(0u8, |acc, (i, b)| if b { acc | (1u8 << i) } else { acc })
    }
}