eryon_surface/model/
model.rs

1/*
2    Appellation: mlp <module>
3    Contrib: @FL03
4*/
5use super::{FftAttention, SurfaceModelConfig};
6use cnc::prelude::{ModelFeatures, ModelParams, Params};
7use num_traits::FromPrimitive;
8
9/// A multi-layer perceptron implementation
10#[derive(Clone, Debug)]
11pub struct SurfaceModel<A = f64> {
12    pub(crate) config: SurfaceModelConfig<A>,
13    pub(crate) features: ModelFeatures,
14    pub(crate) attention: Option<FftAttention<A>>,
15    pub(crate) params: ModelParams<A>,
16    pub(crate) velocities: ModelParams<A>,
17    pub(crate) prev_target_norm: Option<A>,
18}
19
20impl<A> SurfaceModel<A> {
21    /// create a new instance of the model; all parameters are initialized to their defaults
22    pub fn default(config: SurfaceModelConfig<A>, features: ModelFeatures) -> Self
23    where
24        A: Clone + Default,
25    {
26        Self {
27            attention: None,
28            config,
29            features,
30            params: ModelParams::default(features),
31            velocities: ModelParams::default(features),
32            prev_target_norm: None,
33        }
34    }
35    /// create a new instance of the model; all parameters are initialized to 1
36    pub fn ones(config: SurfaceModelConfig<A>, features: ModelFeatures) -> Self
37    where
38        A: Clone + num_traits::One,
39    {
40        Self {
41            attention: None,
42            config,
43            features,
44            params: ModelParams::ones(features),
45            velocities: ModelParams::ones(features),
46            prev_target_norm: None,
47        }
48    }
49    /// create a new instance of the model; all parameters are initialized to 0
50    pub fn zeros(config: SurfaceModelConfig<A>, features: ModelFeatures) -> Self
51    where
52        A: Clone + num_traits::Zero,
53    {
54        Self {
55            attention: None,
56            config,
57            features,
58            params: ModelParams::zeros(features),
59            velocities: ModelParams::zeros(features),
60            prev_target_norm: None,
61        }
62    }
63    /// returns an immutable reference to the attention mechanism of the model
64    pub const fn attention(&self) -> Option<&FftAttention<A>> {
65        self.attention.as_ref()
66    }
67    /// returns a mutable reference to the attention mechanism of the model
68    pub const fn attention_mut(&mut self) -> Option<&mut FftAttention<A>> {
69        self.attention.as_mut()
70    }
71    /// returns an immutable reference to the model's configuration
72    pub const fn config(&self) -> &SurfaceModelConfig<A> {
73        &self.config
74    }
75    /// returns a mutable reference to the model's configuration
76    pub const fn config_mut(&mut self) -> &mut SurfaceModelConfig<A> {
77        &mut self.config
78    }
79    /// returns an immutable reference to the model's features
80    pub const fn features(&self) -> ModelFeatures {
81        self.features
82    }
83    /// returns a mutable reference to the model's features
84    pub const fn features_mut(&mut self) -> &mut ModelFeatures {
85        &mut self.features
86    }
87    /// returns an immutable reference to the model's parameters
88    pub const fn params(&self) -> &ModelParams<A> {
89        &self.params
90    }
91    /// returns a mutable reference to the model's parameters
92    pub const fn params_mut(&mut self) -> &mut ModelParams<A> {
93        &mut self.params
94    }
95    /// returns an immutable reference to the model's velocities
96    pub const fn velocities(&self) -> &ModelParams<A> {
97        &self.velocities
98    }
99    /// returns a mutable reference to the model's velocities
100    pub const fn velocities_mut(&mut self) -> &mut ModelParams<A> {
101        &mut self.velocities
102    }
103    /// returns an immutable reference to the input layer of the model
104    pub const fn input(&self) -> &Params<A> {
105        self.params().input()
106    }
107    /// returns a mutable reference to the input layer of the model
108    pub const fn input_mut(&mut self) -> &mut Params<A> {
109        self.params_mut().input_mut()
110    }
111    /// returns an immutable reference to the hidden layers of the model
112    pub const fn hidden(&self) -> &Vec<Params<A>> {
113        self.params().hidden()
114    }
115    /// returns an immutable reference to the hidden layers of the model as a slice
116    pub const fn hidden_mut(&mut self) -> &mut Vec<Params<A>> {
117        self.params_mut().hidden_mut()
118    }
119    /// returns an immutable reference to the hidden layers of the model as a slice
120    #[inline]
121    pub fn hidden_as_slice(&self) -> &[Params<A>] {
122        self.params().hidden_as_slice()
123    }
124    /// returns an immutable reference to the output layer of the model
125    pub const fn output(&self) -> &Params<A> {
126        self.params().output()
127    }
128    /// returns a mutable reference to the output layer of the model
129    pub const fn output_mut(&mut self) -> &mut Params<A> {
130        self.params_mut().output_mut()
131    }
132    /// returns an immutable reference to the decay of the model
133    pub const fn decay(&self) -> &A {
134        self.config().decay()
135    }
136    /// returns an immutable reference to the learning rate of the model
137    pub const fn learning_rate(&self) -> &A {
138        self.config().learning_rate()
139    }
140    /// returns an immutable reference to the momentum of the model
141    pub const fn momentum(&self) -> &A {
142        self.config().momentum()
143    }
144    /// set the model's attention mechanism
145    #[inline]
146    pub fn set_attention(&mut self, attention: Option<FftAttention<A>>) -> &mut Self {
147        self.attention = attention;
148        self
149    }
150    /// set the model's configuration
151    #[inline]
152    pub fn set_config(&mut self, config: SurfaceModelConfig<A>) -> &mut Self {
153        *self.config_mut() = config;
154        self
155    }
156    /// set the model's features
157    #[inline]
158    pub fn set_features(&mut self, features: ModelFeatures) -> &mut Self {
159        *self.features_mut() = features;
160        self
161    }
162    /// set the model's parameters
163    #[inline]
164    pub fn set_params(&mut self, params: ModelParams<A>) -> &mut Self {
165        *self.params_mut() = params;
166        self
167    }
168    /// set the input layer of the model
169    #[inline]
170    pub fn set_input(&mut self, input: Params<A>) -> &mut Self {
171        self.params_mut().set_input(input);
172        self
173    }
174    /// set the hidden layers of the model
175    #[inline]
176    pub fn set_hidden<I>(&mut self, iter: I) -> &mut Self
177    where
178        I: IntoIterator<Item = Params<A>>,
179    {
180        self.params_mut().set_hidden(Vec::from_iter(iter));
181        self
182    }
183    /// set the output layer of the model
184    #[inline]
185    pub fn set_output(&mut self, output: Params<A>) -> &mut Self {
186        self.params_mut().set_output(output);
187        self
188    }
189    /// set the decay for the model
190    #[inline]
191    pub fn set_decay(&mut self, decay: A) -> &mut Self {
192        self.config_mut().set_decay(decay);
193        self
194    }
195    /// set the learning rate for the model
196    #[inline]
197    pub fn set_learning_rate(&mut self, learning_rate: A) -> &mut Self {
198        self.config_mut().set_learning_rate(learning_rate);
199        self
200    }
201    /// set the model's momentum
202    #[inline]
203    pub fn set_momentum(&mut self, momentum: A) -> &mut Self {
204        self.config_mut().set_momentum(momentum);
205        self
206    }
207    /// create a new instance of the model with the specified attention mechanism
208    #[inline]
209    pub fn with_attention(self) -> Self
210    where
211        A: FromPrimitive,
212    {
213        Self {
214            attention: Some(FftAttention::default()),
215            ..self
216        }
217    }
218    /// consumes the current instance and returns another with the specified configuration
219    pub fn with_config(self, config: SurfaceModelConfig<A>) -> Self {
220        Self { config, ..self }
221    }
222    /// consumes the current instance and returns another with the given features
223    pub fn with_features(self, features: ModelFeatures) -> Self {
224        Self { features, ..self }
225    }
226    /// consumes the current instance and returns another with the specified input layer
227    pub fn with_input(self, input: Params<A>) -> Self {
228        Self {
229            params: self.params.with_input(input),
230            ..self
231        }
232    }
233    /// consumes the current instance and returns another with the specified hidden layers
234    pub fn with_hidden<I>(self, iter: I) -> Self
235    where
236        I: IntoIterator<Item = Params<A>>,
237    {
238        Self {
239            params: self.params.with_hidden(iter),
240            ..self
241        }
242    }
243    /// consumes the current instance and returns another with the specified output layer
244    pub fn with_output(self, output: Params<A>) -> Self {
245        Self {
246            params: self.params.with_output(output),
247            ..self
248        }
249    }
250    /// returns true if the model has an attention mechanism
251    pub fn has_attention(&self) -> bool {
252        self.attention().is_some()
253    }
254    /// returns true if the model has no attention mechanism
255    pub fn has_no_attention(&self) -> bool {
256        self.attention().is_none()
257    }
258}
259
260impl<A> Default for SurfaceModel<A>
261where
262    A: Clone + Default + FromPrimitive,
263{
264    fn default() -> Self {
265        let config = SurfaceModelConfig::default();
266        let features = ModelFeatures::default();
267        Self::default(config, features)
268    }
269}
270
271impl<A> PartialEq for SurfaceModel<A>
272where
273    A: PartialEq,
274    ModelParams<A>: PartialEq,
275{
276    fn eq(&self, other: &Self) -> bool {
277        self.config == other.config
278            && self.features == other.features
279            && self.params == other.params
280            && self.velocities == other.velocities
281            && self.attention == other.attention
282    }
283}