1use cnc::nn::{Model, ModelFeatures, ModelParams, NeuralError, StandardModelConfig, Train};
6use cnc::{Forward, Norm, Params, ReLU, Sigmoid};
7
8use ndarray::prelude::*;
9use ndarray::{Data, ScalarOperand};
10use num_traits::{Float, FromPrimitive, NumAssign};
11
12pub struct SimpleModel<T = f64> {
13 pub config: StandardModelConfig<T>,
14 pub features: ModelFeatures,
15 pub params: ModelParams<T>,
16}
17
18impl<T> SimpleModel<T> {
19 pub fn new(config: StandardModelConfig<T>, features: ModelFeatures) -> Self
20 where
21 T: Clone + Default,
22 {
23 let params = ModelParams::default(features);
24 SimpleModel {
25 config,
26 features,
27 params,
28 }
29 }
30 #[cfg(feature = "rand")]
31 pub fn init(self) -> Self
32 where
33 T: Float + FromPrimitive,
34 cnc::init::rand_distr::StandardNormal: cnc::init::rand_distr::Distribution<T>,
35 {
36 let params = ModelParams::glorot_normal(self.features);
37 SimpleModel { params, ..self }
38 }
39
40 pub const fn config(&self) -> &StandardModelConfig<T> {
41 &self.config
42 }
43
44 pub fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
45 &mut self.config
46 }
47
48 pub const fn features(&self) -> ModelFeatures {
49 self.features
50 }
51
52 pub const fn params(&self) -> &ModelParams<T> {
53 &self.params
54 }
55
56 pub fn params_mut(&mut self) -> &mut ModelParams<T> {
57 &mut self.params
58 }
59}
60
61impl<T> Model<T> for SimpleModel<T> {
62 type Config = StandardModelConfig<T>;
63 type Layout = ModelFeatures;
64
65 fn config(&self) -> &StandardModelConfig<T> {
66 &self.config
67 }
68
69 fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
70 &mut self.config
71 }
72
73 fn layout(&self) -> ModelFeatures {
74 self.features
75 }
76
77 fn params(&self) -> &ModelParams<T> {
78 &self.params
79 }
80
81 fn params_mut(&mut self) -> &mut ModelParams<T> {
82 &mut self.params
83 }
84}
85
86impl<A, S, D> Forward<ArrayBase<S, D>> for SimpleModel<A>
87where
88 A: Float + FromPrimitive + ScalarOperand,
89 D: Dimension,
90 S: Data<Elem = A>,
91 Params<A>: Forward<Array<A, D>, Output = Array<A, D>>,
92{
93 type Output = Array<A, D>;
94
95 fn forward(&self, input: &ArrayBase<S, D>) -> cnc::Result<Self::Output> {
96 let mut output = self
97 .params()
98 .input()
99 .forward_then(&input.to_owned(), |y| y.relu())?;
100
101 for layer in self.params().hidden() {
102 output = layer.forward_then(&output, |y| y.relu())?;
103 }
104
105 let y = self
106 .params()
107 .output()
108 .forward_then(&output, |y| y.sigmoid())?;
109 Ok(y)
110 }
111}
112
113impl<A, S, T> Train<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for SimpleModel<A>
114where
115 A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
116 S: Data<Elem = A>,
117 T: Data<Elem = A>,
118{
119 type Output = A;
120
121 #[cfg_attr(
122 feature = "tracing",
123 tracing::instrument(
124 skip(self, input, target),
125 level = "trace",
126 name = "backward",
127 target = "model",
128 )
129 )]
130 fn train(
131 &mut self,
132 input: &ArrayBase<S, Ix1>,
133 target: &ArrayBase<T, Ix1>,
134 ) -> Result<Self::Output, NeuralError> {
135 if input.len() != self.features().input() {
136 return Err(NeuralError::InvalidInputShape);
137 }
138 if target.len() != self.features().output() {
139 return Err(NeuralError::InvalidOutputShape);
140 }
141 let lr = self
143 .config()
144 .learning_rate()
145 .copied()
146 .unwrap_or(A::from_f32(0.01).unwrap());
147 let input = input / input.l2_norm();
149 let target_norm = target.l2_norm();
150 let target = target / target_norm;
151 let mut activations = Vec::new();
154 activations.push(input.to_owned());
155
156 let mut output = self.params().input().forward(&input)?.relu();
157 activations.push(output.to_owned());
158 for layer in self.params().hidden() {
160 output = layer.forward(&output)?.relu();
161 activations.push(output.to_owned());
162 }
163
164 output = self.params().output().forward(&output)?.sigmoid();
165 activations.push(output.to_owned());
166
167 let error = &target - &output;
169 let loss = error.pow2().mean().unwrap_or(A::zero());
170 #[cfg(feature = "tracing")]
171 tracing::trace!("Training loss: {loss:?}");
172 let mut delta = error * output.sigmoid_derivative();
173 delta /= delta.l2_norm(); self.params_mut()
177 .output_mut()
178 .backward(activations.last().unwrap(), &delta, lr)?;
179
180 let num_hidden = self.features().layers();
181 for i in (0..num_hidden).rev() {
183 delta = if i == num_hidden - 1 {
185 self.params().output().weights().dot(&delta) * activations[i + 1].relu_derivative()
187 } else {
188 self.params().hidden()[i + 1].weights().t().dot(&delta)
190 * activations[i + 1].relu_derivative()
191 };
192 delta /= delta.l2_norm();
194 self.params_mut().hidden_mut()[i].backward(&activations[i + 1], &delta, lr)?;
195 }
196 delta = self.params().hidden()[0].weights().dot(&delta) * activations[1].relu_derivative();
204 delta /= delta.l2_norm(); self.params_mut()
206 .input_mut()
207 .backward(&activations[1], &delta, lr)?;
208
209 Ok(loss)
210 }
211}
212
213impl<A, S, T> Train<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for SimpleModel<A>
214where
215 A: Float + FromPrimitive + NumAssign + ScalarOperand + core::fmt::Debug,
216 S: Data<Elem = A>,
217 T: Data<Elem = A>,
218{
219 type Output = A;
220
221 #[cfg_attr(
222 feature = "tracing",
223 tracing::instrument(
224 skip(self, input, target),
225 level = "trace",
226 name = "train",
227 target = "model",
228 fields(input_shape = ?input.shape(), target_shape = ?target.shape())
229 )
230 )]
231 fn train(
232 &mut self,
233 input: &ArrayBase<S, Ix2>,
234 target: &ArrayBase<T, Ix2>,
235 ) -> Result<Self::Output, NeuralError> {
236 if input.nrows() == 0 || target.nrows() == 0 {
237 return Err(NeuralError::InvalidBatchSize);
238 }
239 if input.ncols() != self.features().input() {
240 return Err(NeuralError::InvalidInputShape);
241 }
242 if target.ncols() != self.features().output() || target.nrows() != input.nrows() {
243 return Err(NeuralError::InvalidOutputShape);
244 }
245 let mut loss = A::zero();
246
247 for (i, (x, e)) in input.rows().into_iter().zip(target.rows()).enumerate() {
248 loss += match Train::<ArrayView1<A>, ArrayView1<A>>::train(self, &x, &e) {
249 Ok(l) => l,
250 Err(err) => {
251 #[cfg(feature = "tracing")]
252 tracing::error!(
253 "Training failed for batch {}/{}: {:?}",
254 i + 1,
255 input.nrows(),
256 err
257 );
258 #[cfg(not(feature = "tracing"))]
259 eprintln!(
260 "Training failed for batch {}/{}: {:?}",
261 i + 1,
262 input.nrows(),
263 err
264 );
265 return Err(err);
266 }
267 };
268 }
269
270 Ok(loss)
271 }
272}