Skip to main content

haloumi_ir/expr/
aexpr.rs

1//! Structs for handling arithmetic expressions.
2
3use crate::{
4    expr::{ExprProperties, ExprProperty},
5    printer::IRPrintable,
6    traits::{Canonicalize, ConstantFolding, Evaluate},
7};
8use eqv::{EqvRelation, equiv};
9use haloumi_core::{eqv::SymbolicEqv, felt::Felt, slot::Slot};
10use haloumi_lowering::{ExprLowering, lowerable::LowerableExpr};
11use std::fmt::Write;
12use std::{
13    convert::Infallible,
14    ops::{Add, Mul, Neg},
15};
16
17/// Represents an arithmetic expression.
18#[derive(PartialEq, Eq, Clone, Debug)]
19pub struct IRAexpr(pub(crate) IRAexprImpl);
20
21#[derive(PartialEq, Eq, Clone)]
22pub(crate) enum IRAexprImpl {
23    /// Constant value.
24    Constant(Felt),
25    /// IO element of the circuit; inputs, outputs, cells, etc.
26    IO(Slot),
27    /// Represents the negation of the inner expression.
28    Negated(Box<IRAexpr>),
29    /// Represents the sum of the inner expressions.
30    Sum(Box<IRAexpr>, Box<IRAexpr>),
31    /// Represents the product of the inner expresions.
32    Product(Box<IRAexpr>, Box<IRAexpr>),
33}
34
35impl IRAexpr {
36    /// Creates a constant expression.
37    pub fn constant(felt: Felt) -> Self {
38        Self(IRAexprImpl::Constant(felt))
39    }
40
41    /// Creates an expression pointing to a slot.
42    pub fn slot(s: impl Into<Slot>) -> Self {
43        Self(IRAexprImpl::IO(s.into()))
44    }
45
46    /// Maps the IO in-place.
47    pub fn try_map_io<E>(&mut self, f: &impl Fn(&mut Slot) -> Result<(), E>) -> Result<(), E> {
48        match &mut self.0 {
49            IRAexprImpl::IO(func_io) => f(func_io),
50            IRAexprImpl::Negated(expr) => expr.try_map_io(f),
51            IRAexprImpl::Sum(lhs, rhs) => {
52                lhs.try_map_io(f)?;
53                rhs.try_map_io(f)
54            }
55            IRAexprImpl::Product(lhs, rhs) => {
56                lhs.try_map_io(f)?;
57                rhs.try_map_io(f)
58            }
59            _ => Ok(()),
60        }
61    }
62}
63
64impl Neg for IRAexpr {
65    type Output = Self;
66
67    fn neg(self) -> Self::Output {
68        Self(IRAexprImpl::Negated(Box::new(self)))
69    }
70}
71
72impl Add for IRAexpr {
73    type Output = Self;
74
75    fn add(self, rhs: Self) -> Self::Output {
76        Self(IRAexprImpl::Sum(Box::new(self), Box::new(rhs)))
77    }
78}
79
80impl Mul for IRAexpr {
81    type Output = Self;
82
83    fn mul(self, rhs: Self) -> Self::Output {
84        Self(IRAexprImpl::Product(Box::new(self), Box::new(rhs)))
85    }
86}
87
88impl From<Felt> for IRAexpr {
89    fn from(value: Felt) -> Self {
90        Self(IRAexprImpl::Constant(value))
91    }
92}
93
94impl From<Slot> for IRAexpr {
95    fn from(value: Slot) -> Self {
96        Self(IRAexprImpl::IO(value))
97    }
98}
99
100impl Evaluate<Option<Felt>> for IRAexpr {
101    fn evaluate(&self) -> Option<Felt> {
102        match &self.0 {
103            IRAexprImpl::Constant(felt) => Some(*felt),
104            IRAexprImpl::IO(_) => None,
105            IRAexprImpl::Negated(expr) => Evaluate::<Option<Felt>>::evaluate(expr).map(|f| -f),
106            IRAexprImpl::Sum(lhs, rhs) => Evaluate::<Option<Felt>>::evaluate(lhs)
107                .zip(Evaluate::<Option<Felt>>::evaluate(rhs))
108                .map(|(lhs, rhs)| lhs + rhs),
109            IRAexprImpl::Product(lhs, rhs) => Evaluate::<Option<Felt>>::evaluate(lhs)
110                .zip(Evaluate::<Option<Felt>>::evaluate(rhs))
111                .map(|(lhs, rhs)| lhs * rhs),
112        }
113    }
114}
115
116impl Evaluate<ExprProperties> for IRAexpr {
117    fn evaluate(&self) -> ExprProperties {
118        match &self.0 {
119            IRAexprImpl::Constant(_) => ExprProperty::Const.into(),
120            IRAexprImpl::IO(_) => Default::default(),
121            IRAexprImpl::Negated(expr) => expr.evaluate(),
122            IRAexprImpl::Sum(lhs, rhs) | IRAexprImpl::Product(lhs, rhs) => {
123                Evaluate::<ExprProperties>::evaluate(lhs)
124                    & Evaluate::<ExprProperties>::evaluate(rhs)
125            }
126        }
127    }
128}
129
130impl ConstantFolding for IRAexpr {
131    type T = Felt;
132
133    type Error = Infallible;
134
135    fn constant_fold(&mut self) -> Result<(), Self::Error> {
136        match &mut self.0 {
137            IRAexprImpl::Constant(_) => {}
138            IRAexprImpl::IO(_) => {}
139            IRAexprImpl::Negated(expr) => {
140                expr.constant_fold()?;
141                if let Some(f) = expr.const_value() {
142                    *self = (-f).into();
143                }
144            }
145
146            IRAexprImpl::Sum(lhs, rhs) => {
147                lhs.constant_fold()?;
148                rhs.constant_fold()?;
149
150                match (lhs.const_value(), rhs.const_value()) {
151                    (Some(lhs), Some(rhs)) => {
152                        *self = Self(IRAexprImpl::Constant(lhs + rhs));
153                    }
154                    (None, Some(rhs)) if rhs == 0usize => {
155                        *self = (**lhs).clone();
156                    }
157                    (Some(lhs), None) if lhs == 0usize => {
158                        *self = (**rhs).clone();
159                    }
160                    _ => {}
161                }
162            }
163            IRAexprImpl::Product(lhs, rhs) => {
164                lhs.constant_fold()?;
165                rhs.constant_fold()?;
166                match (lhs.const_value(), rhs.const_value()) {
167                    (Some(lhs), Some(rhs)) => {
168                        *self = (lhs * rhs).into();
169                    }
170                    // (* 1 X) => X
171                    (None, Some(rhs)) if rhs == 1usize => {
172                        *self = (**lhs).clone();
173                    }
174                    (Some(lhs), None) if lhs == 1usize => {
175                        *self = (**rhs).clone();
176                    }
177                    // (* 0 X) => 0
178                    (None, Some(rhs)) if rhs == 0usize => {
179                        *self = rhs.into();
180                    }
181                    (Some(lhs), None) if lhs == 0usize => {
182                        *self = lhs.into();
183                    }
184                    // (* -1 X) => -X
185                    (None, Some(rhs)) if rhs.is_minus_one() => {
186                        *self = Self(IRAexprImpl::Negated(lhs.clone()));
187                    }
188                    (Some(lhs), None) if lhs.is_minus_one() => {
189                        *self = Self(IRAexprImpl::Negated(rhs.clone()));
190                    }
191                    _ => {}
192                }
193            }
194        }
195        Ok(())
196    }
197
198    /// Returns `Some(_)` if the expression is a constant value. None otherwise.
199    fn const_value(&self) -> Option<Felt> {
200        match &self.0 {
201            IRAexprImpl::Constant(f) => Some(*f),
202            _ => None,
203        }
204    }
205}
206
207impl IRAexpr {
208    /// Returns the inner element of the expression if it matches [`IRAexprImpl::Negated`].
209    fn negated_inner(&self) -> Option<&IRAexpr> {
210        match &self.0 {
211            IRAexprImpl::Negated(inner) => Some(inner),
212            _ => None,
213        }
214    }
215}
216
217impl Canonicalize for IRAexpr {
218    fn canonicalize(&mut self) {
219        match &mut self.0 {
220            IRAexprImpl::Constant(_) => {}
221            IRAexprImpl::IO(_) => {}
222            IRAexprImpl::Negated(expr) => {
223                expr.canonicalize();
224                // (- (- X)) => X
225                if let Some(inner) = expr.negated_inner() {
226                    *self = inner.clone();
227                }
228            }
229            IRAexprImpl::Sum(_, _) => todo!(),
230            IRAexprImpl::Product(_, _) => todo!(),
231        };
232    }
233}
234
235impl std::fmt::Debug for IRAexprImpl {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        match self {
238            Self::Constant(arg0) => write!(f, "{arg0:?}"),
239            Self::IO(arg0) => write!(f, "{arg0:?}"),
240            Self::Negated(arg0) => write!(f, "(- {arg0:?})"),
241            Self::Sum(arg0, arg1) => write!(f, "(+ {arg0:?} {arg1:?})"),
242            Self::Product(arg0, arg1) => write!(f, "(* {arg0:?} {arg1:?})"),
243        }
244    }
245}
246
247impl EqvRelation<IRAexpr> for SymbolicEqv {
248    /// Two arithmetic expressions are equivalent if they are structurally equal, constant values
249    /// equal and variables are equivalent.
250    fn equivalent(lhs: &IRAexpr, rhs: &IRAexpr) -> bool {
251        match (&lhs.0, &rhs.0) {
252            (IRAexprImpl::Constant(lhs), IRAexprImpl::Constant(rhs)) => lhs == rhs,
253            (IRAexprImpl::IO(lhs), IRAexprImpl::IO(rhs)) => equiv!(Self | lhs, rhs),
254            (IRAexprImpl::Negated(lhs), IRAexprImpl::Negated(rhs)) => equiv!(Self | lhs, rhs),
255            (IRAexprImpl::Sum(lhs0, lhs1), IRAexprImpl::Sum(rhs0, rhs1)) => {
256                equiv!(Self | lhs0, rhs0) && equiv!(Self | lhs1, rhs1)
257            }
258            (IRAexprImpl::Product(lhs0, lhs1), IRAexprImpl::Product(rhs0, rhs1)) => {
259                equiv!(Self | lhs0, rhs0) && equiv!(Self | lhs1, rhs1)
260            }
261            _ => false,
262        }
263    }
264}
265
266impl LowerableExpr for IRAexpr {
267    fn lower<L>(self, l: &L) -> haloumi_lowering::Result<L::CellOutput>
268    where
269        L: ExprLowering + ?Sized,
270    {
271        match self.0 {
272            IRAexprImpl::Constant(f) => l.lower_constant(f),
273            IRAexprImpl::IO(io) => l.lower_funcio(io),
274            IRAexprImpl::Negated(expr) => l.lower_neg(&expr.lower(l)?),
275            IRAexprImpl::Sum(lhs, rhs) => l.lower_sum(&lhs.lower(l)?, &rhs.lower(l)?),
276            IRAexprImpl::Product(lhs, rhs) => l.lower_product(&lhs.lower(l)?, &rhs.lower(l)?),
277        }
278    }
279}
280
281impl IRPrintable for IRAexpr {
282    fn fmt(&self, ctx: &mut crate::printer::IRPrinterCtx<'_, '_>) -> crate::printer::Result {
283        match &self.0 {
284            IRAexprImpl::Constant(felt) => ctx.list("const", |ctx| write!(ctx, "{}", felt)),
285            IRAexprImpl::IO(slot) => slot.fmt(ctx),
286            IRAexprImpl::Negated(expr) => ctx.block("-", |ctx| expr.fmt(ctx)),
287            IRAexprImpl::Sum(lhs, rhs) => ctx.block("+", |ctx| {
288                let do_nl = lhs.depth() > 1 || rhs.depth() > 1;
289                if lhs.depth() > 1 {
290                    ctx.nl()?;
291                }
292                lhs.fmt(ctx)?;
293                if do_nl {
294                    ctx.nl()?;
295                } else {
296                    write!(ctx, " ")?;
297                }
298                rhs.fmt(ctx)
299            }),
300            IRAexprImpl::Product(lhs, rhs) => ctx.block("*", |ctx| {
301                let do_nl = lhs.depth() > 1 || rhs.depth() > 1;
302                if lhs.depth() > 1 {
303                    ctx.nl()?;
304                }
305                lhs.fmt(ctx)?;
306                if do_nl {
307                    ctx.nl()?;
308                } else {
309                    write!(ctx, " ")?;
310                }
311                rhs.fmt(ctx)
312            }),
313        }
314    }
315
316    fn depth(&self) -> usize {
317        match &self.0 {
318            IRAexprImpl::Constant(_) | IRAexprImpl::IO(_) => 1,
319            IRAexprImpl::Negated(expr) => 1 + expr.depth(),
320            IRAexprImpl::Sum(lhs, rhs) | IRAexprImpl::Product(lhs, rhs) => {
321                1 + std::cmp::max(lhs.depth(), rhs.depth())
322            }
323        }
324    }
325}
326
327#[cfg(test)]
328mod folding_tests {
329    use super::*;
330    use rstest::rstest;
331
332    use ff::PrimeField;
333
334    /// Implementation of BabyBear used for testing.
335    #[derive(PrimeField)]
336    #[PrimeFieldModulus = "2013265921"]
337    #[PrimeFieldGenerator = "31"]
338    #[PrimeFieldReprEndianness = "little"]
339    pub struct BabyBear([u64; 1]);
340
341    /// Creates a constant value under BabyBear
342    fn c(v: impl Into<BabyBear>) -> IRAexpr {
343        IRAexpr(IRAexprImpl::Constant(Felt::from(v.into())))
344    }
345
346    #[rstest]
347    fn folding_constant_within_field() {
348        let mut test = c(5);
349        let expected = test.clone();
350        test.constant_fold().unwrap();
351        assert_eq!(test, expected);
352    }
353
354    #[rstest]
355    fn folding_constant_outside_field() {
356        let mut test = c(2013265922);
357        let expected = c(1);
358        test.constant_fold().unwrap();
359        assert_eq!(test, expected);
360    }
361
362    #[rstest]
363    fn mult_identity() {
364        let lhs = c(1);
365        let rhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
366        let mut mul = IRAexpr(IRAexprImpl::Product(Box::new(lhs), Box::new(rhs.clone())));
367        mul.constant_fold().unwrap();
368        assert_eq!(mul, rhs);
369    }
370
371    #[rstest]
372    fn mult_identity_rev() {
373        let rhs = c(1);
374        let lhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
375        let mut mul = IRAexpr(IRAexprImpl::Product(Box::new(lhs.clone()), Box::new(rhs)));
376        mul.constant_fold().unwrap();
377        assert_eq!(mul, lhs);
378    }
379
380    #[rstest]
381    fn mult_by_zero() {
382        let lhs = c(0);
383        let rhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
384        let mut mul = IRAexpr(IRAexprImpl::Product(Box::new(lhs.clone()), Box::new(rhs)));
385        mul.constant_fold().unwrap();
386        assert_eq!(mul, lhs);
387    }
388
389    #[rstest]
390    fn mult_by_zero_rev() {
391        let rhs = c(0);
392        let lhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
393        let mut mul = IRAexpr(IRAexprImpl::Product(Box::new(lhs), Box::new(rhs.clone())));
394        mul.constant_fold().unwrap();
395        assert_eq!(mul, rhs);
396    }
397
398    #[rstest]
399    fn sum_identity() {
400        let lhs = c(0);
401        let rhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
402        let mut sum = IRAexpr(IRAexprImpl::Sum(Box::new(lhs), Box::new(rhs.clone())));
403        sum.constant_fold().unwrap();
404        assert_eq!(sum, rhs);
405    }
406
407    #[rstest]
408    fn sum_identity_rev() {
409        let rhs = c(0);
410        let lhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
411        let mut sum = IRAexpr(IRAexprImpl::Sum(Box::new(lhs.clone()), Box::new(rhs)));
412        sum.constant_fold().unwrap();
413        assert_eq!(sum, lhs);
414    }
415}