use crate::{
core::{CoreAlgebra, HasDims},
error::Result,
graph::{Config1, ConfigN, Graph, Value},
linked::LinkedAlgebra,
store::GradientStore,
Check, Eval, Number,
};
pub trait ArithAlgebra<Value> {
fn sub(&mut self, v0: &Value, v1: &Value) -> Result<Value>;
fn mul(&mut self, v0: &Value, v1: &Value) -> Result<Value>;
fn zeros(&mut self, v: &Value) -> Value;
fn ones(&mut self, v: &Value) -> Value;
fn neg(&mut self, v: &Value) -> Value {
let z = self.zeros(v);
self.sub(&z, v).expect("subtracting zero should not fail")
}
}
#[cfg(feature = "arrayfire")]
mod af_arith {
use super::*;
use crate::error::check_equal_dimensions;
use arrayfire as af;
impl<T> ArithAlgebra<af::Array<T>> for Eval
where
Self: CoreAlgebra<af::Array<T>, Value = af::Array<T>>,
T: af::HasAfEnum
+ af::ImplicitPromote<T, Output = T>
+ af::ConstGenerator<OutType = T>
+ num::Zero
+ num::One,
{
#[inline]
fn zeros(&mut self, v: &af::Array<T>) -> af::Array<T> {
af::constant(T::zero(), v.dims())
}
#[inline]
fn ones(&mut self, v: &af::Array<T>) -> af::Array<T> {
af::constant(T::one(), v.dims())
}
#[inline]
fn neg(&mut self, v: &af::Array<T>) -> af::Array<T> {
af::constant(T::zero(), v.dims()) - v
}
#[inline]
fn sub(&mut self, v0: &af::Array<T>, v1: &af::Array<T>) -> Result<af::Array<T>> {
self.check().sub(&v0.dims(), &v1.dims())?;
Ok(v0 - v1)
}
#[inline]
fn mul(&mut self, v0: &af::Array<T>, v1: &af::Array<T>) -> Result<af::Array<T>> {
self.check().mul(&v0.dims(), &v1.dims())?;
Ok(v0 * v1)
}
}
impl ArithAlgebra<af::Dim4> for Check {
#[inline]
fn zeros(&mut self, v: &af::Dim4) -> af::Dim4 {
*v
}
#[inline]
fn ones(&mut self, v: &af::Dim4) -> af::Dim4 {
*v
}
#[inline]
fn neg(&mut self, v: &af::Dim4) -> af::Dim4 {
*v
}
#[inline]
fn sub(&mut self, v0: &af::Dim4, v1: &af::Dim4) -> Result<af::Dim4> {
check_equal_dimensions(func_name!(), &[v0, v1])
}
#[inline]
fn mul(&mut self, v0: &af::Dim4, v1: &af::Dim4) -> Result<af::Dim4> {
check_equal_dimensions(func_name!(), &[v0, v1])
}
}
}
impl<T: Number> ArithAlgebra<T> for Eval {
#[inline]
fn zeros(&mut self, _v: &T) -> T {
T::zero()
}
#[inline]
fn ones(&mut self, _v: &T) -> T {
T::one()
}
#[inline]
fn neg(&mut self, v: &T) -> T {
-(*v)
}
#[inline]
fn sub(&mut self, v0: &T, v1: &T) -> Result<T> {
Ok(*(v0) - *(v1))
}
#[inline]
fn mul(&mut self, v0: &T, v1: &T) -> Result<T> {
Ok((*v0) * (*v1))
}
}
impl ArithAlgebra<()> for Check {
#[inline]
fn zeros(&mut self, _v: &()) {}
#[inline]
fn ones(&mut self, _v: &()) {}
#[inline]
fn neg(&mut self, _v: &()) {}
#[inline]
fn sub(&mut self, _v0: &(), _v1: &()) -> Result<()> {
Ok(())
}
#[inline]
fn mul(&mut self, _v0: &(), _v1: &()) -> Result<()> {
Ok(())
}
}
macro_rules! impl_graph {
($config:ident) => {
impl<D, E, Dims> ArithAlgebra<Value<D>> for Graph<$config<E>>
where
E: Default
+ Clone
+ CoreAlgebra<D, Value = D>
+ ArithAlgebra<D>
+ LinkedAlgebra<Value<D>, D>,
D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync,
{
fn zeros(&mut self, v: &Value<D>) -> Value<D> {
let result = self.eval().zeros(v.data());
self.constant(result)
}
fn ones(&mut self, v: &Value<D>) -> Value<D> {
let result = self.eval().ones(v.data());
self.constant(result)
}
fn neg(&mut self, v: &Value<D>) -> Value<D> {
let result = self.eval().neg(v.data());
self.make_node(result, vec![v.input()], {
let id = v.id();
move |graph, store, gradient| {
if let Some(id) = id {
let n = graph.neg(&gradient);
store.add_gradient(graph, id, &n)?;
}
Ok(())
}
})
}
fn sub(&mut self, v0: &Value<D>, v1: &Value<D>) -> Result<Value<D>> {
let result = self.eval().sub(v0.data(), v1.data())?;
let value = self.make_node(result, vec![v0.input(), v1.input()], {
let id0 = v0.id();
let id1 = v1.id();
move |graph, store, gradient| {
if let Some(id) = id0 {
store.add_gradient(graph, id, &gradient)?;
}
if let Some(id) = id1 {
let n = graph.neg(&gradient);
store.add_gradient(graph, id, &n)?;
}
Ok(())
}
});
Ok(value)
}
fn mul(&mut self, v0: &Value<D>, v1: &Value<D>) -> Result<Value<D>> {
let result = self.eval().mul(v0.data(), v1.data())?;
let value = self.make_node(result, vec![v0.input(), v1.input()], {
let v0 = v0.clone();
let v1 = v1.clone();
move |graph, store, gradient| {
if let Some(id) = v0.id() {
let c1 = graph.link(&v1);
let grad = graph.mul(&gradient, c1)?;
store.add_gradient(graph, id, &grad)?;
}
if let Some(id) = v1.id() {
let c0 = graph.link(&v0);
let grad = graph.mul(c0, &gradient)?;
store.add_gradient(graph, id, &grad)?;
}
Ok(())
}
});
Ok(value)
}
}
};
}
impl_graph!(Config1);
impl_graph!(ConfigN);