1use crate::{Params, ParamsBase};
6use concision_traits::{Backward, Forward, Norm};
7use ndarray::linalg::Dot;
8use ndarray::{
9 Array, ArrayBase, ArrayView, Data, Dimension, Ix0, Ix1, Ix2, RemoveAxis, ScalarOperand,
10};
11use num_traits::{Float, FromPrimitive, Num, Signed};
12
13macro_rules! impl_tensor_unary {
14 (@impl $method:ident $(where $($where:tt)*)?) => {
15 pub fn $method(&self) -> Params<A, D> $(where $($where)*)? {
16 self.mapv(|x| x.$method())
17 }
18 };
19 ($($method:ident),* $(,)?) => {
20 $(impl_tensor_unary! { @impl $method where A: Float })*
21 }
22}
23
24impl<S, D, A> ParamsBase<S, D, A>
25where
26 A: 'static + Clone,
27 D: Dimension,
28 S: Data<Elem = A>,
29{
30 impl_tensor_unary! {
31 cos,
32 cosh,
33 exp,
34 ln,
35 sin,
36 sinh,
37 sqrt,
38 tan,
39 tanh,
40 }
41 pub fn abs(&self) -> Params<A, D>
43 where
44 A: Signed,
45 {
46 self.mapv(|x| x.abs())
47 }
48 #[cfg(feature = "complex")]
49 pub fn conj(&self) -> Params<A, D>
51 where
52 A: num_complex::ComplexFloat,
53 {
54 self.mapv(|x| x.conj())
55 }
56}
57
58impl<A, S, D> ParamsBase<S, D, A>
59where
60 A: Clone,
61 D: Dimension,
62 S: Data<Elem = A>,
63{
64 pub fn backward<X, Y>(&mut self, input: &X, grad: &Y, lr: A)
66 where
67 Self: Backward<X, Y, Elem = A>,
68 {
69 <Self as Backward<X, Y>>::backward(self, input, grad, lr)
70 }
71 pub fn forward<X, Y>(&self, input: &X) -> Y
74 where
75 Self: Forward<X, Output = Y>,
76 {
77 <Self as Forward<X>>::forward(self, input)
78 }
79}
80
81impl<A, S, D> ParamsBase<S, D, A>
82where
83 A: ScalarOperand + Float + FromPrimitive,
84 D: Dimension,
85 S: Data<Elem = A>,
86{
87 pub fn l1_norm(&self) -> A {
89 let bias = self.bias().l1_norm();
90 let weights = self.weights().l1_norm();
91 bias + weights
92 }
93 pub fn l2_norm(&self) -> A {
95 let bias = self.bias().l2_norm();
96 let weights = self.weights().l2_norm();
97 bias + weights
98 }
99}
100
101impl<A, X, Y, Z, S, D> Forward<X> for ParamsBase<S, D, A>
102where
103 A: Clone,
104 D: Dimension,
105 S: Data<Elem = A>,
106 for<'a> ArrayView<'a, A, D>: Dot<X, Output = Y>,
107 Y: for<'a> core::ops::Add<&'a ArrayBase<S, D::Smaller, A>, Output = Z>,
108{
109 type Output = Z;
110
111 fn forward(&self, input: &X) -> Self::Output {
112 self.weights().t().dot(input) + self.bias()
113 }
114}
115
116impl<A, S, T> Backward<ArrayBase<S, Ix0, A>, ArrayBase<T, Ix0, A>> for Params<A, Ix1>
117where
118 A: Float + FromPrimitive + ScalarOperand,
119 S: Data<Elem = A>,
120 T: Data<Elem = A>,
121{
122 type Elem = A;
123
124 fn backward(
125 &mut self,
126 input: &ArrayBase<S, Ix0, A>,
127 delta: &ArrayBase<T, Ix0, A>,
128 gamma: Self::Elem,
129 ) {
130 self.weights_mut().scaled_add(gamma, &(input * delta));
131 self.bias_mut().scaled_add(gamma, delta);
132 }
133}
134
135impl<A, S, T> Backward<ArrayBase<S, Ix1, A>, ArrayBase<T, Ix1, A>> for Params<A, Ix2>
136where
137 A: Float + FromPrimitive + ScalarOperand,
138 S: Data<Elem = A>,
139 T: Data<Elem = A>,
140{
141 type Elem = A;
142
143 fn backward(
144 &mut self,
145 input: &ArrayBase<S, Ix1, A>,
146 delta: &ArrayBase<T, Ix1, A>,
147 gamma: Self::Elem,
148 ) {
149 self.weights_mut().scaled_add(gamma, &(delta * input));
150 self.bias_mut().scaled_add(gamma, delta);
151 }
152}
153
154impl<A, D1, D2, S1, S2> Backward<ArrayBase<S1, D1, A>, ArrayBase<S2, D2, A>> for Params<A, D1>
155where
156 A: 'static + Copy + Num,
157 D1: RemoveAxis<Smaller = D2>,
158 D2: Dimension<Larger = D1>,
159 S1: Data<Elem = A>,
160 S2: Data<Elem = A>,
161 for<'b> &'b ArrayBase<S1, D1, A>: Dot<ArrayView<'b, A, D2>, Output = Array<A, D2>>,
162{
163 type Elem = A;
164
165 fn backward(
166 &mut self,
167 input: &ArrayBase<S1, D1, A>,
168 delta: &ArrayBase<S2, D2, A>,
169 gamma: Self::Elem,
170 ) {
171 self.weights_mut().backward(input, delta, gamma);
172 self.bias_mut().scaled_add(gamma, delta);
173 }
174}