eryon_surface/model/
model.rs1use super::{FftAttention, SurfaceModelConfig};
6use cnc::prelude::{ModelFeatures, ModelParams, Params};
7use num_traits::FromPrimitive;
8
9#[derive(Clone, Debug)]
11pub struct SurfaceModel<A = f64> {
12 pub(crate) config: SurfaceModelConfig<A>,
13 pub(crate) features: ModelFeatures,
14 pub(crate) attention: Option<FftAttention<A>>,
15 pub(crate) params: ModelParams<A>,
16 pub(crate) velocities: ModelParams<A>,
17 pub(crate) prev_target_norm: Option<A>,
18}
19
20impl<A> SurfaceModel<A> {
21 pub fn default(config: SurfaceModelConfig<A>, features: ModelFeatures) -> Self
23 where
24 A: Clone + Default,
25 {
26 Self {
27 attention: None,
28 config,
29 features,
30 params: ModelParams::default(features),
31 velocities: ModelParams::default(features),
32 prev_target_norm: None,
33 }
34 }
35 pub fn ones(config: SurfaceModelConfig<A>, features: ModelFeatures) -> Self
37 where
38 A: Clone + num_traits::One,
39 {
40 Self {
41 attention: None,
42 config,
43 features,
44 params: ModelParams::ones(features),
45 velocities: ModelParams::ones(features),
46 prev_target_norm: None,
47 }
48 }
49 pub fn zeros(config: SurfaceModelConfig<A>, features: ModelFeatures) -> Self
51 where
52 A: Clone + num_traits::Zero,
53 {
54 Self {
55 attention: None,
56 config,
57 features,
58 params: ModelParams::zeros(features),
59 velocities: ModelParams::zeros(features),
60 prev_target_norm: None,
61 }
62 }
63 pub const fn attention(&self) -> Option<&FftAttention<A>> {
65 self.attention.as_ref()
66 }
67 pub const fn attention_mut(&mut self) -> Option<&mut FftAttention<A>> {
69 self.attention.as_mut()
70 }
71 pub const fn config(&self) -> &SurfaceModelConfig<A> {
73 &self.config
74 }
75 pub const fn config_mut(&mut self) -> &mut SurfaceModelConfig<A> {
77 &mut self.config
78 }
79 pub const fn features(&self) -> ModelFeatures {
81 self.features
82 }
83 pub const fn features_mut(&mut self) -> &mut ModelFeatures {
85 &mut self.features
86 }
87 pub const fn params(&self) -> &ModelParams<A> {
89 &self.params
90 }
91 pub const fn params_mut(&mut self) -> &mut ModelParams<A> {
93 &mut self.params
94 }
95 pub const fn velocities(&self) -> &ModelParams<A> {
97 &self.velocities
98 }
99 pub const fn velocities_mut(&mut self) -> &mut ModelParams<A> {
101 &mut self.velocities
102 }
103 pub const fn input(&self) -> &Params<A> {
105 self.params().input()
106 }
107 pub const fn input_mut(&mut self) -> &mut Params<A> {
109 self.params_mut().input_mut()
110 }
111 pub const fn hidden(&self) -> &Vec<Params<A>> {
113 self.params().hidden()
114 }
115 pub const fn hidden_mut(&mut self) -> &mut Vec<Params<A>> {
117 self.params_mut().hidden_mut()
118 }
119 #[inline]
121 pub fn hidden_as_slice(&self) -> &[Params<A>] {
122 self.params().hidden_as_slice()
123 }
124 pub const fn output(&self) -> &Params<A> {
126 self.params().output()
127 }
128 pub const fn output_mut(&mut self) -> &mut Params<A> {
130 self.params_mut().output_mut()
131 }
132 pub const fn decay(&self) -> &A {
134 self.config().decay()
135 }
136 pub const fn learning_rate(&self) -> &A {
138 self.config().learning_rate()
139 }
140 pub const fn momentum(&self) -> &A {
142 self.config().momentum()
143 }
144 #[inline]
146 pub fn set_attention(&mut self, attention: Option<FftAttention<A>>) -> &mut Self {
147 self.attention = attention;
148 self
149 }
150 #[inline]
152 pub fn set_config(&mut self, config: SurfaceModelConfig<A>) -> &mut Self {
153 *self.config_mut() = config;
154 self
155 }
156 #[inline]
158 pub fn set_features(&mut self, features: ModelFeatures) -> &mut Self {
159 *self.features_mut() = features;
160 self
161 }
162 #[inline]
164 pub fn set_params(&mut self, params: ModelParams<A>) -> &mut Self {
165 *self.params_mut() = params;
166 self
167 }
168 #[inline]
170 pub fn set_input(&mut self, input: Params<A>) -> &mut Self {
171 self.params_mut().set_input(input);
172 self
173 }
174 #[inline]
176 pub fn set_hidden<I>(&mut self, iter: I) -> &mut Self
177 where
178 I: IntoIterator<Item = Params<A>>,
179 {
180 self.params_mut().set_hidden(Vec::from_iter(iter));
181 self
182 }
183 #[inline]
185 pub fn set_output(&mut self, output: Params<A>) -> &mut Self {
186 self.params_mut().set_output(output);
187 self
188 }
189 #[inline]
191 pub fn set_decay(&mut self, decay: A) -> &mut Self {
192 self.config_mut().set_decay(decay);
193 self
194 }
195 #[inline]
197 pub fn set_learning_rate(&mut self, learning_rate: A) -> &mut Self {
198 self.config_mut().set_learning_rate(learning_rate);
199 self
200 }
201 #[inline]
203 pub fn set_momentum(&mut self, momentum: A) -> &mut Self {
204 self.config_mut().set_momentum(momentum);
205 self
206 }
207 #[inline]
209 pub fn with_attention(self) -> Self
210 where
211 A: FromPrimitive,
212 {
213 Self {
214 attention: Some(FftAttention::default()),
215 ..self
216 }
217 }
218 pub fn with_config(self, config: SurfaceModelConfig<A>) -> Self {
220 Self { config, ..self }
221 }
222 pub fn with_features(self, features: ModelFeatures) -> Self {
224 Self { features, ..self }
225 }
226 pub fn with_input(self, input: Params<A>) -> Self {
228 Self {
229 params: self.params.with_input(input),
230 ..self
231 }
232 }
233 pub fn with_hidden<I>(self, iter: I) -> Self
235 where
236 I: IntoIterator<Item = Params<A>>,
237 {
238 Self {
239 params: self.params.with_hidden(iter),
240 ..self
241 }
242 }
243 pub fn with_output(self, output: Params<A>) -> Self {
245 Self {
246 params: self.params.with_output(output),
247 ..self
248 }
249 }
250 pub fn has_attention(&self) -> bool {
252 self.attention().is_some()
253 }
254 pub fn has_no_attention(&self) -> bool {
256 self.attention().is_none()
257 }
258}
259
260impl<A> Default for SurfaceModel<A>
261where
262 A: Clone + Default + FromPrimitive,
263{
264 fn default() -> Self {
265 let config = SurfaceModelConfig::default();
266 let features = ModelFeatures::default();
267 Self::default(config, features)
268 }
269}
270
271impl<A> PartialEq for SurfaceModel<A>
272where
273 A: PartialEq,
274 ModelParams<A>: PartialEq,
275{
276 fn eq(&self, other: &Self) -> bool {
277 self.config == other.config
278 && self.features == other.features
279 && self.params == other.params
280 && self.velocities == other.velocities
281 && self.attention == other.attention
282 }
283}