Skip to main content

arcis_compiler/core/
instruction.rs

1use crate::{
2    core::{
3        expressions::{domain::DomainElement, expr::EvalValue},
4        mxe_input::{ArxInput, MxeInput},
5        profile_circuit::{explain_circuit_depth, get_circuit_depth},
6    },
7    network_content::NetworkContent,
8    profile_info::ProfileInfo,
9    utils::{curve_point::CurvePoint, number::Number},
10    AsyncMPCCircuit,
11};
12use core_utils::circuit::mock_eval::MockRng;
13use ff::Field;
14use num_bigint::{BigUint, ToBigInt};
15use num_traits::Zero;
16use primitives::algebra::{
17    elliptic_curve::{BaseFieldElement, Curve, Point, Scalar},
18    field::subfield_element::Mersenne107Element,
19};
20use rand::Rng;
21use serde::{Deserialize, Serialize};
22
23struct MockEvalRng<'a, R: Rng + ?Sized> {
24    rng: &'a mut R,
25    bools: &'a [Option<bool>],
26}
27
28impl<'a, R: Rng + ?Sized> MockEvalRng<'a, R> {
29    pub fn new(rng: &'a mut R, bools: &'a [Option<bool>]) -> Self {
30        Self { rng, bools }
31    }
32}
33impl<'a, R: Rng + ?Sized> MockRng for MockEvalRng<'a, R> {
34    fn gen_bit(&mut self) -> bool {
35        if let Some(b) = self.bools.first() {
36            self.bools = &self.bools[1..];
37            if let Some(b) = b {
38                return *b;
39            }
40        }
41        self.rng.gen()
42    }
43
44    fn gen_da_bit(&mut self) -> bool {
45        self.rng.gen()
46    }
47
48    fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
49        Scalar::<C>::random(&mut self.rng)
50    }
51
52    fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
53        BaseFieldElement::<C>::random(&mut self.rng)
54    }
55
56    fn gen_mersenne(&mut self) -> Mersenne107Element {
57        Mersenne107Element::random(&mut self.rng)
58    }
59
60    fn gen_point<C: Curve>(&mut self) -> Point<C> {
61        self.rng.gen()
62    }
63}
64
65/// The main output of the compiler.
66#[derive(Debug, Serialize, Deserialize)]
67pub struct ArcisInstruction {
68    /// The circuit, as understood by our async-mpc library.
69    pub circuit: AsyncMPCCircuit,
70    /// Metadata.
71    /// Currently, only states which input to put in which position.
72    pub metadata: ArcisMetadata,
73}
74
75impl ArcisInstruction {
76    /// Local evaluation of a circuit on some inputs.
77    /// Used for testing.
78    pub fn mock_eval_vec<R: Rng + ?Sized>(
79        &self,
80        inputs: Vec<EvalValue>,
81        output_domains: &[DomainElement<(), (), (), ()>],
82        bools: &[Option<bool>],
83        rng: &mut R,
84    ) -> Vec<EvalValue> {
85        let numbers: Vec<EvalValue> = self
86            .metadata
87            .input_order
88            .iter()
89            .map(|x| match *x {
90                ArcisInput::InputId(id) => inputs[id],
91                ArcisInput::Mxe(input) => input.mock_eval(),
92            })
93            .collect();
94
95        let v = numbers
96            .into_iter()
97            .map(|x| match x {
98                EvalValue::Bit(b) => BigUint::from(b),
99                EvalValue::Scalar(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
100                EvalValue::Base(b) => BigUint::from_bytes_le(&b.to_le_bytes()),
101                EvalValue::Curve(b) => BigUint::from_bytes_le(&b.to_bytes()),
102            })
103            .collect();
104        let mut mock_rng = MockEvalRng::new(rng, bools);
105        let res = self.circuit.mock_eval_big_uint(v, &mut mock_rng);
106        std::iter::zip(output_domains, res)
107            .map(|(domain, x)| match domain {
108                DomainElement::Bit(_) => EvalValue::Bit(!x.is_zero()),
109                DomainElement::Scalar(_) => {
110                    EvalValue::Scalar(Number::from(x.to_bigint().unwrap()).into())
111                }
112                DomainElement::Base(_) => {
113                    EvalValue::Base(Number::from(x.to_bigint().unwrap()).into())
114                }
115                DomainElement::Curve(_) => {
116                    let mut bytes = x.to_bytes_le();
117                    bytes.resize(32, 0);
118                    EvalValue::Curve(
119                        CurvePoint::from_le_bytes(&bytes)
120                            .expect("Failed to convert to CurvePoint."),
121                    )
122                }
123            })
124            .collect()
125    }
126    /// The network depth in number of rounds of communication necessary to compute the circuit.
127    /// Used for profiling.
128    pub fn network_depth(&self) -> usize {
129        get_circuit_depth(&self.circuit)
130    }
131    /// Gives profiling information on the circuit.
132    pub fn profile_info(&self) -> ProfileInfo {
133        let network_depth = self.network_depth();
134        let total_gates = self.circuit.ops_count() as usize;
135        let preprocessing = self.circuit.required_preprocessing();
136        let pre_process = (&preprocessing).into();
137        let network_content = NetworkContent::from_circuit(&self.circuit);
138        ProfileInfo {
139            network_depth,
140            total_gates,
141            network_content,
142            pre_process,
143        }
144    }
145    /// The weight of the circuit. Used to compute the cost to run it.
146    pub fn weight(&self) -> usize {
147        self.profile_info().weight()
148    }
149    /// A function used for debugging when the depth is different from expected.
150    #[allow(dead_code)]
151    pub fn explain_depth(&self) {
152        explain_circuit_depth(&self.circuit);
153    }
154}
155
156/// The circuit metadata.
157/// Gives the order of inputs.
158#[derive(Debug, Serialize, Deserialize, Default)]
159pub struct ArcisMetadata {
160    pub input_order: Vec<ArcisInput>,
161}
162
163/// A form of input id.
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
165pub enum ArcisInput {
166    /// An input given by user. `InputId(0)` will be the 1st, `InputId(1)` will be the 2nd, etc...
167    InputId(usize),
168    /// An input given by the MXE.
169    Mxe(MxeInput),
170}