concision_traits/impls/
impl_backward.rs1use crate::Backward;
7use ndarray::linalg::Dot;
8use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension};
9use num_traits::Num;
10
11impl<A, S, D, S1, D1, S2, D2> Backward<ArrayBase<S1, D1, A>, ArrayBase<S2, D2, A>>
12 for ArrayBase<S, D, A>
13where
14 A: 'static + Copy + Num,
15 D: Dimension,
16 S: DataMut<Elem = A>,
17 D1: Dimension,
18 D2: Dimension,
19 S1: Data<Elem = A>,
20 S2: Data<Elem = A>,
21 for<'b> &'b ArrayBase<S1, D1, A>: Dot<ArrayView<'b, A, D2>, Output = Array<A, D2>>,
22{
23 type Elem = A;
24
25 fn backward(
26 &mut self,
27 input: &ArrayBase<S1, D1, A>,
28 delta: &ArrayBase<S2, D2, A>,
29 gamma: Self::Elem,
30 ) {
31 self.scaled_add(gamma, &input.dot(&delta.t()))
32 }
33}