optirs_learned/transformer_based_optimizer/
config.rs1use super::positional_encoding::PositionalEncodingType;
4use scirs2_core::numeric::Float;
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TransformerBasedOptimizerConfig<T: Float + Debug + Send + Sync + 'static> {
11 pub model_dimension: usize,
13
14 pub num_transformer_layers: usize,
16
17 pub num_attention_heads: usize,
19
20 pub attention_head_dimension: usize,
22
23 pub feedforward_dimension: usize,
25
26 pub sequence_length: usize,
28
29 pub dropout_rate: f64,
31
32 pub learning_rate: T,
34
35 pub batch_size: usize,
37
38 pub num_epochs: usize,
40
41 pub activation_function: ActivationFunction,
43
44 pub positional_encoding_type: PositionalEncodingType,
46
47 pub memory_config: MemoryConfig,
49
50 pub meta_learning_config: MetaLearningConfig<T>,
52
53 pub performance_config: PerformanceConfig,
55
56 pub enable_gradient_clipping: bool,
58
59 pub gradient_clip_value: T,
61
62 pub weight_decay: T,
64
65 pub warmup_steps: usize,
67
68 pub enable_layer_norm: bool,
70
71 pub use_pre_norm: bool,
73
74 pub enable_residual_connections: bool,
76}
77
78impl<T: Float + Debug + Send + Sync + 'static> Default for TransformerBasedOptimizerConfig<T> {
79 fn default() -> Self {
80 Self {
81 model_dimension: 512,
82 num_transformer_layers: 6,
83 num_attention_heads: 8,
84 attention_head_dimension: 64,
85 feedforward_dimension: 2048,
86 sequence_length: 128,
87 dropout_rate: 0.1,
88 learning_rate: scirs2_core::numeric::NumCast::from(1e-4).unwrap_or_else(|| T::zero()),
89 batch_size: 32,
90 num_epochs: 100,
91 activation_function: ActivationFunction::ReLU,
92 positional_encoding_type: PositionalEncodingType::Sinusoidal,
93 memory_config: MemoryConfig::default(),
94 meta_learning_config: MetaLearningConfig::default(),
95 performance_config: PerformanceConfig::default(),
96 enable_gradient_clipping: true,
97 gradient_clip_value: scirs2_core::numeric::NumCast::from(1.0)
98 .unwrap_or_else(|| T::zero()),
99 weight_decay: scirs2_core::numeric::NumCast::from(1e-5).unwrap_or_else(|| T::zero()),
100 warmup_steps: 1000,
101 enable_layer_norm: true,
102 use_pre_norm: true,
103 enable_residual_connections: true,
104 }
105 }
106}
107
108impl<T: Float + Debug + Send + Sync + 'static> TransformerBasedOptimizerConfig<T> {
109 pub fn small() -> Self {
111 Self {
112 model_dimension: 256,
113 num_transformer_layers: 4,
114 num_attention_heads: 4,
115 attention_head_dimension: 64,
116 feedforward_dimension: 1024,
117 sequence_length: 64,
118 ..Self::default()
119 }
120 }
121
122 pub fn large() -> Self {
124 Self {
125 model_dimension: 1024,
126 num_transformer_layers: 12,
127 num_attention_heads: 16,
128 attention_head_dimension: 64,
129 feedforward_dimension: 4096,
130 sequence_length: 256,
131 ..Self::default()
132 }
133 }
134
135 pub fn for_training() -> Self {
137 Self {
138 batch_size: 64,
139 num_epochs: 200,
140 learning_rate: scirs2_core::numeric::NumCast::from(2e-4).unwrap_or_else(|| T::zero()),
141 warmup_steps: 2000,
142 enable_gradient_clipping: true,
143 weight_decay: scirs2_core::numeric::NumCast::from(1e-4).unwrap_or_else(|| T::zero()),
144 ..Self::default()
145 }
146 }
147
148 pub fn for_inference() -> Self {
150 Self {
151 batch_size: 1,
152 dropout_rate: 0.0,
153 enable_gradient_clipping: false,
154 ..Self::default()
155 }
156 }
157
158 pub fn validate(&self) -> Result<(), String> {
160 if self.model_dimension == 0 {
161 return Err("model_dimension must be greater than 0".to_string());
162 }
163
164 if self.num_transformer_layers == 0 {
165 return Err("num_transformer_layers must be greater than 0".to_string());
166 }
167
168 if self.num_attention_heads == 0 {
169 return Err("num_attention_heads must be greater than 0".to_string());
170 }
171
172 if !self
173 .model_dimension
174 .is_multiple_of(self.num_attention_heads)
175 {
176 return Err("model_dimension must be divisible by num_attention_heads".to_string());
177 }
178
179 if self.attention_head_dimension * self.num_attention_heads != self.model_dimension {
180 return Err(
181 "attention_head_dimension * num_attention_heads must equal model_dimension"
182 .to_string(),
183 );
184 }
185
186 if self.sequence_length == 0 {
187 return Err("sequence_length must be greater than 0".to_string());
188 }
189
190 if self.dropout_rate < 0.0 || self.dropout_rate > 1.0 {
191 return Err("dropout_rate must be between 0.0 and 1.0".to_string());
192 }
193
194 if self.learning_rate <= T::zero() {
195 return Err("learning_rate must be positive".to_string());
196 }
197
198 if self.batch_size == 0 {
199 return Err("batch_size must be greater than 0".to_string());
200 }
201
202 self.memory_config.validate()?;
203 self.meta_learning_config.validate()?;
204 self.performance_config.validate()?;
205
206 Ok(())
207 }
208
209 pub fn estimate_parameter_count(&self) -> usize {
211 let embedding_params = self.model_dimension * self.model_dimension; let positional_params = self.sequence_length * self.model_dimension;
213
214 let attention_params_per_layer = 4 * self.model_dimension * self.model_dimension; let ffn_params_per_layer = 2 * self.model_dimension * self.feedforward_dimension; let norm_params_per_layer = 2 * self.model_dimension; let layer_params =
219 attention_params_per_layer + ffn_params_per_layer + norm_params_per_layer;
220 let total_layer_params = layer_params * self.num_transformer_layers;
221
222 let output_params = self.model_dimension * self.model_dimension; embedding_params + positional_params + total_layer_params + output_params
225 }
226
227 pub fn estimate_memory_usage(&self) -> f64 {
229 let param_count = self.estimate_parameter_count();
230 let bytes_per_param = if std::mem::size_of::<T>() == 4 {
231 4.0
232 } else {
233 8.0
234 };
235
236 let model_memory = param_count as f64 * bytes_per_param;
237 let activation_memory = self.batch_size as f64
238 * self.sequence_length as f64
239 * self.model_dimension as f64
240 * bytes_per_param;
241 let gradient_memory = model_memory; let total_bytes = model_memory + activation_memory + gradient_memory;
244 total_bytes / (1024.0 * 1024.0) }
246}
247
248#[derive(Debug, Clone)]
250pub struct TransformerArchConfig {
251 pub model_dimension: usize,
252 pub num_layers: usize,
253 pub num_attention_heads: usize,
254 pub feedforward_dimension: usize,
255 pub dropout_rate: f64,
256 pub use_pre_norm: bool,
257 pub enable_residual_connections: bool,
258}
259
260impl TransformerArchConfig {
261 pub fn from_optimizer_config<T: Float + Debug + Send + Sync + 'static>(
262 config: &TransformerBasedOptimizerConfig<T>,
263 ) -> Self {
264 Self {
265 model_dimension: config.model_dimension,
266 num_layers: config.num_transformer_layers,
267 num_attention_heads: config.num_attention_heads,
268 feedforward_dimension: config.feedforward_dimension,
269 dropout_rate: config.dropout_rate,
270 use_pre_norm: config.use_pre_norm,
271 enable_residual_connections: config.enable_residual_connections,
272 }
273 }
274}
275
276#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
278pub enum ActivationFunction {
279 ReLU,
280 GELU,
281 Swish,
282 Tanh,
283 Sigmoid,
284 LeakyReLU,
285}
286
287#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct MemoryConfig {
290 pub max_cache_size: usize,
292 pub enable_compression: bool,
294 pub eviction_strategy: CacheEvictionStrategy,
296 pub allocation_block_size: usize,
298}
299
300impl Default for MemoryConfig {
301 fn default() -> Self {
302 Self {
303 max_cache_size: 1024 * 1024 * 1024, enable_compression: false,
305 eviction_strategy: CacheEvictionStrategy::LRU,
306 allocation_block_size: 4096,
307 }
308 }
309}
310
311impl MemoryConfig {
312 pub fn validate(&self) -> Result<(), String> {
313 if self.max_cache_size == 0 {
314 return Err("max_cache_size must be greater than 0".to_string());
315 }
316
317 if self.allocation_block_size == 0 {
318 return Err("allocation_block_size must be greater than 0".to_string());
319 }
320
321 Ok(())
322 }
323}
324
325#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
327pub enum CacheEvictionStrategy {
328 LRU,
329 LFU,
330 FIFO,
331 Random,
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct MetaLearningConfig<T: Float + Debug + Send + Sync + 'static> {
337 pub meta_learning_rate: T,
339 pub inner_steps: usize,
341 pub inner_learning_rate: T,
343 pub first_order: bool,
345 pub num_support: usize,
347 pub num_query: usize,
349}
350
351impl<T: Float + Debug + Send + Sync + 'static> Default for MetaLearningConfig<T> {
352 fn default() -> Self {
353 Self {
354 meta_learning_rate: scirs2_core::numeric::NumCast::from(1e-3)
355 .unwrap_or_else(|| T::zero()),
356 inner_steps: 5,
357 inner_learning_rate: scirs2_core::numeric::NumCast::from(1e-2)
358 .unwrap_or_else(|| T::zero()),
359 first_order: false,
360 num_support: 5,
361 num_query: 15,
362 }
363 }
364}
365
366impl<T: Float + Debug + Send + Sync + 'static> MetaLearningConfig<T> {
367 pub fn validate(&self) -> Result<(), String> {
368 if self.meta_learning_rate <= T::zero() {
369 return Err("meta_learning_rate must be positive".to_string());
370 }
371
372 if self.inner_learning_rate <= T::zero() {
373 return Err("inner_learning_rate must be positive".to_string());
374 }
375
376 if self.inner_steps == 0 {
377 return Err("inner_steps must be greater than 0".to_string());
378 }
379
380 if self.num_support == 0 {
381 return Err("num_support must be greater than 0".to_string());
382 }
383
384 if self.num_query == 0 {
385 return Err("num_query must be greater than 0".to_string());
386 }
387
388 Ok(())
389 }
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct PerformanceConfig {
395 pub enable_detailed_tracking: bool,
397 pub metrics_interval: usize,
399 pub max_history_size: usize,
401 pub track_memory_usage: bool,
403 pub enable_timing_analysis: bool,
405}
406
407impl Default for PerformanceConfig {
408 fn default() -> Self {
409 Self {
410 enable_detailed_tracking: true,
411 metrics_interval: 10,
412 max_history_size: 10000,
413 track_memory_usage: true,
414 enable_timing_analysis: true,
415 }
416 }
417}
418
419impl PerformanceConfig {
420 pub fn validate(&self) -> Result<(), String> {
421 if self.metrics_interval == 0 {
422 return Err("metrics_interval must be greater than 0".to_string());
423 }
424
425 if self.max_history_size == 0 {
426 return Err("max_history_size must be greater than 0".to_string());
427 }
428
429 Ok(())
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_config_default() {
439 let config = TransformerBasedOptimizerConfig::<f32>::default();
440 assert!(config.validate().is_ok());
441 assert_eq!(config.model_dimension, 512);
442 assert_eq!(config.num_transformer_layers, 6);
443 }
444
445 #[test]
446 fn test_config_validation() {
447 let config = TransformerBasedOptimizerConfig::<f32> {
449 model_dimension: 0,
450 ..Default::default()
451 };
452 assert!(config.validate().is_err());
453
454 let config = TransformerBasedOptimizerConfig::<f32> {
456 model_dimension: 512,
457 num_attention_heads: 7, ..Default::default()
459 };
460 assert!(config.validate().is_err());
461 }
462
463 #[test]
464 fn test_parameter_estimation() {
465 let config = TransformerBasedOptimizerConfig::<f32>::small();
466 let param_count = config.estimate_parameter_count();
467 assert!(param_count > 0);
468
469 let memory_usage = config.estimate_memory_usage();
470 assert!(memory_usage > 0.0);
471 }
472
473 #[test]
474 fn test_preset_configs() {
475 let small_config = TransformerBasedOptimizerConfig::<f32>::small();
476 assert!(small_config.validate().is_ok());
477 assert_eq!(small_config.model_dimension, 256);
478
479 let large_config = TransformerBasedOptimizerConfig::<f32>::large();
480 assert!(large_config.validate().is_ok());
481 assert_eq!(large_config.model_dimension, 1024);
482
483 let training_config = TransformerBasedOptimizerConfig::<f32>::for_training();
484 assert!(training_config.validate().is_ok());
485 assert_eq!(training_config.batch_size, 64);
486 }
487}