Skip to main content

arcis_compiler/core/expressions/
circuit.rs

1use crate::{
2    core::{
3        actually_used_field::ActuallyUsedField,
4        bounds::FieldBounds,
5        circuits::{
6            arithmetic::{
7                abs::AbsCircuit,
8                bitwise_and::BitwiseAnd,
9                float_div::Div,
10                float_exp::{Exp, Exp2},
11                float_log::{Ln, Log2},
12                float_sqrt::{DivSqrt, Sqrt},
13                lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
14                max::Max,
15                min::Min,
16                sigmoid::Sigmoid,
17                zero::ZeroCircuit,
18            },
19            boolean::{
20                ed25519::{
21                    Ed25519MXESign,
22                    Ed25519Sign,
23                    Ed25519Verify,
24                    Ed25519VerifyingKeyFromSecretKey,
25                },
26                sha3::{SHA3_256, SHA3_512},
27            },
28            general::conversion::ConversionCircuit,
29            traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
30        },
31        expressions::{
32            expr::{EvalValue, Expr},
33            field_expr::FieldExpr,
34            other_expr::OtherExpr,
35        },
36        global_value::{
37            field_array::FieldArray,
38            global_expr_store::with_global_expr_store_as_local,
39            value::FieldValue,
40        },
41    },
42    types::DOUBLE_PRECISION_MANTISSA,
43    utils::field::BaseField,
44};
45use serde::{Deserialize, Serialize};
46
47// inline_const is unavailable,
48// so instead I follow the desugaring example on
49// https://github.com/rust-lang/rust/pull/104087 :
50struct StaticCircuits;
51impl StaticCircuits {
52    const MIN: Min = Min::new(true);
53    const MAX: Max = Max::new(true);
54    const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
55    const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
56    const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
57    const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
58    const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
59    const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
60    const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
61    const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
62    const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
66pub enum ArithmeticCircuitId {
67    Min,
68    Max,
69    Log2,
70    Ln,
71    Exp2,
72    Exp,
73    Sqrt,
74    DivSqrt,
75    Div,
76    BitwiseAnd,
77    LowestBiggerPowerOfTwoMinusOne,
78    Sigmoid,
79    Zero,
80    Abs,
81}
82
83impl ArithmeticCircuitId {
84    pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
85        match self {
86            ArithmeticCircuitId::Min => &StaticCircuits::MIN,
87            ArithmeticCircuitId::Max => &StaticCircuits::MAX,
88            ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
89            ArithmeticCircuitId::Ln => &StaticCircuits::LN,
90            ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
91            ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
92            ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
93            ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
94            ArithmeticCircuitId::Div => &StaticCircuits::DIV,
95            ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
96            ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
97            ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
98            ArithmeticCircuitId::Zero => &ZeroCircuit,
99            ArithmeticCircuitId::Abs => &AbsCircuit,
100        }
101    }
102    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
103        T::apply_arithmetic_circuit_id(args, self)
104    }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
108pub enum GeneralCircuitId {
109    Conversion,
110}
111
112impl GeneralCircuitId {
113    pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
114        match self {
115            GeneralCircuitId::Conversion => &ConversionCircuit,
116        }
117    }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
121pub enum BaseCircuitId {
122    Ed25519Sign,
123    Ed25519MXESign,
124    Ed25519Verify,
125    Ed25519VerifyingKeyFromSecretKey,
126    Sha3_256,
127    Sha3_512,
128    Arith(ArithmeticCircuitId),
129}
130
131impl BaseCircuitId {
132    pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
133        match self {
134            BaseCircuitId::Ed25519Sign => &Ed25519Sign,
135            BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
136            BaseCircuitId::Ed25519Verify => &Ed25519Verify,
137            BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
138            BaseCircuitId::Sha3_256 => &SHA3_256,
139            BaseCircuitId::Sha3_512 => &SHA3_512,
140            BaseCircuitId::Arith(a) => a.to_circuit(),
141        }
142    }
143    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
144        T::apply_base_circuit_id(args, self)
145    }
146    pub fn new_from_str(str: &str) -> Option<Self> {
147        let res = match str {
148            // Please keep alphabetical order.
149            "abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
150            "bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
151            "float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
152            "float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
153            "float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
154            "float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
155            "float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
156            "float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
157            "lowest_bigger_power_of_two_minus_one" => {
158                BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
159            }
160            "max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
161            "min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
162            "sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
163            "sign" => BaseCircuitId::Ed25519Sign,
164            "mxe-sign" => BaseCircuitId::Ed25519MXESign,
165            "verify" => BaseCircuitId::Ed25519Verify,
166            "verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
167            "sha3-256" => BaseCircuitId::Sha3_256,
168            "sha3-512" => BaseCircuitId::Sha3_512,
169            "zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
170            _ => return None,
171        };
172        Some(res)
173    }
174}
175
176pub trait CircuitArg: Sized {
177    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
178    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
179}
180
181impl CircuitArg for BaseField {
182    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
183        c.to_circuit().eval(v).unwrap()
184    }
185
186    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
187        c.to_circuit().eval(v).unwrap()
188    }
189}
190
191impl CircuitArg for EvalValue {
192    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
193        CircuitArg::apply_arithmetic_circuit_id(
194            v.into_iter()
195                .map(|x| BaseField::from(x.to_signed_number()))
196                .collect(),
197            c,
198        )
199        .into_iter()
200        .map(EvalValue::Base)
201        .collect()
202    }
203
204    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
205        CircuitArg::apply_base_circuit_id(
206            v.into_iter()
207                .map(|x| BaseField::from(x.to_signed_number()))
208                .collect(),
209            c,
210        )
211        .into_iter()
212        .map(EvalValue::Base)
213        .collect()
214    }
215}
216
217impl CircuitArg for FieldValue<BaseField> {
218    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
219        let all_bounds =
220            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
221        let n = c.to_circuit().bounds(all_bounds).len();
222        (0..n)
223            .map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
224            .collect()
225    }
226
227    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
228        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
229        let n = c.to_circuit().bounds(all_bounds).len();
230        (0..n)
231            .map(|i| {
232                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
233                    expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
234                        v.iter().map(FieldValue::get_id).collect(),
235                        c,
236                        i,
237                    )))
238                }))
239            })
240            .collect()
241    }
242}
243
244impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
245    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
246        let all_bounds =
247            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
248        let n = c.to_circuit().bounds(all_bounds).len();
249        (0..n)
250            .map(|i| {
251                FieldArray::from(
252                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
253                        (0..N)
254                            .map(|j| {
255                                FieldValue::new(FieldExpr::SubCircuit(
256                                    v.iter()
257                                        .copied()
258                                        .map(|x| x[j])
259                                        .collect::<Vec<FieldValue<BaseField>>>(),
260                                    c,
261                                    i,
262                                ))
263                            })
264                            .collect::<Vec<FieldValue<BaseField>>>(),
265                    )
266                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
267                        panic!("Expected a Vec of length {} (found {})", N, v.len())
268                    }),
269                )
270            })
271            .collect::<Vec<Self>>()
272    }
273
274    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
275        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
276        let n = c.to_circuit().bounds(all_bounds).len();
277        (0..n)
278            .map(|i| {
279                FieldArray::from(
280                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
281                        (0..N)
282                            .map(|j| {
283                                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
284                                    expr_store.new_expr(Expr::Other(
285                                        OtherExpr::BaseArithmeticCircuit(
286                                            v.iter()
287                                                .copied()
288                                                .map(|x| x[j].get_id())
289                                                .collect::<Vec<usize>>(),
290                                            c,
291                                            i,
292                                        ),
293                                    ))
294                                }))
295                            })
296                            .collect::<Vec<FieldValue<BaseField>>>(),
297                    )
298                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
299                        panic!("Expected a Vec of length {} (found {})", N, v.len())
300                    }),
301                )
302            })
303            .collect::<Vec<Self>>()
304    }
305}