eryon_surface/model/impls/
impl_model.rs

1/*
2    Appellation: impl_model <module>
3    Contrib: @FL03
4*/
5use crate::model::SurfaceModel;
6
7use cnc::prelude::{Forward, Norm, Predict, ReLU, Sigmoid, Train};
8use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, ScalarOperand};
9use num_traits::{Float, FromPrimitive, NumAssign};
10use rustfft::FftNum;
11
12impl<A> SurfaceModel<A>
13where
14    A: Float + FromPrimitive + ScalarOperand,
15{
16    /// predict some outcome by forward propgating the given input through the model
17    #[cfg_attr(
18        feature = "tracing",
19        tracing::instrument(skip(self, input), target = "surface")
20    )]
21    pub fn predict<X>(&self, input: &X) -> cnc::nn::NeuralResult<<Self as Predict<X>>::Output>
22    where
23        Self: Predict<X>,
24    {
25        #[cfg(feature = "tracing")]
26        tracing::trace!("Forwarding input through model");
27        <Self as Predict<X>>::predict(self, input)
28    }
29    /// train the model on some input and target data by backpropagating the error through the
30    /// model
31    #[cfg_attr(
32        feature = "tracing",
33        tracing::instrument(skip(self, input, target), target = "surface")
34    )]
35    pub fn train<X, Y, Z>(&mut self, input: &X, target: &Y) -> cnc::nn::NeuralResult<Z>
36    where
37        Self: Train<X, Y, Output = Z>,
38    {
39        #[cfg(feature = "tracing")]
40        tracing::trace!("Backpropagating through model");
41        <Self as Train<X, Y>>::train(self, input, target)
42    }
43}
44
45#[doc(hidden)]
46#[allow(deprecated)]
47impl<A> SurfaceModel<A>
48where
49    A: Float + FromPrimitive + ScalarOperand,
50{
51    #[deprecated(since = "0.0.2", note = "Use `train` instead")]
52    pub fn backward<X, Y, Z>(&mut self, input: &X, target: &Y) -> cnc::nn::NeuralResult<Z>
53    where
54        Self: Train<X, Y, Output = Z>,
55    {
56        #[cfg(feature = "tracing")]
57        tracing::trace!("Backpropagating through model");
58        <Self as Train<X, Y>>::train(self, input, target)
59    }
60    /// forward propogate input through the model
61    #[deprecated(since = "0.0.2", note = "Use `predict` instead")]
62    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
63    where
64        Self: Forward<X, Output = Y>,
65    {
66        #[cfg(feature = "tracing")]
67        tracing::trace!("Forwarding input through model");
68        <Self as Forward<X>>::forward(self, input)
69    }
70}
71
72impl<S, T> Forward<ArrayBase<S, Ix1>> for SurfaceModel<T>
73where
74    S: Data<Elem = T>,
75    T: FftNum + Float + FromPrimitive + ScalarOperand,
76{
77    type Output = Array1<T>;
78
79    fn forward(&self, input: &ArrayBase<S, Ix1>) -> cnc::Result<Self::Output> {
80        if input.len() != self.features().input() {
81            return Err(cnc::params::error::ParamsError::InvalidInputShape.into());
82        }
83        // apply the input layer and the activation function
84        let mut output = self.input().forward(input)?.relu();
85        // iterate through the hidden layers and apply the activation function
86        for layer in self.hidden() {
87            output = layer.forward(&output)?.relu();
88        }
89        // if the model has an attention mechanism, apply it
90        if let Some(ref attention) = self.attention {
91            output = attention.forward(&output)?;
92        }
93        // if the model has a previous target norm, normalize the output
94        if let Some(prev_target_norm) = self.prev_target_norm {
95            output = &output * prev_target_norm;
96        }
97        self.output().forward_then(&output, |y| y.sigmoid())
98    }
99}
100
101impl<S, T> Forward<ArrayBase<S, Ix2>> for SurfaceModel<T>
102where
103    S: Data<Elem = T>,
104    T: FftNum + Float + FromPrimitive + ScalarOperand,
105{
106    type Output = Array2<T>;
107
108    fn forward(&self, input: &ArrayBase<S, Ix2>) -> cnc::Result<Self::Output> {
109        if input.ncols() != self.features().input() {
110            return Err(cnc::params::error::ParamsError::InvalidInputShape.into());
111        }
112        let mut output = self.input().forward(input)?.relu();
113        for layer in self.hidden() {
114            output = layer.forward(&output)?.relu();
115        }
116        // if the model has an attention mechanism, apply it
117        if let Some(ref attention) = self.attention {
118            output = attention.forward(&output)?;
119        }
120        output = self.output().forward(&output)?.sigmoid();
121
122        if let Some(prev_target_norm) = self.prev_target_norm {
123            // Normalize the output by the previous target norm to maintain scale
124            output = &output * prev_target_norm;
125        }
126        Ok(output)
127    }
128}
129
130impl<A, S, T> Train<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for SurfaceModel<A>
131where
132    A: FftNum + Float + FromPrimitive + NumAssign + ScalarOperand,
133    S: Data<Elem = A>,
134    T: Data<Elem = A>,
135{
136    type Output = A;
137
138    #[cfg_attr(
139        feature = "tracing",
140        tracing::instrument(
141            skip(self, input, target),
142            level = "trace",
143            name = "backward",
144            target = "model",
145            fields(learning_rate = ?self.learning_rate(), input_shape = ?input.shape(), target_shape = ?target.shape())
146        )
147    )]
148    fn train(
149        &mut self,
150        input: &ArrayBase<S, Ix1>,
151        target: &ArrayBase<T, Ix1>,
152    ) -> Result<A, cnc::nn::NeuralError> {
153        if input.len() != self.features().input() {
154            return Err(cnc::nn::NeuralError::InvalidInputShape.into());
155        }
156        if target.len() != self.features().output() {
157            return Err(cnc::Error::ShapeError(ndarray::ShapeError::from_kind(
158                ndarray::ErrorKind::IncompatibleShape,
159            ))
160            .into());
161        }
162        // get the learning rate from the model's configuration
163        let lr = *self.learning_rate();
164        // Normalize the input and target
165        let input = input / input.l2_norm();
166        let target_norm = target.l2_norm();
167        let target = target / target_norm;
168        self.prev_target_norm = Some(target_norm);
169        // Forward pass to collect activations
170        let mut activations = Vec::new();
171        activations.push(input.to_owned());
172
173        let mut output = self.input().forward(&input)?.relu();
174        activations.push(output.to_owned());
175        // collect the activations of the hidden
176        for layer in self.hidden() {
177            output = layer.forward(&output)?.relu();
178            activations.push(output.to_owned());
179        }
180
181        // if initialized, apply the attention mechanism
182        if let Some(ref attention) = self.attention {
183            output = attention.forward(&output)?;
184        }
185
186        output = self.output().forward(&output)?.sigmoid();
187        activations.push(output.to_owned());
188
189        // Calculate output layer error
190        let error = target - &output;
191        let loss = error.pow2().mean().unwrap_or(A::zero());
192        #[cfg(feature = "tracing")]
193        tracing::trace!("Training loss: {loss:?}");
194        let mut delta = error * output.sigmoid_derivative();
195        delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients
196
197        // Update output weights
198        self.output_mut()
199            .backward(activations.last().unwrap(), &delta, lr)?;
200
201        let num_hidden = self.features().layers();
202        // Iterate through hidden layers in reverse order
203        for i in (0..num_hidden).rev() {
204            // Calculate error for this layer
205            delta = if i == num_hidden - 1 {
206                // use the output activations for the final hidden layer
207                self.output().weights().dot(&delta) * activations[i + 1].relu_derivative()
208            } else {
209                // else; backpropagate using the previous hidden layer
210                self.hidden()[i + 1].weights().t().dot(&delta)
211                    * activations[i + 1].relu_derivative()
212            };
213            // Normalize delta to prevent exploding gradients
214            delta /= delta.l2_norm();
215            self.hidden_mut()[i].backward(&activations[i + 1], &delta, lr)?;
216        }
217        /*
218            Backpropagate to the input layer
219            The delta for the input layer is computed using the weights of the first hidden layer
220            and the derivative of the activation function of the first hidden layer.
221
222            (h, h).dot(h) * derivative(h) = dim(h) where h is the number of features within a hidden layer
223        */
224        delta = self.hidden()[0].weights().dot(&delta) * activations[1].relu_derivative();
225        delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients
226        self.input_mut().backward(&activations[1], &delta, lr)?;
227
228        Ok(loss)
229    }
230}
231
232impl<A, S, T> Train<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for SurfaceModel<A>
233where
234    A: FftNum + Float + FromPrimitive + NumAssign + ScalarOperand,
235    S: Data<Elem = A>,
236    T: Data<Elem = A>,
237{
238    type Output = A;
239
240    #[cfg_attr(
241        feature = "tracing",
242        tracing::instrument(
243            skip_all,
244            level = "trace",
245            name = "train",
246            target = "model",
247            fields(input_shape = ?input.shape(), target_shape = ?target.shape())
248        )
249    )]
250    fn train(
251        &mut self,
252        input: &ArrayBase<S, Ix2>,
253        target: &ArrayBase<T, Ix2>,
254    ) -> Result<A, cnc::nn::NeuralError> {
255        if input.nrows() == 0 || target.nrows() == 0 {
256            return Err(cnc::nn::NeuralError::InvalidBatchSize);
257        }
258        if input.ncols() != self.features().input() {
259            return Err(cnc::nn::NeuralError::InvalidInputShape);
260        }
261        if target.ncols() != self.features().output() || target.nrows() != input.nrows() {
262            return Err(cnc::Error::ShapeError(ndarray::ShapeError::from_kind(
263                ndarray::ErrorKind::IncompatibleShape,
264            ))
265            .into());
266        }
267        let _batch_size = input.nrows();
268        let mut loss = A::zero();
269
270        for (i, (x, e)) in input.rows().into_iter().zip(target.rows()).enumerate() {
271            loss += match self.train(&x, &e) {
272                Ok(l) => l,
273                Err(err) => {
274                    let bi = i + 1;
275                    #[cfg(feature = "tracing")]
276                    tracing::error!(
277                        "Training failed for batch {s}/{b}: {err:?}",
278                        s = bi,
279                        b = _batch_size
280                    );
281                    #[cfg(not(feature = "tracing"))]
282                    eprintln!(
283                        "Training failed for batch {s}/{b}: {err:?}",
284                        s = bi,
285                        b = _batch_size
286                    );
287                    return Err(err);
288                }
289            };
290        }
291
292        Ok(loss)
293    }
294}