concision_kan/
model.rs

1/*
2    appellation: model <module>
3    authors: @FL03
4*/
5
6use cnc::nn::{DeepModelParams, Model, ModelFeatures, StandardModelConfig};
7#[cfg(feature = "rand")]
8use cnc::rand_distr;
9
10use num_traits::{Float, FromPrimitive};
11
12#[derive(Clone, Debug)]
13pub struct KanModel<T = f64> {
14    pub config: StandardModelConfig<T>,
15    pub features: ModelFeatures,
16    pub params: DeepModelParams<T>,
17}
18
19impl<T> KanModel<T>
20where
21    T: Float + FromPrimitive,
22{
23    pub fn new(config: StandardModelConfig<T>, features: ModelFeatures) -> Self
24    where
25        T: Clone + Default,
26    {
27        let params = DeepModelParams::default(features);
28        KanModel {
29            config,
30            features,
31            params,
32        }
33    }
34    #[cfg(feature = "rand")]
35    pub fn init(self) -> Self
36    where
37        rand_distr::StandardNormal: rand_distr::Distribution<T>,
38    {
39        let params = DeepModelParams::glorot_normal(self.features());
40        KanModel { params, ..self }
41    }
42    /// returns a reference to the model configuration
43    pub const fn config(&self) -> &StandardModelConfig<T> {
44        &self.config
45    }
46    /// returns a mutable reference to the model configuration
47    pub const fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
48        &mut self.config
49    }
50    /// returns the model features
51    pub const fn features(&self) -> ModelFeatures {
52        self.features
53    }
54    /// returns a mutable reference to the model features
55    pub const fn features_mut(&mut self) -> &mut ModelFeatures {
56        &mut self.features
57    }
58    /// returns a reference to the model parameters
59    pub const fn params(&self) -> &DeepModelParams<T> {
60        &self.params
61    }
62    /// returns a mutable reference to the model parameters
63    pub const fn params_mut(&mut self) -> &mut DeepModelParams<T> {
64        &mut self.params
65    }
66    /// set the current configuration and return a mutable reference to the model
67    pub fn set_config(&mut self, config: StandardModelConfig<T>) -> &mut Self {
68        self.config = config;
69        self
70    }
71    /// set the current features and return a mutable reference to the model
72    pub fn set_features(&mut self, features: ModelFeatures) -> &mut Self {
73        self.features = features;
74        self
75    }
76    /// set the current parameters and return a mutable reference to the model
77    pub fn set_params(&mut self, params: DeepModelParams<T>) -> &mut Self {
78        self.params = params;
79        self
80    }
81    /// consumes the current instance to create another with the given configuration
82    pub fn with_config(self, config: StandardModelConfig<T>) -> Self {
83        Self { config, ..self }
84    }
85    /// consumes the current instance to create another with the given features
86    pub fn with_features(self, features: ModelFeatures) -> Self {
87        Self { features, ..self }
88    }
89    /// consumes the current instance to create another with the given parameters
90    pub fn with_params(self, params: DeepModelParams<T>) -> Self {
91        Self { params, ..self }
92    }
93}
94
95impl<T> Model<T> for KanModel<T> {
96    type Config = StandardModelConfig<T>;
97    type Layout = ModelFeatures;
98
99    fn config(&self) -> &StandardModelConfig<T> {
100        &self.config
101    }
102
103    fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
104        &mut self.config
105    }
106
107    fn layout(&self) -> ModelFeatures {
108        self.features
109    }
110
111    fn params(&self) -> &DeepModelParams<T> {
112        &self.params
113    }
114
115    fn params_mut(&mut self) -> &mut DeepModelParams<T> {
116        &mut self.params
117    }
118}