1use crate::{
2 core::{
3 actually_used_field::ActuallyUsedField,
4 bounds::FieldBounds,
5 circuits::{
6 arithmetic::{
7 abs::AbsCircuit,
8 bitwise_and::BitwiseAnd,
9 float_div::Div,
10 float_exp::{Exp, Exp2},
11 float_log::{Ln, Log2},
12 float_sqrt::{DivSqrt, Sqrt},
13 lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
14 max::Max,
15 min::Min,
16 sigmoid::Sigmoid,
17 zero::ZeroCircuit,
18 },
19 boolean::{
20 ed25519::{
21 Ed25519MXESign,
22 Ed25519Sign,
23 Ed25519Verify,
24 Ed25519VerifyingKeyFromSecretKey,
25 },
26 sha3::{SHA3_256, SHA3_512},
27 },
28 general::conversion::ConversionCircuit,
29 traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
30 },
31 expressions::{
32 expr::{EvalValue, Expr},
33 field_expr::FieldExpr,
34 other_expr::OtherExpr,
35 },
36 global_value::{
37 field_array::FieldArray,
38 global_expr_store::with_global_expr_store_as_local,
39 value::FieldValue,
40 },
41 },
42 types::DOUBLE_PRECISION_MANTISSA,
43 utils::field::BaseField,
44};
45use serde::{Deserialize, Serialize};
46
47struct StaticCircuits;
51impl StaticCircuits {
52 const MIN: Min = Min::new(true);
53 const MAX: Max = Max::new(true);
54 const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
55 const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
56 const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
57 const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
58 const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
59 const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
60 const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
61 const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
62 const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
66pub enum ArithmeticCircuitId {
67 Min,
68 Max,
69 Log2,
70 Ln,
71 Exp2,
72 Exp,
73 Sqrt,
74 DivSqrt,
75 Div,
76 BitwiseAnd,
77 LowestBiggerPowerOfTwoMinusOne,
78 Sigmoid,
79 Zero,
80 Abs,
81}
82
83impl ArithmeticCircuitId {
84 pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
85 match self {
86 ArithmeticCircuitId::Min => &StaticCircuits::MIN,
87 ArithmeticCircuitId::Max => &StaticCircuits::MAX,
88 ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
89 ArithmeticCircuitId::Ln => &StaticCircuits::LN,
90 ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
91 ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
92 ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
93 ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
94 ArithmeticCircuitId::Div => &StaticCircuits::DIV,
95 ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
96 ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
97 ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
98 ArithmeticCircuitId::Zero => &ZeroCircuit,
99 ArithmeticCircuitId::Abs => &AbsCircuit,
100 }
101 }
102 pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
103 T::apply_arithmetic_circuit_id(args, self)
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
108pub enum GeneralCircuitId {
109 Conversion,
110}
111
112impl GeneralCircuitId {
113 pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
114 match self {
115 GeneralCircuitId::Conversion => &ConversionCircuit,
116 }
117 }
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
121pub enum BaseCircuitId {
122 Ed25519Sign,
123 Ed25519MXESign,
124 Ed25519Verify,
125 Ed25519VerifyingKeyFromSecretKey,
126 Sha3_256,
127 Sha3_512,
128 Arith(ArithmeticCircuitId),
129}
130
131impl BaseCircuitId {
132 pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
133 match self {
134 BaseCircuitId::Ed25519Sign => &Ed25519Sign,
135 BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
136 BaseCircuitId::Ed25519Verify => &Ed25519Verify,
137 BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
138 BaseCircuitId::Sha3_256 => &SHA3_256,
139 BaseCircuitId::Sha3_512 => &SHA3_512,
140 BaseCircuitId::Arith(a) => a.to_circuit(),
141 }
142 }
143 pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
144 T::apply_base_circuit_id(args, self)
145 }
146 pub fn new_from_str(str: &str) -> Option<Self> {
147 let res = match str {
148 "abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
150 "bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
151 "float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
152 "float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
153 "float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
154 "float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
155 "float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
156 "float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
157 "lowest_bigger_power_of_two_minus_one" => {
158 BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
159 }
160 "max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
161 "min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
162 "sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
163 "sign" => BaseCircuitId::Ed25519Sign,
164 "mxe-sign" => BaseCircuitId::Ed25519MXESign,
165 "verify" => BaseCircuitId::Ed25519Verify,
166 "verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
167 "sha3-256" => BaseCircuitId::Sha3_256,
168 "sha3-512" => BaseCircuitId::Sha3_512,
169 "zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
170 _ => return None,
171 };
172 Some(res)
173 }
174}
175
176pub trait CircuitArg: Sized {
177 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
178 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
179}
180
181impl CircuitArg for BaseField {
182 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
183 c.to_circuit().eval(v).unwrap()
184 }
185
186 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
187 c.to_circuit().eval(v).unwrap()
188 }
189}
190
191impl CircuitArg for EvalValue {
192 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
193 CircuitArg::apply_arithmetic_circuit_id(
194 v.into_iter()
195 .map(|x| BaseField::from(x.to_signed_number()))
196 .collect(),
197 c,
198 )
199 .into_iter()
200 .map(EvalValue::Base)
201 .collect()
202 }
203
204 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
205 CircuitArg::apply_base_circuit_id(
206 v.into_iter()
207 .map(|x| BaseField::from(x.to_signed_number()))
208 .collect(),
209 c,
210 )
211 .into_iter()
212 .map(EvalValue::Base)
213 .collect()
214 }
215}
216
217impl CircuitArg for FieldValue<BaseField> {
218 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
219 let all_bounds =
220 std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
221 let n = c.to_circuit().bounds(all_bounds).len();
222 (0..n)
223 .map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
224 .collect()
225 }
226
227 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
228 let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
229 let n = c.to_circuit().bounds(all_bounds).len();
230 (0..n)
231 .map(|i| {
232 FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
233 expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
234 v.iter().map(FieldValue::get_id).collect(),
235 c,
236 i,
237 )))
238 }))
239 })
240 .collect()
241 }
242}
243
244impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
245 fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
246 let all_bounds =
247 std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
248 let n = c.to_circuit().bounds(all_bounds).len();
249 (0..n)
250 .map(|i| {
251 FieldArray::from(
252 TryInto::<[FieldValue<BaseField>; N]>::try_into(
253 (0..N)
254 .map(|j| {
255 FieldValue::new(FieldExpr::SubCircuit(
256 v.iter()
257 .copied()
258 .map(|x| x[j])
259 .collect::<Vec<FieldValue<BaseField>>>(),
260 c,
261 i,
262 ))
263 })
264 .collect::<Vec<FieldValue<BaseField>>>(),
265 )
266 .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
267 panic!("Expected a Vec of length {} (found {})", N, v.len())
268 }),
269 )
270 })
271 .collect::<Vec<Self>>()
272 }
273
274 fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
275 let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
276 let n = c.to_circuit().bounds(all_bounds).len();
277 (0..n)
278 .map(|i| {
279 FieldArray::from(
280 TryInto::<[FieldValue<BaseField>; N]>::try_into(
281 (0..N)
282 .map(|j| {
283 FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
284 expr_store.new_expr(Expr::Other(
285 OtherExpr::BaseArithmeticCircuit(
286 v.iter()
287 .copied()
288 .map(|x| x[j].get_id())
289 .collect::<Vec<usize>>(),
290 c,
291 i,
292 ),
293 ))
294 }))
295 })
296 .collect::<Vec<FieldValue<BaseField>>>(),
297 )
298 .unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
299 panic!("Expected a Vec of length {} (found {})", N, v.len())
300 }),
301 )
302 })
303 .collect::<Vec<Self>>()
304 }
305}