concision_core/traits/
propagation.rs

1/*
2    Appellation: predict <module>
3    Contrib: @FL03
4*/
5
6/// [Backward] propagate a delta through the system;
7pub trait Backward<X, Delta = X> {
8    type Elem;
9    type Output;
10
11    fn backward(
12        &mut self,
13        input: &X,
14        delta: &Delta,
15        gamma: Self::Elem,
16    ) -> crate::Result<Self::Output>;
17}
18
19/// This trait denotes entities capable of performing a single forward step
20pub trait Forward<Rhs> {
21    type Output;
22    /// a single forward step
23    fn forward(&self, input: &Rhs) -> crate::Result<Self::Output>;
24    /// this method enables the forward pass to be generically _activated_ using some closure.
25    /// This is useful for isolating the logic of the forward pass from that of the activation
26    /// function and is often used by layers and models.
27    fn forward_then<F>(&self, input: &Rhs, then: F) -> crate::Result<Self::Output>
28    where
29        F: FnOnce(Self::Output) -> Self::Output,
30    {
31        self.forward(input).map(then)
32    }
33}
34
35/*
36 ************* Implementations *************
37*/
38
39use ndarray::linalg::Dot;
40use ndarray::{ArrayBase, Data, Dimension};
41// impl<X, Y, Dx, A, S, D> Backward<X, Y> for ArrayBase<S, D>
42// where
43//     A: LinalgScalar + FromPrimitive,
44//     D: Dimension,
45//     S: DataMut<Elem = A>,
46//     Dx: core::ops::Mul<A, Output = Dx>,
47//     for<'a> X: Dot<Y, Output = Dx>,
48//     for<'a> &'a Self: core::ops::Add<Dx, Output = Self>,
49
50// {
51//     type Elem = A;
52//     type Output = ();
53
54//     fn backward(
55//         &mut self,
56//         input: &X,
57//         delta: &Y,
58//         gamma: Self::Elem,
59//     ) -> crate::Result<Self::Output> {
60//         let grad = input.dot(delta);
61//         let next = &self + grad * gamma;
62//         self.assign(&next)?;
63//         Ok(())
64
65//     }
66// }
67
68impl<X, Y, A, S, D> Forward<X> for ArrayBase<S, D>
69where
70    A: Clone,
71    D: Dimension,
72    S: Data<Elem = A>,
73    for<'a> X: Dot<ArrayBase<S, D>, Output = Y>,
74{
75    type Output = Y;
76
77    fn forward(&self, input: &X) -> crate::Result<Self::Output> {
78        let output = input.dot(self);
79        Ok(output)
80    }
81}