use crate::abstract_interpreter::{CoreSSA, InterpreterError, Result};
use crate::category::core::Dtype;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TypeExpr {
Var(usize),
NdArrayType(NdArrayType),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ShapeExpr {
Var(usize),
OfType(usize), Shape(Vec<NatExpr>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct NdArrayType {
pub dtype: DtypeExpr,
pub shape: ShapeExpr,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum NatExpr {
Var(usize),
Constant(usize),
Mul(Vec<NatExpr>),
Add(Vec<NatExpr>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DtypeExpr {
Var(usize),
OfType(usize), Constant(Dtype),
}
impl NatExpr {
pub(crate) fn nf(&self) -> Self {
use super::isomorphism::normalize;
normalize(self)
}
}
impl ShapeExpr {
pub(crate) fn nf(&self) -> Self {
match self {
ShapeExpr::Var(_) => self.clone(),
ShapeExpr::OfType(_) => self.clone(),
ShapeExpr::Shape(nat_exprs) => {
ShapeExpr::Shape(nat_exprs.iter().map(|m| m.nf()).collect())
}
}
}
}
impl DtypeExpr {
pub(crate) fn nf(&self) -> Self {
self.clone()
}
}
impl TypeExpr {
pub(crate) fn into_ndarraytype(self, ssa: &CoreSSA) -> Result<NdArrayType> {
match self {
Self::NdArrayType(t) => Ok(t),
_ => Err(InterpreterError::TypeError(ssa.edge_id)),
}
}
pub(crate) fn into_shapeexpr_dtype(self, ssa: &CoreSSA) -> Result<(ShapeExpr, DtypeExpr)> {
match self {
Self::NdArrayType(NdArrayType { shape, dtype }) => Ok((shape, dtype)),
_ => Err(InterpreterError::TypeError(ssa.edge_id)),
}
}
pub(crate) fn nf(&self) -> Self {
match self {
TypeExpr::Var(_) => todo!(),
TypeExpr::NdArrayType(NdArrayType { dtype, shape }) => {
TypeExpr::NdArrayType(NdArrayType {
dtype: dtype.nf(),
shape: shape.nf(),
})
}
}
}
}