arcis_compiler/core/
instruction.rs1use crate::{
2 core::{
3 expressions::{domain::DomainElement, expr::EvalValue},
4 mxe_input::{ArxInput, MxeInput},
5 profile_circuit::{explain_circuit_depth, get_circuit_depth},
6 },
7 network_content::NetworkContent,
8 profile_info::ProfileInfo,
9 utils::{curve_point::CurvePoint, number::Number},
10 AsyncMPCCircuit,
11};
12use core_utils::circuit::mock_eval::MockRng;
13use ff::Field;
14use num_bigint::{BigUint, ToBigInt};
15use num_traits::Zero;
16use primitives::algebra::{
17 elliptic_curve::{BaseFieldElement, Curve, Point, Scalar},
18 field::subfield_element::Mersenne107Element,
19};
20use rand::Rng;
21use serde::{Deserialize, Serialize};
22use sha3::Digest;
23
24struct MockEvalRng<'a, R: Rng + ?Sized> {
25 rng: &'a mut R,
26 bools: &'a [Option<bool>],
27}
28
29impl<'a, R: Rng + ?Sized> MockEvalRng<'a, R> {
30 pub fn new(rng: &'a mut R, bools: &'a [Option<bool>]) -> Self {
31 Self { rng, bools }
32 }
33}
34impl<'a, R: Rng + ?Sized> MockRng for MockEvalRng<'a, R> {
35 fn gen_bit(&mut self) -> bool {
36 if let Some(b) = self.bools.first() {
37 self.bools = &self.bools[1..];
38 if let Some(b) = b {
39 return *b;
40 }
41 }
42 self.rng.gen()
43 }
44
45 fn gen_da_bit(&mut self) -> bool {
46 self.rng.gen()
47 }
48
49 fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
50 Scalar::<C>::random(&mut self.rng)
51 }
52
53 fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
54 BaseFieldElement::<C>::random(&mut self.rng)
55 }
56
57 fn gen_mersenne(&mut self) -> Mersenne107Element {
58 Mersenne107Element::random(&mut self.rng)
59 }
60
61 fn gen_point<C: Curve>(&mut self) -> Point<C> {
62 self.rng.gen()
63 }
64}
65
66#[derive(Debug, Serialize, Deserialize)]
68pub struct ArcisInstruction {
69 pub circuit: AsyncMPCCircuit,
71 pub metadata: ArcisMetadata,
74}
75
76impl ArcisInstruction {
77 pub fn mock_eval_vec<R: Rng + ?Sized>(
80 &self,
81 inputs: Vec<EvalValue>,
82 output_domains: &[DomainElement<(), (), (), ()>],
83 bools: &[Option<bool>],
84 rng: &mut R,
85 ) -> Vec<EvalValue> {
86 let numbers: Vec<EvalValue> = self
87 .metadata
88 .input_order
89 .iter()
90 .map(|x| match *x {
91 ArcisInput::InputId(id) => inputs[id],
92 ArcisInput::Mxe(input) => input.mock_eval(),
93 })
94 .collect();
95
96 let v = numbers
97 .into_iter()
98 .map(|x| match x {
99 EvalValue::Bit(b) => BigUint::from(b),
100 EvalValue::Scalar(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
101 EvalValue::Base(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
102 EvalValue::Curve(b) => BigUint::from_bytes_le(&b.to_bytes()),
103 })
104 .collect();
105 let mut mock_rng = MockEvalRng::new(rng, bools);
106 let res = self.circuit.mock_eval_big_uint(v, &mut mock_rng);
107 std::iter::zip(output_domains, res)
108 .map(|(domain, x)| match domain {
109 DomainElement::Bit(_) => EvalValue::Bit(!x.is_zero()),
110 DomainElement::Scalar(_) => {
111 EvalValue::Scalar(Number::from(x.to_bigint().unwrap()).into())
112 }
113 DomainElement::Base(_) => {
114 EvalValue::Base(Number::from(x.to_bigint().unwrap()).into())
115 }
116 DomainElement::Curve(_) => {
117 let mut bytes = x.to_bytes_le();
118 bytes.resize(32, 0);
119 EvalValue::Curve(
120 CurvePoint::from_le_bytes(&bytes)
121 .expect("Failed to convert to CurvePoint."),
122 )
123 }
124 })
125 .collect()
126 }
127 pub fn network_depth(&self) -> usize {
130 get_circuit_depth(&self.circuit)
131 }
132 pub fn profile_info(&self) -> ProfileInfo {
134 let network_depth = self.network_depth();
135 let total_gates = self.circuit.ops_count() as usize;
136 let preprocessing = self.circuit.required_preprocessing();
137 let pre_process = (&preprocessing).into();
138 let network_content = NetworkContent::from_circuit(&self.circuit);
139 let mut hash = sha3::Keccak256::new();
140 hash.update(bincode::serialize(&self).unwrap());
141 let hash = hash.finalize();
142 let circuit_hash = usize::from_le_bytes(hash[0..8].try_into().unwrap());
143
144 ProfileInfo {
145 network_depth,
146 total_gates,
147 network_content,
148 pre_process,
149 circuit_hash,
150 }
151 }
152 pub fn weight(&self) -> usize {
154 self.profile_info().weight()
155 }
156 #[allow(dead_code)]
158 pub fn explain_depth(&self) {
159 explain_circuit_depth(&self.circuit);
160 }
161}
162
163#[derive(Debug, Serialize, Deserialize, Default)]
166pub struct ArcisMetadata {
167 pub input_order: Vec<ArcisInput>,
168}
169
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
172pub enum ArcisInput {
173 InputId(usize),
175 Mxe(MxeInput),
177}