burn_efficient_kan/
lib.rs

1#![warn(clippy::nursery, clippy::pedantic)]
2mod down_dim;
3mod kaiming;
4mod least_squares;
5mod linear;
6use burn::{
7    config::Config,
8    module::Module,
9    tensor::{backend::Backend, Tensor},
10};
11use linear::Linear;
12
13#[derive(Debug, Module)]
14pub struct Kan<B: Backend> {
15    grid_size: u16,
16    spline_order: u32,
17    layer_one: Linear<B>,
18    layer_two: Linear<B>,
19}
20
21#[derive(Config, Debug)]
22pub struct KanOptions {
23    pub layers_hidden: [u32; 3],
24    #[config(default = 5)]
25    pub grid_size: u16,
26    #[config(default = 3)]
27    pub spline_order: u32,
28    #[config(default = 0.1)]
29    pub scale_noise: f32,
30    #[config(default = 1.0)]
31    pub scale_base: f32,
32    #[config(default = 1.0)]
33    pub scale_spline: f32,
34    #[config(default = 0.02)]
35    pub grid_eps: f32,
36    #[config(default = true)]
37    pub enable_standalone_scale_spine: bool,
38    #[config(default = -1)]
39    pub grid_range_start: i32,
40    #[config(default = 1)]
41    pub grid_range_end: i32,
42}
43
44impl KanOptions {
45    pub fn init<B: Backend>(&self, device: &B::Device) -> Kan<B>
46    where
47        B::FloatElem: ndarray_linalg::Scalar + ndarray_linalg::Lapack,
48    {
49        Kan::new(self, device)
50    }
51}
52
53impl<B: Backend> Kan<B> {
54    pub fn new(options: &KanOptions, device: &B::Device) -> Self
55    where
56        B::FloatElem: ndarray_linalg::Scalar + ndarray_linalg::Lapack,
57    {
58        let layers_hidden = options.layers_hidden;
59        let zip = [
60            (layers_hidden[0], layers_hidden[1]),
61            (layers_hidden[1], layers_hidden[2]),
62        ];
63        let [layer_one, layer_two] = zip.map(|(in_features, out_features)| {
64            Linear::new(in_features, out_features, options, device)
65        });
66        Self {
67            grid_size: options.grid_size,
68            spline_order: options.spline_order,
69            layer_one,
70            layer_two,
71        }
72    }
73
74    const fn layers(&self) -> [&Linear<B>; 2] {
75        [&self.layer_one, &self.layer_two]
76    }
77
78    pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
79        let mut output = input;
80        for layer in self.layers() {
81            output = layer.forward(&output);
82        }
83        output
84    }
85}