1pub mod adam;
8pub mod adamw;
9pub mod approximation;
10pub mod momentum;
11pub mod new_adam;
12pub mod new_sgd;
13pub mod new_variance_reduction;
14pub mod optimizers;
15pub mod rmsprop;
16pub mod schedules;
17pub mod sgd;
18pub mod variance_reduction;
19
20pub use adam::{minimize_adam, AdamOptions};
22pub use adamw::{minimize_adamw, AdamWOptions};
23pub use momentum::{minimize_sgd_momentum, MomentumOptions};
24pub use rmsprop::{minimize_rmsprop, RMSPropOptions};
25pub use sgd::{minimize_sgd, SGDOptions};
26
27use crate::error::OptimizeError;
28use crate::unconstrained::result::OptimizeResult;
29use scirs2_core::ndarray::{Array1, ArrayView1};
30use scirs2_core::random::prelude::*;
31
32#[derive(Debug, Clone, Copy)]
34pub enum StochasticMethod {
35 SGD,
37 Momentum,
39 RMSProp,
41 Adam,
43 AdamW,
45}
46
47#[derive(Debug, Clone)]
49pub struct StochasticOptions {
50 pub learning_rate: f64,
52 pub max_iter: usize,
54 pub batch_size: Option<usize>,
56 pub tol: f64,
58 pub adaptive_lr: bool,
60 pub lr_decay: f64,
62 pub lr_schedule: LearningRateSchedule,
64 pub gradient_clip: Option<f64>,
66 pub early_stopping_patience: Option<usize>,
68}
69
70impl Default for StochasticOptions {
71 fn default() -> Self {
72 Self {
73 learning_rate: 0.001,
74 max_iter: 1000,
75 batch_size: None,
76 tol: 1e-6,
77 adaptive_lr: false,
78 lr_decay: 0.99,
79 lr_schedule: LearningRateSchedule::Constant,
80 gradient_clip: None,
81 early_stopping_patience: None,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub enum LearningRateSchedule {
89 Constant,
91 ExponentialDecay { decay_rate: f64 },
93 StepDecay {
95 decay_factor: f64,
96 decay_steps: usize,
97 },
98 LinearDecay,
100 CosineAnnealing,
102 InverseTimeDecay { decay_rate: f64 },
104}
105
106pub trait DataProvider {
108 fn num_samples(&self) -> usize;
110
111 fn get_batch(&self, indices: &[usize]) -> Vec<f64>;
113
114 fn get_full_data(&self) -> Vec<f64>;
116}
117
118#[derive(Clone)]
120pub struct InMemoryDataProvider {
121 data: Vec<f64>,
122}
123
124impl InMemoryDataProvider {
125 pub fn new(data: Vec<f64>) -> Self {
126 Self { data }
127 }
128}
129
130impl DataProvider for InMemoryDataProvider {
131 fn num_samples(&self) -> usize {
132 self.data.len()
133 }
134
135 fn get_batch(&self, indices: &[usize]) -> Vec<f64> {
136 indices.iter().map(|&i| self.data[i]).collect()
137 }
138
139 fn get_full_data(&self) -> Vec<f64> {
140 self.data.clone()
141 }
142}
143
144pub trait StochasticGradientFunction {
146 fn compute_gradient(&mut self, x: &ArrayView1<f64>, batchdata: &[f64]) -> Array1<f64>;
148
149 fn compute_value(&mut self, x: &ArrayView1<f64>, batchdata: &[f64]) -> f64;
151}
152
153pub struct BatchGradientWrapper<F, G> {
155 func: F,
156 grad: G,
157}
158
159impl<F, G> BatchGradientWrapper<F, G>
160where
161 F: FnMut(&ArrayView1<f64>) -> f64,
162 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
163{
164 pub fn new(func: F, grad: G) -> Self {
165 Self { func, grad }
166 }
167}
168
169impl<F, G> StochasticGradientFunction for BatchGradientWrapper<F, G>
170where
171 F: FnMut(&ArrayView1<f64>) -> f64,
172 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
173{
174 fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
175 (self.grad)(x)
176 }
177
178 fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
179 (self.func)(x)
180 }
181}
182
183#[allow(dead_code)]
185pub fn update_learning_rate(
186 initial_lr: f64,
187 epoch: usize,
188 max_epochs: usize,
189 schedule: &LearningRateSchedule,
190) -> f64 {
191 match schedule {
192 LearningRateSchedule::Constant => initial_lr,
193 LearningRateSchedule::ExponentialDecay { decay_rate } => {
194 initial_lr * decay_rate.powi(epoch as i32)
195 }
196 LearningRateSchedule::StepDecay {
197 decay_factor,
198 decay_steps,
199 } => initial_lr * decay_factor.powi((epoch / decay_steps) as i32),
200 LearningRateSchedule::LinearDecay => {
201 initial_lr * (1.0 - epoch as f64 / max_epochs as f64).max(0.0)
202 }
203 LearningRateSchedule::CosineAnnealing => {
204 initial_lr
205 * 0.5
206 * (1.0 + (std::f64::consts::PI * epoch as f64 / max_epochs as f64).cos())
207 }
208 LearningRateSchedule::InverseTimeDecay { decay_rate } => {
209 initial_lr / (1.0 + decay_rate * epoch as f64)
210 }
211 }
212}
213
214#[allow(dead_code)]
216pub fn clip_gradients(gradient: &mut Array1<f64>, maxnorm: f64) {
217 let grad_norm = gradient.mapv(|x| x * x).sum().sqrt();
218 if grad_norm > maxnorm {
219 let scale = maxnorm / grad_norm;
220 gradient.mapv_inplace(|x| x * scale);
221 }
222}
223
224#[allow(dead_code)]
226pub fn generate_batch_indices(_num_samples: usize, batchsize: usize, shuffle: bool) -> Vec<usize> {
227 let mut indices: Vec<usize> = (0.._num_samples).collect();
228
229 if shuffle {
230 use scirs2_core::random::seq::SliceRandom;
231 indices.shuffle(&mut thread_rng());
232 }
233
234 indices.into_iter().take(batchsize).collect()
235}
236
237#[allow(dead_code)]
239pub fn minimize_stochastic<F>(
240 method: StochasticMethod,
241 grad_func: F,
242 x0: Array1<f64>,
243 data_provider: Box<dyn DataProvider>,
244 options: StochasticOptions,
245) -> Result<OptimizeResult<f64>, OptimizeError>
246where
247 F: StochasticGradientFunction,
248{
249 match method {
250 StochasticMethod::SGD => {
251 let sgd_options = SGDOptions {
252 learning_rate: options.learning_rate,
253 max_iter: options.max_iter,
254 tol: options.tol,
255 lr_schedule: options.lr_schedule,
256 gradient_clip: options.gradient_clip,
257 batch_size: options.batch_size,
258 };
259 sgd::minimize_sgd(grad_func, x0, data_provider, sgd_options)
260 }
261 StochasticMethod::Momentum => {
262 let momentum_options = MomentumOptions {
263 learning_rate: options.learning_rate,
264 momentum: 0.9, max_iter: options.max_iter,
266 tol: options.tol,
267 lr_schedule: options.lr_schedule,
268 gradient_clip: options.gradient_clip,
269 batch_size: options.batch_size,
270 nesterov: false,
271 dampening: 0.0,
272 };
273 momentum::minimize_sgd_momentum(grad_func, x0, data_provider, momentum_options)
274 }
275 StochasticMethod::RMSProp => {
276 let rmsprop_options = RMSPropOptions {
277 learning_rate: options.learning_rate,
278 decay_rate: 0.99, epsilon: 1e-8,
280 max_iter: options.max_iter,
281 tol: options.tol,
282 lr_schedule: options.lr_schedule,
283 gradient_clip: options.gradient_clip,
284 batch_size: options.batch_size,
285 centered: false,
286 momentum: None,
287 };
288 rmsprop::minimize_rmsprop(grad_func, x0, data_provider, rmsprop_options)
289 }
290 StochasticMethod::Adam => {
291 let adam_options = AdamOptions {
292 learning_rate: options.learning_rate,
293 beta1: 0.9,
294 beta2: 0.999,
295 epsilon: 1e-8,
296 max_iter: options.max_iter,
297 tol: options.tol,
298 lr_schedule: options.lr_schedule,
299 gradient_clip: options.gradient_clip,
300 batch_size: options.batch_size,
301 amsgrad: false,
302 };
303 adam::minimize_adam(grad_func, x0, data_provider, adam_options)
304 }
305 StochasticMethod::AdamW => {
306 let adamw_options = AdamWOptions {
307 learning_rate: options.learning_rate,
308 beta1: 0.9,
309 beta2: 0.999,
310 epsilon: 1e-8,
311 weight_decay: 0.01, max_iter: options.max_iter,
313 tol: options.tol,
314 lr_schedule: options.lr_schedule,
315 gradient_clip: options.gradient_clip,
316 batch_size: options.batch_size,
317 decouple_weight_decay: true,
318 };
319 adamw::minimize_adamw(grad_func, x0, data_provider, adamw_options)
320 }
321 }
322}
323
324#[allow(dead_code)]
326pub fn create_stochastic_options_for_problem(
327 problem_type: &str,
328 dataset_size: usize,
329) -> StochasticOptions {
330 match problem_type.to_lowercase().as_str() {
331 "neural_network" => StochasticOptions {
332 learning_rate: 0.001,
333 max_iter: 1000,
334 batch_size: Some(32.min(dataset_size / 10)),
335 lr_schedule: LearningRateSchedule::ExponentialDecay { decay_rate: 0.99 },
336 gradient_clip: Some(1.0),
337 early_stopping_patience: Some(50),
338 ..Default::default()
339 },
340 "linear_regression" => StochasticOptions {
341 learning_rate: 0.01,
342 max_iter: 500,
343 batch_size: Some(64.min(dataset_size / 5)),
344 lr_schedule: LearningRateSchedule::LinearDecay,
345 ..Default::default()
346 },
347 "logistic_regression" => StochasticOptions {
348 learning_rate: 0.01,
349 max_iter: 200,
350 batch_size: Some(32.min(dataset_size / 10)),
351 lr_schedule: LearningRateSchedule::StepDecay {
352 decay_factor: 0.9,
353 decay_steps: 50,
354 },
355 ..Default::default()
356 },
357 "large_scale" => StochasticOptions {
358 learning_rate: 0.001,
359 max_iter: 2000,
360 batch_size: Some(128.min(dataset_size / 20)),
361 lr_schedule: LearningRateSchedule::CosineAnnealing,
362 gradient_clip: Some(5.0),
363 adaptive_lr: true,
364 ..Default::default()
365 },
366 "noisy_gradients" => StochasticOptions {
367 learning_rate: 0.01,
368 max_iter: 1000,
369 batch_size: Some(64.min(dataset_size / 5)),
370 lr_schedule: LearningRateSchedule::InverseTimeDecay { decay_rate: 1.0 },
371 gradient_clip: Some(2.0),
372 early_stopping_patience: Some(100),
373 ..Default::default()
374 },
375 _ => StochasticOptions::default(),
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use approx::assert_abs_diff_eq;
383
384 #[test]
385 fn test_learning_rate_schedules() {
386 let initial_lr = 0.1;
387 let max_epochs = 100;
388
389 let constant = LearningRateSchedule::Constant;
391 assert_abs_diff_eq!(
392 update_learning_rate(initial_lr, 50, max_epochs, &constant),
393 initial_lr,
394 epsilon = 1e-10
395 );
396
397 let exp_decay = LearningRateSchedule::ExponentialDecay { decay_rate: 0.9 };
399 let lr_exp = update_learning_rate(initial_lr, 10, max_epochs, &exp_decay);
400 assert_abs_diff_eq!(lr_exp, initial_lr * 0.9_f64.powi(10), epsilon = 1e-10);
401
402 let linear = LearningRateSchedule::LinearDecay;
404 let lr_linear = update_learning_rate(initial_lr, 50, max_epochs, &linear);
405 assert_abs_diff_eq!(lr_linear, initial_lr * 0.5, epsilon = 1e-10);
406 }
407
408 #[test]
409 fn test_gradient_clipping() {
410 let mut grad = Array1::from_vec(vec![3.0, 4.0]); clip_gradients(&mut grad, 2.5);
412
413 let clipped_norm = grad.mapv(|x| x * x).sum().sqrt();
414 assert_abs_diff_eq!(clipped_norm, 2.5, epsilon = 1e-10);
415
416 assert_abs_diff_eq!(grad[0] / grad[1], 3.0 / 4.0, epsilon = 1e-10);
418 }
419
420 #[test]
421 fn test_batch_indices_generation() {
422 let indices = generate_batch_indices(100, 10, false);
423 assert_eq!(indices.len(), 10);
424 assert_eq!(indices, (0..10).collect::<Vec<usize>>());
425
426 let shuffled = generate_batch_indices(100, 10, true);
427 assert_eq!(shuffled.len(), 10);
428 assert!(shuffled.iter().all(|&i| i < 100));
430 }
431
432 #[test]
433 fn test_in_memory_data_provider() {
434 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
435 let provider = InMemoryDataProvider::new(data.clone());
436
437 assert_eq!(provider.num_samples(), 5);
438 assert_eq!(provider.get_full_data(), data);
439
440 let batch = provider.get_batch(&[0, 2, 4]);
441 assert_eq!(batch, vec![1.0, 3.0, 5.0]);
442 }
443
444 #[test]
445 fn test_problem_specific_options() {
446 let nn_options = create_stochastic_options_for_problem("neural_network", 1000);
447 assert_eq!(nn_options.learning_rate, 0.001);
448 assert!(nn_options.batch_size.is_some());
449 assert!(nn_options.gradient_clip.is_some());
450
451 let lr_options = create_stochastic_options_for_problem("linear_regression", 500);
452 assert_eq!(lr_options.learning_rate, 0.01);
453 assert!(matches!(
454 lr_options.lr_schedule,
455 LearningRateSchedule::LinearDecay
456 ));
457
458 let large_options = create_stochastic_options_for_problem("large_scale", 10000);
459 assert!(matches!(
460 large_options.lr_schedule,
461 LearningRateSchedule::CosineAnnealing
462 ));
463 assert_eq!(large_options.batch_size, Some(128));
464 }
465}