concision_ext/
simple.rs

1/*
2    Appellation: simple <module>
3    Contrib: @FL03
4*/
5use cnc::nn::{Model, ModelFeatures, ModelParams, NeuralError, StandardModelConfig, Train};
6use cnc::{Forward, Norm, Params, ReLU, Sigmoid};
7
8use ndarray::prelude::*;
9use ndarray::{Data, ScalarOperand};
10use num_traits::{Float, FromPrimitive, NumAssign};
11
12#[derive(Clone, Debug)]
13pub struct SimpleModel<T = f64> {
14    pub config: StandardModelConfig<T>,
15    pub features: ModelFeatures,
16    pub params: ModelParams<T>,
17}
18
19impl<T> SimpleModel<T> {
20    pub fn new(config: StandardModelConfig<T>, features: ModelFeatures) -> Self
21    where
22        T: Clone + Default,
23    {
24        let params = ModelParams::default(features);
25        SimpleModel {
26            config,
27            features,
28            params,
29        }
30    }
31    /// returns a reference to the model configuration
32    pub const fn config(&self) -> &StandardModelConfig<T> {
33        &self.config
34    }
35    /// returns a mutable reference to the model configuration
36    pub const fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
37        &mut self.config
38    }
39    /// returns the model features
40    pub const fn features(&self) -> ModelFeatures {
41        self.features
42    }
43    /// returns a mutable reference to the model features
44    pub const fn features_mut(&mut self) -> &mut ModelFeatures {
45        &mut self.features
46    }
47    /// returns a reference to the model parameters
48    pub const fn params(&self) -> &ModelParams<T> {
49        &self.params
50    }
51    /// returns a mutable reference to the model parameters
52    pub const fn params_mut(&mut self) -> &mut ModelParams<T> {
53        &mut self.params
54    }
55    /// set the current configuration and return a mutable reference to the model
56    pub fn set_config(&mut self, config: StandardModelConfig<T>) -> &mut Self {
57        self.config = config;
58        self
59    }
60    /// set the current features and return a mutable reference to the model
61    pub fn set_features(&mut self, features: ModelFeatures) -> &mut Self {
62        self.features = features;
63        self
64    }
65    /// set the current parameters and return a mutable reference to the model
66    pub fn set_params(&mut self, params: ModelParams<T>) -> &mut Self {
67        self.params = params;
68        self
69    }
70    /// consumes the current instance to create another with the given configuration
71    pub fn with_config(self, config: StandardModelConfig<T>) -> Self {
72        Self { config, ..self }
73    }
74    /// consumes the current instance to create another with the given features
75    pub fn with_features(self, features: ModelFeatures) -> Self {
76        Self { features, ..self }
77    }
78    /// consumes the current instance to create another with the given parameters
79    pub fn with_params(self, params: ModelParams<T>) -> Self {
80        Self { params, ..self }
81    }
82    /// initializes the model with Glorot normal distribution
83    #[cfg(feature = "rand")]
84    pub fn init(self) -> Self
85    where
86        T: Float + FromPrimitive,
87        cnc::rand_distr::StandardNormal: cnc::rand_distr::Distribution<T>,
88    {
89        let params = ModelParams::glorot_normal(self.features());
90        SimpleModel { params, ..self }
91    }
92}
93
94impl<T> Model<T> for SimpleModel<T> {
95    type Config = StandardModelConfig<T>;
96    type Layout = ModelFeatures;
97
98    fn config(&self) -> &StandardModelConfig<T> {
99        &self.config
100    }
101
102    fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
103        &mut self.config
104    }
105
106    fn layout(&self) -> ModelFeatures {
107        self.features
108    }
109
110    fn params(&self) -> &ModelParams<T> {
111        &self.params
112    }
113
114    fn params_mut(&mut self) -> &mut ModelParams<T> {
115        &mut self.params
116    }
117}
118
119impl<A, S, D> Forward<ArrayBase<S, D>> for SimpleModel<A>
120where
121    A: Float + FromPrimitive + ScalarOperand,
122    D: Dimension,
123    S: Data<Elem = A>,
124    Params<A>: Forward<Array<A, D>, Output = Array<A, D>>,
125{
126    type Output = Array<A, D>;
127
128    fn forward(&self, input: &ArrayBase<S, D>) -> cnc::Result<Self::Output> {
129        let mut output = self
130            .params()
131            .input()
132            .forward_then(&input.to_owned(), |y| y.relu())?;
133
134        for layer in self.params().hidden() {
135            output = layer.forward_then(&output, |y| y.relu())?;
136        }
137
138        let y = self
139            .params()
140            .output()
141            .forward_then(&output, |y| y.sigmoid())?;
142        Ok(y)
143    }
144}
145
146impl<A, S, T> Train<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for SimpleModel<A>
147where
148    A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
149    S: Data<Elem = A>,
150    T: Data<Elem = A>,
151{
152    type Output = A;
153
154    #[cfg_attr(
155        feature = "tracing",
156        tracing::instrument(
157            skip(self, input, target),
158            level = "trace",
159            name = "backward",
160            target = "model",
161        )
162    )]
163    fn train(
164        &mut self,
165        input: &ArrayBase<S, Ix1>,
166        target: &ArrayBase<T, Ix1>,
167    ) -> Result<Self::Output, NeuralError> {
168        if input.len() != self.features().input() {
169            return Err(NeuralError::InvalidInputShape);
170        }
171        if target.len() != self.features().output() {
172            return Err(NeuralError::InvalidOutputShape);
173        }
174        // get the learning rate from the model's configuration
175        let lr = self
176            .config()
177            .learning_rate()
178            .copied()
179            .unwrap_or(A::from_f32(0.01).unwrap());
180        // Normalize the input and target
181        let input = input / input.l2_norm();
182        let target_norm = target.l2_norm();
183        let target = target / target_norm;
184        // self.prev_target_norm = Some(target_norm);
185        // Forward pass to collect activations
186        let mut activations = Vec::new();
187        activations.push(input.to_owned());
188
189        let mut output = self.params().input().forward(&input)?.relu();
190        activations.push(output.to_owned());
191        // collect the activations of the hidden
192        for layer in self.params().hidden() {
193            output = layer.forward(&output)?.relu();
194            activations.push(output.to_owned());
195        }
196
197        output = self.params().output().forward(&output)?.sigmoid();
198        activations.push(output.to_owned());
199
200        // Calculate output layer error
201        let error = &target - &output;
202        let loss = error.pow2().mean().unwrap_or(A::zero());
203        #[cfg(feature = "tracing")]
204        tracing::trace!("Training loss: {loss:?}");
205        let mut delta = error * output.sigmoid_derivative();
206        delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients
207
208        // Update output weights
209        self.params_mut()
210            .output_mut()
211            .backward(activations.last().unwrap(), &delta, lr)?;
212
213        let num_hidden = self.features().layers();
214        // Iterate through hidden layers in reverse order
215        for i in (0..num_hidden).rev() {
216            // Calculate error for this layer
217            delta = if i == num_hidden - 1 {
218                // use the output activations for the final hidden layer
219                self.params().output().weights().dot(&delta) * activations[i + 1].relu_derivative()
220            } else {
221                // else; backpropagate using the previous hidden layer
222                self.params().hidden()[i + 1].weights().t().dot(&delta)
223                    * activations[i + 1].relu_derivative()
224            };
225            // Normalize delta to prevent exploding gradients
226            delta /= delta.l2_norm();
227            self.params_mut().hidden_mut()[i].backward(&activations[i + 1], &delta, lr)?;
228        }
229        /*
230            Backpropagate to the input layer
231            The delta for the input layer is computed using the weights of the first hidden layer
232            and the derivative of the activation function of the first hidden layer.
233
234            (h, h).dot(h) * derivative(h) = dim(h) where h is the number of features within a hidden layer
235        */
236        delta = self.params().hidden()[0].weights().dot(&delta) * activations[1].relu_derivative();
237        delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients
238        self.params_mut()
239            .input_mut()
240            .backward(&activations[1], &delta, lr)?;
241
242        Ok(loss)
243    }
244}
245
246impl<A, S, T> Train<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for SimpleModel<A>
247where
248    A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
249    S: Data<Elem = A>,
250    T: Data<Elem = A>,
251{
252    type Output = A;
253
254    #[cfg_attr(
255        feature = "tracing",
256        tracing::instrument(
257            skip(self, input, target),
258            level = "trace",
259            name = "train",
260            target = "model",
261            fields(input_shape = ?input.shape(), target_shape = ?target.shape())
262        )
263    )]
264    fn train(
265        &mut self,
266        input: &ArrayBase<S, Ix2>,
267        target: &ArrayBase<T, Ix2>,
268    ) -> Result<Self::Output, NeuralError> {
269        if input.nrows() == 0 || target.nrows() == 0 {
270            return Err(NeuralError::InvalidBatchSize);
271        }
272        if input.ncols() != self.features().input() {
273            return Err(NeuralError::InvalidInputShape);
274        }
275        if target.ncols() != self.features().output() || target.nrows() != input.nrows() {
276            return Err(NeuralError::InvalidOutputShape);
277        }
278        let mut loss = A::zero();
279
280        for (i, (x, e)) in input.rows().into_iter().zip(target.rows()).enumerate() {
281            loss += match Train::<ArrayView1<A>, ArrayView1<A>>::train(self, &x, &e) {
282                Ok(l) => l,
283                Err(err) => {
284                    #[cfg(feature = "tracing")]
285                    tracing::error!(
286                        "Training failed for batch {}/{}: {:?}",
287                        i + 1,
288                        input.nrows(),
289                        err
290                    );
291                    #[cfg(not(feature = "tracing"))]
292                    eprintln!(
293                        "Training failed for batch {}/{}: {:?}",
294                        i + 1,
295                        input.nrows(),
296                        err
297                    );
298                    return Err(err);
299                }
300            };
301        }
302
303        Ok(loss)
304    }
305}