haloumi-ir 0.5.9

Intermediate representation of the haloumi framework.
Documentation
use crate::{
    error::Error,
    traits::{Canonicalize, ConstantFolding},
};
use haloumi_lowering::{
    Lowering,
    lowerable::{LowerableExpr, LowerableStmt},
};

use super::IRStmt;

pub struct Seq<T>(Vec<IRStmt<T>>);

impl<T> Seq<T> {
    pub fn new<I: Into<T>>(stmts: impl IntoIterator<Item = IRStmt<I>>) -> Self {
        Self(
            stmts
                .into_iter()
                .map(|stmt| stmt.map(&mut Into::into))
                .collect(),
        )
    }

    pub fn empty() -> Self {
        Self(vec![])
    }

    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }

    pub fn iter(&self) -> std::slice::Iter<'_, IRStmt<T>> {
        self.0.iter()
    }

    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, IRStmt<T>> {
        self.0.iter_mut()
    }

    pub fn len(&self) -> usize {
        self.0.len()
    }

    /// Folds the statements if the expressions are constant.
    /// If a assert-like statement folds into a tautology (i.e. `(= 0 0 )`) gets removed. If it
    /// folds into a unsatisfiable proposition the method returns an error.
    pub fn constant_fold(&mut self) -> Result<(), Error>
    where
        T: ConstantFolding + std::fmt::Debug + Clone,
        Error: From<T::Error>,
        T::T: Eq + Ord,
    {
        self.0
            .iter_mut()
            .try_for_each(|inner| inner.constant_fold())
    }

    /// Appends a statement to the sequence
    pub fn push(&mut self, stmt: IRStmt<T>) {
        self.0.push(stmt)
    }
}

impl<T> Seq<T>
where
    IRStmt<T>: Canonicalize,
{
    /// Matches the statements against a series of known patterns and applies rewrites if able to.
    pub fn canonicalize(&mut self) {
        for inner in &mut self.0 {
            inner.canonicalize();
        }
    }
}

impl<T: LowerableExpr> LowerableStmt for Seq<T>
where
    IRStmt<T>: LowerableStmt,
{
    fn lower<L>(self, l: &L) -> haloumi_lowering::Result<()>
    where
        L: Lowering + ?Sized,
    {
        self.0.into_iter().try_for_each(|s| s.lower(l))
    }
}

impl<T: Clone> Clone for Seq<T> {
    fn clone(&self) -> Self {
        Self(self.0.clone())
    }
}

impl<T> IntoIterator for Seq<T> {
    type Item = IRStmt<T>;

    type IntoIter = <Vec<IRStmt<T>> as IntoIterator>::IntoIter;

    fn into_iter(self) -> Self::IntoIter {
        self.0.into_iter()
    }
}

impl<T: PartialEq> PartialEq for Seq<T> {
    fn eq(&self, other: &Self) -> bool {
        self.0 == other.0
    }
}

impl<T: std::fmt::Debug> std::fmt::Debug for Seq<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if f.alternate() {
            writeln!(f, "{{")?;
            self.0
                .iter()
                .enumerate()
                .try_for_each(|(idx, stmt)| writeln!(f, "{idx}: {stmt:#?};"))?;
            writeln!(f, "}}")
        } else {
            write!(f, "{{ ")?;
            self.0.iter().try_for_each(|stmt| write!(f, "{stmt:?}; "))?;
            write!(f, " }}")
        }
    }
}