concision_core/traits/
train.rs

1/*
2    Appellation: train <module>
3    Contrib: @FL03
4*/
5/// A trait declaring basic gradient-related routines for a neural network
6pub trait ApplyGradient<Grad, A> {
7    type Output;
8
9    fn apply_gradient(&mut self, grad: &Grad, lr: A) -> crate::Result<Self::Output>;
10
11    fn apply_gradient_with_decay(
12        &mut self,
13        grad: &Grad,
14        lr: A,
15        decay: A,
16    ) -> crate::Result<Self::Output>;
17}
18
19/// This trait extends the [ApplyGradient] trait by allowing for momentum-based optimization
20pub trait ApplyGradientExt<Grad, A>: ApplyGradient<Grad, A> {
21    type Velocity;
22
23    fn apply_gradient_with_momentum(
24        &mut self,
25        grad: &Grad,
26        lr: A,
27        momentum: A,
28        velocity: &mut Self::Velocity,
29    ) -> crate::Result<Self::Output>;
30
31    fn apply_gradient_with_decay_and_momentum(
32        &mut self,
33        grad: &Grad,
34        lr: A,
35        decay: A,
36        momentum: A,
37        velocity: &mut Self::Velocity,
38    ) -> crate::Result<Self::Output>;
39}
40
41/// A simple trait denoting a single backward pass through a layer of a neural network; the
42/// trait
43pub trait Backward<X, Y> {
44    type HParam;
45    type Output;
46
47    fn backward(
48        &mut self,
49        input: &X,
50        delta: &Y,
51        gamma: Self::HParam,
52    ) -> crate::Result<Self::Output>;
53}
54
55/// This trait defines the training process for the network
56pub trait Train<X, Y> {
57    type Output;
58
59    fn train(&mut self, input: &X, target: &Y) -> crate::Result<Self::Output>;
60
61    fn train_for(&mut self, input: &X, target: &Y, epochs: usize) -> crate::Result<Self::Output> {
62        let mut output = None;
63
64        for _ in 0..epochs {
65            output = match self.train(input, target) {
66                Ok(o) => Some(o),
67                Err(e) => {
68                    #[cfg(feature = "tracing")]
69                    tracing::error!("Training failed: {e}");
70                    return Err(e);
71                }
72            }
73        }
74        output.ok_or_else(|| crate::error::Error::TrainingFailed("No output".into()))
75    }
76}
77
78use ndarray::{ArrayBase, Dimension, ScalarOperand};
79use num_traits::{Float, FromPrimitive};
80
81impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D>
82where
83    A: Float + FromPrimitive + ScalarOperand,
84    S: ndarray::DataMut<Elem = A>,
85    T: ndarray::Data<Elem = A>,
86    D: Dimension,
87{
88    type Output = ();
89
90    fn apply_gradient(&mut self, grad: &ArrayBase<T, D>, lr: A) -> crate::Result<Self::Output> {
91        if self.shape() != grad.shape() {
92            return Err(crate::error::Error::ShapeMismatch(
93                self.shape().to_vec(),
94                grad.shape().to_vec(),
95            ));
96        }
97        let batch_size = if grad.shape().len() > 0 {
98            A::from_usize(self.shape()[0]).unwrap()
99        } else {
100            A::one()
101        };
102        self.scaled_add(lr / batch_size, &grad);
103        Ok(())
104    }
105
106    fn apply_gradient_with_decay(
107        &mut self,
108        grad: &ArrayBase<T, D>,
109        lr: A,
110        decay: A,
111    ) -> crate::Result<Self::Output> {
112        if self.shape() != grad.shape() {
113            return Err(crate::error::Error::ShapeMismatch(
114                self.shape().to_vec(),
115                grad.shape().to_vec(),
116            ));
117        }
118        let batch_size = if grad.shape().len() > 0 {
119            A::from_usize(self.shape()[0]).unwrap()
120        } else {
121            A::one()
122        };
123        self.scaled_add(lr / batch_size, &(grad + &*self * decay));
124        Ok(())
125    }
126}
127impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
128where
129    A: Float + FromPrimitive + ScalarOperand,
130    S: ndarray::DataMut<Elem = A>,
131    T: ndarray::Data<Elem = A>,
132    D: Dimension,
133{
134    type Velocity = ndarray::Array<A, D>;
135
136    fn apply_gradient_with_momentum(
137        &mut self,
138        grad: &ArrayBase<T, D>,
139        lr: A,
140        momentum: A,
141        velocity: &mut Self::Velocity,
142    ) -> crate::Result<Self::Output> {
143        if self.shape() != grad.shape() {
144            return Err(crate::error::Error::ShapeMismatch(
145                self.shape().to_vec(),
146                grad.shape().to_vec(),
147            ));
148        }
149        let batch_size = if grad.shape().len() > 0 {
150            A::from_usize(self.shape()[0]).unwrap()
151        } else {
152            A::one()
153        };
154        *velocity = &*velocity * momentum + grad * (A::one() - momentum);
155        self.scaled_add(lr / batch_size, &velocity);
156        Ok(())
157    }
158
159    fn apply_gradient_with_decay_and_momentum(
160        &mut self,
161        grad: &ArrayBase<T, D>,
162        lr: A,
163        decay: A,
164        momentum: A,
165        velocity: &mut Self::Velocity,
166    ) -> crate::Result<Self::Output> {
167        if self.shape() != grad.shape() {
168            return Err(crate::error::Error::ShapeMismatch(
169                self.shape().to_vec(),
170                grad.shape().to_vec(),
171            ));
172        }
173        let batch_size = if grad.shape().len() > 0 {
174            A::from_usize(self.shape()[0]).unwrap()
175        } else {
176            A::one()
177        };
178
179        let adjusted_grad = grad + &*self * decay;
180        *velocity = &*velocity * momentum + adjusted_grad * (A::one() - momentum);
181        self.scaled_add(lr / batch_size, &velocity);
182        Ok(())
183    }
184}