1use cnc::nn::{DeepModelParams, Model, ModelFeatures, StandardModelConfig};
7#[cfg(feature = "rand")]
8use cnc::rand_distr;
9
10use num_traits::{Float, FromPrimitive};
11
12#[derive(Clone, Debug)]
13pub struct KanModel<T = f64> {
14 pub config: StandardModelConfig<T>,
15 pub features: ModelFeatures,
16 pub params: DeepModelParams<T>,
17}
18
19impl<T> KanModel<T>
20where
21 T: Float + FromPrimitive,
22{
23 pub fn new(config: StandardModelConfig<T>, features: ModelFeatures) -> Self
24 where
25 T: Clone + Default,
26 {
27 let params = DeepModelParams::default(features);
28 KanModel {
29 config,
30 features,
31 params,
32 }
33 }
34 #[cfg(feature = "rand")]
35 pub fn init(self) -> Self
36 where
37 rand_distr::StandardNormal: rand_distr::Distribution<T>,
38 {
39 let params = DeepModelParams::glorot_normal(self.features());
40 KanModel { params, ..self }
41 }
42 pub const fn config(&self) -> &StandardModelConfig<T> {
44 &self.config
45 }
46 pub const fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
48 &mut self.config
49 }
50 pub const fn features(&self) -> ModelFeatures {
52 self.features
53 }
54 pub const fn features_mut(&mut self) -> &mut ModelFeatures {
56 &mut self.features
57 }
58 pub const fn params(&self) -> &DeepModelParams<T> {
60 &self.params
61 }
62 pub const fn params_mut(&mut self) -> &mut DeepModelParams<T> {
64 &mut self.params
65 }
66 pub fn set_config(&mut self, config: StandardModelConfig<T>) -> &mut Self {
68 self.config = config;
69 self
70 }
71 pub fn set_features(&mut self, features: ModelFeatures) -> &mut Self {
73 self.features = features;
74 self
75 }
76 pub fn set_params(&mut self, params: DeepModelParams<T>) -> &mut Self {
78 self.params = params;
79 self
80 }
81 pub fn with_config(self, config: StandardModelConfig<T>) -> Self {
83 Self { config, ..self }
84 }
85 pub fn with_features(self, features: ModelFeatures) -> Self {
87 Self { features, ..self }
88 }
89 pub fn with_params(self, params: DeepModelParams<T>) -> Self {
91 Self { params, ..self }
92 }
93}
94
95impl<T> Model<T> for KanModel<T> {
96 type Config = StandardModelConfig<T>;
97 type Layout = ModelFeatures;
98
99 fn config(&self) -> &StandardModelConfig<T> {
100 &self.config
101 }
102
103 fn config_mut(&mut self) -> &mut StandardModelConfig<T> {
104 &mut self.config
105 }
106
107 fn layout(&self) -> ModelFeatures {
108 self.features
109 }
110
111 fn params(&self) -> &DeepModelParams<T> {
112 &self.params
113 }
114
115 fn params_mut(&mut self) -> &mut DeepModelParams<T> {
116 &mut self.params
117 }
118}