concision_ext/
simple.rs

1/*
2    Appellation: simple <module>
3    Contrib: @FL03
4*/
5use cnc::nn::{Model, ModelFeatures, ModelParams, NeuralError, StandardModelConfig, Train};
6use cnc::{Forward, Norm, Params, ReLU, Sigmoid};
7
8use ndarray::prelude::*;
9use ndarray::{Data, ScalarOperand};
10use num_traits::{Float, FromPrimitive, NumAssign};
11
12pub struct SimpleModel<T = f64> {
13    pub config: StandardModelConfig<T>,
14    pub features: ModelFeatures,
15    pub params: ModelParams<T>,
16}
17
18impl<T> SimpleModel<T> {
19    pub fn new(config: StandardModelConfig<T>, features: ModelFeatures) -> Self
20    where
21        T: Clone + Default,
22    {
23        let params = ModelParams::default(features);
24        SimpleModel {
25            config,
26            features,
27            params,
28        }
29    }
30    #[cfg(feature = "rand")]
31    pub fn init(self) -> Self
32    where
33        T: Float + FromPrimitive,
34        cnc::init::rand_distr::StandardNormal: cnc::init::rand_distr::Distribution<T>,
35    {
36        let params = ModelParams::glorot_normal(self.features);
37        SimpleModel { params, ..self }
38    }
39
40    pub const fn config(&self) -> &StandardModelConfig<T> {
41        &self.config
42    }
43
44    pub fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
45        &mut self.config
46    }
47
48    pub const fn features(&self) -> ModelFeatures {
49        self.features
50    }
51
52    pub const fn params(&self) -> &ModelParams<T> {
53        &self.params
54    }
55
56    pub fn params_mut(&mut self) -> &mut ModelParams<T> {
57        &mut self.params
58    }
59}
60
61impl<T> Model<T> for SimpleModel<T> {
62    type Config = StandardModelConfig<T>;
63    type Layout = ModelFeatures;
64
65    fn config(&self) -> &StandardModelConfig<T> {
66        &self.config
67    }
68
69    fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
70        &mut self.config
71    }
72
73    fn layout(&self) -> ModelFeatures {
74        self.features
75    }
76
77    fn params(&self) -> &ModelParams<T> {
78        &self.params
79    }
80
81    fn params_mut(&mut self) -> &mut ModelParams<T> {
82        &mut self.params
83    }
84}
85
86impl<A, S, D> Forward<ArrayBase<S, D>> for SimpleModel<A>
87where
88    A: Float + FromPrimitive + ScalarOperand,
89    D: Dimension,
90    S: Data<Elem = A>,
91    Params<A>: Forward<Array<A, D>, Output = Array<A, D>>,
92{
93    type Output = Array<A, D>;
94
95    fn forward(&self, input: &ArrayBase<S, D>) -> cnc::Result<Self::Output> {
96        let mut output = self
97            .params()
98            .input()
99            .forward_then(&input.to_owned(), |y| y.relu())?;
100
101        for layer in self.params().hidden() {
102            output = layer.forward_then(&output, |y| y.relu())?;
103        }
104
105        let y = self
106            .params()
107            .output()
108            .forward_then(&output, |y| y.sigmoid())?;
109        Ok(y)
110    }
111}
112
113impl<A, S, T> Train<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for SimpleModel<A>
114where
115    A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
116    S: Data<Elem = A>,
117    T: Data<Elem = A>,
118{
119    type Output = A;
120
121    #[cfg_attr(
122        feature = "tracing",
123        tracing::instrument(
124            skip(self, input, target),
125            level = "trace",
126            name = "backward",
127            target = "model",
128        )
129    )]
130    fn train(
131        &mut self,
132        input: &ArrayBase<S, Ix1>,
133        target: &ArrayBase<T, Ix1>,
134    ) -> Result<Self::Output, NeuralError> {
135        if input.len() != self.features().input() {
136            return Err(NeuralError::InvalidInputShape);
137        }
138        if target.len() != self.features().output() {
139            return Err(NeuralError::InvalidOutputShape);
140        }
141        // get the learning rate from the model's configuration
142        let lr = self
143            .config()
144            .learning_rate()
145            .copied()
146            .unwrap_or(A::from_f32(0.01).unwrap());
147        // Normalize the input and target
148        let input = input / input.l2_norm();
149        let target_norm = target.l2_norm();
150        let target = target / target_norm;
151        // self.prev_target_norm = Some(target_norm);
152        // Forward pass to collect activations
153        let mut activations = Vec::new();
154        activations.push(input.to_owned());
155
156        let mut output = self.params().input().forward(&input)?.relu();
157        activations.push(output.to_owned());
158        // collect the activations of the hidden
159        for layer in self.params().hidden() {
160            output = layer.forward(&output)?.relu();
161            activations.push(output.to_owned());
162        }
163
164        output = self.params().output().forward(&output)?.sigmoid();
165        activations.push(output.to_owned());
166
167        // Calculate output layer error
168        let error = &target - &output;
169        let loss = error.pow2().mean().unwrap_or(A::zero());
170        #[cfg(feature = "tracing")]
171        tracing::trace!("Training loss: {loss:?}");
172        let mut delta = error * output.sigmoid_derivative();
173        delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients
174
175        // Update output weights
176        self.params_mut()
177            .output_mut()
178            .backward(activations.last().unwrap(), &delta, lr)?;
179
180        let num_hidden = self.features().layers();
181        // Iterate through hidden layers in reverse order
182        for i in (0..num_hidden).rev() {
183            // Calculate error for this layer
184            delta = if i == num_hidden - 1 {
185                // use the output activations for the final hidden layer
186                self.params().output().weights().dot(&delta) * activations[i + 1].relu_derivative()
187            } else {
188                // else; backpropagate using the previous hidden layer
189                self.params().hidden()[i + 1].weights().t().dot(&delta)
190                    * activations[i + 1].relu_derivative()
191            };
192            // Normalize delta to prevent exploding gradients
193            delta /= delta.l2_norm();
194            self.params_mut().hidden_mut()[i].backward(&activations[i + 1], &delta, lr)?;
195        }
196        /*
197            Backpropagate to the input layer
198            The delta for the input layer is computed using the weights of the first hidden layer
199            and the derivative of the activation function of the first hidden layer.
200
201            (h, h).dot(h) * derivative(h) = dim(h) where h is the number of features within a hidden layer
202        */
203        delta = self.params().hidden()[0].weights().dot(&delta) * activations[1].relu_derivative();
204        delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients
205        self.params_mut()
206            .input_mut()
207            .backward(&activations[1], &delta, lr)?;
208
209        Ok(loss)
210    }
211}
212
213impl<A, S, T> Train<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for SimpleModel<A>
214where
215    A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
216    S: Data<Elem = A>,
217    T: Data<Elem = A>,
218{
219    type Output = A;
220
221    #[cfg_attr(
222        feature = "tracing",
223        tracing::instrument(
224            skip(self, input, target),
225            level = "trace",
226            name = "train",
227            target = "model",
228            fields(input_shape = ?input.shape(), target_shape = ?target.shape())
229        )
230    )]
231    fn train(
232        &mut self,
233        input: &ArrayBase<S, Ix2>,
234        target: &ArrayBase<T, Ix2>,
235    ) -> Result<Self::Output, NeuralError> {
236        if input.nrows() == 0 || target.nrows() == 0 {
237            return Err(NeuralError::InvalidBatchSize);
238        }
239        if input.ncols() != self.features().input() {
240            return Err(NeuralError::InvalidInputShape);
241        }
242        if target.ncols() != self.features().output() || target.nrows() != input.nrows() {
243            return Err(NeuralError::InvalidOutputShape);
244        }
245        let mut loss = A::zero();
246
247        for (i, (x, e)) in input.rows().into_iter().zip(target.rows()).enumerate() {
248            loss += match Train::<ArrayView1<A>, ArrayView1<A>>::train(self, &x, &e) {
249                Ok(l) => l,
250                Err(err) => {
251                    #[cfg(feature = "tracing")]
252                    tracing::error!(
253                        "Training failed for batch {}/{}: {:?}",
254                        i + 1,
255                        input.nrows(),
256                        err
257                    );
258                    #[cfg(not(feature = "tracing"))]
259                    eprintln!(
260                        "Training failed for batch {}/{}: {:?}",
261                        i + 1,
262                        input.nrows(),
263                        err
264                    );
265                    return Err(err);
266                }
267            };
268        }
269
270        Ok(loss)
271    }
272}