concision_params/impls/
impl_params_ops.rs

1/*
2    Appellation: impl_ops <module>
3    Contrib: @FL03
4*/
5use crate::{Params, ParamsBase};
6use concision_traits::{Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::{
9    Array, ArrayBase, ArrayView, Data, Dimension, Ix0, Ix1, Ix2, RemoveAxis, ScalarOperand,
10};
11use num_traits::{Float, FromPrimitive, Num, Signed};
12
13macro_rules! impl_tensor_unary {
14    (@impl $method:ident $(where $($where:tt)*)?) => {
15        pub fn $method(&self) -> Params<A, D> $(where $($where)*)? {
16            self.mapv(|x| x.$method())
17        }
18    };
19    ($($method:ident),* $(,)?) => {
20        $(impl_tensor_unary! { @impl $method where A: Float })*
21    }
22}
23
24impl<S, D, A> ParamsBase<S, D, A>
25where
26    A: 'static + Clone,
27    D: Dimension,
28    S: Data<Elem = A>,
29{
30    impl_tensor_unary! {
31        cos,
32        cosh,
33        exp,
34        ln,
35        sin,
36        sinh,
37        sqrt,
38        tan,
39        tanh,
40    }
41    /// take the absolute value of each element within the parameters
42    pub fn abs(&self) -> Params<A, D>
43    where
44        A: Signed,
45    {
46        self.mapv(|x| x.abs())
47    }
48    #[cfg(feature = "complex")]
49    /// compute the conjugate of each element within the parameters
50    pub fn conj(&self) -> Params<A, D>
51    where
52        A: num_complex::ComplexFloat,
53    {
54        self.mapv(|x| x.conj())
55    }
56}
57
58impl<A, S, D> ParamsBase<S, D, A>
59where
60    A: Clone,
61    D: Dimension,
62    S: Data<Elem = A>,
63{
64    /// execute a single backward propagation
65    pub fn backward<X, Y>(&mut self, input: &X, grad: &Y, lr: A)
66    where
67        Self: Backward<X, Y, Elem = A>,
68    {
69        <Self as Backward<X, Y>>::backward(self, input, grad, lr)
70    }
71    /// invoke a single forward step; this method is simply a convienience method implemented
72    /// to reduce the number of `Forward` imports.
73    pub fn forward<X, Y>(&self, input: &X) -> Y
74    where
75        Self: Forward<X, Output = Y>,
76    {
77        <Self as Forward<X>>::forward(self, input)
78    }
79}
80
81impl<A, S, D> ParamsBase<S, D, A>
82where
83    A: ScalarOperand + Float + FromPrimitive,
84    D: Dimension,
85    S: Data<Elem = A>,
86{
87    /// computes the `l1` normalization of the current weights and biases
88    pub fn l1_norm(&self) -> A {
89        let bias = self.bias().l1_norm();
90        let weights = self.weights().l1_norm();
91        bias + weights
92    }
93    /// Returns the L2 norm of the parameters (bias and weights).
94    pub fn l2_norm(&self) -> A {
95        let bias = self.bias().l2_norm();
96        let weights = self.weights().l2_norm();
97        bias + weights
98    }
99}
100
101impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D, A>
102where
103    A: Clone,
104    D: Dimension,
105    S: Data<Elem = A>,
106    for<'a> ArrayView<'a, A, D>: Dot<X, Output = Y>,
107    Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller, A>, Output = Z>,
108{
109    type Output = Z;
110
111    fn forward(&self, input: &X) -> Self::Output {
112        self.weights().t().dot(input) + self.bias()
113    }
114}
115
116impl<A, S, T> Backward<ArrayBase<S, Ix0, A>, ArrayBase<T, Ix0, A>> for Params<A, Ix1>
117where
118    A: Float + FromPrimitive + ScalarOperand,
119    S: Data<Elem = A>,
120    T: Data<Elem = A>,
121{
122    type Elem = A;
123
124    fn backward(
125        &mut self,
126        input: &ArrayBase<S, Ix0, A>,
127        delta: &ArrayBase<T, Ix0, A>,
128        gamma: Self::Elem,
129    ) {
130        self.weights_mut().scaled_add(gamma, &(input * delta));
131        self.bias_mut().scaled_add(gamma, delta);
132    }
133}
134
135impl<A, S, T> Backward<ArrayBase<S, Ix1, A>, ArrayBase<T, Ix1, A>> for Params<A, Ix2>
136where
137    A: Float + FromPrimitive + ScalarOperand,
138    S: Data<Elem = A>,
139    T: Data<Elem = A>,
140{
141    type Elem = A;
142
143    fn backward(
144        &mut self,
145        input: &ArrayBase<S, Ix1, A>,
146        delta: &ArrayBase<T, Ix1, A>,
147        gamma: Self::Elem,
148    ) {
149        self.weights_mut().scaled_add(gamma, &(delta * input));
150        self.bias_mut().scaled_add(gamma, delta);
151    }
152}
153
154impl<A, D1, D2, S1, S2> Backward<ArrayBase<S1, D1, A>, ArrayBase<S2, D2, A>> for Params<A, D1>
155where
156    A: 'static + Copy + Num,
157    D1: RemoveAxis<Smaller = D2>,
158    D2: Dimension<Larger = D1>,
159    S1: Data<Elem = A>,
160    S2: Data<Elem = A>,
161    for<'b> &'b ArrayBase<S1, D1, A>: Dot<ArrayView<'b, A, D2>, Output = Array<A, D2>>,
162{
163    type Elem = A;
164
165    fn backward(
166        &mut self,
167        input: &ArrayBase<S1, D1, A>,
168        delta: &ArrayBase<S2, D2, A>,
169        gamma: Self::Elem,
170    ) {
171        self.weights_mut().backward(input, delta, gamma);
172        self.bias_mut().scaled_add(gamma, delta);
173    }
174}