optirs_core/optimizers/sgd.rs
1// Stochastic Gradient Descent optimizer
2
3use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
4use scirs2_core::numeric::Float;
5use std::fmt::Debug;
6
7// SciRS2 Integration - CRITICAL for OptiRS functionality
8use scirs2_core::ScientificNumber;
9use scirs2_optimize::stochastic::{minimize_sgd, SGDOptions};
10
11use crate::error::Result;
12use crate::optimizers::Optimizer;
13
14/// Stochastic Gradient Descent optimizer
15///
16/// Implements the classic SGD algorithm with support for momentum and weight decay.
17///
18/// Formula:
19/// v_t = momentum * v_{t-1} + learning_rate * (gradient + weight_decay * param)
20/// param_t = param_{t-1} - v_t
21///
22/// # Examples
23///
24/// ```
25/// use scirs2_core::ndarray::Array1;
26/// use optirs_core::optimizers::{SGD, Optimizer};
27///
28/// // Initialize parameters and gradients
29/// let params = Array1::zeros(5);
30/// let gradients = Array1::from_vec(vec![0.1, 0.2, -0.3, 0.0, 0.5]);
31///
32/// // Create an SGD optimizer with learning rate 0.01 and momentum 0.9
33/// let mut optimizer = SGD::new_with_config(0.01, 0.9, 0.0);
34///
35/// // Update parameters
36/// let new_params = optimizer.step(¶ms, &gradients).unwrap();
37/// ```
38#[derive(Debug, Clone)]
39pub struct SGD<A: Float + ScalarOperand + Debug> {
40 /// Learning rate
41 learning_rate: A,
42 /// Momentum factor (0.0 means no momentum)
43 momentum: A,
44 /// Weight decay factor (L2 regularization)
45 weight_decay: A,
46 /// Velocity (momentum state)
47 velocity: Option<Vec<Array<A, scirs2_core::ndarray::IxDyn>>>,
48}
49
50impl<A: Float + ScalarOperand + Debug + Send + Sync> SGD<A> {
51 /// Creates a new SGD optimizer with the given learning rate and no momentum/weight decay
52 ///
53 /// # Arguments
54 ///
55 /// * `learning_rate` - The learning rate for parameter updates
56 pub fn new(learning_rate: A) -> Self {
57 Self {
58 learning_rate,
59 momentum: A::zero(),
60 weight_decay: A::zero(),
61 velocity: None,
62 }
63 }
64
65 /// Creates a new SGD optimizer with the full configuration
66 ///
67 /// # Arguments
68 ///
69 /// * `learning_rate` - The learning rate for parameter updates
70 /// * `momentum` - The momentum factor (0.0 means no momentum)
71 /// * `weight_decay` - The weight decay factor (L2 regularization)
72 pub fn new_with_config(learning_rate: A, momentum: A, weight_decay: A) -> Self {
73 Self {
74 learning_rate,
75 momentum,
76 weight_decay,
77 velocity: None,
78 }
79 }
80
81 /// Sets the momentum factor
82 ///
83 /// # Arguments
84 ///
85 /// * `momentum` - The momentum factor (0.0 means no momentum)
86 pub fn set_momentum(&mut self, momentum: A) -> &mut Self {
87 self.momentum = momentum;
88 self
89 }
90
91 /// Builder method to set momentum and return self
92 ///
93 /// # Arguments
94 ///
95 /// * `momentum` - The momentum factor (0.0 means no momentum)
96 pub fn with_momentum(mut self, momentum: A) -> Self {
97 self.momentum = momentum;
98 self
99 }
100
101 /// Gets the current momentum factor
102 pub fn get_momentum(&self) -> A {
103 self.momentum
104 }
105
106 /// Gets the current learning rate
107 pub fn learning_rate(&self) -> A {
108 self.learning_rate
109 }
110
111 /// Sets the weight decay factor
112 ///
113 /// # Arguments
114 ///
115 /// * `weight_decay` - The weight decay factor (L2 regularization)
116 pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
117 self.weight_decay = weight_decay;
118 self
119 }
120
121 /// Builder method to set weight decay and return self
122 ///
123 /// # Arguments
124 ///
125 /// * `weight_decay` - The weight decay factor (L2 regularization)
126 pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
127 self.weight_decay = weight_decay;
128 self
129 }
130
131 /// Gets the current weight decay factor
132 pub fn get_weight_decay(&self) -> A {
133 self.weight_decay
134 }
135}
136
137impl<A, D> Optimizer<A, D> for SGD<A>
138where
139 A: Float + ScalarOperand + Debug + Send + Sync,
140 D: Dimension,
141{
142 fn step(&mut self, params: &Array<A, D>, gradients: &Array<A, D>) -> Result<Array<A, D>> {
143 // Convert to dynamic dimension for storage in velocity
144 let params_dyn = params.to_owned().into_dyn();
145 let gradients_dyn = gradients.to_owned().into_dyn();
146
147 // Initialize velocity if this is the first step
148 if self.velocity.is_none() {
149 self.velocity = Some(vec![Array::zeros(params_dyn.raw_dim())]);
150 }
151
152 let velocity = self.velocity.as_mut().unwrap();
153
154 // Ensure we have velocity for this parameter set
155 if velocity.is_empty() {
156 velocity.push(Array::zeros(params_dyn.raw_dim()));
157 } else if velocity[0].raw_dim() != params_dyn.raw_dim() {
158 // If the parameter dimensions have changed, reset velocity
159 velocity[0] = Array::zeros(params_dyn.raw_dim());
160 }
161
162 // Apply weight decay to gradients if needed
163 let adjusted_gradients = if self.weight_decay > A::zero() {
164 &gradients_dyn + &(¶ms_dyn * self.weight_decay)
165 } else {
166 gradients_dyn
167 };
168
169 // Update velocity with momentum
170 if self.momentum > A::zero() {
171 velocity[0] =
172 &velocity[0] * self.momentum + &(&adjusted_gradients * self.learning_rate);
173 } else {
174 velocity[0] = &adjusted_gradients * self.learning_rate;
175 }
176
177 // Update parameters
178 let updated_params = ¶ms_dyn - &velocity[0];
179
180 // Convert back to original dimension
181 Ok(updated_params.into_dimensionality::<D>().unwrap())
182 }
183
184 fn get_learning_rate(&self) -> A {
185 self.learning_rate
186 }
187
188 fn set_learning_rate(&mut self, learning_rate: A) {
189 self.learning_rate = learning_rate;
190 }
191}