Skip to main content

arcis_compiler/core/expressions/
circuit.rs

1use crate::{
2    core::{
3        actually_used_field::ActuallyUsedField,
4        bounds::FieldBounds,
5        circuits::{
6            arithmetic::{
7                abs::AbsCircuit,
8                bitwise_and::BitwiseAnd,
9                fast_euclidean_by_constant::FastEuclideanByConstant,
10                float_div::Div,
11                float_exp::{Exp, Exp2},
12                float_log::{Ln, Log2},
13                float_sqrt::{DivSqrt, Sqrt},
14                lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
15                max::Max,
16                min::Min,
17                sigmoid::Sigmoid,
18                signed_divide::SignedDivide,
19                zero::ZeroCircuit,
20            },
21            boolean::{
22                ed25519::{
23                    Ed25519MXESign,
24                    Ed25519Sign,
25                    Ed25519Verify,
26                    Ed25519VerifyingKeyFromSecretKey,
27                },
28                sha3::{SHA3_256, SHA3_512},
29            },
30            general::{conversion::ConversionCircuit, identity::IdentityCircuit},
31            traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
32        },
33        expressions::{
34            expr::{EvalValue, Expr},
35            field_expr::FieldExpr,
36            other_expr::OtherExpr,
37        },
38        global_value::{
39            field_array::FieldArray,
40            global_expr_store::with_global_expr_store_as_local,
41            value::FieldValue,
42        },
43    },
44    types::DOUBLE_PRECISION_MANTISSA,
45    utils::field::BaseField,
46};
47use serde::{Deserialize, Serialize};
48use std::num::NonZeroU8;
49
50// inline_const is unavailable,
51// so instead I follow the desugaring example on
52// https://github.com/rust-lang/rust/pull/104087 :
53struct StaticCircuits;
54impl StaticCircuits {
55    const MIN: Min = Min::new(true);
56    const MAX: Max = Max::new(true);
57    const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
58    const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
59    const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
60    const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
61    const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
62    const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
63    const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
64    const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
65    const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
66    const FAST_EUCLIDEAN_U8: [FastEuclideanByConstant; 255] = {
67        let mut arr = [FastEuclideanByConstant {
68            b: NonZeroU8::new(1).unwrap(),
69        }; 255];
70        let mut idx = 1;
71        while idx < arr.len() {
72            arr[idx].b = NonZeroU8::new(idx as u8 + 1).unwrap();
73            idx += 1;
74        }
75        arr
76    };
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
80pub enum ArithmeticCircuitId {
81    Min,
82    Max,
83    Log2,
84    Ln,
85    Exp2,
86    Exp,
87    Sqrt,
88    DivSqrt,
89    Div,
90    BitwiseAnd,
91    LowestBiggerPowerOfTwoMinusOne,
92    Sigmoid,
93    Zero,
94    Abs,
95    Identity,
96    SignedIntegerDiv,
97    FastEuclideanByConstant(NonZeroU8),
98}
99
100impl ArithmeticCircuitId {
101    pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
102        match self {
103            ArithmeticCircuitId::Min => &StaticCircuits::MIN,
104            ArithmeticCircuitId::Max => &StaticCircuits::MAX,
105            ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
106            ArithmeticCircuitId::Ln => &StaticCircuits::LN,
107            ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
108            ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
109            ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
110            ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
111            ArithmeticCircuitId::Div => &StaticCircuits::DIV,
112            ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
113            ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
114            ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
115            ArithmeticCircuitId::Zero => &ZeroCircuit,
116            ArithmeticCircuitId::Abs => &AbsCircuit,
117            ArithmeticCircuitId::Identity => &IdentityCircuit,
118            ArithmeticCircuitId::SignedIntegerDiv => &SignedDivide,
119            ArithmeticCircuitId::FastEuclideanByConstant(x) => {
120                &StaticCircuits::FAST_EUCLIDEAN_U8[(x.get() - 1) as usize]
121            }
122        }
123    }
124    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
125        T::apply_arithmetic_circuit_id(args, self)
126    }
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
130pub enum GeneralCircuitId {
131    Conversion,
132}
133
134impl GeneralCircuitId {
135    pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
136        match self {
137            GeneralCircuitId::Conversion => &ConversionCircuit,
138        }
139    }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
143pub enum BaseCircuitId {
144    Ed25519Sign,
145    Ed25519MXESign,
146    Ed25519Verify,
147    Ed25519VerifyingKeyFromSecretKey,
148    Sha3_256,
149    Sha3_512,
150    Arith(ArithmeticCircuitId),
151}
152
153impl BaseCircuitId {
154    pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
155        match self {
156            BaseCircuitId::Ed25519Sign => &Ed25519Sign,
157            BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
158            BaseCircuitId::Ed25519Verify => &Ed25519Verify,
159            BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
160            BaseCircuitId::Sha3_256 => &SHA3_256,
161            BaseCircuitId::Sha3_512 => &SHA3_512,
162            BaseCircuitId::Arith(a) => a.to_circuit(),
163        }
164    }
165    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
166        T::apply_base_circuit_id(args, self)
167    }
168    pub fn new_from_str(str: &str) -> Option<Self> {
169        let res = match str {
170            // Please keep alphabetical order.
171            "abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
172            "bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
173            "float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
174            "float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
175            "float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
176            "float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
177            "float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
178            "float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
179            "identity" => BaseCircuitId::Arith(ArithmeticCircuitId::Identity),
180            "lowest_bigger_power_of_two_minus_one" => {
181                BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
182            }
183            "max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
184            "min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
185            "sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
186            "sign" => BaseCircuitId::Ed25519Sign,
187            "mxe-sign" => BaseCircuitId::Ed25519MXESign,
188            "verify" => BaseCircuitId::Ed25519Verify,
189            "verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
190            "sha3-256" => BaseCircuitId::Sha3_256,
191            "sha3-512" => BaseCircuitId::Sha3_512,
192            "zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
193            _ => return None,
194        };
195        Some(res)
196    }
197}
198
199pub trait CircuitArg: Sized {
200    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
201    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
202}
203
204impl CircuitArg for BaseField {
205    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
206        c.to_circuit().eval(v).unwrap()
207    }
208
209    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
210        c.to_circuit().eval(v).unwrap()
211    }
212}
213
214impl CircuitArg for EvalValue {
215    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
216        CircuitArg::apply_arithmetic_circuit_id(
217            v.into_iter()
218                .map(|x| BaseField::from(x.to_signed_number()))
219                .collect(),
220            c,
221        )
222        .into_iter()
223        .map(EvalValue::Base)
224        .collect()
225    }
226
227    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
228        CircuitArg::apply_base_circuit_id(
229            v.into_iter()
230                .map(|x| BaseField::from(x.to_signed_number()))
231                .collect(),
232            c,
233        )
234        .into_iter()
235        .map(EvalValue::Base)
236        .collect()
237    }
238}
239
240impl CircuitArg for FieldValue<BaseField> {
241    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
242        let all_bounds =
243            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
244        let n = c.to_circuit().bounds(all_bounds).len();
245        (0..n)
246            .map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
247            .collect()
248    }
249
250    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
251        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
252        let n = c.to_circuit().bounds(all_bounds).len();
253        (0..n)
254            .map(|i| {
255                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
256                    expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
257                        v.iter().map(FieldValue::get_id).collect(),
258                        c,
259                        i,
260                    )))
261                }))
262            })
263            .collect()
264    }
265}
266
267impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
268    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
269        let all_bounds =
270            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
271        let n = c.to_circuit().bounds(all_bounds).len();
272        (0..n)
273            .map(|i| {
274                FieldArray::from(
275                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
276                        (0..N)
277                            .map(|j| {
278                                FieldValue::new(FieldExpr::SubCircuit(
279                                    v.iter()
280                                        .copied()
281                                        .map(|x| x[j])
282                                        .collect::<Vec<FieldValue<BaseField>>>(),
283                                    c,
284                                    i,
285                                ))
286                            })
287                            .collect::<Vec<FieldValue<BaseField>>>(),
288                    )
289                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
290                        panic!("Expected a Vec of length {} (found {})", N, v.len())
291                    }),
292                )
293            })
294            .collect::<Vec<Self>>()
295    }
296
297    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
298        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
299        let n = c.to_circuit().bounds(all_bounds).len();
300        (0..n)
301            .map(|i| {
302                FieldArray::from(
303                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
304                        (0..N)
305                            .map(|j| {
306                                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
307                                    expr_store.new_expr(Expr::Other(
308                                        OtherExpr::BaseArithmeticCircuit(
309                                            v.iter()
310                                                .copied()
311                                                .map(|x| x[j].get_id())
312                                                .collect::<Vec<usize>>(),
313                                            c,
314                                            i,
315                                        ),
316                                    ))
317                                }))
318                            })
319                            .collect::<Vec<FieldValue<BaseField>>>(),
320                    )
321                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
322                        panic!("Expected a Vec of length {} (found {})", N, v.len())
323                    }),
324                )
325            })
326            .collect::<Vec<Self>>()
327    }
328}