1use scirs2_core::ndarray::{Array1, ArrayView1};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::Result;
11use crate::optimizers::Optimizer;
12use crate::simd_optimizer::SimdOptimizer;
13
14#[derive(Debug, Clone)]
50pub struct SimdSGD<A: Float> {
51 learning_rate: A,
53 momentum: A,
55 weight_decay: A,
57 velocity: Option<Array1<A>>,
59}
60
61impl<A: Float> SimdSGD<A> {
62 pub fn new(learning_rate: A) -> Self {
68 Self {
69 learning_rate,
70 momentum: A::zero(),
71 weight_decay: A::zero(),
72 velocity: None,
73 }
74 }
75
76 pub fn new_with_config(learning_rate: A, momentum: A, weight_decay: A) -> Self {
84 Self {
85 learning_rate,
86 momentum,
87 weight_decay,
88 velocity: None,
89 }
90 }
91
92 pub fn set_momentum(&mut self, momentum: A) -> &mut Self {
94 self.momentum = momentum;
95 self
96 }
97
98 pub fn with_momentum(mut self, momentum: A) -> Self {
100 self.momentum = momentum;
101 self
102 }
103
104 pub fn get_momentum(&self) -> A {
106 self.momentum
107 }
108
109 pub fn learning_rate(&self) -> A {
111 self.learning_rate
112 }
113
114 pub fn set_weight_decay(&mut self, weight_decay: A) -> &mut Self {
116 self.weight_decay = weight_decay;
117 self
118 }
119
120 pub fn with_weight_decay(mut self, weight_decay: A) -> Self {
122 self.weight_decay = weight_decay;
123 self
124 }
125
126 pub fn get_weight_decay(&self) -> A {
128 self.weight_decay
129 }
130
131 pub fn reset(&mut self) {
133 self.velocity = None;
134 }
135}
136
137impl Optimizer<f32, scirs2_core::ndarray::Ix1> for SimdSGD<f32> {
139 fn step(&mut self, params: &Array1<f32>, gradients: &Array1<f32>) -> Result<Array1<f32>> {
140 if params.shape() != gradients.shape() {
142 return Err(crate::error::OptimError::DimensionMismatch(format!(
143 "Incompatible shapes: parameters have shape {:?}, gradients have shape {:?}",
144 params.shape(),
145 gradients.shape()
146 )));
147 }
148
149 let params_view = params.view();
150 let gradients_view = gradients.view();
151
152 let adjusted_gradients = if self.weight_decay > 0.0 {
154 f32::simd_weight_decay(&gradients_view, ¶ms_view, self.weight_decay)
155 } else {
156 gradients.to_owned()
157 };
158
159 if self.velocity.is_none() {
161 self.velocity = Some(Array1::zeros(params.len()));
162 }
163
164 let velocity = self.velocity.as_mut().unwrap();
165
166 if velocity.len() != params.len() {
168 *velocity = Array1::zeros(params.len());
169 }
170
171 let new_params = if self.momentum > 0.0 {
173 let (updated_params, updated_velocity) = f32::simd_momentum_update(
175 ¶ms_view,
176 &adjusted_gradients.view(),
177 &velocity.view(),
178 self.learning_rate,
179 self.momentum,
180 );
181 *velocity = updated_velocity;
182 updated_params
183 } else {
184 f32::simd_sgd_update(¶ms_view, &adjusted_gradients.view(), self.learning_rate)
186 };
187
188 Ok(new_params)
189 }
190
191 fn get_learning_rate(&self) -> f32 {
192 self.learning_rate
193 }
194
195 fn set_learning_rate(&mut self, learning_rate: f32) {
196 self.learning_rate = learning_rate;
197 }
198}
199
200impl Optimizer<f64, scirs2_core::ndarray::Ix1> for SimdSGD<f64> {
202 fn step(&mut self, params: &Array1<f64>, gradients: &Array1<f64>) -> Result<Array1<f64>> {
203 if params.shape() != gradients.shape() {
205 return Err(crate::error::OptimError::DimensionMismatch(format!(
206 "Incompatible shapes: parameters have shape {:?}, gradients have shape {:?}",
207 params.shape(),
208 gradients.shape()
209 )));
210 }
211
212 let params_view = params.view();
213 let gradients_view = gradients.view();
214
215 let adjusted_gradients = if self.weight_decay > 0.0 {
217 f64::simd_weight_decay(&gradients_view, ¶ms_view, self.weight_decay)
218 } else {
219 gradients.to_owned()
220 };
221
222 if self.velocity.is_none() {
224 self.velocity = Some(Array1::zeros(params.len()));
225 }
226
227 let velocity = self.velocity.as_mut().unwrap();
228
229 if velocity.len() != params.len() {
231 *velocity = Array1::zeros(params.len());
232 }
233
234 let new_params = if self.momentum > 0.0 {
236 let (updated_params, updated_velocity) = f64::simd_momentum_update(
238 ¶ms_view,
239 &adjusted_gradients.view(),
240 &velocity.view(),
241 self.learning_rate,
242 self.momentum,
243 );
244 *velocity = updated_velocity;
245 updated_params
246 } else {
247 f64::simd_sgd_update(¶ms_view, &adjusted_gradients.view(), self.learning_rate)
249 };
250
251 Ok(new_params)
252 }
253
254 fn get_learning_rate(&self) -> f64 {
255 self.learning_rate
256 }
257
258 fn set_learning_rate(&mut self, learning_rate: f64) {
259 self.learning_rate = learning_rate;
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use approx::assert_relative_eq;
267
268 #[test]
269 fn test_simd_sgd_basic() {
270 let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
271 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
272
273 let mut optimizer = SimdSGD::new(0.1);
274 let result = optimizer.step(¶ms, &gradients).unwrap();
275
276 assert_relative_eq!(result[0], 0.99, epsilon = 1e-6);
277 assert_relative_eq!(result[1], 1.98, epsilon = 1e-6);
278 assert_relative_eq!(result[2], 2.97, epsilon = 1e-6);
279 assert_relative_eq!(result[3], 3.96, epsilon = 1e-6);
280 }
281
282 #[test]
283 fn test_simd_sgd_momentum() {
284 let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
285 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
286
287 let mut optimizer = SimdSGD::new_with_config(0.1, 0.9, 0.0);
288
289 let result1 = optimizer.step(¶ms, &gradients).unwrap();
291
292 let result2 = optimizer.step(&result1, &gradients).unwrap();
294
295 assert!(result2[0] < result1[0]);
297 }
298
299 #[test]
300 fn test_simd_sgd_weight_decay() {
301 let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
302 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
303
304 let mut optimizer = SimdSGD::new_with_config(0.1, 0.0, 0.01);
305 let result = optimizer.step(¶ms, &gradients).unwrap();
306
307 let expected_grad = 0.1 + 0.01 * 1.0;
309 assert_relative_eq!(result[0], 1.0 - 0.1 * expected_grad, epsilon = 1e-6);
310 }
311
312 #[test]
313 fn test_simd_sgd_large_array() {
314 let size = 1000;
316 let params: Array1<f32> = Array1::from_vec((0..size).map(|i| i as f32).collect());
317 let gradients: Array1<f32> = Array1::from_elem(size, 0.1);
318
319 let mut optimizer = SimdSGD::new(0.01);
320 let result = optimizer.step(¶ms, &gradients).unwrap();
321
322 for i in 0..size {
323 assert_relative_eq!(result[i], (i as f32) - 0.01 * 0.1, epsilon = 1e-6);
324 }
325 }
326
327 #[test]
328 fn test_simd_sgd_f64() {
329 let params = Array1::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
330 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
331
332 let mut optimizer = SimdSGD::new(0.1);
333 let result = optimizer.step(¶ms, &gradients).unwrap();
334
335 assert_relative_eq!(result[0], 0.99, epsilon = 1e-10);
336 assert_relative_eq!(result[1], 1.98, epsilon = 1e-10);
337 assert_relative_eq!(result[2], 2.97, epsilon = 1e-10);
338 assert_relative_eq!(result[3], 3.96, epsilon = 1e-10);
339 }
340
341 #[test]
342 fn test_simd_sgd_reset() {
343 let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
344 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
345
346 let mut optimizer = SimdSGD::new_with_config(0.1, 0.9, 0.0);
347
348 let _ = optimizer.step(¶ms, &gradients).unwrap();
350 assert!(optimizer.velocity.is_some());
351
352 optimizer.reset();
354 assert!(optimizer.velocity.is_none());
355 }
356}