eryon_surface/model/impls/
impl_model.rs1use crate::model::SurfaceModel;
6
7use cnc::prelude::{Forward, Norm, Predict, ReLU, Sigmoid, Train};
8use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, ScalarOperand};
9use num_traits::{Float, FromPrimitive, NumAssign};
10use rustfft::FftNum;
11
12impl<A> SurfaceModel<A>
13where
14 A: Float + FromPrimitive + ScalarOperand,
15{
16 #[cfg_attr(
18 feature = "tracing",
19 tracing::instrument(skip(self, input), target = "surface")
20 )]
21 pub fn predict<X>(&self, input: &X) -> cnc::nn::NeuralResult<<Self as Predict<X>>::Output>
22 where
23 Self: Predict<X>,
24 {
25 #[cfg(feature = "tracing")]
26 tracing::trace!("Forwarding input through model");
27 <Self as Predict<X>>::predict(self, input)
28 }
29 #[cfg_attr(
32 feature = "tracing",
33 tracing::instrument(skip(self, input, target), target = "surface")
34 )]
35 pub fn train<X, Y, Z>(&mut self, input: &X, target: &Y) -> cnc::nn::NeuralResult<Z>
36 where
37 Self: Train<X, Y, Output = Z>,
38 {
39 #[cfg(feature = "tracing")]
40 tracing::trace!("Backpropagating through model");
41 <Self as Train<X, Y>>::train(self, input, target)
42 }
43}
44
45#[doc(hidden)]
46#[allow(deprecated)]
47impl<A> SurfaceModel<A>
48where
49 A: Float + FromPrimitive + ScalarOperand,
50{
51 #[deprecated(since = "0.0.2", note = "Use `train` instead")]
52 pub fn backward<X, Y, Z>(&mut self, input: &X, target: &Y) -> cnc::nn::NeuralResult<Z>
53 where
54 Self: Train<X, Y, Output = Z>,
55 {
56 #[cfg(feature = "tracing")]
57 tracing::trace!("Backpropagating through model");
58 <Self as Train<X, Y>>::train(self, input, target)
59 }
60 #[deprecated(since = "0.0.2", note = "Use `predict` instead")]
62 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
63 where
64 Self: Forward<X, Output = Y>,
65 {
66 #[cfg(feature = "tracing")]
67 tracing::trace!("Forwarding input through model");
68 <Self as Forward<X>>::forward(self, input)
69 }
70}
71
72impl<S, T> Forward<ArrayBase<S, Ix1>> for SurfaceModel<T>
73where
74 S: Data<Elem = T>,
75 T: FftNum + Float + FromPrimitive + ScalarOperand,
76{
77 type Output = Array1<T>;
78
79 fn forward(&self, input: &ArrayBase<S, Ix1>) -> cnc::Result<Self::Output> {
80 if input.len() != self.features().input() {
81 return Err(cnc::params::error::ParamsError::InvalidInputShape.into());
82 }
83 let mut output = self.input().forward(input)?.relu();
85 for layer in self.hidden() {
87 output = layer.forward(&output)?.relu();
88 }
89 if let Some(ref attention) = self.attention {
91 output = attention.forward(&output)?;
92 }
93 if let Some(prev_target_norm) = self.prev_target_norm {
95 output = &output * prev_target_norm;
96 }
97 self.output().forward_then(&output, |y| y.sigmoid())
98 }
99}
100
101impl<S, T> Forward<ArrayBase<S, Ix2>> for SurfaceModel<T>
102where
103 S: Data<Elem = T>,
104 T: FftNum + Float + FromPrimitive + ScalarOperand,
105{
106 type Output = Array2<T>;
107
108 fn forward(&self, input: &ArrayBase<S, Ix2>) -> cnc::Result<Self::Output> {
109 if input.ncols() != self.features().input() {
110 return Err(cnc::params::error::ParamsError::InvalidInputShape.into());
111 }
112 let mut output = self.input().forward(input)?.relu();
113 for layer in self.hidden() {
114 output = layer.forward(&output)?.relu();
115 }
116 if let Some(ref attention) = self.attention {
118 output = attention.forward(&output)?;
119 }
120 output = self.output().forward(&output)?.sigmoid();
121
122 if let Some(prev_target_norm) = self.prev_target_norm {
123 output = &output * prev_target_norm;
125 }
126 Ok(output)
127 }
128}
129
130impl<A, S, T> Train<ArrayBase<S, Ix1>, ArrayBase<T, Ix1>> for SurfaceModel<A>
131where
132 A: FftNum + Float + FromPrimitive + NumAssign + ScalarOperand,
133 S: Data<Elem = A>,
134 T: Data<Elem = A>,
135{
136 type Output = A;
137
138 #[cfg_attr(
139 feature = "tracing",
140 tracing::instrument(
141 skip(self, input, target),
142 level = "trace",
143 name = "backward",
144 target = "model",
145 fields(learning_rate = ?self.learning_rate(), input_shape = ?input.shape(), target_shape = ?target.shape())
146 )
147 )]
148 fn train(
149 &mut self,
150 input: &ArrayBase<S, Ix1>,
151 target: &ArrayBase<T, Ix1>,
152 ) -> Result<A, cnc::nn::NeuralError> {
153 if input.len() != self.features().input() {
154 return Err(cnc::nn::NeuralError::InvalidInputShape.into());
155 }
156 if target.len() != self.features().output() {
157 return Err(cnc::Error::ShapeError(ndarray::ShapeError::from_kind(
158 ndarray::ErrorKind::IncompatibleShape,
159 ))
160 .into());
161 }
162 let lr = *self.learning_rate();
164 let input = input / input.l2_norm();
166 let target_norm = target.l2_norm();
167 let target = target / target_norm;
168 self.prev_target_norm = Some(target_norm);
169 let mut activations = Vec::new();
171 activations.push(input.to_owned());
172
173 let mut output = self.input().forward(&input)?.relu();
174 activations.push(output.to_owned());
175 for layer in self.hidden() {
177 output = layer.forward(&output)?.relu();
178 activations.push(output.to_owned());
179 }
180
181 if let Some(ref attention) = self.attention {
183 output = attention.forward(&output)?;
184 }
185
186 output = self.output().forward(&output)?.sigmoid();
187 activations.push(output.to_owned());
188
189 let error = target - &output;
191 let loss = error.pow2().mean().unwrap_or(A::zero());
192 #[cfg(feature = "tracing")]
193 tracing::trace!("Training loss: {loss:?}");
194 let mut delta = error * output.sigmoid_derivative();
195 delta /= delta.l2_norm(); self.output_mut()
199 .backward(activations.last().unwrap(), &delta, lr)?;
200
201 let num_hidden = self.features().layers();
202 for i in (0..num_hidden).rev() {
204 delta = if i == num_hidden - 1 {
206 self.output().weights().dot(&delta) * activations[i + 1].relu_derivative()
208 } else {
209 self.hidden()[i + 1].weights().t().dot(&delta)
211 * activations[i + 1].relu_derivative()
212 };
213 delta /= delta.l2_norm();
215 self.hidden_mut()[i].backward(&activations[i + 1], &delta, lr)?;
216 }
217 delta = self.hidden()[0].weights().dot(&delta) * activations[1].relu_derivative();
225 delta /= delta.l2_norm(); self.input_mut().backward(&activations[1], &delta, lr)?;
227
228 Ok(loss)
229 }
230}
231
232impl<A, S, T> Train<ArrayBase<S, Ix2>, ArrayBase<T, Ix2>> for SurfaceModel<A>
233where
234 A: FftNum + Float + FromPrimitive + NumAssign + ScalarOperand,
235 S: Data<Elem = A>,
236 T: Data<Elem = A>,
237{
238 type Output = A;
239
240 #[cfg_attr(
241 feature = "tracing",
242 tracing::instrument(
243 skip_all,
244 level = "trace",
245 name = "train",
246 target = "model",
247 fields(input_shape = ?input.shape(), target_shape = ?target.shape())
248 )
249 )]
250 fn train(
251 &mut self,
252 input: &ArrayBase<S, Ix2>,
253 target: &ArrayBase<T, Ix2>,
254 ) -> Result<A, cnc::nn::NeuralError> {
255 if input.nrows() == 0 || target.nrows() == 0 {
256 return Err(cnc::nn::NeuralError::InvalidBatchSize);
257 }
258 if input.ncols() != self.features().input() {
259 return Err(cnc::nn::NeuralError::InvalidInputShape);
260 }
261 if target.ncols() != self.features().output() || target.nrows() != input.nrows() {
262 return Err(cnc::Error::ShapeError(ndarray::ShapeError::from_kind(
263 ndarray::ErrorKind::IncompatibleShape,
264 ))
265 .into());
266 }
267 let _batch_size = input.nrows();
268 let mut loss = A::zero();
269
270 for (i, (x, e)) in input.rows().into_iter().zip(target.rows()).enumerate() {
271 loss += match self.train(&x, &e) {
272 Ok(l) => l,
273 Err(err) => {
274 let bi = i + 1;
275 #[cfg(feature = "tracing")]
276 tracing::error!(
277 "Training failed for batch {s}/{b}: {err:?}",
278 s = bi,
279 b = _batch_size
280 );
281 #[cfg(not(feature = "tracing"))]
282 eprintln!(
283 "Training failed for batch {s}/{b}: {err:?}",
284 s = bi,
285 b = _batch_size
286 );
287 return Err(err);
288 }
289 };
290 }
291
292 Ok(loss)
293 }
294}