1use crate::{
2 core::{
3 actually_used_field::ActuallyUsedField,
4 bounds::FieldBounds,
5 circuits::{
6 arithmetic::{
7 abs::AbsCircuit,
8 bitwise_and::BitwiseAnd,
9 fast_euclidean_by_constant::FastEuclideanByConstant,
10 float_div::Div,
11 float_exp::{Exp, Exp2},
12 float_log::{Ln, Log2},
13 float_sqrt::{DivSqrt, Sqrt},
14 lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
15 max::Max,
16 min::Min,
17 sigmoid::Sigmoid,
18 signed_divide::SignedDivide,
19 zero::ZeroCircuit,
20 },
21 boolean::{
22 ed25519::{
23 Ed25519MXESign,
24 Ed25519Sign,
25 Ed25519Verify,
26 Ed25519VerifyingKeyFromSecretKey,
27 },
28 sha3::{SHA3_256, SHA3_512},
29 },
30 general::{conversion::ConversionCircuit, identity::IdentityCircuit},
31 traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
32 },
33 expressions::{
34 expr::{EvalValue, Expr},
35 field_expr::FieldExpr,
36 other_expr::OtherExpr,
37 },
38 global_value::{
39 field_array::FieldArray,
40 global_expr_store::with_global_expr_store_as_local,
41 value::FieldValue,
42 },
43 },
44 types::DOUBLE_PRECISION_MANTISSA,
45 utils::field::BaseField,
46};
47use serde::{Deserialize, Serialize};
48use std::num::NonZeroU8;
49
50struct StaticCircuits;
54impl StaticCircuits {
55 const MIN: Min = Min::new(true);
56 const MAX: Max = Max::new(true);
57 const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
58 const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
59 const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
60 const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
61 const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
62 const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
63 const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
64 const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
65 const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
66 const FAST_EUCLIDEAN_U8: [FastEuclideanByConstant; 255] = {
67 let mut arr = [FastEuclideanByConstant {
68 b: NonZeroU8::new(1).unwrap(),
69 }; 255];
70 let mut idx = 1;
71 while idx < arr.len() {
72 arr[idx].b = NonZeroU8::new(idx as u8 + 1).unwrap();
73 idx += 1;
74 }
75 arr
76 };
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
80pub enum ArithmeticCircuitId {
81 Min,
82 Max,
83 Log2,
84 Ln,
85 Exp2,
86 Exp,
87 Sqrt,
88 DivSqrt,
89 Div,
90 BitwiseAnd,
91 LowestBiggerPowerOfTwoMinusOne,
92 Sigmoid,
93 Zero,
94 Abs,
95 Identity,
96 SignedIntegerDiv,
97 FastEuclideanByConstant(NonZeroU8),
98}
99
100impl ArithmeticCircuitId {
101 pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
102 match self {
103 ArithmeticCircuitId::Min => &StaticCircuits::MIN,
104 ArithmeticCircuitId::Max => &StaticCircuits::MAX,
105 ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
106 ArithmeticCircuitId::Ln => &StaticCircuits::LN,
107 ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
108 ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
109 ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
110 ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
111 ArithmeticCircuitId::Div => &StaticCircuits::DIV,
112 ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
113 ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
114 ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
115 ArithmeticCircuitId::Zero => &ZeroCircuit,
116 ArithmeticCircuitId::Abs => &AbsCircuit,
117 ArithmeticCircuitId::Identity => &IdentityCircuit,
118 ArithmeticCircuitId::SignedIntegerDiv => &SignedDivide,
119 ArithmeticCircuitId::FastEuclideanByConstant(x) => {
120 &StaticCircuits::FAST_EUCLIDEAN_U8[(x.get() - 1) as usize]
121 }
122 }
123 }
124 pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
125 T::apply_arithmetic_circuit_id(args, self)
126 }
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
130pub enum GeneralCircuitId {
131 Conversion,
132}
133
134impl GeneralCircuitId {
135 pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
136 match self {
137 GeneralCircuitId::Conversion => &ConversionCircuit,
138 }
139 }
140}
141
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
143pub enum BaseCircuitId {
144 Ed25519Sign,
145 Ed25519MXESign,
146 Ed25519Verify,
147 Ed25519VerifyingKeyFromSecretKey,
148 Sha3_256,
149 Sha3_512,
150 Arith(ArithmeticCircuitId),
151}
152
153impl BaseCircuitId {
154 pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
155 match self {
156 BaseCircuitId::Ed25519Sign => &Ed25519Sign,
157 BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
158 BaseCircuitId::Ed25519Verify => &Ed25519Verify,
159 BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
160 BaseCircuitId::Sha3_256 => &SHA3_256,
161 BaseCircuitId::Sha3_512 => &SHA3_512,
162 BaseCircuitId::Arith(a) => a.to_circuit(),
163 }
164 }
165 pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
166 T::apply_base_circuit_id(args, self)
167 }
168 pub fn new_from_str(str: &str) -> Option<Self> {
169 let res = match str {
170 "abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
172 "bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
173 "float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
174 "float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
175 "float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
176 "float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
177 "float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
178 "float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
179 "identity" => BaseCircuitId::Arith(ArithmeticCircuitId::Identity),
180 "lowest_bigger_power_of_two_minus_one" => {
181 BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
182 }
183 "max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
184 "min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
185 "sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
186 "sign" => BaseCircuitId::Ed25519Sign,
187 "mxe-sign" => BaseCircuitId::Ed25519MXESign,
188 "verify" => BaseCircuitId::Ed25519Verify,
189 "verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
190 "sha3-256" => BaseCircuitId::Sha3_256,
191 "sha3-512" => BaseCircuitId::Sha3_512,
192 "zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
193 _ => return None,
194 };
195 Some(res)
196 }
197}
198
199pub trait CircuitArg: Sized {
200 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
201 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
202}
203
204impl CircuitArg for BaseField {
205 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
206 c.to_circuit().eval(v).unwrap()
207 }
208
209 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
210 c.to_circuit().eval(v).unwrap()
211 }
212}
213
214impl CircuitArg for EvalValue {
215 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
216 CircuitArg::apply_arithmetic_circuit_id(
217 v.into_iter()
218 .map(|x| BaseField::from(x.to_signed_number()))
219 .collect(),
220 c,
221 )
222 .into_iter()
223 .map(EvalValue::Base)
224 .collect()
225 }
226
227 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
228 CircuitArg::apply_base_circuit_id(
229 v.into_iter()
230 .map(|x| BaseField::from(x.to_signed_number()))
231 .collect(),
232 c,
233 )
234 .into_iter()
235 .map(EvalValue::Base)
236 .collect()
237 }
238}
239
240impl CircuitArg for FieldValue<BaseField> {
241 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
242 let all_bounds =
243 std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
244 let n = c.to_circuit().bounds(all_bounds).len();
245 (0..n)
246 .map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
247 .collect()
248 }
249
250 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
251 let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
252 let n = c.to_circuit().bounds(all_bounds).len();
253 (0..n)
254 .map(|i| {
255 FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
256 expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
257 v.iter().map(FieldValue::get_id).collect(),
258 c,
259 i,
260 )))
261 }))
262 })
263 .collect()
264 }
265}
266
267impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
268 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
269 let all_bounds =
270 std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
271 let n = c.to_circuit().bounds(all_bounds).len();
272 (0..n)
273 .map(|i| {
274 FieldArray::from(
275 TryInto::<[FieldValue<BaseField>; N]>::try_into(
276 (0..N)
277 .map(|j| {
278 FieldValue::new(FieldExpr::SubCircuit(
279 v.iter()
280 .copied()
281 .map(|x| x[j])
282 .collect::<Vec<FieldValue<BaseField>>>(),
283 c,
284 i,
285 ))
286 })
287 .collect::<Vec<FieldValue<BaseField>>>(),
288 )
289 .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
290 panic!("Expected a Vec of length {} (found {})", N, v.len())
291 }),
292 )
293 })
294 .collect::<Vec<Self>>()
295 }
296
297 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
298 let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
299 let n = c.to_circuit().bounds(all_bounds).len();
300 (0..n)
301 .map(|i| {
302 FieldArray::from(
303 TryInto::<[FieldValue<BaseField>; N]>::try_into(
304 (0..N)
305 .map(|j| {
306 FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
307 expr_store.new_expr(Expr::Other(
308 OtherExpr::BaseArithmeticCircuit(
309 v.iter()
310 .copied()
311 .map(|x| x[j].get_id())
312 .collect::<Vec<usize>>(),
313 c,
314 i,
315 ),
316 ))
317 }))
318 })
319 .collect::<Vec<FieldValue<BaseField>>>(),
320 )
321 .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
322 panic!("Expected a Vec of length {} (found {})", N, v.len())
323 }),
324 )
325 })
326 .collect::<Vec<Self>>()
327 }
328}