1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// Copyright (c) Facebook, Inc. and its affiliates
// SPDX-License-Identifier: MIT OR Apache-2.0

use crate::{core, graph, net, store, Check, Eval, Graph1, GraphN};
use arrayfire as af;

/// Generic trait for an algebra implementing all known operations over `af::Array<T>` (and `T`) for a
/// given float type `T`.
pub trait AfAlgebra<T>:
    net::HasGradientReader<GradientReader = <Self as AfAlgebra<T>>::GradientReader>
    + core::CoreAlgebra<af::Array<T>, Value = <Self as AfAlgebra<T>>::Value>
    + core::CoreAlgebra<T, Value = <Self as AfAlgebra<T>>::Scalar>
    + crate::matrix::MatrixAlgebra<<Self as AfAlgebra<T>>::Value>
    + crate::array::ArrayAlgebra<
        <Self as AfAlgebra<T>>::Value,
        Scalar = <Self as AfAlgebra<T>>::Scalar,
    > + crate::analytic::AnalyticAlgebra<<Self as AfAlgebra<T>>::Value>
    + crate::analytic::AnalyticAlgebra<<Self as AfAlgebra<T>>::Scalar>
    + crate::arith::ArithAlgebra<<Self as AfAlgebra<T>>::Value>
    + crate::arith::ArithAlgebra<<Self as AfAlgebra<T>>::Scalar>
    + crate::const_arith::ConstArithAlgebra<<Self as AfAlgebra<T>>::Value, T>
    + crate::const_arith::ConstArithAlgebra<<Self as AfAlgebra<T>>::Scalar, T>
    + crate::const_arith::ConstArithAlgebra<<Self as AfAlgebra<T>>::Value, i16>
    + crate::const_arith::ConstArithAlgebra<<Self as AfAlgebra<T>>::Scalar, i16>
    + crate::compare::CompareAlgebra<<Self as AfAlgebra<T>>::Value>
    + crate::compare::CompareAlgebra<<Self as AfAlgebra<T>>::Scalar>
    + crate::array_compare::ArrayCompareAlgebra<<Self as AfAlgebra<T>>::Value>
where
    T: Float,
{
    type Scalar;
    type Value: net::HasGradientId;
    type GradientReader: store::GradientReader<
        <<Self as AfAlgebra<T>>::Value as net::HasGradientId>::GradientId,
        af::Array<T>,
    >;
}

impl<T: Float> AfAlgebra<T> for Eval {
    type Scalar = T;
    type Value = af::Array<T>;
    type GradientReader = store::EmptyGradientMap;
}

impl<T: Float> AfAlgebra<T> for Check {
    type Scalar = ();
    type Value = af::Dim4;
    type GradientReader = store::EmptyGradientMap;
}

impl<T: Float> AfAlgebra<T> for Graph1 {
    type Scalar = graph::Value<T>;
    type Value = graph::Value<af::Array<T>>;
    type GradientReader = store::GenericGradientMap1;
}

impl<T: Float> AfAlgebra<T> for GraphN {
    type Scalar = graph::Value<T>;
    type Value = graph::Value<af::Array<T>>;
    type GradientReader = store::GenericGradientMapN;
}

/// All supported float types.
pub trait Float:
    crate::Number
    + Default
    + PartialOrd
    + num::Float
    + From<i16>
    + num::pow::Pow<i16, Output = Self>
    + num::pow::Pow<Self, Output = Self>
    + af::HasAfEnum<
        InType = Self,
        AggregateOutType = Self,
        ProductOutType = Self,
        UnaryOutType = Self,
        AbsOutType = Self,
    > + af::ImplicitPromote<Self, Output = Self>
    + af::ConstGenerator<OutType = Self>
    + af::Convertable<OutType = Self>
    + af::FloatingPoint
    + for<'a> std::ops::Div<&'a af::Array<Self>, Output = af::Array<Self>>
{
}

impl Float for f32 {}
impl Float for f64 {}

/// An AfAlgebra for all supported floats.
pub trait FullAlgebra:
    AfAlgebra<f32, GradientReader = <Self as FullAlgebra>::GradientReader>
    + AfAlgebra<f64, GradientReader = <Self as FullAlgebra>::GradientReader>
{
    type GradientReader;
}

impl FullAlgebra for Eval {
    type GradientReader = store::EmptyGradientMap;
}

impl FullAlgebra for Check {
    type GradientReader = store::EmptyGradientMap;
}

impl FullAlgebra for Graph1 {
    type GradientReader = store::GenericGradientMap1;
}

impl FullAlgebra for GraphN {
    type GradientReader = store::GenericGradientMapN;
}

/// Convenient functions used for testing.
pub mod testing {
    use super::*;
    use crate::array::ArrayAlgebra;

    /// Estimate gradient along the given direction.
    #[allow(clippy::suspicious_operation_groupings)]
    pub fn estimate_gradient<T, F>(
        input: &af::Array<T>,
        direction: &af::Array<T>,
        epsilon: T,
        f: F,
    ) -> af::Array<T>
    where
        T: Float + std::fmt::Display,
        F: Fn(&af::Array<T>) -> af::Array<T>,
    {
        let mut v = vec![T::zero(); input.elements()];
        input.host(&mut v);

        let mut gradient = vec![T::zero(); input.elements()];
        for i in 0..input.elements() {
            let x = v[i];

            v[i] = x + epsilon;
            let out = f(&af::Array::new(&v, input.dims()));
            let y2 = Eval::default().dot(&out, direction).unwrap();

            v[i] = x - epsilon;
            let out = f(&af::Array::new(&v, input.dims()));
            let y1 = Eval::default().dot(&out, direction).unwrap();

            gradient[i] = (y2 - y1) / (epsilon + epsilon);
            v[i] = x;
        }

        af::Array::new(&gradient, input.dims())
    }

    /// Assert that the two arrays are close for the L-infinity norm.
    pub fn assert_almost_all_equal<T>(v1: &af::Array<T>, v2: &af::Array<T>, precision: T)
    where
        T: af::HasAfEnum<AbsOutType = T, InType = T, BaseType = T>
            + af::ImplicitPromote<T, Output = T>
            + af::Fromf64
            + std::cmp::PartialOrd,
    {
        assert_eq!(v1.dims(), v2.dims());
        let d = af::max_all(&af::abs(&(v1 - v2))).0;
        assert!(d < precision);
    }
}