1use crate::{
2 core::{
3 expressions::{domain::DomainElement, expr::EvalValue},
4 mxe_input::{ArxInput, MxeInput},
5 profile_circuit::{explain_circuit_depth, get_circuit_depth},
6 },
7 network_content::NetworkContent,
8 profile_info::ProfileInfo,
9 utils::{curve_point::CurvePoint, number::Number},
10 AsyncMPCCircuit,
11};
12use core_utils::circuit::{mock_eval::MockRng, AlgebraicType};
13use ff::Field;
14use num_bigint::{BigUint, ToBigInt};
15use num_traits::Zero;
16use primitives::algebra::{
17 elliptic_curve::{BaseFieldElement, Curve, Curve25519Ristretto, Point, Scalar},
18 field::subfield_element::Mersenne107Element,
19};
20use rand::Rng;
21use serde::{Deserialize, Serialize};
22use sha3::Digest;
23
24struct MockEvalRng<'a, R: Rng + ?Sized> {
25 rng: &'a mut R,
26 bools: &'a [Option<bool>],
27}
28
29impl<'a, R: Rng + ?Sized> MockEvalRng<'a, R> {
30 pub fn new(rng: &'a mut R, bools: &'a [Option<bool>]) -> Self {
31 Self { rng, bools }
32 }
33}
34impl<'a, R: Rng + ?Sized> MockRng for MockEvalRng<'a, R> {
35 fn gen_bit(&mut self) -> bool {
36 if let Some(b) = self.bools.first() {
37 self.bools = &self.bools[1..];
38 if let Some(b) = b {
39 return *b;
40 }
41 }
42 self.rng.gen()
43 }
44
45 fn gen_da_bit(&mut self) -> bool {
46 self.rng.gen()
47 }
48
49 fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
50 Scalar::<C>::random(&mut self.rng)
51 }
52
53 fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
54 BaseFieldElement::<C>::random(&mut self.rng)
55 }
56
57 fn gen_mersenne(&mut self) -> Mersenne107Element {
58 Mersenne107Element::random(&mut self.rng)
59 }
60
61 fn gen_point<C: Curve>(&mut self) -> Point<C> {
62 self.rng.gen()
63 }
64}
65
66#[derive(Debug, Serialize, Deserialize)]
68pub struct ArcisInstruction {
69 pub circuit: AsyncMPCCircuit,
71 pub metadata: ArcisMetadata,
74}
75
76pub type DomainKind = DomainElement<(), (), (), ()>;
77
78impl ArcisInstruction {
79 pub fn mock_eval_vec<R: Rng + ?Sized>(
82 &self,
83 inputs: Vec<EvalValue>,
84 output_domains: &[DomainKind],
85 bools: &[Option<bool>],
86 rng: &mut R,
87 ) -> Vec<EvalValue> {
88 let numbers: Vec<EvalValue> = self
89 .metadata
90 .input_order
91 .iter()
92 .map(|x| match *x {
93 ArcisInput::InputId(id) => inputs[id],
94 ArcisInput::Mxe(input) => input.mock_eval(),
95 ArcisInput::TypedInput(id, _) => inputs[id],
96 })
97 .collect();
98
99 let v = numbers
100 .into_iter()
101 .map(|x| match x {
102 EvalValue::Bit(b) => BigUint::from(b),
103 EvalValue::Scalar(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
104 EvalValue::Base(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
105 EvalValue::Curve(b) => BigUint::from_bytes_le(&b.to_bytes()),
106 })
107 .collect();
108 let mut mock_rng = MockEvalRng::new(rng, bools);
109 let res = self.circuit.mock_eval_big_uint(v, &mut mock_rng);
110 std::iter::zip(output_domains, res)
111 .map(|(domain, x)| match domain {
112 DomainElement::Bit(_) => EvalValue::Bit(!x.is_zero()),
113 DomainElement::Scalar(_) => {
114 EvalValue::Scalar(Number::from(x.to_bigint().unwrap()).into())
115 }
116 DomainElement::Base(_) => {
117 EvalValue::Base(Number::from(x.to_bigint().unwrap()).into())
118 }
119 DomainElement::Curve(_) => {
120 let mut bytes = x.to_bytes_le();
121 bytes.resize(32, 0);
122 EvalValue::Curve(
123 CurvePoint::from_le_bytes(&bytes)
124 .expect("Failed to convert to CurvePoint."),
125 )
126 }
127 })
128 .collect()
129 }
130 pub fn network_depth(&self) -> usize {
133 get_circuit_depth(&self.circuit)
134 }
135 pub fn profile_info(&self) -> ProfileInfo {
137 let network_depth = self.network_depth();
138 let total_gates = self.circuit.nb_gates() as usize;
139 let preprocessing = self.circuit.required_preprocessing();
140 let pre_process = (&preprocessing).into();
141 let network_content = NetworkContent::from_circuit(&self.circuit);
142 let mut hash = sha3::Keccak256::new();
143 hash.update(self.to_bytes());
144 let hash = hash.finalize();
145 let circuit_hash = usize::from_le_bytes(hash[0..8].try_into().unwrap());
146
147 ProfileInfo {
148 network_depth,
149 total_gates,
150 network_content,
151 pre_process,
152 circuit_hash,
153 }
154 }
155 pub fn weight(&self) -> usize {
157 self.profile_info().weight()
158 }
159 #[allow(dead_code)]
161 pub fn explain_depth(&self) {
162 explain_circuit_depth(&self.circuit);
163 }
164 pub fn to_bytes(&self) -> Vec<u8> {
165 let mut data = bincode::serialize(self).unwrap();
166 CURRENT_VERSION.add_to_bytes(&mut data);
167 data
168 }
169 pub fn from_bytes(mut bytes: Vec<u8>) -> Result<Self, String> {
170 let version = Version::from_bytes(&bytes)?;
171 let len_without_version = bytes.len() - version.len();
172 bytes.truncate(len_without_version);
173 match version {
174 Version::V1 => {
175 let v1_instruction: ArcisInstructionV1 =
176 bincode::deserialize(&bytes).map_err(|e| e.to_string())?;
177 Ok(ArcisInstruction {
178 circuit: v1_instruction
179 .circuit
180 .into_v2()
181 .map_err(|e| e.to_string())?,
182 metadata: v1_instruction.metadata,
183 })
184 }
185 Version::V2 => bincode::deserialize(&bytes).map_err(|e| e.to_string()),
186 }
187 }
188}
189
190#[derive(Debug, Serialize, Deserialize, Default)]
193pub struct ArcisMetadata {
194 pub input_order: Vec<ArcisInput>,
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
199pub enum ArcisInput {
200 InputId(usize),
202 Mxe(MxeInput),
204 TypedInput(usize, AlgebraicType),
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq)]
210enum Version {
211 V1,
212 V2,
213}
214
215impl Version {
216 pub fn from_bytes(bytes: &[u8]) -> Result<Version, String> {
217 let Some(d) = bytes.last() else {
218 return Err("Empty bytes.".into());
219 };
220 if *d == 255 {
221 let [.., v, _] = bytes else {
222 return Err("Invalid bytes length.".into());
223 };
224 if *v == 2 {
225 Ok(Version::V2)
226 } else {
227 Err(format!("Invalid version suffix: [{v}, 255]."))
228 }
229 } else {
230 Ok(Version::V1)
231 }
232 }
233 pub fn add_to_bytes(&self, bytes: &mut Vec<u8>) {
234 match self {
235 Version::V1 => {}
236 Version::V2 => {
237 bytes.push(2);
238 bytes.push(255);
239 }
240 }
241 }
242 pub fn len(&self) -> usize {
243 match self {
244 Version::V1 => 0,
245 Version::V2 => 2,
246 }
247 }
248}
249
250const CURRENT_VERSION: Version = Version::V2;
251
252type AsyncMPCCircuitV1 = core_utils::circuit::CircuitV1<Curve25519Ristretto>;
253#[derive(Deserialize)]
254struct ArcisInstructionV1 {
255 pub circuit: AsyncMPCCircuitV1,
257 pub metadata: ArcisMetadata,
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use crate::{MxeBitInput, MxeFieldInput, MxeScalarInput};
266 #[test]
267 fn v1_compatibility() {
268 const V1_FILES: &[&[u8]] = &[include_bytes!("../../old-arcis-files/v1/keygen.arcis")];
269 for file in V1_FILES {
270 ArcisInstruction::from_bytes(file.to_vec()).unwrap();
271 }
272 }
273 #[test]
274 fn version_v1() {
275 fn would_version_be_v1(metadata: ArcisMetadata) -> bool {
279 let data = bincode::serialize(&metadata).unwrap();
280 let version = Version::from_bytes(&data);
281 version.is_ok_and(|x| x == Version::V1)
282 }
283 assert!(would_version_be_v1(ArcisMetadata::default()));
285 fn would_version_be_v1_one_item(item: ArcisInput) -> bool {
286 let metadata = ArcisMetadata {
287 input_order: vec![item],
288 };
289 would_version_be_v1(metadata)
290 }
291 assert!(would_version_be_v1_one_item(ArcisInput::Mxe(
293 MxeInput::ScalarOnly(MxeScalarInput::ElGamalSecretKey())
294 )));
295 fn check_number(number: usize, result: bool) {
297 for item in [
298 ArcisInput::InputId(number),
299 ArcisInput::Mxe(MxeInput::Scalar(MxeFieldInput::RescueKey(number))),
300 ArcisInput::Mxe(MxeInput::Bit(MxeBitInput::AES256Key(number))),
301 ] {
302 assert_eq!(would_version_be_v1_one_item(item), result);
303 }
304 }
305 check_number(0, true);
306 check_number(255, true);
307 check_number(usize::MAX >> 1, true);
308 check_number(usize::MAX, false);
309 let limit = 255usize << 56;
311 check_number(limit, false);
312 check_number(limit - 1, true);
313 }
317}