Skip to main content

optirs_learned/transformer_based_optimizer/
config.rs

1// Configuration structures for transformer-based optimizer
2
3use super::positional_encoding::PositionalEncodingType;
4use scirs2_core::numeric::Float;
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7
8/// Configuration for transformer-based optimizer
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TransformerBasedOptimizerConfig<T: Float + Debug + Send + Sync + 'static> {
11    /// Model dimension (embedding size)
12    pub model_dimension: usize,
13
14    /// Number of transformer layers
15    pub num_transformer_layers: usize,
16
17    /// Number of attention heads
18    pub num_attention_heads: usize,
19
20    /// Dimension of each attention head
21    pub attention_head_dimension: usize,
22
23    /// Feed-forward network dimension
24    pub feedforward_dimension: usize,
25
26    /// Maximum sequence length
27    pub sequence_length: usize,
28
29    /// Dropout rate
30    pub dropout_rate: f64,
31
32    /// Learning rate
33    pub learning_rate: T,
34
35    /// Batch size for training
36    pub batch_size: usize,
37
38    /// Number of training epochs
39    pub num_epochs: usize,
40
41    /// Activation function type
42    pub activation_function: ActivationFunction,
43
44    /// Positional encoding type
45    pub positional_encoding_type: PositionalEncodingType,
46
47    /// Memory management configuration
48    pub memory_config: MemoryConfig,
49
50    /// Meta-learning configuration
51    pub meta_learning_config: MetaLearningConfig<T>,
52
53    /// Performance tracking configuration
54    pub performance_config: PerformanceConfig,
55
56    /// Enable gradient clipping
57    pub enable_gradient_clipping: bool,
58
59    /// Gradient clipping threshold
60    pub gradient_clip_value: T,
61
62    /// Weight decay factor
63    pub weight_decay: T,
64
65    /// Warmup steps for learning rate schedule
66    pub warmup_steps: usize,
67
68    /// Enable layer normalization
69    pub enable_layer_norm: bool,
70
71    /// Pre-norm vs post-norm
72    pub use_pre_norm: bool,
73
74    /// Enable residual connections
75    pub enable_residual_connections: bool,
76}
77
78impl<T: Float + Debug + Send + Sync + 'static> Default for TransformerBasedOptimizerConfig<T> {
79    fn default() -> Self {
80        Self {
81            model_dimension: 512,
82            num_transformer_layers: 6,
83            num_attention_heads: 8,
84            attention_head_dimension: 64,
85            feedforward_dimension: 2048,
86            sequence_length: 128,
87            dropout_rate: 0.1,
88            learning_rate: scirs2_core::numeric::NumCast::from(1e-4).unwrap_or_else(|| T::zero()),
89            batch_size: 32,
90            num_epochs: 100,
91            activation_function: ActivationFunction::ReLU,
92            positional_encoding_type: PositionalEncodingType::Sinusoidal,
93            memory_config: MemoryConfig::default(),
94            meta_learning_config: MetaLearningConfig::default(),
95            performance_config: PerformanceConfig::default(),
96            enable_gradient_clipping: true,
97            gradient_clip_value: scirs2_core::numeric::NumCast::from(1.0)
98                .unwrap_or_else(|| T::zero()),
99            weight_decay: scirs2_core::numeric::NumCast::from(1e-5).unwrap_or_else(|| T::zero()),
100            warmup_steps: 1000,
101            enable_layer_norm: true,
102            use_pre_norm: true,
103            enable_residual_connections: true,
104        }
105    }
106}
107
108impl<T: Float + Debug + Send + Sync + 'static> TransformerBasedOptimizerConfig<T> {
109    /// Create configuration for small models
110    pub fn small() -> Self {
111        Self {
112            model_dimension: 256,
113            num_transformer_layers: 4,
114            num_attention_heads: 4,
115            attention_head_dimension: 64,
116            feedforward_dimension: 1024,
117            sequence_length: 64,
118            ..Self::default()
119        }
120    }
121
122    /// Create configuration for large models
123    pub fn large() -> Self {
124        Self {
125            model_dimension: 1024,
126            num_transformer_layers: 12,
127            num_attention_heads: 16,
128            attention_head_dimension: 64,
129            feedforward_dimension: 4096,
130            sequence_length: 256,
131            ..Self::default()
132        }
133    }
134
135    /// Create configuration optimized for training
136    pub fn for_training() -> Self {
137        Self {
138            batch_size: 64,
139            num_epochs: 200,
140            learning_rate: scirs2_core::numeric::NumCast::from(2e-4).unwrap_or_else(|| T::zero()),
141            warmup_steps: 2000,
142            enable_gradient_clipping: true,
143            weight_decay: scirs2_core::numeric::NumCast::from(1e-4).unwrap_or_else(|| T::zero()),
144            ..Self::default()
145        }
146    }
147
148    /// Create configuration optimized for inference
149    pub fn for_inference() -> Self {
150        Self {
151            batch_size: 1,
152            dropout_rate: 0.0,
153            enable_gradient_clipping: false,
154            ..Self::default()
155        }
156    }
157
158    /// Validate configuration
159    pub fn validate(&self) -> Result<(), String> {
160        if self.model_dimension == 0 {
161            return Err("model_dimension must be greater than 0".to_string());
162        }
163
164        if self.num_transformer_layers == 0 {
165            return Err("num_transformer_layers must be greater than 0".to_string());
166        }
167
168        if self.num_attention_heads == 0 {
169            return Err("num_attention_heads must be greater than 0".to_string());
170        }
171
172        if !self
173            .model_dimension
174            .is_multiple_of(self.num_attention_heads)
175        {
176            return Err("model_dimension must be divisible by num_attention_heads".to_string());
177        }
178
179        if self.attention_head_dimension * self.num_attention_heads != self.model_dimension {
180            return Err(
181                "attention_head_dimension * num_attention_heads must equal model_dimension"
182                    .to_string(),
183            );
184        }
185
186        if self.sequence_length == 0 {
187            return Err("sequence_length must be greater than 0".to_string());
188        }
189
190        if self.dropout_rate < 0.0 || self.dropout_rate > 1.0 {
191            return Err("dropout_rate must be between 0.0 and 1.0".to_string());
192        }
193
194        if self.learning_rate <= T::zero() {
195            return Err("learning_rate must be positive".to_string());
196        }
197
198        if self.batch_size == 0 {
199            return Err("batch_size must be greater than 0".to_string());
200        }
201
202        self.memory_config.validate()?;
203        self.meta_learning_config.validate()?;
204        self.performance_config.validate()?;
205
206        Ok(())
207    }
208
209    /// Calculate total parameters estimate
210    pub fn estimate_parameter_count(&self) -> usize {
211        let embedding_params = self.model_dimension * self.model_dimension; // Input embedding
212        let positional_params = self.sequence_length * self.model_dimension;
213
214        let attention_params_per_layer = 4 * self.model_dimension * self.model_dimension; // Q, K, V, O projections
215        let ffn_params_per_layer = 2 * self.model_dimension * self.feedforward_dimension; // Up and down projections
216        let norm_params_per_layer = 2 * self.model_dimension; // Layer norm parameters
217
218        let layer_params =
219            attention_params_per_layer + ffn_params_per_layer + norm_params_per_layer;
220        let total_layer_params = layer_params * self.num_transformer_layers;
221
222        let output_params = self.model_dimension * self.model_dimension; // Output projection
223
224        embedding_params + positional_params + total_layer_params + output_params
225    }
226
227    /// Calculate memory requirements (in MB)
228    pub fn estimate_memory_usage(&self) -> f64 {
229        let param_count = self.estimate_parameter_count();
230        let bytes_per_param = if std::mem::size_of::<T>() == 4 {
231            4.0
232        } else {
233            8.0
234        };
235
236        let model_memory = param_count as f64 * bytes_per_param;
237        let activation_memory = self.batch_size as f64
238            * self.sequence_length as f64
239            * self.model_dimension as f64
240            * bytes_per_param;
241        let gradient_memory = model_memory; // Assume same as model for gradients
242
243        let total_bytes = model_memory + activation_memory + gradient_memory;
244        total_bytes / (1024.0 * 1024.0) // Convert to MB
245    }
246}
247
248/// Transformer architecture configuration
249#[derive(Debug, Clone)]
250pub struct TransformerArchConfig {
251    pub model_dimension: usize,
252    pub num_layers: usize,
253    pub num_attention_heads: usize,
254    pub feedforward_dimension: usize,
255    pub dropout_rate: f64,
256    pub use_pre_norm: bool,
257    pub enable_residual_connections: bool,
258}
259
260impl TransformerArchConfig {
261    pub fn from_optimizer_config<T: Float + Debug + Send + Sync + 'static>(
262        config: &TransformerBasedOptimizerConfig<T>,
263    ) -> Self {
264        Self {
265            model_dimension: config.model_dimension,
266            num_layers: config.num_transformer_layers,
267            num_attention_heads: config.num_attention_heads,
268            feedforward_dimension: config.feedforward_dimension,
269            dropout_rate: config.dropout_rate,
270            use_pre_norm: config.use_pre_norm,
271            enable_residual_connections: config.enable_residual_connections,
272        }
273    }
274}
275
276/// Activation function types
277#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
278pub enum ActivationFunction {
279    ReLU,
280    GELU,
281    Swish,
282    Tanh,
283    Sigmoid,
284    LeakyReLU,
285}
286
287/// Memory management configuration
288#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct MemoryConfig {
290    /// Maximum memory cache size
291    pub max_cache_size: usize,
292    /// Enable memory compression
293    pub enable_compression: bool,
294    /// Cache eviction strategy
295    pub eviction_strategy: CacheEvictionStrategy,
296    /// Memory allocation block size
297    pub allocation_block_size: usize,
298}
299
300impl Default for MemoryConfig {
301    fn default() -> Self {
302        Self {
303            max_cache_size: 1024 * 1024 * 1024, // 1GB
304            enable_compression: false,
305            eviction_strategy: CacheEvictionStrategy::LRU,
306            allocation_block_size: 4096,
307        }
308    }
309}
310
311impl MemoryConfig {
312    pub fn validate(&self) -> Result<(), String> {
313        if self.max_cache_size == 0 {
314            return Err("max_cache_size must be greater than 0".to_string());
315        }
316
317        if self.allocation_block_size == 0 {
318            return Err("allocation_block_size must be greater than 0".to_string());
319        }
320
321        Ok(())
322    }
323}
324
325/// Cache eviction strategies
326#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
327pub enum CacheEvictionStrategy {
328    LRU,
329    LFU,
330    FIFO,
331    Random,
332}
333
334/// Meta-learning configuration
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct MetaLearningConfig<T: Float + Debug + Send + Sync + 'static> {
337    /// Meta-learning rate
338    pub meta_learning_rate: T,
339    /// Number of inner optimization steps
340    pub inner_steps: usize,
341    /// Inner learning rate
342    pub inner_learning_rate: T,
343    /// Enable first-order approximation
344    pub first_order: bool,
345    /// Number of support examples
346    pub num_support: usize,
347    /// Number of query examples
348    pub num_query: usize,
349}
350
351impl<T: Float + Debug + Send + Sync + 'static> Default for MetaLearningConfig<T> {
352    fn default() -> Self {
353        Self {
354            meta_learning_rate: scirs2_core::numeric::NumCast::from(1e-3)
355                .unwrap_or_else(|| T::zero()),
356            inner_steps: 5,
357            inner_learning_rate: scirs2_core::numeric::NumCast::from(1e-2)
358                .unwrap_or_else(|| T::zero()),
359            first_order: false,
360            num_support: 5,
361            num_query: 15,
362        }
363    }
364}
365
366impl<T: Float + Debug + Send + Sync + 'static> MetaLearningConfig<T> {
367    pub fn validate(&self) -> Result<(), String> {
368        if self.meta_learning_rate <= T::zero() {
369            return Err("meta_learning_rate must be positive".to_string());
370        }
371
372        if self.inner_learning_rate <= T::zero() {
373            return Err("inner_learning_rate must be positive".to_string());
374        }
375
376        if self.inner_steps == 0 {
377            return Err("inner_steps must be greater than 0".to_string());
378        }
379
380        if self.num_support == 0 {
381            return Err("num_support must be greater than 0".to_string());
382        }
383
384        if self.num_query == 0 {
385            return Err("num_query must be greater than 0".to_string());
386        }
387
388        Ok(())
389    }
390}
391
392/// Performance tracking configuration
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct PerformanceConfig {
395    /// Enable detailed performance tracking
396    pub enable_detailed_tracking: bool,
397    /// Performance metrics collection interval
398    pub metrics_interval: usize,
399    /// Maximum history size for metrics
400    pub max_history_size: usize,
401    /// Enable memory usage tracking
402    pub track_memory_usage: bool,
403    /// Enable timing analysis
404    pub enable_timing_analysis: bool,
405}
406
407impl Default for PerformanceConfig {
408    fn default() -> Self {
409        Self {
410            enable_detailed_tracking: true,
411            metrics_interval: 10,
412            max_history_size: 10000,
413            track_memory_usage: true,
414            enable_timing_analysis: true,
415        }
416    }
417}
418
419impl PerformanceConfig {
420    pub fn validate(&self) -> Result<(), String> {
421        if self.metrics_interval == 0 {
422            return Err("metrics_interval must be greater than 0".to_string());
423        }
424
425        if self.max_history_size == 0 {
426            return Err("max_history_size must be greater than 0".to_string());
427        }
428
429        Ok(())
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_config_default() {
439        let config = TransformerBasedOptimizerConfig::<f32>::default();
440        assert!(config.validate().is_ok());
441        assert_eq!(config.model_dimension, 512);
442        assert_eq!(config.num_transformer_layers, 6);
443    }
444
445    #[test]
446    fn test_config_validation() {
447        // Test invalid model dimension
448        let config = TransformerBasedOptimizerConfig::<f32> {
449            model_dimension: 0,
450            ..Default::default()
451        };
452        assert!(config.validate().is_err());
453
454        // Test mismatched attention dimensions
455        let config = TransformerBasedOptimizerConfig::<f32> {
456            model_dimension: 512,
457            num_attention_heads: 7, // 512 is not divisible by 7
458            ..Default::default()
459        };
460        assert!(config.validate().is_err());
461    }
462
463    #[test]
464    fn test_parameter_estimation() {
465        let config = TransformerBasedOptimizerConfig::<f32>::small();
466        let param_count = config.estimate_parameter_count();
467        assert!(param_count > 0);
468
469        let memory_usage = config.estimate_memory_usage();
470        assert!(memory_usage > 0.0);
471    }
472
473    #[test]
474    fn test_preset_configs() {
475        let small_config = TransformerBasedOptimizerConfig::<f32>::small();
476        assert!(small_config.validate().is_ok());
477        assert_eq!(small_config.model_dimension, 256);
478
479        let large_config = TransformerBasedOptimizerConfig::<f32>::large();
480        assert!(large_config.validate().is_ok());
481        assert_eq!(large_config.model_dimension, 1024);
482
483        let training_config = TransformerBasedOptimizerConfig::<f32>::for_training();
484        assert!(training_config.validate().is_ok());
485        assert_eq!(training_config.batch_size, 64);
486    }
487}