eryon_surface/network/
impl_network.rs

1/*
2    Appellation: impl_network <module>
3    Contrib: @FL03
4*/
5use crate::model::SurfaceModel;
6use crate::network::SurfaceNetwork;
7use crate::points::{Point, PointKind};
8use cnc::nn::{NeuralError, Predict, PredictWithConfidence, Train};
9use cnc::params::Params;
10use cnc::prelude::{Forward, Init, InitInplace};
11use ndarray::ScalarOperand;
12use num_traits::{Float, FromPrimitive, Num, NumAssign, Signed};
13use rstmt::nrt::{LPR, Triad};
14
15impl<A> SurfaceNetwork<A>
16where
17    A: FromPrimitive + ScalarOperand,
18{
19    /// Add a new critical point to configure part of the network's hidden layer
20    pub fn add_critical_point(&mut self, kind: PointKind) -> &mut Self
21    where
22        A: Float,
23    {
24        // Create a new critical point from the current triad
25        let critical_point = Point::from_triad(kind, &self.headspace);
26
27        // Add to critical points
28        self.points.push(critical_point);
29
30        // Ensure model has enough hidden layers for all critical points
31        let current_layers = self.model.features().layers();
32        let required_layers = self.points.len().max(3);
33
34        if current_layers < required_layers {
35            // Update model features
36            self.model.features_mut().set_layers(required_layers);
37
38            // Add additional layers as needed
39            let layer_dims = self.model.features().dim_hidden();
40            let hidden = self.model.hidden_mut();
41
42            // Initialize new layers - use existing or zero initialization
43            while hidden.len() < required_layers {
44                hidden.push(Params::zeros(layer_dims));
45            }
46        }
47
48        self
49    }
50    #[doc(hidden)]
51    #[deprecated(since = "0.0.2", note = "Use `train` instead")]
52    /// perform a single backpropagation step on qualifing datasets
53    pub fn backward<X, Y, Z>(&mut self, input: &X, expected: &Y) -> Result<Z, NeuralError>
54    where
55        A: Float + FromPrimitive + NumAssign + ScalarOperand + rustfft::FftNum,
56        SurfaceModel<A>: Train<X, Y, Output = Z>,
57    {
58        self.model_mut().train(input, expected)
59    }
60    #[doc(hidden)]
61    #[deprecated(since = "0.0.2", note = "Use `predict` instead")]
62    /// forward propogate the given input through the network
63    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
64    where
65        SurfaceModel<A>: cnc::Forward<X, Output = Y>,
66    {
67        self.model.forward(input)
68    }
69    /// forward propogate the given input through the network
70    pub fn predict<X, Y>(&self, input: &X) -> Result<Y, NeuralError>
71    where
72        SurfaceModel<A>: Predict<X, Output = Y>,
73    {
74        Predict::predict(&self.model, input)
75    }
76    /// perform a single backpropagation step on qualifing datasets
77    pub fn train<X, Y, Z>(&mut self, input: &X, expected: &Y) -> Result<Z, NeuralError>
78    where
79        A: Float + FromPrimitive + NumAssign + ScalarOperand + rustfft::FftNum,
80        SurfaceModel<A>: Train<X, Y, Output = Z>,
81    {
82        self.model_mut().train(input, expected)
83    }
84    /// get the position of the centroid w.r.t the headspace
85    pub fn get_centroid(&self) -> Option<[A; 2]>
86    where
87        A: Float + FromPrimitive,
88    {
89        self.headspace().centroid()
90    }
91    /// returns the position of a particular hidden layer w.r.t. the current headspace
92    pub fn get_position(&self, cp_index: usize) -> Option<[A; 3]>
93    where
94        A: Copy + Num,
95    {
96        if cp_index >= self.points.len() {
97            return None;
98        }
99
100        let cp = &self.points[cp_index];
101        let notes = self.headspace.notes();
102        let mut result = [A::zero(); 3];
103
104        // Weight each note by its barycentric coordinate
105        for i in 0..3 {
106            let note_val = A::from_usize(notes[i]).unwrap();
107            // Apply modulus to properly position in tonal space
108            let octave_shift =
109                A::from_isize(cp.modulus().into_inner()).unwrap() * A::from_isize(12).unwrap();
110            result[i] = (note_val + octave_shift) * cp.coordinates[i];
111        }
112
113        Some(result)
114    }
115    /// initialize the network
116    pub fn init(self) -> Self
117    where
118        SurfaceModel<A>: Init,
119    {
120        let model = self.model.init();
121        Self { model, ..self }
122    }
123    /// initialize the network in-place
124    pub fn init_inplace(&mut self) -> &mut Self
125    where
126        SurfaceModel<A>: InitInplace,
127    {
128        self.model_mut().init_inplace();
129        self
130    }
131    /// returns the model's prediction with a confidence value
132    pub fn predict_with_confidence<X, Y>(
133        &self,
134        input: &X,
135    ) -> Result<(Y, <SurfaceModel<A> as PredictWithConfidence<X>>::Confidence), NeuralError>
136    where
137        SurfaceModel<A>: PredictWithConfidence<X, Output = Y>,
138    {
139        self.model().predict_with_confidence(input)
140    }
141    /// reconfigure the network with respect to the given triad; the method is typically called
142    /// after applying a transformation to the headspace
143    #[cfg_attr(
144        feature = "tracing",
145        tracing::instrument(skip_all, name = "reconfigure", target = "network", level = "trace")
146    )]
147    pub fn reconfigure(&mut self, next: Triad)
148    where
149        A: Float,
150    {
151        // Create new critical points for the new triad
152        let new_critical_points = self
153            .points
154            .iter()
155            .map(|cp| Point::from_triad(cp.kind, &next))
156            .collect::<Vec<_>>();
157
158        // update the network's critical points
159        self.set_points(new_critical_points);
160        // update the model's headspace
161        self.headspace = next;
162    }
163    /// Transfer knowledge from another surface network
164    pub fn transfer_knowledge_from(&mut self, other: &Self, adaptation_rate: A)
165    where
166        A: Copy + PartialOrd + Num + Signed,
167    {
168        // Only transfer if the networks have the same number of critical points
169        if self.points.len() != other.points.len() {
170            return;
171        }
172
173        // Map critical points between networks based on type
174        let mut cp_mapping = Vec::new();
175
176        for (i, cp) in self.points.iter().enumerate() {
177            let mut best_match = 0;
178            let mut best_score = A::zero();
179
180            for (j, other_cp) in other.points.iter().enumerate() {
181                // Match based on kind and coordinates
182                if cp.kind() == other_cp.kind() {
183                    let mut similarity = A::from_f32(0.5).unwrap(); // Base score for matching kind
184
185                    // Add coordinate similarity
186                    let coord_similarity = (0..3).fold(A::zero(), |acc, k| {
187                        acc + (A::one() - (cp.coordinates[k] - other_cp.coordinates[k]).abs())
188                    }) / A::from_usize(3).unwrap();
189
190                    similarity = similarity + coord_similarity * A::from_f32(0.5).unwrap();
191
192                    if similarity > best_score {
193                        best_score = similarity;
194                        best_match = j;
195                    }
196                }
197            }
198
199            cp_mapping.push((i, best_match, best_score));
200        }
201
202        // Transfer primary weights
203        for (self_idx, other_idx, score) in &cp_mapping {
204            if *score > A::from_f32(0.6).unwrap() {
205                for j in 0..3 {
206                    // Blend weights based on similarity score and adaptation rate
207                    self.model.input_mut().weights_mut()[[*self_idx, j]] = self
208                        .model
209                        .input()
210                        .weights()[[*self_idx, j]]
211                        * (A::one() - adaptation_rate * *score)
212                        + other.model.input().weights()[[*other_idx, j]] * adaptation_rate * *score;
213                }
214            }
215        }
216
217        // Transfer secondary weights
218        for (self_idx, other_idx, score) in &cp_mapping {
219            if *score > A::from_f32(0.6).unwrap() {
220                // Blend weights
221                self.model.output_mut().weights_mut()[[0, *self_idx]] = self
222                    .model
223                    .output()
224                    .weights()[[0, *self_idx]]
225                    * (A::one() - adaptation_rate * *score)
226                    + other.model.output().weights()[[0, *other_idx]] * adaptation_rate * *score;
227            }
228        }
229
230        // Transfer temporal context if both have surface layers
231        if !self.model.hidden().is_empty() && !other.model.hidden().is_empty() {
232            for (self_i, other_i, score_i) in &cp_mapping {
233                for (self_j, other_j, score_j) in &cp_mapping {
234                    if *score_i > A::from_f32(0.6).unwrap() && *score_j > A::from_f32(0.6).unwrap()
235                    {
236                        // Blend temporal weights
237                        self.model.hidden_mut()[0].weights_mut()[[*self_i, *self_j]] =
238                            self.model.hidden()[0].weights()[[*self_i, *self_j]]
239                                * (A::one() - adaptation_rate * *score_i * *score_j)
240                                + other.model.hidden()[0].weights()[[*other_i, *other_j]]
241                                    * adaptation_rate
242                                    * *score_i
243                                    * *score_j;
244                    }
245                }
246            }
247        }
248    }
249    /// applies the given [LPR] transformation to the headspace of the network before applying
250    /// the neccessary adjusments to the architecture.
251    #[cfg_attr(
252        feature = "tracing",
253        tracing::instrument(skip_all, name = "transform", target = "network", level = "trace")
254    )]
255    pub fn transform(&mut self, transform: LPR) -> crate::SurfaceResult<()>
256    where
257        A: Float,
258    {
259        // apply transformation to the triad
260        let new_triad = self.headspace().transform(transform);
261        // Update network's triad and critical points
262        self.reconfigure(new_triad);
263
264        Ok(())
265    }
266    /// Walk the network through a series of transformations
267    #[cfg_attr(
268        feature = "tracing",
269        tracing::instrument(skip_all, name = "walk", target = "network", level = "trace")
270    )]
271    pub fn walk<I>(&mut self, iter: I) -> crate::SurfaceResult<()>
272    where
273        A: Float,
274        I: IntoIterator<Item = LPR>,
275    {
276        for transform in iter {
277            self.transform(transform)?;
278        }
279        Ok(())
280    }
281}
282
283impl<A, X, Y> Forward<X> for SurfaceNetwork<A>
284where
285    A: Float + FromPrimitive + NumAssign + ScalarOperand,
286    SurfaceModel<A>: Forward<X, Output = Y>,
287{
288    type Output = Y;
289
290    fn forward(&self, input: &X) -> cnc::Result<Self::Output> {
291        <SurfaceModel<A> as Forward<X>>::forward(self.model(), input)
292    }
293}
294
295impl<A, X, Y, Z> Train<X, Y> for SurfaceNetwork<A>
296where
297    A: Float + FromPrimitive + NumAssign + ScalarOperand,
298    SurfaceModel<A>: Train<X, Y, Output = Z>,
299{
300    type Output = Z;
301
302    fn train(&mut self, input: &X, expected: &Y) -> Result<Self::Output, cnc::nn::NeuralError> {
303        <SurfaceModel<A> as Train<X, Y>>::train(self.model_mut(), input, expected)
304    }
305}