use crate::traits::{Canonicalize, ConstantFolding};
use eqv::{EqvRelation, equiv};
use haloumi_core::{eqv::SymbolicEqv, slot::Slot};
use haloumi_lowering::{
Lowering,
lowerable::{LowerableExpr, LowerableStmt},
};
pub struct Call<I> {
callee: String,
inputs: Vec<I>,
outputs: Vec<Slot>,
}
impl<T> Call<T> {
pub fn new(
callee: impl AsRef<str>,
inputs: impl IntoIterator<Item = T>,
outputs: impl IntoIterator<Item = Slot>,
) -> Self {
Self {
callee: callee.as_ref().to_owned(),
inputs: inputs.into_iter().collect(),
outputs: outputs.into_iter().collect(),
}
}
pub fn map<O>(self, f: &mut impl FnMut(T) -> O) -> Call<O> {
Call::new(self.callee, self.inputs.into_iter().map(f), self.outputs)
}
pub fn map_into<O>(&self, f: &mut impl FnMut(&T) -> O) -> Call<O> {
Call::new(
self.callee.clone(),
self.inputs.iter().map(f),
self.outputs.clone(),
)
}
pub fn try_map<O, E>(self, f: &mut impl FnMut(T) -> Result<O, E>) -> Result<Call<O>, E> {
Ok(Call::new(
self.callee,
self.inputs
.into_iter()
.map(f)
.collect::<Result<Vec<_>, _>>()?,
self.outputs,
))
}
pub fn map_inplace(&mut self, f: &mut impl FnMut(&mut T)) {
for i in &mut self.inputs {
f(i);
}
}
pub fn try_map_inplace<E>(
&mut self,
f: &mut impl FnMut(&mut T) -> Result<(), E>,
) -> Result<(), E> {
for i in &mut self.inputs {
f(i)?;
}
Ok(())
}
pub fn callee(&self) -> &str {
&self.callee
}
pub fn inputs(&self) -> &[T] {
&self.inputs
}
pub fn outputs(&self) -> &[Slot] {
&self.outputs
}
pub fn inputs_mut(&mut self) -> &mut Vec<T> {
&mut self.inputs
}
pub fn outputs_mut(&mut self) -> &mut Vec<Slot> {
&mut self.outputs
}
pub fn constant_fold(&mut self) -> Result<(), T::Error>
where
T: ConstantFolding,
{
for i in &mut self.inputs {
i.constant_fold()?;
}
Ok(())
}
pub fn canonicalize(&mut self)
where
T: Canonicalize,
{
for i in &mut self.inputs {
i.canonicalize();
}
}
}
impl<I: LowerableExpr> LowerableStmt for Call<I> {
fn lower<L>(self, l: &L) -> haloumi_lowering::Result<()>
where
L: Lowering + ?Sized,
{
let inputs = self
.inputs
.into_iter()
.map(|i| i.lower(l))
.collect::<Result<Vec<_>, _>>()?;
l.generate_call(self.callee.as_str(), &inputs, &self.outputs)
}
}
impl<T: Clone> Clone for Call<T> {
fn clone(&self) -> Self {
Self {
callee: self.callee.clone(),
inputs: self.inputs.clone(),
outputs: self.outputs.clone(),
}
}
}
impl<T: PartialEq> PartialEq for Call<T> {
fn eq(&self, other: &Self) -> bool {
self.callee == other.callee && self.inputs == other.inputs && self.outputs == other.outputs
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for Call<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if f.alternate() {
write!(
f,
"call '{}'({:#?}) -> ({:#?})",
self.callee, self.inputs, self.outputs
)
} else {
write!(
f,
"call '{}'({:?}) -> ({:?})",
self.callee, self.inputs, self.outputs
)
}
}
}
impl<L, R> EqvRelation<Call<L>, Call<R>> for SymbolicEqv
where
SymbolicEqv: EqvRelation<L, R> + EqvRelation<Slot, Slot>,
{
fn equivalent(lhs: &Call<L>, rhs: &Call<R>) -> bool {
lhs.callee == rhs.callee
&& equiv! { SymbolicEqv | &lhs.inputs, &rhs.inputs }
&& equiv! { SymbolicEqv | &lhs.outputs, &rhs.outputs }
}
}