1use crate::model::SurfaceModel;
6use crate::network::SurfaceNetwork;
7use crate::points::{Point, PointKind};
8use cnc::nn::{NeuralError, Predict, PredictWithConfidence, Train};
9use cnc::params::Params;
10use cnc::prelude::{Forward, Init, InitInplace};
11use ndarray::ScalarOperand;
12use num_traits::{Float, FromPrimitive, Num, NumAssign, Signed};
13use rstmt::nrt::{LPR, Triad};
14
15impl<A> SurfaceNetwork<A>
16where
17 A: FromPrimitive + ScalarOperand,
18{
19 pub fn add_critical_point(&mut self, kind: PointKind) -> &mut Self
21 where
22 A: Float,
23 {
24 let critical_point = Point::from_triad(kind, &self.headspace);
26
27 self.points.push(critical_point);
29
30 let current_layers = self.model.features().layers();
32 let required_layers = self.points.len().max(3);
33
34 if current_layers < required_layers {
35 self.model.features_mut().set_layers(required_layers);
37
38 let layer_dims = self.model.features().dim_hidden();
40 let hidden = self.model.hidden_mut();
41
42 while hidden.len() < required_layers {
44 hidden.push(Params::zeros(layer_dims));
45 }
46 }
47
48 self
49 }
50 #[doc(hidden)]
51 #[deprecated(since = "0.0.2", note = "Use `train` instead")]
52 pub fn backward<X, Y, Z>(&mut self, input: &X, expected: &Y) -> Result<Z, NeuralError>
54 where
55 A: Float + FromPrimitive + NumAssign + ScalarOperand + rustfft::FftNum,
56 SurfaceModel<A>: Train<X, Y, Output = Z>,
57 {
58 self.model_mut().train(input, expected)
59 }
60 #[doc(hidden)]
61 #[deprecated(since = "0.0.2", note = "Use `predict` instead")]
62 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
64 where
65 SurfaceModel<A>: cnc::Forward<X, Output = Y>,
66 {
67 self.model.forward(input)
68 }
69 pub fn predict<X, Y>(&self, input: &X) -> Result<Y, NeuralError>
71 where
72 SurfaceModel<A>: Predict<X, Output = Y>,
73 {
74 Predict::predict(&self.model, input)
75 }
76 pub fn train<X, Y, Z>(&mut self, input: &X, expected: &Y) -> Result<Z, NeuralError>
78 where
79 A: Float + FromPrimitive + NumAssign + ScalarOperand + rustfft::FftNum,
80 SurfaceModel<A>: Train<X, Y, Output = Z>,
81 {
82 self.model_mut().train(input, expected)
83 }
84 pub fn get_centroid(&self) -> Option<[A; 2]>
86 where
87 A: Float + FromPrimitive,
88 {
89 self.headspace().centroid()
90 }
91 pub fn get_position(&self, cp_index: usize) -> Option<[A; 3]>
93 where
94 A: Copy + Num,
95 {
96 if cp_index >= self.points.len() {
97 return None;
98 }
99
100 let cp = &self.points[cp_index];
101 let notes = self.headspace.notes();
102 let mut result = [A::zero(); 3];
103
104 for i in 0..3 {
106 let note_val = A::from_usize(notes[i]).unwrap();
107 let octave_shift =
109 A::from_isize(cp.modulus().into_inner()).unwrap() * A::from_isize(12).unwrap();
110 result[i] = (note_val + octave_shift) * cp.coordinates[i];
111 }
112
113 Some(result)
114 }
115 pub fn init(self) -> Self
117 where
118 SurfaceModel<A>: Init,
119 {
120 let model = self.model.init();
121 Self { model, ..self }
122 }
123 pub fn init_inplace(&mut self) -> &mut Self
125 where
126 SurfaceModel<A>: InitInplace,
127 {
128 self.model_mut().init_inplace();
129 self
130 }
131 pub fn predict_with_confidence<X, Y>(
133 &self,
134 input: &X,
135 ) -> Result<(Y, <SurfaceModel<A> as PredictWithConfidence<X>>::Confidence), NeuralError>
136 where
137 SurfaceModel<A>: PredictWithConfidence<X, Output = Y>,
138 {
139 self.model().predict_with_confidence(input)
140 }
141 #[cfg_attr(
144 feature = "tracing",
145 tracing::instrument(skip_all, name = "reconfigure", target = "network", level = "trace")
146 )]
147 pub fn reconfigure(&mut self, next: Triad)
148 where
149 A: Float,
150 {
151 let new_critical_points = self
153 .points
154 .iter()
155 .map(|cp| Point::from_triad(cp.kind, &next))
156 .collect::<Vec<_>>();
157
158 self.set_points(new_critical_points);
160 self.headspace = next;
162 }
163 pub fn transfer_knowledge_from(&mut self, other: &Self, adaptation_rate: A)
165 where
166 A: Copy + PartialOrd + Num + Signed,
167 {
168 if self.points.len() != other.points.len() {
170 return;
171 }
172
173 let mut cp_mapping = Vec::new();
175
176 for (i, cp) in self.points.iter().enumerate() {
177 let mut best_match = 0;
178 let mut best_score = A::zero();
179
180 for (j, other_cp) in other.points.iter().enumerate() {
181 if cp.kind() == other_cp.kind() {
183 let mut similarity = A::from_f32(0.5).unwrap(); let coord_similarity = (0..3).fold(A::zero(), |acc, k| {
187 acc + (A::one() - (cp.coordinates[k] - other_cp.coordinates[k]).abs())
188 }) / A::from_usize(3).unwrap();
189
190 similarity = similarity + coord_similarity * A::from_f32(0.5).unwrap();
191
192 if similarity > best_score {
193 best_score = similarity;
194 best_match = j;
195 }
196 }
197 }
198
199 cp_mapping.push((i, best_match, best_score));
200 }
201
202 for (self_idx, other_idx, score) in &cp_mapping {
204 if *score > A::from_f32(0.6).unwrap() {
205 for j in 0..3 {
206 self.model.input_mut().weights_mut()[[*self_idx, j]] = self
208 .model
209 .input()
210 .weights()[[*self_idx, j]]
211 * (A::one() - adaptation_rate * *score)
212 + other.model.input().weights()[[*other_idx, j]] * adaptation_rate * *score;
213 }
214 }
215 }
216
217 for (self_idx, other_idx, score) in &cp_mapping {
219 if *score > A::from_f32(0.6).unwrap() {
220 self.model.output_mut().weights_mut()[[0, *self_idx]] = self
222 .model
223 .output()
224 .weights()[[0, *self_idx]]
225 * (A::one() - adaptation_rate * *score)
226 + other.model.output().weights()[[0, *other_idx]] * adaptation_rate * *score;
227 }
228 }
229
230 if !self.model.hidden().is_empty() && !other.model.hidden().is_empty() {
232 for (self_i, other_i, score_i) in &cp_mapping {
233 for (self_j, other_j, score_j) in &cp_mapping {
234 if *score_i > A::from_f32(0.6).unwrap() && *score_j > A::from_f32(0.6).unwrap()
235 {
236 self.model.hidden_mut()[0].weights_mut()[[*self_i, *self_j]] =
238 self.model.hidden()[0].weights()[[*self_i, *self_j]]
239 * (A::one() - adaptation_rate * *score_i * *score_j)
240 + other.model.hidden()[0].weights()[[*other_i, *other_j]]
241 * adaptation_rate
242 * *score_i
243 * *score_j;
244 }
245 }
246 }
247 }
248 }
249 #[cfg_attr(
252 feature = "tracing",
253 tracing::instrument(skip_all, name = "transform", target = "network", level = "trace")
254 )]
255 pub fn transform(&mut self, transform: LPR) -> crate::SurfaceResult<()>
256 where
257 A: Float,
258 {
259 let new_triad = self.headspace().transform(transform);
261 self.reconfigure(new_triad);
263
264 Ok(())
265 }
266 #[cfg_attr(
268 feature = "tracing",
269 tracing::instrument(skip_all, name = "walk", target = "network", level = "trace")
270 )]
271 pub fn walk<I>(&mut self, iter: I) -> crate::SurfaceResult<()>
272 where
273 A: Float,
274 I: IntoIterator<Item = LPR>,
275 {
276 for transform in iter {
277 self.transform(transform)?;
278 }
279 Ok(())
280 }
281}
282
283impl<A, X, Y> Forward<X> for SurfaceNetwork<A>
284where
285 A: Float + FromPrimitive + NumAssign + ScalarOperand,
286 SurfaceModel<A>: Forward<X, Output = Y>,
287{
288 type Output = Y;
289
290 fn forward(&self, input: &X) -> cnc::Result<Self::Output> {
291 <SurfaceModel<A> as Forward<X>>::forward(self.model(), input)
292 }
293}
294
295impl<A, X, Y, Z> Train<X, Y> for SurfaceNetwork<A>
296where
297 A: Float + FromPrimitive + NumAssign + ScalarOperand,
298 SurfaceModel<A>: Train<X, Y, Output = Z>,
299{
300 type Output = Z;
301
302 fn train(&mut self, input: &X, expected: &Y) -> Result<Self::Output, cnc::nn::NeuralError> {
303 <SurfaceModel<A> as Train<X, Y>>::train(self.model_mut(), input, expected)
304 }
305}