use crate::{
core::{
expressions::{domain::DomainElement, expr::EvalValue},
mxe_input::{ArxInput, MxeInput},
profile_circuit::{explain_circuit_depth, get_circuit_depth},
},
network_content::NetworkContent,
profile_info::ProfileInfo,
utils::{curve_point::CurvePoint, number::Number},
AsyncMPCCircuit,
};
use core_utils::circuit::mock_eval::MockRng;
use ff::Field;
use num_bigint::{BigUint, ToBigInt};
use num_traits::Zero;
use primitives::algebra::{
elliptic_curve::{BaseFieldElement, Curve, Point, Scalar},
field::subfield_element::Mersenne107Element,
};
use rand::Rng;
use serde::{Deserialize, Serialize};
use sha3::Digest;
struct MockEvalRng<'a, R: Rng + ?Sized> {
rng: &'a mut R,
bools: &'a [Option<bool>],
}
impl<'a, R: Rng + ?Sized> MockEvalRng<'a, R> {
pub fn new(rng: &'a mut R, bools: &'a [Option<bool>]) -> Self {
Self { rng, bools }
}
}
impl<'a, R: Rng + ?Sized> MockRng for MockEvalRng<'a, R> {
fn gen_bit(&mut self) -> bool {
if let Some(b) = self.bools.first() {
self.bools = &self.bools[1..];
if let Some(b) = b {
return *b;
}
}
self.rng.gen()
}
fn gen_da_bit(&mut self) -> bool {
self.rng.gen()
}
fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
Scalar::<C>::random(&mut self.rng)
}
fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
BaseFieldElement::<C>::random(&mut self.rng)
}
fn gen_mersenne(&mut self) -> Mersenne107Element {
Mersenne107Element::random(&mut self.rng)
}
fn gen_point<C: Curve>(&mut self) -> Point<C> {
self.rng.gen()
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ArcisInstruction {
pub circuit: AsyncMPCCircuit,
pub metadata: ArcisMetadata,
}
impl ArcisInstruction {
pub fn mock_eval_vec<R: Rng + ?Sized>(
&self,
inputs: Vec<EvalValue>,
output_domains: &[DomainElement<(), (), (), ()>],
bools: &[Option<bool>],
rng: &mut R,
) -> Vec<EvalValue> {
let numbers: Vec<EvalValue> = self
.metadata
.input_order
.iter()
.map(|x| match *x {
ArcisInput::InputId(id) => inputs[id],
ArcisInput::Mxe(input) => input.mock_eval(),
})
.collect();
let v = numbers
.into_iter()
.map(|x| match x {
EvalValue::Bit(b) => BigUint::from(b),
EvalValue::Scalar(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
EvalValue::Base(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
EvalValue::Curve(b) => BigUint::from_bytes_le(&b.to_bytes()),
})
.collect();
let mut mock_rng = MockEvalRng::new(rng, bools);
let res = self.circuit.mock_eval_big_uint(v, &mut mock_rng);
std::iter::zip(output_domains, res)
.map(|(domain, x)| match domain {
DomainElement::Bit(_) => EvalValue::Bit(!x.is_zero()),
DomainElement::Scalar(_) => {
EvalValue::Scalar(Number::from(x.to_bigint().unwrap()).into())
}
DomainElement::Base(_) => {
EvalValue::Base(Number::from(x.to_bigint().unwrap()).into())
}
DomainElement::Curve(_) => {
let mut bytes = x.to_bytes_le();
bytes.resize(32, 0);
EvalValue::Curve(
CurvePoint::from_le_bytes(&bytes)
.expect("Failed to convert to CurvePoint."),
)
}
})
.collect()
}
pub fn network_depth(&self) -> usize {
get_circuit_depth(&self.circuit)
}
pub fn profile_info(&self) -> ProfileInfo {
let network_depth = self.network_depth();
let total_gates = self.circuit.ops_count() as usize;
let preprocessing = self.circuit.required_preprocessing();
let pre_process = (&preprocessing).into();
let network_content = NetworkContent::from_circuit(&self.circuit);
let mut hash = sha3::Keccak256::new();
hash.update(bincode::serialize(&self).unwrap());
let hash = hash.finalize();
let circuit_hash = usize::from_le_bytes(hash[0..8].try_into().unwrap());
ProfileInfo {
network_depth,
total_gates,
network_content,
pre_process,
circuit_hash,
}
}
pub fn weight(&self) -> usize {
self.profile_info().weight()
}
#[allow(dead_code)]
pub fn explain_depth(&self) {
explain_circuit_depth(&self.circuit);
}
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct ArcisMetadata {
pub input_order: Vec<ArcisInput>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ArcisInput {
InputId(usize),
Mxe(MxeInput),
}