Skip to main content

arcis_compiler/core/
instruction.rs

1use crate::{
2    core::{
3        expressions::{domain::DomainElement, expr::EvalValue},
4        mxe_input::{ArxInput, MxeInput},
5        profile_circuit::{explain_circuit_depth, get_circuit_depth},
6    },
7    network_content::NetworkContent,
8    profile_info::ProfileInfo,
9    utils::{curve_point::CurvePoint, number::Number},
10    AsyncMPCCircuit,
11};
12use core_utils::circuit::{mock_eval::MockRng, AlgebraicType};
13use ff::Field;
14use num_bigint::{BigUint, ToBigInt};
15use num_traits::Zero;
16use primitives::algebra::{
17    elliptic_curve::{BaseFieldElement, Curve, Curve25519Ristretto, Point, Scalar},
18    field::subfield_element::Mersenne107Element,
19};
20use rand::Rng;
21use serde::{Deserialize, Serialize};
22use sha3::Digest;
23
24struct MockEvalRng<'a, R: Rng + ?Sized> {
25    rng: &'a mut R,
26    bools: &'a [Option<bool>],
27}
28
29impl<'a, R: Rng + ?Sized> MockEvalRng<'a, R> {
30    pub fn new(rng: &'a mut R, bools: &'a [Option<bool>]) -> Self {
31        Self { rng, bools }
32    }
33}
34impl<'a, R: Rng + ?Sized> MockRng for MockEvalRng<'a, R> {
35    fn gen_bit(&mut self) -> bool {
36        if let Some(b) = self.bools.first() {
37            self.bools = &self.bools[1..];
38            if let Some(b) = b {
39                return *b;
40            }
41        }
42        self.rng.gen()
43    }
44
45    fn gen_da_bit(&mut self) -> bool {
46        self.rng.gen()
47    }
48
49    fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
50        Scalar::<C>::random(&mut self.rng)
51    }
52
53    fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
54        BaseFieldElement::<C>::random(&mut self.rng)
55    }
56
57    fn gen_mersenne(&mut self) -> Mersenne107Element {
58        Mersenne107Element::random(&mut self.rng)
59    }
60
61    fn gen_point<C: Curve>(&mut self) -> Point<C> {
62        self.rng.gen()
63    }
64}
65
66/// The main output of the compiler.
67#[derive(Debug, Serialize, Deserialize)]
68pub struct ArcisInstruction {
69    /// The circuit, as understood by our async-mpc library.
70    pub circuit: AsyncMPCCircuit,
71    /// Metadata.
72    /// Currently, only states which input to put in which position.
73    pub metadata: ArcisMetadata,
74}
75
76pub type DomainKind = DomainElement<(), (), (), ()>;
77
78impl ArcisInstruction {
79    /// Local evaluation of a circuit on some inputs.
80    /// Used for testing.
81    pub fn mock_eval_vec<R: Rng + ?Sized>(
82        &self,
83        inputs: Vec<EvalValue>,
84        output_domains: &[DomainKind],
85        bools: &[Option<bool>],
86        rng: &mut R,
87    ) -> Vec<EvalValue> {
88        let numbers: Vec<EvalValue> = self
89            .metadata
90            .input_order
91            .iter()
92            .map(|x| match *x {
93                ArcisInput::InputId(id) => inputs[id],
94                ArcisInput::Mxe(input) => input.mock_eval(),
95                ArcisInput::TypedInput(id, _) => inputs[id],
96            })
97            .collect();
98
99        let v = numbers
100            .into_iter()
101            .map(|x| match x {
102                EvalValue::Bit(b) => BigUint::from(b),
103                EvalValue::Scalar(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
104                EvalValue::Base(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
105                EvalValue::Curve(b) => BigUint::from_bytes_le(&b.to_bytes()),
106            })
107            .collect();
108        let mut mock_rng = MockEvalRng::new(rng, bools);
109        let res = self.circuit.mock_eval_big_uint(v, &mut mock_rng);
110        std::iter::zip(output_domains, res)
111            .map(|(domain, x)| match domain {
112                DomainElement::Bit(_) => EvalValue::Bit(!x.is_zero()),
113                DomainElement::Scalar(_) => {
114                    EvalValue::Scalar(Number::from(x.to_bigint().unwrap()).into())
115                }
116                DomainElement::Base(_) => {
117                    EvalValue::Base(Number::from(x.to_bigint().unwrap()).into())
118                }
119                DomainElement::Curve(_) => {
120                    let mut bytes = x.to_bytes_le();
121                    bytes.resize(32, 0);
122                    EvalValue::Curve(
123                        CurvePoint::from_le_bytes(&bytes)
124                            .expect("Failed to convert to CurvePoint."),
125                    )
126                }
127            })
128            .collect()
129    }
130    /// The network depth in number of rounds of communication necessary to compute the circuit.
131    /// Used for profiling.
132    pub fn network_depth(&self) -> usize {
133        get_circuit_depth(&self.circuit)
134    }
135    /// Gives profiling information on the circuit.
136    pub fn profile_info(&self) -> ProfileInfo {
137        let network_depth = self.network_depth();
138        let total_gates = self.circuit.nb_gates() as usize;
139        let preprocessing = self.circuit.required_preprocessing();
140        let pre_process = (&preprocessing).into();
141        let network_content = NetworkContent::from_circuit(&self.circuit);
142        let mut hash = sha3::Keccak256::new();
143        hash.update(self.to_bytes());
144        let hash = hash.finalize();
145        let circuit_hash = usize::from_le_bytes(hash[0..8].try_into().unwrap());
146
147        ProfileInfo {
148            network_depth,
149            total_gates,
150            network_content,
151            pre_process,
152            circuit_hash,
153        }
154    }
155    /// The weight of the circuit. Used to compute the cost to run it.
156    pub fn weight(&self) -> usize {
157        self.profile_info().weight()
158    }
159    /// A function used for debugging when the depth is different from expected.
160    #[allow(dead_code)]
161    pub fn explain_depth(&self) {
162        explain_circuit_depth(&self.circuit);
163    }
164    pub fn to_bytes(&self) -> Vec<u8> {
165        let mut data = bincode::serialize(self).unwrap();
166        CURRENT_VERSION.add_to_bytes(&mut data);
167        data
168    }
169    pub fn from_bytes(mut bytes: Vec<u8>) -> Result<Self, String> {
170        let version = Version::from_bytes(&bytes)?;
171        let len_without_version = bytes.len() - version.len();
172        bytes.truncate(len_without_version);
173        match version {
174            Version::V1 => {
175                let v1_instruction: ArcisInstructionV1 =
176                    bincode::deserialize(&bytes).map_err(|e| e.to_string())?;
177                Ok(ArcisInstruction {
178                    circuit: v1_instruction
179                        .circuit
180                        .into_v2()
181                        .map_err(|e| e.to_string())?,
182                    metadata: v1_instruction.metadata,
183                })
184            }
185            Version::V2 => bincode::deserialize(&bytes).map_err(|e| e.to_string()),
186        }
187    }
188}
189
190/// The circuit metadata.
191/// Gives the order of inputs.
192#[derive(Debug, Serialize, Deserialize, Default)]
193pub struct ArcisMetadata {
194    pub input_order: Vec<ArcisInput>,
195}
196
197/// A form of input id.
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
199pub enum ArcisInput {
200    /// An input given by user. `InputId(0)` will be the 1st, `InputId(1)` will be the 2nd, etc...
201    InputId(usize),
202    /// An input given by the MXE.
203    Mxe(MxeInput),
204    /// An input given by user. `InputId(0)` will be the 1st, `InputId(1)` will be the 2nd, etc...
205    /// The algebraic type specifies the type the circuit expects.
206    TypedInput(usize, AlgebraicType),
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq)]
210enum Version {
211    V1,
212    V2,
213}
214
215impl Version {
216    pub fn from_bytes(bytes: &[u8]) -> Result<Version, String> {
217        let Some(d) = bytes.last() else {
218            return Err("Empty bytes.".into());
219        };
220        if *d == 255 {
221            let [.., v, _] = bytes else {
222                return Err("Invalid bytes length.".into());
223            };
224            if *v == 2 {
225                Ok(Version::V2)
226            } else {
227                Err(format!("Invalid version suffix: [{v}, 255]."))
228            }
229        } else {
230            Ok(Version::V1)
231        }
232    }
233    pub fn add_to_bytes(&self, bytes: &mut Vec<u8>) {
234        match self {
235            Version::V1 => {}
236            Version::V2 => {
237                bytes.push(2);
238                bytes.push(255);
239            }
240        }
241    }
242    pub fn len(&self) -> usize {
243        match self {
244            Version::V1 => 0,
245            Version::V2 => 2,
246        }
247    }
248}
249
250const CURRENT_VERSION: Version = Version::V2;
251
252type AsyncMPCCircuitV1 = core_utils::circuit::CircuitV1<Curve25519Ristretto>;
253#[derive(Deserialize)]
254struct ArcisInstructionV1 {
255    /// The circuit, as understood by our async-mpc library.
256    pub circuit: AsyncMPCCircuitV1,
257    /// Metadata.
258    /// Currently, only states which input to put in which position.
259    pub metadata: ArcisMetadata,
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use crate::{MxeBitInput, MxeFieldInput, MxeScalarInput};
266    #[test]
267    fn v1_compatibility() {
268        const V1_FILES: &[&[u8]] = &[include_bytes!("../../old-arcis-files/v1/keygen.arcis")];
269        for file in V1_FILES {
270            ArcisInstruction::from_bytes(file.to_vec()).unwrap();
271        }
272    }
273    #[test]
274    fn version_v1() {
275        // V1 circuits end in their metadata.
276        // We check our understanding of how the version would be computed.
277        // We check that in all reasonable cases the computed version would be V1.
278        fn would_version_be_v1(metadata: ArcisMetadata) -> bool {
279            let data = bincode::serialize(&metadata).unwrap();
280            let version = Version::from_bytes(&data);
281            version.is_ok_and(|x| x == Version::V1)
282        }
283        // Empty ArcisMetadata is serialized as `[0u8; 8]`.
284        assert!(would_version_be_v1(ArcisMetadata::default()));
285        fn would_version_be_v1_one_item(item: ArcisInput) -> bool {
286            let metadata = ArcisMetadata {
287                input_order: vec![item],
288            };
289            would_version_be_v1(metadata)
290        }
291        // Checking with an input which is only a variant.
292        assert!(would_version_be_v1_one_item(ArcisInput::Mxe(
293            MxeInput::ScalarOnly(MxeScalarInput::ElGamalSecretKey())
294        )));
295        // Checking with inputs ending in a `usize`.
296        fn check_number(number: usize, result: bool) {
297            for item in [
298                ArcisInput::InputId(number),
299                ArcisInput::Mxe(MxeInput::Scalar(MxeFieldInput::RescueKey(number))),
300                ArcisInput::Mxe(MxeInput::Bit(MxeBitInput::AES256Key(number))),
301            ] {
302                assert_eq!(would_version_be_v1_one_item(item), result);
303            }
304        }
305        check_number(0, true);
306        check_number(255, true);
307        check_number(usize::MAX >> 1, true);
308        check_number(usize::MAX, false);
309        // We check that the limit at which it works is really what we think it is.
310        let limit = 255usize << 56;
311        check_number(limit, false);
312        check_number(limit - 1, true);
313        // As long as users have less than `255usize << 56` inputs,
314        // and we did not use negative numbers in input ids,
315        // this versioning system will correctly identify V1 instructions as V1.
316    }
317}