burn_efficient_kan/
lib.rs1#![warn(clippy::nursery, clippy::pedantic)]
2mod down_dim;
3mod kaiming;
4mod least_squares;
5mod linear;
6use burn::{
7 config::Config,
8 module::Module,
9 tensor::{backend::Backend, Tensor},
10};
11use linear::Linear;
12
13#[derive(Debug, Module)]
14pub struct Kan<B: Backend> {
15 grid_size: u16,
16 spline_order: u32,
17 layer_one: Linear<B>,
18 layer_two: Linear<B>,
19}
20
21#[derive(Config, Debug)]
22pub struct KanOptions {
23 pub layers_hidden: [u32; 3],
24 #[config(default = 5)]
25 pub grid_size: u16,
26 #[config(default = 3)]
27 pub spline_order: u32,
28 #[config(default = 0.1)]
29 pub scale_noise: f32,
30 #[config(default = 1.0)]
31 pub scale_base: f32,
32 #[config(default = 1.0)]
33 pub scale_spline: f32,
34 #[config(default = 0.02)]
35 pub grid_eps: f32,
36 #[config(default = true)]
37 pub enable_standalone_scale_spine: bool,
38 #[config(default = -1)]
39 pub grid_range_start: i32,
40 #[config(default = 1)]
41 pub grid_range_end: i32,
42}
43
44impl KanOptions {
45 pub fn init<B: Backend>(&self, device: &B::Device) -> Kan<B>
46 where
47 B::FloatElem: ndarray_linalg::Scalar + ndarray_linalg::Lapack,
48 {
49 Kan::new(self, device)
50 }
51}
52
53impl<B: Backend> Kan<B> {
54 pub fn new(options: &KanOptions, device: &B::Device) -> Self
55 where
56 B::FloatElem: ndarray_linalg::Scalar + ndarray_linalg::Lapack,
57 {
58 let layers_hidden = options.layers_hidden;
59 let zip = [
60 (layers_hidden[0], layers_hidden[1]),
61 (layers_hidden[1], layers_hidden[2]),
62 ];
63 let [layer_one, layer_two] = zip.map(|(in_features, out_features)| {
64 Linear::new(in_features, out_features, options, device)
65 });
66 Self {
67 grid_size: options.grid_size,
68 spline_order: options.spline_order,
69 layer_one,
70 layer_two,
71 }
72 }
73
74 const fn layers(&self) -> [&Linear<B>; 2] {
75 [&self.layer_one, &self.layer_two]
76 }
77
78 pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
79 let mut output = input;
80 for layer in self.layers() {
81 output = layer.forward(&output);
82 }
83 output
84 }
85}