1use cnc::nn::{DeepModelParams, Model, ModelFeatures, NeuralError, StandardModelConfig, Train};
6use cnc::{Forward, Norm, Params, ReLU, Sigmoid};
7
8use ndarray::prelude::*;
9use ndarray::{Data, ScalarOperand};
10use num_traits::{Float, FromPrimitive, NumAssign};
11
12#[derive(Clone, Debug)]
13pub struct SimpleModel<T = f64> {
14 pub config: StandardModelConfig<T>,
15 pub features: ModelFeatures,
16 pub params: DeepModelParams<T>,
17}
18
19impl<T> SimpleModel<T> {
20 pub fn new(config: StandardModelConfig<T>, features: ModelFeatures) -> Self
21 where
22 T: Clone + Default,
23 {
24 let params = DeepModelParams::default(features);
25 SimpleModel {
26 config,
27 features,
28 params,
29 }
30 }
31 pub const fn config(&self) -> &StandardModelConfig<T> {
33 &self.config
34 }
35 pub const fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
37 &mut self.config
38 }
39 pub const fn features(&self) -> ModelFeatures {
41 self.features
42 }
43 pub const fn features_mut(&mut self) -> &mut ModelFeatures {
45 &mut self.features
46 }
47 pub const fn params(&self) -> &DeepModelParams<T> {
49 &self.params
50 }
51 pub const fn params_mut(&mut self) -> &mut DeepModelParams<T> {
53 &mut self.params
54 }
55 pub fn set_config(&mut self, config: StandardModelConfig<T>) -> &mut Self {
57 self.config = config;
58 self
59 }
60 pub fn set_features(&mut self, features: ModelFeatures) -> &mut Self {
62 self.features = features;
63 self
64 }
65 pub fn set_params(&mut self, params: DeepModelParams<T>) -> &mut Self {
67 self.params = params;
68 self
69 }
70 pub fn with_config(self, config: StandardModelConfig<T>) -> Self {
72 Self { config, ..self }
73 }
74 pub fn with_features(self, features: ModelFeatures) -> Self {
76 Self { features, ..self }
77 }
78 pub fn with_params(self, params: DeepModelParams<T>) -> Self {
80 Self { params, ..self }
81 }
82 #[cfg(feature = "rand")]
84 pub fn init(self) -> Self
85 where
86 T: Float + FromPrimitive,
87 cnc::rand_distr::StandardNormal: cnc::rand_distr::Distribution<T>,
88 {
89 let params = DeepModelParams::glorot_normal(self.features());
90 SimpleModel { params, ..self }
91 }
92}
93
94impl<T> Model<T> for SimpleModel<T> {
95 type Config = StandardModelConfig<T>;
96 type Layout = ModelFeatures;
97
98 fn config(&self) -> &StandardModelConfig<T> {
99 &self.config
100 }
101
102 fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
103 &mut self.config
104 }
105
106 fn layout(&self) -> ModelFeatures {
107 self.features
108 }
109
110 fn params(&self) -> &DeepModelParams<T> {
111 &self.params
112 }
113
114 fn params_mut(&mut self) -> &mut DeepModelParams<T> {
115 &mut self.params
116 }
117}
118
119impl<A, S, D> Forward<ArrayBase<S, D>> for SimpleModel<A>
120where
121 A: Float + FromPrimitive + ScalarOperand,
122 D: Dimension,
123 S: Data<Elem = A>,
124 Params<A>: Forward<Array<A, D>, Output = Array<A, D>>,
125{
126 type Output = Array<A, D>;
127
128 fn forward(&self, input: &ArrayBase<S, D>) -> cnc::Result<Self::Output> {
129 let mut output = self
130 .params()
131 .input()
132 .forward_then(&input.to_owned(), |y| y.relu())?;
133
134 for layer in self.params().hidden() {
135 output = layer.forward_then(&output, |y| y.relu())?;
136 }
137
138 let y = self
139 .params()
140 .output()
141 .forward_then(&output, |y| y.sigmoid())?;
142 Ok(y)
143 }
144}
145
146impl<A, S, T> Train<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for SimpleModel<A>
147where
148 A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
149 S: Data<Elem = A>,
150 T: Data<Elem = A>,
151{
152 type Output = A;
153
154 #[cfg_attr(
155 feature = "tracing",
156 tracing::instrument(
157 skip(self, input, target),
158 level = "trace",
159 name = "backward",
160 target = "model",
161 )
162 )]
163 fn train(
164 &mut self,
165 input: &ArrayBase<S, Ix1>,
166 target: &ArrayBase<T, Ix1>,
167 ) -> Result<Self::Output, NeuralError> {
168 if input.len() != self.features().input() {
169 return Err(NeuralError::InvalidInputShape);
170 }
171 if target.len() != self.features().output() {
172 return Err(NeuralError::InvalidOutputShape);
173 }
174 let lr = self
176 .config()
177 .learning_rate()
178 .copied()
179 .unwrap_or(A::from_f32(0.01).unwrap());
180 let input = input / input.l2_norm();
182 let target_norm = target.l2_norm();
183 let target = target / target_norm;
184 let mut activations = Vec::new();
187 activations.push(input.to_owned());
188
189 let mut output = self.params().input().forward(&input)?.relu();
190 activations.push(output.to_owned());
191 for layer in self.params().hidden() {
193 output = layer.forward(&output)?.relu();
194 activations.push(output.to_owned());
195 }
196
197 output = self.params().output().forward(&output)?.sigmoid();
198 activations.push(output.to_owned());
199
200 let error = &target - &output;
202 let loss = error.pow2().mean().unwrap_or(A::zero());
203 #[cfg(feature = "tracing")]
204 tracing::trace!("Training loss: {loss:?}");
205 let mut delta = error * output.sigmoid_derivative();
206 delta /= delta.l2_norm(); self.params_mut()
210 .output_mut()
211 .backward(activations.last().unwrap(), &delta, lr)?;
212
213 let num_hidden = self.features().layers();
214 for i in (0..num_hidden).rev() {
216 delta = if i == num_hidden - 1 {
218 self.params().output().weights().dot(&delta) * activations[i + 1].relu_derivative()
220 } else {
221 self.params().hidden()[i + 1].weights().t().dot(&delta)
223 * activations[i + 1].relu_derivative()
224 };
225 delta /= delta.l2_norm();
227 self.params_mut().hidden_mut()[i].backward(&activations[i + 1], &delta, lr)?;
228 }
229 delta = self.params().hidden()[0].weights().dot(&delta) * activations[1].relu_derivative();
237 delta /= delta.l2_norm(); self.params_mut()
239 .input_mut()
240 .backward(&activations[1], &delta, lr)?;
241
242 Ok(loss)
243 }
244}
245
246impl<A, S, T> Train<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for SimpleModel<A>
247where
248 A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
249 S: Data<Elem = A>,
250 T: Data<Elem = A>,
251{
252 type Output = A;
253
254 #[cfg_attr(
255 feature = "tracing",
256 tracing::instrument(
257 skip(self, input, target),
258 level = "trace",
259 name = "train",
260 target = "model",
261 fields(input_shape = ?input.shape(), target_shape = ?target.shape())
262 )
263 )]
264 fn train(
265 &mut self,
266 input: &ArrayBase<S, Ix2>,
267 target: &ArrayBase<T, Ix2>,
268 ) -> Result<Self::Output, NeuralError> {
269 if input.nrows() == 0 || target.nrows() == 0 {
270 return Err(NeuralError::InvalidBatchSize);
271 }
272 if input.ncols() != self.features().input() {
273 return Err(NeuralError::InvalidInputShape);
274 }
275 if target.ncols() != self.features().output() || target.nrows() != input.nrows() {
276 return Err(NeuralError::InvalidOutputShape);
277 }
278 let mut loss = A::zero();
279
280 for (i, (x, e)) in input.rows().into_iter().zip(target.rows()).enumerate() {
281 loss += match Train::<ArrayView1<A>, ArrayView1<A>>::train(self, &x, &e) {
282 Ok(l) => l,
283 Err(err) => {
284 #[cfg(feature = "tracing")]
285 tracing::error!(
286 "Training failed for batch {}/{}: {:?}",
287 i + 1,
288 input.nrows(),
289 err
290 );
291 #[cfg(not(feature = "tracing"))]
292 eprintln!(
293 "Training failed for batch {}/{}: {:?}",
294 i + 1,
295 input.nrows(),
296 err
297 );
298 return Err(err);
299 }
300 };
301 }
302
303 Ok(loss)
304 }
305}