optirs_learned/transformer/
mod.rs

1use std::fmt::Debug;
2// Transformer-based learned optimizer
3//
4// This module provides a modular implementation of a transformer-based neural optimizer
5// that uses self-attention mechanisms to adaptively update optimization parameters.
6// The implementation is organized into architectural components, optimization strategies,
7// and training infrastructure for improved maintainability and extensibility.
8
9pub mod architecture;
10pub mod strategies;
11pub mod training;
12
13// Re-export key types for convenience
14pub use architecture::{
15    ActivationFunction, AttentionOptimization, FeedForwardNetwork, InputEmbedding, LayerNorm,
16    MultiHeadAttention, OutputProjectionLayer, PositionalEncoder, PositionalEncodingType,
17    TransformerLayer,
18};
19
20pub use strategies::{
21    GradientProcessingStrategy, GradientProcessor, LearningRateAdaptationStrategy,
22    LearningRateAdapter, MomentumIntegrator, MomentumStrategy, RegularizationStrategy,
23    TransformerRegularizer,
24};
25
26pub use training::{
27    CurriculumLearner, CurriculumStrategy, EvaluationStrategy, MetaLearningStrategy,
28    TransformerEvaluator, TransformerMetaLearner,
29};
30
31use scirs2_core::ndarray::{Array1, Array2};
32use scirs2_core::numeric::Float;
33use std::collections::{HashMap, VecDeque};
34
35use super::{LearnedOptimizerConfig, MetaOptimizationStrategy};
36use crate::error::{OptimError, Result};
37
38/// Configuration specific to Transformer optimizer
39#[derive(Debug, Clone)]
40pub struct TransformerOptimizerConfig {
41    /// Base learned optimizer config
42    pub base_config: LearnedOptimizerConfig,
43
44    /// Model dimension (d_model)
45    pub modeldim: usize,
46
47    /// Number of attention heads
48    pub numheads: usize,
49
50    /// Feed-forward network dimension
51    pub ff_dim: usize,
52
53    /// Number of transformer layers
54    pub num_layers: usize,
55
56    /// Maximum sequence length
57    pub max_sequence_length: usize,
58
59    /// Attention dropout rate
60    pub attention_dropout: f64,
61
62    /// Feed-forward dropout rate
63    pub ff_dropout: f64,
64
65    /// Layer normalization epsilon
66    pub layer_norm_eps: f64,
67
68    /// Use pre-layer normalization
69    pub pre_layer_norm: bool,
70
71    /// Positional encoding type
72    pub pos_encoding_type: PositionalEncodingType,
73
74    /// Enable relative position bias
75    pub relative_position_bias: bool,
76
77    /// Use rotary position embedding
78    pub use_rope: bool,
79
80    /// Enable gradient checkpointing
81    pub gradient_checkpointing: bool,
82
83    /// Attention pattern optimization
84    pub attention_optimization: AttentionOptimization,
85
86    /// Multi-scale attention
87    pub multi_scale_attention: bool,
88
89    /// Cross-attention for multi-task learning
90    pub cross_attention: bool,
91
92    /// Memory efficiency mode
93    pub memory_efficient: bool,
94}
95
96/// Transformer network architecture
97#[derive(Debug, Clone)]
98pub struct TransformerNetwork<
99    T: Float
100        + Debug
101        + Default
102        + Clone
103        + std::iter::Sum
104        + scirs2_core::ndarray::ScalarOperand
105        + Send
106        + Sync
107        + 'static,
108> {
109    /// Input embedding layer
110    input_embedding: InputEmbedding<T>,
111
112    /// Transformer layers
113    layers: Vec<TransformerLayer<T>>,
114
115    /// Output projection
116    output_projection: OutputProjectionLayer<T>,
117
118    /// Layer normalization for output
119    output_layer_norm: LayerNorm<T>,
120
121    /// Position encoder
122    position_encoder: PositionalEncoder<T>,
123
124    /// Configuration
125    config: TransformerOptimizerConfig,
126}
127
128/// Transformer-based neural optimizer with self-attention mechanisms
129#[derive(Debug)]
130pub struct TransformerOptimizer<
131    T: Float
132        + Debug
133        + Default
134        + Clone
135        + std::iter::Sum
136        + scirs2_core::ndarray::ScalarOperand
137        + Send
138        + Sync
139        + 'static,
140> {
141    /// Configuration for the Transformer optimizer
142    config: TransformerOptimizerConfig,
143
144    /// Transformer network architecture
145    transformer_network: TransformerNetwork<T>,
146
147    /// Gradient processing strategies
148    gradient_processor: GradientProcessor<T>,
149
150    /// Learning rate adaptation
151    lr_adapter: LearningRateAdapter<T>,
152
153    /// Momentum integration
154    momentum_integrator: MomentumIntegrator<T>,
155
156    /// Regularization strategies
157    regularizer: TransformerRegularizer<T>,
158
159    /// Meta-learning components
160    meta_learner: TransformerMetaLearner<T>,
161
162    /// Curriculum learning
163    curriculum_learner: CurriculumLearner<T>,
164
165    /// Evaluation framework
166    evaluator: TransformerEvaluator<T>,
167
168    /// Sequence buffer for maintaining optimization history
169    sequence_buffer: SequenceBuffer<T>,
170
171    /// Performance metrics
172    metrics: TransformerOptimizerMetrics,
173
174    /// Current optimization step
175    step_count: usize,
176
177    /// Random number generator
178    rng: scirs2_core::random::CoreRandom,
179}
180
181/// Sequence buffer for optimization history
182#[derive(Debug, Clone)]
183pub struct SequenceBuffer<
184    T: Float + Debug + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
185> {
186    /// Gradient sequences
187    gradient_sequences: VecDeque<Array1<T>>,
188
189    /// Parameter sequences
190    parameter_sequences: VecDeque<Array1<T>>,
191
192    /// Loss sequences
193    loss_sequences: VecDeque<T>,
194
195    /// Learning rate sequences
196    lr_sequences: VecDeque<T>,
197
198    /// Buffer capacity
199    capacity: usize,
200}
201
202/// Performance metrics for transformer optimizer
203#[derive(Debug, Clone)]
204pub struct TransformerOptimizerMetrics {
205    /// Total optimization steps
206    total_steps: usize,
207
208    /// Convergence history
209    convergence_history: Vec<f64>,
210
211    /// Attention pattern statistics
212    attention_stats: HashMap<String, f64>,
213
214    /// Strategy usage statistics
215    strategy_stats: HashMap<String, f64>,
216
217    /// Performance comparisons
218    performance_comparisons: HashMap<String, f64>,
219}
220
221impl<
222        T: Float
223            + Debug
224            + Default
225            + Clone
226            + std::iter::Sum
227            + scirs2_core::ndarray::ScalarOperand
228            + Send
229            + Sync
230            + 'static,
231    > TransformerNetwork<T>
232{
233    /// Create new transformer network
234    pub fn new(config: &TransformerOptimizerConfig) -> Result<Self> {
235        let input_embedding = InputEmbedding::new(config.modeldim, config.modeldim)?;
236
237        let mut layers = Vec::new();
238        for _ in 0..config.num_layers {
239            let mut rng = scirs2_core::random::thread_rng();
240            layers.push(TransformerLayer::new(config, &mut rng)?);
241        }
242
243        let output_projection = OutputProjectionLayer::new(config.modeldim, config.modeldim)?;
244        let output_layer_norm = LayerNorm::new(config.modeldim);
245        let position_encoder = PositionalEncoder::new(config)?;
246
247        Ok(Self {
248            input_embedding,
249            layers,
250            output_projection,
251            output_layer_norm,
252            position_encoder,
253            config: config.clone(),
254        })
255    }
256
257    /// Forward pass through transformer network
258    pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
259        // Input embedding
260        let mut x = self.input_embedding.forward(input)?;
261
262        // Add positional encoding
263        x = self.position_encoder.encode(&x)?;
264
265        // Pass through transformer layers
266        for layer in &mut self.layers {
267            x = layer.forward(&x)?;
268        }
269
270        // Output layer normalization
271        x = self.output_layer_norm.forward(&x)?;
272
273        // Output projection
274        let output = self.output_projection.forward(&x)?;
275
276        Ok(output)
277    }
278
279    /// Get attention patterns from all layers
280    pub fn get_attention_patterns(&self) -> Vec<Option<&scirs2_core::ndarray::Array3<T>>> {
281        self.layers
282            .iter()
283            .map(|layer| layer.get_attention_patterns())
284            .collect()
285    }
286}
287
288impl<
289        T: Float
290            + Debug
291            + Default
292            + Clone
293            + std::iter::Sum
294            + scirs2_core::ndarray::ScalarOperand
295            + Send
296            + Sync
297            + 'static,
298    > TransformerOptimizer<T>
299{
300    /// Create new transformer optimizer
301    pub fn new(config: TransformerOptimizerConfig) -> Result<Self> {
302        let transformer_network = TransformerNetwork::new(&config)?;
303        let gradient_processor = GradientProcessor::new(GradientProcessingStrategy::Adaptive);
304        let lr_adapter = LearningRateAdapter::new(
305            LearningRateAdaptationStrategy::TransformerPredicted,
306            scirs2_core::numeric::NumCast::from(0.001).unwrap_or_else(|| T::zero()),
307        );
308        let momentum_integrator = MomentumIntegrator::new(MomentumStrategy::TransformerPredicted);
309        let regularizer = TransformerRegularizer::new(RegularizationStrategy::Adaptive);
310        let meta_learner = TransformerMetaLearner::new(MetaLearningStrategy::GradientBased)?;
311        let curriculum_learner = CurriculumLearner::new(CurriculumStrategy::Adaptive)?;
312        let evaluator = TransformerEvaluator::new(EvaluationStrategy::Comprehensive)?;
313        let sequence_buffer = SequenceBuffer::new(1000);
314        let metrics = TransformerOptimizerMetrics::new();
315
316        Ok(Self {
317            config,
318            transformer_network,
319            gradient_processor,
320            lr_adapter,
321            momentum_integrator,
322            regularizer,
323            meta_learner,
324            curriculum_learner,
325            evaluator,
326            sequence_buffer,
327            metrics,
328            step_count: 0,
329            rng: scirs2_core::random::thread_rng(),
330        })
331    }
332
333    /// Perform optimization step
334    pub fn step(
335        &mut self,
336        parameters: &mut HashMap<String, Array2<T>>,
337        gradients: &mut HashMap<String, Array2<T>>,
338        loss: T,
339    ) -> Result<T> {
340        self.step_count += 1;
341
342        // Process gradients for each parameter
343        for (param_name, gradient) in gradients.iter_mut() {
344            // Flatten gradient for processing
345            let flat_gradient = gradient.iter().cloned().collect::<Vec<_>>();
346            let gradient_array = Array1::from_vec(flat_gradient);
347
348            // Apply gradient processing
349            let processed_gradient = self.gradient_processor.process_gradients(&gradient_array)?;
350
351            // Update learning rate
352            let current_lr = self
353                .lr_adapter
354                .update_learning_rate(Some(loss), Some(&processed_gradient))?;
355
356            // Apply momentum
357            let momentum_gradient = self.momentum_integrator.integrate_momentum(
358                &processed_gradient,
359                None, // Would pass attention patterns in full implementation
360            )?;
361
362            // Apply regularization
363            let mut param_map = HashMap::new();
364            if let Some(param_values) = parameters.get(param_name) {
365                param_map.insert(param_name.clone(), param_values.clone());
366            }
367
368            let mut grad_map = HashMap::new();
369            grad_map.insert(param_name.clone(), gradient.clone());
370
371            let _reg_loss = self.regularizer.apply_regularization(
372                &param_map,
373                &mut grad_map,
374                None, // Would pass attention patterns in full implementation
375            )?;
376
377            // Store processed gradients in sequence buffer
378            self.sequence_buffer.add_gradient(momentum_gradient);
379        }
380
381        // Update sequence buffer with loss and learning rate
382        self.sequence_buffer.add_loss(loss);
383        self.sequence_buffer
384            .add_learning_rate(self.lr_adapter.current_learning_rate());
385
386        // Update curriculum learning
387        let task_id = "current_task"; // In practice, this would be provided
388        self.curriculum_learner
389            .update_curriculum(task_id, loss, self.step_count)?;
390
391        // Update metrics
392        self.metrics
393            .update_step(loss.to_f64().unwrap_or(0.0), self.step_count);
394
395        Ok(loss)
396    }
397
398    /// Get current optimization statistics
399    pub fn get_statistics(&self) -> HashMap<String, f64> {
400        let mut stats = HashMap::new();
401
402        stats.insert("step_count".to_string(), self.step_count as f64);
403        stats.insert(
404            "current_lr".to_string(),
405            self.lr_adapter
406                .current_learning_rate()
407                .to_f64()
408                .unwrap_or(0.0),
409        );
410
411        // Add gradient processor statistics
412        let grad_stats = self.gradient_processor.statistics();
413        stats.insert(
414            "mean_gradient_magnitude".to_string(),
415            grad_stats.mean_magnitude().to_f64().unwrap_or(0.0),
416        );
417        stats.insert(
418            "gradient_sparsity".to_string(),
419            grad_stats.sparsity().to_f64().unwrap_or(0.0),
420        );
421
422        // Add momentum statistics
423        let momentum_stats = self.momentum_integrator.statistics();
424        stats.insert(
425            "momentum_magnitude".to_string(),
426            momentum_stats
427                .avg_momentum_magnitude
428                .to_f64()
429                .unwrap_or(0.0),
430        );
431
432        // Add curriculum statistics
433        let curriculum_stats = self.curriculum_learner.get_curriculum_statistics();
434        for (key, value) in curriculum_stats {
435            stats.insert(format!("curriculum_{}", key), value.to_f64().unwrap_or(0.0));
436        }
437
438        stats
439    }
440
441    /// Reset optimizer state
442    pub fn reset(&mut self) -> Result<()> {
443        self.step_count = 0;
444        self.gradient_processor.reset();
445        self.lr_adapter.reset();
446        self.momentum_integrator.reset();
447        self.regularizer.reset();
448        self.meta_learner.reset();
449        self.curriculum_learner.reset();
450        self.evaluator.reset();
451        self.sequence_buffer.clear();
452        self.metrics = TransformerOptimizerMetrics::new();
453
454        Ok(())
455    }
456}
457
458impl<
459        T: Float
460            + Debug
461            + Default
462            + Clone
463            + scirs2_core::ndarray::ScalarOperand
464            + Send
465            + Sync
466            + 'static,
467    > SequenceBuffer<T>
468{
469    /// Create new sequence buffer
470    pub fn new(capacity: usize) -> Self {
471        Self {
472            gradient_sequences: VecDeque::new(),
473            parameter_sequences: VecDeque::new(),
474            loss_sequences: VecDeque::new(),
475            lr_sequences: VecDeque::new(),
476            capacity,
477        }
478    }
479
480    /// Add gradient to buffer
481    pub fn add_gradient(&mut self, gradient: Array1<T>) {
482        self.gradient_sequences.push_back(gradient);
483        if self.gradient_sequences.len() > self.capacity {
484            self.gradient_sequences.pop_front();
485        }
486    }
487
488    /// Add loss to buffer
489    pub fn add_loss(&mut self, loss: T) {
490        self.loss_sequences.push_back(loss);
491        if self.loss_sequences.len() > self.capacity {
492            self.loss_sequences.pop_front();
493        }
494    }
495
496    /// Add learning rate to buffer
497    pub fn add_learning_rate(&mut self, lr: T) {
498        self.lr_sequences.push_back(lr);
499        if self.lr_sequences.len() > self.capacity {
500            self.lr_sequences.pop_front();
501        }
502    }
503
504    /// Clear buffer
505    pub fn clear(&mut self) {
506        self.gradient_sequences.clear();
507        self.parameter_sequences.clear();
508        self.loss_sequences.clear();
509        self.lr_sequences.clear();
510    }
511
512    /// Get recent gradient history
513    pub fn get_recent_gradients(&self, count: usize) -> Vec<&Array1<T>> {
514        self.gradient_sequences.iter().rev().take(count).collect()
515    }
516}
517
518impl Default for TransformerOptimizerMetrics {
519    fn default() -> Self {
520        Self::new()
521    }
522}
523
524impl TransformerOptimizerMetrics {
525    /// Create new metrics tracker
526    pub fn new() -> Self {
527        Self {
528            total_steps: 0,
529            convergence_history: Vec::new(),
530            attention_stats: HashMap::new(),
531            strategy_stats: HashMap::new(),
532            performance_comparisons: HashMap::new(),
533        }
534    }
535
536    /// Update metrics after optimization step
537    pub fn update_step(&mut self, loss: f64, step: usize) {
538        self.total_steps = step;
539        self.convergence_history.push(loss);
540
541        // Keep only recent history
542        if self.convergence_history.len() > 10000 {
543            self.convergence_history.remove(0);
544        }
545    }
546}