1use crate::error::OptimizeError;
7use crate::stochastic::{
8 clip_gradients, generate_batch_indices, update_learning_rate, DataProvider,
9 LearningRateSchedule, StochasticGradientFunction,
10};
11use crate::unconstrained::result::OptimizeResult;
12use scirs2_core::ndarray::Array1;
13use scirs2_core::random::prelude::*;
14
15#[derive(Debug, Clone)]
17pub struct SGDOptions {
18 pub learning_rate: f64,
20 pub max_iter: usize,
22 pub tol: f64,
24 pub lr_schedule: LearningRateSchedule,
26 pub gradient_clip: Option<f64>,
28 pub batch_size: Option<usize>,
30}
31
32impl Default for SGDOptions {
33 fn default() -> Self {
34 Self {
35 learning_rate: 0.01,
36 max_iter: 1000,
37 tol: 1e-6,
38 lr_schedule: LearningRateSchedule::Constant,
39 gradient_clip: None,
40 batch_size: None,
41 }
42 }
43}
44
45#[allow(dead_code)]
47pub fn minimize_sgd<F>(
48 mut grad_func: F,
49 mut x: Array1<f64>,
50 data_provider: Box<dyn DataProvider>,
51 options: SGDOptions,
52) -> Result<OptimizeResult<f64>, OptimizeError>
53where
54 F: StochasticGradientFunction,
55{
56 let mut func_evals = 0;
57 let mut _grad_evals = 0;
58
59 let num_samples = data_provider.num_samples();
60 let batch_size = options.batch_size.unwrap_or(num_samples);
61 let actual_batch_size = batch_size.min(num_samples);
62
63 let mut best_x = x.clone();
65 let mut best_f = f64::INFINITY;
66
67 let mut prev_loss = f64::INFINITY;
69 let mut stagnant_iterations = 0;
70
71 println!("Starting SGD optimization:");
72 println!(" Parameters: {}", x.len());
73 println!(" Dataset size: {}", num_samples);
74 println!(" Batch size: {}", actual_batch_size);
75 println!(" Initial learning rate: {}", options.learning_rate);
76
77 #[allow(clippy::explicit_counter_loop)]
78 for iteration in 0..options.max_iter {
79 let current_lr = update_learning_rate(
81 options.learning_rate,
82 iteration,
83 options.max_iter,
84 &options.lr_schedule,
85 );
86
87 let batch_indices = if actual_batch_size < num_samples {
89 generate_batch_indices(num_samples, actual_batch_size, true)
90 } else {
91 (0..num_samples).collect()
92 };
93
94 let batch_data = data_provider.get_batch(&batch_indices);
96
97 let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
99 _grad_evals += 1;
100
101 if let Some(clip_threshold) = options.gradient_clip {
103 clip_gradients(&mut gradient, clip_threshold);
104 }
105
106 x = &x - &(&gradient * current_lr);
108
109 if iteration % 10 == 0 || iteration == options.max_iter - 1 {
111 let full_data = data_provider.get_full_data();
112 let current_loss = grad_func.compute_value(&x.view(), &full_data);
113 func_evals += 1;
114
115 if current_loss < best_f {
117 best_f = current_loss;
118 best_x = x.clone();
119 stagnant_iterations = 0;
120 } else {
121 stagnant_iterations += 1;
122 }
123
124 if iteration % 100 == 0 {
126 let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
127 println!(
128 " Iteration {}: loss = {:.6e}, |grad| = {:.3e}, lr = {:.3e}",
129 iteration, current_loss, grad_norm, current_lr
130 );
131 }
132
133 let loss_change = (prev_loss - current_loss).abs();
135 if loss_change < options.tol {
136 return Ok(OptimizeResult {
137 x: best_x,
138 fun: best_f,
139 nit: iteration,
140 func_evals,
141 nfev: func_evals,
142 success: true,
143 message: format!(
144 "SGD converged: loss change {:.2e} < {:.2e}",
145 loss_change, options.tol
146 ),
147 jacobian: Some(gradient),
148 hessian: None,
149 });
150 }
151
152 prev_loss = current_loss;
153
154 if stagnant_iterations > 50 {
156 return Ok(OptimizeResult {
157 x: best_x,
158 fun: best_f,
159 nit: iteration,
160 func_evals,
161 nfev: func_evals,
162 success: false,
163 message: "SGD stopped due to stagnation".to_string(),
164 jacobian: Some(gradient),
165 hessian: None,
166 });
167 }
168 }
169 }
170
171 let full_data = data_provider.get_full_data();
173 let final_loss = grad_func.compute_value(&best_x.view(), &full_data);
174 func_evals += 1;
175
176 Ok(OptimizeResult {
177 x: best_x,
178 fun: final_loss.min(best_f),
179 nit: options.max_iter,
180 func_evals,
181 nfev: func_evals,
182 success: false,
183 message: "SGD reached maximum iterations".to_string(),
184 jacobian: None,
185 hessian: None,
186 })
187}
188
189#[allow(dead_code)]
191pub fn minimize_svrg<F>(
192 mut grad_func: F,
193 mut x: Array1<f64>,
194 data_provider: Box<dyn DataProvider>,
195 options: SGDOptions,
196) -> Result<OptimizeResult<f64>, OptimizeError>
197where
198 F: StochasticGradientFunction,
199{
200 let mut func_evals = 0;
201 let mut _grad_evals = 0;
202
203 let num_samples = data_provider.num_samples();
204 let batch_size = options.batch_size.unwrap_or(1);
205 let update_frequency = num_samples / batch_size; let full_data = data_provider.get_full_data();
209 let mut full_gradient = grad_func.compute_gradient(&x.view(), &full_data);
210 _grad_evals += 1;
211
212 let mut x_snapshot = x.clone();
213 let mut best_x = x.clone();
214 let mut best_f = f64::INFINITY;
215
216 println!("Starting SVRG optimization:");
217 println!(" Parameters: {}", x.len());
218 println!(" Dataset size: {}", num_samples);
219 println!(" Batch size: {}", batch_size);
220 println!(" Update frequency: {}", update_frequency);
221
222 for epoch in 0..options.max_iter {
223 let current_lr = update_learning_rate(
224 options.learning_rate,
225 epoch,
226 options.max_iter,
227 &options.lr_schedule,
228 );
229
230 for _inner_iter in 0..update_frequency {
232 let batch_indices = generate_batch_indices(num_samples, batch_size, true);
234 let batch_data = data_provider.get_batch(&batch_indices);
235
236 let stoch_grad = grad_func.compute_gradient(&x.view(), &batch_data);
238 _grad_evals += 1;
239
240 let control_grad = grad_func.compute_gradient(&x_snapshot.view(), &batch_data);
242 _grad_evals += 1;
243
244 let mut svrg_gradient = &stoch_grad - &control_grad + &full_gradient;
246
247 if let Some(clip_threshold) = options.gradient_clip {
249 clip_gradients(&mut svrg_gradient, clip_threshold);
250 }
251
252 x = &x - &(&svrg_gradient * current_lr);
254 }
255
256 x_snapshot = x.clone();
258 full_gradient = grad_func.compute_gradient(&x_snapshot.view(), &full_data);
259 _grad_evals += 1;
260
261 let current_loss = grad_func.compute_value(&x.view(), &full_data);
263 func_evals += 1;
264
265 if current_loss < best_f {
266 best_f = current_loss;
267 best_x = x.clone();
268 }
269
270 if epoch % 10 == 0 {
271 let grad_norm = full_gradient.mapv(|g| g * g).sum().sqrt();
272 println!(
273 " Epoch {}: loss = {:.6e}, |grad| = {:.3e}, lr = {:.3e}",
274 epoch, current_loss, grad_norm, current_lr
275 );
276 }
277
278 let grad_norm = full_gradient.mapv(|g| g * g).sum().sqrt();
280 if grad_norm < options.tol {
281 return Ok(OptimizeResult {
282 x: best_x,
283 fun: best_f,
284 nit: epoch,
285 func_evals,
286 nfev: func_evals,
287 success: true,
288 message: format!(
289 "SVRG converged: gradient norm {:.2e} < {:.2e}",
290 grad_norm, options.tol
291 ),
292 jacobian: Some(full_gradient),
293 hessian: None,
294 });
295 }
296 }
297
298 Ok(OptimizeResult {
299 x: best_x,
300 fun: best_f,
301 nit: options.max_iter,
302 func_evals,
303 nfev: func_evals,
304 success: false,
305 message: "SVRG reached maximum iterations".to_string(),
306 jacobian: Some(full_gradient),
307 hessian: None,
308 })
309}
310
311#[allow(dead_code)]
313pub fn minimize_mini_batch_sgd<F>(
314 mut grad_func: F,
315 mut x: Array1<f64>,
316 data_provider: Box<dyn DataProvider>,
317 options: SGDOptions,
318) -> Result<OptimizeResult<f64>, OptimizeError>
319where
320 F: StochasticGradientFunction,
321{
322 let mut func_evals = 0;
323 let mut _grad_evals = 0;
324
325 let num_samples = data_provider.num_samples();
326 let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
327 let batches_per_epoch = num_samples.div_ceil(batch_size);
328
329 let mut best_x = x.clone();
330 let mut best_f = f64::INFINITY;
331
332 let mut x_avg = x.clone();
334 let avg_start_epoch = options.max_iter / 4; println!("Starting Mini-batch SGD optimization:");
337 println!(" Parameters: {}", x.len());
338 println!(" Dataset size: {}", num_samples);
339 println!(" Batch size: {}", batch_size);
340 println!(" Batches per epoch: {}", batches_per_epoch);
341
342 for epoch in 0..options.max_iter {
343 let current_lr = update_learning_rate(
344 options.learning_rate,
345 epoch,
346 options.max_iter,
347 &options.lr_schedule,
348 );
349
350 let mut all_indices: Vec<usize> = (0..num_samples).collect();
352 use scirs2_core::random::seq::SliceRandom;
353 all_indices.shuffle(&mut thread_rng());
354
355 let mut _epoch_loss = 0.0;
356 let mut epoch_grad_norm = 0.0;
357
358 for batch_idx in 0..batches_per_epoch {
360 let start_idx = batch_idx * batch_size;
361 let end_idx = (start_idx + batch_size).min(num_samples);
362 let batch_indices = &all_indices[start_idx..end_idx];
363
364 let batch_data = data_provider.get_batch(batch_indices);
365
366 let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
368 _grad_evals += 1;
369
370 if let Some(clip_threshold) = options.gradient_clip {
372 clip_gradients(&mut gradient, clip_threshold);
373 }
374
375 x = &x - &(&gradient * current_lr);
377
378 let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
380 epoch_grad_norm += grad_norm;
381
382 let batch_loss = grad_func.compute_value(&x.view(), &batch_data);
383 func_evals += 1;
384 _epoch_loss += batch_loss;
385 }
386
387 if epoch >= avg_start_epoch {
389 let weight = 1.0 / (epoch - avg_start_epoch + 1) as f64;
390 x_avg = &x_avg * (1.0 - weight) + &x * weight;
391 }
392
393 let eval_x = if epoch >= avg_start_epoch { &x_avg } else { &x };
395
396 let full_data = data_provider.get_full_data();
398 let current_loss = grad_func.compute_value(&eval_x.view(), &full_data);
399 func_evals += 1;
400
401 if current_loss < best_f {
402 best_f = current_loss;
403 best_x = eval_x.clone();
404 }
405
406 if epoch % 10 == 0 {
408 let avg_grad_norm = epoch_grad_norm / batches_per_epoch as f64;
409 println!(
410 " Epoch {}: loss = {:.6e}, avg |grad| = {:.3e}, lr = {:.3e}{}",
411 epoch,
412 current_loss,
413 avg_grad_norm,
414 current_lr,
415 if epoch >= avg_start_epoch {
416 " (averaged)"
417 } else {
418 ""
419 }
420 );
421 }
422
423 let avg_grad_norm = epoch_grad_norm / batches_per_epoch as f64;
425 if avg_grad_norm < options.tol {
426 return Ok(OptimizeResult {
427 x: best_x,
428 fun: best_f,
429 nit: epoch,
430 func_evals,
431 nfev: func_evals,
432 success: true,
433 message: format!(
434 "Mini-batch SGD converged: avg gradient norm {:.2e} < {:.2e}",
435 avg_grad_norm, options.tol
436 ),
437 jacobian: None,
438 hessian: None,
439 });
440 }
441 }
442
443 Ok(OptimizeResult {
444 x: best_x,
445 fun: best_f,
446 nit: options.max_iter,
447 func_evals,
448 nfev: func_evals,
449 success: false,
450 message: "Mini-batch SGD reached maximum iterations".to_string(),
451 jacobian: None,
452 hessian: None,
453 })
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::stochastic::InMemoryDataProvider;
460 use approx::assert_abs_diff_eq;
461 use scirs2_core::ndarray::ArrayView1;
462
463 struct QuadraticFunction;
465
466 impl StochasticGradientFunction for QuadraticFunction {
467 fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
468 x.mapv(|xi| 2.0 * xi)
470 }
471
472 fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
473 x.mapv(|xi| xi * xi).sum()
475 }
476 }
477
478 #[test]
479 fn test_sgd_quadratic() {
480 let grad_func = QuadraticFunction;
481 let x0 = Array1::from_vec(vec![1.0, 2.0, -1.5]);
482 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
483
484 let options = SGDOptions {
485 learning_rate: 0.1,
486 max_iter: 100,
487 tol: 1e-6,
488 ..Default::default()
489 };
490
491 let result = minimize_sgd(grad_func, x0, data_provider, options).unwrap();
492
493 assert!(result.success || result.fun < 1e-4);
495 for &xi in result.x.iter() {
496 assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-2);
497 }
498 }
499
500 #[test]
501 fn test_svrg_quadratic() {
502 let grad_func = QuadraticFunction;
503 let x0 = Array1::from_vec(vec![1.0, -1.0]);
504 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
505
506 let options = SGDOptions {
507 learning_rate: 0.05,
508 max_iter: 50,
509 batch_size: Some(5),
510 tol: 1e-6,
511 ..Default::default()
512 };
513
514 let result = minimize_svrg(grad_func, x0, data_provider, options).unwrap();
515
516 assert!(result.success || result.fun < 1e-4);
518 }
519
520 #[test]
521 fn test_mini_batch_sgd() {
522 let grad_func = QuadraticFunction;
523 let x0 = Array1::from_vec(vec![2.0, -2.0, 1.0]);
524 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 200]));
525
526 let options = SGDOptions {
527 learning_rate: 0.01,
528 max_iter: 100,
529 batch_size: Some(10),
530 tol: 1e-6,
531 lr_schedule: LearningRateSchedule::ExponentialDecay { decay_rate: 0.99 },
532 ..Default::default()
533 };
534
535 let result = minimize_mini_batch_sgd(grad_func, x0, data_provider, options).unwrap();
536
537 assert!(result.success || result.fun < 1e-3);
539 }
540}