1use crate::{
2 core::{
3 actually_used_field::ActuallyUsedField,
4 bounds::FieldBounds,
5 circuits::{
6 arithmetic::{
7 abs::AbsCircuit,
8 bitwise_and::BitwiseAnd,
9 float_div::Div,
10 float_exp::{Exp, Exp2},
11 float_log::{Ln, Log2},
12 float_sqrt::{DivSqrt, Sqrt},
13 lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
14 max::Max,
15 min::Min,
16 sigmoid::Sigmoid,
17 zero::ZeroCircuit,
18 },
19 boolean::{
20 ed25519::{
21 Ed25519MXESign,
22 Ed25519Sign,
23 Ed25519Verify,
24 Ed25519VerifyingKeyFromSecretKey,
25 },
26 sha3::{SHA3_256, SHA3_512},
27 },
28 general::{conversion::ConversionCircuit, identity::IdentityCircuit},
29 traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
30 },
31 expressions::{
32 expr::{EvalValue, Expr},
33 field_expr::FieldExpr,
34 other_expr::OtherExpr,
35 },
36 global_value::{
37 field_array::FieldArray,
38 global_expr_store::with_global_expr_store_as_local,
39 value::FieldValue,
40 },
41 },
42 types::DOUBLE_PRECISION_MANTISSA,
43 utils::field::BaseField,
44};
45use serde::{Deserialize, Serialize};
46
47struct StaticCircuits;
51impl StaticCircuits {
52 const MIN: Min = Min::new(true);
53 const MAX: Max = Max::new(true);
54 const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
55 const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
56 const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
57 const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
58 const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
59 const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
60 const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
61 const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
62 const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
66pub enum ArithmeticCircuitId {
67 Min,
68 Max,
69 Log2,
70 Ln,
71 Exp2,
72 Exp,
73 Sqrt,
74 DivSqrt,
75 Div,
76 BitwiseAnd,
77 LowestBiggerPowerOfTwoMinusOne,
78 Sigmoid,
79 Zero,
80 Abs,
81 Identity,
82}
83
84impl ArithmeticCircuitId {
85 pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
86 match self {
87 ArithmeticCircuitId::Min => &StaticCircuits::MIN,
88 ArithmeticCircuitId::Max => &StaticCircuits::MAX,
89 ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
90 ArithmeticCircuitId::Ln => &StaticCircuits::LN,
91 ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
92 ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
93 ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
94 ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
95 ArithmeticCircuitId::Div => &StaticCircuits::DIV,
96 ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
97 ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
98 ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
99 ArithmeticCircuitId::Zero => &ZeroCircuit,
100 ArithmeticCircuitId::Abs => &AbsCircuit,
101 ArithmeticCircuitId::Identity => &IdentityCircuit,
102 }
103 }
104 pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
105 T::apply_arithmetic_circuit_id(args, self)
106 }
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
110pub enum GeneralCircuitId {
111 Conversion,
112}
113
114impl GeneralCircuitId {
115 pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
116 match self {
117 GeneralCircuitId::Conversion => &ConversionCircuit,
118 }
119 }
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
123pub enum BaseCircuitId {
124 Ed25519Sign,
125 Ed25519MXESign,
126 Ed25519Verify,
127 Ed25519VerifyingKeyFromSecretKey,
128 Sha3_256,
129 Sha3_512,
130 Arith(ArithmeticCircuitId),
131}
132
133impl BaseCircuitId {
134 pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
135 match self {
136 BaseCircuitId::Ed25519Sign => &Ed25519Sign,
137 BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
138 BaseCircuitId::Ed25519Verify => &Ed25519Verify,
139 BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
140 BaseCircuitId::Sha3_256 => &SHA3_256,
141 BaseCircuitId::Sha3_512 => &SHA3_512,
142 BaseCircuitId::Arith(a) => a.to_circuit(),
143 }
144 }
145 pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
146 T::apply_base_circuit_id(args, self)
147 }
148 pub fn new_from_str(str: &str) -> Option<Self> {
149 let res = match str {
150 "abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
152 "bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
153 "float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
154 "float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
155 "float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
156 "float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
157 "float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
158 "float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
159 "identity" => BaseCircuitId::Arith(ArithmeticCircuitId::Identity),
160 "lowest_bigger_power_of_two_minus_one" => {
161 BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
162 }
163 "max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
164 "min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
165 "sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
166 "sign" => BaseCircuitId::Ed25519Sign,
167 "mxe-sign" => BaseCircuitId::Ed25519MXESign,
168 "verify" => BaseCircuitId::Ed25519Verify,
169 "verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
170 "sha3-256" => BaseCircuitId::Sha3_256,
171 "sha3-512" => BaseCircuitId::Sha3_512,
172 "zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
173 _ => return None,
174 };
175 Some(res)
176 }
177}
178
179pub trait CircuitArg: Sized {
180 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
181 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
182}
183
184impl CircuitArg for BaseField {
185 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
186 c.to_circuit().eval(v).unwrap()
187 }
188
189 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
190 c.to_circuit().eval(v).unwrap()
191 }
192}
193
194impl CircuitArg for EvalValue {
195 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
196 CircuitArg::apply_arithmetic_circuit_id(
197 v.into_iter()
198 .map(|x| BaseField::from(x.to_signed_number()))
199 .collect(),
200 c,
201 )
202 .into_iter()
203 .map(EvalValue::Base)
204 .collect()
205 }
206
207 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
208 CircuitArg::apply_base_circuit_id(
209 v.into_iter()
210 .map(|x| BaseField::from(x.to_signed_number()))
211 .collect(),
212 c,
213 )
214 .into_iter()
215 .map(EvalValue::Base)
216 .collect()
217 }
218}
219
220impl CircuitArg for FieldValue<BaseField> {
221 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
222 let all_bounds =
223 std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
224 let n = c.to_circuit().bounds(all_bounds).len();
225 (0..n)
226 .map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
227 .collect()
228 }
229
230 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
231 let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
232 let n = c.to_circuit().bounds(all_bounds).len();
233 (0..n)
234 .map(|i| {
235 FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
236 expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
237 v.iter().map(FieldValue::get_id).collect(),
238 c,
239 i,
240 )))
241 }))
242 })
243 .collect()
244 }
245}
246
247impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
248 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
249 let all_bounds =
250 std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
251 let n = c.to_circuit().bounds(all_bounds).len();
252 (0..n)
253 .map(|i| {
254 FieldArray::from(
255 TryInto::<[FieldValue<BaseField>; N]>::try_into(
256 (0..N)
257 .map(|j| {
258 FieldValue::new(FieldExpr::SubCircuit(
259 v.iter()
260 .copied()
261 .map(|x| x[j])
262 .collect::<Vec<FieldValue<BaseField>>>(),
263 c,
264 i,
265 ))
266 })
267 .collect::<Vec<FieldValue<BaseField>>>(),
268 )
269 .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
270 panic!("Expected a Vec of length {} (found {})", N, v.len())
271 }),
272 )
273 })
274 .collect::<Vec<Self>>()
275 }
276
277 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
278 let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
279 let n = c.to_circuit().bounds(all_bounds).len();
280 (0..n)
281 .map(|i| {
282 FieldArray::from(
283 TryInto::<[FieldValue<BaseField>; N]>::try_into(
284 (0..N)
285 .map(|j| {
286 FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
287 expr_store.new_expr(Expr::Other(
288 OtherExpr::BaseArithmeticCircuit(
289 v.iter()
290 .copied()
291 .map(|x| x[j].get_id())
292 .collect::<Vec<usize>>(),
293 c,
294 i,
295 ),
296 ))
297 }))
298 })
299 .collect::<Vec<FieldValue<BaseField>>>(),
300 )
301 .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
302 panic!("Expected a Vec of length {} (found {})", N, v.len())
303 }),
304 )
305 })
306 .collect::<Vec<Self>>()
307 }
308}