Skip to main content

scirs2_optimize/stochastic/
mod.rs

1//! Stochastic optimization methods for machine learning and large-scale problems
2//!
3//! This module provides stochastic optimization algorithms that are particularly
4//! well-suited for machine learning, neural networks, and large-scale problems
5//! where exact gradients are expensive or noisy.
6
7pub mod adam;
8pub mod adamw;
9pub mod approximation;
10pub mod momentum;
11pub mod new_adam;
12pub mod new_sgd;
13pub mod new_variance_reduction;
14pub mod optimizers;
15pub mod rmsprop;
16pub mod schedules;
17pub mod sgd;
18pub mod variance_reduction;
19
20// Re-export commonly used items
21pub use adam::{minimize_adam, AdamOptions};
22pub use adamw::{minimize_adamw, AdamWOptions};
23pub use momentum::{minimize_sgd_momentum, MomentumOptions};
24pub use rmsprop::{minimize_rmsprop, RMSPropOptions};
25pub use sgd::{minimize_sgd, SGDOptions};
26
27use crate::error::OptimizeError;
28use crate::unconstrained::result::OptimizeResult;
29use scirs2_core::ndarray::{Array1, ArrayView1};
30use scirs2_core::random::prelude::*;
31
32/// Stochastic optimization method selection
33#[derive(Debug, Clone, Copy)]
34pub enum StochasticMethod {
35    /// Stochastic Gradient Descent
36    SGD,
37    /// SGD with Momentum
38    Momentum,
39    /// Root Mean Square Propagation
40    RMSProp,
41    /// Adaptive Moment Estimation
42    Adam,
43    /// Adam with Weight Decay
44    AdamW,
45}
46
47/// Common options for stochastic optimization
48#[derive(Debug, Clone)]
49pub struct StochasticOptions {
50    /// Learning rate (step size)
51    pub learning_rate: f64,
52    /// Maximum number of iterations (epochs)
53    pub max_iter: usize,
54    /// Batch size for mini-batch optimization
55    pub batch_size: Option<usize>,
56    /// Tolerance for convergence
57    pub tol: f64,
58    /// Whether to use adaptive learning rate
59    pub adaptive_lr: bool,
60    /// Learning rate decay factor
61    pub lr_decay: f64,
62    /// Learning rate decay schedule
63    pub lr_schedule: LearningRateSchedule,
64    /// Gradient clipping threshold
65    pub gradient_clip: Option<f64>,
66    /// Early stopping patience
67    pub early_stopping_patience: Option<usize>,
68}
69
70impl Default for StochasticOptions {
71    fn default() -> Self {
72        Self {
73            learning_rate: 0.001,
74            max_iter: 1000,
75            batch_size: None,
76            tol: 1e-6,
77            adaptive_lr: false,
78            lr_decay: 0.99,
79            lr_schedule: LearningRateSchedule::Constant,
80            gradient_clip: None,
81            early_stopping_patience: None,
82        }
83    }
84}
85
86/// Learning rate schedules
87#[derive(Debug, Clone)]
88pub enum LearningRateSchedule {
89    /// Constant learning rate
90    Constant,
91    /// Exponential decay: lr * decay^epoch
92    ExponentialDecay { decay_rate: f64 },
93    /// Step decay: lr * decay_factor every decay_steps
94    StepDecay {
95        decay_factor: f64,
96        decay_steps: usize,
97    },
98    /// Linear decay: lr * (1 - epoch/max_epochs)
99    LinearDecay,
100    /// Cosine annealing: lr * 0.5 * (1 + cos(π * epoch/max_epochs))
101    CosineAnnealing,
102    /// Inverse time decay: lr / (1 + decay_rate * epoch)
103    InverseTimeDecay { decay_rate: f64 },
104}
105
106/// Data provider trait for stochastic optimization
107pub trait DataProvider {
108    /// Get the total number of samples
109    fn num_samples(&self) -> usize;
110
111    /// Get a batch of samples
112    fn get_batch(&self, indices: &[usize]) -> Vec<f64>;
113
114    /// Get the full dataset
115    fn get_full_data(&self) -> Vec<f64>;
116}
117
118/// Simple in-memory data provider
119#[derive(Clone)]
120pub struct InMemoryDataProvider {
121    data: Vec<f64>,
122}
123
124impl InMemoryDataProvider {
125    pub fn new(data: Vec<f64>) -> Self {
126        Self { data }
127    }
128}
129
130impl DataProvider for InMemoryDataProvider {
131    fn num_samples(&self) -> usize {
132        self.data.len()
133    }
134
135    fn get_batch(&self, indices: &[usize]) -> Vec<f64> {
136        indices.iter().map(|&i| self.data[i]).collect()
137    }
138
139    fn get_full_data(&self) -> Vec<f64> {
140        self.data.clone()
141    }
142}
143
144/// Stochastic gradient function trait
145pub trait StochasticGradientFunction {
146    /// Compute gradient on a batch of data
147    fn compute_gradient(&mut self, x: &ArrayView1<f64>, batchdata: &[f64]) -> Array1<f64>;
148
149    /// Compute function value on a batch of data
150    fn compute_value(&mut self, x: &ArrayView1<f64>, batchdata: &[f64]) -> f64;
151}
152
153/// Wrapper for regular gradient functions
154pub struct BatchGradientWrapper<F, G> {
155    func: F,
156    grad: G,
157}
158
159impl<F, G> BatchGradientWrapper<F, G>
160where
161    F: FnMut(&ArrayView1<f64>) -> f64,
162    G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
163{
164    pub fn new(func: F, grad: G) -> Self {
165        Self { func, grad }
166    }
167}
168
169impl<F, G> StochasticGradientFunction for BatchGradientWrapper<F, G>
170where
171    F: FnMut(&ArrayView1<f64>) -> f64,
172    G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
173{
174    fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
175        (self.grad)(x)
176    }
177
178    fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
179        (self.func)(x)
180    }
181}
182
183/// Update learning rate according to schedule
184#[allow(dead_code)]
185pub fn update_learning_rate(
186    initial_lr: f64,
187    epoch: usize,
188    max_epochs: usize,
189    schedule: &LearningRateSchedule,
190) -> f64 {
191    match schedule {
192        LearningRateSchedule::Constant => initial_lr,
193        LearningRateSchedule::ExponentialDecay { decay_rate } => {
194            initial_lr * decay_rate.powi(epoch as i32)
195        }
196        LearningRateSchedule::StepDecay {
197            decay_factor,
198            decay_steps,
199        } => initial_lr * decay_factor.powi((epoch / decay_steps) as i32),
200        LearningRateSchedule::LinearDecay => {
201            initial_lr * (1.0 - epoch as f64 / max_epochs as f64).max(0.0)
202        }
203        LearningRateSchedule::CosineAnnealing => {
204            initial_lr
205                * 0.5
206                * (1.0 + (std::f64::consts::PI * epoch as f64 / max_epochs as f64).cos())
207        }
208        LearningRateSchedule::InverseTimeDecay { decay_rate } => {
209            initial_lr / (1.0 + decay_rate * epoch as f64)
210        }
211    }
212}
213
214/// Clip gradients to prevent exploding gradients
215#[allow(dead_code)]
216pub fn clip_gradients(gradient: &mut Array1<f64>, maxnorm: f64) {
217    let grad_norm = gradient.mapv(|x| x * x).sum().sqrt();
218    if grad_norm > maxnorm {
219        let scale = maxnorm / grad_norm;
220        gradient.mapv_inplace(|x| x * scale);
221    }
222}
223
224/// Generate random batch indices
225#[allow(dead_code)]
226pub fn generate_batch_indices(_num_samples: usize, batchsize: usize, shuffle: bool) -> Vec<usize> {
227    let mut indices: Vec<usize> = (0.._num_samples).collect();
228
229    if shuffle {
230        use scirs2_core::random::seq::SliceRandom;
231        indices.shuffle(&mut thread_rng());
232    }
233
234    indices.into_iter().take(batchsize).collect()
235}
236
237/// Main stochastic optimization function
238#[allow(dead_code)]
239pub fn minimize_stochastic<F>(
240    method: StochasticMethod,
241    grad_func: F,
242    x0: Array1<f64>,
243    data_provider: Box<dyn DataProvider>,
244    options: StochasticOptions,
245) -> Result<OptimizeResult<f64>, OptimizeError>
246where
247    F: StochasticGradientFunction,
248{
249    match method {
250        StochasticMethod::SGD => {
251            let sgd_options = SGDOptions {
252                learning_rate: options.learning_rate,
253                max_iter: options.max_iter,
254                tol: options.tol,
255                lr_schedule: options.lr_schedule,
256                gradient_clip: options.gradient_clip,
257                batch_size: options.batch_size,
258            };
259            sgd::minimize_sgd(grad_func, x0, data_provider, sgd_options)
260        }
261        StochasticMethod::Momentum => {
262            let momentum_options = MomentumOptions {
263                learning_rate: options.learning_rate,
264                momentum: 0.9, // Default momentum
265                max_iter: options.max_iter,
266                tol: options.tol,
267                lr_schedule: options.lr_schedule,
268                gradient_clip: options.gradient_clip,
269                batch_size: options.batch_size,
270                nesterov: false,
271                dampening: 0.0,
272            };
273            momentum::minimize_sgd_momentum(grad_func, x0, data_provider, momentum_options)
274        }
275        StochasticMethod::RMSProp => {
276            let rmsprop_options = RMSPropOptions {
277                learning_rate: options.learning_rate,
278                decay_rate: 0.99, // Default RMSProp decay
279                epsilon: 1e-8,
280                max_iter: options.max_iter,
281                tol: options.tol,
282                lr_schedule: options.lr_schedule,
283                gradient_clip: options.gradient_clip,
284                batch_size: options.batch_size,
285                centered: false,
286                momentum: None,
287            };
288            rmsprop::minimize_rmsprop(grad_func, x0, data_provider, rmsprop_options)
289        }
290        StochasticMethod::Adam => {
291            let adam_options = AdamOptions {
292                learning_rate: options.learning_rate,
293                beta1: 0.9,
294                beta2: 0.999,
295                epsilon: 1e-8,
296                max_iter: options.max_iter,
297                tol: options.tol,
298                lr_schedule: options.lr_schedule,
299                gradient_clip: options.gradient_clip,
300                batch_size: options.batch_size,
301                amsgrad: false,
302            };
303            adam::minimize_adam(grad_func, x0, data_provider, adam_options)
304        }
305        StochasticMethod::AdamW => {
306            let adamw_options = AdamWOptions {
307                learning_rate: options.learning_rate,
308                beta1: 0.9,
309                beta2: 0.999,
310                epsilon: 1e-8,
311                weight_decay: 0.01, // Default weight decay
312                max_iter: options.max_iter,
313                tol: options.tol,
314                lr_schedule: options.lr_schedule,
315                gradient_clip: options.gradient_clip,
316                batch_size: options.batch_size,
317                decouple_weight_decay: true,
318            };
319            adamw::minimize_adamw(grad_func, x0, data_provider, adamw_options)
320        }
321    }
322}
323
324/// Create stochastic options optimized for specific problem types
325#[allow(dead_code)]
326pub fn create_stochastic_options_for_problem(
327    problem_type: &str,
328    dataset_size: usize,
329) -> StochasticOptions {
330    match problem_type.to_lowercase().as_str() {
331        "neural_network" => StochasticOptions {
332            learning_rate: 0.001,
333            max_iter: 1000,
334            batch_size: Some(32.min(dataset_size / 10)),
335            lr_schedule: LearningRateSchedule::ExponentialDecay { decay_rate: 0.99 },
336            gradient_clip: Some(1.0),
337            early_stopping_patience: Some(50),
338            ..Default::default()
339        },
340        "linear_regression" => StochasticOptions {
341            learning_rate: 0.01,
342            max_iter: 500,
343            batch_size: Some(64.min(dataset_size / 5)),
344            lr_schedule: LearningRateSchedule::LinearDecay,
345            ..Default::default()
346        },
347        "logistic_regression" => StochasticOptions {
348            learning_rate: 0.01,
349            max_iter: 200,
350            batch_size: Some(32.min(dataset_size / 10)),
351            lr_schedule: LearningRateSchedule::StepDecay {
352                decay_factor: 0.9,
353                decay_steps: 50,
354            },
355            ..Default::default()
356        },
357        "large_scale" => StochasticOptions {
358            learning_rate: 0.001,
359            max_iter: 2000,
360            batch_size: Some(128.min(dataset_size / 20)),
361            lr_schedule: LearningRateSchedule::CosineAnnealing,
362            gradient_clip: Some(5.0),
363            adaptive_lr: true,
364            ..Default::default()
365        },
366        "noisy_gradients" => StochasticOptions {
367            learning_rate: 0.01,
368            max_iter: 1000,
369            batch_size: Some(64.min(dataset_size / 5)),
370            lr_schedule: LearningRateSchedule::InverseTimeDecay { decay_rate: 1.0 },
371            gradient_clip: Some(2.0),
372            early_stopping_patience: Some(100),
373            ..Default::default()
374        },
375        _ => StochasticOptions::default(),
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use approx::assert_abs_diff_eq;
383
384    #[test]
385    fn test_learning_rate_schedules() {
386        let initial_lr = 0.1;
387        let max_epochs = 100;
388
389        // Test constant schedule
390        let constant = LearningRateSchedule::Constant;
391        assert_abs_diff_eq!(
392            update_learning_rate(initial_lr, 50, max_epochs, &constant),
393            initial_lr,
394            epsilon = 1e-10
395        );
396
397        // Test exponential decay
398        let exp_decay = LearningRateSchedule::ExponentialDecay { decay_rate: 0.9 };
399        let lr_exp = update_learning_rate(initial_lr, 10, max_epochs, &exp_decay);
400        assert_abs_diff_eq!(lr_exp, initial_lr * 0.9_f64.powi(10), epsilon = 1e-10);
401
402        // Test linear decay
403        let linear = LearningRateSchedule::LinearDecay;
404        let lr_linear = update_learning_rate(initial_lr, 50, max_epochs, &linear);
405        assert_abs_diff_eq!(lr_linear, initial_lr * 0.5, epsilon = 1e-10);
406    }
407
408    #[test]
409    fn test_gradient_clipping() {
410        let mut grad = Array1::from_vec(vec![3.0, 4.0]); // Norm = 5
411        clip_gradients(&mut grad, 2.5);
412
413        let clipped_norm = grad.mapv(|x| x * x).sum().sqrt();
414        assert_abs_diff_eq!(clipped_norm, 2.5, epsilon = 1e-10);
415
416        // Check direction is preserved
417        assert_abs_diff_eq!(grad[0] / grad[1], 3.0 / 4.0, epsilon = 1e-10);
418    }
419
420    #[test]
421    fn test_batch_indices_generation() {
422        let indices = generate_batch_indices(100, 10, false);
423        assert_eq!(indices.len(), 10);
424        assert_eq!(indices, (0..10).collect::<Vec<usize>>());
425
426        let shuffled = generate_batch_indices(100, 10, true);
427        assert_eq!(shuffled.len(), 10);
428        // All indices should be < 100
429        assert!(shuffled.iter().all(|&i| i < 100));
430    }
431
432    #[test]
433    fn test_in_memory_data_provider() {
434        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
435        let provider = InMemoryDataProvider::new(data.clone());
436
437        assert_eq!(provider.num_samples(), 5);
438        assert_eq!(provider.get_full_data(), data);
439
440        let batch = provider.get_batch(&[0, 2, 4]);
441        assert_eq!(batch, vec![1.0, 3.0, 5.0]);
442    }
443
444    #[test]
445    fn test_problem_specific_options() {
446        let nn_options = create_stochastic_options_for_problem("neural_network", 1000);
447        assert_eq!(nn_options.learning_rate, 0.001);
448        assert!(nn_options.batch_size.is_some());
449        assert!(nn_options.gradient_clip.is_some());
450
451        let lr_options = create_stochastic_options_for_problem("linear_regression", 500);
452        assert_eq!(lr_options.learning_rate, 0.01);
453        assert!(matches!(
454            lr_options.lr_schedule,
455            LearningRateSchedule::LinearDecay
456        ));
457
458        let large_options = create_stochastic_options_for_problem("large_scale", 10000);
459        assert!(matches!(
460            large_options.lr_schedule,
461            LearningRateSchedule::CosineAnnealing
462        ));
463        assert_eq!(large_options.batch_size, Some(128));
464    }
465}