use std::collections::HashMap;
use crate::category::{core, lang};
use crate::definition::Def;
use crate::path::Path;
use std::fmt::Debug;
#[derive(Clone, Debug)]
pub struct Environment {
pub declarations: HashMap<Path, core::Operation>,
pub definitions: HashMap<Path, lang::TypedTerm>,
}
impl Environment {
pub fn op_to_core(&self, op: lang::Operation) -> Def<Path, core::Operation> {
op_to_core(op, &self.declarations)
}
pub fn to_core(&self, term: lang::Term) -> core::Term {
to_core(term, &self.declarations)
}
}
pub fn to_core(term: lang::Term, core_ops: &HashMap<lang::Path, core::Operation>) -> core::Term {
term.map_edges(|e| op_to_core(e, core_ops))
}
fn op_to_core(
op: lang::Operation,
core_ops: &HashMap<lang::Path, core::Operation>,
) -> Def<Path, core::Operation> {
match op {
lang::Operation::Definition(path) => Def::Def(path),
lang::Operation::Declaration(path) => match core_ops.get(&path) {
Some(op) => Def::Arr(op.clone()),
None => Def::Def(path.clone()),
},
lang::Operation::Literal(lit) => Def::Arr(match lit {
lang::Literal::F32(x) => {
core::Operation::Tensor(core::TensorOp::Scalar(core::Scalar::F32(x)))
}
lang::Literal::U32(x) => {
core::Operation::Tensor(core::TensorOp::Scalar(core::Scalar::U32(x)))
}
lang::Literal::Nat(n) => core::Operation::Nat(core::NatOp::Constant(n as usize)),
lang::Literal::Dtype(d) => core::Operation::DtypeConstant(d),
}),
}
}
pub fn env_to_core(env: Environment) -> Environment {
env
}
macro_rules! path{
[$($x:expr),* $(,)?] => {
$crate::path::path(vec![$($x),*]).expect("invalid operation name")
};
}
pub(crate) fn core_declarations() -> HashMap<lang::Path, core::Operation> {
use crate::category::core::{NatOp, Operation, ScalarOp::*, TensorOp::*, TypeOp};
HashMap::from([
(path!["cartesian", "copy"], Operation::Copy),
(path!["tensor", "add"], Operation::Tensor(Map(Add))),
(path!["tensor", "sub"], Operation::Tensor(Map(Sub))),
(path!["tensor", "neg"], Operation::Tensor(Map(Neg))),
(path!["tensor", "mul"], Operation::Tensor(Map(Mul))),
(path!["tensor", "div"], Operation::Tensor(Map(Div))),
(path!["tensor", "pow"], Operation::Tensor(Map(Pow))),
(path!["tensor", "sin"], Operation::Tensor(Map(Sin))),
(path!["tensor", "cos"], Operation::Tensor(Map(Cos))),
(path!["tensor", "lt"], Operation::Tensor(Map(LT))),
(path!["tensor", "eq"], Operation::Tensor(Map(EQ))),
(path!["tensor", "matmul"], Operation::Tensor(MatMul)),
(path!["tensor", "reshape"], Operation::Tensor(Reshape)),
(path!["tensor", "transpose"], Operation::Tensor(Transpose)),
(path!["tensor", "broadcast"], Operation::Tensor(Broadcast)),
(path!["tensor", "cast"], Operation::Tensor(Cast)),
(path!["tensor", "index"], Operation::Tensor(Index)),
(path!["tensor", "slice"], Operation::Tensor(Slice)),
(path!["tensor", "sum"], Operation::Tensor(Sum)),
(path!["tensor", "max"], Operation::Tensor(Max)),
(path!["tensor", "argmax"], Operation::Tensor(Argmax)),
(path!["tensor", "arange"], Operation::Tensor(Arange)),
(path!["tensor", "concat"], Operation::Tensor(Concat)),
(path!["tensor", "nat_to_u32"], Operation::Tensor(NatToU32)),
(path!["tensor", "shape"], Operation::Type(TypeOp::Shape)),
(path!["tensor", "dtype"], Operation::Type(TypeOp::Dtype)),
(path!["shape", "pack"], Operation::Type(TypeOp::Pack)),
(path!["shape", "unpack"], Operation::Type(TypeOp::Unpack)),
(path!["nat", "add"], Operation::Nat(NatOp::Add)),
(path!["nat", "mul"], Operation::Nat(NatOp::Mul)),
])
}