use crate::{
core::{
bounds::{BoolBounds, Bounds, CurveBounds, FieldBounds},
expressions::{
bit_expr::BitExpr,
conversion_expr::{ConversionBounds, ConversionExpr, ConversionValue},
curve_expr::CurveExpr,
domain::DomainElement,
field_expr::{FieldExpr, InputId},
macro_uses::DefaultFiller,
other_expr::OtherExpr,
},
},
traits::ToMontgomery,
utils::{
curve_point::CurvePoint,
field::{BaseField, ScalarField},
number::Number,
used_field::UsedField,
},
};
use serde::{Deserialize, Serialize};
use std::{cell::Cell, hash::Hash};
#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Expr<Scalar: Clone, Bit: Clone = Scalar, Base: Clone = Scalar, Curve: Clone = Scalar> {
Scalar(FieldExpr<ScalarField, Scalar>),
Bit(BitExpr<Bit>),
ScalarConversion(ConversionExpr<ScalarField, Scalar, Bit>),
Base(FieldExpr<BaseField, Base>),
BaseConversion(ConversionExpr<BaseField, Base, Bit>),
Curve(CurveExpr<Curve, Scalar>),
Other(OtherExpr<Scalar, Base, Curve>),
}
#[derive(Clone, Debug, PartialEq)]
pub struct UndefinedBehavior {
reason: String,
}
#[derive(Clone, Debug, PartialEq)]
pub enum EvalFailure {
UndefinedBehavior(UndefinedBehavior),
BoundsNotRespected(String),
ImpossibleGate(String),
WrongType(String),
}
pub type EvalValue = DomainElement<bool, ScalarField, BaseField, CurvePoint>;
impl EvalValue {
pub fn to_signed_number(self) -> Number {
match self {
EvalValue::Bit(b) => Number::from(b),
EvalValue::Scalar(n) => n.to_signed_number(),
EvalValue::Base(n) => n.to_signed_number(),
EvalValue::Curve(n) => n.to_montgomery(false).0.to_signed_number(),
}
}
pub fn to_curve(self) -> CurvePoint {
match self {
EvalValue::Curve(point) => point,
_ => panic!("cannot convert to CurvePoint"),
}
}
pub fn arcis(a: BaseField) -> Self {
EvalValue::Base(a)
}
}
pub type EvalResult = Result<EvalValue, EvalFailure>;
impl EvalFailure {
pub fn err_ub<T>(reason: impl Into<String>) -> Result<T, EvalFailure> {
Err(EvalFailure::UndefinedBehavior(UndefinedBehavior {
reason: reason.into(),
}))
}
pub fn err_bounds<T>(reason: impl Into<String>) -> Result<T, EvalFailure> {
Err(EvalFailure::BoundsNotRespected(reason.into()))
}
pub fn ub(reason: impl Into<String>) -> EvalFailure {
EvalFailure::UndefinedBehavior(UndefinedBehavior {
reason: reason.into(),
})
}
pub fn err_imp<T>(reason: impl Into<String>) -> Result<T, EvalFailure> {
Err(EvalFailure::ImpossibleGate(reason.into()))
}
}
impl Expr<ScalarField, bool, BaseField, CurvePoint> {
pub fn eval(self) -> EvalResult {
let val: EvalValue = match self {
Expr::Scalar(expr) => EvalValue::Scalar(expr.eval()?),
Expr::Bit(expr) => EvalValue::Bit(expr.eval()?),
Expr::ScalarConversion(expr) => match expr.eval()? {
ConversionValue::Bit(b) => EvalValue::Bit(b),
ConversionValue::Scalar(s) => EvalValue::Scalar(s),
},
Expr::Base(expr) => EvalValue::Base(expr.eval()?),
Expr::BaseConversion(expr) => match expr.eval()? {
ConversionValue::Bit(b) => EvalValue::Bit(b),
ConversionValue::Scalar(s) => EvalValue::Base(s),
},
Expr::Curve(expr) => EvalValue::Curve(expr.eval()?),
Expr::Other(expr) => expr.eval()?,
};
Ok(val)
}
}
impl Expr<bool> {
pub fn is_plaintext(&self) -> bool {
match self {
Expr::Scalar(e) => e.is_plaintext(),
Expr::Bit(e) => e.is_plaintext(),
Expr::ScalarConversion(e) => e.is_plaintext(),
Expr::Base(e) => e.is_plaintext(),
Expr::BaseConversion(e) => e.is_plaintext(),
Expr::Curve(e) => e.is_plaintext(),
Expr::Other(e) => e.is_plaintext(),
}
}
}
impl<T: Clone, U: Clone, V: Clone, W: Clone> Expr<T, U, V, W> {
pub fn is_boolean(&self) -> bool {
match self {
Expr::Scalar(_) => false,
Expr::Bit(_) => true,
Expr::ScalarConversion(e) => e.is_boolean(),
Expr::Base(_) => false,
Expr::BaseConversion(e) => e.is_boolean(),
Expr::Curve(_) => false,
Expr::Other(_) => false,
}
}
pub fn is_eval_deterministic_fn_from_deps(&self) -> bool {
match self {
Expr::Scalar(e) => e.is_eval_deterministic_fn_from_deps(),
Expr::Bit(e) => e.is_eval_deterministic_fn_from_deps(),
Expr::ScalarConversion(e) => e.is_eval_deterministic_fn_from_deps(),
Expr::Base(e) => e.is_eval_deterministic_fn_from_deps(),
Expr::BaseConversion(e) => e.is_eval_deterministic_fn_from_deps(),
Expr::Curve(e) => e.is_eval_deterministic_fn_from_deps(),
Expr::Other(e) => e.is_eval_deterministic_fn_from_deps(),
}
}
pub fn get_input(&self) -> Option<InputId> {
match self {
Expr::Scalar(e) => e.get_input(),
Expr::Bit(e) => e.get_input(),
Expr::ScalarConversion(e) => e.get_input(),
Expr::Base(e) => e.get_input(),
Expr::BaseConversion(e) => e.get_input(),
Expr::Curve(e) => e.get_input(),
Expr::Other(e) => e.get_input(),
}
}
pub fn get_input_name(&self) -> &str {
match self {
Expr::Scalar(e) => e.get_input_name(),
Expr::Bit(e) => e.get_input_name(),
Expr::ScalarConversion(e) => e.get_input_name(),
Expr::Base(e) => e.get_input_name(),
Expr::BaseConversion(e) => e.get_input_name(),
Expr::Curve(e) => e.get_input_name(),
Expr::Other(e) => e.get_input_name(),
}
}
pub fn get_is_input_already_optimized_out(&self) -> Option<&Cell<bool>> {
match self {
Expr::Scalar(e) => e.get_is_input_already_optimized_out(),
Expr::Bit(e) => e.get_is_input_already_optimized_out(),
Expr::ScalarConversion(e) => e.get_is_input_already_optimized_out(),
Expr::Base(e) => e.get_is_input_already_optimized_out(),
Expr::BaseConversion(e) => e.get_is_input_already_optimized_out(),
Expr::Curve(e) => e.get_is_input_already_optimized_out(),
Expr::Other(e) => e.get_is_input_already_optimized_out(),
}
}
pub fn result_domain(&self) -> DomainElement<(), (), (), ()> {
match self {
Expr::Scalar(_) => DomainElement::Scalar(()),
Expr::Bit(_) => DomainElement::Bit(()),
Expr::ScalarConversion(_) => {
if self.is_boolean() {
DomainElement::Bit(())
} else {
DomainElement::Scalar(())
}
}
Expr::Base(_) => DomainElement::Base(()),
Expr::BaseConversion(_) => {
if self.is_boolean() {
DomainElement::Bit(())
} else {
DomainElement::Base(())
}
}
Expr::Curve(_) => DomainElement::Curve(()),
Expr::Other(_) => self
.clone()
.apply_2(&mut DefaultFiller::default())
.eval()
.expect("Eval of defaults failed on OtherExpr")
.to_domain(),
}
}
}
impl Expr<FieldBounds<ScalarField>, BoolBounds, FieldBounds<BaseField>, CurveBounds> {
pub fn bounds(self) -> Bounds {
match self {
Expr::Scalar(e) => Bounds::Scalar(e.bounds()),
Expr::Bit(e) => Bounds::Bit(e.bounds()),
Expr::ScalarConversion(e) => match e.bounds() {
ConversionBounds::Bit(b) => Bounds::Bit(b),
ConversionBounds::Scalar(b) => Bounds::Scalar(b),
},
Expr::Base(e) => Bounds::Base(e.bounds()),
Expr::BaseConversion(e) => match e.bounds() {
ConversionBounds::Bit(b) => Bounds::Bit(b),
ConversionBounds::Scalar(b) => Bounds::Base(b),
},
Expr::Curve(e) => Bounds::Curve(e.bounds()),
Expr::Other(e) => e.bounds(),
}
}
}
pub fn expr_true() -> Expr<usize> {
Expr::Bit(BitExpr::Val(true))
}
pub fn expr_false() -> Expr<usize> {
Expr::Bit(BitExpr::Val(false))
}