arcis-compiler 0.10.3

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::{
        expressions::{domain::DomainElement, expr::EvalValue},
        mxe_input::{ArxInput, MxeInput},
        profile_circuit::{explain_circuit_depth, get_circuit_depth},
    },
    network_content::NetworkContent,
    profile_info::ProfileInfo,
    utils::{curve_point::CurvePoint, number::Number},
    AsyncMPCCircuit,
};
use core_utils::circuit::{mock_eval::MockRng, AlgebraicType};
use ff::Field;
use num_bigint::{BigUint, ToBigInt};
use num_traits::Zero;
use primitives::algebra::{
    elliptic_curve::{BaseFieldElement, Curve, Curve25519Ristretto, Point, Scalar},
    field::subfield_element::Mersenne107Element,
};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha3::Digest;

struct MockEvalRng<'a, R: Rng + ?Sized> {
    rng: &'a mut R,
    bools: &'a [Option<bool>],
}

impl<'a, R: Rng + ?Sized> MockEvalRng<'a, R> {
    pub fn new(rng: &'a mut R, bools: &'a [Option<bool>]) -> Self {
        Self { rng, bools }
    }
}
impl<'a, R: Rng + ?Sized> MockRng for MockEvalRng<'a, R> {
    fn gen_bit(&mut self) -> bool {
        if let Some(b) = self.bools.first() {
            self.bools = &self.bools[1..];
            if let Some(b) = b {
                return *b;
            }
        }
        self.rng.gen()
    }

    fn gen_da_bit(&mut self) -> bool {
        self.rng.gen()
    }

    fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
        Scalar::<C>::random(&mut self.rng)
    }

    fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
        BaseFieldElement::<C>::random(&mut self.rng)
    }

    fn gen_mersenne(&mut self) -> Mersenne107Element {
        Mersenne107Element::random(&mut self.rng)
    }

    fn gen_point<C: Curve>(&mut self) -> Point<C> {
        self.rng.gen()
    }
}

/// The main output of the compiler.
#[derive(Debug, Serialize, Deserialize)]
pub struct ArcisInstruction {
    /// The circuit, as understood by our async-mpc library.
    pub circuit: AsyncMPCCircuit,
    /// Metadata.
    /// Currently, only states which input to put in which position.
    pub metadata: ArcisMetadata,
}

pub type DomainKind = DomainElement<(), (), (), ()>;

