concision_transformer/
model.rs

1/*
2    Appellation: transformer <library>
3    Contrib: @FL03
4*/
5
6use cnc::nn::{DeepModelParams, Model, ModelError, ModelFeatures, StandardModelConfig, Train};
7#[cfg(feature = "rand")]
8use cnc::rand_distr;
9use cnc::{Forward, Norm, Params, ReLU, Sigmoid};
10
11use ndarray::prelude::*;
12use ndarray::{Data, ScalarOperand};
13use num_traits::{Float, FromPrimitive, NumAssign};
14
15#[derive(Clone, Debug)]
16pub struct TransformerModel<T = f64> {
17    pub config: StandardModelConfig<T>,
18    pub features: ModelFeatures,
19    pub params: DeepModelParams<T>,
20}
21
22impl<T> TransformerModel<T> {
23    pub fn new(config: StandardModelConfig<T>, features: ModelFeatures) -> Self
24    where
25        T: Clone + Default,
26    {
27        let params = DeepModelParams::default(features);
28        TransformerModel {
29            config,
30            features,
31            params,
32        }
33    }
34    #[cfg(feature = "rand")]
35    pub fn init(self) -> Self
36    where
37        T: Float + FromPrimitive,
38        rand_distr::StandardNormal: rand_distr::Distribution<T>,
39    {
40        let params = DeepModelParams::glorot_normal(self.features());
41        TransformerModel { params, ..self }
42    }
43    /// returns a reference to the model configuration
44    pub const fn config(&self) -> &StandardModelConfig<T> {
45        &self.config
46    }
47    /// returns a mutable reference to the model configuration
48    pub const fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
49        &mut self.config
50    }
51    /// returns the model features
52    pub const fn features(&self) -> ModelFeatures {
53        self.features
54    }
55    /// returns a mutable reference to the model features
56    pub const fn features_mut(&mut self) -> &mut ModelFeatures {
57        &mut self.features
58    }
59    /// returns a reference to the model parameters
60    pub const fn params(&self) -> &DeepModelParams<T> {
61        &self.params
62    }
63    /// returns a mutable reference to the model parameters
64    pub const fn params_mut(&mut self) -> &mut DeepModelParams<T> {
65        &mut self.params
66    }
67    /// set the current configuration and return a mutable reference to the model
68    pub fn set_config(&mut self, config: StandardModelConfig<T>) -> &mut Self {
69        self.config = config;
70        self
71    }
72    /// set the current features and return a mutable reference to the model
73    pub fn set_features(&mut self, features: ModelFeatures) -> &mut Self {
74        self.features = features;
75        self
76    }
77    /// set the current parameters and return a mutable reference to the model
78    pub fn set_params(&mut self, params: DeepModelParams<T>) -> &mut Self {
79        self.params = params;
80        self
81    }
82    /// consumes the current instance to create another with the given configuration
83    pub fn with_config(self, config: StandardModelConfig<T>) -> Self {
84        Self { config, ..self }
85    }
86    /// consumes the current instance to create another with the given features
87    pub fn with_features(self, features: ModelFeatures) -> Self {
88        Self { features, ..self }
89    }
90    /// consumes the current instance to create another with the given parameters
91    pub fn with_params(self, params: DeepModelParams<T>) -> Self {
92        Self { params, ..self }
93    }
94}
95
96impl<T> Model<T> for TransformerModel<T> {
97    type Config = StandardModelConfig<T>;
98    type Layout = ModelFeatures;
99
100    fn config(&self) -> &StandardModelConfig<T> {
101        &self.config
102    }
103
104    fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
105        &mut self.config
106    }
107
108    fn layout(&self) -> ModelFeatures {
109        self.features
110    }
111
112    fn params(&self) -> &DeepModelParams<T> {
113        &self.params
114    }
115
116    fn params_mut(&mut self) -> &mut DeepModelParams<T> {
117        &mut self.params
118    }
119}
120
121impl<A, U, V> Forward<U> for TransformerModel<A>
122where
123    A: Float + FromPrimitive + ScalarOperand,
124    V: ReLU<Output = V> + Sigmoid<Output = V>,
125    Params<A>: Forward<U, Output = V> + Forward<V, Output = V>,
126    for<'a> &'a U: ndarray::linalg::Dot<Array2<A>, Output = V> + core::ops::Add<&'a Array1<A>>,
127    V: for<'a> core::ops::Add<&'a Array1<A>, Output = V>,
128{
129    type Output = V;
130
131    fn forward(&self, input: &U) -> cnc::Result<Self::Output> {
132        let mut output = self.params().input().forward_then(&input, |y| y.relu())?;
133
134        for layer in self.params().hidden() {
135            output = layer.forward_then(&output, |y| y.relu())?;
136        }
137
138        let y = self
139            .params()
140            .output()
141            .forward_then(&output, |y| y.sigmoid())?;
142        Ok(y)
143    }
144}
145
146impl<A, S, T> Train<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for TransformerModel<A>
147where
148    A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
149    S: Data<Elem = A>,
150    T: Data<Elem = A>,
151{
152    type Output = A;
153
154    #[cfg_attr(
155        feature = "tracing",
156        tracing::instrument(
157            skip(self, input, target),
158            level = "trace",
159            name = "backward",
160            target = "model",
161        )
162    )]
163    fn train(
164        &mut self,
165        input: &ArrayBase<S, Ix1>,
166        target: &ArrayBase<T, Ix1>,
167    ) -> Result<Self::Output, ModelError> {
168        if input.len() != self.features().input() {
169            return Err(ModelError::InvalidInputShape);
170        }
171        if target.len() != self.features().output() {
172            return Err(ModelError::InvalidOutputShape);
173        }
174        // get the learning rate from the model's configuration
175        let lr = self
176            .config()
177            .learning_rate()
178            .copied()
179            .unwrap_or(A::from_f32(0.01).unwrap());
180        // Normalize the input and target
181        let input = input / input.l2_norm();
182        let target_norm = target.l2_norm();
183        let target = target / target_norm;
184        // self.prev_target_norm = Some(target_norm);
185        // Forward pass to collect activations
186        let mut activations = Vec::new();
187        activations.push(input.to_owned());
188
189        let mut output = self.params().input().forward(&input)?.relu();
190        activations.push(output.to_owned());
191        // collect the activations of the hidden
192        for layer in self.params().hidden() {
193            output = layer.forward(&output)?.relu();
194            activations.push(output.to_owned());
195        }
196
197        output = self.params().output().forward(&output)?.sigmoid();
198        activations.push(output.to_owned());
199
200        // Calculate output layer error
201        let error = &target - &output;
202        let loss = error.pow2().mean().unwrap_or(A::zero());
203        #[cfg(feature = "tracing")]
204        tracing::trace!("Training loss: {loss:?}");
205        let mut delta = error * output.sigmoid_derivative();
206        delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients
207
208        // Update output weights
209        self.params_mut()
210            .output_mut()
211            .backward(activations.last().unwrap(), &delta, lr)?;
212
213        let num_hidden = self.features().layers();
214        // Iterate through hidden layers in reverse order
215        for i in (0..num_hidden).rev() {
216            // Calculate error for this layer
217            delta = if i == num_hidden - 1 {
218                // use the output activations for the final hidden layer
219                self.params().output().weights().dot(&delta) * activations[i + 1].relu_derivative()
220            } else {
221                // else; backpropagate using the previous hidden layer
222                self.params().hidden()[i + 1].weights().t().dot(&delta)
223                    * activations[i + 1].relu_derivative()
224            };
225            // Normalize delta to prevent exploding gradients
226            delta /= delta.l2_norm();
227            self.params_mut().hidden_mut()[i].backward(&activations[i + 1], &delta, lr)?;
228        }
229        /*
230            Backpropagate to the input layer
231            The delta for the input layer is computed using the weights of the first hidden layer
232            and the derivative of the activation function of the first hidden layer.
233
234            (h, h).dot(h) * derivative(h) = dim(h) where h is the number of features within a hidden layer
235        */
236        delta = self.params().hidden()[0].weights().dot(&delta) * activations[1].relu_derivative();
237        delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients
238        self.params_mut()
239            .input_mut()
240            .backward(&activations[1], &delta, lr)?;
241
242        Ok(loss)
243    }
244}
245
246impl<A, S, T> Train<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for TransformerModel<A>
247where
248    A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
249    S: Data<Elem = A>,
250    T: Data<Elem = A>,
251{
252    type Output = A;
253
254    #[cfg_attr(
255        feature = "tracing",
256        tracing::instrument(
257            skip(self, input, target),
258            level = "trace",
259            name = "train",
260            target = "model",
261            fields(input_shape = ?input.shape(), target_shape = ?target.shape())
262        )
263    )]
264    fn train(
265        &mut self,
266        input: &ArrayBase<S, Ix2>,
267        target: &ArrayBase<T, Ix2>,
268    ) -> Result<Self::Output, ModelError> {
269        if input.nrows() == 0 || target.nrows() == 0 {
270            return Err(ModelError::InvalidBatchSize);
271        }
272        if input.ncols() != self.features().input() {
273            return Err(ModelError::InvalidInputShape);
274        }
275        if target.ncols() != self.features().output() || target.nrows() != input.nrows() {
276            return Err(ModelError::InvalidOutputShape);
277        }
278        let mut loss = A::zero();
279
280        for (_i, (x, e)) in input.rows().into_iter().zip(target.rows()).enumerate() {
281            loss += match Train::<ArrayView1<A>, ArrayView1<A>>::train(self, &x, &e) {
282                Ok(l) => l,
283                Err(err) => {
284                    #[cfg(feature = "tracing")]
285                    tracing::error!(
286                        "Training failed for batch {}/{}: {:?}",
287                        _i + 1,
288                        input.nrows(),
289                        err
290                    );
291                    return Err(err);
292                }
293            };
294        }
295
296        Ok(loss)
297    }
298}