Skip to main content

arcis_compiler/core/
instruction.rs

1use crate::{
2    auxiliary_circuit_info::AuxiliaryCircuitInfo,
3    core::{
4        expressions::{domain::DomainElement, expr::EvalValue, other_expr::SharedRescueKeyData},
5        mxe_input::{ArxInput, MxeInput},
6        profile_circuit::{explain_circuit_depth, get_circuit_depth},
7    },
8    network_content::NetworkContent,
9    profile_info::ProfileInfo,
10    utils::{curve_point::CurvePoint, number::Number},
11    AsyncMPCCircuit,
12};
13use core_utils::circuit::{mock_eval::MockRng, AlgebraicType};
14use ff::Field;
15use num_bigint::{BigUint, ToBigInt};
16use num_traits::Zero;
17use primitives::algebra::{
18    elliptic_curve::{BaseFieldElement, Curve, Curve25519Ristretto, Point, Scalar},
19    field::subfield_element::Mersenne107Element,
20};
21use rand::Rng;
22use serde::{Deserialize, Serialize};
23use sha3::Digest;
24
25struct MockEvalRng<'a, R: Rng + ?Sized> {
26    rng: &'a mut R,
27    bools: &'a [Option<bool>],
28}
29
30impl<'a, R: Rng + ?Sized> MockEvalRng<'a, R> {
31    pub fn new(rng: &'a mut R, bools: &'a [Option<bool>]) -> Self {
32        Self { rng, bools }
33    }
34}
35impl<'a, R: Rng + ?Sized> MockRng for MockEvalRng<'a, R> {
36    fn gen_bit(&mut self) -> bool {
37        if let Some(b) = self.bools.first() {
38            self.bools = &self.bools[1..];
39            if let Some(b) = b {
40                return *b;
41            }
42        }
43        self.rng.gen()
44    }
45
46    fn gen_da_bit(&mut self) -> bool {
47        self.rng.gen()
48    }
49
50    fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
51        Scalar::<C>::random(&mut self.rng)
52    }
53
54    fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
55        BaseFieldElement::<C>::random(&mut self.rng)
56    }
57
58    fn gen_mersenne(&mut self) -> Mersenne107Element {
59        Mersenne107Element::random(&mut self.rng)
60    }
61
62    fn gen_point<C: Curve>(&mut self) -> Point<C> {
63        self.rng.gen()
64    }
65}
66
67/// The main output of the compiler.
68#[derive(Debug, Serialize, Deserialize)]
69pub struct ArcisInstruction {
70    /// The circuit, as understood by our async-mpc library.
71    pub circuit: AsyncMPCCircuit,
72    /// Metadata.
73    /// Currently, only states which input to put in which position.
74    pub metadata: ArcisMetadata,
75}
76
77pub type DomainKind = DomainElement<(), (), (), ()>;
78
79impl ArcisInstruction {
80    /// Local evaluation of a circuit on some inputs.
81    /// Used for testing.
82    pub fn mock_eval_vec<R: Rng + ?Sized>(
83        &self,
84        inputs: Vec<EvalValue>,
85        output_domains: &[DomainKind],
86        bools: &[Option<bool>],
87        rng: &mut R,
88    ) -> Vec<EvalValue> {
89        let numbers: Vec<EvalValue> = self
90            .metadata
91            .input_order
92            .iter()
93            .map(|x| match *x {
94                ArcisInput::InputId(id) => inputs[id],
95                ArcisInput::Mxe(input) => input.mock_eval(),
96                ArcisInput::TypedInput(id, _) => inputs[id],
97                ArcisInput::SharedRescueKey(data) => data
98                    .eval(inputs[data.pubkey_input_id])
99                    .unwrap_or_else(|e| panic!("{}", e)),
100            })
101            .collect();
102
103        let v = numbers
104            .into_iter()
105            .map(|x| match x {
106                EvalValue::Bit(b) => BigUint::from(b),
107                EvalValue::Scalar(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
108                EvalValue::Base(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
109                EvalValue::Curve(b) => BigUint::from_bytes_le(&b.to_bytes()),
110            })
111            .collect();
112        let mut mock_rng = MockEvalRng::new(rng, bools);
113        let res = self.circuit.mock_eval_big_uint(v, &mut mock_rng);
114        std::iter::zip(output_domains, res)
115            .map(|(domain, x)| match domain {
116                DomainElement::Bit(_) => EvalValue::Bit(!x.is_zero()),
117                DomainElement::Scalar(_) => {
118                    EvalValue::Scalar(Number::from(x.to_bigint().unwrap()).into())
119                }
120                DomainElement::Base(_) => {
121                    EvalValue::Base(Number::from(x.to_bigint().unwrap()).into())
122                }
123                DomainElement::Curve(_) => {
124                    let mut bytes = x.to_bytes_le();
125                    bytes.resize(32, 0);
126                    EvalValue::Curve(
127                        CurvePoint::from_le_bytes(&bytes)
128                            .expect("Failed to convert to CurvePoint."),
129                    )
130                }
131            })
132            .collect()
133    }
134    /// The network depth in number of rounds of communication necessary to compute the circuit.
135    /// Used for profiling.
136    pub fn network_depth(&self) -> usize {
137        get_circuit_depth(&self.circuit)
138    }
139    /// Gives profiling information on the circuit.
140    pub fn profile_info(&self) -> ProfileInfo {
141        let network_depth = self.network_depth();
142        let total_gates = self.circuit.nb_gates() as usize;
143        let preprocessing = self.circuit.required_preprocessing();
144        let pre_process = (&preprocessing).into();
145        let network_content = NetworkContent::from_circuit(&self.circuit);
146        let mut hash = sha3::Keccak256::new();
147        hash.update(self.to_bytes());
148        let hash = hash.finalize();
149        let circuit_hash = usize::from_le_bytes(hash[0..8].try_into().unwrap());
150        let auxiliary_circuits = AuxiliaryCircuitInfo::counts(&self.metadata);
151
152        ProfileInfo {
153            network_depth,
154            total_gates,
155            network_content,
156            pre_process,
157            auxiliary_circuits,
158            circuit_hash,
159        }
160    }
161    /// The weight of the circuit. Used to compute the cost to run it.
162    pub fn weight(&self) -> usize {
163        self.profile_info().weight()
164    }
165    /// A function used for debugging when the depth is different from expected.
166    #[allow(dead_code)]
167    pub fn explain_depth(&self) {
168        explain_circuit_depth(&self.circuit);
169    }
170    pub fn to_bytes(&self) -> Vec<u8> {
171        let mut data = bincode::serialize(self).unwrap();
172        CURRENT_VERSION.add_to_bytes(&mut data);
173        data
174    }
175    pub fn from_bytes(mut bytes: Vec<u8>) -> Result<Self, String> {
176        let version = Version::from_bytes(&bytes)?;
177        let len_without_version = bytes.len() - version.len();
178        bytes.truncate(len_without_version);
179        match version {
180            Version::V1 => {
181                let v1_instruction: ArcisInstructionV1 =
182                    bincode::deserialize(&bytes).map_err(|e| e.to_string())?;
183                Ok(ArcisInstruction {
184                    circuit: v1_instruction
185                        .circuit
186                        .into_v2()
187                        .map_err(|e| e.to_string())?,
188                    metadata: v1_instruction.metadata,
189                })
190            }
191            Version::V2 => bincode::deserialize(&bytes).map_err(|e| e.to_string()),
192        }
193    }
194}
195/// The circuit metadata.
196/// Gives the order of inputs.
197#[derive(Debug, Serialize, Deserialize, Default)]
198pub struct ArcisMetadata {
199    pub input_order: Vec<ArcisInput>,
200}
201
202/// A form of input id.
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
204pub enum ArcisInput {
205    /// An input given by user. `InputId(0)` will be the 1st, `InputId(1)` will be the 2nd, etc...
206    InputId(usize),
207    /// An input given by the MXE.
208    Mxe(MxeInput),
209    /// An input given by user. `InputId(0)` will be the 1st, `InputId(1)` will be the 2nd, etc...
210    /// The algebraic type specifies the type the circuit expects.
211    TypedInput(usize, AlgebraicType),
212    /// Shared Rescue Key between a client and the MXE.
213    SharedRescueKey(SharedRescueKeyData),
214}
215
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
217enum Version {
218    V1,
219    V2,
220}
221
222impl Version {
223    pub fn from_bytes(bytes: &[u8]) -> Result<Version, String> {
224        let Some(d) = bytes.last() else {
225            return Err("Empty bytes.".into());
226        };
227        if *d == 255 {
228            let [.., v, _] = bytes else {
229                return Err("Invalid bytes length.".into());
230            };
231            if *v == 2 {
232                Ok(Version::V2)
233            } else {
234                Err(format!("Invalid version suffix: [{v}, 255]."))
235            }
236        } else {
237            Ok(Version::V1)
238        }
239    }
240    pub fn add_to_bytes(&self, bytes: &mut Vec<u8>) {
241        match self {
242            Version::V1 => {}
243            Version::V2 => {
244                bytes.push(2);
245                bytes.push(255);
246            }
247        }
248    }
249    pub fn len(&self) -> usize {
250        match self {
251            Version::V1 => 0,
252            Version::V2 => 2,
253        }
254    }
255}
256
257const CURRENT_VERSION: Version = Version::V2;
258
259type AsyncMPCCircuitV1 = core_utils::circuit::CircuitV1<Curve25519Ristretto>;
260#[derive(Deserialize)]
261struct ArcisInstructionV1 {
262    /// The circuit, as understood by our async-mpc library.
263    pub circuit: AsyncMPCCircuitV1,
264    /// Metadata.
265    /// Currently, only states which input to put in which position.
266    pub metadata: ArcisMetadata,
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::{MxeBitInput, MxeFieldInput, MxeScalarInput};
273    #[test]
274    fn v1_compatibility() {
275        const V1_FILES: &[&[u8]] = &[include_bytes!("../../old-arcis-files/v1/keygen.arcis")];
276        for file in V1_FILES {
277            ArcisInstruction::from_bytes(file.to_vec()).unwrap();
278        }
279    }
280    #[test]
281    fn version_v1() {
282        // V1 circuits end in their metadata.
283        // We check our understanding of how the version would be computed.
284        // We check that in all reasonable cases the computed version would be V1.
285        fn would_version_be_v1(metadata: ArcisMetadata) -> bool {
286            let data = bincode::serialize(&metadata).unwrap();
287            let version = Version::from_bytes(&data);
288            version.is_ok_and(|x| x == Version::V1)
289        }
290        // Empty ArcisMetadata is serialized as `[0u8; 8]`.
291        assert!(would_version_be_v1(ArcisMetadata::default()));
292        fn would_version_be_v1_one_item(item: ArcisInput) -> bool {
293            let metadata = ArcisMetadata {
294                input_order: vec![item],
295            };
296            would_version_be_v1(metadata)
297        }
298        // Checking with an input which is only a variant.
299        assert!(would_version_be_v1_one_item(ArcisInput::Mxe(
300            MxeInput::ScalarOnly(MxeScalarInput::ElGamalSecretKey())
301        )));
302        // Checking with inputs ending in a `usize`.
303        fn check_number(number: usize, result: bool) {
304            for item in [
305                ArcisInput::InputId(number),
306                ArcisInput::Mxe(MxeInput::Scalar(MxeFieldInput::RescueKey(number))),
307                ArcisInput::Mxe(MxeInput::Bit(MxeBitInput::AES256Key(number))),
308            ] {
309                assert_eq!(would_version_be_v1_one_item(item), result);
310            }
311        }
312        check_number(0, true);
313        check_number(255, true);
314        check_number(usize::MAX >> 1, true);
315        check_number(usize::MAX, false);
316        // We check that the limit at which it works is really what we think it is.
317        let limit = 255usize << 56;
318        check_number(limit, false);
319        check_number(limit - 1, true);
320        // As long as users have less than `255usize << 56` inputs,
321        // and we did not use negative numbers in input ids,
322        // this versioning system will correctly identify V1 instructions as V1.
323    }
324}