use arrayfire::Array;
use crate::graph::node::{BinaryReverseFn, Node, UnaryReverseFn};
use crate::tensor::{constant::Constant, variable::Variable};
pub trait SingleParam<Y> {
fn push_unary(&self, result: Array<f32>, reverse: UnaryReverseFn) -> Y;
}
impl<
const B: u64,
const L: u64,
const R: u64,
const C: u64,
const YB: u64,
const YL: u64,
const YR: u64,
const YC: u64,
> SingleParam<Variable<YB, YL, YR, YC>> for Variable<B, L, R, C>
{
fn push_unary(&self, result: Array<f32>, reverse: UnaryReverseFn) -> Variable<YB, YL, YR, YC> {
let node = Node::unary(result, self.into(), reverse);
Variable::new(self.tape().clone(), node)
}
}
impl<
const B: u64,
const L: u64,
const R: u64,
const C: u64,
const YB: u64,
const YL: u64,
const YR: u64,
const YC: u64,
> SingleParam<Constant<YB, YL, YR, YC>> for Constant<B, L, R, C>
{
fn push_unary(&self, result: Array<f32>, _reverse: UnaryReverseFn) -> Constant<YB, YL, YR, YC> {
Constant::new(result)
}
}
pub trait DoubleParam<Y, Z> {
fn push_binary(&self, other: &Y, result: Array<f32>, reverse: BinaryReverseFn) -> Z;
}
impl<
const B: u64,
const L: u64,
const R: u64,
const C: u64,
const YB: u64,
const YL: u64,
const YR: u64,
const YC: u64,
const ZB: u64,
const ZL: u64,
const ZR: u64,
const ZC: u64,
> DoubleParam<Variable<YB, YL, YR, YC>, Variable<ZB, ZL, ZR, ZC>> for Variable<B, L, R, C>
{
fn push_binary(
&self,
other: &Variable<YB, YL, YR, YC>,
result: Array<f32>,
reverse: BinaryReverseFn,
) -> Variable<ZB, ZL, ZR, ZC> {
let node = Node::binary_varvar(result, (self.into(), other.into()), reverse);
Variable::new(self.tape().merge(other.tape()), node)
}
}
impl<
const B: u64,
const L: u64,
const R: u64,
const C: u64,
const YB: u64,
const YL: u64,
const YR: u64,
const YC: u64,
const ZB: u64,
const ZL: u64,
const ZR: u64,
const ZC: u64,
> DoubleParam<Constant<YB, YL, YR, YC>, Variable<ZB, ZL, ZR, ZC>> for Variable<B, L, R, C>
{
fn push_binary(
&self,
other: &Constant<YB, YL, YR, YC>,
result: Array<f32>,
reverse: BinaryReverseFn,
) -> Variable<ZB, ZL, ZR, ZC> {
let node = Node::binary_varconst(result, (self.into(), other.into()), reverse);
Variable::new(self.tape().clone(), node)
}
}
impl<
const B: u64,
const L: u64,
const R: u64,
const C: u64,
const YB: u64,
const YL: u64,
const YR: u64,
const YC: u64,
const ZB: u64,
const ZL: u64,
const ZR: u64,
const ZC: u64,
> DoubleParam<Variable<YB, YL, YR, YC>, Variable<ZB, ZL, ZR, ZC>> for Constant<B, L, R, C>
{
fn push_binary(
&self,
other: &Variable<YB, YL, YR, YC>,
result: Array<f32>,
reverse: BinaryReverseFn,
) -> Variable<ZB, ZL, ZR, ZC> {
let node = Node::binary_constvar(result, (self.into(), other.into()), reverse);
Variable::new(other.tape().clone(), node)
}
}
impl<
const B: u64,
const L: u64,
const R: u64,
const C: u64,
const YB: u64,
const YL: u64,
const YR: u64,
const YC: u64,
const ZB: u64,
const ZL: u64,
const ZR: u64,
const ZC: u64,
> DoubleParam<Constant<YB, YL, YR, YC>, Constant<ZB, ZL, ZR, ZC>> for Constant<B, L, R, C>
{
fn push_binary(
&self,
_other: &Constant<YB, YL, YR, YC>,
result: Array<f32>,
_reverse: BinaryReverseFn,
) -> Constant<ZB, ZL, ZR, ZC> {
Constant::new(result)
}
}