impl ArcisInstruction {
    /// Local evaluation of a circuit on some inputs.
    /// Used for testing.
    pub fn mock_eval_vec<R: Rng + ?Sized>(
        &self,
        inputs: Vec<EvalValue>,
        output_domains: &[DomainKind],
        bools: &[Option<bool>],
        rng: &mut R,
    ) -> Vec<EvalValue> {
        let numbers: Vec<EvalValue> = self
            .metadata
            .input_order
            .iter()
            .map(|x| match *x {
                ArcisInput::InputId(id) => inputs[id],
                ArcisInput::Mxe(input) => input.mock_eval(),
                ArcisInput::TypedInput(id, _) => inputs[id],
            })
            .collect();

        let v = numbers
            .into_iter()
            .map(|x| match x {
                EvalValue::Bit(b) => BigUint::from(b),
                EvalValue::Scalar(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
                EvalValue::Base(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
                EvalValue::Curve(b) => BigUint::from_bytes_le(&b.to_bytes()),
            })
            .collect();
        let mut mock_rng = MockEvalRng::new(rng, bools);
        let res = self.circuit.mock_eval_big_uint(v, &mut mock_rng);
        std::iter::zip(output_domains, res)
            .map(|(domain, x)| match domain {
                DomainElement::Bit(_) => EvalValue::Bit(!x.is_zero()),
                DomainElement::Scalar(_) => {
                    EvalValue::Scalar(Number::from(x.to_bigint().unwrap()).into())
                }
                DomainElement::Base(_) => {
                    EvalValue::Base(Number::from(x.to_bigint().unwrap()).into())
                }
                DomainElement::Curve(_) => {
                    let mut bytes = x.to_bytes_le();
                    bytes.resize(32, 0);
                    EvalValue::Curve(
                        CurvePoint::from_le_bytes(&bytes)
                            .expect("Failed to convert to CurvePoint."),
                    )
                }
            })
            .collect()
    }
    /// The network depth in number of rounds of communication necessary to compute the circuit.
    /// Used for profiling.
    pub fn network_depth(&self) -> usize {
        get_circuit_depth(&self.circuit)
    }
    /// Gives profiling information on the circuit.
    pub fn profile_info(&self) -> ProfileInfo {
        let network_depth = self.network_depth();
        let total_gates = self.circuit.nb_gates() as usize;
        let preprocessing = self.circuit.required_preprocessing();
        let pre_process = (&preprocessing).into();
        let network_content = NetworkContent::from_circuit(&self.circuit);
        let mut hash = sha3::Keccak256::new();
        hash.update(self.to_bytes());
        let hash = hash.finalize();
        let circuit_hash = usize::from_le_bytes(hash[0..8].try_into().unwrap());

        ProfileInfo {
            network_depth,
            total_gates,
            network_content,
            pre_process,
            circuit_hash,
        }
    }
    /// The weight of the circuit. Used to compute the cost to run it.
    pub fn weight(&self) -> usize {
        self.profile_info().weight()
    }
    /// A function used for debugging when the depth is different from expected.
    #[allow(dead_code)]
    pub fn explain_depth(&self) {
        explain_circuit_depth(&self.circuit);
    }
    pub fn to_bytes(&self) -> Vec<u8> {
        let mut data = bincode::serialize(self).unwrap();
        CURRENT_VERSION.add_to_bytes(&mut data);
        data
    }
    pub fn from_bytes(mut bytes: Vec<u8>) -> Result<Self, String> {
        let version = Version::from_bytes(&bytes)?;
        let len_without_version = bytes.len() - version.len();
        bytes.truncate(len_without_version);
        match version {
            Version::V1 => {
                let v1_instruction: ArcisInstructionV1 =
                    bincode::deserialize(&bytes).map_err(|e| e.to_string())?;
                Ok(ArcisInstruction {
                    circuit: v1_instruction
                        .circuit
                        .into_v2()
                        .map_err(|e| e.to_string())?,
                    metadata: v1_instruction.metadata,
                })
            }
            Version::V2 => bincode::deserialize(&bytes).map_err(|e| e.to_string()),
        }
    }
}

/// The circuit metadata.
/// Gives the order of inputs.
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct ArcisMetadata {
    pub input_order: Vec<ArcisInput>,
}

/// A form of input id.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ArcisInput {
    /// An input given by user. `InputId(0)` will be the 1st, `InputId(1)` will be the 2nd, etc...
    InputId(usize),
    /// An input given by the MXE.
    Mxe(MxeInput),
    /// An input given by user. `InputId(0)` will be the 1st, `InputId(1)` will be the 2nd, etc...
    /// The algebraic type specifies the type the circuit expects.
    TypedInput(usize, AlgebraicType),
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Version {
    V1,
    V2,
}

impl Version {
    pub fn from_bytes(bytes: &[u8]) -> Result<Version, String> {
        let Some(d) = bytes.last() else {
            return Err("Empty bytes.".into());
        };
        if *d == 255 {
            let [.., v, _] = bytes else {
                return Err("Invalid bytes length.".into());
            };
            if *v == 2 {
                Ok(Version::V2)
            } else {
                Err(format!("Invalid version suffix: [{v}, 255]."))
            }
        } else {
            Ok(Version::V1)
        }
    }
    pub fn add_to_bytes(&self, bytes: &mut Vec<u8>) {
        match self {
            Version::V1 => {}
            Version::V2 => {
                bytes.push(2);
                bytes.push(255);
            }
        }
    }
    pub fn len(&self) -> usize {
        match self {
            Version::V1 => 0,
            Version::V2 => 2,
        }
    }
}

const CURRENT_VERSION: Version = Version::V2;

type AsyncMPCCircuitV1 = core_utils::circuit::CircuitV1<Curve25519Ristretto>;
#[derive(Deserialize)]
struct ArcisInstructionV1 {
    /// The circuit, as understood by our async-mpc library.
    pub circuit: AsyncMPCCircuitV1,
    /// Metadata.
    /// Currently, only states which input to put in which position.
    pub metadata: ArcisMetadata,
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{MxeBitInput, MxeFieldInput, MxeScalarInput};
    #[test]
    fn v1_compatibility() {
        const V1_FILES: &[&[u8]] = &[include_bytes!("../../old-arcis-files/v1/keygen.arcis")];
        for file in V1_FILES {
            ArcisInstruction::from_bytes(file.to_vec()).unwrap();
        }
    }
    #[test]
    fn version_v1() {
        // V1 circuits end in their metadata.
        // We check our understanding of how the version would be computed.
        // We check that in all reasonable cases the computed version would be V1.
        fn would_version_be_v1(metadata: ArcisMetadata) -> bool {
            let data = bincode::serialize(&metadata).unwrap();
            let version = Version::from_bytes(&data);
            version.is_ok_and(|x| x == Version::V1)
        }
        // Empty ArcisMetadata is serialized as `[0u8; 8]`.
        assert!(would_version_be_v1(ArcisMetadata::default()));
        fn would_version_be_v1_one_item(item: ArcisInput) -> bool {
            let metadata = ArcisMetadata {
                input_order: vec![item],
            };
            would_version_be_v1(metadata)
        }
        // Checking with an input which is only a variant.
        assert!(would_version_be_v1_one_item(ArcisInput::Mxe(
            MxeInput::ScalarOnly(MxeScalarInput::ElGamalSecretKey())
        )));
        // Checking with inputs ending in a `usize`.
        fn check_number(number: usize, result: bool) {
            for item in [
                ArcisInput::InputId(number),
                ArcisInput::Mxe(MxeInput::Scalar(MxeFieldInput::RescueKey(number))),
                ArcisInput::Mxe(MxeInput::Bit(MxeBitInput::AES256Key(number))),
            ] {
                assert_eq!(would_version_be_v1_one_item(item), result);
            }
        }
        check_number(0, true);
        check_number(255, true);
        check_number(usize::MAX >> 1, true);
        check_number(usize::MAX, false);
        // We check that the limit at which it works is really what we think it is.
        let limit = 255usize << 56;
        check_number(limit, false);
        check_number(limit - 1, true);
        // As long as users have less than `255usize << 56` inputs,
        // and we did not use negative numbers in input ids,
        // this versioning system will correctly identify V1 instructions as V1.
    }
}