1use crate::{
2 core::{
3 actually_used_field::ActuallyUsedField,
4 bounds::FieldBounds,
5 circuits::{
6 arithmetic::{
7 abs::AbsCircuit,
8 bitwise_and::BitwiseAnd,
9 float_div::Div,
10 float_exp::{Exp, Exp2},
11 float_log::{Ln, Log2},
12 float_sqrt::{DivSqrt, Sqrt},
13 lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
14 max::Max,
15 min::Min,
16 sigmoid::Sigmoid,
17 signed_divide::SignedDivide,
18 zero::ZeroCircuit,
19 },
20 boolean::{
21 ed25519::{
22 Ed25519MXESign,
23 Ed25519Sign,
24 Ed25519Verify,
25 Ed25519VerifyingKeyFromSecretKey,
26 },
27 sha3::{SHA3_256, SHA3_512},
28 },
29 general::{conversion::ConversionCircuit, identity::IdentityCircuit},
30 traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
31 },
32 expressions::{
33 expr::{EvalValue, Expr},
34 field_expr::FieldExpr,
35 other_expr::OtherExpr,
36 },
37 global_value::{
38 field_array::FieldArray,
39 global_expr_store::with_global_expr_store_as_local,
40 value::FieldValue,
41 },
42 },
43 types::DOUBLE_PRECISION_MANTISSA,
44 utils::field::BaseField,
45};
46use serde::{Deserialize, Serialize};
47
48struct StaticCircuits;
52impl StaticCircuits {
53 const MIN: Min = Min::new(true);
54 const MAX: Max = Max::new(true);
55 const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
56 const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
57 const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
58 const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
59 const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
60 const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
61 const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
62 const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
63 const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub enum ArithmeticCircuitId {
68 Min,
69 Max,
70 Log2,
71 Ln,
72 Exp2,
73 Exp,
74 Sqrt,
75 DivSqrt,
76 Div,
77 BitwiseAnd,
78 LowestBiggerPowerOfTwoMinusOne,
79 Sigmoid,
80 Zero,
81 Abs,
82 Identity,
83 SignedIntegerDiv,
84}
85
86impl ArithmeticCircuitId {
87 pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
88 match self {
89 ArithmeticCircuitId::Min => &StaticCircuits::MIN,
90 ArithmeticCircuitId::Max => &StaticCircuits::MAX,
91 ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
92 ArithmeticCircuitId::Ln => &StaticCircuits::LN,
93 ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
94 ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
95 ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
96 ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
97 ArithmeticCircuitId::Div => &StaticCircuits::DIV,
98 ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
99 ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
100 ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
101 ArithmeticCircuitId::Zero => &ZeroCircuit,
102 ArithmeticCircuitId::Abs => &AbsCircuit,
103 ArithmeticCircuitId::Identity => &IdentityCircuit,
104 ArithmeticCircuitId::SignedIntegerDiv => &SignedDivide,
105 }
106 }
107 pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
108 T::apply_arithmetic_circuit_id(args, self)
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
113pub enum GeneralCircuitId {
114 Conversion,
115}
116
117impl GeneralCircuitId {
118 pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
119 match self {
120 GeneralCircuitId::Conversion => &ConversionCircuit,
121 }
122 }
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
126pub enum BaseCircuitId {
127 Ed25519Sign,
128 Ed25519MXESign,
129 Ed25519Verify,
130 Ed25519VerifyingKeyFromSecretKey,
131 Sha3_256,
132 Sha3_512,
133 Arith(ArithmeticCircuitId),
134}
135
136impl BaseCircuitId {
137 pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
138 match self {
139 BaseCircuitId::Ed25519Sign => &Ed25519Sign,
140 BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
141 BaseCircuitId::Ed25519Verify => &Ed25519Verify,
142 BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
143 BaseCircuitId::Sha3_256 => &SHA3_256,
144 BaseCircuitId::Sha3_512 => &SHA3_512,
145 BaseCircuitId::Arith(a) => a.to_circuit(),
146 }
147 }
148 pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
149 T::apply_base_circuit_id(args, self)
150 }
151 pub fn new_from_str(str: &str) -> Option<Self> {
152 let res = match str {
153 "abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
155 "bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
156 "float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
157 "float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
158 "float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
159 "float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
160 "float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
161 "float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
162 "identity" => BaseCircuitId::Arith(ArithmeticCircuitId::Identity),
163 "lowest_bigger_power_of_two_minus_one" => {
164 BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
165 }
166 "max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
167 "min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
168 "sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
169 "sign" => BaseCircuitId::Ed25519Sign,
170 "mxe-sign" => BaseCircuitId::Ed25519MXESign,
171 "verify" => BaseCircuitId::Ed25519Verify,
172 "verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
173 "sha3-256" => BaseCircuitId::Sha3_256,
174 "sha3-512" => BaseCircuitId::Sha3_512,
175 "zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
176 _ => return None,
177 };
178 Some(res)
179 }
180}
181
182pub trait CircuitArg: Sized {
183 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
184 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
185}
186
187impl CircuitArg for BaseField {
188 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
189 c.to_circuit().eval(v).unwrap()
190 }
191
192 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
193 c.to_circuit().eval(v).unwrap()
194 }
195}
196
197impl CircuitArg for EvalValue {
198 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
199 CircuitArg::apply_arithmetic_circuit_id(
200 v.into_iter()
201 .map(|x| BaseField::from(x.to_signed_number()))
202 .collect(),
203 c,
204 )
205 .into_iter()
206 .map(EvalValue::Base)
207 .collect()
208 }
209
210 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
211 CircuitArg::apply_base_circuit_id(
212 v.into_iter()
213 .map(|x| BaseField::from(x.to_signed_number()))
214 .collect(),
215 c,
216 )
217 .into_iter()
218 .map(EvalValue::Base)
219 .collect()
220 }
221}
222
223impl CircuitArg for FieldValue<BaseField> {
224 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
225 let all_bounds =
226 std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
227 let n = c.to_circuit().bounds(all_bounds).len();
228 (0..n)
229 .map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
230 .collect()
231 }
232
233 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
234 let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
235 let n = c.to_circuit().bounds(all_bounds).len();
236 (0..n)
237 .map(|i| {
238 FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
239 expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
240 v.iter().map(FieldValue::get_id).collect(),
241 c,
242 i,
243 )))
244 }))
245 })
246 .collect()
247 }
248}
249
250impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
251 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
252 let all_bounds =
253 std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
254 let n = c.to_circuit().bounds(all_bounds).len();
255 (0..n)
256 .map(|i| {
257 FieldArray::from(
258 TryInto::<[FieldValue<BaseField>; N]>::try_into(
259 (0..N)
260 .map(|j| {
261 FieldValue::new(FieldExpr::SubCircuit(
262 v.iter()
263 .copied()
264 .map(|x| x[j])
265 .collect::<Vec<FieldValue<BaseField>>>(),
266 c,
267 i,
268 ))
269 })
270 .collect::<Vec<FieldValue<BaseField>>>(),
271 )
272 .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
273 panic!("Expected a Vec of length {} (found {})", N, v.len())
274 }),
275 )
276 })
277 .collect::<Vec<Self>>()
278 }
279
280 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
281 let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
282 let n = c.to_circuit().bounds(all_bounds).len();
283 (0..n)
284 .map(|i| {
285 FieldArray::from(
286 TryInto::<[FieldValue<BaseField>; N]>::try_into(
287 (0..N)
288 .map(|j| {
289 FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
290 expr_store.new_expr(Expr::Other(
291 OtherExpr::BaseArithmeticCircuit(
292 v.iter()
293 .copied()
294 .map(|x| x[j].get_id())
295 .collect::<Vec<usize>>(),
296 c,
297 i,
298 ),
299 ))
300 }))
301 })
302 .collect::<Vec<FieldValue<BaseField>>>(),
303 )
304 .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
305 panic!("Expected a Vec of length {} (found {})", N, v.len())
306 }),
307 )
308 })
309 .collect::<Vec<Self>>()
310 }
311}