use open_hypergraphs::lax::{EdgeId, NodeId};
use std::fmt::Debug;
use crate::category::{
core,
core::{Object, Operation, TensorOp},
};
use crate::definition::Def;
use crate::path::Path;
use crate::ssa::{SSA, SSAError};
pub type CoreSSA = SSA<Object, Def<Path, Operation>>;
pub type Result<T> = std::result::Result<T, InterpreterError>;
pub type ResultValues<I> = std::result::Result<Vec<Value<I>>, InterpreterError>;
pub trait Interpreter: Clone {
type Nat: Clone + Debug + PartialEq;
type Dtype: Clone + Debug + PartialEq;
type Shape: Clone + Debug + PartialEq;
type NdArrayType: Clone + Debug + PartialEq;
type Tensor: Clone + Debug;
fn pack(dims: Vec<Self::Nat>) -> Self::Shape;
fn unpack(shape: Self::Shape) -> Option<Vec<Self::Nat>>;
fn shape(tensor: Self::Tensor) -> Option<Self::Shape>;
fn dtype(tensor: Self::Tensor) -> Option<Self::Dtype>;
fn dtype_constant(dtype: core::Dtype) -> Self::Dtype;
fn nat_constant(nat: usize) -> Self::Nat;
fn nat_add(a: Self::Nat, b: Self::Nat) -> Self::Nat;
fn nat_mul(a: Self::Nat, b: Self::Nat) -> Self::Nat;
fn handle_load(&self, ssa: &CoreSSA, path: &Path) -> Option<Vec<Value<Self>>>;
fn handle_definition(
&self,
ssa: &CoreSSA,
args: Vec<Value<Self>>,
path: &Path,
) -> ResultValues<Self>;
fn tensor_op(&self, ssa: &CoreSSA, args: Vec<Value<Self>>, op: &TensorOp)
-> ResultValues<Self>;
}
#[derive(Debug, Clone)]
pub enum Value<V: Interpreter> {
Nat(V::Nat),
Dtype(V::Dtype),
Shape(V::Shape),
Type(V::NdArrayType),
Tensor(V::Tensor),
}
impl<I: Interpreter> PartialEq for Value<I>
where
I::Tensor: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Nat(l0), Self::Nat(r0)) => l0 == r0,
(Self::Dtype(l0), Self::Dtype(r0)) => l0 == r0,
(Self::Shape(l0), Self::Shape(r0)) => l0 == r0,
(Self::Type(l0), Self::Type(r0)) => l0 == r0,
(Self::Tensor(l0), Self::Tensor(r0)) => l0 == r0,
_ => false,
}
}
}
#[derive(Clone, Debug)]
pub enum InterpreterError {
MultipleRead(NodeId),
MultipleWrite(NodeId),
SSAError(SSAError),
TypeError(EdgeId),
ArityError(EdgeId),
Load(EdgeId, Path),
ApplyError(EdgeId),
}
impl From<SSAError> for InterpreterError {
fn from(value: SSAError) -> Self {
InterpreterError::SSAError(value)
}
}
impl std::fmt::Display for InterpreterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for InterpreterError {}