fekan/kan.rs
1pub mod kan_error;
2use std::collections::VecDeque;
3
4use kan_error::KanError;
5use log::{debug, trace};
6
7use crate::embedding_layer::{EmbeddingLayer, EmbeddingOptions};
8use crate::kan_layer::{KanLayer, KanLayerOptions};
9
10use serde::{Deserialize, Serialize};
11
12/// A full neural network model, consisting of multiple Kolmogorov-Arnold layers
13#[derive(Debug, Serialize, Deserialize, Clone)]
14pub struct Kan {
15 /// An optional trainable embedding layer that will replace designated discreet-valued features with a vector of real-valued features
16 pub embedding_layer: Option<EmbeddingLayer>,
17 /// the (true) layers of the model
18 pub layers: Vec<KanLayer>,
19 /// the type of model. This field is metadata and does not affect the operation of the model, though it is used elsewhere in the crate. See [`fekan::train_model()`](crate::train_model) for an example
20 model_type: ModelType, // determined how the output is interpreted, and what the loss function ought to be
21 /// A map of class names to node indices. Only used if the model is a classification model or multi-output regression model.
22 class_map: Option<Vec<String>>,
23}
24
25/// Hyperparameters for a Kan model
26///
27/// # Example
28/// see [Kan::new]
29///
30#[derive(Clone, Eq, PartialEq, Hash, Debug)]
31pub struct KanOptions {
32 /// the number of input features the model should accept
33 pub num_features: usize,
34 /// the indexes of the features that should be embedded. These features will be replaced with a vector from the embedding table
35 pub embedding_options: Option<EmbeddingOptions>,
36 /// the sizes of the layers to use in the model, including the output layer
37 pub layer_sizes: Vec<usize>,
38 /// the degree of the b-splines to use in each layer
39 pub degree: usize,
40 /// the number of coefficients to use in the b-splines in each layer
41 pub coef_size: usize,
42 /// the type of model to create. This field is metadata and does not affect the operation of the model, though it is used by [`fekan::train_model()`](crate::train_model) to determine the proper loss function
43 pub model_type: ModelType,
44 /// A list of human-readable names for the output nodes.
45 /// The length of this vector should be equal to the number of output nodes in the model (the last number in `layer_sizes`), or else behavior is undefined
46 pub class_map: Option<Vec<String>>,
47}
48
49/// Metadata suggesting how the model's output ought to be interpreted
50///
51/// For information on how model type can affect training, see [`train_model()`](crate::train_model)
52#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq, Hash)]
53pub enum ModelType {
54 /// For models designed to assign a discreet class to an input. For example, determining if an image contains a cat or a dog
55 Classification,
56 /// For models design to predict a continuous value. For example, predicting the price of a house
57 Regression,
58}
59
60impl std::fmt::Display for ModelType {
61 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
62 match self {
63 ModelType::Classification => write!(f, "Classification"),
64 ModelType::Regression => write!(f, "Regression"),
65 }
66 }
67}
68
69impl Kan {
70 /// creates a new Kan model with the given hyperparameters
71 ///
72 /// # Example
73 /// Create a regression model with 5 input features, 2 hidden layers of size 4 and 3, and 1 output feature, using degree 3 b-splines with 6 coefficients per spline
74 /// ```
75 /// use fekan::kan::{Kan, KanOptions, ModelType};
76 ///
77 /// let options = KanOptions {
78 /// num_features: 5,
79 /// layer_sizes: vec![4, 3, 1],
80 /// degree: 3,
81 /// coef_size: 6,
82 /// model_type: ModelType::Regression,
83 /// class_map: None,
84 /// embedding_options: None,
85 /// };
86 /// let mut model = Kan::new(&options);
87 ///```
88 pub fn new(options: &KanOptions) -> Self {
89 // build the embedding table
90 let embedding_table = match &options.embedding_options {
91 Some(emb_opt) => Some(EmbeddingLayer::new(emb_opt)),
92 None => None,
93 };
94 let mut prev_size = if let Some(emb_table) = embedding_table.as_ref() {
95 emb_table.output_dimension()
96 } else {
97 options.num_features
98 };
99 let mut layers = Vec::with_capacity(options.layer_sizes.len());
100 for &size in options.layer_sizes.iter() {
101 layers.push(KanLayer::new(&KanLayerOptions {
102 input_dimension: prev_size,
103 output_dimension: size,
104 degree: options.degree,
105 coef_size: options.coef_size,
106 }));
107 prev_size = size;
108 }
109 Kan {
110 layers,
111 embedding_layer: embedding_table,
112 model_type: options.model_type,
113 class_map: options.class_map.clone(),
114 }
115 }
116
117 /// returns the type of the model
118 pub fn model_type(&self) -> ModelType {
119 self.model_type
120 }
121
122 /// returns the class map of the model, if it has one
123 pub fn class_map(&self) -> Option<&Vec<String>> {
124 self.class_map.as_ref()
125 }
126
127 /// Returns the index of the output node that corresponds to the given label.
128 ///
129 /// Returns None if the label is not found in the model's class map, or if the model does not have a class map
130 ///
131 /// # Example
132 /// creating a model with a class map
133 /// ```
134 /// use fekan::kan::{Kan, KanOptions, ModelType};
135 /// let my_class_map = vec!["cat".to_string(), "dog".to_string()];
136 /// let options = KanOptions {
137 /// num_features: 5,
138 /// layer_sizes: vec![4, 2],
139 /// degree: 3,
140 /// coef_size: 6,
141 /// model_type: ModelType::Regression,
142 /// class_map: Some(my_class_map),
143 /// embedding_options: None,
144 /// };
145 /// let model = Kan::new(&options);
146 /// assert_eq!(model.label_to_node("cat"), Some(0));
147 /// assert_eq!(model.label_to_node("dog"), Some(1));
148 /// assert_eq!(model.label_to_node("fish"), None);
149 /// ```
150 /// Using a model's class map during training to determine the index of node that should have had the highest value
151 /// ```
152 /// # use fekan::kan::{Kan, KanOptions, ModelType};
153 /// # let my_class_map = vec!["cat".to_string(), "dog".to_string()];
154 /// # let options = KanOptions {
155 /// # num_features: 5,
156 /// # layer_sizes: vec![4, 2],
157 /// # degree: 3,
158 /// # coef_size: 6,
159 /// # model_type: ModelType::Regression,
160 /// # class_map: Some(my_class_map),
161 /// # embedding_options: None,
162 /// # };
163 /// # let mut model = Kan::new(&options);
164 /// # let feature_data = vec![vec![0.5, 0.4, 0.5, 0.5, 0.4]];
165 /// # let label = "cat";
166 /// # fn cross_entropy_loss(output: Vec<f64>, expected_highest_node: usize) -> f64 {0.0}
167 /// /* within your custom training function */
168 /// let batch_logits: Vec<Vec<f64>> = model.forward(feature_data)?;
169 /// for logits in batch_logits {
170 /// let expected_highest_node: usize = model.label_to_node(label).unwrap();
171 /// let loss: f64 = cross_entropy_loss(logits, expected_highest_node);
172 /// }
173 /// # Ok::<(), fekan::kan::kan_error::KanError>(())
174 /// ```
175 pub fn label_to_node(&self, label: &str) -> Option<usize> {
176 if let Some(class_map) = &self.class_map {
177 class_map.iter().position(|x| x == label)
178 } else {
179 None
180 }
181 }
182
183 /// Returns the label for the output node at the given index.
184 ///
185 /// Returns None if the index is out of bounds, or if the model does not have a class map
186 ///
187 /// # Example
188 /// ```
189 /// use fekan::kan::{Kan, KanOptions, ModelType};
190 /// let class_map = vec!["cat".to_string(), "dog".to_string()];
191 /// let options = KanOptions {
192 /// num_features: 5,
193 /// layer_sizes: vec![4, 2],
194 /// degree: 3,
195 /// coef_size: 6,
196 /// model_type: ModelType::Regression,
197 /// class_map: Some(class_map),
198 /// embedding_options: None,
199 /// };
200 /// let model = Kan::new(&options);
201 /// assert_eq!(model.node_to_label(0), Some("cat"));
202 /// assert_eq!(model.node_to_label(1), Some("dog"));
203 /// assert_eq!(model.node_to_label(2), None);
204 /// ```
205 /// Using a model's class map during inference to interpret the output of a classifier
206 /// ```
207 /// # use fekan::kan::{Kan, KanOptions, ModelType};
208 /// # let my_class_map = vec!["cat".to_string(), "dog".to_string()];
209 /// # let options = KanOptions {
210 /// # num_features: 5,
211 /// # layer_sizes: vec![4, 2],
212 /// # degree: 3,
213 /// # coef_size: 6,
214 /// # model_type: ModelType::Regression,
215 /// # class_map: Some(my_class_map),
216 /// embedding_options: None,
217 /// # };
218 /// # let model = Kan::new(&options);
219 /// # let feature_data = vec![vec![0.5, 0.4, 0.5, 0.5, 0.4]];
220 /// /* using an already trained model... */
221 /// let batch_logits: Vec<Vec<f64>> = model.infer(feature_data)?;
222 /// for logits in batch_logits {
223 /// let highest_node: usize = logits.iter().enumerate().max_by(|(a_idx, a_val), (b_idx, b_val)| a_val.partial_cmp(b_val).unwrap()).unwrap().0;
224 /// let label: &str = model.node_to_label(highest_node).unwrap();
225 /// println!("The model predicts the input is a {}", label);
226 /// }
227 /// Ok::<(), fekan::kan::kan_error::KanError>(())
228 /// ```
229 pub fn node_to_label(&self, node: usize) -> Option<&str> {
230 if let Some(class_map) = &self.class_map {
231 class_map.get(node).map(|x| x.as_str())
232 } else {
233 None
234 }
235 }
236
237 /// Forward-propogate the input through the model, by calling [`KanLayer::forward`] method of the first layer on the input,
238 /// then calling the `forward` method of each subsequent layer with the output of the previous layer,
239 /// returning the output of the final layer.
240 ///
241 /// This method accumulates internal state in the model needed for training. For inference or validation, use [`Kan::infer`], which does not accumulate state and is more efficient
242 ///
243 /// # Errors
244 /// returns a [`KanError`] if any layer returns an error.
245 /// See [`KanLayer::forward`] for more information
246 ///
247 /// # Example
248 /// ```
249 /// use fekan::kan::{Kan, KanOptions, ModelType, kan_error::KanError};
250 /// let num_features = 5;
251 /// let output_size = 3;
252 /// let options = KanOptions {
253 /// num_features,
254 /// layer_sizes: vec![4, output_size],
255 /// degree: 3,
256 /// coef_size: 6,
257 /// model_type: ModelType::Classification,
258 /// class_map: None,
259 /// embedding_options: None,
260 /// };
261 /// let mut model = Kan::new(&options);
262 /// let batch_size = 2;
263 /// let input = vec![vec![1.0; num_features]; batch_size];
264 /// let output = model.forward(input)?;
265 /// assert_eq!(output.len(), batch_size);
266 /// assert_eq!(output[0].len(), output_size);
267 /// /* interpret the output as you like, for example as logits in a classifier, or as predicted value in a regressor */
268 /// # Ok::<(), fekan::kan::kan_error::KanError>(())
269 /// ```
270 pub fn forward(&mut self, input: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, KanError> {
271 debug!("Forwarding {} samples through model", input.len());
272 trace!("Preactivations: {:?}", input);
273 let mut preacts = input;
274 if let Some(embedding_layer) = self.embedding_layer.as_mut() {
275 preacts = embedding_layer
276 .forward(preacts)
277 .map_err(|e| KanError::forward(e, 0))?;
278 }
279 for (idx, layer) in self.layers.iter_mut().enumerate() {
280 debug!("Forwarding through layer {}", idx);
281 preacts = layer
282 .forward(preacts)
283 .map_err(|e| KanError::forward(e, idx))?;
284 }
285 Ok(preacts)
286 }
287
288 /// as [`Kan::forward`], but uses multiple threads to forward the input through the model
289 pub fn forward_multithreaded(
290 &mut self,
291 input: Vec<Vec<f64>>,
292 num_threads: usize,
293 ) -> Result<Vec<Vec<f64>>, KanError> {
294 debug!(
295 "Forwarding {} samples through model using {} threads",
296 input.len(),
297 num_threads
298 );
299 trace!("Preactivations: {:?}", input);
300 let mut preacts = input;
301 if let Some(embedding_layer) = self.embedding_layer.as_mut() {
302 preacts = embedding_layer
303 .forward(preacts)
304 .map_err(|e| KanError::forward(e, 0))?;
305 }
306 for (idx, layer) in self.layers.iter_mut().enumerate() {
307 debug!("Forwarding through layer {}", idx);
308 preacts = layer
309 .forward_multithreaded(preacts, num_threads)
310 .map_err(|e| KanError::forward(e, idx))?;
311 }
312 Ok(preacts)
313 }
314
315 /// as [`Kan::forward`], but does not accumulate any internal state
316 ///
317 /// This method should be used when the model is not being trained, for example during inference or validation: when you won't be backpropogating, this method is faster uses less memory than [`Kan::forward`]
318 ///
319 /// # Errors
320 /// returns a [KanError] if any layer returns an error.
321 /// # Example
322 /// see [`Kan::forward`] for an example
323 pub fn infer(&self, input: Vec<Vec<f64>>) -> Result<Vec<Vec<f64>>, KanError> {
324 let mut preacts = input;
325 if let Some(embedding_layer) = self.embedding_layer.as_ref() {
326 preacts = embedding_layer
327 .infer(&preacts)
328 .map_err(|e| KanError::forward(e, 0))?;
329 }
330 for (idx, layer) in self.layers.iter().enumerate() {
331 preacts = layer
332 .infer(&preacts)
333 .map_err(|e| KanError::forward(e, idx))?;
334 }
335 Ok(preacts)
336 }
337
338 /// as [`Kan::infer`], but uses multiple threads to forward the input through the model
339 pub fn infer_multithreaded(
340 &self,
341 input: Vec<Vec<f64>>,
342 num_threads: usize,
343 ) -> Result<Vec<Vec<f64>>, KanError> {
344 let mut preacts = input;
345 if let Some(embedding_layer) = self.embedding_layer.as_ref() {
346 preacts = embedding_layer
347 .infer(&preacts)
348 .map_err(|e| KanError::forward(e, 0))?;
349 }
350 for (idx, layer) in self.layers.iter().enumerate() {
351 preacts = layer
352 .infer_multithreaded(&preacts, num_threads)
353 .map_err(|e| KanError::forward(e, idx))?;
354 }
355 Ok(preacts)
356 }
357
358 /// Back-propogate the gradient through the model, internally accumulating the gradients of the model's parameters, to be applied later with [`Kan::update`]
359 ///
360 /// # Errors
361 /// returns an error if any layer returns an error.
362 /// See [`KanLayer::backward`] for more information
363 ///
364 /// # Example
365 /// ```
366 /// use fekan::kan::{Kan, KanOptions, ModelType, kan_error::KanError};
367 ///
368 /// # let options = KanOptions {
369 /// # num_features: 5,
370 /// # layer_sizes: vec![4, 1],
371 /// # degree: 3,
372 /// # coef_size: 6,
373 /// # model_type: ModelType::Regression,
374 /// # class_map: None,
375 /// # embedding_options: None,
376 /// # };
377 /// let mut model = Kan::new(&options);
378 ///
379 /// # fn calculate_gradient(output: &[Vec<f64>], label: f64) -> Vec<Vec<f64>> {vec![vec![1.0; output[0].len()]; output.len()]}
380 /// # let learning_rate = 0.1;
381 /// # let l1_penalty = 0.0;
382 /// # let entropy_penalty = 0.0;
383 /// # let features = vec![vec![0.5, 0.4, 0.5, 0.5, 0.4]];
384 /// # let label = 0;
385 /// let output = model.forward(features)?;
386 /// let gradient = calculate_gradient(&output, label as f64);
387 /// assert_eq!(gradient.len(), output.len());
388 /// let _ = model.backward(gradient)?; // the input gradient can be disregarded here.
389 ///
390 /// /*
391 /// * The model has stored the gradients for it's parameters internally.
392 /// * We can add conduct as many forward/backward pass-pairs as we like to accumulate gradient,
393 /// * until we're ready to update the paramaters.
394 /// */
395 ///
396 /// model.update(learning_rate, l1_penalty, entropy_penalty); // update the parameters of the model based on the accumulated gradients here
397 /// model.zero_gradients(); // zero the gradients for the next batch of training data
398 /// # Ok::<(), fekan::kan::kan_error::KanError>(())
399 /// ```
400 pub fn backward(&mut self, gradients: Vec<Vec<f64>>) -> Result<(), KanError> {
401 debug!("Backwarding {} gradients through model", gradients.len());
402 let mut gradients = gradients;
403 for (idx, layer) in self.layers.iter_mut().enumerate().rev() {
404 gradients = layer
405 .backward(&gradients)
406 .map_err(|e| KanError::backward(e, idx))?;
407 }
408 if let Some(embedding_layer) = self.embedding_layer.as_mut() {
409 embedding_layer
410 .backward(gradients)
411 .map_err(|e| KanError::backward(e, 0))?;
412 }
413
414 Ok(())
415 }
416
417 /// as [`Kan::backward`], but uses multiple threads to back-propogate the gradient through the model
418 pub fn backward_multithreaded(
419 &mut self,
420 gradients: Vec<Vec<f64>>,
421 num_threads: usize,
422 ) -> Result<(), KanError> {
423 debug!("Backwarding {} gradients through model", gradients.len());
424 let mut gradients = gradients;
425 for (idx, layer) in self.layers.iter_mut().enumerate().rev() {
426 gradients = layer
427 .backward_multithreaded(&gradients, num_threads)
428 .map_err(|e| KanError::backward(e, idx))?;
429 }
430 if let Some(embedding_layer) = self.embedding_layer.as_mut() {
431 embedding_layer
432 .backward(gradients)
433 .map_err(|e| KanError::backward(e, 0))?;
434 }
435
436 Ok(())
437 }
438
439 /// Update the model's parameters based on the gradients that have been accumulated with [`Kan::backward`].
440 /// # Example
441 /// see [`Kan::backward`]
442 pub fn update(&mut self, learning_rate: f64, l1_penalty: f64, entropy_penalty: f64) {
443 if let Some(embedding_table) = self.embedding_layer.as_mut() {
444 embedding_table.update(learning_rate);
445 }
446 for layer in self.layers.iter_mut() {
447 layer.update(learning_rate, l1_penalty, entropy_penalty);
448 }
449 }
450
451 /// Zero the internal gradients of the model's parameters
452 /// # Example
453 /// see [`Kan::backward`]
454 pub fn zero_gradients(&mut self) {
455 if let Some(embedding_table) = self.embedding_layer.as_mut() {
456 embedding_table.zero_gradients();
457 }
458 for layer in self.layers.iter_mut() {
459 layer.zero_gradients();
460 }
461 }
462
463 /// get the total number of parameters in the model, inlcuding untrained parameters. See [`KanLayer::parameter_count`] for more information
464 pub fn parameter_count(&self) -> usize {
465 self.layers
466 .iter()
467 .map(|layer| layer.parameter_count())
468 .sum()
469 }
470
471 /// get the total number of trainable parameters in the model. See [`KanLayer::trainable_parameter_count`] for more information
472 pub fn trainable_parameter_count(&self) -> usize {
473 self.layers
474 .iter()
475 .map(|layer| layer.trainable_parameter_count())
476 .sum()
477 }
478
479 /// Update the ranges spanned by the B-spline knots in the model, using samples accumulated by recent [`Kan::forward`] calls.
480 ///
481 /// see [`KanLayer::update_knots_from_samples`] for more information
482 /// # Errors
483 /// returns a [KanError] if any layer returns an error.
484 ///
485 /// # Example
486 /// see [`KanLayer::update_knots_from_samples`] for examples
487 pub fn update_knots_from_samples(&mut self, knot_adaptivity: f64) -> Result<(), KanError> {
488 for (idx, layer) in self.layers.iter_mut().enumerate() {
489 debug!("Updating knots for layer {}", idx);
490 if let Err(e) = layer.update_knots_from_samples(knot_adaptivity) {
491 return Err(KanError::update_knots(e, idx));
492 }
493 }
494 return Ok(());
495 }
496
497 /// as [`Kan::update_knots_from_samples`], but uses multiple threads to update the knot vectors
498 pub fn update_knots_from_samples_multithreaded(
499 &mut self,
500 knot_adaptivity: f64,
501 num_threads: usize,
502 ) -> Result<(), KanError> {
503 for (idx, layer) in self.layers.iter_mut().enumerate() {
504 debug!("Updating knots for layer {}", idx);
505 if let Err(e) =
506 layer.update_knots_from_samples_multithreaded(knot_adaptivity, num_threads)
507 {
508 return Err(KanError::update_knots(e, idx));
509 }
510 }
511 return Ok(());
512 }
513
514 /// Clear the cached samples used by [`Kan::update_knots_from_samples`]
515 ///
516 /// see [`KanLayer::clear_samples`] for more information
517 pub fn clear_samples(&mut self) {
518 debug!("Clearing samples from model");
519 if let Some(embedding_table) = self.embedding_layer.as_mut() {
520 embedding_table.clear_samples();
521 }
522 for layer_idx in 0..self.layers.len() {
523 debug!("Clearing samples from layer {}", layer_idx);
524 self.layers[layer_idx].clear_samples();
525 }
526 }
527
528 /// Set the size of the knot vector used in all splines in this model
529 /// see [`KanLayer::set_knot_length`] for more information
530 pub fn set_knot_length(&mut self, knot_length: usize) -> Result<(), KanError> {
531 for (idx, layer) in self.layers.iter_mut().enumerate() {
532 if let Err(e) = layer.set_knot_length(knot_length) {
533 return Err(KanError::set_knot_length(e, idx));
534 }
535 }
536 Ok(())
537 }
538
539 /// Get the size of the knot vector used in all splines in this model
540 ///
541 /// ## Note
542 /// if different layers have different knot lengths, this method will return the knot length of the first layer
543 pub fn knot_length(&self) -> usize {
544 self.layers[0].knot_length()
545 }
546
547 /// Create a new model by merging multiple models together. Models must be of the same type and have the same number of layers, and all layers must be mergable (see [`KanLayer::merge_layers`])
548 /// # Errors
549 /// Returns a [`KanError`] if:
550 /// * the models are not mergable. See [`Kan::models_mergable`] for more information
551 /// * any layer encounters an error during the merge. See [`KanLayer::merge_layers`] for more information
552 /// # Example
553 /// Train multiple copies of the model on different data in different threads, then merge the trained models together
554 /// ```
555 /// use fekan::{kan::{Kan, KanOptions, ModelType, kan_error::KanError}, Sample};
556 /// use std::thread;
557 /// # let model_options = KanOptions {
558 /// # num_features: 5,
559 /// # layer_sizes: vec![4, 3],
560 /// # degree: 3,
561 /// # coef_size: 6,
562 /// # model_type: ModelType::Regression,
563 /// # class_map: None,
564 /// # embedding_options: None,
565 /// };
566 /// # let num_training_threads = 1;
567 /// # let training_data = vec![ Sample::new_regression_sample(vec![], 0.0) ];
568 /// # fn my_train_model_function(model: Kan, data: &[Sample]) -> Kan {model}
569 /// let mut my_model = Kan::new(&model_options);
570 /// let partially_trained_models: Vec<Kan> = thread::scope(|s|{
571 /// let chunk_size = f32::ceil(training_data.len() as f32 / num_training_threads as f32) as usize; // round up, since .chunks() gives up-to chunk_size chunks. This way to don't leave any data on the cutting room floor
572 /// let handles: Vec<_> = training_data.chunks(chunk_size).map(|training_data_chunk|{
573 /// let clone_model = my_model.clone();
574 /// s.spawn(move ||{
575 /// my_train_model_function(clone_model, training_data_chunk) // `my_train_model_function` is a stand-in for whatever function you're using to train the model - not actually defined in this crate
576 /// })
577 /// }).collect();
578 /// handles.into_iter().map(|handle| handle.join().unwrap()).collect()
579 /// });
580 /// let fully_trained_model = Kan::merge_models(partially_trained_models)?;
581 /// # Ok::<(), fekan::kan::kan_error::KanError>(())
582 /// ```
583 ///
584 pub fn merge_models(models: Vec<Kan>) -> Result<Kan, KanError> {
585 Self::models_mergable(&models)?; // check if the models are mergable
586 let layer_count = models[0].layers.len();
587 let model_type = models[0].model_type;
588 let class_map = models[0].class_map.clone();
589 let merged_embedding_layer = if models[0].embedding_layer.is_some() {
590 let embedding_layers: Vec<&EmbeddingLayer> = models
591 .iter()
592 .map(|model| model.embedding_layer.as_ref().unwrap())
593 .collect();
594 let merged_embedding_layer = EmbeddingLayer::merge_layers(&embedding_layers)
595 .map_err(|e| KanError::merge_unmergable_layers(e, 0))?;
596 Some(merged_embedding_layer)
597 } else {
598 None
599 };
600 // merge the layers
601 let mut all_layers: Vec<VecDeque<KanLayer>> = models
602 .into_iter()
603 .map(|model| model.layers.into())
604 .collect();
605 // aggregate layers by index
606 let mut merged_layers = Vec::new();
607 for layer_idx in 0..layer_count {
608 let layers_to_merge: Vec<KanLayer> = all_layers
609 .iter_mut()
610 .map(|layers| {
611 layers
612 .pop_front()
613 .expect("iterated past end of dequeue while merging models")
614 })
615 .collect();
616 let merged_layer = KanLayer::merge_layers(layers_to_merge)
617 .map_err(|e| KanError::merge_unmergable_layers(e, layer_idx))?;
618 merged_layers.push(merged_layer);
619 }
620
621 let merged_model = Kan {
622 embedding_layer: merged_embedding_layer,
623 layers: merged_layers,
624 model_type,
625 class_map,
626 };
627 Ok(merged_model)
628 }
629
630 /// Check if the given models can be merged using [`Kan::merge_models`]. Returns Ok(()) if the models are mergable, an error otherwise
631 /// # Errors
632 /// Returns a [`KanError`] if any of the models:
633 /// * have different model types (e.g. classification vs regression)
634 /// * have different numbers of layers
635 /// * have different class maps (if the models are classification models)
636 /// or if the input slice is empty
637 pub fn models_mergable(models: &[Kan]) -> Result<(), KanError> {
638 let expected_model_type = models[0].model_type;
639 let expected_class_map = &models[0].class_map;
640 let expected_layer_count = models[0].layers.len();
641 let expected_embedding_table = &models[0].embedding_layer;
642 for idx in 1..models.len() {
643 if models[idx].model_type != expected_model_type {
644 return Err(KanError::merge_mismatched_model_type(
645 idx,
646 expected_model_type,
647 models[idx].model_type,
648 ));
649 }
650 if models[idx].class_map != *expected_class_map {
651 return Err(KanError::merge_mismatched_class_map(
652 idx,
653 expected_class_map.clone(),
654 models[idx].class_map.clone(),
655 ));
656 }
657 if models[idx].layers.len() != expected_layer_count {
658 return Err(KanError::merge_mismatched_depth_model(
659 idx,
660 expected_layer_count,
661 models[idx].layers.len(),
662 ));
663 }
664 if models[idx].embedding_layer.is_some() != expected_embedding_table.is_some() {
665 return Err(KanError::merge_mismatched_embedding_table_presence(
666 idx,
667 expected_embedding_table.is_some(),
668 models[idx].embedding_layer.is_some(),
669 ));
670 }
671 }
672 Ok(())
673 }
674
675 /// Test and set the symbolic status of the model, using the given R^2 threshold. See [`KanLayer::test_and_set_symbolic`] for more information
676 pub fn test_and_set_symbolic(&mut self, r2_threshold: f64) -> Vec<(usize, usize)> {
677 debug!("Testing and setting symbolic for the model");
678 let mut symbolified_edges = Vec::new();
679 for i in 0..self.layers.len() {
680 debug!("Symbolifying layer {}", i);
681 let clamped_edges = self.layers[i].test_and_set_symbolic(r2_threshold);
682 symbolified_edges.extend(clamped_edges.into_iter().map(|j| (i, j)));
683 }
684 symbolified_edges
685 }
686
687 /// Prune the model, removing any edges that have low average output. See [`KanLayer::prune`] for more information
688 /// Returns a list of the indices of pruned edges (i,j), where i is the index of the layer, and j is the index of the edge in that layer that was pruned
689 pub fn prune(
690 &mut self,
691 samples: Vec<Vec<f64>>,
692 threshold: f64,
693 ) -> Result<Vec<(usize, usize)>, KanError> {
694 let mut pruned_edges = Vec::new();
695 let mut samples = match &self.embedding_layer {
696 None => samples,
697 Some(embedding_layer) => embedding_layer
698 .infer(&samples)
699 .map_err(|e| KanError::forward(e, 0))?,
700 };
701 for i in 0..self.layers.len() {
702 debug!("Pruning layer {}", i);
703 let this_layer = &mut self.layers[i];
704 let next_samples = this_layer
705 .infer(&samples)
706 .map_err(|e| KanError::forward(e, i))?;
707 let layer_prunings = this_layer.prune(&samples, threshold);
708 pruned_edges.extend(layer_prunings.into_iter().map(|j| (i, j)));
709 samples = next_samples;
710 }
711 Ok(pruned_edges)
712 }
713}
714
715impl PartialEq for Kan {
716 fn eq(&self, other: &Self) -> bool {
717 self.layers == other.layers && self.model_type == other.model_type
718 }
719}
720
721#[cfg(test)]
722mod test {
723 use super::*;
724
725 #[test]
726 fn test_forward() {
727 let kan_config = KanOptions {
728 num_features: 3,
729 layer_sizes: vec![4, 2, 3],
730 degree: 3,
731 coef_size: 4,
732 model_type: ModelType::Classification,
733 class_map: None,
734 embedding_options: None,
735 };
736 let mut first_kan = Kan::new(&kan_config);
737 let second_kan_config = KanOptions {
738 layer_sizes: vec![2, 4, 3],
739 ..kan_config
740 };
741 let mut second_kan = Kan::new(&second_kan_config);
742 let input = vec![vec![0.5, 0.4, 0.5]];
743 let result = first_kan.forward(input.clone()).unwrap();
744 assert_eq!(result.len(), 1);
745 assert_eq!(result[0].len(), 3);
746 let result = second_kan.forward(input).unwrap();
747 assert_eq!(result.len(), 1);
748 assert_eq!(result[0].len(), 3);
749 }
750
751 #[test]
752 fn test_forward_then_backward() {
753 let options = &KanOptions {
754 num_features: 5,
755 layer_sizes: vec![4, 2, 3],
756 degree: 3,
757 coef_size: 4,
758 model_type: ModelType::Classification,
759 class_map: None,
760 embedding_options: None,
761 };
762 let mut first_kan = Kan::new(options);
763 let input = vec![vec![0.5, 0.4, 0.5, 0.5, 0.4]];
764 let result = first_kan.forward(input.clone()).unwrap();
765 assert_eq!(result.len(), 1);
766 assert_eq!(result[0].len(), options.layer_sizes.last().unwrap().clone());
767 let error = vec![vec![0.5, 0.4, 0.5]];
768 let result = first_kan.backward(error);
769 assert!(result.is_ok());
770 }
771
772 #[test]
773 fn test_merge_identical_models_yields_identical_output() {
774 let kan_config = KanOptions {
775 num_features: 3,
776 layer_sizes: vec![4, 2, 3],
777 degree: 3,
778 coef_size: 4,
779 model_type: ModelType::Classification,
780 class_map: None,
781 embedding_options: None,
782 };
783 let first_kan = Kan::new(&kan_config);
784 let second_kan = first_kan.clone();
785 let input = vec![vec![0.5, 0.4, 0.5]];
786 let first_result = first_kan.infer(input.clone()).unwrap();
787 let second_result = second_kan.infer(input.clone()).unwrap();
788 assert_eq!(first_result, second_result);
789 let merged_kan = Kan::merge_models(vec![first_kan, second_kan]).unwrap();
790 let merged_result = merged_kan.infer(input).unwrap();
791 assert_eq!(first_result, merged_result);
792 }
793
794 #[test]
795 fn test_model_send() {
796 fn assert_send<T: Send>() {}
797 assert_send::<Kan>();
798 }
799
800 #[test]
801 fn test_model_sync() {
802 fn assert_sync<T: Sync>() {}
803 assert_sync::<Kan>();
804 }
805
806 #[test]
807 fn test_error_send() {
808 fn assert_send<T: Send>() {}
809 assert_send::<KanError>();
810 }
811
812 #[test]
813 fn test_error_sync() {
814 fn assert_sync<T: Sync>() {}
815 assert_sync::<KanError>();
816 }
817}