1use crate::{
2 auxiliary_circuit_info::AuxiliaryCircuitInfo,
3 core::{
4 expressions::{domain::DomainElement, expr::EvalValue, other_expr::SharedRescueKeyData},
5 mxe_input::{ArxInput, MxeInput},
6 profile_circuit::{explain_circuit_depth, get_circuit_depth},
7 },
8 network_content::NetworkContent,
9 profile_info::ProfileInfo,
10 utils::{curve_point::CurvePoint, number::Number},
11 AsyncMPCCircuit,
12};
13use core_utils::circuit::{mock_eval::MockRng, AlgebraicType};
14use ff::Field;
15use num_bigint::{BigUint, ToBigInt};
16use num_traits::Zero;
17use primitives::algebra::{
18 elliptic_curve::{BaseFieldElement, Curve, Curve25519Ristretto, Point, Scalar},
19 field::subfield_element::Mersenne107Element,
20};
21use rand::Rng;
22use serde::{Deserialize, Serialize};
23use sha3::Digest;
24
25struct MockEvalRng<'a, R: Rng + ?Sized> {
26 rng: &'a mut R,
27 bools: &'a [Option<bool>],
28}
29
30impl<'a, R: Rng + ?Sized> MockEvalRng<'a, R> {
31 pub fn new(rng: &'a mut R, bools: &'a [Option<bool>]) -> Self {
32 Self { rng, bools }
33 }
34}
35impl<'a, R: Rng + ?Sized> MockRng for MockEvalRng<'a, R> {
36 fn gen_bit(&mut self) -> bool {
37 if let Some(b) = self.bools.first() {
38 self.bools = &self.bools[1..];
39 if let Some(b) = b {
40 return *b;
41 }
42 }
43 self.rng.gen()
44 }
45
46 fn gen_da_bit(&mut self) -> bool {
47 self.rng.gen()
48 }
49
50 fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
51 Scalar::<C>::random(&mut self.rng)
52 }
53
54 fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
55 BaseFieldElement::<C>::random(&mut self.rng)
56 }
57
58 fn gen_mersenne(&mut self) -> Mersenne107Element {
59 Mersenne107Element::random(&mut self.rng)
60 }
61
62 fn gen_point<C: Curve>(&mut self) -> Point<C> {
63 self.rng.gen()
64 }
65}
66
67#[derive(Debug, Serialize, Deserialize)]
69pub struct ArcisInstruction {
70 pub circuit: AsyncMPCCircuit,
72 pub metadata: ArcisMetadata,
75}
76
77pub type DomainKind = DomainElement<(), (), (), ()>;
78
79impl ArcisInstruction {
80 pub fn mock_eval_vec<R: Rng + ?Sized>(
83 &self,
84 inputs: Vec<EvalValue>,
85 output_domains: &[DomainKind],
86 bools: &[Option<bool>],
87 rng: &mut R,
88 ) -> Vec<EvalValue> {
89 let numbers: Vec<EvalValue> = self
90 .metadata
91 .input_order
92 .iter()
93 .map(|x| match *x {
94 ArcisInput::InputId(id) => inputs[id],
95 ArcisInput::Mxe(input) => input.mock_eval(),
96 ArcisInput::TypedInput(id, _) => inputs[id],
97 ArcisInput::SharedRescueKey(data) => data
98 .eval(inputs[data.pubkey_input_id])
99 .unwrap_or_else(|e| panic!("{}", e)),
100 })
101 .collect();
102
103 let v = numbers
104 .into_iter()
105 .map(|x| match x {
106 EvalValue::Bit(b) => BigUint::from(b),
107 EvalValue::Scalar(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
108 EvalValue::Base(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
109 EvalValue::Curve(b) => BigUint::from_bytes_le(&b.to_bytes()),
110 })
111 .collect();
112 let mut mock_rng = MockEvalRng::new(rng, bools);
113 let res = self.circuit.mock_eval_big_uint(v, &mut mock_rng);
114 std::iter::zip(output_domains, res)
115 .map(|(domain, x)| match domain {
116 DomainElement::Bit(_) => EvalValue::Bit(!x.is_zero()),
117 DomainElement::Scalar(_) => {
118 EvalValue::Scalar(Number::from(x.to_bigint().unwrap()).into())
119 }
120 DomainElement::Base(_) => {
121 EvalValue::Base(Number::from(x.to_bigint().unwrap()).into())
122 }
123 DomainElement::Curve(_) => {
124 let mut bytes = x.to_bytes_le();
125 bytes.resize(32, 0);
126 EvalValue::Curve(
127 CurvePoint::from_le_bytes(&bytes)
128 .expect("Failed to convert to CurvePoint."),
129 )
130 }
131 })
132 .collect()
133 }
134 pub fn network_depth(&self) -> usize {
137 get_circuit_depth(&self.circuit)
138 }
139 pub fn profile_info(&self) -> ProfileInfo {
141 let network_depth = self.network_depth();
142 let total_gates = self.circuit.nb_gates() as usize;
143 let preprocessing = self.circuit.required_preprocessing();
144 let pre_process = (&preprocessing).into();
145 let network_content = NetworkContent::from_circuit(&self.circuit);
146 let mut hash = sha3::Keccak256::new();
147 hash.update(self.to_bytes());
148 let hash = hash.finalize();
149 let circuit_hash = usize::from_le_bytes(hash[0..8].try_into().unwrap());
150 let auxiliary_circuits = AuxiliaryCircuitInfo::counts(&self.metadata);
151
152 ProfileInfo {
153 network_depth,
154 total_gates,
155 network_content,
156 pre_process,
157 auxiliary_circuits,
158 circuit_hash,
159 }
160 }
161 pub fn weight(&self) -> usize {
163 self.profile_info().weight()
164 }
165 #[allow(dead_code)]
167 pub fn explain_depth(&self) {
168 explain_circuit_depth(&self.circuit);
169 }
170 pub fn to_bytes(&self) -> Vec<u8> {
171 let mut data = bincode::serialize(self).unwrap();
172 CURRENT_VERSION.add_to_bytes(&mut data);
173 data
174 }
175 pub fn from_bytes(mut bytes: Vec<u8>) -> Result<Self, String> {
176 let version = Version::from_bytes(&bytes)?;
177 let len_without_version = bytes.len() - version.len();
178 bytes.truncate(len_without_version);
179 match version {
180 Version::V1 => {
181 let v1_instruction: ArcisInstructionV1 =
182 bincode::deserialize(&bytes).map_err(|e| e.to_string())?;
183 Ok(ArcisInstruction {
184 circuit: v1_instruction
185 .circuit
186 .into_v2()
187 .map_err(|e| e.to_string())?,
188 metadata: v1_instruction.metadata,
189 })
190 }
191 Version::V2 => bincode::deserialize(&bytes).map_err(|e| e.to_string()),
192 }
193 }
194}
195#[derive(Debug, Serialize, Deserialize, Default)]
198pub struct ArcisMetadata {
199 pub input_order: Vec<ArcisInput>,
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
204pub enum ArcisInput {
205 InputId(usize),
207 Mxe(MxeInput),
209 TypedInput(usize, AlgebraicType),
212 SharedRescueKey(SharedRescueKeyData),
214}
215
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
217enum Version {
218 V1,
219 V2,
220}
221
222impl Version {
223 pub fn from_bytes(bytes: &[u8]) -> Result<Version, String> {
224 let Some(d) = bytes.last() else {
225 return Err("Empty bytes.".into());
226 };
227 if *d == 255 {
228 let [.., v, _] = bytes else {
229 return Err("Invalid bytes length.".into());
230 };
231 if *v == 2 {
232 Ok(Version::V2)
233 } else {
234 Err(format!("Invalid version suffix: [{v}, 255]."))
235 }
236 } else {
237 Ok(Version::V1)
238 }
239 }
240 pub fn add_to_bytes(&self, bytes: &mut Vec<u8>) {
241 match self {
242 Version::V1 => {}
243 Version::V2 => {
244 bytes.push(2);
245 bytes.push(255);
246 }
247 }
248 }
249 pub fn len(&self) -> usize {
250 match self {
251 Version::V1 => 0,
252 Version::V2 => 2,
253 }
254 }
255}
256
257const CURRENT_VERSION: Version = Version::V2;
258
259type AsyncMPCCircuitV1 = core_utils::circuit::CircuitV1<Curve25519Ristretto>;
260#[derive(Deserialize)]
261struct ArcisInstructionV1 {
262 pub circuit: AsyncMPCCircuitV1,
264 pub metadata: ArcisMetadata,
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::{MxeBitInput, MxeFieldInput, MxeScalarInput};
273 #[test]
274 fn v1_compatibility() {
275 const V1_FILES: &[&[u8]] = &[include_bytes!("../../old-arcis-files/v1/keygen.arcis")];
276 for file in V1_FILES {
277 ArcisInstruction::from_bytes(file.to_vec()).unwrap();
278 }
279 }
280 #[test]
281 fn version_v1() {
282 fn would_version_be_v1(metadata: ArcisMetadata) -> bool {
286 let data = bincode::serialize(&metadata).unwrap();
287 let version = Version::from_bytes(&data);
288 version.is_ok_and(|x| x == Version::V1)
289 }
290 assert!(would_version_be_v1(ArcisMetadata::default()));
292 fn would_version_be_v1_one_item(item: ArcisInput) -> bool {
293 let metadata = ArcisMetadata {
294 input_order: vec![item],
295 };
296 would_version_be_v1(metadata)
297 }
298 assert!(would_version_be_v1_one_item(ArcisInput::Mxe(
300 MxeInput::ScalarOnly(MxeScalarInput::ElGamalSecretKey())
301 )));
302 fn check_number(number: usize, result: bool) {
304 for item in [
305 ArcisInput::InputId(number),
306 ArcisInput::Mxe(MxeInput::Scalar(MxeFieldInput::RescueKey(number))),
307 ArcisInput::Mxe(MxeInput::Bit(MxeBitInput::AES256Key(number))),
308 ] {
309 assert_eq!(would_version_be_v1_one_item(item), result);
310 }
311 }
312 check_number(0, true);
313 check_number(255, true);
314 check_number(usize::MAX >> 1, true);
315 check_number(usize::MAX, false);
316 let limit = 255usize << 56;
318 check_number(limit, false);
319 check_number(limit - 1, true);
320 }
324}