1use crate::diagnostics::{SimpleDiagnostic, Validation};
4use crate::error::Error;
5use crate::expr::{ExprProperties, ExprProperty};
6use crate::printer::IRPrintable;
7use crate::traits::{Canonicalize, ConstantFolding, Evaluate, Validatable};
8use crate::{canon::canonicalize_constraint, expr::IRAexpr};
9use eqv::{EqvRelation, equiv};
10use haloumi_core::cmp::CmpOp;
11use haloumi_core::eqv::SymbolicEqv;
12use haloumi_lowering::lowering_err;
13use haloumi_lowering::{ExprLowering, lowerable::LowerableExpr};
14use std::borrow::{Borrow, BorrowMut};
15use std::ops::{Deref, DerefMut};
16use std::{
17 convert::identity,
18 fmt::Write,
19 ops::{BitAnd, BitOr, Not},
20};
21use thiserror::Error;
22
23#[derive(Debug)]
25pub struct IRBexpr<A>(IRBexprImpl<A>);
26
27enum IRBexprImpl<A> {
28 True,
30 False,
32 Cmp(CmpOp, A, A),
34 And(Vec<IRBexpr<A>>),
36 Or(Vec<IRBexpr<A>>),
38 Not(Box<IRBexpr<A>>),
40 Det(A),
42 Implies(Box<IRBexpr<A>>, Box<IRBexpr<A>>),
44 Iff(Box<IRBexpr<A>>, Box<IRBexpr<A>>),
46}
47
48impl<T> IRBexpr<T> {
49 pub fn map<O>(self, f: &mut impl FnMut(T) -> O) -> IRBexpr<O> {
51 match self.0 {
52 IRBexprImpl::Cmp(cmp_op, lhs, rhs) => IRBexpr(IRBexprImpl::Cmp(cmp_op, f(lhs), f(rhs))),
53 IRBexprImpl::And(exprs) => IRBexpr(IRBexprImpl::And(
54 exprs.into_iter().map(|e| e.map(f)).collect(),
55 )),
56 IRBexprImpl::Or(exprs) => IRBexpr(IRBexprImpl::Or(
57 exprs.into_iter().map(|e| e.map(f)).collect(),
58 )),
59 IRBexprImpl::Not(expr) => IRBexpr(IRBexprImpl::Not(Box::new(expr.map(f)))),
60 IRBexprImpl::True => IRBexpr(IRBexprImpl::True),
61 IRBexprImpl::False => IRBexpr(IRBexprImpl::False),
62 IRBexprImpl::Det(expr) => IRBexpr(IRBexprImpl::Det(f(expr))),
63 IRBexprImpl::Implies(lhs, rhs) => IRBexpr(IRBexprImpl::Implies(
64 Box::new(lhs.map(f)),
65 Box::new(rhs.map(f)),
66 )),
67 IRBexprImpl::Iff(lhs, rhs) => {
68 IRBexpr(IRBexprImpl::Iff(Box::new(lhs.map(f)), Box::new(rhs.map(f))))
69 }
70 }
71 }
72
73 pub fn map_into<O>(&self, f: &mut impl FnMut(&T) -> O) -> IRBexpr<O> {
75 match &self.0 {
76 IRBexprImpl::Cmp(cmp_op, lhs, rhs) => {
77 IRBexpr(IRBexprImpl::Cmp(*cmp_op, f(lhs), f(rhs)))
78 }
79 IRBexprImpl::And(exprs) => IRBexpr(IRBexprImpl::And(
80 exprs.iter().map(|e| e.map_into(f)).collect(),
81 )),
82 IRBexprImpl::Or(exprs) => IRBexpr(IRBexprImpl::Or(
83 exprs.iter().map(|e| e.map_into(f)).collect(),
84 )),
85 IRBexprImpl::Not(expr) => IRBexpr(IRBexprImpl::Not(Box::new(expr.map_into(f)))),
86 IRBexprImpl::True => IRBexpr(IRBexprImpl::True),
87 IRBexprImpl::False => IRBexpr(IRBexprImpl::False),
88 IRBexprImpl::Det(expr) => IRBexpr(IRBexprImpl::Det(f(expr))),
89 IRBexprImpl::Implies(lhs, rhs) => IRBexpr(IRBexprImpl::Implies(
90 Box::new(lhs.map_into(f)),
91 Box::new(rhs.map_into(f)),
92 )),
93 IRBexprImpl::Iff(lhs, rhs) => IRBexpr(IRBexprImpl::Iff(
94 Box::new(lhs.map_into(f)),
95 Box::new(rhs.map_into(f)),
96 )),
97 }
98 }
99
100 pub fn try_map<O, E>(self, f: &mut impl FnMut(T) -> Result<O, E>) -> Result<IRBexpr<O>, E> {
102 Ok(match self.0 {
103 IRBexprImpl::Cmp(cmp_op, lhs, rhs) => {
104 IRBexpr(IRBexprImpl::Cmp(cmp_op, f(lhs)?, f(rhs)?))
105 }
106 IRBexprImpl::And(exprs) => IRBexpr(IRBexprImpl::And(
107 exprs
108 .into_iter()
109 .map(|e| e.try_map(f))
110 .collect::<Result<Vec<_>, _>>()?,
111 )),
112 IRBexprImpl::Or(exprs) => IRBexpr(IRBexprImpl::Or(
113 exprs
114 .into_iter()
115 .map(|e| e.try_map(f))
116 .collect::<Result<Vec<_>, _>>()?,
117 )),
118 IRBexprImpl::Not(expr) => IRBexpr(IRBexprImpl::Not(Box::new(expr.try_map(f)?))),
119 IRBexprImpl::True => IRBexpr(IRBexprImpl::True),
120 IRBexprImpl::False => IRBexpr(IRBexprImpl::False),
121 IRBexprImpl::Det(expr) => IRBexpr(IRBexprImpl::Det(f(expr)?)),
122 IRBexprImpl::Implies(lhs, rhs) => IRBexpr(IRBexprImpl::Implies(
123 Box::new(lhs.try_map(f)?),
124 Box::new(rhs.try_map(f)?),
125 )),
126 IRBexprImpl::Iff(lhs, rhs) => IRBexpr(IRBexprImpl::Iff(
127 Box::new(lhs.try_map(f)?),
128 Box::new(rhs.try_map(f)?),
129 )),
130 })
131 }
132
133 pub fn map_inplace(&mut self, f: &mut impl FnMut(&mut T)) {
135 match &mut self.0 {
136 IRBexprImpl::Cmp(_, lhs, rhs) => {
137 f(lhs);
138 f(rhs);
139 }
140 IRBexprImpl::And(exprs) => {
141 for expr in exprs {
142 expr.map_inplace(f);
143 }
144 }
145 IRBexprImpl::Or(exprs) => {
146 for expr in exprs {
147 expr.map_inplace(f);
148 }
149 }
150 IRBexprImpl::Not(expr) => expr.map_inplace(f),
151 IRBexprImpl::True => {}
152 IRBexprImpl::False => {}
153 IRBexprImpl::Det(expr) => f(expr),
154 IRBexprImpl::Implies(lhs, rhs) => {
155 lhs.map_inplace(f);
156 rhs.map_inplace(f);
157 }
158 IRBexprImpl::Iff(lhs, rhs) => {
159 lhs.map_inplace(f);
160 rhs.map_inplace(f);
161 }
162 }
163 }
164
165 pub fn try_map_inplace<E>(
167 &mut self,
168 f: &mut impl FnMut(&mut T) -> Result<(), E>,
169 ) -> Result<(), E> {
170 match &mut self.0 {
171 IRBexprImpl::Cmp(_, lhs, rhs) => {
172 f(lhs)?;
173 f(rhs)
174 }
175 IRBexprImpl::And(exprs) => {
176 for expr in exprs {
177 expr.try_map_inplace(f)?;
178 }
179 Ok(())
180 }
181 IRBexprImpl::Or(exprs) => {
182 for expr in exprs {
183 expr.try_map_inplace(f)?;
184 }
185 Ok(())
186 }
187 IRBexprImpl::Not(expr) => expr.try_map_inplace(f),
188 IRBexprImpl::True => Ok(()),
189 IRBexprImpl::False => Ok(()),
190 IRBexprImpl::Det(expr) => f(expr),
191 IRBexprImpl::Implies(lhs, rhs) => {
192 lhs.try_map_inplace(f)?;
193 rhs.try_map_inplace(f)
194 }
195 IRBexprImpl::Iff(lhs, rhs) => {
196 lhs.try_map_inplace(f)?;
197 rhs.try_map_inplace(f)
198 }
199 }
200 }
201
202 pub(crate) fn cmp(op: CmpOp, lhs: T, rhs: T) -> Self {
203 Self(IRBexprImpl::Cmp(op, lhs, rhs))
204 }
205
206 pub fn det(expr: T) -> Self {
208 Self(IRBexprImpl::Det(expr))
209 }
210
211 #[inline]
212 pub fn eq(lhs: T, rhs: T) -> Self {
214 Self(IRBexprImpl::Cmp(CmpOp::Eq, lhs, rhs))
215 }
216
217 #[inline]
218 pub fn lt(lhs: T, rhs: T) -> Self {
220 Self(IRBexprImpl::Cmp(CmpOp::Lt, lhs, rhs))
221 }
222
223 #[inline]
224 pub fn le(lhs: T, rhs: T) -> Self {
226 Self(IRBexprImpl::Cmp(CmpOp::Le, lhs, rhs))
227 }
228
229 #[inline]
230 pub fn gt(lhs: T, rhs: T) -> Self {
232 Self(IRBexprImpl::Cmp(CmpOp::Gt, lhs, rhs))
233 }
234
235 #[inline]
236 pub fn ge(lhs: T, rhs: T) -> Self {
238 Self(IRBexprImpl::Cmp(CmpOp::Ge, lhs, rhs))
239 }
240
241 #[inline]
242 pub fn implies(self, rhs: Self) -> Self {
244 Self(IRBexprImpl::Implies(Box::new(self), Box::new(rhs)))
245 }
246
247 #[inline]
248 pub fn iff(self, rhs: Self) -> Self {
250 Self(IRBexprImpl::Iff(Box::new(self), Box::new(rhs)))
251 }
252
253 pub fn and(self, rhs: Self) -> Self {
255 Self(match (self.0, rhs.0) {
256 (IRBexprImpl::And(mut lhs), IRBexprImpl::And(rhs)) => {
257 lhs.reserve(rhs.len());
258 lhs.extend(rhs);
259 IRBexprImpl::And(lhs)
260 }
261 (exp, IRBexprImpl::And(mut lst)) | (IRBexprImpl::And(mut lst), exp) => {
263 lst.push(Self(exp));
264 IRBexprImpl::And(lst)
265 }
266 (lhs, rhs) => IRBexprImpl::And(vec![Self(lhs), Self(rhs)]),
267 })
268 }
269
270 pub fn and_many(exprs: impl IntoIterator<Item = Self>) -> Self {
272 Self(IRBexprImpl::And(exprs.into_iter().collect()))
273 }
274
275 pub fn or(self, rhs: Self) -> Self {
277 Self(match (self.0, rhs.0) {
278 (IRBexprImpl::Or(mut lhs), IRBexprImpl::Or(rhs)) => {
279 lhs.reserve(rhs.len());
280 lhs.extend(rhs);
281 IRBexprImpl::Or(lhs)
282 }
283 (exp, IRBexprImpl::Or(mut lst)) | (IRBexprImpl::Or(mut lst), exp) => {
285 lst.push(Self(exp));
286 IRBexprImpl::Or(lst)
287 }
288 (lhs, rhs) => IRBexprImpl::Or(vec![Self(lhs), Self(rhs)]),
289 })
290 }
291
292 pub fn or_many(exprs: impl IntoIterator<Item = Self>) -> Self {
294 Self(IRBexprImpl::Or(exprs.into_iter().collect()))
295 }
296
297 pub fn with<O>(self, other: O) -> IRBexpr<(O, T)>
299 where
300 O: Clone,
301 {
302 self.map(&mut |t| (other.clone(), t))
303 }
304
305 pub fn with_fn<O>(self, other: impl Fn() -> O) -> IRBexpr<(O, T)> {
307 self.map(&mut |t| (other(), t))
308 }
309}
310
311struct LogLine {
312 before: Option<String>,
313 ident: usize,
314}
315
316impl LogLine {
317 fn new<T: std::fmt::Debug>(expr: &IRBexprImpl<T>, ident: usize) -> Self {
318 if matches!(
319 expr,
320 IRBexprImpl::True | IRBexprImpl::False | IRBexprImpl::Cmp(_, _, _)
321 ) {
322 Self {
323 before: Some(format!("{expr:?}")),
324 ident,
325 }
326 } else {
327 log::debug!("[constant_fold] {:ident$} {expr:?} {{", "", ident = ident);
328 Self {
329 before: None,
330 ident,
331 }
332 }
333 }
334
335 fn log<T: std::fmt::Debug>(self, expr: &mut IRBexprImpl<T>) {
336 match self.before {
337 Some(before) => {
338 log::debug!(
339 "[constant_fold] {:ident$} {} -> {expr:?}",
340 "",
341 before,
342 ident = self.ident
343 );
344 }
345 None => {
346 log::debug!(
347 "[constant_fold] {:ident$} }} -> {expr:?}",
348 "",
349 ident = self.ident
350 );
351 }
352 }
353 }
354}
355
356impl Canonicalize for IRBexpr<IRAexpr> {
357 fn canonicalize(&mut self) {
359 match &mut self.0 {
360 IRBexprImpl::True => {}
361 IRBexprImpl::False => {}
362 IRBexprImpl::Cmp(op, lhs, rhs) => {
363 if let Some((op, lhs, rhs)) = canonicalize_constraint(*op, lhs, rhs) {
364 *self = Self(IRBexprImpl::Cmp(op, lhs, rhs));
365 }
366 }
367 IRBexprImpl::And(exprs) => {
368 for expr in exprs {
369 expr.canonicalize();
370 }
371 }
372 IRBexprImpl::Or(exprs) => {
373 for expr in exprs {
374 expr.canonicalize();
375 }
376 }
377 IRBexprImpl::Not(expr) => {
378 expr.canonicalize();
379 match &expr.0 {
380 IRBexprImpl::True => {
381 *self = Self(IRBexprImpl::False);
382 }
383 IRBexprImpl::False => {
384 *self = Self(IRBexprImpl::True);
385 }
386 IRBexprImpl::Cmp(op, lhs, rhs) => {
387 *self = Self(IRBexprImpl::Cmp(
388 match op {
389 CmpOp::Eq => CmpOp::Ne,
390 CmpOp::Lt => CmpOp::Ge,
391 CmpOp::Le => CmpOp::Gt,
392 CmpOp::Gt => CmpOp::Le,
393 CmpOp::Ge => CmpOp::Lt,
394 CmpOp::Ne => CmpOp::Eq,
395 },
396 lhs.clone(),
397 rhs.clone(),
398 ));
399 self.canonicalize();
400 }
401 _ => {}
402 }
403 }
404 IRBexprImpl::Det(_) => {}
405 IRBexprImpl::Implies(lhs, rhs) => {
406 lhs.canonicalize();
407 rhs.canonicalize();
408 }
409 IRBexprImpl::Iff(lhs, rhs) => {
410 lhs.canonicalize();
411 rhs.canonicalize();
412 }
413 }
414 }
415}
416
417impl<T> IRBexpr<T>
418where
419 T: ConstantFolding + std::fmt::Debug,
420 T::T: Eq + Ord,
421{
422 fn constant_fold_impl(&mut self, indent: usize) -> Result<(), T::Error> {
424 let log = LogLine::new(&self.0, indent);
425 match &mut self.0 {
426 IRBexprImpl::True => {
427 log.log(&mut self.0);
428 }
429 IRBexprImpl::False => {
430 log.log(&mut self.0);
431 }
432 IRBexprImpl::Cmp(op, lhs, rhs) => {
433 lhs.constant_fold()?;
434 rhs.constant_fold()?;
435 if let Some((lhs, rhs)) = lhs.const_value().zip(rhs.const_value()) {
436 *self = match op {
437 CmpOp::Eq => lhs == rhs,
438 CmpOp::Lt => lhs < rhs,
439 CmpOp::Le => lhs <= rhs,
440 CmpOp::Gt => lhs > rhs,
441 CmpOp::Ge => lhs >= rhs,
442 CmpOp::Ne => lhs != rhs,
443 }
444 .into()
445 }
446 log.log(&mut self.0);
447 }
448 IRBexprImpl::And(exprs) => {
449 for expr in &mut *exprs {
450 expr.constant_fold_impl(indent + 2)?;
451 }
452 if exprs.iter().any(|expr| {
454 expr.const_value()
455 .map(|b| !b)
457 .unwrap_or_default()
459 }) {
460 *self = Self(IRBexprImpl::False);
461 log.log(&mut self.0);
462 return Ok(());
463 }
464 exprs.retain(|expr| {
466 expr.const_value()
467 .map(|b| !b)
469 .unwrap_or(true)
471 });
472 if exprs.is_empty() {
473 *self = Self(IRBexprImpl::True);
474 }
475 log.log(&mut self.0);
476 }
477 IRBexprImpl::Or(exprs) => {
478 for expr in &mut *exprs {
479 expr.constant_fold_impl(indent + 2)?;
480 }
481 if exprs
483 .iter()
484 .any(|expr| expr.const_value().unwrap_or_default())
485 {
486 *self = Self(IRBexprImpl::True);
487 log.log(&mut self.0);
488 return Ok(());
489 }
490 exprs.retain(|expr| {
492 expr.const_value()
493 .unwrap_or(true)
495 });
496 if exprs.is_empty() {
497 *self = Self(IRBexprImpl::False);
498 }
499 log.log(&mut self.0);
500 }
501 IRBexprImpl::Not(expr) => {
502 expr.constant_fold_impl(indent + 2)?;
503 if let Some(b) = expr.const_value() {
504 *self = (!b).into();
505 }
506 log.log(&mut self.0);
507 }
508 IRBexprImpl::Det(expr) => expr.constant_fold()?,
509 IRBexprImpl::Implies(lhs, rhs) => {
510 lhs.constant_fold_impl(indent + 2)?;
511 rhs.constant_fold_impl(indent + 2)?;
512 if let Some((lhs, rhs)) = lhs.const_value().zip(rhs.const_value()) {
513 *self = (!lhs || rhs).into();
514 }
515 }
516 IRBexprImpl::Iff(lhs, rhs) => {
517 lhs.constant_fold_impl(indent + 2)?;
518 rhs.constant_fold_impl(indent + 2)?;
519 if let Some((lhs, rhs)) = lhs.const_value().zip(rhs.const_value()) {
520 *self = (lhs == rhs).into();
521 }
522 }
523 }
524 Ok(())
525 }
526}
527
528impl<T> ConstantFolding for IRBexpr<T>
529where
530 T: ConstantFolding + std::fmt::Debug,
531 T::T: Eq + Ord,
532{
533 type T = bool;
534
535 type Error = T::Error;
536
537 fn constant_fold(&mut self) -> Result<(), Self::Error> {
538 self.constant_fold_impl(0)
539 }
540
541 fn const_value(&self) -> Option<bool> {
543 match &self.0 {
544 IRBexprImpl::True => Some(true),
545 IRBexprImpl::False => Some(false),
546 _ => None,
547 }
548 }
549}
550
551impl<T: Evaluate<ExprProperties>> Evaluate<ExprProperties> for IRBexpr<T> {
552 fn evaluate(&self) -> ExprProperties {
553 match &self.0 {
554 IRBexprImpl::True | IRBexprImpl::False => ExprProperty::Const.into(),
555 IRBexprImpl::Cmp(_, lhs, rhs) => lhs.evaluate() & rhs.evaluate(),
556 IRBexprImpl::And(exprs) | IRBexprImpl::Or(exprs) => {
557 exprs.iter().map(Evaluate::evaluate).product()
558 }
559 IRBexprImpl::Not(expr) => expr.evaluate(),
560 IRBexprImpl::Det(expr) => expr.evaluate(),
561 IRBexprImpl::Implies(lhs, rhs) | IRBexprImpl::Iff(lhs, rhs) => {
562 lhs.evaluate() & rhs.evaluate()
563 }
564 }
565 }
566}
567
568impl<T> From<bool> for IRBexpr<T> {
569 fn from(value: bool) -> Self {
570 Self(if value {
571 IRBexprImpl::True
572 } else {
573 IRBexprImpl::False
574 })
575 }
576}
577
578impl<L, R> EqvRelation<IRBexpr<L>, IRBexpr<R>> for SymbolicEqv
580where
581 SymbolicEqv: EqvRelation<L, R>,
582{
583 fn equivalent(lhs: &IRBexpr<L>, rhs: &IRBexpr<R>) -> bool {
586 match (&lhs.0, &rhs.0) {
587 (IRBexprImpl::True, IRBexprImpl::True) | (IRBexprImpl::False, IRBexprImpl::False) => {
588 true
589 }
590 (IRBexprImpl::Cmp(op1, lhs1, rhs1), IRBexprImpl::Cmp(op2, lhs2, rhs2)) => {
591 op1 == op2 && equiv!(Self | lhs1, lhs2) && equiv!(Self | rhs1, rhs2)
592 }
593 (IRBexprImpl::And(lhs), IRBexprImpl::And(rhs)) => {
594 equiv!(Self | lhs, rhs)
595 }
596 (IRBexprImpl::Or(lhs), IRBexprImpl::Or(rhs)) => {
597 equiv!(Self | lhs, rhs)
598 }
599 (IRBexprImpl::Not(lhs), IRBexprImpl::Not(rhs)) => {
600 equiv!(Self | lhs, rhs)
601 }
602 (IRBexprImpl::Det(lhs), IRBexprImpl::Det(rhs)) => equiv!(Self | lhs, rhs),
603 (IRBexprImpl::Implies(lhs1, rhs1), IRBexprImpl::Implies(lhs2, rhs2)) => {
604 equiv!(Self | lhs1, lhs2) && equiv!(Self | rhs1, rhs2)
605 }
606
607 (IRBexprImpl::Iff(lhs1, rhs1), IRBexprImpl::Iff(lhs2, rhs2)) => {
608 equiv!(Self | lhs1, lhs2) && equiv!(Self | rhs1, rhs2)
609 }
610 _ => false,
611 }
612 }
613}
614
615impl<T> BitAnd for IRBexpr<T> {
616 type Output = Self;
617
618 fn bitand(self, rhs: Self) -> Self::Output {
619 self.and(rhs)
620 }
621}
622
623impl<T> BitOr for IRBexpr<T> {
624 type Output = Self;
625
626 fn bitor(self, rhs: Self) -> Self::Output {
627 self.or(rhs)
628 }
629}
630
631impl<T> Not for IRBexpr<T> {
632 type Output = Self;
633
634 fn not(self) -> Self::Output {
635 match self.0 {
636 IRBexprImpl::Not(e) => *e,
637 e => Self(IRBexprImpl::Not(Box::new(Self(e)))),
638 }
639 }
640}
641
642impl<T: std::fmt::Debug> std::fmt::Debug for IRBexprImpl<T> {
643 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644 match self {
645 IRBexprImpl::Cmp(cmp_op, lhs, rhs) => write!(f, "({cmp_op} {lhs:?} {rhs:?})",),
646 IRBexprImpl::And(exprs) => write!(f, "(&& {exprs:?})"),
647 IRBexprImpl::Or(exprs) => write!(f, "(|| {exprs:?})"),
648 IRBexprImpl::Not(expr) => write!(f, "(! {expr:?})"),
649 IRBexprImpl::True => write!(f, "(true)"),
650 IRBexprImpl::False => write!(f, "(false)"),
651 IRBexprImpl::Det(expr) => write!(f, "(det {expr:?})"),
652 IRBexprImpl::Implies(lhs, rhs) => write!(f, "(=> {lhs:?} {rhs:?})"),
653 IRBexprImpl::Iff(lhs, rhs) => write!(f, "(<=> {lhs:?} {rhs:?})"),
654 }
655 }
656}
657
658impl<T: Clone> Clone for IRBexpr<T> {
659 fn clone(&self) -> Self {
660 Self(match &self.0 {
661 IRBexprImpl::Cmp(cmp_op, lhs, rhs) => {
662 IRBexprImpl::Cmp(*cmp_op, lhs.clone(), rhs.clone())
663 }
664 IRBexprImpl::And(exprs) => IRBexprImpl::And(exprs.clone()),
665 IRBexprImpl::Or(exprs) => IRBexprImpl::Or(exprs.clone()),
666 IRBexprImpl::Not(expr) => IRBexprImpl::Not(expr.clone()),
667 IRBexprImpl::True => IRBexprImpl::True,
668 IRBexprImpl::False => IRBexprImpl::False,
669 IRBexprImpl::Det(expr) => IRBexprImpl::Det(expr.clone()),
670 IRBexprImpl::Implies(lhs, rhs) => IRBexprImpl::Implies(lhs.clone(), rhs.clone()),
671 IRBexprImpl::Iff(lhs, rhs) => IRBexprImpl::Iff(lhs.clone(), rhs.clone()),
672 })
673 }
674}
675
676impl<T: PartialEq> PartialEq for IRBexpr<T> {
677 fn eq(&self, other: &Self) -> bool {
678 match (&self.0, &other.0) {
679 (IRBexprImpl::Cmp(op1, lhs1, rhs1), IRBexprImpl::Cmp(op2, lhs2, rhs2)) => {
680 op1 == op2 && lhs1 == lhs2 && rhs1 == rhs2
681 }
682 (IRBexprImpl::And(lhs), IRBexprImpl::And(rhs)) => lhs == rhs,
683 (IRBexprImpl::Or(lhs), IRBexprImpl::Or(rhs)) => lhs == rhs,
684 (IRBexprImpl::Not(lhs), IRBexprImpl::Not(rhs)) => lhs == rhs,
685 (IRBexprImpl::True, IRBexprImpl::True) => true,
686 (IRBexprImpl::False, IRBexprImpl::False) => true,
687 (IRBexprImpl::Det(lhs), IRBexprImpl::Det(rhs)) => lhs == rhs,
688 (IRBexprImpl::Implies(lhs1, rhs1), IRBexprImpl::Implies(lhs2, rhs2)) => {
689 lhs1 == lhs2 && rhs1 == rhs2
690 }
691 (IRBexprImpl::Iff(lhs1, rhs1), IRBexprImpl::Iff(lhs2, rhs2)) => {
692 lhs1 == lhs2 && rhs1 == rhs2
693 }
694 _ => false,
695 }
696 }
697}
698
699fn reduce_bool_expr<A, L>(
700 exprs: impl IntoIterator<Item = IRBexpr<A>>,
701 l: &L,
702 cb: impl Fn(&L, &L::CellOutput, &L::CellOutput) -> haloumi_lowering::Result<L::CellOutput>,
703) -> haloumi_lowering::Result<L::CellOutput>
704where
705 A: LowerableExpr,
706 L: ExprLowering + ?Sized,
707{
708 exprs
709 .into_iter()
710 .map(|e| e.lower(l))
711 .reduce(|lhs, rhs| lhs.and_then(|lhs| rhs.and_then(|rhs| cb(l, &lhs, &rhs))))
712 .ok_or_else(|| lowering_err!(Error::EmptyBexpr))
713 .and_then(identity)
714}
715
716impl<A: LowerableExpr> LowerableExpr for IRBexpr<A> {
717 fn lower<L>(self, l: &L) -> haloumi_lowering::Result<L::CellOutput>
718 where
719 L: ExprLowering + ?Sized,
720 {
721 match self.0 {
722 IRBexprImpl::Cmp(cmp_op, lhs, rhs) => {
723 let lhs = lhs.lower(l)?;
724 let rhs = rhs.lower(l)?;
725 match cmp_op {
726 CmpOp::Eq => l.lower_eq(&lhs, &rhs),
727 CmpOp::Lt => l.lower_lt(&lhs, &rhs),
728 CmpOp::Le => l.lower_le(&lhs, &rhs),
729 CmpOp::Gt => l.lower_gt(&lhs, &rhs),
730 CmpOp::Ge => l.lower_ge(&lhs, &rhs),
731 CmpOp::Ne => l.lower_ne(&lhs, &rhs),
732 }
733 }
734 IRBexprImpl::And(exprs) => reduce_bool_expr(exprs, l, L::lower_and),
735 IRBexprImpl::Or(exprs) => reduce_bool_expr(exprs, l, L::lower_or),
736 IRBexprImpl::Not(expr) => expr.lower(l).and_then(|e| l.lower_not(&e)),
737 IRBexprImpl::True => l.lower_true(),
738 IRBexprImpl::False => l.lower_false(),
739 IRBexprImpl::Det(expr) => expr.lower(l).and_then(|e| l.lower_det(&e)),
740 IRBexprImpl::Implies(lhs, rhs) => {
741 let lhs = lhs.lower(l)?;
742 let rhs = rhs.lower(l)?;
743 l.lower_implies(&lhs, &rhs)
744 }
745 IRBexprImpl::Iff(lhs, rhs) => {
746 let lhs = lhs.lower(l)?;
747 let rhs = rhs.lower(l)?;
748 l.lower_iff(&lhs, &rhs)
749 }
750 }
751 }
752}
753
754impl<T: IRPrintable> IRPrintable for IRBexpr<T> {
755 fn fmt(&self, ctx: &mut crate::printer::IRPrinterCtx<'_, '_>) -> crate::printer::Result {
756 match &self.0 {
757 IRBexprImpl::True => write!(ctx, "(true)"),
758 IRBexprImpl::False => write!(ctx, "(false)"),
759 IRBexprImpl::Cmp(cmp_op, lhs, rhs) => ctx.block(format!("{cmp_op}").as_str(), |ctx| {
760 if lhs.depth() > 1 {
761 ctx.nl()?;
762 }
763 lhs.fmt(ctx)?;
764 if lhs.depth() > 1 || rhs.depth() > 1 {
765 ctx.nl()?;
766 }
767 rhs.fmt(ctx)
768 }),
769 IRBexprImpl::And(exprs) => ctx.block("&&", |ctx| {
770 let do_nl = exprs.iter().any(|expr| expr.depth() > 1);
771 let mut is_first = true;
772 for expr in exprs {
773 if do_nl && !is_first {
774 ctx.nl()?;
775 }
776 is_first = false;
777 expr.fmt(ctx)?;
778 }
779 Ok(())
780 }),
781 IRBexprImpl::Or(exprs) => ctx.block("||", |ctx| {
782 let do_nl = exprs.iter().any(|expr| expr.depth() > 1);
783 let mut is_first = true;
784 for expr in exprs {
785 if do_nl && !is_first {
786 ctx.nl()?;
787 }
788 is_first = false;
789 expr.fmt(ctx)?;
790 }
791 Ok(())
792 }),
793 IRBexprImpl::Not(expr) => ctx.block("!", |ctx| expr.fmt(ctx)),
794 IRBexprImpl::Det(expr) => ctx.block("det", |ctx| expr.fmt(ctx)),
795 IRBexprImpl::Implies(lhs, rhs) => ctx.block("=>", |ctx| {
796 if lhs.depth() > 1 {
797 ctx.nl()?;
798 }
799 lhs.fmt(ctx)?;
800 if lhs.depth() > 1 || rhs.depth() > 1 {
801 ctx.nl()?;
802 }
803 rhs.fmt(ctx)
804 }),
805 IRBexprImpl::Iff(lhs, rhs) => ctx.block("<=>", |ctx| {
806 if lhs.depth() > 1 {
807 ctx.nl()?;
808 }
809 lhs.fmt(ctx)?;
810 if lhs.depth() > 1 || rhs.depth() > 1 {
811 ctx.nl()?;
812 }
813 rhs.fmt(ctx)
814 }),
815 }
816 }
817
818 fn depth(&self) -> usize {
819 match &self.0 {
820 IRBexprImpl::True | IRBexprImpl::False => 1,
821 IRBexprImpl::Cmp(_, lhs, rhs) => 1 + std::cmp::max(lhs.depth(), rhs.depth()),
822 IRBexprImpl::And(exprs) | IRBexprImpl::Or(exprs) => {
823 1 + exprs
824 .iter()
825 .map(|expr| expr.depth())
826 .max()
827 .unwrap_or_default()
828 }
829 IRBexprImpl::Not(expr) => 1 + expr.depth(),
830 IRBexprImpl::Det(expr) => 1 + expr.depth(),
831 IRBexprImpl::Implies(lhs, rhs) | IRBexprImpl::Iff(lhs, rhs) => {
832 1 + std::cmp::max(lhs.depth(), rhs.depth())
833 }
834 }
835 }
836}
837
838#[derive(Debug, Clone, PartialEq)]
843pub struct IRConstBexpr<A>(IRBexpr<A>);
844
845impl<A> IRConstBexpr<A> {
846 #[allow(dead_code)]
847 pub(crate) fn map<O>(expr: IRConstBexpr<O>, f: &mut impl FnMut(O) -> A) -> Self {
848 Self(expr.0.map(f))
849 }
850
851 #[allow(dead_code)]
852 pub(crate) fn map_into<O>(expr: &IRConstBexpr<O>, f: &mut impl FnMut(&O) -> A) -> Self {
853 Self(expr.0.map_into(f))
854 }
855
856 #[allow(dead_code)]
857 pub(crate) fn try_map<O, E>(
858 expr: IRConstBexpr<O>,
859 f: &mut impl FnMut(O) -> Result<A, E>,
860 ) -> Result<Self, E> {
861 Ok(Self(expr.0.try_map(f)?))
862 }
863
864 #[allow(dead_code)]
865 pub(crate) fn map_inplace(expr: &mut Self, f: &mut impl FnMut(&mut A)) {
866 expr.0.map_inplace(f);
867 }
868
869 #[allow(dead_code)]
870 pub(crate) fn try_map_inplace<E>(
871 expr: &mut Self,
872 f: &mut impl FnMut(&mut A) -> Result<(), E>,
873 ) -> Result<(), E> {
874 expr.0.try_map_inplace(f)
875 }
876}
877
878impl<A> Deref for IRConstBexpr<A> {
879 type Target = IRBexpr<A>;
880
881 fn deref(&self) -> &Self::Target {
882 &self.0
883 }
884}
885
886impl<A> DerefMut for IRConstBexpr<A> {
887 fn deref_mut(&mut self) -> &mut Self::Target {
888 &mut self.0
889 }
890}
891
892impl<A> AsRef<IRBexpr<A>> for IRConstBexpr<A> {
893 fn as_ref(&self) -> &IRBexpr<A> {
894 self.deref()
895 }
896}
897
898impl<A> AsMut<IRBexpr<A>> for IRConstBexpr<A> {
899 fn as_mut(&mut self) -> &mut IRBexpr<A> {
900 self.deref_mut()
901 }
902}
903
904impl<A> Borrow<IRBexpr<A>> for IRConstBexpr<A> {
905 fn borrow(&self) -> &IRBexpr<A> {
906 self.deref()
907 }
908}
909
910impl<A> BorrowMut<IRBexpr<A>> for IRConstBexpr<A> {
911 fn borrow_mut(&mut self) -> &mut IRBexpr<A> {
912 self.deref_mut()
913 }
914}
915
916#[derive(Debug, Error, Clone, Copy)]
918#[error("attempted to transform a non constant boolean expression")]
919pub struct NonConstIRBexprError;
920
921impl<A> TryFrom<IRBexpr<A>> for IRConstBexpr<A>
922where
923 IRBexpr<A>: Evaluate<ExprProperties>,
924{
925 type Error = NonConstIRBexprError;
926
927 fn try_from(value: IRBexpr<A>) -> Result<Self, Self::Error> {
928 let props = value.evaluate();
929 if props != ExprProperty::Const {
930 return Err(NonConstIRBexprError);
931 }
932 Ok(Self(value))
933 }
934}
935
936impl<A> From<IRConstBexpr<A>> for IRBexpr<A> {
937 fn from(value: IRConstBexpr<A>) -> Self {
938 value.0
939 }
940}
941
942impl<A> Validatable for IRConstBexpr<A>
943where
944 IRBexpr<A>: Evaluate<ExprProperties>,
945{
946 type Diagnostic = SimpleDiagnostic;
947
948 type Context = ();
949
950 fn validate_with_context(
951 &self,
952 _: &Self::Context,
953 ) -> Result<Vec<Self::Diagnostic>, Vec<Self::Diagnostic>> {
954 let mut validation = Validation::new();
955 if self.0.evaluate() != ExprProperty::Const {
956 validation.with_error(SimpleDiagnostic::error(
957 "boolean expression is not constant",
958 ));
959 }
960 validation.into()
961 }
962}
963
964#[cfg(test)]
965mod tests {
966 use super::*;
967
968 fn t() -> IRBexpr<()> {
969 true.into()
970 }
971
972 fn f() -> IRBexpr<()> {
973 false.into()
974 }
975
976 #[test]
977 fn constant_fold_not_true() {
978 let mut expr = !t();
979 expr.constant_fold().unwrap();
980 assert_eq!(expr, f());
981 }
982
983 #[test]
984 fn constant_fold_not_false() {
985 let mut expr = !f();
986 expr.constant_fold().unwrap();
987 assert_eq!(expr, t());
988 }
989
990 impl ConstantFolding for () {
991 type Error = std::convert::Infallible;
992
993 type T = ();
994
995 fn constant_fold(&mut self) -> Result<(), Self::Error> {
996 Ok(())
997 }
998 }
999}