haloumi-ir 0.5.11

Intermediate representation of the haloumi framework.
Documentation
use crate::traits::{Canonicalize, ConstantFolding};
use eqv::{EqvRelation, equiv};
use haloumi_core::{eqv::SymbolicEqv, slot::Slot};
use haloumi_lowering::{
    Lowering,
    lowerable::{LowerableExpr, LowerableStmt},
};

pub struct Call<I> {
    callee: String,
    inputs: Vec<I>,
    outputs: Vec<Slot>,
}

impl<T> Call<T> {
    pub fn new(
        callee: impl AsRef<str>,
        inputs: impl IntoIterator<Item = T>,
        outputs: impl IntoIterator<Item = Slot>,
    ) -> Self {
        Self {
            callee: callee.as_ref().to_owned(),
            inputs: inputs.into_iter().collect(),
            outputs: outputs.into_iter().collect(),
        }
    }

    pub fn map<O>(self, f: &mut impl FnMut(T) -> O) -> Call<O> {
        Call::new(self.callee, self.inputs.into_iter().map(f), self.outputs)
    }

    pub fn map_into<O>(&self, f: &mut impl FnMut(&T) -> O) -> Call<O> {
        Call::new(
            self.callee.clone(),
            self.inputs.iter().map(f),
            self.outputs.clone(),
        )
    }

    pub fn try_map<O, E>(self, f: &mut impl FnMut(T) -> Result<O, E>) -> Result<Call<O>, E> {
        Ok(Call::new(
            self.callee,
            self.inputs
                .into_iter()
                .map(f)
                .collect::<Result<Vec<_>, _>>()?,
            self.outputs,
        ))
    }

    pub fn map_inplace(&mut self, f: &mut impl FnMut(&mut T)) {
        for i in &mut self.inputs {
            f(i);
        }
    }

    pub fn try_map_inplace<E>(
        &mut self,
        f: &mut impl FnMut(&mut T) -> Result<(), E>,
    ) -> Result<(), E> {
        for i in &mut self.inputs {
            f(i)?;
        }
        Ok(())
    }

    pub fn callee(&self) -> &str {
        &self.callee
    }

    pub fn inputs(&self) -> &[T] {
        &self.inputs
    }

    pub fn outputs(&self) -> &[Slot] {
        &self.outputs
    }

    pub fn inputs_mut(&mut self) -> &mut Vec<T> {
        &mut self.inputs
    }

    pub fn outputs_mut(&mut self) -> &mut Vec<Slot> {
        &mut self.outputs
    }

    /// Folds the statements if the expressions are constant.
    pub fn constant_fold(&mut self) -> Result<(), T::Error>
    where
        T: ConstantFolding,
    {
        for i in &mut self.inputs {
            i.constant_fold()?;
        }
        Ok(())
    }

    /// Applies canonicalization patterns
    pub fn canonicalize(&mut self)
    where
        T: Canonicalize,
    {
        for i in &mut self.inputs {
            i.canonicalize();
        }
    }
}

impl<I: LowerableExpr> LowerableStmt for Call<I> {
    fn lower<L>(self, l: &L) -> haloumi_lowering::Result<()>
    where
        L: Lowering + ?Sized,
    {
        let inputs = self
            .inputs
            .into_iter()
            .map(|i| i.lower(l))
            .collect::<Result<Vec<_>, _>>()?;
        l.generate_call(self.callee.as_str(), &inputs, &self.outputs)
    }
}

impl<T: Clone> Clone for Call<T> {
    fn clone(&self) -> Self {
        Self {
            callee: self.callee.clone(),
            inputs: self.inputs.clone(),
            outputs: self.outputs.clone(),
        }
    }
}

impl<T: PartialEq> PartialEq for Call<T> {
    fn eq(&self, other: &Self) -> bool {
        self.callee == other.callee && self.inputs == other.inputs && self.outputs == other.outputs
    }
}

impl<T: std::fmt::Debug> std::fmt::Debug for Call<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if f.alternate() {
            write!(
                f,
                "call '{}'({:#?}) -> ({:#?})",
                self.callee, self.inputs, self.outputs
            )
        } else {
            write!(
                f,
                "call '{}'({:?}) -> ({:?})",
                self.callee, self.inputs, self.outputs
            )
        }
    }
}

impl<L, R> EqvRelation<Call<L>, Call<R>> for SymbolicEqv
where
    SymbolicEqv: EqvRelation<L, R> + EqvRelation<Slot, Slot>,
{
    /// A call statement is equivalent to another if their input and outputs are equivalent and
    /// point to the same callee.
    fn equivalent(lhs: &Call<L>, rhs: &Call<R>) -> bool {
        lhs.callee == rhs.callee
            && equiv! { SymbolicEqv | &lhs.inputs, &rhs.inputs }
            && equiv! { SymbolicEqv | &lhs.outputs, &rhs.outputs }
    }
}