use crate::{
error::{Error, Result},
graph::{Config1, ConfigN, Graph, Value},
store::GradientStore,
Check, Eval, Number,
};
pub trait CoreAlgebra<Data> {
type Value;
fn variable(&mut self, data: Data) -> Self::Value;
fn constant(&mut self, data: Data) -> Self::Value;
fn add(&mut self, v1: &Self::Value, v2: &Self::Value) -> Result<Self::Value>;
fn add_all(&mut self, values: &[&Self::Value]) -> Result<Self::Value>
where
Self::Value: Clone,
{
let mut values = values.iter();
let mut result: Self::Value =
(*values.next().ok_or_else(|| Error::empty(func_name!()))?).clone();
for value in values {
result = self.add(&result, *value)?;
}
Ok(result)
}
}
pub trait HasDims {
type Dims;
fn dims(&self) -> Self::Dims;
}
impl<A> HasDims for crate::graph::Value<A>
where
A: HasDims,
{
type Dims = A::Dims;
#[inline]
fn dims(&self) -> Self::Dims {
self.data().dims()
}
}
impl<T: Number> HasDims for T {
type Dims = ();
#[inline]
fn dims(&self) {}
}
impl<T: HasDims> HasDims for std::sync::Arc<T> {
type Dims = T::Dims;
#[inline]
fn dims(&self) -> Self::Dims {
self.as_ref().dims()
}
}
impl<T: Number> CoreAlgebra<T> for Check {
type Value = ();
#[inline]
fn variable(&mut self, _data: T) {}
#[inline]
fn constant(&mut self, _data: T) {}
#[inline]
fn add(&mut self, _v0: &(), _v1: &()) -> Result<()> {
Ok(())
}
#[inline]
fn add_all(&mut self, _values: &[&()]) -> Result<()> {
Ok(())
}
}
impl<T: Number> CoreAlgebra<T> for Eval {
type Value = T;
#[inline]
fn variable(&mut self, data: T) -> T {
data
}
#[inline]
fn constant(&mut self, data: T) -> T {
data
}
#[inline]
fn add(&mut self, v0: &T, v1: &T) -> Result<T> {
Ok(*v0 + *v1)
}
}
#[cfg(feature = "arrayfire")]
mod af_core {
use super::*;
use crate::error::check_equal_dimensions;
use arrayfire as af;
impl<T: af::HasAfEnum> HasDims for af::Array<T> {
type Dims = af::Dim4;
#[inline]
fn dims(&self) -> af::Dim4 {
self.dims()
}
}
impl HasDims for af::Dim4 {
type Dims = af::Dim4;
#[inline]
fn dims(&self) -> af::Dim4 {
*self
}
}
impl<T: af::HasAfEnum> CoreAlgebra<af::Array<T>> for Check {
type Value = af::Dim4;
#[inline]
fn variable(&mut self, array: af::Array<T>) -> af::Dim4 {
array.dims()
}
#[inline]
fn constant(&mut self, array: af::Array<T>) -> af::Dim4 {
array.dims()
}
#[inline]
fn add(&mut self, v0: &af::Dim4, v1: &af::Dim4) -> Result<af::Dim4> {
check_equal_dimensions(func_name!(), &[v0, v1])
}
#[inline]
fn add_all(&mut self, values: &[&af::Dim4]) -> Result<af::Dim4> {
check_equal_dimensions(func_name!(), values)
}
}
impl<T> CoreAlgebra<af::Array<T>> for Eval
where
T: af::HasAfEnum + af::ImplicitPromote<T, Output = T>,
{
type Value = af::Array<T>;
#[inline]
fn variable(&mut self, array: af::Array<T>) -> af::Array<T> {
array
}
#[inline]
fn constant(&mut self, array: af::Array<T>) -> af::Array<T> {
array
}
#[inline]
fn add(&mut self, v0: &af::Array<T>, v1: &af::Array<T>) -> Result<af::Array<T>> {
<Check as CoreAlgebra<af::Array<T>>>::add(self.check(), &v0.dims(), &v1.dims())?;
Ok(v0 + v1)
}
}
}
macro_rules! impl_graph {
($config:ident) => {
impl<D, E, Dims> CoreAlgebra<D> for Graph<$config<E>>
where
E: Default + Clone + CoreAlgebra<D, Value = D>,
D: HasDims<Dims = Dims> + Clone + 'static + Send + Sync,
Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync,
{
type Value = Value<D>;
fn variable(&mut self, data: D) -> Value<D> {
self.make_variable(data)
}
fn constant(&mut self, data: D) -> Value<D> {
Value::constant(data)
}
fn add(&mut self, v1: &Value<D>, v2: &Value<D>) -> Result<Value<D>> {
let result = self.eval().add(v1.data(), v2.data())?;
let value = self.make_node(result, vec![v1.input(), v2.input()], {
let id1 = v1.id();
let id2 = v2.id();
move |graph, store, gradient| {
if let Some(id) = id1 {
store.add_gradient(graph, id, &gradient)?;
}
if let Some(id) = id2 {
store.add_gradient(graph, id, &gradient)?;
}
Ok(())
}
});
Ok(value)
}
fn add_all(&mut self, values: &[&Value<D>]) -> Result<Value<D>> {
let result = self
.eval()
.add_all(&values.iter().map(|v| v.data()).collect::<Vec<_>>())?;
let inputs = values.iter().map(|v| v.input()).collect::<Vec<_>>();
let value = self.make_node(result, inputs, {
let ids = values.iter().map(|v| v.id()).collect::<Vec<_>>();
move |graph, store, gradient| {
for id in &ids {
if let Some(id) = id {
store.add_gradient(graph, *id, &gradient)?;
}
}
Ok(())
}
});
Ok(value)
}
}
};
}
impl_graph!(Config1);
impl_graph!(ConfigN);