Skip to main content

optirs_learned/transformer_based_optimizer/
mod.rs

1// Transformer-Based Meta-Learning for Optimization
2//
3// This module implements transformer architectures specifically designed for
4// meta-learning in optimization tasks. It includes attention mechanisms,
5// sequence modeling for optimization trajectories, and advanced transformer
6// architectures tailored for learning optimization strategies.
7
8pub mod architecture;
9pub mod attention;
10pub mod config;
11pub mod feedforward;
12pub mod layers;
13pub mod memory_manager;
14pub mod meta_learning;
15pub mod performance_tracker;
16pub mod positional_encoding;
17pub mod sequence_processor;
18pub mod state;
19
20// Re-export main types for backward compatibility
21pub use architecture::{TransformerArchitecture, TransformerLayer};
22pub use attention::{AttentionMechanism, MultiHeadAttention};
23pub use config::{TransformerArchConfig, TransformerBasedOptimizerConfig};
24pub use feedforward::{ActivationFunction, FeedForwardNetwork};
25pub use layers::{
26    DropoutLayer, EmbeddingLayer, LayerNormalization, OutputProjection, ResidualConnections,
27};
28pub use memory_manager::{MemoryManagementStrategy, TransformerMemoryManager};
29pub use meta_learning::{MetaLearningStrategy, TransformerMetaLearning};
30pub use performance_tracker::{PerformanceMetrics, TransformerPerformanceTracker};
31pub use positional_encoding::{PositionalEncoding, PositionalEncodingType};
32pub use sequence_processor::{OptimizationSequenceProcessor, SequenceProcessingStrategy};
33pub use state::{OptimizerStateSnapshot, TransformerOptimizerState};
34
35// Re-export for backward compatibility - create alias for the old name
36pub use TransformerBasedOptimizerConfig as TransformerOptimizerConfig;
37
38use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayBase, Axis, Data, Dimension};
39use scirs2_core::numeric::{Float, ToPrimitive};
40use serde::{Deserialize, Serialize};
41use std::collections::{HashMap, VecDeque};
42use std::fmt::Debug;
43use std::sync::{Arc, Mutex};
44use std::time::{Duration, Instant};
45
46use super::{
47    LearnedOptimizerConfig, MetaOptimizationStrategy, NeuralOptimizerMetrics, NeuralOptimizerType,
48    OptimizerState, TaskContext, TaskPerformance,
49};
50use crate::error::{OptimError, Result};
51use optirs_core::adaptive_selection::OptimizerType;
52
53// Import for external compatibility
54
55/// Transformer-based meta-learning optimizer
56pub struct TransformerOptimizer<T: Float + Debug + Send + Sync + 'static> {
57    /// Core transformer architecture
58    transformer: TransformerArchitecture<T>,
59
60    /// Positional encoding for sequence modeling
61    positional_encoding: PositionalEncoding<T>,
62
63    /// Attention mechanism for optimization history
64    attention_mechanism: MultiHeadAttention<T>,
65
66    /// Feed-forward networks for optimization steps
67    feedforward_networks: Vec<FeedForwardNetwork<T>>,
68
69    /// Meta-learning components
70    meta_learning: TransformerMetaLearning<T>,
71
72    /// Sequence processor for optimization trajectories
73    sequence_processor: OptimizationSequenceProcessor<T>,
74
75    /// Memory management for long sequences
76    memory_manager: TransformerMemoryManager<T>,
77
78    /// Configuration
79    config: TransformerBasedOptimizerConfig<T>,
80
81    /// Performance tracking
82    performance_tracker: TransformerPerformanceTracker<T>,
83
84    /// State management
85    state: TransformerOptimizerState<T>,
86}
87
88impl<
89        T: Float
90            + Debug
91            + Send
92            + Sync
93            + 'static
94            + scirs2_core::ndarray::ScalarOperand
95            + scirs2_core::numeric::FromPrimitive,
96    > TransformerOptimizer<T>
97{
98    /// Create new transformer optimizer
99    pub fn new(config: TransformerBasedOptimizerConfig<T>) -> Result<Self> {
100        let transformer_config = TransformerArchConfig::from_optimizer_config(&config);
101        let transformer = TransformerArchitecture::new(transformer_config)?;
102
103        let positional_encoding = PositionalEncoding::new(
104            config.sequence_length,
105            config.model_dimension,
106            config.positional_encoding_type,
107        )?;
108
109        let attention_mechanism = MultiHeadAttention::new(
110            config.num_attention_heads,
111            config.model_dimension,
112            config.attention_head_dimension,
113        )?;
114
115        let mut feedforward_networks = Vec::new();
116        for _ in 0..config.num_transformer_layers {
117            feedforward_networks.push(FeedForwardNetwork::new(
118                config.model_dimension,
119                config.feedforward_dimension,
120                config.activation_function,
121            )?);
122        }
123
124        let meta_learning = TransformerMetaLearning::new(&config)?;
125        let sequence_processor = OptimizationSequenceProcessor::new(&config)?;
126        let memory_manager = TransformerMemoryManager::new(&config)?;
127        let performance_tracker = TransformerPerformanceTracker::new();
128        let state = TransformerOptimizerState::new(&config)?;
129
130        Ok(Self {
131            transformer,
132            positional_encoding,
133            attention_mechanism,
134            feedforward_networks,
135            meta_learning,
136            sequence_processor,
137            memory_manager,
138            config,
139            performance_tracker,
140            state,
141        })
142    }
143
144    /// Generate optimization step using transformer
145    pub fn generate_optimization_step(
146        &mut self,
147        gradient_history: &Array2<T>,
148        parameter_history: &Array2<T>,
149        loss_history: &Array1<T>,
150    ) -> Result<Array1<T>> {
151        let start_time = Instant::now();
152
153        // Process input sequences
154        let processed_sequence = self.sequence_processor.process_optimization_sequence(
155            gradient_history,
156            parameter_history,
157            loss_history,
158        )?;
159
160        // Apply positional encoding
161        let encoded_sequence = self.positional_encoding.encode(&processed_sequence)?;
162
163        // Forward pass through transformer
164        let transformer_output = self.transformer.forward(&encoded_sequence)?;
165
166        // Generate optimization step
167        let optimization_step = self
168            .meta_learning
169            .generate_update(&transformer_output, &self.state.current_parameters)?;
170
171        // Update state
172        self.state
173            .update_with_step(&optimization_step, loss_history.last().copied())?;
174
175        // Track performance
176        let elapsed = start_time.elapsed();
177        self.performance_tracker
178            .record_optimization_step(elapsed, &optimization_step);
179
180        Ok(optimization_step)
181    }
182
183    /// Train the transformer on optimization trajectories
184    pub fn train_on_trajectories(
185        &mut self,
186        trajectories: &[OptimizationTrajectory<T>],
187    ) -> Result<TrainingMetrics> {
188        let start_time = Instant::now();
189        let mut total_loss = T::zero();
190        let mut batch_count = 0;
191
192        for trajectory in trajectories {
193            // Process trajectory into sequences
194            let sequences = self
195                .sequence_processor
196                .trajectory_to_sequences(trajectory)?;
197
198            for sequence in sequences {
199                // Forward pass
200                let prediction = self.forward_sequence(&sequence.input)?;
201
202                // Calculate loss
203                let loss = self.calculate_sequence_loss(&prediction, &sequence.target)?;
204                total_loss = total_loss + loss;
205
206                // Backward pass (simplified)
207                self.backward_pass(&sequence.input, &sequence.target, loss)?;
208
209                batch_count += 1;
210            }
211        }
212
213        let avg_loss = if batch_count > 0 {
214            total_loss
215                / scirs2_core::numeric::NumCast::from(batch_count).unwrap_or_else(|| T::zero())
216        } else {
217            T::zero()
218        };
219
220        let training_time = start_time.elapsed();
221
222        let metrics = TrainingMetrics {
223            loss: avg_loss.to_f64().unwrap_or(0.0),
224            training_time,
225            num_sequences: batch_count,
226            convergence_rate: self.calculate_convergence_rate()?,
227        };
228
229        self.performance_tracker
230            .record_training_epoch(metrics.clone());
231
232        Ok(metrics)
233    }
234
235    /// Forward pass through the transformer for a sequence
236    fn forward_sequence(&mut self, sequence: &Array2<T>) -> Result<Array2<T>> {
237        // Apply positional encoding
238        let encoded_sequence = self.positional_encoding.encode(sequence)?;
239
240        // Forward through transformer
241        self.transformer.forward(&encoded_sequence)
242    }
243
244    /// Calculate loss for sequence prediction
245    fn calculate_sequence_loss(&self, prediction: &Array2<T>, target: &Array2<T>) -> Result<T> {
246        if prediction.shape() != target.shape() {
247            return Err(OptimError::Other(
248                "Shape mismatch in loss calculation".to_string(),
249            ));
250        }
251
252        // Mean squared error
253        let diff = prediction - target;
254        let squared_diff = &diff * &diff;
255        let sum = squared_diff.sum();
256        let mse = sum / T::from(prediction.len()).expect("unwrap failed");
257
258        Ok(mse)
259    }
260
261    /// Simplified backward pass
262    fn backward_pass(&mut self, input: &Array2<T>, target: &Array2<T>, loss: T) -> Result<()> {
263        // In a full implementation, this would compute gradients and update parameters
264        // For now, we'll just update the learning state
265        self.meta_learning.update_from_loss(loss)?;
266        Ok(())
267    }
268
269    /// Calculate convergence rate
270    fn calculate_convergence_rate(&self) -> Result<f64> {
271        let loss_history = self.performance_tracker.get_loss_history();
272        if loss_history.len() < 2 {
273            return Ok(0.0);
274        }
275
276        let recent_losses: Vec<_> = loss_history.iter().rev().take(10).collect();
277        if recent_losses.len() < 2 {
278            return Ok(0.0);
279        }
280
281        let initial_loss = *recent_losses.last().expect("unwrap failed");
282        let final_loss = *recent_losses.first().expect("unwrap failed");
283
284        let improvement = (initial_loss - final_loss) / initial_loss;
285        let improvement_f64 = improvement.to_f64().unwrap_or(0.0);
286        Ok(improvement_f64.clamp(0.0, 1.0))
287    }
288
289    /// Get current state
290    pub fn get_state(&self) -> &TransformerOptimizerState<T> {
291        &self.state
292    }
293
294    /// Get performance metrics
295    pub fn get_performance_metrics(&self) -> &TransformerPerformanceTracker<T> {
296        &self.performance_tracker
297    }
298
299    /// Reset optimizer state
300    pub fn reset_state(&mut self) -> Result<()> {
301        self.state = TransformerOptimizerState::new(&self.config)?;
302        self.performance_tracker.reset();
303        Ok(())
304    }
305}
306
307/// Optimization trajectory for training
308#[derive(Debug, Clone)]
309pub struct OptimizationTrajectory<T: Float + Debug + Send + Sync + 'static> {
310    pub gradient_sequence: Array2<T>,
311    pub parameter_sequence: Array2<T>,
312    pub loss_sequence: Array1<T>,
313    pub metadata: TrajectoryMetadata,
314}
315
316/// Trajectory metadata
317#[derive(Debug, Clone)]
318pub struct TrajectoryMetadata {
319    pub task_id: String,
320    pub optimizer_type: String,
321    pub convergence_achieved: bool,
322    pub total_steps: usize,
323}
324
325/// Training sequence
326#[derive(Debug, Clone)]
327pub struct TrainingSequence<T: Float + Debug + Send + Sync + 'static> {
328    pub input: Array2<T>,
329    pub target: Array2<T>,
330    pub sequence_length: usize,
331}
332
333/// Training metrics
334#[derive(Debug, Clone)]
335pub struct TrainingMetrics {
336    pub loss: f64,
337    pub training_time: Duration,
338    pub num_sequences: usize,
339    pub convergence_rate: f64,
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    #[ignore]
348    fn test_transformer_optimizer_creation() {
349        let config = TransformerBasedOptimizerConfig::default();
350        let optimizer = TransformerOptimizer::<f32>::new(config);
351        assert!(optimizer.is_ok());
352    }
353
354    #[test]
355    fn test_trajectory_creation() {
356        let trajectory = OptimizationTrajectory::<f32> {
357            gradient_sequence: Array2::zeros((10, 5)),
358            parameter_sequence: Array2::zeros((10, 5)),
359            loss_sequence: Array1::zeros(10),
360            metadata: TrajectoryMetadata {
361                task_id: "test".to_string(),
362                optimizer_type: "adam".to_string(),
363                convergence_achieved: true,
364                total_steps: 10,
365            },
366        };
367
368        assert_eq!(trajectory.gradient_sequence.shape(), &[10, 5]);
369        assert_eq!(trajectory.loss_sequence.len(), 10);
370    }
371}