arcis_compiler/core/expressions/
expr.rs1use crate::{
2 core::{
3 bounds::{BoolBounds, Bounds, CurveBounds, FieldBounds},
4 expressions::{
5 bit_expr::BitExpr,
6 conversion_expr::{ConversionBounds, ConversionExpr, ConversionValue},
7 curve_expr::CurveExpr,
8 domain::DomainElement,
9 field_expr::{FieldExpr, InputId},
10 macro_uses::DefaultFiller,
11 other_expr::OtherExpr,
12 },
13 },
14 traits::ToMontgomery,
15 utils::{
16 curve_point::CurvePoint,
17 field::{BaseField, ScalarField},
18 number::Number,
19 used_field::UsedField,
20 },
21};
22use serde::{Deserialize, Serialize};
23use std::{cell::Cell, hash::Hash};
24
25#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
27pub enum Expr<Scalar: Clone, Bit: Clone = Scalar, Base: Clone = Scalar, Curve: Clone = Scalar> {
28 Scalar(FieldExpr<ScalarField, Scalar>),
31 Bit(BitExpr<Bit>),
32 ScalarConversion(ConversionExpr<ScalarField, Scalar, Bit>),
33 Base(FieldExpr<BaseField, Base>),
34 BaseConversion(ConversionExpr<BaseField, Base, Bit>),
35 Curve(CurveExpr<Curve, Scalar>),
36 Other(OtherExpr<Scalar, Base, Curve>),
37}
38
39#[derive(Clone, Debug, PartialEq)]
41pub struct UndefinedBehavior {
42 reason: String,
43}
44
45#[derive(Clone, Debug, PartialEq)]
47pub enum EvalFailure {
48 UndefinedBehavior(UndefinedBehavior),
51 BoundsNotRespected(String),
53 ImpossibleGate(String),
57 WrongType(String),
60}
61
62pub type EvalValue = DomainElement<bool, ScalarField, BaseField, CurvePoint>;
64
65impl EvalValue {
66 pub fn to_signed_number(self) -> Number {
67 match self {
68 EvalValue::Bit(b) => Number::from(b),
69 EvalValue::Scalar(n) => n.to_signed_number(),
70 EvalValue::Base(n) => n.to_signed_number(),
71 EvalValue::Curve(n) => n.to_montgomery(false).0.to_signed_number(),
72 }
73 }
74 pub fn to_curve(self) -> CurvePoint {
75 match self {
76 EvalValue::Curve(point) => point,
77 _ => panic!("cannot convert to CurvePoint"),
78 }
79 }
80 pub fn arcis(a: BaseField) -> Self {
81 EvalValue::Base(a)
82 }
83}
84
85pub type EvalResult = Result<EvalValue, EvalFailure>;
86
87impl EvalFailure {
88 pub fn err_ub<T>(reason: impl Into<String>) -> Result<T, EvalFailure> {
89 Err(EvalFailure::UndefinedBehavior(UndefinedBehavior {
90 reason: reason.into(),
91 }))
92 }
93 pub fn err_bounds<T>(reason: impl Into<String>) -> Result<T, EvalFailure> {
94 Err(EvalFailure::BoundsNotRespected(reason.into()))
95 }
96 pub fn ub(reason: impl Into<String>) -> EvalFailure {
97 EvalFailure::UndefinedBehavior(UndefinedBehavior {
98 reason: reason.into(),
99 })
100 }
101 pub fn err_imp<T>(reason: impl Into<String>) -> Result<T, EvalFailure> {
102 Err(EvalFailure::ImpossibleGate(reason.into()))
103 }
104}
105
106impl Expr<ScalarField, bool, BaseField, CurvePoint> {
108 pub fn eval(self) -> EvalResult {
110 let val: EvalValue = match self {
111 Expr::Scalar(expr) => EvalValue::Scalar(expr.eval()?),
112 Expr::Bit(expr) => EvalValue::Bit(expr.eval()?),
113 Expr::ScalarConversion(expr) => match expr.eval()? {
114 ConversionValue::Bit(b) => EvalValue::Bit(b),
115 ConversionValue::Scalar(s) => EvalValue::Scalar(s),
116 },
117 Expr::Base(expr) => EvalValue::Base(expr.eval()?),
118 Expr::BaseConversion(expr) => match expr.eval()? {
119 ConversionValue::Bit(b) => EvalValue::Bit(b),
120 ConversionValue::Scalar(s) => EvalValue::Base(s),
121 },
122 Expr::Curve(expr) => EvalValue::Curve(expr.eval()?),
123 Expr::Other(expr) => expr.eval()?,
124 };
125 Ok(val)
126 }
127}
128
129impl Expr<bool> {
130 pub fn is_plaintext(&self) -> bool {
131 match self {
132 Expr::Scalar(e) => e.is_plaintext(),
133 Expr::Bit(e) => e.is_plaintext(),
134 Expr::ScalarConversion(e) => e.is_plaintext(),
135 Expr::Base(e) => e.is_plaintext(),
136 Expr::BaseConversion(e) => e.is_plaintext(),
137 Expr::Curve(e) => e.is_plaintext(),
138 Expr::Other(e) => e.is_plaintext(),
139 }
140 }
141}
142
143impl<T: Clone, U: Clone, V: Clone, W: Clone> Expr<T, U, V, W> {
144 pub fn is_boolean(&self) -> bool {
145 match self {
146 Expr::Scalar(_) => false,
147 Expr::Bit(_) => true,
148 Expr::ScalarConversion(e) => e.is_boolean(),
149 Expr::Base(_) => false,
150 Expr::BaseConversion(e) => e.is_boolean(),
151 Expr::Curve(_) => false,
152 Expr::Other(_) => false,
153 }
154 }
155 pub fn is_eval_deterministic_fn_from_deps(&self) -> bool {
156 match self {
157 Expr::Scalar(e) => e.is_eval_deterministic_fn_from_deps(),
158 Expr::Bit(e) => e.is_eval_deterministic_fn_from_deps(),
159 Expr::ScalarConversion(e) => e.is_eval_deterministic_fn_from_deps(),
160 Expr::Base(e) => e.is_eval_deterministic_fn_from_deps(),
161 Expr::BaseConversion(e) => e.is_eval_deterministic_fn_from_deps(),
162 Expr::Curve(e) => e.is_eval_deterministic_fn_from_deps(),
163 Expr::Other(e) => e.is_eval_deterministic_fn_from_deps(),
164 }
165 }
166 pub fn get_input(&self) -> Option<InputId> {
167 match self {
168 Expr::Scalar(e) => e.get_input(),
169 Expr::Bit(e) => e.get_input(),
170 Expr::ScalarConversion(e) => e.get_input(),
171 Expr::Base(e) => e.get_input(),
172 Expr::BaseConversion(e) => e.get_input(),
173 Expr::Curve(e) => e.get_input(),
174 Expr::Other(e) => e.get_input(),
175 }
176 }
177 pub fn get_input_name(&self) -> &str {
178 match self {
179 Expr::Scalar(e) => e.get_input_name(),
180 Expr::Bit(e) => e.get_input_name(),
181 Expr::ScalarConversion(e) => e.get_input_name(),
182 Expr::Base(e) => e.get_input_name(),
183 Expr::BaseConversion(e) => e.get_input_name(),
184 Expr::Curve(e) => e.get_input_name(),
185 Expr::Other(e) => e.get_input_name(),
186 }
187 }
188 pub fn get_is_input_already_optimized_out(&self) -> Option<&Cell<bool>> {
189 match self {
190 Expr::Scalar(e) => e.get_is_input_already_optimized_out(),
191 Expr::Bit(e) => e.get_is_input_already_optimized_out(),
192 Expr::ScalarConversion(e) => e.get_is_input_already_optimized_out(),
193 Expr::Base(e) => e.get_is_input_already_optimized_out(),
194 Expr::BaseConversion(e) => e.get_is_input_already_optimized_out(),
195 Expr::Curve(e) => e.get_is_input_already_optimized_out(),
196 Expr::Other(e) => e.get_is_input_already_optimized_out(),
197 }
198 }
199 pub fn result_domain(&self) -> DomainElement<(), (), (), ()> {
200 match self {
201 Expr::Scalar(_) => DomainElement::Scalar(()),
202 Expr::Bit(_) => DomainElement::Bit(()),
203 Expr::ScalarConversion(_) => {
204 if self.is_boolean() {
205 DomainElement::Bit(())
206 } else {
207 DomainElement::Scalar(())
208 }
209 }
210 Expr::Base(_) => DomainElement::Base(()),
211 Expr::BaseConversion(_) => {
212 if self.is_boolean() {
213 DomainElement::Bit(())
214 } else {
215 DomainElement::Base(())
216 }
217 }
218 Expr::Curve(_) => DomainElement::Curve(()),
219 Expr::Other(_) => self
220 .clone()
221 .apply_2(&mut DefaultFiller::default())
222 .eval()
223 .expect("Eval of defaults failed on OtherExpr")
224 .to_domain(),
225 }
226 }
227}
228
229impl Expr<FieldBounds<ScalarField>, BoolBounds, FieldBounds<BaseField>, CurveBounds> {
230 pub fn bounds(self) -> Bounds {
231 match self {
232 Expr::Scalar(e) => Bounds::Scalar(e.bounds()),
233 Expr::Bit(e) => Bounds::Bit(e.bounds()),
234 Expr::ScalarConversion(e) => match e.bounds() {
235 ConversionBounds::Bit(b) => Bounds::Bit(b),
236 ConversionBounds::Scalar(b) => Bounds::Scalar(b),
237 },
238 Expr::Base(e) => Bounds::Base(e.bounds()),
239 Expr::BaseConversion(e) => match e.bounds() {
240 ConversionBounds::Bit(b) => Bounds::Bit(b),
241 ConversionBounds::Scalar(b) => Bounds::Base(b),
242 },
243 Expr::Curve(e) => Bounds::Curve(e.bounds()),
244 Expr::Other(e) => e.bounds(),
245 }
246 }
247}
248pub fn expr_true() -> Expr<usize> {
249 Expr::Bit(BitExpr::Val(true))
250}
251pub fn expr_false() -> Expr<usize> {
252 Expr::Bit(BitExpr::Val(false))
253}