fekan/
kan.rs

1pub mod kan_error;
2use std::collections::VecDeque;
3
4use kan_error::KanError;
5use log::{debug, trace};
6
7use crate::embedding_layer::{EmbeddingLayer, EmbeddingOptions};
8use crate::kan_layer::{KanLayer, KanLayerOptions};
9
10use serde::{Deserialize, Serialize};
11
12/// A full neural network model, consisting of multiple Kolmogorov-Arnold layers
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct Kan {
15    /// An optional trainable embedding layer that will replace designated discreet-valued features with a vector of real-valued features
16    pub embedding_layer: Option<EmbeddingLayer>,
17    /// the (true) layers of the model
18    pub layers: Vec<KanLayer>,
19    /// the type of model. This field is metadata and does not affect the operation of the model, though it is used elsewhere in the crate. See [`fekan::train_model()`](crate::train_model) for an example
20    model_type: ModelType, // determined how the output is interpreted, and what the loss function ought to be
21    /// A map of class names to node indices. Only used if the model is a classification model or multi-output regression model.
22    class_map: Option<Vec<String>>,
23}
24
25/// Hyperparameters for a Kan model
26///
27/// # Example
28/// see [Kan::new]
29///
30#[derive(Clone, Eq, PartialEq, Hash, Debug)]
31pub struct KanOptions {
32    /// the number of input features the model should accept
33    pub num_features: usize,
34    /// the indexes of the features that should be embedded. These features will be replaced with a vector from the embedding table
35    pub embedding_options: Option<EmbeddingOptions>,
36    /// the sizes of the layers to use in the model, including the output layer
37    pub layer_sizes: Vec<usize>,
38    /// the degree of the b-splines to use in each layer
39    pub degree: usize,
40    /// the number of coefficients to use in the b-splines in each layer
41    pub coef_size: usize,
42    /// the type of model to create. This field is metadata and does not affect the operation of the model, though it is used by [`fekan::train_model()`](crate::train_model) to determine the proper loss function
43    pub model_type: ModelType,
44    /// A list of human-readable names for the output nodes.
45    /// The length of this vector should be equal to the number of output nodes in the model (the last number in `layer_sizes`), or else behavior is undefined
46    pub class_map: Option<Vec<String>>,
47}
48
49/// Metadata suggesting how the model's output ought to be interpreted
50///
51/// For information on how model type can affect training, see [`train_model()`](crate::train_model)
52#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash)]
53pub enum ModelType {
54    /// For models designed to assign a discreet class to an input. For example, determining if an image contains a cat or a dog
55    Classification,
56    /// For models design to predict a continuous value. For example, predicting the price of a house
57    Regression,
58}
59
60impl std::fmt::Display for ModelType {
61    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
62        match self {
63            ModelType::Classification => write!(f, "Classification"),
64            ModelType::Regression => write!(f, "Regression"),
65        }
66    }
67}
68
69impl Kan {
70    /// creates a new Kan model with the given hyperparameters
71    ///
72    /// # Example
73    /// Create a regression model with 5 input features, 2 hidden layers of size 4 and 3, and 1 output feature, using degree 3 b-splines with 6 coefficients per spline
74    /// ```
75    /// use fekan::kan::{Kan, KanOptions, ModelType};
76    ///
77    /// let options = KanOptions {
78    ///     num_features: 5,
79    ///     layer_sizes: vec![4, 3, 1],
80    ///     degree: 3,
81    ///     coef_size: 6,
82    ///     model_type: ModelType::Regression,
83    ///     class_map: None,
84    ///     embedding_options: None,
85    /// };
86    /// let mut model = Kan::new(&options);
87    ///```
88    pub fn new(options: &KanOptions) -> Self {
89        // build the embedding table
90        let embedding_table = match &options.embedding_options {
91            Some(emb_opt) => Some(EmbeddingLayer::new(emb_opt)),
92            None => None,
93        };
94        let mut prev_size = if let Some(emb_table) = embedding_table.as_ref() {
95            emb_table.output_dimension()
96        } else {
97            options.num_features
98        };
99        let mut layers = Vec::with_capacity(options.layer_sizes.len());
100        for &size in options.layer_sizes.iter() {
101            layers.push(KanLayer::new(&KanLayerOptions {
102                input_dimension: prev_size,
103                output_dimension: size,
104                degree: options.degree,
105                coef_size: options.coef_size,
106            }));
107            prev_size = size;
108        }
109        Kan {
110            layers,
111            embedding_layer: embedding_table,
112            model_type: options.model_type,
113            class_map: options.class_map.clone(),
114        }
115    }
116
117    /// returns the type of the model
118    pub fn model_type(&self) -> ModelType {
119        self.model_type
120    }
121
122    /// returns the class map of the model, if it has one
123    pub fn class_map(&self) -> Option<&Vec<String>> {
124        self.class_map.as_ref()
125    }
126
127    /// Returns the index of the output node that corresponds to the given label.
128    ///
129    /// Returns None if the label is not found in the model's class map, or if the model does not have a class map
130    ///
131    /// # Example
132    /// creating a model with a class map
133    /// ```
134    /// use fekan::kan::{Kan, KanOptions, ModelType};
135    /// let my_class_map = vec!["cat".to_string(), "dog".to_string()];
136    /// let options = KanOptions {
137    ///     num_features: 5,
138    ///     layer_sizes: vec![4, 2],
139    ///     degree: 3,
140    ///     coef_size: 6,
141    ///     model_type: ModelType::Regression,
142    ///     class_map: Some(my_class_map),
143    ///     embedding_options: None,
144    /// };
145    /// let model = Kan::new(&options);
146    /// assert_eq!(model.label_to_node("cat"), Some(0));
147    /// assert_eq!(model.label_to_node("dog"), Some(1));
148    /// assert_eq!(model.label_to_node("fish"), None);
149    /// ```
150    /// Using a model's class map during training to determine the index of node that should have had the highest value
151    /// ```
152    /// # use fekan::kan::{Kan, KanOptions, ModelType};
153    /// # let my_class_map = vec!["cat".to_string(), "dog".to_string()];
154    /// # let options = KanOptions {
155    /// #    num_features: 5,
156    /// #    layer_sizes: vec![4, 2],
157    /// #    degree: 3,
158    /// #    coef_size: 6,
159    /// #    model_type: ModelType::Regression,
160    /// #    class_map: Some(my_class_map),
161    /// #    embedding_options: None,
162    /// # };
163    /// # let mut model = Kan::new(&options);
164    /// # let feature_data = vec![vec![0.5, 0.4, 0.5, 0.5, 0.4]];
165    /// # let label = "cat";
166    /// # fn cross_entropy_loss(output: Vec<f64>, expected_highest_node: usize) -> f64 {0.0}
167    /// /* within your custom training function */
168    /// let batch_logits: Vec<Vec<f64>> = model.forward(feature_data)?;
169    /// for logits in batch_logits {
170    ///     let expected_highest_node: usize = model.label_to_node(label).unwrap();
171    ///     let loss: f64 = cross_entropy_loss(logits, expected_highest_node);
172    /// }
173    /// # Ok::<(), fekan::kan::kan_error::KanError>(())
174    /// ```
175    pub fn label_to_node(&self, label: &str) -> Option<usize> {
176        if let Some(class_map) = &self.class_map {
177            class_map.iter().position(|x| x == label)
178        } else {
179            None
180        }
181    }
182
183    /// Returns the label for the output node at the given index.
184    ///
185    /// Returns None if the index is out of bounds, or if the model does not have a class map
186    ///
187    /// # Example
188    /// ```
189    /// use fekan::kan::{Kan, KanOptions, ModelType};
190    /// let class_map = vec!["cat".to_string(), "dog".to_string()];
191    /// let options = KanOptions {
192    ///     num_features: 5,
193    ///     layer_sizes: vec![4, 2],
194    ///     degree: 3,
195    ///     coef_size: 6,
196    ///     model_type: ModelType::Regression,
197    ///     class_map: Some(class_map),
198    ///     embedding_options: None,
199    /// };
200    /// let model = Kan::new(&options);
201    /// assert_eq!(model.node_to_label(0), Some("cat"));
202    /// assert_eq!(model.node_to_label(1), Some("dog"));
203    /// assert_eq!(model.node_to_label(2), None);
204    /// ```
205    /// Using a model's class map during inference to interpret the output of a classifier
206    /// ```
207    /// # use fekan::kan::{Kan, KanOptions, ModelType};
208    /// # let my_class_map = vec!["cat".to_string(), "dog".to_string()];
209    /// # let options = KanOptions {
210    /// #    num_features: 5,
211    /// #    layer_sizes: vec![4, 2],
212    /// #    degree: 3,
213    /// #    coef_size: 6,
214    /// #    model_type: ModelType::Regression,
215    /// #    class_map: Some(my_class_map),
216    ///     embedding_options: None,
217    /// # };
218    /// # let model = Kan::new(&options);
219    /// # let feature_data = vec![vec![0.5, 0.4, 0.5, 0.5, 0.4]];
220    /// /* using an already trained model... */
221    /// let batch_logits: Vec<Vec<f64>> = model.infer(feature_data)?;
222    /// for logits in batch_logits {
223    ///     let highest_node: usize = logits.iter().enumerate().max_by(|(a_idx, a_val), (b_idx, b_val)| a_val.partial_cmp(b_val).unwrap()).unwrap().0;
224    ///     let label: &str = model.node_to_label(highest_node).unwrap();
225    ///     println!("The model predicts the input is a {}", label);
226    /// }
227    /// Ok::<(), fekan::kan::kan_error::KanError>(())
228    /// ```
229    pub fn node_to_label(&self, node: usize) -> Option<&str> {
230        if let Some(class_map) = &self.class_map {
231            class_map.get(node).map(|x| x.as_str())
232        } else {
233            None
234        }
235    }
236
237    /// Forward-propogate the input through the model, by calling [`KanLayer::forward`] method of the first layer on the input,
238    /// then calling the `forward` method of each subsequent layer with the output of the previous layer,
239    /// returning the output of the final layer.
240    ///
241    /// This method accumulates internal state in the model needed for training. For inference or validation, use [`Kan::infer`], which does not accumulate state and is more efficient
242    ///
243    /// # Errors
244    /// returns a [`KanError`] if any layer returns an error.
245    /// See [`KanLayer::forward`] for more information
246    ///
247    /// # Example
248    /// ```
249    /// use fekan::kan::{Kan, KanOptions, ModelType, kan_error::KanError};
250    /// let num_features = 5;
251    /// let output_size = 3;
252    /// let options = KanOptions {
253    ///     num_features,
254    ///     layer_sizes: vec![4, output_size],
255    ///     degree: 3,
256    ///     coef_size: 6,
257    ///     model_type: ModelType::Classification,
258    ///     class_map: None,
259    ///     embedding_options: None,
260    /// };
261    /// let mut model = Kan::new(&options);
262    /// let batch_size = 2;
263    /// let input = vec![vec![1.0; num_features]; batch_size];
264    /// let output = model.forward(input)?;
265    /// assert_eq!(output.len(), batch_size);
266    /// assert_eq!(output[0].len(), output_size);
267    /// /* interpret the output as you like, for example as logits in a classifier, or as predicted value in a regressor */
268    /// # Ok::<(), fekan::kan::kan_error::KanError>(())
269    /// ```
270    pub fn forward(&mut self, input: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, KanError> {
271        debug!("Forwarding {} samples through model", input.len());
272        trace!("Preactivations: {:?}", input);
273        let mut preacts = input;
274        if let Some(embedding_layer) = self.embedding_layer.as_mut() {
275            preacts = embedding_layer
276                .forward(preacts)
277                .map_err(|e| KanError::forward(e, 0))?;
278        }
279        for (idx, layer) in self.layers.iter_mut().enumerate() {
280            debug!("Forwarding through layer {}", idx);
281            preacts = layer
282                .forward(preacts)
283                .map_err(|e| KanError::forward(e, idx))?;
284        }
285        Ok(preacts)
286    }
287
288    /// as [`Kan::forward`], but uses multiple threads to forward the input through the model
289    pub fn forward_multithreaded(
290        &mut self,
291        input: Vec<Vec<f64>>,
292        num_threads: usize,
293    ) -> Result<Vec<Vec<f64>>, KanError> {
294        debug!(
295            "Forwarding {} samples through model using {} threads",
296            input.len(),
297            num_threads
298        );
299        trace!("Preactivations: {:?}", input);
300        let mut preacts = input;
301        if let Some(embedding_layer) = self.embedding_layer.as_mut() {
302            preacts = embedding_layer
303                .forward(preacts)
304                .map_err(|e| KanError::forward(e, 0))?;
305        }
306        for (idx, layer) in self.layers.iter_mut().enumerate() {
307            debug!("Forwarding through layer {}", idx);
308            preacts = layer
309                .forward_multithreaded(preacts, num_threads)
310                .map_err(|e| KanError::forward(e, idx))?;
311        }
312        Ok(preacts)
313    }
314
315    /// as [`Kan::forward`], but does not accumulate any internal state
316    ///
317    /// This method should be used when the model is not being trained, for example during inference or validation: when you won't be backpropogating, this method is faster uses less memory than [`Kan::forward`]
318    ///
319    /// # Errors
320    /// returns a [KanError] if any layer returns an error.
321    /// # Example
322    /// see [`Kan::forward`] for an example
323    pub fn infer(&self, input: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, KanError> {
324        let mut preacts = input;
325        if let Some(embedding_layer) = self.embedding_layer.as_ref() {
326            preacts = embedding_layer
327                .infer(&preacts)
328                .map_err(|e| KanError::forward(e, 0))?;
329        }
330        for (idx, layer) in self.layers.iter().enumerate() {
331            preacts = layer
332                .infer(&preacts)
333                .map_err(|e| KanError::forward(e, idx))?;
334        }
335        Ok(preacts)
336    }
337
338    /// as [`Kan::infer`], but uses multiple threads to forward the input through the model
339    pub fn infer_multithreaded(
340        &self,
341        input: Vec<Vec<f64>>,
342        num_threads: usize,
343    ) -> Result<Vec<Vec<f64>>, KanError> {
344        let mut preacts = input;
345        if let Some(embedding_layer) = self.embedding_layer.as_ref() {
346            preacts = embedding_layer
347                .infer(&preacts)
348                .map_err(|e| KanError::forward(e, 0))?;
349        }
350        for (idx, layer) in self.layers.iter().enumerate() {
351            preacts = layer
352                .infer_multithreaded(&preacts, num_threads)
353                .map_err(|e| KanError::forward(e, idx))?;
354        }
355        Ok(preacts)
356    }
357
358    /// Back-propogate the gradient through the model, internally accumulating the gradients of the model's parameters, to be applied later with [`Kan::update`]
359    ///
360    /// # Errors
361    /// returns an error if any layer returns an error.
362    /// See [`KanLayer::backward`] for more information
363    ///
364    /// # Example
365    /// ```
366    /// use fekan::kan::{Kan, KanOptions, ModelType, kan_error::KanError};
367    ///
368    /// # let options = KanOptions {
369    /// #    num_features: 5,
370    /// #    layer_sizes: vec![4, 1],
371    /// #    degree: 3,
372    /// #    coef_size: 6,
373    /// #    model_type: ModelType::Regression,
374    /// #    class_map: None,
375    /// #    embedding_options: None,
376    /// # };
377    /// let mut model = Kan::new(&options);
378    ///
379    /// # fn calculate_gradient(output: &[Vec<f64>], label: f64) -> Vec<Vec<f64>> {vec![vec![1.0; output[0].len()]; output.len()]}
380    /// # let learning_rate = 0.1;
381    /// # let l1_penalty = 0.0;
382    /// # let entropy_penalty = 0.0;
383    /// # let features = vec![vec![0.5, 0.4, 0.5, 0.5, 0.4]];
384    /// # let label = 0;
385    /// let output = model.forward(features)?;
386    /// let gradient = calculate_gradient(&output, label as f64);
387    /// assert_eq!(gradient.len(), output.len());
388    /// let _ = model.backward(gradient)?; // the input gradient can be disregarded here.
389    ///
390    /// /*
391    /// * The model has stored the gradients for it's parameters internally.
392    /// * We can add conduct as many forward/backward pass-pairs as we like to accumulate gradient,
393    /// * until we're ready to update the paramaters.
394    /// */
395    ///
396    /// model.update(learning_rate, l1_penalty, entropy_penalty); // update the parameters of the model based on the accumulated gradients here
397    /// model.zero_gradients(); // zero the gradients for the next batch of training data
398    /// # Ok::<(), fekan::kan::kan_error::KanError>(())
399    /// ```
400    pub fn backward(&mut self, gradients: Vec<Vec<f64>>) -> Result<(), KanError> {
401        debug!("Backwarding {} gradients through model", gradients.len());
402        let mut gradients = gradients;
403        for (idx, layer) in self.layers.iter_mut().enumerate().rev() {
404            gradients = layer
405                .backward(&gradients)
406                .map_err(|e| KanError::backward(e, idx))?;
407        }
408        if let Some(embedding_layer) = self.embedding_layer.as_mut() {
409            embedding_layer
410                .backward(gradients)
411                .map_err(|e| KanError::backward(e, 0))?;
412        }
413
414        Ok(())
415    }
416
417    /// as [`Kan::backward`], but uses multiple threads to back-propogate the gradient through the model
418    pub fn backward_multithreaded(
419        &mut self,
420        gradients: Vec<Vec<f64>>,
421        num_threads: usize,
422    ) -> Result<(), KanError> {
423        debug!("Backwarding {} gradients through model", gradients.len());
424        let mut gradients = gradients;
425        for (idx, layer) in self.layers.iter_mut().enumerate().rev() {
426            gradients = layer
427                .backward_multithreaded(&gradients, num_threads)
428                .map_err(|e| KanError::backward(e, idx))?;
429        }
430        if let Some(embedding_layer) = self.embedding_layer.as_mut() {
431            embedding_layer
432                .backward(gradients)
433                .map_err(|e| KanError::backward(e, 0))?;
434        }
435
436        Ok(())
437    }
438
439    /// Update the model's parameters based on the gradients that have been accumulated with [`Kan::backward`].
440    /// # Example
441    /// see [`Kan::backward`]
442    pub fn update(&mut self, learning_rate: f64, l1_penalty: f64, entropy_penalty: f64) {
443        if let Some(embedding_table) = self.embedding_layer.as_mut() {
444            embedding_table.update(learning_rate);
445        }
446        for layer in self.layers.iter_mut() {
447            layer.update(learning_rate, l1_penalty, entropy_penalty);
448        }
449    }
450
451    /// Zero the internal gradients of the model's parameters
452    /// # Example
453    /// see [`Kan::backward`]
454    pub fn zero_gradients(&mut self) {
455        if let Some(embedding_table) = self.embedding_layer.as_mut() {
456            embedding_table.zero_gradients();
457        }
458        for layer in self.layers.iter_mut() {
459            layer.zero_gradients();
460        }
461    }
462
463    /// get the total number of parameters in the model, inlcuding untrained parameters. See [`KanLayer::parameter_count`] for more information
464    pub fn parameter_count(&self) -> usize {
465        self.layers
466            .iter()
467            .map(|layer| layer.parameter_count())
468            .sum()
469    }
470
471    /// get the total number of trainable parameters in the model. See [`KanLayer::trainable_parameter_count`] for more information
472    pub fn trainable_parameter_count(&self) -> usize {
473        self.layers
474            .iter()
475            .map(|layer| layer.trainable_parameter_count())
476            .sum()
477    }
478
479    /// Update the ranges spanned by the B-spline knots in the model, using samples accumulated by recent [`Kan::forward`] calls.
480    ///
481    /// see [`KanLayer::update_knots_from_samples`] for more information
482    /// # Errors
483    /// returns a [KanError] if any layer returns an error.
484    ///
485    /// # Example
486    /// see [`KanLayer::update_knots_from_samples`] for examples
487    pub fn update_knots_from_samples(&mut self, knot_adaptivity: f64) -> Result<(), KanError> {
488        for (idx, layer) in self.layers.iter_mut().enumerate() {
489            debug!("Updating knots for layer {}", idx);
490            if let Err(e) = layer.update_knots_from_samples(knot_adaptivity) {
491                return Err(KanError::update_knots(e, idx));
492            }
493        }
494        return Ok(());
495    }
496
497    /// as [`Kan::update_knots_from_samples`], but uses multiple threads to update the knot vectors
498    pub fn update_knots_from_samples_multithreaded(
499        &mut self,
500        knot_adaptivity: f64,
501        num_threads: usize,
502    ) -> Result<(), KanError> {
503        for (idx, layer) in self.layers.iter_mut().enumerate() {
504            debug!("Updating knots for layer {}", idx);
505            if let Err(e) =
506                layer.update_knots_from_samples_multithreaded(knot_adaptivity, num_threads)
507            {
508                return Err(KanError::update_knots(e, idx));
509            }
510        }
511        return Ok(());
512    }
513
514    /// Clear the cached samples used by [`Kan::update_knots_from_samples`]
515    ///
516    /// see [`KanLayer::clear_samples`] for more information
517    pub fn clear_samples(&mut self) {
518        debug!("Clearing samples from model");
519        if let Some(embedding_table) = self.embedding_layer.as_mut() {
520            embedding_table.clear_samples();
521        }
522        for layer_idx in 0..self.layers.len() {
523            debug!("Clearing samples from layer {}", layer_idx);
524            self.layers[layer_idx].clear_samples();
525        }
526    }
527
528    /// Set the size of the knot vector used in all splines in this model
529    /// see [`KanLayer::set_knot_length`] for more information
530    pub fn set_knot_length(&mut self, knot_length: usize) -> Result<(), KanError> {
531        for (idx, layer) in self.layers.iter_mut().enumerate() {
532            if let Err(e) = layer.set_knot_length(knot_length) {
533                return Err(KanError::set_knot_length(e, idx));
534            }
535        }
536        Ok(())
537    }
538
539    /// Get the size of the knot vector used in all splines in this model
540    ///
541    /// ## Note
542    /// if different layers have different knot lengths, this method will return the knot length of the first layer
543    pub fn knot_length(&self) -> usize {
544        self.layers[0].knot_length()
545    }
546
547    /// Create a new model by merging multiple models together. Models must be of the same type and have the same number of layers, and all layers must be mergable (see [`KanLayer::merge_layers`])
548    /// # Errors
549    /// Returns a [`KanError`] if:
550    /// * the models are not mergable. See [`Kan::models_mergable`] for more information
551    /// * any layer encounters an error during the merge. See [`KanLayer::merge_layers`] for more information
552    /// # Example
553    /// Train multiple copies of the model on different data in different threads, then merge the trained models together
554    /// ```
555    /// use fekan::{kan::{Kan, KanOptions, ModelType, kan_error::KanError}, Sample};
556    /// use std::thread;
557    /// # let model_options = KanOptions {
558    /// #    num_features: 5,
559    /// #    layer_sizes: vec![4, 3],
560    /// #    degree: 3,
561    /// #    coef_size: 6,
562    /// #    model_type: ModelType::Regression,
563    /// #    class_map: None,
564    /// #    embedding_options: None,
565    /// };
566    /// # let num_training_threads = 1;
567    /// # let training_data = vec![ Sample::new_regression_sample(vec![], 0.0) ];
568    /// # fn my_train_model_function(model: Kan, data: &[Sample]) -> Kan {model}
569    /// let mut my_model = Kan::new(&model_options);
570    /// let partially_trained_models: Vec<Kan> = thread::scope(|s|{
571    ///     let chunk_size = f32::ceil(training_data.len() as f32 / num_training_threads as f32) as usize; // round up, since .chunks() gives up-to chunk_size chunks. This way to don't leave any data on the cutting room floor
572    ///     let handles: Vec<_> = training_data.chunks(chunk_size).map(|training_data_chunk|{
573    ///         let clone_model = my_model.clone();
574    ///         s.spawn(move ||{
575    ///             my_train_model_function(clone_model, training_data_chunk) // `my_train_model_function` is a stand-in for whatever function you're using to train the model - not actually defined in this crate
576    ///         })
577    ///     }).collect();
578    ///     handles.into_iter().map(|handle| handle.join().unwrap()).collect()
579    /// });
580    /// let fully_trained_model = Kan::merge_models(partially_trained_models)?;
581    /// # Ok::<(), fekan::kan::kan_error::KanError>(())
582    /// ```
583    ///
584    pub fn merge_models(models: Vec<Kan>) -> Result<Kan, KanError> {
585        Self::models_mergable(&models)?; // check if the models are mergable
586        let layer_count = models[0].layers.len();
587        let model_type = models[0].model_type;
588        let class_map = models[0].class_map.clone();
589        let merged_embedding_layer = if models[0].embedding_layer.is_some() {
590            let embedding_layers: Vec<&EmbeddingLayer> = models
591                .iter()
592                .map(|model| model.embedding_layer.as_ref().unwrap())
593                .collect();
594            let merged_embedding_layer = EmbeddingLayer::merge_layers(&embedding_layers)
595                .map_err(|e| KanError::merge_unmergable_layers(e, 0))?;
596            Some(merged_embedding_layer)
597        } else {
598            None
599        };
600        // merge the layers
601        let mut all_layers: Vec<VecDeque<KanLayer>> = models
602            .into_iter()
603            .map(|model| model.layers.into())
604            .collect();
605        // aggregate layers by index
606        let mut merged_layers = Vec::new();
607        for layer_idx in 0..layer_count {
608            let layers_to_merge: Vec<KanLayer> = all_layers
609                .iter_mut()
610                .map(|layers| {
611                    layers
612                        .pop_front()
613                        .expect("iterated past end of dequeue while merging models")
614                })
615                .collect();
616            let merged_layer = KanLayer::merge_layers(layers_to_merge)
617                .map_err(|e| KanError::merge_unmergable_layers(e, layer_idx))?;
618            merged_layers.push(merged_layer);
619        }
620
621        let merged_model = Kan {
622            embedding_layer: merged_embedding_layer,
623            layers: merged_layers,
624            model_type,
625            class_map,
626        };
627        Ok(merged_model)
628    }
629
630    /// Check if the given models can be merged using [`Kan::merge_models`]. Returns Ok(()) if the models are mergable, an error otherwise
631    /// # Errors
632    /// Returns a [`KanError`] if any of the models:
633    /// * have different model types (e.g. classification vs regression)
634    /// * have different numbers of layers
635    /// * have different class maps (if the models are classification models)
636    /// or if the input slice is empty
637    pub fn models_mergable(models: &[Kan]) -> Result<(), KanError> {
638        let expected_model_type = models[0].model_type;
639        let expected_class_map = &models[0].class_map;
640        let expected_layer_count = models[0].layers.len();
641        let expected_embedding_table = &models[0].embedding_layer;
642        for idx in 1..models.len() {
643            if models[idx].model_type != expected_model_type {
644                return Err(KanError::merge_mismatched_model_type(
645                    idx,
646                    expected_model_type,
647                    models[idx].model_type,
648                ));
649            }
650            if models[idx].class_map != *expected_class_map {
651                return Err(KanError::merge_mismatched_class_map(
652                    idx,
653                    expected_class_map.clone(),
654                    models[idx].class_map.clone(),
655                ));
656            }
657            if models[idx].layers.len() != expected_layer_count {
658                return Err(KanError::merge_mismatched_depth_model(
659                    idx,
660                    expected_layer_count,
661                    models[idx].layers.len(),
662                ));
663            }
664            if models[idx].embedding_layer.is_some() != expected_embedding_table.is_some() {
665                return Err(KanError::merge_mismatched_embedding_table_presence(
666                    idx,
667                    expected_embedding_table.is_some(),
668                    models[idx].embedding_layer.is_some(),
669                ));
670            }
671        }
672        Ok(())
673    }
674
675    /// Test and set the symbolic status of the model, using the given R^2 threshold. See [`KanLayer::test_and_set_symbolic`] for more information
676    pub fn test_and_set_symbolic(&mut self, r2_threshold: f64) -> Vec<(usize, usize)> {
677        debug!("Testing and setting symbolic for the model");
678        let mut symbolified_edges = Vec::new();
679        for i in 0..self.layers.len() {
680            debug!("Symbolifying layer {}", i);
681            let clamped_edges = self.layers[i].test_and_set_symbolic(r2_threshold);
682            symbolified_edges.extend(clamped_edges.into_iter().map(|j| (i, j)));
683        }
684        symbolified_edges
685    }
686
687    /// Prune the model, removing any edges that have low average output. See [`KanLayer::prune`] for more information
688    /// Returns a list of the indices of pruned edges (i,j), where i is the index of the layer, and j is the index of the edge in that layer that was pruned
689    pub fn prune(
690        &mut self,
691        samples: Vec<Vec<f64>>,
692        threshold: f64,
693    ) -> Result<Vec<(usize, usize)>, KanError> {
694        let mut pruned_edges = Vec::new();
695        let mut samples = match &self.embedding_layer {
696            None => samples,
697            Some(embedding_layer) => embedding_layer
698                .infer(&samples)
699                .map_err(|e| KanError::forward(e, 0))?,
700        };
701        for i in 0..self.layers.len() {
702            debug!("Pruning layer {}", i);
703            let this_layer = &mut self.layers[i];
704            let next_samples = this_layer
705                .infer(&samples)
706                .map_err(|e| KanError::forward(e, i))?;
707            let layer_prunings = this_layer.prune(&samples, threshold);
708            pruned_edges.extend(layer_prunings.into_iter().map(|j| (i, j)));
709            samples = next_samples;
710        }
711        Ok(pruned_edges)
712    }
713}
714
715impl PartialEq for Kan {
716    fn eq(&self, other: &Self) -> bool {
717        self.layers == other.layers && self.model_type == other.model_type
718    }
719}
720
721#[cfg(test)]
722mod test {
723    use super::*;
724
725    #[test]
726    fn test_forward() {
727        let kan_config = KanOptions {
728            num_features: 3,
729            layer_sizes: vec![4, 2, 3],
730            degree: 3,
731            coef_size: 4,
732            model_type: ModelType::Classification,
733            class_map: None,
734            embedding_options: None,
735        };
736        let mut first_kan = Kan::new(&kan_config);
737        let second_kan_config = KanOptions {
738            layer_sizes: vec![2, 4, 3],
739            ..kan_config
740        };
741        let mut second_kan = Kan::new(&second_kan_config);
742        let input = vec![vec![0.5, 0.4, 0.5]];
743        let result = first_kan.forward(input.clone()).unwrap();
744        assert_eq!(result.len(), 1);
745        assert_eq!(result[0].len(), 3);
746        let result = second_kan.forward(input).unwrap();
747        assert_eq!(result.len(), 1);
748        assert_eq!(result[0].len(), 3);
749    }
750
751    #[test]
752    fn test_forward_then_backward() {
753        let options = &KanOptions {
754            num_features: 5,
755            layer_sizes: vec![4, 2, 3],
756            degree: 3,
757            coef_size: 4,
758            model_type: ModelType::Classification,
759            class_map: None,
760            embedding_options: None,
761        };
762        let mut first_kan = Kan::new(options);
763        let input = vec![vec![0.5, 0.4, 0.5, 0.5, 0.4]];
764        let result = first_kan.forward(input.clone()).unwrap();
765        assert_eq!(result.len(), 1);
766        assert_eq!(result[0].len(), options.layer_sizes.last().unwrap().clone());
767        let error = vec![vec![0.5, 0.4, 0.5]];
768        let result = first_kan.backward(error);
769        assert!(result.is_ok());
770    }
771
772    #[test]
773    fn test_merge_identical_models_yields_identical_output() {
774        let kan_config = KanOptions {
775            num_features: 3,
776            layer_sizes: vec![4, 2, 3],
777            degree: 3,
778            coef_size: 4,
779            model_type: ModelType::Classification,
780            class_map: None,
781            embedding_options: None,
782        };
783        let first_kan = Kan::new(&kan_config);
784        let second_kan = first_kan.clone();
785        let input = vec![vec![0.5, 0.4, 0.5]];
786        let first_result = first_kan.infer(input.clone()).unwrap();
787        let second_result = second_kan.infer(input.clone()).unwrap();
788        assert_eq!(first_result, second_result);
789        let merged_kan = Kan::merge_models(vec![first_kan, second_kan]).unwrap();
790        let merged_result = merged_kan.infer(input).unwrap();
791        assert_eq!(first_result, merged_result);
792    }
793
794    #[test]
795    fn test_model_send() {
796        fn assert_send<T: Send>() {}
797        assert_send::<Kan>();
798    }
799
800    #[test]
801    fn test_model_sync() {
802        fn assert_sync<T: Sync>() {}
803        assert_sync::<Kan>();
804    }
805
806    #[test]
807    fn test_error_send() {
808        fn assert_send<T: Send>() {}
809        assert_send::<KanError>();
810    }
811
812    #[test]
813    fn test_error_sync() {
814        fn assert_sync<T: Sync>() {}
815        assert_sync::<KanError>();
816    }
817}