concision_traits/impls/
impl_backward.rs

1/*
2    Appellation: impl_backward <module>
3    Created At: 2025.12.14:09:36:08
4    Contrib: @FL03
5*/
6use crate::Backward;
7use ndarray::linalg::Dot;
8use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension};
9use num_traits::Num;
10
11impl<A, S, D, S1, D1, S2, D2> Backward<ArrayBase<S1, D1, A>, ArrayBase<S2, D2, A>>
12    for ArrayBase<S, D, A>
13where
14    A: 'static + Copy + Num,
15    D: Dimension,
16    S: DataMut<Elem = A>,
17    D1: Dimension,
18    D2: Dimension,
19    S1: Data<Elem = A>,
20    S2: Data<Elem = A>,
21    for<'b> &'b ArrayBase<S1, D1, A>: Dot<ArrayView<'b, A, D2>, Output = Array<A, D2>>,
22{
23    type Elem = A;
24
25    fn backward(
26        &mut self,
27        input: &ArrayBase<S1, D1, A>,
28        delta: &ArrayBase<S2, D2, A>,
29        gamma: Self::Elem,
30    ) {
31        self.scaled_add(gamma, &input.dot(&delta.t()))
32    }
33}