optirs_learned/transformer_based_optimizer/
mod.rs1pub mod architecture;
9pub mod attention;
10pub mod config;
11pub mod feedforward;
12pub mod layers;
13pub mod memory_manager;
14pub mod meta_learning;
15pub mod performance_tracker;
16pub mod positional_encoding;
17pub mod sequence_processor;
18pub mod state;
19
20pub use architecture::{TransformerArchitecture, TransformerLayer};
22pub use attention::{AttentionMechanism, MultiHeadAttention};
23pub use config::{TransformerArchConfig, TransformerBasedOptimizerConfig};
24pub use feedforward::{ActivationFunction, FeedForwardNetwork};
25pub use layers::{
26 DropoutLayer, EmbeddingLayer, LayerNormalization, OutputProjection, ResidualConnections,
27};
28pub use memory_manager::{MemoryManagementStrategy, TransformerMemoryManager};
29pub use meta_learning::{MetaLearningStrategy, TransformerMetaLearning};
30pub use performance_tracker::{PerformanceMetrics, TransformerPerformanceTracker};
31pub use positional_encoding::{PositionalEncoding, PositionalEncodingType};
32pub use sequence_processor::{OptimizationSequenceProcessor, SequenceProcessingStrategy};
33pub use state::{OptimizerStateSnapshot, TransformerOptimizerState};
34
35pub use TransformerBasedOptimizerConfig as TransformerOptimizerConfig;
37
38use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayBase, Axis, Data, Dimension};
39use scirs2_core::numeric::{Float, ToPrimitive};
40use serde::{Deserialize, Serialize};
41use std::collections::{HashMap, VecDeque};
42use std::fmt::Debug;
43use std::sync::{Arc, Mutex};
44use std::time::{Duration, Instant};
45
46use super::{
47 LearnedOptimizerConfig, MetaOptimizationStrategy, NeuralOptimizerMetrics, NeuralOptimizerType,
48 OptimizerState, TaskContext, TaskPerformance,
49};
50use crate::error::{OptimError, Result};
51use optirs_core::adaptive_selection::OptimizerType;
52
53pub struct TransformerOptimizer<T: Float + Debug + Send + Sync + 'static> {
57 transformer: TransformerArchitecture<T>,
59
60 positional_encoding: PositionalEncoding<T>,
62
63 attention_mechanism: MultiHeadAttention<T>,
65
66 feedforward_networks: Vec<FeedForwardNetwork<T>>,
68
69 meta_learning: TransformerMetaLearning<T>,
71
72 sequence_processor: OptimizationSequenceProcessor<T>,
74
75 memory_manager: TransformerMemoryManager<T>,
77
78 config: TransformerBasedOptimizerConfig<T>,
80
81 performance_tracker: TransformerPerformanceTracker<T>,
83
84 state: TransformerOptimizerState<T>,
86}
87
88impl<
89 T: Float
90 + Debug
91 + Send
92 + Sync
93 + 'static
94 + scirs2_core::ndarray::ScalarOperand
95 + scirs2_core::numeric::FromPrimitive,
96 > TransformerOptimizer<T>
97{
98 pub fn new(config: TransformerBasedOptimizerConfig<T>) -> Result<Self> {
100 let transformer_config = TransformerArchConfig::from_optimizer_config(&config);
101 let transformer = TransformerArchitecture::new(transformer_config)?;
102
103 let positional_encoding = PositionalEncoding::new(
104 config.sequence_length,
105 config.model_dimension,
106 config.positional_encoding_type,
107 )?;
108
109 let attention_mechanism = MultiHeadAttention::new(
110 config.num_attention_heads,
111 config.model_dimension,
112 config.attention_head_dimension,
113 )?;
114
115 let mut feedforward_networks = Vec::new();
116 for _ in 0..config.num_transformer_layers {
117 feedforward_networks.push(FeedForwardNetwork::new(
118 config.model_dimension,
119 config.feedforward_dimension,
120 config.activation_function,
121 )?);
122 }
123
124 let meta_learning = TransformerMetaLearning::new(&config)?;
125 let sequence_processor = OptimizationSequenceProcessor::new(&config)?;
126 let memory_manager = TransformerMemoryManager::new(&config)?;
127 let performance_tracker = TransformerPerformanceTracker::new();
128 let state = TransformerOptimizerState::new(&config)?;
129
130 Ok(Self {
131 transformer,
132 positional_encoding,
133 attention_mechanism,
134 feedforward_networks,
135 meta_learning,
136 sequence_processor,
137 memory_manager,
138 config,
139 performance_tracker,
140 state,
141 })
142 }
143
144 pub fn generate_optimization_step(
146 &mut self,
147 gradient_history: &Array2<T>,
148 parameter_history: &Array2<T>,
149 loss_history: &Array1<T>,
150 ) -> Result<Array1<T>> {
151 let start_time = Instant::now();
152
153 let processed_sequence = self.sequence_processor.process_optimization_sequence(
155 gradient_history,
156 parameter_history,
157 loss_history,
158 )?;
159
160 let encoded_sequence = self.positional_encoding.encode(&processed_sequence)?;
162
163 let transformer_output = self.transformer.forward(&encoded_sequence)?;
165
166 let optimization_step = self
168 .meta_learning
169 .generate_update(&transformer_output, &self.state.current_parameters)?;
170
171 self.state
173 .update_with_step(&optimization_step, loss_history.last().copied())?;
174
175 let elapsed = start_time.elapsed();
177 self.performance_tracker
178 .record_optimization_step(elapsed, &optimization_step);
179
180 Ok(optimization_step)
181 }
182
183 pub fn train_on_trajectories(
185 &mut self,
186 trajectories: &[OptimizationTrajectory<T>],
187 ) -> Result<TrainingMetrics> {
188 let start_time = Instant::now();
189 let mut total_loss = T::zero();
190 let mut batch_count = 0;
191
192 for trajectory in trajectories {
193 let sequences = self
195 .sequence_processor
196 .trajectory_to_sequences(trajectory)?;
197
198 for sequence in sequences {
199 let prediction = self.forward_sequence(&sequence.input)?;
201
202 let loss = self.calculate_sequence_loss(&prediction, &sequence.target)?;
204 total_loss = total_loss + loss;
205
206 self.backward_pass(&sequence.input, &sequence.target, loss)?;
208
209 batch_count += 1;
210 }
211 }
212
213 let avg_loss = if batch_count > 0 {
214 total_loss
215 / scirs2_core::numeric::NumCast::from(batch_count).unwrap_or_else(|| T::zero())
216 } else {
217 T::zero()
218 };
219
220 let training_time = start_time.elapsed();
221
222 let metrics = TrainingMetrics {
223 loss: avg_loss.to_f64().unwrap_or(0.0),
224 training_time,
225 num_sequences: batch_count,
226 convergence_rate: self.calculate_convergence_rate()?,
227 };
228
229 self.performance_tracker
230 .record_training_epoch(metrics.clone());
231
232 Ok(metrics)
233 }
234
235 fn forward_sequence(&mut self, sequence: &Array2<T>) -> Result<Array2<T>> {
237 let encoded_sequence = self.positional_encoding.encode(sequence)?;
239
240 self.transformer.forward(&encoded_sequence)
242 }
243
244 fn calculate_sequence_loss(&self, prediction: &Array2<T>, target: &Array2<T>) -> Result<T> {
246 if prediction.shape() != target.shape() {
247 return Err(OptimError::Other(
248 "Shape mismatch in loss calculation".to_string(),
249 ));
250 }
251
252 let diff = prediction - target;
254 let squared_diff = &diff * &diff;
255 let sum = squared_diff.sum();
256 let mse = sum / T::from(prediction.len()).expect("unwrap failed");
257
258 Ok(mse)
259 }
260
261 fn backward_pass(&mut self, input: &Array2<T>, target: &Array2<T>, loss: T) -> Result<()> {
263 self.meta_learning.update_from_loss(loss)?;
266 Ok(())
267 }
268
269 fn calculate_convergence_rate(&self) -> Result<f64> {
271 let loss_history = self.performance_tracker.get_loss_history();
272 if loss_history.len() < 2 {
273 return Ok(0.0);
274 }
275
276 let recent_losses: Vec<_> = loss_history.iter().rev().take(10).collect();
277 if recent_losses.len() < 2 {
278 return Ok(0.0);
279 }
280
281 let initial_loss = *recent_losses.last().expect("unwrap failed");
282 let final_loss = *recent_losses.first().expect("unwrap failed");
283
284 let improvement = (initial_loss - final_loss) / initial_loss;
285 let improvement_f64 = improvement.to_f64().unwrap_or(0.0);
286 Ok(improvement_f64.clamp(0.0, 1.0))
287 }
288
289 pub fn get_state(&self) -> &TransformerOptimizerState<T> {
291 &self.state
292 }
293
294 pub fn get_performance_metrics(&self) -> &TransformerPerformanceTracker<T> {
296 &self.performance_tracker
297 }
298
299 pub fn reset_state(&mut self) -> Result<()> {
301 self.state = TransformerOptimizerState::new(&self.config)?;
302 self.performance_tracker.reset();
303 Ok(())
304 }
305}
306
307#[derive(Debug, Clone)]
309pub struct OptimizationTrajectory<T: Float + Debug + Send + Sync + 'static> {
310 pub gradient_sequence: Array2<T>,
311 pub parameter_sequence: Array2<T>,
312 pub loss_sequence: Array1<T>,
313 pub metadata: TrajectoryMetadata,
314}
315
316#[derive(Debug, Clone)]
318pub struct TrajectoryMetadata {
319 pub task_id: String,
320 pub optimizer_type: String,
321 pub convergence_achieved: bool,
322 pub total_steps: usize,
323}
324
325#[derive(Debug, Clone)]
327pub struct TrainingSequence<T: Float + Debug + Send + Sync + 'static> {
328 pub input: Array2<T>,
329 pub target: Array2<T>,
330 pub sequence_length: usize,
331}
332
333#[derive(Debug, Clone)]
335pub struct TrainingMetrics {
336 pub loss: f64,
337 pub training_time: Duration,
338 pub num_sequences: usize,
339 pub convergence_rate: f64,
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 #[ignore]
348 fn test_transformer_optimizer_creation() {
349 let config = TransformerBasedOptimizerConfig::default();
350 let optimizer = TransformerOptimizer::<f32>::new(config);
351 assert!(optimizer.is_ok());
352 }
353
354 #[test]
355 fn test_trajectory_creation() {
356 let trajectory = OptimizationTrajectory::<f32> {
357 gradient_sequence: Array2::zeros((10, 5)),
358 parameter_sequence: Array2::zeros((10, 5)),
359 loss_sequence: Array1::zeros(10),
360 metadata: TrajectoryMetadata {
361 task_id: "test".to_string(),
362 optimizer_type: "adam".to_string(),
363 convergence_achieved: true,
364 total_steps: 10,
365 },
366 };
367
368 assert_eq!(trajectory.gradient_sequence.shape(), &[10, 5]);
369 assert_eq!(trajectory.loss_sequence.len(), 10);
370 }
371}