Skip to main content

haloumi_ir/
stmt.rs

1//! Structs for representing statements of the circuit's logic.
2
3use super::expr::IRBexpr;
4use crate::{
5    diagnostics::{Diagnostic, Validation},
6    error::Error,
7    expr::{ExprProperties, IRAexpr, IRConstBexpr},
8    meta::{HasMeta, Meta},
9    printer::IRPrintable,
10    traits::{Canonicalize, ConstantFolding, Evaluate, Validatable},
11};
12use eqv::{EqvRelation, equiv};
13use haloumi_core::{cmp::CmpOp, eqv::SymbolicEqv, slot::Slot};
14use haloumi_lowering::{
15    Lowering,
16    lowerable::{LowerableExpr, LowerableStmt},
17};
18use std::fmt::Write;
19
20mod assert;
21mod assume_determ;
22mod block_comment;
23mod call;
24mod comment;
25mod cond_block;
26mod constraint;
27mod post_cond;
28mod seq;
29
30use assert::Assert;
31use assume_determ::AssumeDeterministic;
32use block_comment::BlockComment;
33use call::Call;
34use comment::Comment;
35use cond_block::CondBlock;
36use constraint::Constraint;
37use post_cond::PostCond;
38use seq::Seq;
39
40mod sealed {
41    pub trait EmitIfSealed {}
42}
43
44/// Trait that enables wrapping IR statements into conditionally emitted blocks.
45pub trait EmitIf<T>: sealed::EmitIfSealed {
46    /// Creates a conditional block.
47    fn emit_if(self, cond: IRConstBexpr<T>) -> IRStmt<T>;
48
49    /// Creates a conditional block that only gets folded if the boolean expression folds to false.
50    fn emit_unless_false(self, cond: IRBexpr<T>) -> IRStmt<T>;
51}
52
53impl<T, I> EmitIf<T> for I
54where
55    I: IntoIterator<Item = IRStmt<T>>,
56{
57    fn emit_if(self, cond: IRConstBexpr<T>) -> IRStmt<T> {
58        CondBlock::new(cond.into(), self.into_iter().collect()).into()
59    }
60
61    fn emit_unless_false(self, cond: IRBexpr<T>) -> IRStmt<T> {
62        CondBlock::new(cond, self.into_iter().collect()).into()
63    }
64}
65impl<T, I> sealed::EmitIfSealed for I where I: IntoIterator<Item = IRStmt<T>> {}
66
67/// IR for operations that occur in the main circuit.
68pub struct IRStmt<T>(IRStmtImpl<T>, Meta);
69
70enum IRStmtImpl<T> {
71    /// A call to another module.
72    ConstraintCall(Call<T>),
73    /// A constraint between two expressions.
74    Constraint(Constraint<T>),
75    /// A text comment.
76    Comment(Comment),
77    /// Indicates that a [`Slot`] must be assumed deterministic by the backend.
78    AssumeDeterministic(AssumeDeterministic),
79    /// A constraint that a [`IRBexpr`] must be true.
80    Assert(Assert<T>),
81    /// A sequence of statements.
82    Seq(Seq<T>),
83    /// A post-condition expression.
84    PostCond(PostCond<T>),
85    /// A conditionally emitted block.
86    CondBlock(CondBlock<T>),
87    /// A block of code with a header comment.
88    BlockComment(BlockComment<T>),
89}
90
91impl<T> HasMeta for IRStmt<T> {
92    fn meta(&self) -> &Meta {
93        &self.1
94    }
95
96    fn meta_mut(&mut self) -> &mut Meta {
97        &mut self.1
98    }
99}
100
101impl<T: PartialEq> PartialEq for IRStmt<T> {
102    /// Equality is defined by the sequence of statements regardless of how they are structured
103    /// inside.
104    ///
105    /// For example:
106    ///     Seq([a, Seq([b, c])]) == Seq([a, b, c])
107    ///     a == Seq([a])
108    fn eq(&self, other: &Self) -> bool {
109        std::iter::zip(self.iter(), other.iter()).all(|(lhs, rhs)| match (&lhs.0, &rhs.0) {
110            (IRStmtImpl::ConstraintCall(lhs), IRStmtImpl::ConstraintCall(rhs)) => lhs.eq(rhs),
111            (IRStmtImpl::Constraint(lhs), IRStmtImpl::Constraint(rhs)) => lhs.eq(rhs),
112            (IRStmtImpl::Comment(lhs), IRStmtImpl::Comment(rhs)) => lhs.eq(rhs),
113            (IRStmtImpl::AssumeDeterministic(lhs), IRStmtImpl::AssumeDeterministic(rhs)) => {
114                lhs.eq(rhs)
115            }
116            (IRStmtImpl::Assert(lhs), IRStmtImpl::Assert(rhs)) => lhs.eq(rhs),
117            (IRStmtImpl::PostCond(lhs), IRStmtImpl::PostCond(rhs)) => lhs.eq(rhs),
118            (IRStmtImpl::CondBlock(lhs), IRStmtImpl::CondBlock(rhs)) => lhs.eq(rhs),
119            (IRStmtImpl::BlockComment(lhs), IRStmtImpl::BlockComment(rhs)) => lhs.eq(rhs),
120            (IRStmtImpl::Seq(_), _) | (_, IRStmtImpl::Seq(_)) => unreachable!(),
121            _ => false,
122        })
123    }
124}
125
126impl<T: std::fmt::Debug> std::fmt::Debug for IRStmt<T> {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        match &self.0 {
129            IRStmtImpl::ConstraintCall(call) => write!(f, "{call:?}"),
130            IRStmtImpl::Constraint(constraint) => write!(f, "{constraint:?}"),
131            IRStmtImpl::Comment(comment) => write!(f, "{comment:?}"),
132            IRStmtImpl::AssumeDeterministic(assume_deterministic) => {
133                write!(f, "{assume_deterministic:?}")
134            }
135            IRStmtImpl::Assert(assert) => write!(f, "{assert:?}"),
136            IRStmtImpl::PostCond(pc) => write!(f, "{pc:?}"),
137            IRStmtImpl::CondBlock(cb) => write!(f, "{cb:?}"),
138            IRStmtImpl::Seq(seq) => write!(f, "{seq:?}"),
139            IRStmtImpl::BlockComment(bc) => write!(f, "{bc:?}"),
140        }
141    }
142}
143
144impl<T> IRStmt<T> {
145    /// Creates a call to another module.
146    pub fn call(
147        callee: impl AsRef<str>,
148        inputs: impl IntoIterator<Item = T>,
149        outputs: impl IntoIterator<Item = Slot>,
150    ) -> Self {
151        Call::new(callee, inputs, outputs).into()
152    }
153
154    /// Creates a post condition formula.
155    pub fn post_cond(cond: IRBexpr<T>) -> Self {
156        PostCond::new(cond).into()
157    }
158
159    /// Creates a constraint between two expressions.
160    pub fn constraint(op: CmpOp, lhs: T, rhs: T) -> Self {
161        Constraint::new(op, lhs, rhs).into()
162    }
163
164    #[inline]
165    /// Creates a constraint with [`CmpOp::Eq`] between two expressions.
166    pub fn eq(lhs: T, rhs: T) -> Self {
167        Self::constraint(CmpOp::Eq, lhs, rhs)
168    }
169
170    #[inline]
171    /// Creates a constraint with [`CmpOp::Lt`] between two expressions.
172    pub fn lt(lhs: T, rhs: T) -> Self {
173        Self::constraint(CmpOp::Lt, lhs, rhs)
174    }
175
176    #[inline]
177    /// Creates a constraint with [`CmpOp::Le`] between two expressions.
178    pub fn le(lhs: T, rhs: T) -> Self {
179        Self::constraint(CmpOp::Le, lhs, rhs)
180    }
181
182    #[inline]
183    /// Creates a constraint with [`CmpOp::Gt`] between two expressions.
184    pub fn gt(lhs: T, rhs: T) -> Self {
185        Self::constraint(CmpOp::Gt, lhs, rhs)
186    }
187
188    #[inline]
189    /// Creates a constraint with [`CmpOp::Ge`] between two expressions.
190    pub fn ge(lhs: T, rhs: T) -> Self {
191        Self::constraint(CmpOp::Ge, lhs, rhs)
192    }
193
194    /// Creates a text comment.
195    pub fn comment(s: impl AsRef<str>) -> Self {
196        Comment::new(s).into()
197    }
198
199    /// Indicates that the [`Slot`] must be assumed deterministic by the backend.
200    pub fn assume_deterministic(f: impl Into<Slot>) -> Self {
201        AssumeDeterministic::new(f.into()).into()
202    }
203
204    /// Creates an assertion in the circuit.
205    pub fn assert(cond: IRBexpr<T>) -> Self {
206        Assert::new(cond).into()
207    }
208
209    /// Creates a statement that is a sequence of other statements.
210    pub fn seq<I>(stmts: impl IntoIterator<Item = IRStmt<I>>) -> Self
211    where
212        I: Into<T>,
213    {
214        Seq::new(stmts).into()
215    }
216
217    /// Creates an empty statement.
218    pub fn empty() -> Self {
219        Seq::empty().into()
220    }
221
222    /// Returns true if the statement is empty.
223    pub fn is_empty(&self) -> bool {
224        match &self.0 {
225            IRStmtImpl::Seq(s) => s.is_empty(),
226            _ => false,
227        }
228    }
229
230    /// Prepends a comment to the statement.
231    pub fn with_comment(self, comment: String) -> Self {
232        let meta = self.1.clone();
233        Self(
234            IRStmtImpl::BlockComment(BlockComment::new(Some(comment), self)),
235            meta,
236        )
237    }
238
239    /// Transforms the inner expression type into another.
240    pub fn map<O>(self, f: &mut impl FnMut(T) -> O) -> IRStmt<O> {
241        match self.0 {
242            IRStmtImpl::ConstraintCall(call) => call.map(f).into(),
243            IRStmtImpl::Constraint(constraint) => constraint.map(f).into(),
244            IRStmtImpl::Comment(comment) => Comment::new(comment.value()).into(),
245            IRStmtImpl::AssumeDeterministic(ad) => AssumeDeterministic::new(ad.value()).into(),
246            IRStmtImpl::Assert(assert) => assert.map(f).into(),
247            IRStmtImpl::PostCond(pc) => pc.map(f).into(),
248            IRStmtImpl::CondBlock(cb) => cb.map(f).into(),
249            IRStmtImpl::Seq(seq) => Seq::new(seq.into_iter().map(|s| s.map(f))).into(),
250            IRStmtImpl::BlockComment(bc) => {
251                BlockComment::new(bc.value().map(ToOwned::to_owned), bc.take_body().map(f)).into()
252            }
253        }
254    }
255
256    /// Maps the statement's inner type to a tuple with the passed value.
257    pub fn with<O>(self, other: O) -> IRStmt<(O, T)>
258    where
259        O: Clone,
260    {
261        self.map(&mut |t| (other.clone(), t))
262    }
263
264    /// Maps the statement's inner type to a tuple with the result of the closure.
265    pub fn with_fn<O>(self, other: impl Fn() -> O) -> IRStmt<(O, T)> {
266        self.map(&mut |t| (other(), t))
267    }
268
269    /// Transforms the inner expression type using [`Into::into`].
270    pub fn into<O>(self) -> IRStmt<O>
271    where
272        O: From<T> + Evaluate<ExprProperties>,
273    {
274        self.map(&mut Into::into)
275    }
276
277    /// Transforms the inner expression type using [`From::from`].
278    pub fn from<O>(value: IRStmt<O>) -> Self
279    where
280        O: Into<T>,
281    {
282        value.map(&mut Into::into)
283    }
284
285    /// Appends the given statement to the current one.
286    pub fn then(self, other: impl Into<Self>) -> Self {
287        match self.0 {
288            IRStmtImpl::Seq(mut seq) => {
289                seq.push(other.into());
290                seq.into()
291            }
292            this => Seq::new([Self(this, self.1), other.into()]).into(),
293        }
294    }
295
296    /// Transforms the inner expression type into another, without moving.
297    pub fn map_into<O>(&self, f: &mut impl FnMut(&T) -> O) -> IRStmt<O> {
298        match &self.0 {
299            IRStmtImpl::ConstraintCall(call) => call.map_into(f).into(),
300            IRStmtImpl::Constraint(constraint) => constraint.map_into(f).into(),
301            IRStmtImpl::Comment(comment) => Comment::new(comment.value()).into(),
302            IRStmtImpl::AssumeDeterministic(ad) => AssumeDeterministic::new(ad.value()).into(),
303            IRStmtImpl::Assert(assert) => assert.map_into(f).into(),
304            IRStmtImpl::PostCond(pc) => pc.map_into(f).into(),
305            IRStmtImpl::CondBlock(cb) => cb.map_into(f).into(),
306            IRStmtImpl::Seq(seq) => Seq::new(seq.iter().map(|s| s.map_into(f))).into(),
307            IRStmtImpl::BlockComment(bc) => {
308                BlockComment::new(bc.value().map(ToOwned::to_owned), bc.body().map_into(f)).into()
309            }
310        }
311    }
312
313    /// Tries to transform the inner expression type into another.
314    pub fn try_map<O, E>(self, f: &mut impl FnMut(T) -> Result<O, E>) -> Result<IRStmt<O>, E> {
315        Ok(match self.0 {
316            IRStmtImpl::ConstraintCall(call) => call.try_map(f)?.into(),
317            IRStmtImpl::Constraint(constraint) => constraint.try_map(f)?.into(),
318            IRStmtImpl::Comment(comment) => Comment::new(comment.value()).into(),
319            IRStmtImpl::AssumeDeterministic(ad) => AssumeDeterministic::new(ad.value()).into(),
320            IRStmtImpl::Assert(assert) => assert.try_map(f)?.into(),
321            IRStmtImpl::PostCond(pc) => pc.try_map(f)?.into(),
322            IRStmtImpl::CondBlock(cb) => cb.try_map(f)?.into(),
323            IRStmtImpl::Seq(seq) => Seq::new(
324                seq.into_iter()
325                    .map(|s| s.try_map(f))
326                    .collect::<Result<Vec<_>, _>>()?,
327            )
328            .into(),
329            IRStmtImpl::BlockComment(bc) => BlockComment::new(
330                bc.value().map(ToOwned::to_owned),
331                bc.take_body().try_map(f)?,
332            )
333            .into(),
334        })
335    }
336
337    /// Modifies the inner expression type in place.
338    pub fn map_inplace(&mut self, f: &mut impl FnMut(&mut T)) {
339        match &mut self.0 {
340            IRStmtImpl::ConstraintCall(call) => call.map_inplace(f),
341            IRStmtImpl::Constraint(constraint) => constraint.map_inplace(f),
342            IRStmtImpl::Assert(assert) => assert.map_inplace(f),
343            IRStmtImpl::PostCond(pc) => pc.map_inplace(f),
344            IRStmtImpl::CondBlock(cb) => cb.map_inplace(f),
345            IRStmtImpl::Seq(seq) => seq.iter_mut().for_each(|stmt| stmt.map_inplace(f)),
346
347            IRStmtImpl::BlockComment(bc) => bc.body_mut().map_inplace(f),
348            _ => {}
349        }
350    }
351
352    /// Tries to modify the inner expression type in place.
353    pub fn try_map_inplace<E>(
354        &mut self,
355        f: &mut impl FnMut(&mut T) -> Result<(), E>,
356    ) -> Result<(), E> {
357        match &mut self.0 {
358            IRStmtImpl::ConstraintCall(call) => call.try_map_inplace(f),
359            IRStmtImpl::Constraint(constraint) => constraint.try_map_inplace(f),
360            IRStmtImpl::Assert(assert) => assert.try_map_inplace(f),
361            IRStmtImpl::PostCond(pc) => pc.try_map_inplace(f),
362            IRStmtImpl::CondBlock(cb) => cb.try_map_inplace(f),
363            IRStmtImpl::Seq(seq) => seq.iter_mut().try_for_each(|stmt| stmt.try_map_inplace(f)),
364            IRStmtImpl::BlockComment(bc) => bc.body_mut().try_map_inplace(f),
365            _ => Ok(()),
366        }
367    }
368
369    /// Modifies the inner slots in place.
370    pub fn map_slot_inplace(&mut self, f: &mut impl FnMut(&mut Slot)) {
371        match &mut self.0 {
372            IRStmtImpl::ConstraintCall(call) => call.outputs_mut().iter_mut().for_each(f),
373            IRStmtImpl::AssumeDeterministic(det) => f(det.value_mut()),
374            IRStmtImpl::Seq(seq) => seq.iter_mut().for_each(|stmt| stmt.map_slot_inplace(f)),
375            IRStmtImpl::CondBlock(cb) => cb.body_mut().map_slot_inplace(f),
376            IRStmtImpl::BlockComment(bc) => bc.body_mut().map_slot_inplace(f),
377            _ => {}
378        }
379    }
380
381    /// Tries to modify the inner slots n place.
382    pub fn try_map_slot_inplace<E>(
383        &mut self,
384        f: &mut impl FnMut(&mut Slot) -> Result<(), E>,
385    ) -> Result<(), E> {
386        match &mut self.0 {
387            IRStmtImpl::ConstraintCall(call) => call.outputs_mut().iter_mut().try_for_each(f),
388            IRStmtImpl::AssumeDeterministic(det) => f(det.value_mut()),
389            IRStmtImpl::Seq(seq) => seq
390                .iter_mut()
391                .try_for_each(|stmt| stmt.try_map_slot_inplace(f)),
392            IRStmtImpl::CondBlock(cb) => cb.body_mut().try_map_slot_inplace(f),
393            IRStmtImpl::BlockComment(bc) => bc.body_mut().try_map_slot_inplace(f),
394            _ => Ok(()),
395        }
396    }
397
398    /// Returns an iterator of references to the statements.
399    pub fn iter(&self) -> IRStmtRefIter<'_, T> {
400        IRStmtRefIter { stack: vec![self] }
401    }
402
403    /// Returns an iterator of mutable references to the statements.
404    pub fn iter_mut(&mut self) -> IRStmtRefMutIter<'_, T> {
405        IRStmtRefMutIter { stack: vec![self] }
406    }
407
408    /// Propagates the metadata of this statement to the inner statements.
409    pub fn propagate_meta(&mut self) {
410        if let IRStmtImpl::Seq(s) = &mut self.0 {
411            for stmt in s.iter_mut() {
412                stmt.meta_mut().complete_with(self.1);
413            }
414        }
415    }
416}
417
418impl<T> ConstantFolding for IRStmt<T>
419where
420    T: ConstantFolding + std::fmt::Debug + Clone,
421    Error: From<T::Error>,
422    T::T: Eq + Ord,
423{
424    type Error = Error;
425    type T = ();
426
427    /// Folds the statements if the expressions are constant.
428    /// If a assert-like statement folds into a tautology (i.e. `(= 0 0 )`) gets removed. If it
429    /// folds into a unsatisfiable proposition the method returns an error.
430    fn constant_fold(&mut self) -> Result<(), Error> {
431        match &mut self.0 {
432            IRStmtImpl::ConstraintCall(call) => call.constant_fold()?,
433            IRStmtImpl::Constraint(constraint) => {
434                if let Some(replacement) = constraint.constant_fold(self.1)? {
435                    *self = replacement;
436                }
437            }
438            IRStmtImpl::Comment(_) => {}
439            IRStmtImpl::AssumeDeterministic(_) => {}
440            IRStmtImpl::Assert(assert) => {
441                if let Some(replacement) = assert.constant_fold(self.1)? {
442                    *self = replacement;
443                }
444            }
445            IRStmtImpl::PostCond(pc) => {
446                if let Some(replacement) = pc.constant_fold(self.1)? {
447                    *self = replacement;
448                }
449            }
450            IRStmtImpl::CondBlock(cb) => {
451                if let Some(replacement) = cb.constant_fold()? {
452                    *self = replacement;
453                }
454            }
455            IRStmtImpl::Seq(seq) => seq.constant_fold()?,
456            IRStmtImpl::BlockComment(bc) => bc.constant_fold()?,
457        }
458        Ok(())
459    }
460}
461
462impl Canonicalize for IRStmt<IRAexpr> {
463    /// Matches the statements against a series of known patterns and applies rewrites if able to.
464    fn canonicalize(&mut self) {
465        match &mut self.0 {
466            IRStmtImpl::ConstraintCall(call) => call.canonicalize(),
467            IRStmtImpl::Constraint(constraint) => constraint.canonicalize(),
468            IRStmtImpl::Comment(_) => {}
469            IRStmtImpl::AssumeDeterministic(_) => {}
470            IRStmtImpl::Assert(assert) => assert.canonicalize(),
471            IRStmtImpl::PostCond(pc) => pc.canonicalize(),
472            IRStmtImpl::CondBlock(cb) => cb.canonicalize(),
473            IRStmtImpl::Seq(seq) => seq.canonicalize(),
474            IRStmtImpl::BlockComment(bc) => bc.canonicalize(),
475        }
476    }
477}
478
479impl<T, D> Validatable for IRStmt<T>
480where
481    IRConstBexpr<T>: Validatable<Diagnostic = D, Context = ()>,
482    D: Diagnostic,
483{
484    type Diagnostic = D;
485
486    type Context = ();
487
488    fn validate_with_context(
489        &self,
490        _: &Self::Context,
491    ) -> Result<Vec<Self::Diagnostic>, Vec<Self::Diagnostic>> {
492        // Nothing that requires validation anymore.
493        Validation::new().into()
494    }
495}
496
497/// IRStmt transilitively inherits the [`SymbolicEqv`] equivalence relation.
498impl<L, R> EqvRelation<IRStmt<L>, IRStmt<R>> for SymbolicEqv
499where
500    SymbolicEqv: EqvRelation<L, R> + EqvRelation<Slot, Slot>,
501{
502    /// Two statements are equivalent if they are structurally equal and their inner entities
503    /// are equivalent.
504    fn equivalent(lhs: &IRStmt<L>, rhs: &IRStmt<R>) -> bool {
505        std::iter::zip(lhs.iter(), rhs.iter()).all(|(lhs, rhs)| match (&lhs.0, &rhs.0) {
506            (IRStmtImpl::ConstraintCall(lhs), IRStmtImpl::ConstraintCall(rhs)) => {
507                equiv! { SymbolicEqv | lhs, rhs }
508            }
509            (IRStmtImpl::Constraint(lhs), IRStmtImpl::Constraint(rhs)) => {
510                equiv! { SymbolicEqv | lhs, rhs }
511            }
512            (IRStmtImpl::Comment(_), IRStmtImpl::Comment(_)) => true,
513            (IRStmtImpl::AssumeDeterministic(lhs), IRStmtImpl::AssumeDeterministic(rhs)) => {
514                equiv! { SymbolicEqv | lhs, rhs }
515            }
516            (IRStmtImpl::Assert(lhs), IRStmtImpl::Assert(rhs)) => {
517                equiv! { SymbolicEqv | lhs, rhs }
518            }
519            (IRStmtImpl::PostCond(lhs), IRStmtImpl::PostCond(rhs)) => {
520                equiv! { SymbolicEqv | lhs, rhs }
521            }
522            (IRStmtImpl::CondBlock(lhs), IRStmtImpl::CondBlock(rhs)) => {
523                equiv! { SymbolicEqv | lhs, rhs }
524            }
525            (IRStmtImpl::BlockComment(lhs), IRStmtImpl::BlockComment(rhs)) => {
526                equiv! { SymbolicEqv | lhs.body(), rhs.body()}
527            }
528            (IRStmtImpl::Seq(_), _) | (_, IRStmtImpl::Seq(_)) => unreachable!(),
529            _ => false,
530        })
531    }
532}
533
534/// Iterator over references.
535#[derive(Debug)]
536pub struct IRStmtRefIter<'a, T> {
537    stack: Vec<&'a IRStmt<T>>,
538}
539
540impl<'a, T> Iterator for IRStmtRefIter<'a, T> {
541    type Item = &'a IRStmt<T>;
542
543    fn next(&mut self) -> Option<Self::Item> {
544        while let Some(node) = self.stack.pop() {
545            match &node.0 {
546                IRStmtImpl::Seq(children) => {
547                    // Reverse to preserve left-to-right order
548                    self.stack.extend(children.iter().rev());
549                }
550                _ => return Some(node),
551            }
552        }
553        None
554    }
555}
556
557/// Iterator over mutable references.
558#[derive(Debug)]
559pub struct IRStmtRefMutIter<'a, T> {
560    stack: Vec<&'a mut IRStmt<T>>,
561}
562
563impl<'a, T> Iterator for IRStmtRefMutIter<'a, T> {
564    type Item = &'a mut IRStmt<T>;
565
566    fn next(&mut self) -> Option<Self::Item> {
567        while let Some(node) = self.stack.pop() {
568            if let IRStmt(IRStmtImpl::Seq(children), _) = node {
569                // Reverse to preserve left-to-right order
570                self.stack.extend(children.iter_mut().rev());
571            } else {
572                return Some(node);
573            }
574        }
575        None
576    }
577}
578
579impl<T> Default for IRStmt<T> {
580    fn default() -> Self {
581        Self::empty()
582    }
583}
584
585/// Iterator of statements.
586#[derive(Debug)]
587pub struct IRStmtIter<T> {
588    stack: Vec<IRStmt<T>>,
589}
590
591impl<T> Iterator for IRStmtIter<T> {
592    type Item = IRStmt<T>;
593
594    fn next(&mut self) -> Option<Self::Item> {
595        while let Some(node) = self.stack.pop() {
596            match node {
597                IRStmt(IRStmtImpl::Seq(children), _) => {
598                    // Reverse to preserve left-to-right order
599                    self.stack.extend(children.into_iter().rev());
600                }
601                stmt => return Some(stmt),
602            }
603        }
604        None
605    }
606}
607
608impl<T> IntoIterator for IRStmt<T> {
609    type Item = Self;
610
611    type IntoIter = IRStmtIter<T>;
612
613    fn into_iter(self) -> Self::IntoIter {
614        IRStmtIter { stack: vec![self] }
615    }
616}
617
618impl<'a, T> IntoIterator for &'a IRStmt<T> {
619    type Item = Self;
620
621    type IntoIter = IRStmtRefIter<'a, T>;
622
623    fn into_iter(self) -> Self::IntoIter {
624        self.iter()
625    }
626}
627
628impl<'a, T> IntoIterator for &'a mut IRStmt<T> {
629    type Item = Self;
630
631    type IntoIter = IRStmtRefMutIter<'a, T>;
632
633    fn into_iter(self) -> Self::IntoIter {
634        self.iter_mut()
635    }
636}
637
638impl<I> FromIterator<IRStmt<I>> for IRStmt<I> {
639    fn from_iter<T: IntoIterator<Item = IRStmt<I>>>(iter: T) -> Self {
640        Self::seq(iter)
641    }
642}
643
644impl<T> From<Call<T>> for IRStmt<T> {
645    fn from(value: Call<T>) -> Self {
646        Self(IRStmtImpl::ConstraintCall(value), Default::default())
647    }
648}
649impl<T> From<Constraint<T>> for IRStmt<T> {
650    fn from(value: Constraint<T>) -> Self {
651        Self(IRStmtImpl::Constraint(value), Default::default())
652    }
653}
654impl<T> From<Comment> for IRStmt<T> {
655    fn from(value: Comment) -> Self {
656        Self(IRStmtImpl::Comment(value), Default::default())
657    }
658}
659impl<T> From<AssumeDeterministic> for IRStmt<T> {
660    fn from(value: AssumeDeterministic) -> Self {
661        Self(IRStmtImpl::AssumeDeterministic(value), Default::default())
662    }
663}
664impl<T> From<Assert<T>> for IRStmt<T> {
665    fn from(value: Assert<T>) -> Self {
666        Self(IRStmtImpl::Assert(value), Default::default())
667    }
668}
669impl<T> From<PostCond<T>> for IRStmt<T> {
670    fn from(value: PostCond<T>) -> Self {
671        Self(IRStmtImpl::PostCond(value), Default::default())
672    }
673}
674impl<T> From<CondBlock<T>> for IRStmt<T> {
675    fn from(value: CondBlock<T>) -> Self {
676        Self(IRStmtImpl::CondBlock(value), Default::default())
677    }
678}
679impl<T> From<Seq<T>> for IRStmt<T> {
680    fn from(value: Seq<T>) -> Self {
681        Self(IRStmtImpl::Seq(value), Default::default())
682    }
683}
684impl<T> From<BlockComment<T>> for IRStmt<T> {
685    fn from(value: BlockComment<T>) -> Self {
686        Self(IRStmtImpl::BlockComment(value), Default::default())
687    }
688}
689
690/// Error raised while lowering if the lowered statement was a conditionally emitted block what was
691/// not resolved yet.
692#[derive(Debug)]
693pub struct UnresolvedCondBlockError;
694
695impl std::fmt::Display for UnresolvedCondBlockError {
696    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
697        write!(
698            f,
699            "attempted to lower an unresolved conditionally emitted block"
700        )
701    }
702}
703
704impl std::error::Error for UnresolvedCondBlockError where Self: std::fmt::Debug {}
705
706impl<T: LowerableExpr> LowerableStmt for IRStmt<T>
707where
708    CondBlock<T>: LowerableStmt,
709{
710    fn lower<L>(self, l: &L) -> haloumi_lowering::Result<()>
711    where
712        L: Lowering + ?Sized,
713    {
714        match self.0 {
715            IRStmtImpl::ConstraintCall(call) => call.lower(l),
716            IRStmtImpl::Constraint(constraint) => constraint.lower(l),
717            IRStmtImpl::Comment(comment) => comment.lower(l),
718            IRStmtImpl::AssumeDeterministic(ad) => ad.lower(l),
719            IRStmtImpl::Assert(assert) => assert.lower(l),
720            IRStmtImpl::PostCond(pc) => pc.lower(l),
721            IRStmtImpl::CondBlock(cb) => cb.lower(l),
722            IRStmtImpl::Seq(seq) => seq.lower(l),
723            IRStmtImpl::BlockComment(block_comment) => block_comment.lower(l),
724        }
725    }
726}
727
728impl<T: Clone> Clone for IRStmt<T> {
729    fn clone(&self) -> Self {
730        match &self.0 {
731            IRStmtImpl::ConstraintCall(call) => call.clone().into(),
732            IRStmtImpl::Constraint(c) => c.clone().into(),
733            IRStmtImpl::Comment(c) => c.clone().into(),
734            IRStmtImpl::AssumeDeterministic(func_io) => func_io.clone().into(),
735            IRStmtImpl::Assert(e) => e.clone().into(),
736            IRStmtImpl::PostCond(e) => e.clone().into(),
737            IRStmtImpl::CondBlock(e) => e.clone().into(),
738            IRStmtImpl::Seq(stmts) => stmts.clone().into(),
739            IRStmtImpl::BlockComment(block_comment) => block_comment.clone().into(),
740        }
741    }
742}
743
744impl<T: IRPrintable> IRPrintable for IRStmt<T> {
745    fn fmt(&self, ctx: &mut crate::printer::IRPrinterCtx<'_, '_>) -> crate::printer::Result {
746        match &self.0 {
747            IRStmtImpl::ConstraintCall(call) => {
748                ctx.fmt_call(call.callee(), call.inputs(), call.outputs(), None)
749            }
750            IRStmtImpl::Constraint(constraint) => {
751                ctx.block(format!("assert/{}", constraint.op()).as_str(), |ctx| {
752                    if constraint.lhs().depth() == 1 {
753                        writeln!(ctx, " ")?;
754                    }
755                    constraint.lhs().fmt(ctx)?;
756                    if constraint.rhs().depth() > 1 {
757                        ctx.nl()?;
758                    } else {
759                        writeln!(ctx, " ")?;
760                    }
761                    constraint.rhs().fmt(ctx)
762                })
763            }
764            IRStmtImpl::Comment(comment) => {
765                ctx.nl()?;
766                writeln!(ctx, "; {}", comment.value())
767            }
768            IRStmtImpl::AssumeDeterministic(assume_deterministic) => ctx
769                .list_nl("assume-deterministic", |ctx| {
770                    assume_deterministic.value().fmt(ctx)
771                }),
772            IRStmtImpl::Assert(assert) => ctx.block("assert", |ctx| assert.cond().fmt(ctx)),
773            IRStmtImpl::Seq(seq) => {
774                for stmt in seq.iter() {
775                    stmt.fmt(ctx)?;
776                }
777                Ok(())
778            }
779            IRStmtImpl::CondBlock(cb) => ctx.block("emit-if", |ctx| {
780                cb.cond().fmt(ctx)?;
781                ctx.nl()?;
782                cb.body().fmt(ctx)
783            }),
784            IRStmtImpl::PostCond(post_cond) => {
785                ctx.block("post-cond", |ctx| post_cond.cond().fmt(ctx))
786            }
787            IRStmtImpl::BlockComment(block_comment) => {
788                if let Some(comment) = block_comment.value() {
789                    ctx.nl()?;
790                    writeln!(ctx, "; {}", comment)?;
791                }
792                block_comment.body().fmt(ctx)
793            }
794        }
795    }
796}
797
798/// Errors caused by failable map operations
799#[derive(Debug, thiserror::Error)]
800#[error(transparent)]
801pub struct TryMapError(#[from] Box<dyn std::error::Error>);
802
803#[cfg(test)]
804mod test;