1use cnc::nn::{DeepModelParams, Model, ModelError, ModelFeatures, StandardModelConfig, Train};
7#[cfg(feature = "rand")]
8use cnc::rand_distr;
9use cnc::{Forward, Norm, Params, ReLU, Sigmoid};
10
11use ndarray::prelude::*;
12use ndarray::{Data, ScalarOperand};
13use num_traits::{Float, FromPrimitive, NumAssign};
14
15#[derive(Clone, Debug)]
16pub struct TransformerModel<T = f64> {
17 pub config: StandardModelConfig<T>,
18 pub features: ModelFeatures,
19 pub params: DeepModelParams<T>,
20}
21
22impl<T> TransformerModel<T> {
23 pub fn new(config: StandardModelConfig<T>, features: ModelFeatures) -> Self
24 where
25 T: Clone + Default,
26 {
27 let params = DeepModelParams::default(features);
28 TransformerModel {
29 config,
30 features,
31 params,
32 }
33 }
34 #[cfg(feature = "rand")]
35 pub fn init(self) -> Self
36 where
37 T: Float + FromPrimitive,
38 rand_distr::StandardNormal: rand_distr::Distribution<T>,
39 {
40 let params = DeepModelParams::glorot_normal(self.features());
41 TransformerModel { params, ..self }
42 }
43 pub const fn config(&self) -> &StandardModelConfig<T> {
45 &self.config
46 }
47 pub const fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
49 &mut self.config
50 }
51 pub const fn features(&self) -> ModelFeatures {
53 self.features
54 }
55 pub const fn features_mut(&mut self) -> &mut ModelFeatures {
57 &mut self.features
58 }
59 pub const fn params(&self) -> &DeepModelParams<T> {
61 &self.params
62 }
63 pub const fn params_mut(&mut self) -> &mut DeepModelParams<T> {
65 &mut self.params
66 }
67 pub fn set_config(&mut self, config: StandardModelConfig<T>) -> &mut Self {
69 self.config = config;
70 self
71 }
72 pub fn set_features(&mut self, features: ModelFeatures) -> &mut Self {
74 self.features = features;
75 self
76 }
77 pub fn set_params(&mut self, params: DeepModelParams<T>) -> &mut Self {
79 self.params = params;
80 self
81 }
82 pub fn with_config(self, config: StandardModelConfig<T>) -> Self {
84 Self { config, ..self }
85 }
86 pub fn with_features(self, features: ModelFeatures) -> Self {
88 Self { features, ..self }
89 }
90 pub fn with_params(self, params: DeepModelParams<T>) -> Self {
92 Self { params, ..self }
93 }
94}
95
96impl<T> Model<T> for TransformerModel<T> {
97 type Config = StandardModelConfig<T>;
98 type Layout = ModelFeatures;
99
100 fn config(&self) -> &StandardModelConfig<T> {
101 &self.config
102 }
103
104 fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
105 &mut self.config
106 }
107
108 fn layout(&self) -> ModelFeatures {
109 self.features
110 }
111
112 fn params(&self) -> &DeepModelParams<T> {
113 &self.params
114 }
115
116 fn params_mut(&mut self) -> &mut DeepModelParams<T> {
117 &mut self.params
118 }
119}
120
121impl<A, U, V> Forward<U> for TransformerModel<A>
122where
123 A: Float + FromPrimitive + ScalarOperand,
124 V: ReLU<Output = V> + Sigmoid<Output = V>,
125 Params<A>: Forward<U, Output = V> + Forward<V, Output = V>,
126 for<'a> &'a U: ndarray::linalg::Dot<Array2<A>, Output = V> + core::ops::Add<&'a Array1<A>>,
127 V: for<'a> core::ops::Add<&'a Array1<A>, Output = V>,
128{
129 type Output = V;
130
131 fn forward(&self, input: &U) -> cnc::Result<Self::Output> {
132 let mut output = self.params().input().forward_then(&input, |y| y.relu())?;
133
134 for layer in self.params().hidden() {
135 output = layer.forward_then(&output, |y| y.relu())?;
136 }
137
138 let y = self
139 .params()
140 .output()
141 .forward_then(&output, |y| y.sigmoid())?;
142 Ok(y)
143 }
144}
145
146impl<A, S, T> Train<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for TransformerModel<A>
147where
148 A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
149 S: Data<Elem = A>,
150 T: Data<Elem = A>,
151{
152 type Output = A;
153
154 #[cfg_attr(
155 feature = "tracing",
156 tracing::instrument(
157 skip(self, input, target),
158 level = "trace",
159 name = "backward",
160 target = "model",
161 )
162 )]
163 fn train(
164 &mut self,
165 input: &ArrayBase<S, Ix1>,
166 target: &ArrayBase<T, Ix1>,
167 ) -> Result<Self::Output, ModelError> {
168 if input.len() != self.features().input() {
169 return Err(ModelError::InvalidInputShape);
170 }
171 if target.len() != self.features().output() {
172 return Err(ModelError::InvalidOutputShape);
173 }
174 let lr = self
176 .config()
177 .learning_rate()
178 .copied()
179 .unwrap_or(A::from_f32(0.01).unwrap());
180 let input = input / input.l2_norm();
182 let target_norm = target.l2_norm();
183 let target = target / target_norm;
184 let mut activations = Vec::new();
187 activations.push(input.to_owned());
188
189 let mut output = self.params().input().forward(&input)?.relu();
190 activations.push(output.to_owned());
191 for layer in self.params().hidden() {
193 output = layer.forward(&output)?.relu();
194 activations.push(output.to_owned());
195 }
196
197 output = self.params().output().forward(&output)?.sigmoid();
198 activations.push(output.to_owned());
199
200 let error = &target - &output;
202 let loss = error.pow2().mean().unwrap_or(A::zero());
203 #[cfg(feature = "tracing")]
204 tracing::trace!("Training loss: {loss:?}");
205 let mut delta = error * output.sigmoid_derivative();
206 delta /= delta.l2_norm(); self.params_mut()
210 .output_mut()
211 .backward(activations.last().unwrap(), &delta, lr)?;
212
213 let num_hidden = self.features().layers();
214 for i in (0..num_hidden).rev() {
216 delta = if i == num_hidden - 1 {
218 self.params().output().weights().dot(&delta) * activations[i + 1].relu_derivative()
220 } else {
221 self.params().hidden()[i + 1].weights().t().dot(&delta)
223 * activations[i + 1].relu_derivative()
224 };
225 delta /= delta.l2_norm();
227 self.params_mut().hidden_mut()[i].backward(&activations[i + 1], &delta, lr)?;
228 }
229 delta = self.params().hidden()[0].weights().dot(&delta) * activations[1].relu_derivative();
237 delta /= delta.l2_norm(); self.params_mut()
239 .input_mut()
240 .backward(&activations[1], &delta, lr)?;
241
242 Ok(loss)
243 }
244}
245
246impl<A, S, T> Train<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for TransformerModel<A>
247where
248 A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
249 S: Data<Elem = A>,
250 T: Data<Elem = A>,
251{
252 type Output = A;
253
254 #[cfg_attr(
255 feature = "tracing",
256 tracing::instrument(
257 skip(self, input, target),
258 level = "trace",
259 name = "train",
260 target = "model",
261 fields(input_shape = ?input.shape(), target_shape = ?target.shape())
262 )
263 )]
264 fn train(
265 &mut self,
266 input: &ArrayBase<S, Ix2>,
267 target: &ArrayBase<T, Ix2>,
268 ) -> Result<Self::Output, ModelError> {
269 if input.nrows() == 0 || target.nrows() == 0 {
270 return Err(ModelError::InvalidBatchSize);
271 }
272 if input.ncols() != self.features().input() {
273 return Err(ModelError::InvalidInputShape);
274 }
275 if target.ncols() != self.features().output() || target.nrows() != input.nrows() {
276 return Err(ModelError::InvalidOutputShape);
277 }
278 let mut loss = A::zero();
279
280 for (_i, (x, e)) in input.rows().into_iter().zip(target.rows()).enumerate() {
281 loss += match Train::<ArrayView1<A>, ArrayView1<A>>::train(self, &x, &e) {
282 Ok(l) => l,
283 Err(err) => {
284 #[cfg(feature = "tracing")]
285 tracing::error!(
286 "Training failed for batch {}/{}: {:?}",
287 _i + 1,
288 input.nrows(),
289 err
290 );
291 return Err(err);
292 }
293 };
294 }
295
296 Ok(loss)
297 }
298}