Skip to main content

arcis_compiler/core/expressions/
circuit.rs

1use crate::{
2    core::{
3        actually_used_field::ActuallyUsedField,
4        bounds::FieldBounds,
5        circuits::{
6            arithmetic::{
7                abs::AbsCircuit,
8                bitwise_and::BitwiseAnd,
9                float_div::Div,
10                float_exp::{Exp, Exp2},
11                float_log::{Ln, Log2},
12                float_sqrt::{DivSqrt, Sqrt},
13                lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
14                max::Max,
15                min::Min,
16                sigmoid::Sigmoid,
17                signed_divide::SignedDivide,
18                zero::ZeroCircuit,
19            },
20            boolean::{
21                ed25519::{
22                    Ed25519MXESign,
23                    Ed25519Sign,
24                    Ed25519Verify,
25                    Ed25519VerifyingKeyFromSecretKey,
26                },
27                sha3::{SHA3_256, SHA3_512},
28            },
29            general::{conversion::ConversionCircuit, identity::IdentityCircuit},
30            traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
31        },
32        expressions::{
33            expr::{EvalValue, Expr},
34            field_expr::FieldExpr,
35            other_expr::OtherExpr,
36        },
37        global_value::{
38            field_array::FieldArray,
39            global_expr_store::with_global_expr_store_as_local,
40            value::FieldValue,
41        },
42    },
43    types::DOUBLE_PRECISION_MANTISSA,
44    utils::field::BaseField,
45};
46use serde::{Deserialize, Serialize};
47
48// inline_const is unavailable,
49// so instead I follow the desugaring example on
50// https://github.com/rust-lang/rust/pull/104087 :
51struct StaticCircuits;
52impl StaticCircuits {
53    const MIN: Min = Min::new(true);
54    const MAX: Max = Max::new(true);
55    const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
56    const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
57    const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
58    const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
59    const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
60    const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
61    const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
62    const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
63    const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub enum ArithmeticCircuitId {
68    Min,
69    Max,
70    Log2,
71    Ln,
72    Exp2,
73    Exp,
74    Sqrt,
75    DivSqrt,
76    Div,
77    BitwiseAnd,
78    LowestBiggerPowerOfTwoMinusOne,
79    Sigmoid,
80    Zero,
81    Abs,
82    Identity,
83    SignedIntegerDiv,
84}
85
86impl ArithmeticCircuitId {
87    pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
88        match self {
89            ArithmeticCircuitId::Min => &StaticCircuits::MIN,
90            ArithmeticCircuitId::Max => &StaticCircuits::MAX,
91            ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
92            ArithmeticCircuitId::Ln => &StaticCircuits::LN,
93            ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
94            ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
95            ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
96            ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
97            ArithmeticCircuitId::Div => &StaticCircuits::DIV,
98            ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
99            ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
100            ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
101            ArithmeticCircuitId::Zero => &ZeroCircuit,
102            ArithmeticCircuitId::Abs => &AbsCircuit,
103            ArithmeticCircuitId::Identity => &IdentityCircuit,
104            ArithmeticCircuitId::SignedIntegerDiv => &SignedDivide,
105        }
106    }
107    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
108        T::apply_arithmetic_circuit_id(args, self)
109    }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
113pub enum GeneralCircuitId {
114    Conversion,
115}
116
117impl GeneralCircuitId {
118    pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
119        match self {
120            GeneralCircuitId::Conversion => &ConversionCircuit,
121        }
122    }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
126pub enum BaseCircuitId {
127    Ed25519Sign,
128    Ed25519MXESign,
129    Ed25519Verify,
130    Ed25519VerifyingKeyFromSecretKey,
131    Sha3_256,
132    Sha3_512,
133    Arith(ArithmeticCircuitId),
134}
135
136impl BaseCircuitId {
137    pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
138        match self {
139            BaseCircuitId::Ed25519Sign => &Ed25519Sign,
140            BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
141            BaseCircuitId::Ed25519Verify => &Ed25519Verify,
142            BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
143            BaseCircuitId::Sha3_256 => &SHA3_256,
144            BaseCircuitId::Sha3_512 => &SHA3_512,
145            BaseCircuitId::Arith(a) => a.to_circuit(),
146        }
147    }
148    pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
149        T::apply_base_circuit_id(args, self)
150    }
151    pub fn new_from_str(str: &str) -> Option<Self> {
152        let res = match str {
153            // Please keep alphabetical order.
154            "abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
155            "bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
156            "float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
157            "float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
158            "float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
159            "float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
160            "float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
161            "float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
162            "identity" => BaseCircuitId::Arith(ArithmeticCircuitId::Identity),
163            "lowest_bigger_power_of_two_minus_one" => {
164                BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
165            }
166            "max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
167            "min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
168            "sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
169            "sign" => BaseCircuitId::Ed25519Sign,
170            "mxe-sign" => BaseCircuitId::Ed25519MXESign,
171            "verify" => BaseCircuitId::Ed25519Verify,
172            "verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
173            "sha3-256" => BaseCircuitId::Sha3_256,
174            "sha3-512" => BaseCircuitId::Sha3_512,
175            "zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
176            _ => return None,
177        };
178        Some(res)
179    }
180}
181
182pub trait CircuitArg: Sized {
183    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
184    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
185}
186
187impl CircuitArg for BaseField {
188    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
189        c.to_circuit().eval(v).unwrap()
190    }
191
192    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
193        c.to_circuit().eval(v).unwrap()
194    }
195}
196
197impl CircuitArg for EvalValue {
198    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
199        CircuitArg::apply_arithmetic_circuit_id(
200            v.into_iter()
201                .map(|x| BaseField::from(x.to_signed_number()))
202                .collect(),
203            c,
204        )
205        .into_iter()
206        .map(EvalValue::Base)
207        .collect()
208    }
209
210    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
211        CircuitArg::apply_base_circuit_id(
212            v.into_iter()
213                .map(|x| BaseField::from(x.to_signed_number()))
214                .collect(),
215            c,
216        )
217        .into_iter()
218        .map(EvalValue::Base)
219        .collect()
220    }
221}
222
223impl CircuitArg for FieldValue<BaseField> {
224    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
225        let all_bounds =
226            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
227        let n = c.to_circuit().bounds(all_bounds).len();
228        (0..n)
229            .map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
230            .collect()
231    }
232
233    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
234        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
235        let n = c.to_circuit().bounds(all_bounds).len();
236        (0..n)
237            .map(|i| {
238                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
239                    expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
240                        v.iter().map(FieldValue::get_id).collect(),
241                        c,
242                        i,
243                    )))
244                }))
245            })
246            .collect()
247    }
248}
249
250impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
251    fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
252        let all_bounds =
253            std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
254        let n = c.to_circuit().bounds(all_bounds).len();
255        (0..n)
256            .map(|i| {
257                FieldArray::from(
258                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
259                        (0..N)
260                            .map(|j| {
261                                FieldValue::new(FieldExpr::SubCircuit(
262                                    v.iter()
263                                        .copied()
264                                        .map(|x| x[j])
265                                        .collect::<Vec<FieldValue<BaseField>>>(),
266                                    c,
267                                    i,
268                                ))
269                            })
270                            .collect::<Vec<FieldValue<BaseField>>>(),
271                    )
272                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
273                        panic!("Expected a Vec of length {} (found {})", N, v.len())
274                    }),
275                )
276            })
277            .collect::<Vec<Self>>()
278    }
279
280    fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
281        let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
282        let n = c.to_circuit().bounds(all_bounds).len();
283        (0..n)
284            .map(|i| {
285                FieldArray::from(
286                    TryInto::<[FieldValue<BaseField>; N]>::try_into(
287                        (0..N)
288                            .map(|j| {
289                                FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
290                                    expr_store.new_expr(Expr::Other(
291                                        OtherExpr::BaseArithmeticCircuit(
292                                            v.iter()
293                                                .copied()
294                                                .map(|x| x[j].get_id())
295                                                .collect::<Vec<usize>>(),
296                                            c,
297                                            i,
298                                        ),
299                                    ))
300                                }))
301                            })
302                            .collect::<Vec<FieldValue<BaseField>>>(),
303                    )
304                    .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
305                        panic!("Expected a Vec of length {} (found {})", N, v.len())
306                    }),
307                )
308            })
309            .collect::<Vec<Self>>()
310    }
311}