1use crate::{
4 expr::{ExprProperties, ExprProperty},
5 printer::IRPrintable,
6 traits::{Canonicalize, ConstantFolding, Evaluate},
7};
8use eqv::{EqvRelation, equiv};
9use haloumi_core::{eqv::SymbolicEqv, felt::Felt, slot::Slot};
10use haloumi_lowering::{ExprLowering, lowerable::LowerableExpr};
11use std::fmt::Write;
12use std::{
13 convert::Infallible,
14 ops::{Add, Mul, Neg},
15};
16
17#[derive(PartialEq, Eq, Clone, Debug)]
19pub struct IRAexpr(pub(crate) IRAexprImpl);
20
21#[derive(PartialEq, Eq, Clone)]
22pub(crate) enum IRAexprImpl {
23 Constant(Felt),
25 IO(Slot),
27 Negated(Box<IRAexpr>),
29 Sum(Box<IRAexpr>, Box<IRAexpr>),
31 Product(Box<IRAexpr>, Box<IRAexpr>),
33}
34
35impl IRAexpr {
36 pub fn constant(felt: Felt) -> Self {
38 Self(IRAexprImpl::Constant(felt))
39 }
40
41 pub fn slot(s: impl Into<Slot>) -> Self {
43 Self(IRAexprImpl::IO(s.into()))
44 }
45
46 pub fn try_map_io<E>(&mut self, f: &impl Fn(&mut Slot) -> Result<(), E>) -> Result<(), E> {
48 match &mut self.0 {
49 IRAexprImpl::IO(func_io) => f(func_io),
50 IRAexprImpl::Negated(expr) => expr.try_map_io(f),
51 IRAexprImpl::Sum(lhs, rhs) => {
52 lhs.try_map_io(f)?;
53 rhs.try_map_io(f)
54 }
55 IRAexprImpl::Product(lhs, rhs) => {
56 lhs.try_map_io(f)?;
57 rhs.try_map_io(f)
58 }
59 _ => Ok(()),
60 }
61 }
62}
63
64impl Neg for IRAexpr {
65 type Output = Self;
66
67 fn neg(self) -> Self::Output {
68 Self(IRAexprImpl::Negated(Box::new(self)))
69 }
70}
71
72impl Add for IRAexpr {
73 type Output = Self;
74
75 fn add(self, rhs: Self) -> Self::Output {
76 Self(IRAexprImpl::Sum(Box::new(self), Box::new(rhs)))
77 }
78}
79
80impl Mul for IRAexpr {
81 type Output = Self;
82
83 fn mul(self, rhs: Self) -> Self::Output {
84 Self(IRAexprImpl::Product(Box::new(self), Box::new(rhs)))
85 }
86}
87
88impl From<Felt> for IRAexpr {
89 fn from(value: Felt) -> Self {
90 Self(IRAexprImpl::Constant(value))
91 }
92}
93
94impl From<Slot> for IRAexpr {
95 fn from(value: Slot) -> Self {
96 Self(IRAexprImpl::IO(value))
97 }
98}
99
100impl Evaluate<Option<Felt>> for IRAexpr {
101 fn evaluate(&self) -> Option<Felt> {
102 match &self.0 {
103 IRAexprImpl::Constant(felt) => Some(*felt),
104 IRAexprImpl::IO(_) => None,
105 IRAexprImpl::Negated(expr) => Evaluate::<Option<Felt>>::evaluate(expr).map(|f| -f),
106 IRAexprImpl::Sum(lhs, rhs) => Evaluate::<Option<Felt>>::evaluate(lhs)
107 .zip(Evaluate::<Option<Felt>>::evaluate(rhs))
108 .map(|(lhs, rhs)| lhs + rhs),
109 IRAexprImpl::Product(lhs, rhs) => Evaluate::<Option<Felt>>::evaluate(lhs)
110 .zip(Evaluate::<Option<Felt>>::evaluate(rhs))
111 .map(|(lhs, rhs)| lhs * rhs),
112 }
113 }
114}
115
116impl Evaluate<ExprProperties> for IRAexpr {
117 fn evaluate(&self) -> ExprProperties {
118 match &self.0 {
119 IRAexprImpl::Constant(_) => ExprProperty::Const.into(),
120 IRAexprImpl::IO(_) => Default::default(),
121 IRAexprImpl::Negated(expr) => expr.evaluate(),
122 IRAexprImpl::Sum(lhs, rhs) | IRAexprImpl::Product(lhs, rhs) => {
123 Evaluate::<ExprProperties>::evaluate(lhs)
124 & Evaluate::<ExprProperties>::evaluate(rhs)
125 }
126 }
127 }
128}
129
130impl ConstantFolding for IRAexpr {
131 type T = Felt;
132
133 type Error = Infallible;
134
135 fn constant_fold(&mut self) -> Result<(), Self::Error> {
136 match &mut self.0 {
137 IRAexprImpl::Constant(_) => {}
138 IRAexprImpl::IO(_) => {}
139 IRAexprImpl::Negated(expr) => {
140 expr.constant_fold()?;
141 if let Some(f) = expr.const_value() {
142 *self = (-f).into();
143 }
144 }
145
146 IRAexprImpl::Sum(lhs, rhs) => {
147 lhs.constant_fold()?;
148 rhs.constant_fold()?;
149
150 match (lhs.const_value(), rhs.const_value()) {
151 (Some(lhs), Some(rhs)) => {
152 *self = Self(IRAexprImpl::Constant(lhs + rhs));
153 }
154 (None, Some(rhs)) if rhs == 0usize => {
155 *self = (**lhs).clone();
156 }
157 (Some(lhs), None) if lhs == 0usize => {
158 *self = (**rhs).clone();
159 }
160 _ => {}
161 }
162 }
163 IRAexprImpl::Product(lhs, rhs) => {
164 lhs.constant_fold()?;
165 rhs.constant_fold()?;
166 match (lhs.const_value(), rhs.const_value()) {
167 (Some(lhs), Some(rhs)) => {
168 *self = (lhs * rhs).into();
169 }
170 (None, Some(rhs)) if rhs == 1usize => {
172 *self = (**lhs).clone();
173 }
174 (Some(lhs), None) if lhs == 1usize => {
175 *self = (**rhs).clone();
176 }
177 (None, Some(rhs)) if rhs == 0usize => {
179 *self = rhs.into();
180 }
181 (Some(lhs), None) if lhs == 0usize => {
182 *self = lhs.into();
183 }
184 (None, Some(rhs)) if rhs.is_minus_one() => {
186 *self = Self(IRAexprImpl::Negated(lhs.clone()));
187 }
188 (Some(lhs), None) if lhs.is_minus_one() => {
189 *self = Self(IRAexprImpl::Negated(rhs.clone()));
190 }
191 _ => {}
192 }
193 }
194 }
195 Ok(())
196 }
197
198 fn const_value(&self) -> Option<Felt> {
200 match &self.0 {
201 IRAexprImpl::Constant(f) => Some(*f),
202 _ => None,
203 }
204 }
205}
206
207impl IRAexpr {
208 fn negated_inner(&self) -> Option<&IRAexpr> {
210 match &self.0 {
211 IRAexprImpl::Negated(inner) => Some(inner),
212 _ => None,
213 }
214 }
215}
216
217impl Canonicalize for IRAexpr {
218 fn canonicalize(&mut self) {
219 match &mut self.0 {
220 IRAexprImpl::Constant(_) => {}
221 IRAexprImpl::IO(_) => {}
222 IRAexprImpl::Negated(expr) => {
223 expr.canonicalize();
224 if let Some(inner) = expr.negated_inner() {
226 *self = inner.clone();
227 }
228 }
229 IRAexprImpl::Sum(_, _) => todo!(),
230 IRAexprImpl::Product(_, _) => todo!(),
231 };
232 }
233}
234
235impl std::fmt::Debug for IRAexprImpl {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 match self {
238 Self::Constant(arg0) => write!(f, "{arg0:?}"),
239 Self::IO(arg0) => write!(f, "{arg0:?}"),
240 Self::Negated(arg0) => write!(f, "(- {arg0:?})"),
241 Self::Sum(arg0, arg1) => write!(f, "(+ {arg0:?} {arg1:?})"),
242 Self::Product(arg0, arg1) => write!(f, "(* {arg0:?} {arg1:?})"),
243 }
244 }
245}
246
247impl EqvRelation<IRAexpr> for SymbolicEqv {
248 fn equivalent(lhs: &IRAexpr, rhs: &IRAexpr) -> bool {
251 match (&lhs.0, &rhs.0) {
252 (IRAexprImpl::Constant(lhs), IRAexprImpl::Constant(rhs)) => lhs == rhs,
253 (IRAexprImpl::IO(lhs), IRAexprImpl::IO(rhs)) => equiv!(Self | lhs, rhs),
254 (IRAexprImpl::Negated(lhs), IRAexprImpl::Negated(rhs)) => equiv!(Self | lhs, rhs),
255 (IRAexprImpl::Sum(lhs0, lhs1), IRAexprImpl::Sum(rhs0, rhs1)) => {
256 equiv!(Self | lhs0, rhs0) && equiv!(Self | lhs1, rhs1)
257 }
258 (IRAexprImpl::Product(lhs0, lhs1), IRAexprImpl::Product(rhs0, rhs1)) => {
259 equiv!(Self | lhs0, rhs0) && equiv!(Self | lhs1, rhs1)
260 }
261 _ => false,
262 }
263 }
264}
265
266impl LowerableExpr for IRAexpr {
267 fn lower<L>(self, l: &L) -> haloumi_lowering::Result<L::CellOutput>
268 where
269 L: ExprLowering + ?Sized,
270 {
271 match self.0 {
272 IRAexprImpl::Constant(f) => l.lower_constant(f),
273 IRAexprImpl::IO(io) => l.lower_funcio(io),
274 IRAexprImpl::Negated(expr) => l.lower_neg(&expr.lower(l)?),
275 IRAexprImpl::Sum(lhs, rhs) => l.lower_sum(&lhs.lower(l)?, &rhs.lower(l)?),
276 IRAexprImpl::Product(lhs, rhs) => l.lower_product(&lhs.lower(l)?, &rhs.lower(l)?),
277 }
278 }
279}
280
281impl IRPrintable for IRAexpr {
282 fn fmt(&self, ctx: &mut crate::printer::IRPrinterCtx<'_, '_>) -> crate::printer::Result {
283 match &self.0 {
284 IRAexprImpl::Constant(felt) => ctx.list("const", |ctx| write!(ctx, "{}", felt)),
285 IRAexprImpl::IO(slot) => slot.fmt(ctx),
286 IRAexprImpl::Negated(expr) => ctx.block("-", |ctx| expr.fmt(ctx)),
287 IRAexprImpl::Sum(lhs, rhs) => ctx.block("+", |ctx| {
288 let do_nl = lhs.depth() > 1 || rhs.depth() > 1;
289 if lhs.depth() > 1 {
290 ctx.nl()?;
291 }
292 lhs.fmt(ctx)?;
293 if do_nl {
294 ctx.nl()?;
295 } else {
296 write!(ctx, " ")?;
297 }
298 rhs.fmt(ctx)
299 }),
300 IRAexprImpl::Product(lhs, rhs) => ctx.block("*", |ctx| {
301 let do_nl = lhs.depth() > 1 || rhs.depth() > 1;
302 if lhs.depth() > 1 {
303 ctx.nl()?;
304 }
305 lhs.fmt(ctx)?;
306 if do_nl {
307 ctx.nl()?;
308 } else {
309 write!(ctx, " ")?;
310 }
311 rhs.fmt(ctx)
312 }),
313 }
314 }
315
316 fn depth(&self) -> usize {
317 match &self.0 {
318 IRAexprImpl::Constant(_) | IRAexprImpl::IO(_) => 1,
319 IRAexprImpl::Negated(expr) => 1 + expr.depth(),
320 IRAexprImpl::Sum(lhs, rhs) | IRAexprImpl::Product(lhs, rhs) => {
321 1 + std::cmp::max(lhs.depth(), rhs.depth())
322 }
323 }
324 }
325}
326
327#[cfg(test)]
328mod folding_tests {
329 use super::*;
330 use rstest::rstest;
331
332 use ff::PrimeField;
333
334 #[derive(PrimeField)]
336 #[PrimeFieldModulus = "2013265921"]
337 #[PrimeFieldGenerator = "31"]
338 #[PrimeFieldReprEndianness = "little"]
339 pub struct BabyBear([u64; 1]);
340
341 fn c(v: impl Into<BabyBear>) -> IRAexpr {
343 IRAexpr(IRAexprImpl::Constant(Felt::from(v.into())))
344 }
345
346 #[rstest]
347 fn folding_constant_within_field() {
348 let mut test = c(5);
349 let expected = test.clone();
350 test.constant_fold().unwrap();
351 assert_eq!(test, expected);
352 }
353
354 #[rstest]
355 fn folding_constant_outside_field() {
356 let mut test = c(2013265922);
357 let expected = c(1);
358 test.constant_fold().unwrap();
359 assert_eq!(test, expected);
360 }
361
362 #[rstest]
363 fn mult_identity() {
364 let lhs = c(1);
365 let rhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
366 let mut mul = IRAexpr(IRAexprImpl::Product(Box::new(lhs), Box::new(rhs.clone())));
367 mul.constant_fold().unwrap();
368 assert_eq!(mul, rhs);
369 }
370
371 #[rstest]
372 fn mult_identity_rev() {
373 let rhs = c(1);
374 let lhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
375 let mut mul = IRAexpr(IRAexprImpl::Product(Box::new(lhs.clone()), Box::new(rhs)));
376 mul.constant_fold().unwrap();
377 assert_eq!(mul, lhs);
378 }
379
380 #[rstest]
381 fn mult_by_zero() {
382 let lhs = c(0);
383 let rhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
384 let mut mul = IRAexpr(IRAexprImpl::Product(Box::new(lhs.clone()), Box::new(rhs)));
385 mul.constant_fold().unwrap();
386 assert_eq!(mul, lhs);
387 }
388
389 #[rstest]
390 fn mult_by_zero_rev() {
391 let rhs = c(0);
392 let lhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
393 let mut mul = IRAexpr(IRAexprImpl::Product(Box::new(lhs), Box::new(rhs.clone())));
394 mul.constant_fold().unwrap();
395 assert_eq!(mul, rhs);
396 }
397
398 #[rstest]
399 fn sum_identity() {
400 let lhs = c(0);
401 let rhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
402 let mut sum = IRAexpr(IRAexprImpl::Sum(Box::new(lhs), Box::new(rhs.clone())));
403 sum.constant_fold().unwrap();
404 assert_eq!(sum, rhs);
405 }
406
407 #[rstest]
408 fn sum_identity_rev() {
409 let rhs = c(0);
410 let lhs = IRAexpr(IRAexprImpl::IO(Slot::Arg(0.into())));
411 let mut sum = IRAexpr(IRAexprImpl::Sum(Box::new(lhs.clone()), Box::new(rhs)));
412 sum.constant_fold().unwrap();
413 assert_eq!(sum, lhs);
414 }
415}