use crate::{
core::{
expressions::{domain::DomainElement, expr::EvalValue},
mxe_input::{ArxInput, MxeInput},
profile_circuit::{explain_circuit_depth, get_circuit_depth},
},
network_content::NetworkContent,
profile_info::ProfileInfo,
utils::{curve_point::CurvePoint, number::Number},
AsyncMPCCircuit,
};
use core_utils::circuit::{mock_eval::MockRng, AlgebraicType};
use ff::Field;
use num_bigint::{BigUint, ToBigInt};
use num_traits::Zero;
use primitives::algebra::{
elliptic_curve::{BaseFieldElement, Curve, Curve25519Ristretto, Point, Scalar},
field::subfield_element::Mersenne107Element,
};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha3::Digest;
struct MockEvalRng<'a, R: Rng + ?Sized> {
rng: &'a mut R,
bools: &'a [Option<bool>],
}
impl<'a, R: Rng + ?Sized> MockEvalRng<'a, R> {
pub fn new(rng: &'a mut R, bools: &'a [Option<bool>]) -> Self {
Self { rng, bools }
}
}
impl<'a, R: Rng + ?Sized> MockRng for MockEvalRng<'a, R> {
fn gen_bit(&mut self) -> bool {
if let Some(b) = self.bools.first() {
self.bools = &self.bools[1..];
if let Some(b) = b {
return *b;
}
}
self.rng.gen()
}
fn gen_da_bit(&mut self) -> bool {
self.rng.gen()
}
fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
Scalar::<C>::random(&mut self.rng)
}
fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
BaseFieldElement::<C>::random(&mut self.rng)
}
fn gen_mersenne(&mut self) -> Mersenne107Element {
Mersenne107Element::random(&mut self.rng)
}
fn gen_point<C: Curve>(&mut self) -> Point<C> {
self.rng.gen()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ArcisInstruction {
pub circuit: AsyncMPCCircuit,
pub metadata: ArcisMetadata,
}
pub type DomainKind = DomainElement<(), (), (), ()>;
impl ArcisInstruction {
pub fn mock_eval_vec<R: Rng + ?Sized>(
&self,
inputs: Vec<EvalValue>,
output_domains: &[DomainKind],
bools: &[Option<bool>],
rng: &mut R,
) -> Vec<EvalValue> {
let numbers: Vec<EvalValue> = self
.metadata
.input_order
.iter()
.map(|x| match *x {
ArcisInput::InputId(id) => inputs[id],
ArcisInput::Mxe(input) => input.mock_eval(),
ArcisInput::TypedInput(id, _) => inputs[id],
})
.collect();
let v = numbers
.into_iter()
.map(|x| match x {
EvalValue::Bit(b) => BigUint::from(b),
EvalValue::Scalar(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
EvalValue::Base(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
EvalValue::Curve(b) => BigUint::from_bytes_le(&b.to_bytes()),
})
.collect();
let mut mock_rng = MockEvalRng::new(rng, bools);
let res = self.circuit.mock_eval_big_uint(v, &mut mock_rng);
std::iter::zip(output_domains, res)
.map(|(domain, x)| match domain {
DomainElement::Bit(_) => EvalValue::Bit(!x.is_zero()),
DomainElement::Scalar(_) => {
EvalValue::Scalar(Number::from(x.to_bigint().unwrap()).into())
}
DomainElement::Base(_) => {
EvalValue::Base(Number::from(x.to_bigint().unwrap()).into())
}
DomainElement::Curve(_) => {
let mut bytes = x.to_bytes_le();
bytes.resize(32, 0);
EvalValue::Curve(
CurvePoint::from_le_bytes(&bytes)
.expect("Failed to convert to CurvePoint."),
)
}
})
.collect()
}
pub fn network_depth(&self) -> usize {
get_circuit_depth(&self.circuit)
}
pub fn profile_info(&self) -> ProfileInfo {
let network_depth = self.network_depth();
let total_gates = self.circuit.nb_gates() as usize;
let preprocessing = self.circuit.required_preprocessing();
let pre_process = (&preprocessing).into();
let network_content = NetworkContent::from_circuit(&self.circuit);
let mut hash = sha3::Keccak256::new();
hash.update(self.to_bytes());
let hash = hash.finalize();
let circuit_hash = usize::from_le_bytes(hash[0..8].try_into().unwrap());
ProfileInfo {
network_depth,
total_gates,
network_content,
pre_process,
circuit_hash,
}
}
pub fn weight(&self) -> usize {
self.profile_info().weight()
}
#[allow(dead_code)]
pub fn explain_depth(&self) {
explain_circuit_depth(&self.circuit);
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut data = bincode::serialize(self).unwrap();
CURRENT_VERSION.add_to_bytes(&mut data);
data
}
pub fn from_bytes(mut bytes: Vec<u8>) -> Result<Self, String> {
let version = Version::from_bytes(&bytes)?;
let len_without_version = bytes.len() - version.len();
bytes.truncate(len_without_version);
match version {
Version::V1 => {
let v1_instruction: ArcisInstructionV1 =
bincode::deserialize(&bytes).map_err(|e| e.to_string())?;
Ok(ArcisInstruction {
circuit: v1_instruction
.circuit
.into_v2()
.map_err(|e| e.to_string())?,
metadata: v1_instruction.metadata,
})
}
Version::V2 => bincode::deserialize(&bytes).map_err(|e| e.to_string()),
}
}
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct ArcisMetadata {
pub input_order: Vec<ArcisInput>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ArcisInput {
InputId(usize),
Mxe(MxeInput),
TypedInput(usize, AlgebraicType),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Version {
V1,
V2,
}
impl Version {
pub fn from_bytes(bytes: &[u8]) -> Result<Version, String> {
let Some(d) = bytes.last() else {
return Err("Empty bytes.".into());
};
if *d == 255 {
let [.., v, _] = bytes else {
return Err("Invalid bytes length.".into());
};
if *v == 2 {
Ok(Version::V2)
} else {
Err(format!("Invalid version suffix: [{v}, 255]."))
}
} else {
Ok(Version::V1)
}
}
pub fn add_to_bytes(&self, bytes: &mut Vec<u8>) {
match self {
Version::V1 => {}
Version::V2 => {
bytes.push(2);
bytes.push(255);
}
}
}
pub fn len(&self) -> usize {
match self {
Version::V1 => 0,
Version::V2 => 2,
}
}
}
const CURRENT_VERSION: Version = Version::V2;
type AsyncMPCCircuitV1 = core_utils::circuit::CircuitV1<Curve25519Ristretto>;
#[derive(Deserialize)]
struct ArcisInstructionV1 {
pub circuit: AsyncMPCCircuitV1,
pub metadata: ArcisMetadata,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{MxeBitInput, MxeFieldInput, MxeScalarInput};
#[test]
fn v1_compatibility() {
const V1_FILES: &[&[u8]] = &[include_bytes!("../../old-arcis-files/v1/keygen.arcis")];
for file in V1_FILES {
ArcisInstruction::from_bytes(file.to_vec()).unwrap();
}
}
#[test]
fn version_v1() {
fn would_version_be_v1(metadata: ArcisMetadata) -> bool {
let data = bincode::serialize(&metadata).unwrap();
let version = Version::from_bytes(&data);
version.is_ok_and(|x| x == Version::V1)
}
assert!(would_version_be_v1(ArcisMetadata::default()));
fn would_version_be_v1_one_item(item: ArcisInput) -> bool {
let metadata = ArcisMetadata {
input_order: vec![item],
};
would_version_be_v1(metadata)
}
assert!(would_version_be_v1_one_item(ArcisInput::Mxe(
MxeInput::ScalarOnly(MxeScalarInput::ElGamalSecretKey())
)));
fn check_number(number: usize, result: bool) {
for item in [
ArcisInput::InputId(number),
ArcisInput::Mxe(MxeInput::Scalar(MxeFieldInput::RescueKey(number))),
ArcisInput::Mxe(MxeInput::Bit(MxeBitInput::AES256Key(number))),
] {
assert_eq!(would_version_be_v1_one_item(item), result);
}
}
check_number(0, true);
check_number(255, true);
check_number(usize::MAX >> 1, true);
check_number(usize::MAX, false);
let limit = 255usize << 56;
check_number(limit, false);
check_number(limit - 1, true);
}
}