optirs_learned/transformer/
mod.rs1use std::fmt::Debug;
2pub mod architecture;
10pub mod strategies;
11pub mod training;
12
13pub use architecture::{
15 ActivationFunction, AttentionOptimization, FeedForwardNetwork, InputEmbedding, LayerNorm,
16 MultiHeadAttention, OutputProjectionLayer, PositionalEncoder, PositionalEncodingType,
17 TransformerLayer,
18};
19
20pub use strategies::{
21 GradientProcessingStrategy, GradientProcessor, LearningRateAdaptationStrategy,
22 LearningRateAdapter, MomentumIntegrator, MomentumStrategy, RegularizationStrategy,
23 TransformerRegularizer,
24};
25
26pub use training::{
27 CurriculumLearner, CurriculumStrategy, EvaluationStrategy, MetaLearningStrategy,
28 TransformerEvaluator, TransformerMetaLearner,
29};
30
31use scirs2_core::ndarray::{Array1, Array2};
32use scirs2_core::numeric::Float;
33use std::collections::{HashMap, VecDeque};
34
35use super::{LearnedOptimizerConfig, MetaOptimizationStrategy};
36use crate::error::{OptimError, Result};
37
38#[derive(Debug, Clone)]
40pub struct TransformerOptimizerConfig {
41 pub base_config: LearnedOptimizerConfig,
43
44 pub modeldim: usize,
46
47 pub numheads: usize,
49
50 pub ff_dim: usize,
52
53 pub num_layers: usize,
55
56 pub max_sequence_length: usize,
58
59 pub attention_dropout: f64,
61
62 pub ff_dropout: f64,
64
65 pub layer_norm_eps: f64,
67
68 pub pre_layer_norm: bool,
70
71 pub pos_encoding_type: PositionalEncodingType,
73
74 pub relative_position_bias: bool,
76
77 pub use_rope: bool,
79
80 pub gradient_checkpointing: bool,
82
83 pub attention_optimization: AttentionOptimization,
85
86 pub multi_scale_attention: bool,
88
89 pub cross_attention: bool,
91
92 pub memory_efficient: bool,
94}
95
96#[derive(Debug, Clone)]
98pub struct TransformerNetwork<
99 T: Float
100 + Debug
101 + Default
102 + Clone
103 + std::iter::Sum
104 + scirs2_core::ndarray::ScalarOperand
105 + Send
106 + Sync
107 + 'static,
108> {
109 input_embedding: InputEmbedding<T>,
111
112 layers: Vec<TransformerLayer<T>>,
114
115 output_projection: OutputProjectionLayer<T>,
117
118 output_layer_norm: LayerNorm<T>,
120
121 position_encoder: PositionalEncoder<T>,
123
124 config: TransformerOptimizerConfig,
126}
127
128#[derive(Debug)]
130pub struct TransformerOptimizer<
131 T: Float
132 + Debug
133 + Default
134 + Clone
135 + std::iter::Sum
136 + scirs2_core::ndarray::ScalarOperand
137 + Send
138 + Sync
139 + 'static,
140> {
141 config: TransformerOptimizerConfig,
143
144 transformer_network: TransformerNetwork<T>,
146
147 gradient_processor: GradientProcessor<T>,
149
150 lr_adapter: LearningRateAdapter<T>,
152
153 momentum_integrator: MomentumIntegrator<T>,
155
156 regularizer: TransformerRegularizer<T>,
158
159 meta_learner: TransformerMetaLearner<T>,
161
162 curriculum_learner: CurriculumLearner<T>,
164
165 evaluator: TransformerEvaluator<T>,
167
168 sequence_buffer: SequenceBuffer<T>,
170
171 metrics: TransformerOptimizerMetrics,
173
174 step_count: usize,
176
177 rng: scirs2_core::random::CoreRandom,
179}
180
181#[derive(Debug, Clone)]
183pub struct SequenceBuffer<
184 T: Float + Debug + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
185> {
186 gradient_sequences: VecDeque<Array1<T>>,
188
189 parameter_sequences: VecDeque<Array1<T>>,
191
192 loss_sequences: VecDeque<T>,
194
195 lr_sequences: VecDeque<T>,
197
198 capacity: usize,
200}
201
202#[derive(Debug, Clone)]
204pub struct TransformerOptimizerMetrics {
205 total_steps: usize,
207
208 convergence_history: Vec<f64>,
210
211 attention_stats: HashMap<String, f64>,
213
214 strategy_stats: HashMap<String, f64>,
216
217 performance_comparisons: HashMap<String, f64>,
219}
220
221impl<
222 T: Float
223 + Debug
224 + Default
225 + Clone
226 + std::iter::Sum
227 + scirs2_core::ndarray::ScalarOperand
228 + Send
229 + Sync
230 + 'static,
231 > TransformerNetwork<T>
232{
233 pub fn new(config: &TransformerOptimizerConfig) -> Result<Self> {
235 let input_embedding = InputEmbedding::new(config.modeldim, config.modeldim)?;
236
237 let mut layers = Vec::new();
238 for _ in 0..config.num_layers {
239 let mut rng = scirs2_core::random::thread_rng();
240 layers.push(TransformerLayer::new(config, &mut rng)?);
241 }
242
243 let output_projection = OutputProjectionLayer::new(config.modeldim, config.modeldim)?;
244 let output_layer_norm = LayerNorm::new(config.modeldim);
245 let position_encoder = PositionalEncoder::new(config)?;
246
247 Ok(Self {
248 input_embedding,
249 layers,
250 output_projection,
251 output_layer_norm,
252 position_encoder,
253 config: config.clone(),
254 })
255 }
256
257 pub fn forward(&mut self, input: &Array2<T>) -> Result<Array2<T>> {
259 let mut x = self.input_embedding.forward(input)?;
261
262 x = self.position_encoder.encode(&x)?;
264
265 for layer in &mut self.layers {
267 x = layer.forward(&x)?;
268 }
269
270 x = self.output_layer_norm.forward(&x)?;
272
273 let output = self.output_projection.forward(&x)?;
275
276 Ok(output)
277 }
278
279 pub fn get_attention_patterns(&self) -> Vec<Option<&scirs2_core::ndarray::Array3<T>>> {
281 self.layers
282 .iter()
283 .map(|layer| layer.get_attention_patterns())
284 .collect()
285 }
286}
287
288impl<
289 T: Float
290 + Debug
291 + Default
292 + Clone
293 + std::iter::Sum
294 + scirs2_core::ndarray::ScalarOperand
295 + Send
296 + Sync
297 + 'static,
298 > TransformerOptimizer<T>
299{
300 pub fn new(config: TransformerOptimizerConfig) -> Result<Self> {
302 let transformer_network = TransformerNetwork::new(&config)?;
303 let gradient_processor = GradientProcessor::new(GradientProcessingStrategy::Adaptive);
304 let lr_adapter = LearningRateAdapter::new(
305 LearningRateAdaptationStrategy::TransformerPredicted,
306 scirs2_core::numeric::NumCast::from(0.001).unwrap_or_else(|| T::zero()),
307 );
308 let momentum_integrator = MomentumIntegrator::new(MomentumStrategy::TransformerPredicted);
309 let regularizer = TransformerRegularizer::new(RegularizationStrategy::Adaptive);
310 let meta_learner = TransformerMetaLearner::new(MetaLearningStrategy::GradientBased)?;
311 let curriculum_learner = CurriculumLearner::new(CurriculumStrategy::Adaptive)?;
312 let evaluator = TransformerEvaluator::new(EvaluationStrategy::Comprehensive)?;
313 let sequence_buffer = SequenceBuffer::new(1000);
314 let metrics = TransformerOptimizerMetrics::new();
315
316 Ok(Self {
317 config,
318 transformer_network,
319 gradient_processor,
320 lr_adapter,
321 momentum_integrator,
322 regularizer,
323 meta_learner,
324 curriculum_learner,
325 evaluator,
326 sequence_buffer,
327 metrics,
328 step_count: 0,
329 rng: scirs2_core::random::thread_rng(),
330 })
331 }
332
333 pub fn step(
335 &mut self,
336 parameters: &mut HashMap<String, Array2<T>>,
337 gradients: &mut HashMap<String, Array2<T>>,
338 loss: T,
339 ) -> Result<T> {
340 self.step_count += 1;
341
342 for (param_name, gradient) in gradients.iter_mut() {
344 let flat_gradient = gradient.iter().cloned().collect::<Vec<_>>();
346 let gradient_array = Array1::from_vec(flat_gradient);
347
348 let processed_gradient = self.gradient_processor.process_gradients(&gradient_array)?;
350
351 let current_lr = self
353 .lr_adapter
354 .update_learning_rate(Some(loss), Some(&processed_gradient))?;
355
356 let momentum_gradient = self.momentum_integrator.integrate_momentum(
358 &processed_gradient,
359 None, )?;
361
362 let mut param_map = HashMap::new();
364 if let Some(param_values) = parameters.get(param_name) {
365 param_map.insert(param_name.clone(), param_values.clone());
366 }
367
368 let mut grad_map = HashMap::new();
369 grad_map.insert(param_name.clone(), gradient.clone());
370
371 let _reg_loss = self.regularizer.apply_regularization(
372 ¶m_map,
373 &mut grad_map,
374 None, )?;
376
377 self.sequence_buffer.add_gradient(momentum_gradient);
379 }
380
381 self.sequence_buffer.add_loss(loss);
383 self.sequence_buffer
384 .add_learning_rate(self.lr_adapter.current_learning_rate());
385
386 let task_id = "current_task"; self.curriculum_learner
389 .update_curriculum(task_id, loss, self.step_count)?;
390
391 self.metrics
393 .update_step(loss.to_f64().unwrap_or(0.0), self.step_count);
394
395 Ok(loss)
396 }
397
398 pub fn get_statistics(&self) -> HashMap<String, f64> {
400 let mut stats = HashMap::new();
401
402 stats.insert("step_count".to_string(), self.step_count as f64);
403 stats.insert(
404 "current_lr".to_string(),
405 self.lr_adapter
406 .current_learning_rate()
407 .to_f64()
408 .unwrap_or(0.0),
409 );
410
411 let grad_stats = self.gradient_processor.statistics();
413 stats.insert(
414 "mean_gradient_magnitude".to_string(),
415 grad_stats.mean_magnitude().to_f64().unwrap_or(0.0),
416 );
417 stats.insert(
418 "gradient_sparsity".to_string(),
419 grad_stats.sparsity().to_f64().unwrap_or(0.0),
420 );
421
422 let momentum_stats = self.momentum_integrator.statistics();
424 stats.insert(
425 "momentum_magnitude".to_string(),
426 momentum_stats
427 .avg_momentum_magnitude
428 .to_f64()
429 .unwrap_or(0.0),
430 );
431
432 let curriculum_stats = self.curriculum_learner.get_curriculum_statistics();
434 for (key, value) in curriculum_stats {
435 stats.insert(format!("curriculum_{}", key), value.to_f64().unwrap_or(0.0));
436 }
437
438 stats
439 }
440
441 pub fn reset(&mut self) -> Result<()> {
443 self.step_count = 0;
444 self.gradient_processor.reset();
445 self.lr_adapter.reset();
446 self.momentum_integrator.reset();
447 self.regularizer.reset();
448 self.meta_learner.reset();
449 self.curriculum_learner.reset();
450 self.evaluator.reset();
451 self.sequence_buffer.clear();
452 self.metrics = TransformerOptimizerMetrics::new();
453
454 Ok(())
455 }
456}
457
458impl<
459 T: Float
460 + Debug
461 + Default
462 + Clone
463 + scirs2_core::ndarray::ScalarOperand
464 + Send
465 + Sync
466 + 'static,
467 > SequenceBuffer<T>
468{
469 pub fn new(capacity: usize) -> Self {
471 Self {
472 gradient_sequences: VecDeque::new(),
473 parameter_sequences: VecDeque::new(),
474 loss_sequences: VecDeque::new(),
475 lr_sequences: VecDeque::new(),
476 capacity,
477 }
478 }
479
480 pub fn add_gradient(&mut self, gradient: Array1<T>) {
482 self.gradient_sequences.push_back(gradient);
483 if self.gradient_sequences.len() > self.capacity {
484 self.gradient_sequences.pop_front();
485 }
486 }
487
488 pub fn add_loss(&mut self, loss: T) {
490 self.loss_sequences.push_back(loss);
491 if self.loss_sequences.len() > self.capacity {
492 self.loss_sequences.pop_front();
493 }
494 }
495
496 pub fn add_learning_rate(&mut self, lr: T) {
498 self.lr_sequences.push_back(lr);
499 if self.lr_sequences.len() > self.capacity {
500 self.lr_sequences.pop_front();
501 }
502 }
503
504 pub fn clear(&mut self) {
506 self.gradient_sequences.clear();
507 self.parameter_sequences.clear();
508 self.loss_sequences.clear();
509 self.lr_sequences.clear();
510 }
511
512 pub fn get_recent_gradients(&self, count: usize) -> Vec<&Array1<T>> {
514 self.gradient_sequences.iter().rev().take(count).collect()
515 }
516}
517
518impl Default for TransformerOptimizerMetrics {
519 fn default() -> Self {
520 Self::new()
521 }
522}
523
524impl TransformerOptimizerMetrics {
525 pub fn new() -> Self {
527 Self {
528 total_steps: 0,
529 convergence_history: Vec::new(),
530 attention_stats: HashMap::new(),
531 strategy_stats: HashMap::new(),
532 performance_comparisons: HashMap::new(),
533 }
534 }
535
536 pub fn update_step(&mut self, loss: f64, step: usize) {
538 self.total_steps = step;
539 self.convergence_history.push(loss);
540
541 if self.convergence_history.len() > 10000 {
543 self.convergence_history.remove(0);
544 }
545 }
546}