Skip to main content

haloumi_ir/expr/
bexpr.rs

1//! Structs for handling boolean expressions.
2
3use crate::diagnostics::{SimpleDiagnostic, Validation};
4use crate::error::Error;
5use crate::expr::{ExprProperties, ExprProperty};
6use crate::printer::IRPrintable;
7use crate::traits::{Canonicalize, ConstantFolding, Evaluate, Validatable};
8use crate::{canon::canonicalize_constraint, expr::IRAexpr};
9use eqv::{EqvRelation, equiv};
10use haloumi_core::cmp::CmpOp;
11use haloumi_core::eqv::SymbolicEqv;
12use haloumi_lowering::lowering_err;
13use haloumi_lowering::{ExprLowering, lowerable::LowerableExpr};
14use std::borrow::{Borrow, BorrowMut};
15use std::ops::{Deref, DerefMut};
16use std::{
17    convert::identity,
18    fmt::Write,
19    ops::{BitAnd, BitOr, Not},
20};
21use thiserror::Error;
22
23/// Represents boolean expressions over some arithmetic expression type A.
24#[derive(Debug)]
25pub struct IRBexpr<A>(IRBexprImpl<A>);
26
27enum IRBexprImpl<A> {
28    /// Literal value for true.
29    True,
30    /// Literal value for false.
31    False,
32    /// Comparison operation of two inner arithmetic expressions.
33    Cmp(CmpOp, A, A),
34    /// Represents the conjunction of the inner expressions.
35    And(Vec<IRBexpr<A>>),
36    /// Represents the disjounction of the inner expressions.
37    Or(Vec<IRBexpr<A>>),
38    /// Represents the negation of the inner expression.
39    Not(Box<IRBexpr<A>>),
40    /// Declares that the inner arithmetic expression needs to be proven deterministic
41    Det(A),
42    /// Logical implication operator.
43    Implies(Box<IRBexpr<A>>, Box<IRBexpr<A>>),
44    /// Logical double-implication operator.
45    Iff(Box<IRBexpr<A>>, Box<IRBexpr<A>>),
46}
47
48impl<T> IRBexpr<T> {
49    /// Transforms the inner expression into a different type.
50    pub fn map<O>(self, f: &mut impl FnMut(T) -> O) -> IRBexpr<O> {
51        match self.0 {
52            IRBexprImpl::Cmp(cmp_op, lhs, rhs) => IRBexpr(IRBexprImpl::Cmp(cmp_op, f(lhs), f(rhs))),
53            IRBexprImpl::And(exprs) => IRBexpr(IRBexprImpl::And(
54                exprs.into_iter().map(|e| e.map(f)).collect(),
55            )),
56            IRBexprImpl::Or(exprs) => IRBexpr(IRBexprImpl::Or(
57                exprs.into_iter().map(|e| e.map(f)).collect(),
58            )),
59            IRBexprImpl::Not(expr) => IRBexpr(IRBexprImpl::Not(Box::new(expr.map(f)))),
60            IRBexprImpl::True => IRBexpr(IRBexprImpl::True),
61            IRBexprImpl::False => IRBexpr(IRBexprImpl::False),
62            IRBexprImpl::Det(expr) => IRBexpr(IRBexprImpl::Det(f(expr))),
63            IRBexprImpl::Implies(lhs, rhs) => IRBexpr(IRBexprImpl::Implies(
64                Box::new(lhs.map(f)),
65                Box::new(rhs.map(f)),
66            )),
67            IRBexprImpl::Iff(lhs, rhs) => {
68                IRBexpr(IRBexprImpl::Iff(Box::new(lhs.map(f)), Box::new(rhs.map(f))))
69            }
70        }
71    }
72
73    /// Transforms the inner expression into a different type without moving the struct.
74    pub fn map_into<O>(&self, f: &mut impl FnMut(&T) -> O) -> IRBexpr<O> {
75        match &self.0 {
76            IRBexprImpl::Cmp(cmp_op, lhs, rhs) => {
77                IRBexpr(IRBexprImpl::Cmp(*cmp_op, f(lhs), f(rhs)))
78            }
79            IRBexprImpl::And(exprs) => IRBexpr(IRBexprImpl::And(
80                exprs.iter().map(|e| e.map_into(f)).collect(),
81            )),
82            IRBexprImpl::Or(exprs) => IRBexpr(IRBexprImpl::Or(
83                exprs.iter().map(|e| e.map_into(f)).collect(),
84            )),
85            IRBexprImpl::Not(expr) => IRBexpr(IRBexprImpl::Not(Box::new(expr.map_into(f)))),
86            IRBexprImpl::True => IRBexpr(IRBexprImpl::True),
87            IRBexprImpl::False => IRBexpr(IRBexprImpl::False),
88            IRBexprImpl::Det(expr) => IRBexpr(IRBexprImpl::Det(f(expr))),
89            IRBexprImpl::Implies(lhs, rhs) => IRBexpr(IRBexprImpl::Implies(
90                Box::new(lhs.map_into(f)),
91                Box::new(rhs.map_into(f)),
92            )),
93            IRBexprImpl::Iff(lhs, rhs) => IRBexpr(IRBexprImpl::Iff(
94                Box::new(lhs.map_into(f)),
95                Box::new(rhs.map_into(f)),
96            )),
97        }
98    }
99
100    /// Transforms the inner expression into a different type, potentially failing.
101    pub fn try_map<O, E>(self, f: &mut impl FnMut(T) -> Result<O, E>) -> Result<IRBexpr<O>, E> {
102        Ok(match self.0 {
103            IRBexprImpl::Cmp(cmp_op, lhs, rhs) => {
104                IRBexpr(IRBexprImpl::Cmp(cmp_op, f(lhs)?, f(rhs)?))
105            }
106            IRBexprImpl::And(exprs) => IRBexpr(IRBexprImpl::And(
107                exprs
108                    .into_iter()
109                    .map(|e| e.try_map(f))
110                    .collect::<Result<Vec<_>, _>>()?,
111            )),
112            IRBexprImpl::Or(exprs) => IRBexpr(IRBexprImpl::Or(
113                exprs
114                    .into_iter()
115                    .map(|e| e.try_map(f))
116                    .collect::<Result<Vec<_>, _>>()?,
117            )),
118            IRBexprImpl::Not(expr) => IRBexpr(IRBexprImpl::Not(Box::new(expr.try_map(f)?))),
119            IRBexprImpl::True => IRBexpr(IRBexprImpl::True),
120            IRBexprImpl::False => IRBexpr(IRBexprImpl::False),
121            IRBexprImpl::Det(expr) => IRBexpr(IRBexprImpl::Det(f(expr)?)),
122            IRBexprImpl::Implies(lhs, rhs) => IRBexpr(IRBexprImpl::Implies(
123                Box::new(lhs.try_map(f)?),
124                Box::new(rhs.try_map(f)?),
125            )),
126            IRBexprImpl::Iff(lhs, rhs) => IRBexpr(IRBexprImpl::Iff(
127                Box::new(lhs.try_map(f)?),
128                Box::new(rhs.try_map(f)?),
129            )),
130        })
131    }
132
133    /// Transforms the inner expression in place instead of returning a new expression.
134    pub fn map_inplace(&mut self, f: &mut impl FnMut(&mut T)) {
135        match &mut self.0 {
136            IRBexprImpl::Cmp(_, lhs, rhs) => {
137                f(lhs);
138                f(rhs);
139            }
140            IRBexprImpl::And(exprs) => {
141                for expr in exprs {
142                    expr.map_inplace(f);
143                }
144            }
145            IRBexprImpl::Or(exprs) => {
146                for expr in exprs {
147                    expr.map_inplace(f);
148                }
149            }
150            IRBexprImpl::Not(expr) => expr.map_inplace(f),
151            IRBexprImpl::True => {}
152            IRBexprImpl::False => {}
153            IRBexprImpl::Det(expr) => f(expr),
154            IRBexprImpl::Implies(lhs, rhs) => {
155                lhs.map_inplace(f);
156                rhs.map_inplace(f);
157            }
158            IRBexprImpl::Iff(lhs, rhs) => {
159                lhs.map_inplace(f);
160                rhs.map_inplace(f);
161            }
162        }
163    }
164
165    /// Tries to transform the inner expression in place instead of returning a new expression.
166    pub fn try_map_inplace<E>(
167        &mut self,
168        f: &mut impl FnMut(&mut T) -> Result<(), E>,
169    ) -> Result<(), E> {
170        match &mut self.0 {
171            IRBexprImpl::Cmp(_, lhs, rhs) => {
172                f(lhs)?;
173                f(rhs)
174            }
175            IRBexprImpl::And(exprs) => {
176                for expr in exprs {
177                    expr.try_map_inplace(f)?;
178                }
179                Ok(())
180            }
181            IRBexprImpl::Or(exprs) => {
182                for expr in exprs {
183                    expr.try_map_inplace(f)?;
184                }
185                Ok(())
186            }
187            IRBexprImpl::Not(expr) => expr.try_map_inplace(f),
188            IRBexprImpl::True => Ok(()),
189            IRBexprImpl::False => Ok(()),
190            IRBexprImpl::Det(expr) => f(expr),
191            IRBexprImpl::Implies(lhs, rhs) => {
192                lhs.try_map_inplace(f)?;
193                rhs.try_map_inplace(f)
194            }
195            IRBexprImpl::Iff(lhs, rhs) => {
196                lhs.try_map_inplace(f)?;
197                rhs.try_map_inplace(f)
198            }
199        }
200    }
201
202    pub(crate) fn cmp(op: CmpOp, lhs: T, rhs: T) -> Self {
203        Self(IRBexprImpl::Cmp(op, lhs, rhs))
204    }
205
206    /// Creates a expression that indicates the backend must prove deterministic.
207    pub fn det(expr: T) -> Self {
208        Self(IRBexprImpl::Det(expr))
209    }
210
211    #[inline]
212    /// Creates a constraint with [`CmpOp::Eq`] between two expressions.
213    pub fn eq(lhs: T, rhs: T) -> Self {
214        Self(IRBexprImpl::Cmp(CmpOp::Eq, lhs, rhs))
215    }
216
217    #[inline]
218    /// Creates a constraint with [`CmpOp::Lt`] between two expressions.
219    pub fn lt(lhs: T, rhs: T) -> Self {
220        Self(IRBexprImpl::Cmp(CmpOp::Lt, lhs, rhs))
221    }
222
223    #[inline]
224    /// Creates a constraint with [`CmpOp::Le`] between two expressions.
225    pub fn le(lhs: T, rhs: T) -> Self {
226        Self(IRBexprImpl::Cmp(CmpOp::Le, lhs, rhs))
227    }
228
229    #[inline]
230    /// Creates a constraint with [`CmpOp::Gt`] between two expressions.
231    pub fn gt(lhs: T, rhs: T) -> Self {
232        Self(IRBexprImpl::Cmp(CmpOp::Gt, lhs, rhs))
233    }
234
235    #[inline]
236    /// Creates a constraint with [`CmpOp::Ge`] between two expressions.
237    pub fn ge(lhs: T, rhs: T) -> Self {
238        Self(IRBexprImpl::Cmp(CmpOp::Ge, lhs, rhs))
239    }
240
241    #[inline]
242    /// Creates an implication expression.
243    pub fn implies(self, rhs: Self) -> Self {
244        Self(IRBexprImpl::Implies(Box::new(self), Box::new(rhs)))
245    }
246
247    #[inline]
248    /// Creates a double implication expression.
249    pub fn iff(self, rhs: Self) -> Self {
250        Self(IRBexprImpl::Iff(Box::new(self), Box::new(rhs)))
251    }
252
253    /// Creates a logical AND.
254    pub fn and(self, rhs: Self) -> Self {
255        Self(match (self.0, rhs.0) {
256            (IRBexprImpl::And(mut lhs), IRBexprImpl::And(rhs)) => {
257                lhs.reserve(rhs.len());
258                lhs.extend(rhs);
259                IRBexprImpl::And(lhs)
260            }
261            // The order of the operators is irrelevant
262            (exp, IRBexprImpl::And(mut lst)) | (IRBexprImpl::And(mut lst), exp) => {
263                lst.push(Self(exp));
264                IRBexprImpl::And(lst)
265            }
266            (lhs, rhs) => IRBexprImpl::And(vec![Self(lhs), Self(rhs)]),
267        })
268    }
269
270    /// Creates a logical AND from a sequence of expressions.
271    pub fn and_many(exprs: impl IntoIterator<Item = Self>) -> Self {
272        Self(IRBexprImpl::And(exprs.into_iter().collect()))
273    }
274
275    /// Creates a logical OR.
276    pub fn or(self, rhs: Self) -> Self {
277        Self(match (self.0, rhs.0) {
278            (IRBexprImpl::Or(mut lhs), IRBexprImpl::Or(rhs)) => {
279                lhs.reserve(rhs.len());
280                lhs.extend(rhs);
281                IRBexprImpl::Or(lhs)
282            }
283            // The order of the operators is irrelevant
284            (exp, IRBexprImpl::Or(mut lst)) | (IRBexprImpl::Or(mut lst), exp) => {
285                lst.push(Self(exp));
286                IRBexprImpl::Or(lst)
287            }
288            (lhs, rhs) => IRBexprImpl::Or(vec![Self(lhs), Self(rhs)]),
289        })
290    }
291
292    /// Creates a logical OR from a sequence of expressions.
293    pub fn or_many(exprs: impl IntoIterator<Item = Self>) -> Self {
294        Self(IRBexprImpl::Or(exprs.into_iter().collect()))
295    }
296
297    /// Maps the statement's inner type to a tuple with the passed value.
298    pub fn with<O>(self, other: O) -> IRBexpr<(O, T)>
299    where
300        O: Clone,
301    {
302        self.map(&mut |t| (other.clone(), t))
303    }
304
305    /// Maps the statement's inner type to a tuple with the result of the closure.
306    pub fn with_fn<O>(self, other: impl Fn() -> O) -> IRBexpr<(O, T)> {
307        self.map(&mut |t| (other(), t))
308    }
309}
310
311struct LogLine {
312    before: Option<String>,
313    ident: usize,
314}
315
316impl LogLine {
317    fn new<T: std::fmt::Debug>(expr: &IRBexprImpl<T>, ident: usize) -> Self {
318        if matches!(
319            expr,
320            IRBexprImpl::True | IRBexprImpl::False | IRBexprImpl::Cmp(_, _, _)
321        ) {
322            Self {
323                before: Some(format!("{expr:?}")),
324                ident,
325            }
326        } else {
327            log::debug!("[constant_fold] {:ident$} {expr:?} {{", "", ident = ident);
328            Self {
329                before: None,
330                ident,
331            }
332        }
333    }
334
335    fn log<T: std::fmt::Debug>(self, expr: &mut IRBexprImpl<T>) {
336        match self.before {
337            Some(before) => {
338                log::debug!(
339                    "[constant_fold] {:ident$} {} -> {expr:?}",
340                    "",
341                    before,
342                    ident = self.ident
343                );
344            }
345            None => {
346                log::debug!(
347                    "[constant_fold] {:ident$} }} -> {expr:?}",
348                    "",
349                    ident = self.ident
350                );
351            }
352        }
353    }
354}
355
356impl Canonicalize for IRBexpr<IRAexpr> {
357    /// Matches the expressions against a series of known patterns and applies rewrites if able to.
358    fn canonicalize(&mut self) {
359        match &mut self.0 {
360            IRBexprImpl::True => {}
361            IRBexprImpl::False => {}
362            IRBexprImpl::Cmp(op, lhs, rhs) => {
363                if let Some((op, lhs, rhs)) = canonicalize_constraint(*op, lhs, rhs) {
364                    *self = Self(IRBexprImpl::Cmp(op, lhs, rhs));
365                }
366            }
367            IRBexprImpl::And(exprs) => {
368                for expr in exprs {
369                    expr.canonicalize();
370                }
371            }
372            IRBexprImpl::Or(exprs) => {
373                for expr in exprs {
374                    expr.canonicalize();
375                }
376            }
377            IRBexprImpl::Not(expr) => {
378                expr.canonicalize();
379                match &expr.0 {
380                    IRBexprImpl::True => {
381                        *self = Self(IRBexprImpl::False);
382                    }
383                    IRBexprImpl::False => {
384                        *self = Self(IRBexprImpl::True);
385                    }
386                    IRBexprImpl::Cmp(op, lhs, rhs) => {
387                        *self = Self(IRBexprImpl::Cmp(
388                            match op {
389                                CmpOp::Eq => CmpOp::Ne,
390                                CmpOp::Lt => CmpOp::Ge,
391                                CmpOp::Le => CmpOp::Gt,
392                                CmpOp::Gt => CmpOp::Le,
393                                CmpOp::Ge => CmpOp::Lt,
394                                CmpOp::Ne => CmpOp::Eq,
395                            },
396                            lhs.clone(),
397                            rhs.clone(),
398                        ));
399                        self.canonicalize();
400                    }
401                    _ => {}
402                }
403            }
404            IRBexprImpl::Det(_) => {}
405            IRBexprImpl::Implies(lhs, rhs) => {
406                lhs.canonicalize();
407                rhs.canonicalize();
408            }
409            IRBexprImpl::Iff(lhs, rhs) => {
410                lhs.canonicalize();
411                rhs.canonicalize();
412            }
413        }
414    }
415}
416
417impl<T> IRBexpr<T>
418where
419    T: ConstantFolding + std::fmt::Debug,
420    T::T: Eq + Ord,
421{
422    /// Folds the expression if the values are constant.
423    fn constant_fold_impl(&mut self, indent: usize) -> Result<(), T::Error> {
424        let log = LogLine::new(&self.0, indent);
425        match &mut self.0 {
426            IRBexprImpl::True => {
427                log.log(&mut self.0);
428            }
429            IRBexprImpl::False => {
430                log.log(&mut self.0);
431            }
432            IRBexprImpl::Cmp(op, lhs, rhs) => {
433                lhs.constant_fold()?;
434                rhs.constant_fold()?;
435                if let Some((lhs, rhs)) = lhs.const_value().zip(rhs.const_value()) {
436                    *self = match op {
437                        CmpOp::Eq => lhs == rhs,
438                        CmpOp::Lt => lhs < rhs,
439                        CmpOp::Le => lhs <= rhs,
440                        CmpOp::Gt => lhs > rhs,
441                        CmpOp::Ge => lhs >= rhs,
442                        CmpOp::Ne => lhs != rhs,
443                    }
444                    .into()
445                }
446                log.log(&mut self.0);
447            }
448            IRBexprImpl::And(exprs) => {
449                for expr in &mut *exprs {
450                    expr.constant_fold_impl(indent + 2)?;
451                }
452                // If any value is a literal 'false' convert into IRBexprImpl::False
453                if exprs.iter().any(|expr| {
454                    expr.const_value()
455                        // If the expr is false-y flip the boolean to return 'true'.
456                        .map(|b| !b)
457                        // Default to 'false' for non-literal expressions.
458                        .unwrap_or_default()
459                }) {
460                    *self = Self(IRBexprImpl::False);
461                    log.log(&mut self.0);
462                    return Ok(());
463                }
464                // Remove any literal 'true' values.
465                exprs.retain(|expr| {
466                    expr.const_value()
467                        // If the expr is IRBexprImpl::True we don't want to retain.
468                        .map(|b| !b)
469                        // Default to true to keep the non-literal values.
470                        .unwrap_or(true)
471                });
472                if exprs.is_empty() {
473                    *self = Self(IRBexprImpl::True);
474                }
475                log.log(&mut self.0);
476            }
477            IRBexprImpl::Or(exprs) => {
478                for expr in &mut *exprs {
479                    expr.constant_fold_impl(indent + 2)?;
480                }
481                // If any value is a literal 'true' convert into IRBexprImpl::True.
482                if exprs
483                    .iter()
484                    .any(|expr| expr.const_value().unwrap_or_default())
485                {
486                    *self = Self(IRBexprImpl::True);
487                    log.log(&mut self.0);
488                    return Ok(());
489                }
490                // Remove any literal 'false' values.
491                exprs.retain(|expr| {
492                    expr.const_value()
493                        // Default to true to keep the non-literal values.
494                        .unwrap_or(true)
495                });
496                if exprs.is_empty() {
497                    *self = Self(IRBexprImpl::False);
498                }
499                log.log(&mut self.0);
500            }
501            IRBexprImpl::Not(expr) => {
502                expr.constant_fold_impl(indent + 2)?;
503                if let Some(b) = expr.const_value() {
504                    *self = (!b).into();
505                }
506                log.log(&mut self.0);
507            }
508            IRBexprImpl::Det(expr) => expr.constant_fold()?,
509            IRBexprImpl::Implies(lhs, rhs) => {
510                lhs.constant_fold_impl(indent + 2)?;
511                rhs.constant_fold_impl(indent + 2)?;
512                if let Some((lhs, rhs)) = lhs.const_value().zip(rhs.const_value()) {
513                    *self = (!lhs || rhs).into();
514                }
515            }
516            IRBexprImpl::Iff(lhs, rhs) => {
517                lhs.constant_fold_impl(indent + 2)?;
518                rhs.constant_fold_impl(indent + 2)?;
519                if let Some((lhs, rhs)) = lhs.const_value().zip(rhs.const_value()) {
520                    *self = (lhs == rhs).into();
521                }
522            }
523        }
524        Ok(())
525    }
526}
527
528impl<T> ConstantFolding for IRBexpr<T>
529where
530    T: ConstantFolding + std::fmt::Debug,
531    T::T: Eq + Ord,
532{
533    type T = bool;
534
535    type Error = T::Error;
536
537    fn constant_fold(&mut self) -> Result<(), Self::Error> {
538        self.constant_fold_impl(0)
539    }
540
541    /// Returns `Some(true)` or `Some(false)` if the expression is constant, `None` otherwise.
542    fn const_value(&self) -> Option<bool> {
543        match &self.0 {
544            IRBexprImpl::True => Some(true),
545            IRBexprImpl::False => Some(false),
546            _ => None,
547        }
548    }
549}
550
551impl<T: Evaluate<ExprProperties>> Evaluate<ExprProperties> for IRBexpr<T> {
552    fn evaluate(&self) -> ExprProperties {
553        match &self.0 {
554            IRBexprImpl::True | IRBexprImpl::False => ExprProperty::Const.into(),
555            IRBexprImpl::Cmp(_, lhs, rhs) => lhs.evaluate() & rhs.evaluate(),
556            IRBexprImpl::And(exprs) | IRBexprImpl::Or(exprs) => {
557                exprs.iter().map(Evaluate::evaluate).product()
558            }
559            IRBexprImpl::Not(expr) => expr.evaluate(),
560            IRBexprImpl::Det(expr) => expr.evaluate(),
561            IRBexprImpl::Implies(lhs, rhs) | IRBexprImpl::Iff(lhs, rhs) => {
562                lhs.evaluate() & rhs.evaluate()
563            }
564        }
565    }
566}
567
568impl<T> From<bool> for IRBexpr<T> {
569    fn from(value: bool) -> Self {
570        Self(if value {
571            IRBexprImpl::True
572        } else {
573            IRBexprImpl::False
574        })
575    }
576}
577
578/// IRBexprImpl transitively inherits the symbolic equivalence relation.
579impl<L, R> EqvRelation<IRBexpr<L>, IRBexpr<R>> for SymbolicEqv
580where
581    SymbolicEqv: EqvRelation<L, R>,
582{
583    /// Two boolean expressions are equivalent if they are structurally equal and their inner entities
584    /// are equivalent.
585    fn equivalent(lhs: &IRBexpr<L>, rhs: &IRBexpr<R>) -> bool {
586        match (&lhs.0, &rhs.0) {
587            (IRBexprImpl::True, IRBexprImpl::True) | (IRBexprImpl::False, IRBexprImpl::False) => {
588                true
589            }
590            (IRBexprImpl::Cmp(op1, lhs1, rhs1), IRBexprImpl::Cmp(op2, lhs2, rhs2)) => {
591                op1 == op2 && equiv!(Self | lhs1, lhs2) && equiv!(Self | rhs1, rhs2)
592            }
593            (IRBexprImpl::And(lhs), IRBexprImpl::And(rhs)) => {
594                equiv!(Self | lhs, rhs)
595            }
596            (IRBexprImpl::Or(lhs), IRBexprImpl::Or(rhs)) => {
597                equiv!(Self | lhs, rhs)
598            }
599            (IRBexprImpl::Not(lhs), IRBexprImpl::Not(rhs)) => {
600                equiv!(Self | lhs, rhs)
601            }
602            (IRBexprImpl::Det(lhs), IRBexprImpl::Det(rhs)) => equiv!(Self | lhs, rhs),
603            (IRBexprImpl::Implies(lhs1, rhs1), IRBexprImpl::Implies(lhs2, rhs2)) => {
604                equiv!(Self | lhs1, lhs2) && equiv!(Self | rhs1, rhs2)
605            }
606
607            (IRBexprImpl::Iff(lhs1, rhs1), IRBexprImpl::Iff(lhs2, rhs2)) => {
608                equiv!(Self | lhs1, lhs2) && equiv!(Self | rhs1, rhs2)
609            }
610            _ => false,
611        }
612    }
613}
614
615impl<T> BitAnd for IRBexpr<T> {
616    type Output = Self;
617
618    fn bitand(self, rhs: Self) -> Self::Output {
619        self.and(rhs)
620    }
621}
622
623impl<T> BitOr for IRBexpr<T> {
624    type Output = Self;
625
626    fn bitor(self, rhs: Self) -> Self::Output {
627        self.or(rhs)
628    }
629}
630
631impl<T> Not for IRBexpr<T> {
632    type Output = Self;
633
634    fn not(self) -> Self::Output {
635        match self.0 {
636            IRBexprImpl::Not(e) => *e,
637            e => Self(IRBexprImpl::Not(Box::new(Self(e)))),
638        }
639    }
640}
641
642impl<T: std::fmt::Debug> std::fmt::Debug for IRBexprImpl<T> {
643    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644        match self {
645            IRBexprImpl::Cmp(cmp_op, lhs, rhs) => write!(f, "({cmp_op} {lhs:?} {rhs:?})",),
646            IRBexprImpl::And(exprs) => write!(f, "(&& {exprs:?})"),
647            IRBexprImpl::Or(exprs) => write!(f, "(|| {exprs:?})"),
648            IRBexprImpl::Not(expr) => write!(f, "(! {expr:?})"),
649            IRBexprImpl::True => write!(f, "(true)"),
650            IRBexprImpl::False => write!(f, "(false)"),
651            IRBexprImpl::Det(expr) => write!(f, "(det {expr:?})"),
652            IRBexprImpl::Implies(lhs, rhs) => write!(f, "(=> {lhs:?} {rhs:?})"),
653            IRBexprImpl::Iff(lhs, rhs) => write!(f, "(<=> {lhs:?} {rhs:?})"),
654        }
655    }
656}
657
658impl<T: Clone> Clone for IRBexpr<T> {
659    fn clone(&self) -> Self {
660        Self(match &self.0 {
661            IRBexprImpl::Cmp(cmp_op, lhs, rhs) => {
662                IRBexprImpl::Cmp(*cmp_op, lhs.clone(), rhs.clone())
663            }
664            IRBexprImpl::And(exprs) => IRBexprImpl::And(exprs.clone()),
665            IRBexprImpl::Or(exprs) => IRBexprImpl::Or(exprs.clone()),
666            IRBexprImpl::Not(expr) => IRBexprImpl::Not(expr.clone()),
667            IRBexprImpl::True => IRBexprImpl::True,
668            IRBexprImpl::False => IRBexprImpl::False,
669            IRBexprImpl::Det(expr) => IRBexprImpl::Det(expr.clone()),
670            IRBexprImpl::Implies(lhs, rhs) => IRBexprImpl::Implies(lhs.clone(), rhs.clone()),
671            IRBexprImpl::Iff(lhs, rhs) => IRBexprImpl::Iff(lhs.clone(), rhs.clone()),
672        })
673    }
674}
675
676impl<T: PartialEq> PartialEq for IRBexpr<T> {
677    fn eq(&self, other: &Self) -> bool {
678        match (&self.0, &other.0) {
679            (IRBexprImpl::Cmp(op1, lhs1, rhs1), IRBexprImpl::Cmp(op2, lhs2, rhs2)) => {
680                op1 == op2 && lhs1 == lhs2 && rhs1 == rhs2
681            }
682            (IRBexprImpl::And(lhs), IRBexprImpl::And(rhs)) => lhs == rhs,
683            (IRBexprImpl::Or(lhs), IRBexprImpl::Or(rhs)) => lhs == rhs,
684            (IRBexprImpl::Not(lhs), IRBexprImpl::Not(rhs)) => lhs == rhs,
685            (IRBexprImpl::True, IRBexprImpl::True) => true,
686            (IRBexprImpl::False, IRBexprImpl::False) => true,
687            (IRBexprImpl::Det(lhs), IRBexprImpl::Det(rhs)) => lhs == rhs,
688            (IRBexprImpl::Implies(lhs1, rhs1), IRBexprImpl::Implies(lhs2, rhs2)) => {
689                lhs1 == lhs2 && rhs1 == rhs2
690            }
691            (IRBexprImpl::Iff(lhs1, rhs1), IRBexprImpl::Iff(lhs2, rhs2)) => {
692                lhs1 == lhs2 && rhs1 == rhs2
693            }
694            _ => false,
695        }
696    }
697}
698
699fn reduce_bool_expr<A, L>(
700    exprs: impl IntoIterator<Item = IRBexpr<A>>,
701    l: &L,
702    cb: impl Fn(&L, &L::CellOutput, &L::CellOutput) -> haloumi_lowering::Result<L::CellOutput>,
703) -> haloumi_lowering::Result<L::CellOutput>
704where
705    A: LowerableExpr,
706    L: ExprLowering + ?Sized,
707{
708    exprs
709        .into_iter()
710        .map(|e| e.lower(l))
711        .reduce(|lhs, rhs| lhs.and_then(|lhs| rhs.and_then(|rhs| cb(l, &lhs, &rhs))))
712        .ok_or_else(|| lowering_err!(Error::EmptyBexpr))
713        .and_then(identity)
714}
715
716impl<A: LowerableExpr> LowerableExpr for IRBexpr<A> {
717    fn lower<L>(self, l: &L) -> haloumi_lowering::Result<L::CellOutput>
718    where
719        L: ExprLowering + ?Sized,
720    {
721        match self.0 {
722            IRBexprImpl::Cmp(cmp_op, lhs, rhs) => {
723                let lhs = lhs.lower(l)?;
724                let rhs = rhs.lower(l)?;
725                match cmp_op {
726                    CmpOp::Eq => l.lower_eq(&lhs, &rhs),
727                    CmpOp::Lt => l.lower_lt(&lhs, &rhs),
728                    CmpOp::Le => l.lower_le(&lhs, &rhs),
729                    CmpOp::Gt => l.lower_gt(&lhs, &rhs),
730                    CmpOp::Ge => l.lower_ge(&lhs, &rhs),
731                    CmpOp::Ne => l.lower_ne(&lhs, &rhs),
732                }
733            }
734            IRBexprImpl::And(exprs) => reduce_bool_expr(exprs, l, L::lower_and),
735            IRBexprImpl::Or(exprs) => reduce_bool_expr(exprs, l, L::lower_or),
736            IRBexprImpl::Not(expr) => expr.lower(l).and_then(|e| l.lower_not(&e)),
737            IRBexprImpl::True => l.lower_true(),
738            IRBexprImpl::False => l.lower_false(),
739            IRBexprImpl::Det(expr) => expr.lower(l).and_then(|e| l.lower_det(&e)),
740            IRBexprImpl::Implies(lhs, rhs) => {
741                let lhs = lhs.lower(l)?;
742                let rhs = rhs.lower(l)?;
743                l.lower_implies(&lhs, &rhs)
744            }
745            IRBexprImpl::Iff(lhs, rhs) => {
746                let lhs = lhs.lower(l)?;
747                let rhs = rhs.lower(l)?;
748                l.lower_iff(&lhs, &rhs)
749            }
750        }
751    }
752}
753
754impl<T: IRPrintable> IRPrintable for IRBexpr<T> {
755    fn fmt(&self, ctx: &mut crate::printer::IRPrinterCtx<'_, '_>) -> crate::printer::Result {
756        match &self.0 {
757            IRBexprImpl::True => write!(ctx, "(true)"),
758            IRBexprImpl::False => write!(ctx, "(false)"),
759            IRBexprImpl::Cmp(cmp_op, lhs, rhs) => ctx.block(format!("{cmp_op}").as_str(), |ctx| {
760                if lhs.depth() > 1 {
761                    ctx.nl()?;
762                }
763                lhs.fmt(ctx)?;
764                if lhs.depth() > 1 || rhs.depth() > 1 {
765                    ctx.nl()?;
766                }
767                rhs.fmt(ctx)
768            }),
769            IRBexprImpl::And(exprs) => ctx.block("&&", |ctx| {
770                let do_nl = exprs.iter().any(|expr| expr.depth() > 1);
771                let mut is_first = true;
772                for expr in exprs {
773                    if do_nl && !is_first {
774                        ctx.nl()?;
775                    }
776                    is_first = false;
777                    expr.fmt(ctx)?;
778                }
779                Ok(())
780            }),
781            IRBexprImpl::Or(exprs) => ctx.block("||", |ctx| {
782                let do_nl = exprs.iter().any(|expr| expr.depth() > 1);
783                let mut is_first = true;
784                for expr in exprs {
785                    if do_nl && !is_first {
786                        ctx.nl()?;
787                    }
788                    is_first = false;
789                    expr.fmt(ctx)?;
790                }
791                Ok(())
792            }),
793            IRBexprImpl::Not(expr) => ctx.block("!", |ctx| expr.fmt(ctx)),
794            IRBexprImpl::Det(expr) => ctx.block("det", |ctx| expr.fmt(ctx)),
795            IRBexprImpl::Implies(lhs, rhs) => ctx.block("=>", |ctx| {
796                if lhs.depth() > 1 {
797                    ctx.nl()?;
798                }
799                lhs.fmt(ctx)?;
800                if lhs.depth() > 1 || rhs.depth() > 1 {
801                    ctx.nl()?;
802                }
803                rhs.fmt(ctx)
804            }),
805            IRBexprImpl::Iff(lhs, rhs) => ctx.block("<=>", |ctx| {
806                if lhs.depth() > 1 {
807                    ctx.nl()?;
808                }
809                lhs.fmt(ctx)?;
810                if lhs.depth() > 1 || rhs.depth() > 1 {
811                    ctx.nl()?;
812                }
813                rhs.fmt(ctx)
814            }),
815        }
816    }
817
818    fn depth(&self) -> usize {
819        match &self.0 {
820            IRBexprImpl::True | IRBexprImpl::False => 1,
821            IRBexprImpl::Cmp(_, lhs, rhs) => 1 + std::cmp::max(lhs.depth(), rhs.depth()),
822            IRBexprImpl::And(exprs) | IRBexprImpl::Or(exprs) => {
823                1 + exprs
824                    .iter()
825                    .map(|expr| expr.depth())
826                    .max()
827                    .unwrap_or_default()
828            }
829            IRBexprImpl::Not(expr) => 1 + expr.depth(),
830            IRBexprImpl::Det(expr) => 1 + expr.depth(),
831            IRBexprImpl::Implies(lhs, rhs) | IRBexprImpl::Iff(lhs, rhs) => {
832                1 + std::cmp::max(lhs.depth(), rhs.depth())
833            }
834        }
835    }
836}
837
838/// A constant boolean expression.
839///
840/// Is used to guarantee that an expression is constant leveraging the type system.
841/// The only way to build this type is via [`TryInto`].
842#[derive(Debug, Clone, PartialEq)]
843pub struct IRConstBexpr<A>(IRBexpr<A>);
844
845impl<A> IRConstBexpr<A> {
846    #[allow(dead_code)]
847    pub(crate) fn map<O>(expr: IRConstBexpr<O>, f: &mut impl FnMut(O) -> A) -> Self {
848        Self(expr.0.map(f))
849    }
850
851    #[allow(dead_code)]
852    pub(crate) fn map_into<O>(expr: &IRConstBexpr<O>, f: &mut impl FnMut(&O) -> A) -> Self {
853        Self(expr.0.map_into(f))
854    }
855
856    #[allow(dead_code)]
857    pub(crate) fn try_map<O, E>(
858        expr: IRConstBexpr<O>,
859        f: &mut impl FnMut(O) -> Result<A, E>,
860    ) -> Result<Self, E> {
861        Ok(Self(expr.0.try_map(f)?))
862    }
863
864    #[allow(dead_code)]
865    pub(crate) fn map_inplace(expr: &mut Self, f: &mut impl FnMut(&mut A)) {
866        expr.0.map_inplace(f);
867    }
868
869    #[allow(dead_code)]
870    pub(crate) fn try_map_inplace<E>(
871        expr: &mut Self,
872        f: &mut impl FnMut(&mut A) -> Result<(), E>,
873    ) -> Result<(), E> {
874        expr.0.try_map_inplace(f)
875    }
876}
877
878impl<A> Deref for IRConstBexpr<A> {
879    type Target = IRBexpr<A>;
880
881    fn deref(&self) -> &Self::Target {
882        &self.0
883    }
884}
885
886impl<A> DerefMut for IRConstBexpr<A> {
887    fn deref_mut(&mut self) -> &mut Self::Target {
888        &mut self.0
889    }
890}
891
892impl<A> AsRef<IRBexpr<A>> for IRConstBexpr<A> {
893    fn as_ref(&self) -> &IRBexpr<A> {
894        self.deref()
895    }
896}
897
898impl<A> AsMut<IRBexpr<A>> for IRConstBexpr<A> {
899    fn as_mut(&mut self) -> &mut IRBexpr<A> {
900        self.deref_mut()
901    }
902}
903
904impl<A> Borrow<IRBexpr<A>> for IRConstBexpr<A> {
905    fn borrow(&self) -> &IRBexpr<A> {
906        self.deref()
907    }
908}
909
910impl<A> BorrowMut<IRBexpr<A>> for IRConstBexpr<A> {
911    fn borrow_mut(&mut self) -> &mut IRBexpr<A> {
912        self.deref_mut()
913    }
914}
915
916/// Raised when attempting to transform an [`IRBexpr`] into a [`IRConstBexpr`].
917#[derive(Debug, Error, Clone, Copy)]
918#[error("attempted to transform a non constant boolean expression")]
919pub struct NonConstIRBexprError;
920
921impl<A> TryFrom<IRBexpr<A>> for IRConstBexpr<A>
922where
923    IRBexpr<A>: Evaluate<ExprProperties>,
924{
925    type Error = NonConstIRBexprError;
926
927    fn try_from(value: IRBexpr<A>) -> Result<Self, Self::Error> {
928        let props = value.evaluate();
929        if props != ExprProperty::Const {
930            return Err(NonConstIRBexprError);
931        }
932        Ok(Self(value))
933    }
934}
935
936impl<A> From<IRConstBexpr<A>> for IRBexpr<A> {
937    fn from(value: IRConstBexpr<A>) -> Self {
938        value.0
939    }
940}
941
942impl<A> Validatable for IRConstBexpr<A>
943where
944    IRBexpr<A>: Evaluate<ExprProperties>,
945{
946    type Diagnostic = SimpleDiagnostic;
947
948    type Context = ();
949
950    fn validate_with_context(
951        &self,
952        _: &Self::Context,
953    ) -> Result<Vec<Self::Diagnostic>, Vec<Self::Diagnostic>> {
954        let mut validation = Validation::new();
955        if self.0.evaluate() != ExprProperty::Const {
956            validation.with_error(SimpleDiagnostic::error(
957                "boolean expression is not constant",
958            ));
959        }
960        validation.into()
961    }
962}
963
964#[cfg(test)]
965mod tests {
966    use super::*;
967
968    fn t() -> IRBexpr<()> {
969        true.into()
970    }
971
972    fn f() -> IRBexpr<()> {
973        false.into()
974    }
975
976    #[test]
977    fn constant_fold_not_true() {
978        let mut expr = !t();
979        expr.constant_fold().unwrap();
980        assert_eq!(expr, f());
981    }
982
983    #[test]
984    fn constant_fold_not_false() {
985        let mut expr = !f();
986        expr.constant_fold().unwrap();
987        assert_eq!(expr, t());
988    }
989
990    impl ConstantFolding for () {
991        type Error = std::convert::Infallible;
992
993        type T = ();
994
995        fn constant_fold(&mut self) -> Result<(), Self::Error> {
996            Ok(())
997        }
998    }
999}