Skip to main content

arcis_compiler/core/expressions/
circuit.rs

1use crate::{
2    core::{
3        actually_used_field::ActuallyUsedField,
4        bounds::FieldBounds,
5        circuits::{
6            arithmetic::{
7                abs::AbsCircuit,
8                bitwise_and::BitwiseAnd,
9                float_div::Div,
10                float_exp::{Exp, Exp2},
11                float_log::{Ln, Log2},
12                float_sqrt::{DivSqrt, Sqrt},
13                lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
14                max::Max,
15                min::Min,
16                sigmoid::Sigmoid,
17                zero::ZeroCircuit,
18            },
19            boolean::{
20                ed25519::{
21                    Ed25519MXESign,
22                    Ed25519Sign,
23                    Ed25519Verify,
24                    Ed25519VerifyingKeyFromSecretKey,
25                },
26                sha3::{SHA3_256, SHA3_512},
27            },
28            general::{conversion::ConversionCircuit, identity::IdentityCircuit},
29            traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
30        },
31        expressions::{
32            expr::{EvalValue, Expr},
33            field_expr::FieldExpr,
34            other_expr::OtherExpr,
35        },
36        global_value::{
37            field_array::FieldArray,
38            global_expr_store::with_global_expr_store_as_local,
39            value::FieldValue,
40        },
41    },
42    types::DOUBLE_PRECISION_MANTISSA,
43    utils::field::BaseField,
44};
45use serde::{Deserialize, Serialize};
46
47// inline_const is unavailable,
48// so instead I follow the desugaring example on
49// https://github.com/rust-lang/rust/pull/104087 :
50struct StaticCircuits;
51impl StaticCircuits {
52    const MIN: Min = Min::new(true);
53    const MAX: Max = Max::new(true);
54    const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
55    const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
56    const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
57    const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
58    const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
59    const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
60    const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
61    const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
62    const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
66pub enum ArithmeticCircuitId {
67    Min,
68    Max,
69    Log2,
70    Ln,
71    Exp2,
72    Exp,
73    Sqrt,
74    DivSqrt,
75    Div,
76    BitwiseAnd,
77    LowestBiggerPowerOfTwoMinusOne,
78    Sigmoid,
79    Zero,
80    Abs,
81    Identity,
82}
83
84impl ArithmeticCircuitId {
85    pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
86        match self {
87            ArithmeticCircuitId::Min => &StaticCircuits::MIN,
88            ArithmeticCircuitId::Max => &StaticCircuits::MAX,
89            ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
90            ArithmeticCircuitId::Ln => &StaticCircuits::LN,
91            ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
92            ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
93            ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
94            ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
95            ArithmeticCircuitId::Div => &StaticCircuits::DIV,
96            ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
97            ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
98            ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
99            ArithmeticCircuitId::Zero => &ZeroCircuit,
100            ArithmeticCircuitId::Abs => &AbsCircuit,
101            ArithmeticCircuitId::Identity => &IdentityCircuit,
102        }
103    }
104    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
105        T::apply_arithmetic_circuit_id(args, self)
106    }
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
110pub enum GeneralCircuitId {
111    Conversion,
112}
113
114impl GeneralCircuitId {
115    pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
116        match self {
117            GeneralCircuitId::Conversion => &ConversionCircuit,
118        }
119    }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
123pub enum BaseCircuitId {
124    Ed25519Sign,
125    Ed25519MXESign,
126    Ed25519Verify,
127    Ed25519VerifyingKeyFromSecretKey,
128    Sha3_256,
129    Sha3_512,
130    Arith(ArithmeticCircuitId),
131}
132
133impl BaseCircuitId {
134    pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
135        match self {
136            BaseCircuitId::Ed25519Sign => &Ed25519Sign,
137            BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
138            BaseCircuitId::Ed25519Verify => &Ed25519Verify,
139            BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
140            BaseCircuitId::Sha3_256 => &SHA3_256,
141            BaseCircuitId::Sha3_512 => &SHA3_512,
142            BaseCircuitId::Arith(a) => a.to_circuit(),
143        }
144    }
145    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
146        T::apply_base_circuit_id(args, self)
147    }
148    pub fn new_from_str(str: &str) -> Option<Self> {
149        let res = match str {
150            // Please keep alphabetical order.
151            "abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
152            "bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
153            "float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
154            "float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
155            "float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
156            "float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
157            "float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
158            "float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
159            "identity" => BaseCircuitId::Arith(ArithmeticCircuitId::Identity),
160            "lowest_bigger_power_of_two_minus_one" => {
161                BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
162            }
163            "max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
164            "min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
165            "sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
166            "sign" => BaseCircuitId::Ed25519Sign,
167            "mxe-sign" => BaseCircuitId::Ed25519MXESign,
168            "verify" => BaseCircuitId::Ed25519Verify,
169            "verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
170            "sha3-256" => BaseCircuitId::Sha3_256,
171            "sha3-512" => BaseCircuitId::Sha3_512,
172            "zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
173            _ => return None,
174        };
175        Some(res)
176    }
177}
178
179pub trait CircuitArg: Sized {
180    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
181    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
182}
183
184impl CircuitArg for BaseField {
185    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
186        c.to_circuit().eval(v).unwrap()
187    }
188
189    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
190        c.to_circuit().eval(v).unwrap()
191    }
192}
193
194impl CircuitArg for EvalValue {
195    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
196        CircuitArg::apply_arithmetic_circuit_id(
197            v.into_iter()
198                .map(|x| BaseField::from(x.to_signed_number()))
199                .collect(),
200            c,
201        )
202        .into_iter()
203        .map(EvalValue::Base)
204        .collect()
205    }
206
207    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
208        CircuitArg::apply_base_circuit_id(
209            v.into_iter()
210                .map(|x| BaseField::from(x.to_signed_number()))
211                .collect(),
212            c,
213        )
214        .into_iter()
215        .map(EvalValue::Base)
216        .collect()
217    }
218}
219
220impl CircuitArg for FieldValue<BaseField> {
221    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
222        let all_bounds =
223            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
224        let n = c.to_circuit().bounds(all_bounds).len();
225        (0..n)
226            .map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
227            .collect()
228    }
229
230    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
231        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
232        let n = c.to_circuit().bounds(all_bounds).len();
233        (0..n)
234            .map(|i| {
235                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
236                    expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
237                        v.iter().map(FieldValue::get_id).collect(),
238                        c,
239                        i,
240                    )))
241                }))
242            })
243            .collect()
244    }
245}
246
247impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
248    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
249        let all_bounds =
250            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
251        let n = c.to_circuit().bounds(all_bounds).len();
252        (0..n)
253            .map(|i| {
254                FieldArray::from(
255                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
256                        (0..N)
257                            .map(|j| {
258                                FieldValue::new(FieldExpr::SubCircuit(
259                                    v.iter()
260                                        .copied()
261                                        .map(|x| x[j])
262                                        .collect::<Vec<FieldValue<BaseField>>>(),
263                                    c,
264                                    i,
265                                ))
266                            })
267                            .collect::<Vec<FieldValue<BaseField>>>(),
268                    )
269                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
270                        panic!("Expected a Vec of length {} (found {})", N, v.len())
271                    }),
272                )
273            })
274            .collect::<Vec<Self>>()
275    }
276
277    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
278        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
279        let n = c.to_circuit().bounds(all_bounds).len();
280        (0..n)
281            .map(|i| {
282                FieldArray::from(
283                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
284                        (0..N)
285                            .map(|j| {
286                                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
287                                    expr_store.new_expr(Expr::Other(
288                                        OtherExpr::BaseArithmeticCircuit(
289                                            v.iter()
290                                                .copied()
291                                                .map(|x| x[j].get_id())
292                                                .collect::<Vec<usize>>(),
293                                            c,
294                                            i,
295                                        ),
296                                    ))
297                                }))
298                            })
299                            .collect::<Vec<FieldValue<BaseField>>>(),
300                    )
301                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
302                        panic!("Expected a Vec of length {} (found {})", N, v.len())
303                    }),
304                )
305            })
306            .collect::<Vec<Self>>()
307    }
308}