use crate::category::core;
use crate::category::lang;
use super::module::*;
pub use crate::pass::to_core::Environment;
use crate::pass::to_core::core_declarations;
use std::collections::HashMap;
fn to_pair<const A: usize, const B: usize, T: Module<A, B>>(
def: T,
) -> (lang::Path, lang::TypedTerm) {
(def.path(), def.term().unwrap())
}
fn definitions() -> HashMap<lang::Path, lang::TypedTerm> {
use super::nn::*;
HashMap::from([
to_pair(Sigmoid),
to_pair(Exp),
to_pair(Sqrt),
to_pair(Gelu),
])
}
pub fn stdlib() -> Environment {
Environment {
declarations: core_declarations(),
definitions: definitions(),
}
}
pub fn to_load_ops<'a, I>(
prefix: lang::Path,
paths: I,
) -> impl Iterator<Item = (lang::Path, core::Operation)>
where
I: IntoIterator<Item = &'a lang::Path>,
{
paths.into_iter().map(move |key| {
let param_path = prefix.concat(key);
(param_path, core::Operation::Load(key.clone()))
})
}