Skip to main content

optirs_learned/
common.rs

1//! Common types and configurations for learned optimizers
2
3use scirs2_core::ndarray::{Array1, Array2};
4use scirs2_core::numeric::Float;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fmt::Debug;
8
9/// Base configuration for learned optimizers
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct LearnedOptimizerConfig {
12    /// Learning rate
13    pub learning_rate: f64,
14
15    /// Meta-learning rate for higher-level optimization
16    pub meta_learning_rate: f64,
17
18    /// Batch size for training
19    pub batch_size: usize,
20
21    /// Maximum number of optimization steps
22    pub max_steps: usize,
23
24    /// Convergence threshold
25    pub convergence_threshold: f64,
26
27    /// Whether to use momentum
28    pub use_momentum: bool,
29
30    /// Momentum decay factor
31    pub momentum_decay: f64,
32
33    /// Weight decay for regularization
34    pub weight_decay: f64,
35
36    /// Hidden size for neural networks
37    pub hidden_size: usize,
38
39    /// Number of attention heads
40    pub attention_heads: usize,
41
42    /// Size of gradient history buffer
43    pub gradient_history_size: usize,
44
45    /// Number of input features
46    pub input_features: usize,
47
48    /// Number of output features
49    pub output_features: usize,
50
51    /// Number of layers in neural networks
52    pub num_layers: usize,
53
54    /// Dropout rate for regularization
55    pub dropout_rate: f64,
56
57    /// Whether to use attention mechanisms
58    pub use_attention: bool,
59
60    /// Random seed for reproducibility
61    pub seed: Option<u64>,
62}
63
64impl Default for LearnedOptimizerConfig {
65    fn default() -> Self {
66        Self {
67            learning_rate: 0.001,
68            meta_learning_rate: 0.0001,
69            batch_size: 32,
70            max_steps: 1000,
71            convergence_threshold: 1e-6,
72            use_momentum: true,
73            momentum_decay: 0.9,
74            weight_decay: 1e-4,
75            hidden_size: 256,
76            attention_heads: 8,
77            gradient_history_size: 50,
78            input_features: 256,
79            output_features: 256,
80            num_layers: 3,
81            dropout_rate: 0.1,
82            use_attention: true,
83            seed: None,
84        }
85    }
86}
87
88/// Meta-optimization strategies
89#[derive(Debug, Clone, Serialize, Deserialize, Default)]
90pub enum MetaOptimizationStrategy {
91    /// First-order approximation (FOMAML)
92    #[default]
93    FirstOrder,
94
95    /// Full second-order gradients (MAML)
96    SecondOrder,
97
98    /// MAML algorithm (alias for SecondOrder)
99    MAML,
100
101    /// Reptile algorithm
102    Reptile,
103
104    /// Custom gradient-based meta-learning
105    Custom {
106        inner_steps: usize,
107        outer_learning_rate: f64,
108    },
109}
110
111/// Base optimizer state
112#[derive(Debug, Clone)]
113pub struct OptimizerState<T: Float + Debug + Send + Sync + 'static> {
114    /// Current parameters
115    pub parameters: Array1<T>,
116
117    /// Current gradients
118    pub gradients: Array1<T>,
119
120    /// Momentum buffer
121    pub momentum: Option<Array1<T>>,
122
123    /// Hidden states for neural optimizers
124    pub hidden_states: HashMap<String, Array1<T>>,
125
126    /// Memory buffers for attention mechanisms
127    pub memory_buffers: HashMap<String, Array2<T>>,
128
129    /// Current step number
130    pub step: usize,
131
132    /// Step count (alias for step)
133    pub step_count: usize,
134
135    /// Current loss value
136    pub loss: Option<T>,
137
138    /// Learning rate schedule
139    pub learning_rate: T,
140
141    /// Additional state metadata
142    pub metadata: StateMetadata,
143}
144
145impl<T: Float + Debug + Send + Sync + 'static> OptimizerState<T> {
146    pub fn new(num_params: usize) -> Self {
147        Self {
148            parameters: Array1::zeros(num_params),
149            gradients: Array1::zeros(num_params),
150            momentum: None,
151            hidden_states: HashMap::new(),
152            memory_buffers: HashMap::new(),
153            step: 0,
154            step_count: 0,
155            loss: None,
156            learning_rate: scirs2_core::numeric::NumCast::from(0.001).expect("unwrap failed"),
157            metadata: StateMetadata::default(),
158        }
159    }
160}
161
162/// State metadata
163#[derive(Debug, Clone)]
164pub struct StateMetadata {
165    /// Task identifier
166    pub task_id: Option<String>,
167
168    /// Optimizer type used
169    pub optimizer_type: Option<String>,
170
171    /// Version information
172    pub version: String,
173
174    /// Timestamp of creation/update
175    pub timestamp: std::time::SystemTime,
176
177    /// Checksum for integrity
178    pub checksum: u64,
179
180    /// Compression level used
181    pub compression_level: u8,
182
183    /// Additional custom metadata
184    pub custom_data: HashMap<String, String>,
185}
186
187impl Default for StateMetadata {
188    fn default() -> Self {
189        Self {
190            task_id: None,
191            optimizer_type: None,
192            version: "1.0".to_string(),
193            timestamp: std::time::SystemTime::now(),
194            checksum: 0,
195            compression_level: 0,
196            custom_data: HashMap::new(),
197        }
198    }
199}
200
201/// Neural optimizer type variants
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub enum NeuralOptimizerType {
204    /// Transformer-based optimizer
205    Transformer,
206
207    /// LSTM-based optimizer
208    LSTM,
209
210    /// Simple MLP optimizer
211    MLP,
212
213    /// Convolutional optimizer
214    CNN,
215}
216
217/// Task context for meta-learning
218#[derive(Debug, Clone)]
219pub struct TaskContext<T: Float + Debug + Send + Sync + 'static> {
220    /// Task identifier
221    pub task_id: String,
222
223    /// Initial parameters for the task
224    pub initial_parameters: Array1<T>,
225
226    /// Task-specific data
227    pub task_data: Array2<T>,
228
229    /// Target values for the task
230    pub targets: Array1<T>,
231
232    /// Task difficulty or complexity measure
233    pub difficulty: f64,
234}
235
236/// Neural optimizer metrics
237#[derive(Debug, Clone)]
238pub struct NeuralOptimizerMetrics {
239    /// Average loss over training
240    pub avg_loss: f64,
241
242    /// Convergence rate
243    pub convergence_rate: f64,
244
245    /// Number of steps to convergence
246    pub steps_to_convergence: Option<usize>,
247
248    /// Final accuracy/performance
249    pub final_performance: f64,
250
251    /// Training time in seconds
252    pub training_time: f64,
253}
254
255/// Task performance metrics
256#[derive(Debug, Clone)]
257pub struct TaskPerformance {
258    /// Task identifier
259    pub task_id: String,
260
261    /// Performance score (higher is better)
262    pub score: f64,
263
264    /// Whether the task converged
265    pub converged: bool,
266
267    /// Number of optimization steps taken
268    pub steps_taken: usize,
269
270    /// Final loss value
271    pub final_loss: f64,
272}