fekan/
kan_layer.rs

1// #![allow(dead_code)]
2pub(crate) mod edge;
3
4use crate::layer_errors::LayerError;
5use edge::{linspace, Edge};
6use log::{debug, trace};
7use rand::distributions::Distribution;
8use rand::thread_rng;
9use serde::{Deserialize, Serialize};
10use statrs::distribution::Normal; // apparently the statrs distributions use the rand Distribution trait
11
12use std::{
13    collections::VecDeque,
14    sync::{Arc, Mutex},
15    thread::{self, ScopedJoinHandle},
16    vec,
17};
18
19/// A layer in a Kolmogorov-Arnold neural Network (KAN)
20///
21/// A KAN layer consists of a number of nodes equal to the output dimension of the layer.
22/// Each node has a number of incoming edges equal to the input dimension of the layer, and each edge holds a B-spline that operates on the value travelling down the edge
23
24#[derive(Debug, Serialize, Deserialize, Clone)]
25pub struct KanLayer {
26    // I think it will make sense to have each KanLayer be a vector of splines, plus the input and output dimension.
27    // the first `out_dim` splines will read from the first input, the second `out_dim` splines will read from the second input, etc., with `in_dim` such chunks
28    // to caluclate the output of the layer, the first element is the sum of the output of splines 0, out_dim, 2*out_dim, etc., the second element is the sum of splines 1, out_dim+1, 2*out_dim+1, etc.
29    /// the splines in this layer. The first `output_dimension` splines read from the first "input node", the second `output_dimension` splines read from the second "input node", etc.
30    ///    /-1--X             /4---X
31    ///  O - 2            O  /
32    ///    \  \-X           /    /-X
33    ///  O  \             O ---5-
34    ///      3--X           \--6---X
35    pub(crate) splines: Vec<Edge>,
36    input_dimension: usize,
37    output_dimension: usize,
38    /// a vector of previous inputs to the layer, used to update the knot vectors for each incoming edge.
39    ///
40    /// dim0 = number of samples
41    ///
42    /// dim1 = input_dimension
43    #[serde(skip)] // part of the layer's operating state, not part of the model
44    samples: Vec<Vec<f64>>,
45
46    #[serde(skip)] // part of the layer's operating state, not part of the model
47    layer_l1: Option<f64>,
48}
49
50/// Hyperparameters for a KanLayer
51///
52/// # Examples
53/// see [`KanLayer::new`]
54#[derive(Debug, Copy, Clone)]
55#[allow(missing_docs)]
56pub struct KanLayerOptions {
57    pub input_dimension: usize,
58    pub output_dimension: usize,
59    pub degree: usize,
60    pub coef_size: usize,
61}
62
63impl KanLayer {
64    /// create a new layer with `output_dimension` nodes in this layer that each expect an `input_dimension`-long preactivation vector.
65    ///
66    /// All incoming edges will be created with a degree `degree` B-spline and `coef_size` control points.
67    ///
68    /// All B-splines areinitialized with coefficients drawn from astandard normal distribution, and with
69    /// `degree + coef_size + 1` knots evenly spaced between -1.0 and 1.0. Because knots are always initialized to span the range [-1, 1], make sure you call [`KanLayer::update_knots_from_samples`] regularly during training, or at least after a good portion of the training data has been passed through the model, to ensure that the layer's supported input range covers the range spanned by the training data.
70    /// # Warning
71    /// If you plan on ever calling [`KanLayer::update_knots_from_samples`] on your layer, make sure coef_size >= 2 * degree + 1. [`KanLayer::update_knots_from_samples`] reserves the first and last `degree` knots as "padding", and you will get NaNs when you call [`KanLayer::forward`] after updating knots if there aren't enough "non-padding" knots
72    ///
73    /// If you don't plan on calling [`KanLayer::update_knots_from_samples`], any coef_size >= degree + 1 should be fine
74    /// # Examples
75    /// ```
76    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
77    /// let input_dimension = 3;
78    /// let output_dimension = 4;
79    /// let layer_options = KanLayerOptions {
80    ///     input_dimension,
81    ///     output_dimension,
82    ///     degree: 3,
83    ///     coef_size: 6,
84    /// };
85    /// let my_layer = KanLayer::new(&layer_options);
86    /// assert_eq!(my_layer.total_edges(), output_dimension * input_dimension);
87    /// ```
88    pub fn new(options: &KanLayerOptions) -> Self {
89        let num_edges = options.input_dimension * options.output_dimension;
90        let num_knots = options.coef_size + options.degree + 1;
91        let normal_dist = Normal::new(0.0, 0.1).expect("unable to create normal distribution");
92        let mut randomness = thread_rng();
93        let splines = (0..num_edges)
94            .map(|_| {
95                let coefficients: Vec<f64> = (0..options.coef_size)
96                    .map(|_| normal_dist.sample(&mut randomness) as f64)
97                    .collect();
98                Edge::new(options.degree, coefficients, linspace(-1.0, 1.0, num_knots))
99                    .expect("spline creation error")
100            })
101            .collect();
102
103        KanLayer {
104            splines,
105            input_dimension: options.input_dimension,
106            output_dimension: options.output_dimension,
107            samples: Vec::new(),
108            layer_l1: None,
109        }
110    }
111
112    // pub fn len(&self) -> usize {
113    //     self.nodes.len()
114    // }
115
116    // pub fn total_edges(&self) -> usize {
117    //     self.nodes.len() * self.nodes[0].0.len()
118    // }
119
120    /// calculate the activations of the nodes in this layer given the preactivations.
121    /// This operation mutates internal state, which will be read in [`KanLayer::backward()`] and [`KanLayer::update_knots_from_samples()`]
122    ///
123    /// each vector in `preactivations` should be of length `input_dimension`, and each vector in the output will have length `output_dimension`
124    /// # Errors
125    /// Returns an [`LayerError`] if
126    /// * the length of any `preactivation` in `preactivations` is not equal to the input_dimension this layer
127    /// * the output would contain NaNs.
128    ///
129    /// See [`LayerError`] for more information
130    ///
131    /// # Examples
132    /// ```
133    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
134    /// let input_dimension = 3;
135    /// let output_dimension = 4;
136    /// let layer_options = KanLayerOptions {
137    ///     input_dimension,
138    ///     output_dimension,
139    ///     degree: 5,
140    ///     coef_size: 6,
141    /// };
142    /// let mut my_layer = KanLayer::new(&layer_options);
143    /// let preacts = vec![vec![0.0; input_dimension], vec![0.5; input_dimension]];
144    /// let acts = my_layer.forward(preacts)?;
145    /// assert_eq!(acts.len(), 2);
146    /// assert_eq!(acts[0].len(), output_dimension);
147    /// # Ok::<(), fekan::layer_errors::LayerError>(())
148    /// ```
149    pub fn forward(&mut self, preactivations: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, LayerError> {
150        let num_inputs = preactivations.len(); // grab this value, since we're about to move the preactivations into the internal cache
151        self.forward_preamble(preactivations)?;
152        // preactivations dim0 = number of samples, dim1 = input_dimension
153
154        // clone the last `num_inputs` preactivations from the internal cache (since we just moved them in), then transpose them so that dim0 = input_dimension, dim1 = number of samples
155        let mut transposed_preacts = vec![vec![0.0; num_inputs]; self.input_dimension];
156        for i in 0..num_inputs {
157            for j in 0..self.input_dimension {
158                transposed_preacts[j][i] = self.samples[i][j];
159            }
160        }
161        // go output-node-by-output-node so we can sum as we go and reduce memory usage
162        // dim0 = num_inputs, dim1 = output_dimension
163        let mut activations = vec![vec![0.0; self.output_dimension]; num_inputs];
164
165        // not the cleanest implementation maybe, but it'll work
166        for edge_index in 0..self.splines.len() {
167            trace!("Calculating activations for edge {}", edge_index);
168            let in_node_idx = edge_index / self.output_dimension;
169            let out_node_idx = edge_index % self.output_dimension;
170            let sample_wise_outputs =
171                self.splines[edge_index].forward(&transposed_preacts[in_node_idx]);
172            if sample_wise_outputs.iter().any(|v| v.is_nan()) {
173                return Err(LayerError::nans_in_activations(
174                    edge_index,
175                    sample_wise_outputs,
176                    self.splines[edge_index].clone(),
177                ));
178            }
179            for sample_idx in 0..num_inputs {
180                activations[sample_idx][out_node_idx] += sample_wise_outputs[sample_idx];
181            }
182        }
183        self.layer_l1 = Some(
184            self.splines
185                .iter()
186                .map(|s| {
187                    s.l1_norm()
188                        .expect("edges should have L1 norm stored after forward pass")
189                })
190                .sum::<f64>()
191                / self.splines.len() as f64,
192        );
193        trace!("Activations: {:?}", activations);
194        Ok(activations)
195    }
196
197    /// As [`KanLayer::forward`], but divides the work among the passed number of threads
198    pub fn forward_multithreaded(
199        &mut self,
200        preactivations: Vec<Vec<f64>>,
201        num_threads: usize,
202    ) -> Result<Vec<Vec<f64>>, LayerError> {
203        let num_samples = preactivations.len(); // grab this value, since we're about to move the preactivations into the internal cache
204        self.forward_preamble(preactivations)?;
205        // clone the last `num_inputs` preactivations from the internal cache (since we just moved them in), then transpose them so that dim0 = input_dimension, dim1 = number of samples
206        let mut transposed_preacts = vec![vec![0.0; num_samples]; self.input_dimension];
207        for i in 0..num_samples {
208            for j in 0..self.input_dimension {
209                transposed_preacts[j][i] = self.samples[i][j];
210            }
211        }
212        // dim0 = num_samples, dim1 = output_dimension
213        let activations = Arc::new(Mutex::new(vec![
214            vec![0.0; self.output_dimension];
215            num_samples
216        ]));
217
218        /* spawn new threads, give those threads ownership of chunks of splines (so they don't have to worry about locking spline chaches),
219         * a reference to the transposed preacts, and the activations vector behnd a mutex.
220         *
221         * The threads will need to return ownership of the edges; as long as we join the threads in the same order we spawned them,
222         * we should be able to reassemble the splines in the correct order without issue
223         */
224        let edges_per_thread = (self.splines.len() as f64 / num_threads as f64).ceil() as usize;
225        let output_dimension = self.output_dimension;
226        let threaded_result: Result<(), LayerError> = thread::scope(|s| {
227            // since the threads will be adding to the activations vector themselves, they don't need to return anything but the edges they took ownership of
228            let handles: Vec<ScopedJoinHandle<Result<Vec<Edge>, LayerError>>> = (0..num_threads)
229                .map(|thread_idx| {
230                    let edges_for_thread: Vec<Edge> = self
231                        .splines
232                        .drain(0..edges_per_thread.min(self.splines.len()))
233                        .collect();
234                    let transposed_preacts = &transposed_preacts;
235                    let threaded_activations = Arc::clone(&activations);
236                    let thread_result = s.spawn(move || {
237                        let mut thread_edges = edges_for_thread; // just to explicitly move the edges into the thread
238                        for edge_index in 0..thread_edges.len() {
239                            let this_edge: &mut Edge = &mut thread_edges[edge_index];
240                            let true_edge_index = thread_idx * edges_per_thread + edge_index;
241                            let in_node_idx = true_edge_index / output_dimension;
242                            let out_node_idx = true_edge_index % output_dimension;
243                            let sample_wise_outputs =
244                                this_edge.forward(&transposed_preacts[in_node_idx]);
245
246                            if sample_wise_outputs.iter().any(|v| v.is_nan()) {
247                                return Err(LayerError::nans_in_activations(
248                                    edge_index,
249                                    sample_wise_outputs,
250                                    this_edge.clone(),
251                                ));
252                            }
253
254                            let mut mutex_acquired_activations =
255                                threaded_activations.lock().unwrap();
256                            for sample_idx in 0..num_samples {
257                                let sample_activations =
258                                    &mut mutex_acquired_activations[sample_idx];
259                                let out_node = &mut sample_activations[out_node_idx];
260                                let edge_output = sample_wise_outputs[sample_idx];
261                                *out_node += edge_output;
262                            }
263                        }
264                        Ok(thread_edges) // give the edges back once we're done
265                    });
266                    thread_result
267                })
268                .collect();
269            // reassmble the splines in the correct order
270            for handle in handles {
271                let thread_result = handle.join().unwrap()?;
272                self.splines.extend(thread_result);
273            }
274            Ok(())
275        });
276        threaded_result?;
277        self.layer_l1 = Some(
278            self.splines
279                .iter()
280                .map(|s| {
281                    s.l1_norm()
282                        .expect("edges should have L1 norm stored after forward pass")
283                })
284                .sum::<f64>()
285                / self.splines.len() as f64,
286        );
287        let result = activations.lock().unwrap().clone();
288        Ok(result)
289    }
290
291    /// check the length of the preactivation vector and save the inputs to the internal cache
292    fn forward_preamble(&mut self, preactivations: Vec<Vec<f64>>) -> Result<(), LayerError> {
293        for preact in preactivations.iter() {
294            if preact.len() != self.input_dimension {
295                return Err(LayerError::missized_preacts(
296                    preact.len(),
297                    self.input_dimension,
298                ));
299            }
300        }
301        // the preactivations to the internal cache. We'll work from this cache during forward and backward passes
302        let mut preactivations = preactivations;
303        self.samples.append(&mut preactivations);
304        Ok(())
305    }
306
307    /// as [KanLayer::forward], but does not accumulate any internal state
308    ///
309    /// This method should be used when the model is not being trained, for example during inference or validation: when you won't be backpropogating, this method is faster uses less memory than [`KanLayer::forward`]
310    ///
311    /// # Errors
312    /// Returns a [`LayerError`] if...
313    /// * the length of `preactivation` is not equal to the input_dimension this layer
314    /// * the output would contain NaNs.
315
316    pub fn infer(&self, preactivations: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, LayerError> {
317        for preactivation in preactivations.iter() {
318            if preactivation.len() != self.input_dimension {
319                return Err(LayerError::missized_preacts(
320                    preactivation.len(),
321                    self.input_dimension,
322                ));
323            }
324        }
325        let num_inputs = preactivations.len();
326        let mut transposed_preacts = vec![vec![0.0; num_inputs]; self.input_dimension];
327        for i in 0..num_inputs {
328            for j in 0..self.input_dimension {
329                transposed_preacts[j][i] = preactivations[i][j];
330            }
331        }
332        let mut activations = vec![vec![0.0; self.output_dimension]; num_inputs];
333        for edge_index in 0..self.splines.len() {
334            let in_node_idx = edge_index / self.output_dimension;
335            let out_node_idx = edge_index % self.output_dimension;
336            let sample_wise_outputs =
337                self.splines[edge_index].infer(&transposed_preacts[in_node_idx]);
338            for sample_idx in 0..num_inputs {
339                activations[sample_idx][out_node_idx] += sample_wise_outputs[sample_idx];
340            }
341        }
342
343        Ok(activations)
344    }
345
346    /// as [`KanLayer::infer`], but divides the work among the passed number of threads
347    pub fn infer_multithreaded(
348        &self,
349        preactivations: &[Vec<f64>],
350        num_threads: usize,
351    ) -> Result<Vec<Vec<f64>>, LayerError> {
352        for preactivation in preactivations.iter() {
353            if preactivation.len() != self.input_dimension {
354                return Err(LayerError::missized_preacts(
355                    preactivation.len(),
356                    self.input_dimension,
357                ));
358            }
359        }
360        let num_inputs = preactivations.len();
361        let mut transposed_preacts = vec![vec![0.0; num_inputs]; self.input_dimension];
362        for i in 0..num_inputs {
363            for j in 0..self.input_dimension {
364                transposed_preacts[j][i] = preactivations[i][j];
365            }
366        }
367
368        let activations = Arc::new(Mutex::new(vec![
369            vec![0.0; self.output_dimension];
370            num_inputs
371        ]));
372
373        let edges_per_thread = (self.splines.len() as f64 / num_threads as f64).ceil() as usize;
374        let output_dimension = self.output_dimension;
375        let num_samples = preactivations.len();
376        let threaded_result: Result<(), LayerError> = thread::scope(|s| {
377            // since the threads will be adding to the activations vector themselves, they don't need to return anything but the edges they took ownership of
378            let handles: Vec<ScopedJoinHandle<Result<(), LayerError>>> = (0..num_threads)
379                .map(|thread_idx| {
380                    let edges_for_thread: &[Edge] = &self.splines
381                        [thread_idx * edges_per_thread..(thread_idx + 1) * edges_per_thread];
382                    let transposed_preacts = &transposed_preacts;
383                    let threaded_activations = Arc::clone(&activations);
384                    let thread_result = s.spawn(move || {
385                        let thread_edges = edges_for_thread; // just to explicitly move the edges into the thread
386                        for edge_index in 0..thread_edges.len() {
387                            let this_edge: &Edge = &thread_edges[edge_index];
388                            let true_edge_index = thread_idx * edges_per_thread + edge_index;
389                            let in_node_idx = true_edge_index / output_dimension;
390                            let out_node_idx = true_edge_index % output_dimension;
391                            let sample_wise_outputs =
392                                this_edge.infer(&transposed_preacts[in_node_idx]);
393
394                            if sample_wise_outputs.iter().any(|v| v.is_nan()) {
395                                return Err(LayerError::nans_in_activations(
396                                    edge_index,
397                                    sample_wise_outputs,
398                                    this_edge.clone(),
399                                ));
400                            }
401
402                            let mut mutex_acquired_activations =
403                                threaded_activations.lock().unwrap();
404                            for sample_idx in 0..num_samples {
405                                let sample_activations =
406                                    &mut mutex_acquired_activations[sample_idx];
407                                let out_node = &mut sample_activations[out_node_idx];
408                                let edge_output = sample_wise_outputs[sample_idx];
409                                *out_node += edge_output;
410                            }
411                        }
412                        Ok(()) // give the edges back once we're done
413                    });
414                    thread_result
415                })
416                .collect();
417            // reassmble the splines in the correct order
418            for handle in handles {
419                handle.join().unwrap()?;
420            }
421            Ok(())
422        });
423        threaded_result?;
424
425        let result = activations.lock().unwrap().clone();
426        Ok(result)
427    }
428    /// Using samples memoized by [`KanLayer::forward`], update the knot vectors for each incoming edge in this layer.
429    ///
430    /// When `knot_adaptivity` is 0, the new knot vectors will be uniformly distributed over the range spanned by the samples;
431    /// when `knot_adaptivity` is 1, the new knots will be placed at the quantiles of the samples. 0 < `knot_adaptivity` < 1 will interpolate between these two extremes.
432    ///
433    /// ## Warning
434    /// calling this function with `knot_adaptivity = 1` can result in a large number of knots being placed at the same value, which can cause [`KanLayer::forward`] to output NaNs. In practice, `knot_adaptivity` should be set to something like 0.1, but anything < 1.0 should be fine
435    ///
436    /// calling this function with fewer samples than the number of knots in a spline AND `knot_adaptivity` > 0 results in undefined behavior
437    ///
438    /// # Errors
439    /// Returns an error if the layer has no memoized samples, which most likely means that [`KanLayer::forward`] has not been called since initialization or the last call to [`KanLayer::clear_samples`]
440    ///
441    /// # Examples
442    ///
443    /// Update the knots of a model every few samples during training, to make sure that the supported input range of a given layer covers the output range of the previous layer.
444    /// ```
445    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
446    ///
447    /// # let input_size = 5;
448    /// # let output_size = 3;
449    /// # let layer_options = KanLayerOptions {input_dimension: 5,output_dimension: 4,degree: 3, coef_size: 6};
450    /// # fn calculate_gradient(output: Vec<Vec<f64>>, label: Vec<f64>) -> Vec<Vec<f64>> {vec![vec![0.0; output[0].len()]; output.len()]}
451    /// # let training_data = vec![(vec![0.1, 0.2, 0.3, 0.4, 0.5], 1.0f64), (vec![0.2, 0.3, 0.4, 0.5, 0.6], 0.0f64), (vec![0.3, 0.4, 0.5, 0.6, 0.7], 1.0f64)];
452    /// let mut my_layer = KanLayer::new(&layer_options);
453    /// # let batch_size = 1;
454    /// for batch_data in training_data.chunks(batch_size) {
455    ///     let batch_features = batch_data.iter().map(|(f, _)| f.clone()).collect::<Vec<Vec<f64>>>();
456    ///     let batch_output: Vec<Vec<f64>> = my_layer.forward(batch_features)?;
457    ///     let batch_labels: Vec<f64> = batch_data.iter().map(|(_, l)| *l).collect();
458    ///     let batch_gradients: Vec<Vec<f64>> = calculate_gradient(batch_output, batch_labels);
459    ///     let _ = my_layer.backward(&batch_gradients)?;
460    ///     my_layer.update(0.1, 1.0, 1.0); // updating the model's parameters changes the output range of the b-splines that make up the model
461    ///     my_layer.update_knots_from_samples(0.1)?; // updating the knots adjusts the input range of the b-splines to match the output range of the previous layer
462    /// }
463    /// # Ok::<(), fekan::layer_errors::LayerError>(())
464    ///```
465    /// Note on the above example: even in this example, where range of input to the layer is the range of values in the training data and does not change during training, it's still important to update the knots at least once, after a good portion of the training data has been passed through, to ensure that the layer's supported input range covers the range spanned by the training data.
466    /// KanLayer knots are initialized to span the range [-1, 1], so if the training data is outside that range, the activations will be 0.0 until the knots are updated.
467    ///
468    ///
469    /// The below example shows why regularly updating the knots is important - especially early in training, before the model starts to converge when its parameters are changing rapidly
470    /// ```
471    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
472    /// # let some_layer_options = KanLayerOptions {input_dimension: 2,output_dimension: 4,degree: 5, coef_size: 6};
473    /// let mut my_layer = KanLayer::new(&some_layer_options);
474    /// let sample1 = vec![vec![100f64, -100f64]];
475    /// let sample2 = vec![vec![-100f64, 100f64]];
476    ///
477    /// let acts = my_layer.forward(sample1.clone()).unwrap();
478    /// assert!(acts[0].iter().all(|x| *x == 0.0)); // the preacts were all outside the initial knot range, so the activations should all be 0
479    /// let acts = my_layer.forward(sample2).unwrap();
480    /// assert!(acts[0].iter().all(|x| *x == 0.0)); // the preacts were all outside the initial knot range, so the activations should all be 0
481    /// my_layer.update_knots_from_samples(0.0).unwrap(); // we don't have enough samples to calculate quantiles, so we have to keep the knots uniformly distributed. In practice, this function should be called every few hundred forward passes or so
482    /// let new_acts = my_layer.forward(sample1).unwrap();
483    /// assert!(new_acts[0].iter().all(|x| *x != 0.0)); // the knot range now covers the samples, so the activations should be non-zero
484    /// # Ok::<(), fekan::layer_errors::LayerError>(())
485    /// ```
486    pub fn update_knots_from_samples(&mut self, knot_adaptivity: f64) -> Result<(), LayerError> {
487        trace!("Updating knots from {} samples", self.samples.len()); // trace since this happens every batch
488        if self.samples.is_empty() {
489            return Err(LayerError::no_samples());
490        }
491        // lets construct a sorted vector of the samples for each incoming value
492        // first we transpose the samples, so that dim0 = input_dimension, dim1 = number of samples
493        let mut sorted_samples: Vec<Vec<f64>> =
494            vec![Vec::with_capacity(self.samples.len()); self.input_dimension];
495        for i in 0..self.samples.len() {
496            for j in 0..self.input_dimension {
497                sorted_samples[j].push(self.samples[i][j]); // remember, push is just an indexed insert that checks capacity first. As long as capacity isn't exceeded, push is O(1)
498            }
499        }
500
501        // now we sort along dim1
502        for j in 0..self.input_dimension {
503            sorted_samples[j].sort_by(|a, b| a.partial_cmp(b).unwrap());
504        }
505        // TODO: it might be worth checking if the above operation would be faster if I changed the order of the loops and sorted inside the outer loop. Maybe something to do with cache performance?
506
507        for (idx, spline) in self.splines.iter_mut().enumerate() {
508            let sample_idx = idx % self.input_dimension; // the first `input_dimension` splines belong to the first "node", so every `input_dimension` splines, we move to the next node and reset which inner sample vector we're looking at
509            let sample = &sorted_samples[sample_idx];
510            trace!("Updating knots for edge {} from samples", idx);
511            spline.update_knots_from_samples(sample, knot_adaptivity);
512        }
513        if log::log_enabled!(log::Level::Trace) {
514            let mut ranges = vec![(0.0, 0.0); self.splines.len()];
515            for (idx, spline) in self.splines.iter().enumerate() {
516                ranges[idx] = spline.get_full_input_range();
517            }
518            trace!("Supported input ranges after knot update: {:#?}", ranges);
519        }
520
521        Ok(())
522    }
523
524    /// as [`KanLayer::update_knots_from_samples`], but divides the work among the passed number of threads
525    pub fn update_knots_from_samples_multithreaded(
526        &mut self,
527        knot_adaptivity: f64,
528        num_threads: usize,
529    ) -> Result<(), LayerError> {
530        if self.samples.is_empty() {
531            return Err(LayerError::no_samples());
532        }
533
534        let mut sorted_samples: Vec<Vec<f64>> =
535            vec![Vec::with_capacity(self.samples.len()); self.input_dimension];
536        for i in 0..self.samples.len() {
537            for j in 0..self.input_dimension {
538                sorted_samples[j].push(self.samples[i][j]); // remember, push is just an indexed insert that checks capacity first. As long as capacity isn't exceeded, push is O(1)
539            }
540        }
541
542        // dim0 = input_dimension, dim1 = number of samples
543        // now we sort along dim1
544        for j in 0..self.input_dimension {
545            sorted_samples[j].sort_by(|a, b| a.partial_cmp(b).unwrap());
546        }
547        // TODO: it might be worth checking if the above operation would be faster if I changed the order of the loops and sorted inside the outer loop. Maybe something to do with cache performance?
548        let edges_per_thread = (self.splines.len() as f64 / num_threads as f64).ceil() as usize;
549        let output_dimension = self.output_dimension;
550        thread::scope(|s| {
551            let handles: Vec<ScopedJoinHandle<Vec<Edge>>> = (0..num_threads)
552                .map(|thread_idx| {
553                    let edges_for_thread: Vec<Edge> = self
554                        .splines
555                        .drain(0..edges_per_thread.min(self.splines.len()))
556                        .collect();
557                    let sorted_samples = &sorted_samples;
558                    s.spawn(move || {
559                        let mut thread_edges = edges_for_thread; // just to explicitly move the edges into the thread
560                        for edge_idx in 0..thread_edges.len() {
561                            let this_edge = &mut thread_edges[edge_idx];
562                            let true_edge_idx = thread_idx * edges_per_thread + edge_idx;
563                            let input_idx = true_edge_idx / output_dimension;
564                            let samples = &sorted_samples[input_idx];
565                            this_edge.update_knots_from_samples(samples, knot_adaptivity);
566                        }
567                        thread_edges // make sure to give the edges back once we're done
568                    })
569                })
570                .collect();
571            for handle in handles {
572                let thread_result: Vec<Edge> = handle.join().unwrap();
573                self.splines.extend(thread_result);
574            }
575        });
576        Ok(())
577    }
578
579    /// wipe the internal state that tracks the samples used to update the knot vectors
580    ///
581    /// # Examples
582    ///
583    /// ```
584    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
585    /// # let some_layer_options = KanLayerOptions {input_dimension: 2,output_dimension: 4,degree: 5, coef_size: 6};
586    /// let mut my_layer = KanLayer::new(&some_layer_options);
587    /// /* After several forward passes... */
588    /// # let samples = vec![vec![100f64, -100f64],vec![-100f64, 100f64]];
589    /// # let _acts = my_layer.forward(samples)?;
590    /// let update_result = my_layer.update_knots_from_samples(0.0);
591    /// assert!(update_result.is_ok());
592    /// my_layer.clear_samples();
593    /// let update_result = my_layer.update_knots_from_samples(0.0);
594    /// assert!(update_result.is_err()); // we've cleared the samples, so we can't update the knot vectors
595    /// # Ok::<(), fekan::layer_errors::LayerError>(())
596    pub fn clear_samples(&mut self) {
597        self.samples.clear();
598    }
599
600    /// Given a vector of gradient values for the nodes in this layer, backpropogate the error through the layer, updating the internal gradients for the incoming edges
601    /// and return the error for the previous layer.
602    ///
603    /// This function relies on mutated inner state and should be called after [`KanLayer::forward`].
604    ///
605    /// Calculated gradients are stored internally, and only applied during [`KanLayer::update`].
606    ///
607    /// # Errors
608    /// Returns a [`LayerError`] if...
609    /// * the length of `gradient` is not equal to the number of nodes in this layer (i.e this layer's output dimension)
610    /// * this method is called before [`KanLayer::forward`]
611    ///
612    /// # Examples
613    /// Backpropgate the error through a two-layer network, and update the gradients
614    /// ```
615    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
616    /// let first_layer_options = KanLayerOptions { input_dimension: 2, output_dimension: 4, degree: 5, coef_size: 6 };
617    /// let second_layer_options = KanLayerOptions { input_dimension: 4, output_dimension: 3, degree: 5, coef_size: 6 };
618    /// let mut first_layer = KanLayer::new(&first_layer_options);
619    /// let mut second_layer = KanLayer::new(&second_layer_options);
620    /// /* forward pass */
621    /// let preacts = vec![vec![0.0, 0.5]];
622    /// let acts = first_layer.forward(preacts).unwrap();
623    /// let output = second_layer.forward(acts).unwrap();
624    /// /* calculate error */
625    /// # let error = vec![vec![1.0, 0.5, 0.5]];
626    /// assert_eq!(error[0].len(), second_layer_options.output_dimension);
627    /// let first_layer_error = second_layer.backward(&error).unwrap();
628    /// assert_eq!(first_layer_error[0].len(), first_layer_options.output_dimension);
629    /// let input_error = first_layer.backward(&first_layer_error).unwrap();
630    /// assert_eq!(input_error[0].len(), first_layer_options.input_dimension);
631    ///
632    /// // apply the gradients
633    /// let learning_rate = 0.1;
634    /// first_layer.update(learning_rate, 1.0, 1.0);
635    /// second_layer.update(learning_rate, 1.0, 1.0);
636    /// // reset the gradients
637    /// first_layer.zero_gradients();
638    /// second_layer.zero_gradients();
639    /// /* continue training */
640    /// ```
641    pub fn backward(&mut self, gradients: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, LayerError> {
642        for gradient in gradients.iter() {
643            if gradient.len() != self.output_dimension {
644                return Err(LayerError::missized_gradient(
645                    gradient.len(),
646                    self.output_dimension,
647                ));
648            }
649
650            if gradient.iter().any(|f| f.is_nan()) {
651                return Err(LayerError::nans_in_gradient());
652            }
653        }
654        let layer_l1 = self
655            .layer_l1
656            .ok_or(LayerError::backward_before_forward(None, 0))?;
657        let layer_entropy = self.layer_entropy();
658
659        let num_gradients = gradients.len();
660        let mut transposed_gradients = vec![vec![0.0; num_gradients]; self.output_dimension];
661        for i in 0..num_gradients {
662            for j in 0..self.output_dimension {
663                let to_move_gradient = gradients[i][j]; // separate the lines for easier debugging
664                transposed_gradients[j][i] = to_move_gradient;
665            }
666        }
667        let mut backpropped_gradients = vec![vec![0.0; self.input_dimension]; num_gradients];
668        let mut sibling_entropy_terms: VecDeque<f64> = self
669            .splines
670            .iter()
671            .map(|s| {
672                let edge_l1 = s.l1_norm().expect("edge should have an L1");
673                if edge_l1 == 0.0 {
674                    0.0
675                } else {
676                    edge_l1 * ((edge_l1 / layer_l1).ln() + 1.0)
677                }
678            })
679            .collect();
680        for edge_index in 0..self.splines.len() {
681            trace!("Backpropping gradients for edge {}", edge_index);
682            let in_node_idx = edge_index / self.output_dimension;
683            let out_node_idx = edge_index % self.output_dimension;
684            // let sibling_l1s: &[f64] = &sibling_entropy_terms.as_slices().0[1..];
685            let sample_wise_outputs = self.splines[edge_index]
686                .backward(&transposed_gradients[out_node_idx], layer_l1, layer_entropy)
687                .map_err(|e| LayerError::backward_before_forward(Some(e), edge_index))?; // TODO incorporate sparsity losses
688            trace!(
689                "Backpropped gradients for edge {}: {:?}",
690                edge_index,
691                sample_wise_outputs
692            );
693            for sample_idx in 0..num_gradients {
694                backpropped_gradients[sample_idx][in_node_idx] += sample_wise_outputs[sample_idx];
695            }
696
697            // rotate the sibling L1s so that the next edge gets the correct values
698            sibling_entropy_terms.rotate_left(1);
699            sibling_entropy_terms.make_contiguous(); // necessary for .as_slices() above to work properly
700        }
701        trace!("Backpropped gradients: {:?}", backpropped_gradients);
702        self.layer_l1 = None; // The L1 should be re-set after the next forward pass, and backward should not be called before forward, so this serves as a (redundant) check
703        Ok(backpropped_gradients)
704    }
705
706    /// as [`KanLayer::backward`], but divides the work among the passed number of threads
707    pub fn backward_multithreaded(
708        &mut self,
709        gradients: &[Vec<f64>],
710        num_threads: usize,
711    ) -> Result<Vec<Vec<f64>>, LayerError> {
712        for gradient in gradients.iter() {
713            if gradient.len() != self.output_dimension {
714                return Err(LayerError::missized_gradient(
715                    gradient.len(),
716                    self.output_dimension,
717                ));
718            }
719
720            if gradient.iter().any(|f| f.is_nan()) {
721                return Err(LayerError::nans_in_gradient());
722            }
723        }
724        let layer_l1 = self
725            .layer_l1
726            .ok_or(LayerError::backward_before_forward(None, 0))?;
727        let layer_entropy = self.layer_entropy();
728
729        let num_gradients = gradients.len();
730        let mut transposed_gradients = vec![vec![0.0; num_gradients]; self.output_dimension];
731        for i in 0..num_gradients {
732            for j in 0..self.output_dimension {
733                let to_move_gradient = gradients[i][j]; // separate the lines for easier debugging
734                transposed_gradients[j][i] = to_move_gradient;
735            }
736        }
737        // dim0 = num_gradients, dim1 = input_dimension
738        let backpropped_gradients = Arc::new(Mutex::new(vec![
739            vec![0.0; self.input_dimension];
740            num_gradients
741        ]));
742        let backprop_result = thread::scope(|s| {
743            let edges_per_thread = (self.splines.len() as f64 / num_threads as f64).ceil() as usize;
744            let output_dimension = self.output_dimension;
745            // let all_edge_l1s = self
746            //     .splines
747            //     .iter()
748            //     .map(|s| s.l1_norm().expect("edge should have an L1"))
749            //     .collect::<Vec<f64>>();
750            let handles: Vec<ScopedJoinHandle<Result<Vec<Edge>, LayerError>>> = (0..num_threads)
751                .map(|thread_idx| {
752                    let edges_for_thread: Vec<Edge> = self
753                        .splines
754                        .drain(0..edges_per_thread.min(self.splines.len()))
755                        .collect();
756                    // let edge_l1s = all_edge_l1s.clone();
757                    let transposed_gradients = &transposed_gradients;
758                    let threaded_gradients = Arc::clone(&backpropped_gradients);
759                    let thread_result = s.spawn(move || {
760                        let mut thread_edges = edges_for_thread; // explicitly move the edges into the thread because it makes me feel good
761                        for edge_index in 0..thread_edges.len() {
762                            let this_edge: &mut Edge = &mut thread_edges[edge_index];
763                            let true_edge_index = thread_idx * edges_per_thread + edge_index;
764                            // let sibling_l1s: Vec<f64> = edge_l1s
765                            //     .iter()
766                            //     .enumerate()
767                            //     .filter(|(idx, _l1)| *idx != true_edge_index)
768                            //     .map(|(_idx, l1)| *l1)
769                            //     .collect();
770                            let in_node_idx = true_edge_index / output_dimension;
771                            let out_node_idx = true_edge_index % output_dimension;
772                            let sample_wise_outputs = this_edge
773                                .backward(
774                                    &transposed_gradients[out_node_idx],
775                                    layer_l1,
776                                    layer_entropy,
777                                )
778                                .map_err(|e| {
779                                    LayerError::backward_before_forward(Some(e), true_edge_index)
780                                })?;
781                            let mut mutex_acquired_gradients = threaded_gradients.lock().unwrap();
782                            for sample_idx in 0..num_gradients {
783                                let sample_gradients = &mut mutex_acquired_gradients[sample_idx];
784                                let in_node = &mut sample_gradients[in_node_idx];
785                                let edge_output = sample_wise_outputs[sample_idx];
786                                *in_node += edge_output;
787                            }
788                        }
789                        Ok(thread_edges)
790                    });
791                    thread_result
792                })
793                .collect();
794            for handle in handles {
795                let thread_result: Vec<Edge> = handle.join().unwrap()?;
796                self.splines.extend(thread_result);
797            }
798            Ok(())
799        });
800        backprop_result?;
801        let result = backpropped_gradients.lock().unwrap().clone();
802        Ok(result)
803    }
804
805    fn layer_entropy(&self) -> f64 {
806        let layer_l1 = self
807            .layer_l1
808            .expect("Layer L1 should be set to calculate entropy");
809        self.splines
810            .iter()
811            .map(|s| {
812                s.l1_norm()
813                    .expect("Edge should have an L1 to calculate entropy")
814            })
815            .map(|edge_l1| {
816                if edge_l1 == 0.0 {
817                    0.0
818                } else {
819                    edge_l1 / layer_l1 * (edge_l1 / layer_l1).ln()
820                }
821            })
822            .sum::<f64>()
823            / self.splines.len() as f64
824    }
825
826    // /// as [KanLayer::backward], but divides the work among the passed thread pool
827    // pub fn backward_concurrent(
828    //     &mut self,
829    //     error: &[f64],
830    //     thread_pool: &ThreadPool,
831    // ) -> Result<Vec<f64>, BackwardLayerError> {
832    //     if error.len() != self.output_dimension {
833    //         return Err(BackwardLayerError::MissizedGradientError {
834    //             actual: error.len(),
835    //             expected: self.output_dimension,
836    //         });
837    //     }
838    //     // if error.iter().any(|f| f.is_nan()) {
839    //     //     return Err(BackwardLayerError::ReceivedNanError);
840    //     // }
841
842    //     let backprop_result: (Vec<f64>, Vec<BackwardSplineError>) = thread_pool.install(|| {
843    //         let mut input_gradient = vec![0.0; self.input_dimension];
844    //         let mut spline_errors = Vec::with_capacity(self.splines.len());
845    //         for i in 0..self.splines.len() {
846    //             // every `input_dimension` splines belong to the same node, and thus will use the same error value.
847    //             // "Distribute" the error at a given node among all incoming edges
848    //             let error_at_edge_output =
849    //                 error[i / self.input_dimension] / self.input_dimension as f64;
850    //             match self.splines[i].backward(error_at_edge_output) {
851    //                 Ok(error_at_edge_input) => {
852    //                     input_gradient[i % self.input_dimension] += error_at_edge_input
853    //                 }
854    //                 Err(e) => spline_errors.push(e),
855    //             }
856    //         }
857    //         (input_gradient, spline_errors)
858    //     });
859    //     if !backprop_result.1.is_empty() {
860    //         return Err(backprop_result.1[0].into());
861    //     }
862    //     Ok(backprop_result.0)
863    // }
864
865    /// set the length of the knot vectors for each incoming edge in this layer
866    ///
867    /// Generally used multiple times throughout training to increase the number of knots in the spline to increase fidelity of the curve
868    /// # Notes
869    /// * The number of control points is set to `knot_length - degree - 1`, and the control points are calculated using lstsq to approximate the previous curve
870    /// # Errors
871    /// * Returns an error if any of the splines' updated control points include NaNs. This should never happen, but if it does, it's a bug
872    /// # Examples
873    /// Extend the knot vectors during training to increase the fidelity of the splines
874    /// ```
875    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
876    /// let input_dimension = 2;
877    /// let output_dimension = 4;
878    /// let degree = 5;
879    /// let coef_size = 6;
880    /// let layer_options = KanLayerOptions {
881    ///     input_dimension,
882    ///     output_dimension,
883    ///     degree,
884    ///     coef_size
885    /// };
886    /// let mut my_layer = KanLayer::new(&layer_options);
887    ///
888    /// let num_splines = input_dimension * output_dimension;
889    /// let expected_knots_per_spline = coef_size + degree + 1;
890    /// assert_eq!(my_layer.knot_length(), expected_knots_per_spline, "starting knots per edge");
891    /// assert_eq!(my_layer.parameter_count(), num_splines * (coef_size + my_layer.knot_length()), "starting parameter count");
892    ///
893    /// /* train the layer a bit to start shaping the splines */
894    ///
895    /// let new_knot_length = my_layer.knot_length() * 2;
896    /// let update_result = my_layer.set_knot_length(new_knot_length);
897    /// assert!(update_result.is_ok(), "update knots");
898    ///
899    /// assert_eq!(my_layer.knot_length(), new_knot_length, "ending knots per edge");
900    /// let new_coef_size = new_knot_length - degree - 1;
901    /// assert_eq!(my_layer.parameter_count(), num_splines * (new_coef_size + new_knot_length), "ending parameter count");
902    ///
903    /// /* continue training layer, now with increased fidelity in the spline */
904    /// ```
905    /// # Panics
906    /// Panics if the Singular Value Decomposition (SVD) used to calculate the control points fails. This should never happen, but if it does, it's a bug
907    pub fn set_knot_length(&mut self, knot_length: usize) -> Result<(), LayerError> {
908        for i in 0..self.splines.len() {
909            self.splines[i]
910                .set_knot_length(knot_length)
911                .map_err(|e| LayerError::set_knot_length(i, e))?;
912        }
913        Ok(())
914    }
915
916    /// as [`KanLayer::set_knot_length`], but divides the work among the passed number of threads
917    pub fn set_knot_length_multithreaded(
918        &mut self,
919        knot_length: usize,
920        num_threads: usize,
921    ) -> Result<(), LayerError> {
922        let edges_per_thread = (self.splines.len() as f64 / num_threads as f64).ceil() as usize;
923        let threaded_result: Result<(), LayerError> = thread::scope(|s| {
924            let handles: Vec<ScopedJoinHandle<Result<Vec<Edge>, LayerError>>> = (0..num_threads)
925                .map(|thread_idx| {
926                    let edges_for_thread: Vec<Edge> = self
927                        .splines
928                        .drain(0..edges_per_thread.min(self.splines.len()))
929                        .collect();
930                    let thread_result = s.spawn(move || {
931                        let mut thread_edges = edges_for_thread; // just to explicitly move the edges into the thread
932                        for edge_index in 0..thread_edges.len() {
933                            let this_edge: &mut Edge = &mut thread_edges[edge_index];
934                            this_edge.set_knot_length(knot_length).map_err(|e| {
935                                LayerError::set_knot_length(
936                                    edge_index + thread_idx * edges_per_thread,
937                                    e,
938                                )
939                            })?;
940                        }
941                        Ok(thread_edges) // give the edges back once we're done
942                    });
943                    thread_result
944                })
945                .collect();
946            for handle in handles {
947                let thread_result = handle.join().unwrap()?;
948                self.splines.extend(thread_result);
949            }
950            Ok(())
951        });
952        threaded_result?;
953        Ok(())
954    }
955
956    /// return the length of the knot vectors for each incoming edge in this layer
957    /// # Examples
958    /// ```
959    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
960    /// let layer_options = KanLayerOptions {
961    ///     input_dimension: 2,
962    ///     output_dimension: 4,
963    ///     degree: 5,
964    ///     coef_size: 6
965    /// };
966    /// let mut my_layer = KanLayer::new(&layer_options);
967    /// assert_eq!(my_layer.knot_length(), 6 + 5 + 1);
968    pub fn knot_length(&self) -> usize {
969        for spline in self.splines.iter() {
970            if spline.knots().len() > 0 {
971                return spline.knots().len();
972            }
973        }
974        0
975    }
976
977    /// update the control points for each incoming edge in this layer given the learning rate
978    ///
979    /// this function relies on internally stored gradients calculated during [`KanLayer::backward()`]
980    ///
981    /// # Examples
982    /// see [`KanLayer::backward`]
983    pub fn update(&mut self, learning_rate: f64, l1_penalty: f64, entropy_penalty: f64) {
984        for spline in self.splines.iter_mut() {
985            spline.update_control_points(learning_rate, l1_penalty, entropy_penalty);
986        }
987    }
988
989    /// as [`KanLayer::update`], but divides the work among the passed number of threads
990    pub fn update_multithreaded(
991        &mut self,
992        learning_rate: f64,
993        l1_penalty: f64,
994        entropy_penalty: f64,
995        num_threads: usize,
996    ) {
997        let edges_per_thread = (self.splines.len() as f64 / num_threads as f64).ceil() as usize;
998        thread::scope(|s| {
999            let handles: Vec<ScopedJoinHandle<Vec<Edge>>> = (0..num_threads)
1000                .map(|_| {
1001                    let edges_for_thread: Vec<Edge> = self
1002                        .splines
1003                        .drain(0..edges_per_thread.min(self.splines.len()))
1004                        .collect();
1005                    let thread_result = s.spawn(move || {
1006                        let mut thread_edges = edges_for_thread; // just to explicitly move the edges into the thread
1007                        for edge_index in 0..thread_edges.len() {
1008                            let this_edge: &mut Edge = &mut thread_edges[edge_index];
1009                            this_edge.update_control_points(
1010                                learning_rate,
1011                                l1_penalty,
1012                                entropy_penalty,
1013                            );
1014                        }
1015                        thread_edges // make sure to give the edges back once we're done
1016                    });
1017                    thread_result
1018                })
1019                .collect();
1020            for handle in handles {
1021                let thread_result: Vec<Edge> = handle.join().unwrap();
1022                self.splines.extend(thread_result);
1023            }
1024        });
1025    }
1026
1027    /// clear gradients for each incoming edge in this layer
1028    ///
1029    /// # Examples
1030    /// see [`KanLayer::backward`]
1031    pub fn zero_gradients(&mut self) {
1032        for spline in self.splines.iter_mut() {
1033            spline.zero_gradients();
1034        }
1035    }
1036
1037    /// return the total number of parameters in this layer.
1038    /// A layer has `input_dimension * output_dimension` splines, each with `degree + coef_size + 1` knots and `coef_size` control points
1039    ///
1040    /// #Examples
1041    /// ```
1042    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
1043    /// let layer_options = KanLayerOptions {
1044    ///     input_dimension: 2,
1045    ///     output_dimension: 4,
1046    ///     degree: 5,
1047    ///     coef_size: 6
1048    /// };
1049    /// let my_layer = KanLayer::new(&layer_options);
1050    /// assert_eq!(my_layer.parameter_count(), 2 * 4 * (6 + (5 + 6 + 1)));
1051    ///```
1052    pub fn parameter_count(&self) -> usize {
1053        self.input_dimension * self.output_dimension * self.splines[0].parameter_count()
1054    }
1055
1056    /// returns the total number of trainable parameters in this layer.
1057    /// A layer has `input_dimension * output_dimension` splines, each with coef_size` control points, which are the trainable parameter in a KAN layer
1058    /// # Examples
1059    /// ```
1060    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
1061    /// let layer_options = KanLayerOptions {
1062    ///     input_dimension: 2,
1063    ///     output_dimension: 4,
1064    ///     degree: 5,
1065    ///     coef_size: 6
1066    /// };
1067    /// let my_layer = KanLayer::new(&layer_options);
1068    /// assert_eq!(my_layer.trainable_parameter_count(), 2 * 4 * 6);
1069    ///```
1070    pub fn trainable_parameter_count(&self) -> usize {
1071        self.input_dimension * self.output_dimension * self.splines[0].trainable_parameter_count()
1072    }
1073
1074    /// return the number of incoming edges to nodes in this layer
1075    pub fn total_edges(&self) -> usize {
1076        self.input_dimension * self.output_dimension
1077    }
1078
1079    /// Create a new KanLayer by merging the splines of multiple KanLayers. Splines are merged by averaging their knots and control points.
1080    /// `new_layer.splines[0] = spline_merge([layer1.splines[0], layer2.splines[0], ...])`, etc. The output of the merged layer is not necessarily the average of the outputs of the input layers.
1081    /// # Errors
1082    /// Returns a [`LayerError`] if...
1083    /// * `kan_layers` is empty
1084    /// * the input dimensions of the layers in `kan_layers` are not all equal
1085    /// * the output dimensions of the layers in `kan_layers` are not all equal
1086    /// * there is an error merging the splines of the layers, caused by the splines having different: degrees, number of control points, number or knots from each other
1087    /// # Examples
1088    /// Train a layer using multiple threads, then merge the results
1089    /// ```
1090    /// use fekan::kan_layer::{KanLayer, KanLayerOptions};
1091    /// use std::thread;
1092    /// # use fekan::Sample;
1093    /// # let layer_options = KanLayerOptions {
1094    /// #    input_dimension: 2,
1095    /// #    output_dimension: 4,
1096    /// #    degree: 5,
1097    /// #    coef_size: 6
1098    /// # };
1099    /// # let num_training_threads = 1;
1100    /// # let training_data = vec![Sample::new_regression_sample(vec![], 0.0)];
1101    /// # fn train_layer(layer: KanLayer, data: &[Sample]) -> KanLayer {layer}
1102    /// let my_layer = KanLayer::new(&layer_options);
1103    /// let partially_trained_layers: Vec<KanLayer> = thread::scope(|s|{
1104    ///     let chunk_size = f32::ceil(training_data.len() as f32 / num_training_threads as f32) as usize; // round up, since .chunks() gives up-to chunk_size chunks. This way to don't leave any data on the cutting room floor
1105    ///     let handles: Vec<_> = training_data.chunks(chunk_size).map(|training_data_chunk|{
1106    ///         let clone_layer = my_layer.clone();
1107    ///         s.spawn(move ||{
1108    ///             train_layer(clone_layer, training_data_chunk) // `train_layer` is a stand-in for whatever function you're using to train the layer - not actually defined in this crate
1109    ///         })
1110    ///     }).collect();
1111    ///     handles.into_iter().map(|handle| handle.join().unwrap()).collect()
1112    /// });
1113    /// let fully_trained_layer = KanLayer::merge_layers(partially_trained_layers)?;
1114    /// # Ok::<(), fekan::layer_errors::LayerError>(())
1115    /// ```
1116    pub fn merge_layers(kan_layers: Vec<KanLayer>) -> Result<KanLayer, LayerError> {
1117        if kan_layers.is_empty() {
1118            return Err(LayerError::merge_no_layers());
1119        }
1120        let expected_input_dimension = kan_layers[0].input_dimension;
1121        let expected_output_dimension = kan_layers[0].output_dimension;
1122        // check that all layers have the same input and output dimensions
1123        for i in 1..kan_layers.len() {
1124            if kan_layers[i].input_dimension != expected_input_dimension {
1125                return Err(LayerError::merge_mismatched_input_dimension(
1126                    i,
1127                    expected_input_dimension,
1128                    kan_layers[i].input_dimension,
1129                ));
1130            }
1131            if kan_layers[i].output_dimension != expected_output_dimension {
1132                return Err(LayerError::merge_mismatched_output_dimension(
1133                    i,
1134                    expected_output_dimension,
1135                    kan_layers[i].output_dimension,
1136                ));
1137            }
1138        }
1139        let edge_count = expected_input_dimension * expected_output_dimension;
1140        // // now build a row-major matrix of splines where each column is the splines in a given layer, and the rows are the ith spline in each layer
1141        // // splines_to_merge = [[L0_S0, L1_S0, ... LJ_S0],
1142        // //                     [L0_S1, L1_S1, ... LJ_S1],
1143        // //                     ...
1144        // //                     [L0_SN, L1_SN, ... LJ_SN]]
1145        // let num_splines = expected_input_dimension * expected_output_dimension;
1146        // let mut splines_to_merge: VecDeque<Vec<Edge>> = vec![vec![]; num_splines].into();
1147        // //populated in column-major order
1148        // for j in 0..kan_layers.len() {
1149        //     for i in 0..num_splines {
1150        //         splines_to_merge[i].push(kan_layers[j].splines.remove(i));
1151        //     }
1152        // }
1153        // let mut merged_splines = Vec::with_capacity(num_splines);
1154        // let mut i = 0;
1155        // while let Some(splines) = splines_to_merge.pop_front() {
1156        //     let merge_result =
1157        //         Edge::merge_edges(splines).map_err(|e| LayerError::spline_merge(i, e))?;
1158        //     i += 1;
1159        //     merged_splines.push(merge_result);
1160        // }
1161        let mut all_edges: Vec<VecDeque<Edge>> = kan_layers
1162            .into_iter()
1163            .map(|layer| layer.splines.into())
1164            .collect();
1165        let mut merged_edges =
1166            Vec::with_capacity(expected_input_dimension * expected_output_dimension);
1167        for i in 0..edge_count {
1168            let edges_to_merge: Vec<Edge> = all_edges
1169                .iter_mut()
1170                .map(|layer_dequeue| {
1171                    layer_dequeue
1172                        .pop_front()
1173                        .expect("iterated past end of dequeue while merging layers")
1174                })
1175                .collect();
1176            merged_edges.push(
1177                Edge::merge_edges(edges_to_merge).map_err(|e| LayerError::spline_merge(i, e))?,
1178            );
1179        }
1180
1181        Ok(KanLayer {
1182            splines: merged_edges,
1183            input_dimension: expected_input_dimension,
1184            output_dimension: expected_output_dimension,
1185            samples: vec![],
1186            layer_l1: None,
1187        })
1188    }
1189
1190    /// does no useful work at the moment - only here for benchmarking
1191    pub fn bench_suggest_symbolic(&self) {
1192        for (_idx, spline) in self.splines.iter().enumerate() {
1193            spline.suggest_symbolic(1);
1194        }
1195    }
1196
1197    /// test each spline in the layer for similarity to a symbolic function (e.g x^2, sin(x), etc.). If the R^2 value of the best fit is greater than `r2_threshold`, replace the spline with the symbolic function
1198    ///
1199    /// Useful at the end of training to enhance interpretability of the model
1200    pub fn test_and_set_symbolic(&mut self, r2_threshold: f64) -> Vec<usize> {
1201        debug!(
1202            "Testing and setting symbolic functions with R2 >= {}",
1203            r2_threshold
1204        );
1205        let mut clamped_edges = Vec::new();
1206        for i in 0..self.splines.len() {
1207            trace!("Testing edge {}", i);
1208            let mut suggestions = self.splines[i].suggest_symbolic(1);
1209            if suggestions.is_empty() {
1210                // pruned or already-symbolified edges will return an empty vector
1211                continue;
1212            }
1213            let (possible_symbol, r2) = suggestions.remove(0);
1214            if r2 >= r2_threshold {
1215                self.splines[i] = possible_symbol;
1216                clamped_edges.push(i);
1217            }
1218        }
1219        debug!("Symbolified layer:\n{}", self);
1220        clamped_edges
1221    }
1222
1223    /// Tests all edges using the samples provided, and 'prunes' any edges in this layer with an average absolute output value less than `threshold` - that is, they will be replaced with a constant edge that outputs 0.0
1224    /// # Returns
1225    /// A vector of the indices of the pruned edges
1226    pub fn prune(&mut self, samples: &[Vec<f64>], threshold: f64) -> Vec<usize> {
1227        assert!(threshold >= 0.0, "Pruning threhsold must be >= 0.0");
1228        let transposed_samples = transpose(samples);
1229        let mut pruned_indices = Vec::new();
1230        for i in 0..self.splines.len() {
1231            trace!("Pruning edge {}", i);
1232            let in_node_idx = i / self.output_dimension;
1233            if self.splines[i].prune(&transposed_samples[in_node_idx], threshold) {
1234                pruned_indices.push(i);
1235            }
1236        }
1237        pruned_indices
1238    }
1239
1240    /// return the input dimension of this layer
1241    pub fn input_dimension(&self) -> usize {
1242        self.input_dimension
1243    }
1244
1245    /// return the output dimension of this layer
1246    pub fn output_dimension(&self) -> usize {
1247        self.output_dimension
1248    }
1249
1250    /// wipe the edge's activations. Useful for debugging and benchmarking
1251    pub fn wipe_activations(&mut self) {
1252        for edge in self.splines.iter_mut() {
1253            edge.wipe_activations();
1254        }
1255    }
1256}
1257
1258fn transpose(matrix: &[Vec<f64>]) -> Vec<Vec<f64>> {
1259    let mut transposed = vec![vec![0.0; matrix.len()]; matrix[0].len()];
1260    for i in 0..matrix.len() {
1261        for j in 0..matrix[0].len() {
1262            transposed[j][i] = matrix[i][j];
1263        }
1264    }
1265    transposed
1266}
1267
1268impl PartialEq for KanLayer {
1269    // only in a VERY contrived case would two layers have equal splines but different input/output dimensions
1270    // but it's technically possible, so we've got to check it
1271    fn eq(&self, other: &Self) -> bool {
1272        self.splines == other.splines
1273            && self.input_dimension == other.input_dimension
1274            && self.output_dimension == other.output_dimension
1275    }
1276}
1277
1278impl std::fmt::Display for KanLayer {
1279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1280        let edge_string = self
1281            .splines
1282            .iter()
1283            .map(|e| "- ".to_string() + &e.to_string())
1284            .collect::<Vec<String>>()
1285            .join("\n");
1286        write!(
1287            f,
1288            "KanLayer: input_dimension: {}, output_dimension: {}, edges:\n {}",
1289            self.input_dimension, self.output_dimension, edge_string
1290        )
1291    }
1292}
1293
1294#[cfg(test)]
1295mod test {
1296
1297    use edge::Edge;
1298    use test_log::test;
1299
1300    use super::*;
1301
1302    /// returns a new layer with input and output dimension = 2, k = 3, and coef_size = 4
1303    fn build_test_layer() -> KanLayer {
1304        let k = 3;
1305        let coef_size = 4;
1306        let knot_size = coef_size + k + 1;
1307        let knots = linspace(-1.0, 1.0, knot_size);
1308        let spline1 = Edge::new(k, vec![1.0; coef_size], knots.clone()).unwrap();
1309        let spline2 = Edge::new(k, vec![-1.0; coef_size], knots.clone()).unwrap();
1310        KanLayer {
1311            splines: vec![spline1.clone(), spline2.clone(), spline2, spline1],
1312            samples: vec![],
1313            input_dimension: 2,
1314            output_dimension: 2,
1315            layer_l1: None,
1316        }
1317    }
1318
1319    #[test]
1320    fn test_new() {
1321        let input_dimension = 3;
1322        let output_dimension = 4;
1323        let k = 5;
1324        let coef_size = 6;
1325        let my_layer = KanLayer::new(&KanLayerOptions {
1326            input_dimension,
1327            output_dimension,
1328            degree: k,
1329            coef_size,
1330        });
1331        assert_eq!(my_layer.output_dimension, output_dimension);
1332        assert_eq!(my_layer.input_dimension, input_dimension);
1333        assert_eq!(my_layer.splines.len(), input_dimension * output_dimension);
1334    }
1335
1336    #[test]
1337    fn test_forward() {
1338        // to properly test layer forward, I need a layer with output and input dim = 2, which means 4 total edges
1339        let mut layer = build_test_layer();
1340        let preacts = vec![vec![0.0, 0.5]];
1341        let acts = layer.forward(preacts).unwrap();
1342        let expected_activations = vec![0.3177, -0.3177];
1343        let rounded_activations: Vec<f64> = acts[0]
1344            .iter()
1345            .map(|x| (x * 10000.0).round() / 10000.0)
1346            .collect();
1347        assert_eq!(rounded_activations, expected_activations);
1348    }
1349
1350    #[test]
1351    fn test_forward_bad_activations() {
1352        let mut layer = build_test_layer();
1353        let preacts = vec![vec![0.0, 0.5, 0.5]];
1354        let acts = layer.forward(preacts);
1355        assert!(acts.is_err());
1356        let error = acts.err().unwrap();
1357        assert_eq!(error, LayerError::missized_preacts(3, 2));
1358        println!("{:?}", error); // make sure we can build the error message
1359    }
1360
1361    #[test]
1362    fn test_forward_then_backward() {
1363        let mut layer = build_test_layer();
1364        let preacts = vec![vec![0.0, 0.5]];
1365        let acts = layer.forward(preacts).unwrap();
1366        let expected_activations = vec![0.3177, -0.3177];
1367        let rounded_activations: Vec<f64> = acts[0]
1368            .iter()
1369            .map(|x| (x * 10000.0).round() / 10000.0)
1370            .collect();
1371        assert_eq!(rounded_activations, expected_activations, "forward failed");
1372
1373        let error = vec![vec![1.0, 0.5]];
1374        let input_error = layer.backward(&error).unwrap();
1375        let expected_input_error = vec![0.0, 1.20313];
1376        let rounded_input_error: Vec<f64> = input_error[0]
1377            .iter()
1378            .map(|f| (f * 100000.0).round() / 100000.0)
1379            .collect();
1380        assert_eq!(rounded_input_error, expected_input_error, "backward failed");
1381    }
1382
1383    #[test]
1384    fn test_forward_multithreaded_activations() {
1385        let mut layer = build_test_layer();
1386        let preacts = vec![vec![0.0, 0.5]];
1387        let acts = layer.forward_multithreaded(preacts, 4).unwrap();
1388        let expected_activations = vec![0.3177, -0.3177];
1389        let rounded_activations: Vec<f64> = acts[0]
1390            .iter()
1391            .map(|x| (x * 10000.0).round() / 10000.0)
1392            .collect();
1393        assert_eq!(rounded_activations, expected_activations);
1394    }
1395
1396    #[test]
1397    fn test_forward_multhreaded_reassemble() {
1398        let mut layer = build_test_layer();
1399        let reference_layer = layer.clone();
1400        let preacts = vec![vec![0.0, 0.5]];
1401        let _ = layer.forward_multithreaded(preacts, 4).unwrap();
1402        assert_eq!(layer, reference_layer);
1403    }
1404
1405    #[test]
1406    fn test_forward_then_backward_multithreaded_result() {
1407        let mut layer = build_test_layer();
1408        let preacts = vec![vec![0.0, 0.5]];
1409        let acts = layer.forward_multithreaded(preacts, 4).unwrap();
1410        let expected_activations = vec![0.3177, -0.3177];
1411        let rounded_activations: Vec<f64> = acts[0]
1412            .iter()
1413            .map(|x| (x * 10000.0).round() / 10000.0)
1414            .collect();
1415        assert_eq!(rounded_activations, expected_activations, "forward failed");
1416
1417        let error = vec![vec![1.0, 0.5]];
1418        let input_error = layer.backward_multithreaded(&error, 4).unwrap();
1419        let expected_input_error = vec![0.0, 1.20313];
1420        let rounded_input_error: Vec<f64> = input_error[0]
1421            .iter()
1422            .map(|f| (f * 100000.0).round() / 100000.0)
1423            .collect();
1424        assert_eq!(rounded_input_error, expected_input_error, "backward failed");
1425    }
1426
1427    #[test]
1428    fn test_forward_then_backward_multithreaded_reassemble() {
1429        let mut layer = build_test_layer();
1430        let reference_layer = layer.clone();
1431        let preacts = vec![vec![0.0, 0.5]];
1432        let acts = layer.forward_multithreaded(preacts, 4).unwrap();
1433        let expected_activations = vec![0.3177, -0.3177];
1434        let rounded_activations: Vec<f64> = acts[0]
1435            .iter()
1436            .map(|x| (x * 10000.0).round() / 10000.0)
1437            .collect();
1438        assert_eq!(rounded_activations, expected_activations, "forward failed");
1439
1440        let error = vec![vec![1.0, 0.5]];
1441        let _ = layer.backward_multithreaded(&error, 4).unwrap();
1442        assert_eq!(layer, reference_layer, "edges not reassembled correctly");
1443    }
1444
1445    // #[test]
1446    // fn test_forward_then_backward_concurrent() {
1447    //     let thread_pool = ThreadPoolBuilder::new().num_threads(4).build().unwrap();
1448    //     let mut layer = build_test_layer();
1449    //     let preacts = vec![0.0, 0.5];
1450    //     let acts = layer.forward(&preacts).unwrap();
1451    //     let expected_activations = vec![0.3177, -0.3177];
1452    //     let rounded_activations: Vec<f64> = acts
1453    //         .iter()
1454    //         .map(|x| (x * 10000.0).round() / 10000.0)
1455    //         .collect();
1456    //     assert_eq!(rounded_activations, expected_activations, "forward failed");
1457
1458    //     let error = vec![1.0, 0.5];
1459    //     let input_error = layer.backward_concurrent(&error, &thread_pool).unwrap();
1460    //     let expected_input_error = vec![0.0, 0.60156];
1461    //     let rounded_input_error: Vec<f64> = input_error
1462    //         .iter()
1463    //         .map(|f| (f * 100000.0).round() / 100000.0)
1464    //         .collect();
1465    //     assert_eq!(rounded_input_error, expected_input_error, "backward failed");
1466    // }
1467
1468    // #[test]
1469    // fn test_forward_concurrent_then_backward_concurrent() {
1470    //     let thread_pool = ThreadPoolBuilder::new().num_threads(4).build().unwrap();
1471    //     let mut layer = build_test_layer();
1472    //     let preacts = vec![0.0, 0.5];
1473    //     let acts = layer.forward_concurrent(&preacts, &thread_pool).unwrap();
1474    //     let expected_activations = vec![0.3177, -0.3177];
1475    //     let rounded_activations: Vec<f64> = acts
1476    //         .iter()
1477    //         .map(|x| (x * 10000.0).round() / 10000.0)
1478    //         .collect();
1479    //     assert_eq!(rounded_activations, expected_activations, "forward failed");
1480
1481    //     let error = vec![1.0, 0.5];
1482    //     let input_error = layer.backward_concurrent(&error, &thread_pool).unwrap();
1483    //     let expected_input_error = vec![0.0, 0.60156];
1484    //     let rounded_input_error: Vec<f64> = input_error
1485    //         .iter()
1486    //         .map(|f| (f * 100000.0).round() / 100000.0)
1487    //         .collect();
1488    //     assert_eq!(rounded_input_error, expected_input_error, "backward failed");
1489    // }
1490
1491    #[test]
1492    fn test_backward_before_forward() {
1493        let mut layer = build_test_layer();
1494        let error = vec![vec![1.0, 0.5]];
1495        let input_error = layer.backward(&error);
1496        assert!(input_error.is_err());
1497    }
1498
1499    // #[test]
1500    // fn test_backward_concurrent_before_forward() {
1501    //     let thread_pool = ThreadPoolBuilder::new().num_threads(4).build().unwrap();
1502    //     let mut layer = build_test_layer();
1503    //     let error = vec![1.0, 0.5];
1504    //     let input_error = layer.backward_concurrent(&error, &thread_pool);
1505    //     assert!(input_error.is_err());
1506    // }
1507
1508    #[test]
1509    fn test_backward_bad_error_length() {
1510        let mut layer = build_test_layer();
1511        let preacts = vec![vec![0.0, 0.5]];
1512        let _ = layer.forward(preacts).unwrap();
1513        let error = vec![vec![1.0, 0.5, 0.5]];
1514        let input_error = layer.backward(&error);
1515        assert!(input_error.is_err());
1516    }
1517
1518    #[test]
1519    fn test_update_knots_from_samples_multithreaded_results_and_reassemble() {
1520        let mut layer = build_test_layer();
1521        let preacts = vec![vec![0.0, 1.0], vec![0.5, 2.0]];
1522        let _ = layer.forward_multithreaded(preacts, 4).unwrap();
1523        layer
1524            .update_knots_from_samples_multithreaded(0.0, 4)
1525            .unwrap();
1526        let expected_knots_1 = vec![-0.1875, -0.125, -0.0625, 0.0, 0.5, 0.5625, 0.625, 0.6875]; // accounts for the padding added in Edge::update_knots_from_samples
1527        let expected_knots_2 = vec![0.625, 0.75, 0.875, 1.0, 2.0, 2.125, 2.25, 2.375]; // accounts for the padding added in Edge::update_knots_from_samples
1528        assert_eq!(layer.splines[0].knots(), expected_knots_1, "edge 0");
1529        assert_eq!(layer.splines[1].knots(), expected_knots_1, "edge 1");
1530        assert_eq!(layer.splines[2].knots(), expected_knots_2, "edge 2");
1531        assert_eq!(layer.splines[3].knots(), expected_knots_2, "edge 3");
1532    }
1533
1534    #[test]
1535    // we checked the proper averaging of knots and control points in the spline tests, so we just need to check that the layer merge isn't messing up the order of the splines
1536    fn test_merge_identical_layers_yield_identical_output() {
1537        let layer1 = build_test_layer();
1538        let layer2 = layer1.clone();
1539        let input = vec![vec![0.0, 0.5]];
1540        let acts1 = layer1.infer(&input).unwrap();
1541        let acts2 = layer2.infer(&input).unwrap();
1542        assert_eq!(acts1, acts2);
1543        let merged_layer = KanLayer::merge_layers(vec![layer1, layer2]).unwrap();
1544        let acts3 = merged_layer.infer(&input).unwrap();
1545        assert_eq!(acts1, acts3);
1546    }
1547
1548    #[test]
1549    fn test_layer_send() {
1550        fn assert_send<T: Send>() {}
1551        assert_send::<KanLayer>();
1552    }
1553
1554    #[test]
1555    fn test_layer_sync() {
1556        fn assert_sync<T: Sync>() {}
1557        assert_sync::<KanLayer>();
1558    }
1559}