use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::FieldBounds,
circuits::{
arithmetic::{
abs::AbsCircuit,
bitwise_and::BitwiseAnd,
float_div::Div,
float_exp::{Exp, Exp2},
float_log::{Ln, Log2},
float_sqrt::{DivSqrt, Sqrt},
lowest_bigger_power_of_two_minus_one::LowestBiggerPowerOfTwoMinusOne,
max::Max,
min::Min,
sigmoid::Sigmoid,
zero::ZeroCircuit,
},
boolean::{
ed25519::{
Ed25519MXESign,
Ed25519Sign,
Ed25519Verify,
Ed25519VerifyingKeyFromSecretKey,
},
sha3::{SHA3_256, SHA3_512},
},
general::{conversion::ConversionCircuit, identity::IdentityCircuit},
traits::{arithmetic_circuit::ArithmeticCircuit, general_circuit::GeneralCircuit},
},
expressions::{
expr::{EvalValue, Expr},
field_expr::FieldExpr,
other_expr::OtherExpr,
},
global_value::{
field_array::FieldArray,
global_expr_store::with_global_expr_store_as_local,
value::FieldValue,
},
},
types::DOUBLE_PRECISION_MANTISSA,
utils::field::BaseField,
};
use serde::{Deserialize, Serialize};
struct StaticCircuits;
impl StaticCircuits {
const MIN: Min = Min::new(true);
const MAX: Max = Max::new(true);
const LOG2: Log2 = Log2::new(DOUBLE_PRECISION_MANTISSA);
const LN: Ln = Ln::new(DOUBLE_PRECISION_MANTISSA);
const EXP2: Exp2 = Exp2::new(DOUBLE_PRECISION_MANTISSA);
const EXP: Exp = Exp::new(DOUBLE_PRECISION_MANTISSA);
const SQRT: Sqrt = Sqrt::new(DOUBLE_PRECISION_MANTISSA);
const DIV_SQRT: DivSqrt = DivSqrt::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
const DIV: Div = Div::new(DOUBLE_PRECISION_MANTISSA, DOUBLE_PRECISION_MANTISSA);
const BITWISE_AND: BitwiseAnd = BitwiseAnd::new(true);
const SIGMOID: Sigmoid = Sigmoid::new(DOUBLE_PRECISION_MANTISSA);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ArithmeticCircuitId {
Min,
Max,
Log2,
Ln,
Exp2,
Exp,
Sqrt,
DivSqrt,
Div,
BitwiseAnd,
LowestBiggerPowerOfTwoMinusOne,
Sigmoid,
Zero,
Abs,
Identity,
}
impl ArithmeticCircuitId {
pub fn to_circuit<F: ActuallyUsedField>(self) -> &'static dyn ArithmeticCircuit<F> {
match self {
ArithmeticCircuitId::Min => &StaticCircuits::MIN,
ArithmeticCircuitId::Max => &StaticCircuits::MAX,
ArithmeticCircuitId::Log2 => &StaticCircuits::LOG2,
ArithmeticCircuitId::Ln => &StaticCircuits::LN,
ArithmeticCircuitId::Exp2 => &StaticCircuits::EXP2,
ArithmeticCircuitId::Exp => &StaticCircuits::EXP,
ArithmeticCircuitId::Sqrt => &StaticCircuits::SQRT,
ArithmeticCircuitId::DivSqrt => &StaticCircuits::DIV_SQRT,
ArithmeticCircuitId::Div => &StaticCircuits::DIV,
ArithmeticCircuitId::BitwiseAnd => &StaticCircuits::BITWISE_AND,
ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne => &LowestBiggerPowerOfTwoMinusOne,
ArithmeticCircuitId::Sigmoid => &StaticCircuits::SIGMOID,
ArithmeticCircuitId::Zero => &ZeroCircuit,
ArithmeticCircuitId::Abs => &AbsCircuit,
ArithmeticCircuitId::Identity => &IdentityCircuit,
}
}
pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
T::apply_arithmetic_circuit_id(args, self)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum GeneralCircuitId {
Conversion,
}
impl GeneralCircuitId {
pub fn to_circuit(self) -> &'static dyn GeneralCircuit {
match self {
GeneralCircuitId::Conversion => &ConversionCircuit,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BaseCircuitId {
Ed25519Sign,
Ed25519MXESign,
Ed25519Verify,
Ed25519VerifyingKeyFromSecretKey,
Sha3_256,
Sha3_512,
Arith(ArithmeticCircuitId),
}
impl BaseCircuitId {
pub fn to_circuit(self) -> &'static dyn ArithmeticCircuit<BaseField> {
match self {
BaseCircuitId::Ed25519Sign => &Ed25519Sign,
BaseCircuitId::Ed25519MXESign => &Ed25519MXESign,
BaseCircuitId::Ed25519Verify => &Ed25519Verify,
BaseCircuitId::Ed25519VerifyingKeyFromSecretKey => &Ed25519VerifyingKeyFromSecretKey,
BaseCircuitId::Sha3_256 => &SHA3_256,
BaseCircuitId::Sha3_512 => &SHA3_512,
BaseCircuitId::Arith(a) => a.to_circuit(),
}
}
pub fn apply<T: CircuitArg>(self, args: Vec<T>) -> Vec<T> {
T::apply_base_circuit_id(args, self)
}
pub fn new_from_str(str: &str) -> Option<Self> {
let res = match str {
"abs" => BaseCircuitId::Arith(ArithmeticCircuitId::Abs),
"bitwise_and" => BaseCircuitId::Arith(ArithmeticCircuitId::BitwiseAnd),
"float_div" => BaseCircuitId::Arith(ArithmeticCircuitId::Div),
"float_exp" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp),
"float_exp2" => BaseCircuitId::Arith(ArithmeticCircuitId::Exp2),
"float_ln" => BaseCircuitId::Arith(ArithmeticCircuitId::Ln),
"float_log2" => BaseCircuitId::Arith(ArithmeticCircuitId::Log2),
"float_sqrt" => BaseCircuitId::Arith(ArithmeticCircuitId::Sqrt),
"identity" => BaseCircuitId::Arith(ArithmeticCircuitId::Identity),
"lowest_bigger_power_of_two_minus_one" => {
BaseCircuitId::Arith(ArithmeticCircuitId::LowestBiggerPowerOfTwoMinusOne)
}
"max" => BaseCircuitId::Arith(ArithmeticCircuitId::Max),
"min" => BaseCircuitId::Arith(ArithmeticCircuitId::Min),
"sigmoid" => BaseCircuitId::Arith(ArithmeticCircuitId::Sigmoid),
"sign" => BaseCircuitId::Ed25519Sign,
"mxe-sign" => BaseCircuitId::Ed25519MXESign,
"verify" => BaseCircuitId::Ed25519Verify,
"verifying_key_from_secret_key" => BaseCircuitId::Ed25519VerifyingKeyFromSecretKey,
"sha3-256" => BaseCircuitId::Sha3_256,
"sha3-512" => BaseCircuitId::Sha3_512,
"zero" => BaseCircuitId::Arith(ArithmeticCircuitId::Zero),
_ => return None,
};
Some(res)
}
}
pub trait CircuitArg: Sized {
fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self>;
fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self>;
}
impl CircuitArg for BaseField {
fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
c.to_circuit().eval(v).unwrap()
}
fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
c.to_circuit().eval(v).unwrap()
}
}
impl CircuitArg for EvalValue {
fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
CircuitArg::apply_arithmetic_circuit_id(
v.into_iter()
.map(|x| BaseField::from(x.to_signed_number()))
.collect(),
c,
)
.into_iter()
.map(EvalValue::Base)
.collect()
}
fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
CircuitArg::apply_base_circuit_id(
v.into_iter()
.map(|x| BaseField::from(x.to_signed_number()))
.collect(),
c,
)
.into_iter()
.map(EvalValue::Base)
.collect()
}
}
impl CircuitArg for FieldValue<BaseField> {
fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
let all_bounds =
std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
let n = c.to_circuit().bounds(all_bounds).len();
(0..n)
.map(|i| FieldValue::new(FieldExpr::SubCircuit(v.clone(), c, i)))
.collect()
}
fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
let n = c.to_circuit().bounds(all_bounds).len();
(0..n)
.map(|i| {
FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
expr_store.new_expr(Expr::Other(OtherExpr::BaseArithmeticCircuit(
v.iter().map(FieldValue::get_id).collect(),
c,
i,
)))
}))
})
.collect()
}
}
impl<const N: usize> CircuitArg for FieldArray<N, BaseField> {
fn apply_arithmetic_circuit_id(v: Vec<Self>, c: ArithmeticCircuitId) -> Vec<Self> {
let all_bounds =
std::iter::repeat_n(FieldBounds::<BaseField>::All, v.len()).collect::<Vec<_>>();
let n = c.to_circuit().bounds(all_bounds).len();
(0..n)
.map(|i| {
FieldArray::from(
TryInto::<[FieldValue<BaseField>; N]>::try_into(
(0..N)
.map(|j| {
FieldValue::new(FieldExpr::SubCircuit(
v.iter()
.copied()
.map(|x| x[j])
.collect::<Vec<FieldValue<BaseField>>>(),
c,
i,
))
})
.collect::<Vec<FieldValue<BaseField>>>(),
)
.unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
panic!("Expected a Vec of length {} (found {})", N, v.len())
}),
)
})
.collect::<Vec<Self>>()
}
fn apply_base_circuit_id(v: Vec<Self>, c: BaseCircuitId) -> Vec<Self> {
let all_bounds = std::iter::repeat_n(FieldBounds::All, v.len()).collect::<Vec<_>>();
let n = c.to_circuit().bounds(all_bounds).len();
(0..n)
.map(|i| {
FieldArray::from(
TryInto::<[FieldValue<BaseField>; N]>::try_into(
(0..N)
.map(|j| {
FieldValue::from_id(with_global_expr_store_as_local(|expr_store| {
expr_store.new_expr(Expr::Other(
OtherExpr::BaseArithmeticCircuit(
v.iter()
.copied()
.map(|x| x[j].get_id())
.collect::<Vec<usize>>(),
c,
i,
),
))
}))
})
.collect::<Vec<FieldValue<BaseField>>>(),
)
.unwrap_or_else(|v: Vec<FieldValue<BaseField>>| {
panic!("Expected a Vec of length {} (found {})", N, v.len())
}),
)
})
.collect::<Vec<Self>>()
}
}