use crate::autotuple::AutoTuple;
use crate::traits::{InstOne, InstZero, GradientIdentity};
use ndarray::{ArrayBase, DataOwned, Dimension, RawDataClone, DimAdd, DimMax, OwnedRepr, IxDyn, LinalgScalar};
use num::traits::{One, Zero};
use std::ops::{Add, Mul};
use crate::forward::ForwardMul;
use crate::gradienttype::GradientType;
use crate::ad_ndarray::dimabssub::DimAbsSub;
use ndarray_einsum_beta::einsum;
#[cfg(test)]
use ndarray::{Array0, Array1, Array2, arr1};
impl<A, S, D> InstZero for ArrayBase<S, D>
where
D: Dimension,
S: DataOwned<Elem = A>,
A: Clone + InstZero + Zero,
Self: Sized + Add<Self, Output = Self>,
{
fn zero(&self) -> Self {
Self::zeros(self.dim())
}
fn is_zero(&self) -> bool {
self.iter().all(|x| Zero::is_zero(x))
}
}
impl<A, S, D> InstOne for ArrayBase<S, D>
where
D: Dimension,
S: DataOwned<Elem = A>,
A: Clone + InstOne + One,
Self: Sized + Mul<Self, Output = Self>,
{
fn one(&self) -> Self {
Self::ones(self.dim())
}
}
impl<AI, DI, AG, DG> GradientIdentity for ArrayBase<OwnedRepr<AI>, DI>
where
DI: Dimension,
DG: Dimension,
AI: Clone + GradientType<AI, GradientType = AG>,
AG: Clone + InstOne + One + Zero,
Self: Sized + GradientType<Self, GradientType = ArrayBase<OwnedRepr<AG>, DG>>,
{
fn grad_identity(&self) -> ArrayBase<OwnedRepr<AG>, DG>
{
let grad_shape = self.shape().iter().chain(self.shape().iter()).map(|x| *x).collect::<Vec<_>>();
let mut grad: ArrayBase<OwnedRepr<AG>, IxDyn> = ArrayBase::<OwnedRepr<AG>, IxDyn>::zeros(grad_shape);
for (i, x) in self.shape().iter().enumerate()
{
for j in 0..*x
{
for (k, _) in self.shape().iter().enumerate() {
let mut idx = vec![k; grad.ndim()];
idx[i] = j;
idx[i + self.ndim()] = j;
println!("{:?}", idx);
grad[idx.as_slice()] = <AG as One>::one();
}
}
}
grad.into_dimensionality::<DG>().unwrap()
}
}
impl<A, S, D> From<ArrayBase<S, D>> for AutoTuple<(ArrayBase<S, D>,)>
where
D: Dimension,
S: DataOwned<Elem = A> + RawDataClone,
A: Clone + PartialEq,
{
fn from(arr: ArrayBase<S, D>) -> Self {
AutoTuple::new((arr,))
}
}
impl<AI, DI, AO, DO, AG, DG> GradientType<ArrayBase<OwnedRepr<AO>, DO>> for ArrayBase<OwnedRepr<AI>, DI>
where
DI: Dimension,
DO: Dimension,
DG: Dimension,
DI: DimAdd<DO, Output = DG>,
AI: GradientType<AO, GradientType = AG>,
{
type GradientType = ArrayBase<OwnedRepr<AG>, DG>;
}
#[test]
fn test_gradient_type() {
let a: Array1<f64> = <Array1<f64> as GradientType<Array0<f64>>>::GradientType::zeros(1);
assert_eq!(a, arr1(&[0.0]));
let b: AutoTuple<(Array1<f64>, Array2<f64>)> = <<AutoTuple<(Array1<f64>,)> as GradientType<AutoTuple<(Array0<f64>, Array1<f64>)>>>::GradientType as Default>::default();
assert_eq!(b, AutoTuple::new((<Array1<f64> as Default>::default(), <Array2<f64> as Default>::default())));
}
fn get_einsum_str(op1_ndim: u8, op2_ndim: u8, sum_idxs: u8) -> String {
assert!(op1_ndim >= sum_idxs);
assert!(op2_ndim >= sum_idxs);
assert!(sum_idxs <= 26u8);
assert!(op1_ndim <= 26u8);
assert!(op2_ndim <= 26u8);
let sum_str = (0u8..sum_idxs).map(|i| (i + 97u8) as char).collect::<String>();
let op1_str = (sum_idxs..op1_ndim).map(|i| (i + 97u8) as char).collect::<String>();
let op2_str = (op1_ndim..op1_ndim + op2_ndim - sum_idxs).map(|i| (i + 97u8) as char).collect::<String>();
format!("{}{},{}{}->{}{}", sum_str, op1_str, op2_str, sum_str, op2_str, op1_str)
}
impl<AI, DI, AS, DS, DG, DR, MAXGD >
ForwardMul<
ArrayBase<OwnedRepr<AI>, DI>,
ArrayBase<OwnedRepr<AS>, DG>,
> for ArrayBase<OwnedRepr<AS>, DS>
where
DI: Dimension,
DS: Dimension + DimMax<DG, Output = MAXGD>,
MAXGD: Dimension + DimAbsSub<DI, Output = DR>,
DG: Dimension + DimMax<DS, Output = MAXGD>,
DR: Dimension,
AI: Clone,
AS: Clone + Mul<AS, Output = AS> + LinalgScalar,
{
type ResultGrad = ArrayBase<OwnedRepr<AS>, DR>;
fn forward_mul(
&self,
other: &ArrayBase<OwnedRepr<AS>, DG>,
) -> Self::ResultGrad {
let res_dyn: ArrayBase<OwnedRepr<AS>, IxDyn> =
einsum(&get_einsum_str(self.ndim().try_into().unwrap(), other.ndim().try_into().unwrap(), DI::NDIM.unwrap().try_into().unwrap()), &[self, other]).unwrap();
let res: ArrayBase<OwnedRepr<AS>, DR> = res_dyn.into_dimensionality::<DR>().unwrap();
res
}
}