use crate::category::lang::*;
use crate::typecheck::Value;
use crate::util::build_typed;
use open_hypergraphs::lax::var;
pub trait Module<const A: usize, const B: usize> {
fn ty(&self) -> ([Type; A], [Type; B]);
fn path(&self) -> Path;
fn def(&self, builder: &Builder, args: [Var; A]) -> [Var; B];
fn sort(&self) -> ([Object; A], [Object; B]) {
let (v1, v2) = self.ty();
(v1.map(to_sort), v2.map(to_sort))
}
fn inline(&self, builder: &Builder, args: [Var; A]) -> [Var; B] {
self.def(builder, args)
}
fn op(&self, builder: &Builder, args: [Var; A]) -> [Var; B] {
let result_types = self.sort().1.to_vec();
var::operation(
builder,
&args,
result_types,
Operation::Definition(self.path()),
)
.try_into()
.unwrap() }
fn term(&self) -> Option<TypedTerm> {
let (source_type, target_type) = self.ty();
let source_object = source_type.clone().map(to_sort);
let term = build_typed(source_object, |builder, args| {
self.inline(builder, args).to_vec()
})
.unwrap();
use open_hypergraphs::category::*; if term.target() != target_type.clone().map(to_sort) {
None
} else {
Some(TypedTerm {
term,
source_type: source_type.to_vec(),
target_type: target_type.to_vec(),
})
}
}
}
fn to_sort(value: Type) -> Object {
match value {
Value::Type(_) => Object::NdArrayType,
Value::Shape(_) => Object::Shape,
Value::Nat(_) => Object::Nat,
Value::Dtype(_) => Object::Dtype,
Value::Tensor(_) => Object::Tensor,
}
}
pub trait FnModule<const N: usize>: Module<N, 1> {
fn call(&self, builder: &Builder, args: [Var; N]) -> Var {
let [r] = self.op(builder, args);
r
}
}
impl<const N: usize, T: Module<N, 1>> FnModule<N> for T {}