use crate::category::lang::{Literal, ops};
use crate::prelude::{Builder, Var};
pub use ops::{
argmax, broadcast, cos, dtype, dtype_constant, index, lt, matmul, max, nat_to_u32, pack, param,
pow, reshape, shape, sin, sum, unpack,
};
pub fn cast(builder: &Builder, x: Var, d: impl IntoDtypeVar) -> Var {
ops::cast(builder, x, d.to_var(builder))
}
pub fn arange(builder: &Builder, end: impl IntoNatVar) -> Var {
ops::arange(builder, end.to_var(builder))
}
pub fn concat(builder: &Builder, dim: impl IntoNatVar, x: Var, y: Var) -> Var {
ops::concat(builder, dim.to_var(builder), x, y)
}
pub fn slice(
builder: &Builder,
dim: impl IntoNatVar,
start: impl IntoNatVar,
len: impl IntoNatVar,
x: Var,
) -> Var {
ops::slice(
builder,
dim.to_var(builder),
start.to_var(builder),
len.to_var(builder),
x,
)
}
pub fn lit<T: Into<Literal>>(builder: &Builder, x: T) -> Var {
ops::lit(builder, x.into())
}
pub fn nat(x: u32) -> Literal {
Literal::Nat(x)
}
pub fn constant<T: Into<Literal>>(builder: &Builder, x: T, s: &Var) -> Var {
let x = lit(builder, x);
ops::broadcast(builder, x, s.clone())
}
pub fn inverse(builder: &Builder, x: Var) -> Var {
let shape = shape(builder, x.clone());
let one = constant(builder, 1.0, &shape);
one / x
}
pub fn transpose(builder: &Builder, a: impl IntoNatVar, b: impl IntoNatVar, x: Var) -> Var {
ops::transpose(builder, a.to_var(builder), b.to_var(builder), x)
}
pub trait IntoNatVar {
fn to_var(&self, builder: &Builder) -> Var;
}
impl IntoNatVar for Var {
fn to_var(&self, _builder: &Builder) -> Var {
self.clone()
}
}
impl IntoNatVar for u32 {
fn to_var(&self, builder: &Builder) -> Var {
lit(builder, nat(*self))
}
}
impl IntoNatVar for i32 {
fn to_var(&self, builder: &Builder) -> Var {
lit(builder, nat((*self).try_into().unwrap()))
}
}
impl IntoNatVar for usize {
fn to_var(&self, builder: &Builder) -> Var {
lit(builder, nat((*self).try_into().unwrap()))
}
}
pub trait IntoDtypeVar {
fn to_var(&self, builder: &Builder) -> Var;
}
impl IntoDtypeVar for crate::category::core::Dtype {
fn to_var(&self, builder: &Builder) -> Var {
dtype_constant(builder, crate::category::core::Dtype::F32)
}
}
impl IntoDtypeVar for Var {
fn to_var(&self, _builder: &Builder) -> Var {
self.clone()
}
}