use crate::definition::Def;
use crate::path::Path;
use open_hypergraphs::lax::OpenHypergraph;
pub type Term = OpenHypergraph<Object, Def<Path, Operation>>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NdArrayType {
pub dtype: Dtype,
pub shape: Shape,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Dtype {
F32,
U32,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Shape(pub Vec<usize>);
impl Shape {
pub fn rank(&self) -> usize {
self.0.len()
}
pub fn size(&self) -> usize {
self.0.iter().product()
}
pub fn contiguous_strides(&self) -> Vec<isize> {
let mut strides: Vec<isize> = vec![1];
for dim in self.0.iter().skip(1).rev() {
strides.push(strides.last().unwrap() * (*dim as isize));
}
strides.reverse();
strides
}
}
impl std::ops::Index<usize> for Shape {
type Output = usize;
fn index(&self, index: usize) -> &Self::Output {
&self.0[index]
}
}
impl std::ops::IndexMut<usize> for Shape {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.0[index]
}
}
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub enum Object {
Nat, Dtype,
NdArrayType,
Shape,
Tensor,
}
impl std::fmt::Display for Object {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
use crate::category::lang;
#[derive(Debug, PartialEq, Clone)]
pub enum Operation {
Type(TypeOp),
Nat(NatOp),
DtypeConstant(Dtype),
Tensor(TensorOp),
Copy,
Load(lang::Path),
}
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub enum NatOp {
Constant(usize),
Mul,
Add,
}
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub enum TypeOp {
Pack,
Unpack,
Shape,
Dtype,
}
#[derive(Debug, Clone, PartialEq)]
pub enum Scalar {
F32(f32),
U32(u32),
}
#[derive(Debug, Clone, PartialEq)]
pub enum TensorOp {
Map(ScalarOp),
NatToU32,
Cast,
MatMul,
Scalar(Scalar),
Sum,
Max,
Argmax,
Broadcast,
Reshape,
Transpose,
Slice,
Concat,
Arange,
Index,
}
#[derive(Debug, Hash, Clone, PartialEq, Eq)]
pub enum ScalarOp {
Add, Sub, Mul, Div, Neg, Pow, LT, EQ, Cos, Sin, }
impl ScalarOp {
pub fn profile(&self) -> (usize, usize) {
match self {
ScalarOp::Add => (2, 1),
ScalarOp::Sub => (2, 1),
ScalarOp::Mul => (2, 1),
ScalarOp::Div => (2, 1),
ScalarOp::Neg => (1, 1),
ScalarOp::Pow => (2, 1),
ScalarOp::LT => (2, 1),
ScalarOp::EQ => (2, 1),
ScalarOp::Cos => (1, 1),
ScalarOp::Sin => (1, 1),
}
}
}