1use scirs2_core::ndarray::{Array, Array1, ArrayView, ArrayView1, Dimension, Ix1};
10use scirs2_core::numeric::Float;
11use scirs2_core::simd_ops::SimdUnifiedOps;
12use std::fmt::Debug;
13
14pub trait SimdOptimizer<T: Float> {
19 fn simd_sgd_update(
31 params: &ArrayView1<T>,
32 gradients: &ArrayView1<T>,
33 learning_rate: T,
34 ) -> Array1<T>;
35
36 fn simd_momentum_update(
53 params: &ArrayView1<T>,
54 gradients: &ArrayView1<T>,
55 velocity: &ArrayView1<T>,
56 learning_rate: T,
57 momentum: T,
58 ) -> (Array1<T>, Array1<T>);
59
60 fn simd_adam_first_moment(m: &ArrayView1<T>, gradients: &ArrayView1<T>, beta1: T) -> Array1<T>;
74
75 fn simd_adam_second_moment(v: &ArrayView1<T>, gradients: &ArrayView1<T>, beta2: T)
89 -> Array1<T>;
90
91 fn simd_adam_update(
107 params: &ArrayView1<T>,
108 m_hat: &ArrayView1<T>,
109 v_hat: &ArrayView1<T>,
110 learning_rate: T,
111 epsilon: T,
112 ) -> Array1<T>;
113
114 fn simd_weight_decay(
128 gradients: &ArrayView1<T>,
129 params: &ArrayView1<T>,
130 weight_decay: T,
131 ) -> Array1<T>;
132
133 fn simd_gradient_norm(gradients: &ArrayView1<T>) -> T;
143}
144
145impl SimdOptimizer<f32> for f32 {
147 fn simd_sgd_update(
148 params: &ArrayView1<f32>,
149 gradients: &ArrayView1<f32>,
150 learning_rate: f32,
151 ) -> Array1<f32> {
152 if params.len() >= 16 {
154 let scaled_grads = f32::simd_scalar_mul(gradients, learning_rate);
156 f32::simd_sub(params, &scaled_grads.view())
157 } else {
158 params
160 .iter()
161 .zip(gradients.iter())
162 .map(|(&p, &g)| p - learning_rate * g)
163 .collect()
164 }
165 }
166
167 fn simd_momentum_update(
168 params: &ArrayView1<f32>,
169 gradients: &ArrayView1<f32>,
170 velocity: &ArrayView1<f32>,
171 learning_rate: f32,
172 momentum: f32,
173 ) -> (Array1<f32>, Array1<f32>) {
174 if params.len() >= 16 {
175 let scaled_velocity = f32::simd_scalar_mul(velocity, momentum);
178 let scaled_gradients = f32::simd_scalar_mul(gradients, learning_rate);
179 let new_velocity = f32::simd_add(&scaled_velocity.view(), &scaled_gradients.view());
180
181 let new_params = f32::simd_sub(params, &new_velocity.view());
183
184 (new_params, new_velocity)
185 } else {
186 let new_velocity: Array1<f32> = velocity
188 .iter()
189 .zip(gradients.iter())
190 .map(|(&v, &g)| momentum * v + learning_rate * g)
191 .collect();
192
193 let new_params: Array1<f32> = params
194 .iter()
195 .zip(new_velocity.iter())
196 .map(|(&p, &v)| p - v)
197 .collect();
198
199 (new_params, new_velocity)
200 }
201 }
202
203 fn simd_adam_first_moment(
204 m: &ArrayView1<f32>,
205 gradients: &ArrayView1<f32>,
206 beta1: f32,
207 ) -> Array1<f32> {
208 if m.len() >= 16 {
209 let scaled_m = f32::simd_scalar_mul(m, beta1);
211 let scaled_grads = f32::simd_scalar_mul(gradients, 1.0 - beta1);
212 f32::simd_add(&scaled_m.view(), &scaled_grads.view())
213 } else {
214 m.iter()
216 .zip(gradients.iter())
217 .map(|(&m_val, &g)| beta1 * m_val + (1.0 - beta1) * g)
218 .collect()
219 }
220 }
221
222 fn simd_adam_second_moment(
223 v: &ArrayView1<f32>,
224 gradients: &ArrayView1<f32>,
225 beta2: f32,
226 ) -> Array1<f32> {
227 if v.len() >= 16 {
228 let scaled_v = f32::simd_scalar_mul(v, beta2);
230 let grad_squared = f32::simd_mul(gradients, gradients);
231 let scaled_grad_squared = f32::simd_scalar_mul(&grad_squared.view(), 1.0 - beta2);
232 f32::simd_add(&scaled_v.view(), &scaled_grad_squared.view())
233 } else {
234 v.iter()
236 .zip(gradients.iter())
237 .map(|(&v_val, &g)| beta2 * v_val + (1.0 - beta2) * g * g)
238 .collect()
239 }
240 }
241
242 fn simd_adam_update(
243 params: &ArrayView1<f32>,
244 m_hat: &ArrayView1<f32>,
245 v_hat: &ArrayView1<f32>,
246 learning_rate: f32,
247 epsilon: f32,
248 ) -> Array1<f32> {
249 if params.len() >= 16 {
250 let v_hat_sqrt: Array1<f32> = v_hat.iter().map(|&v| v.sqrt() + epsilon).collect();
253
254 let step = f32::simd_div(m_hat, &v_hat_sqrt.view());
256
257 let scaled_step = f32::simd_scalar_mul(&step.view(), learning_rate);
259
260 f32::simd_sub(params, &scaled_step.view())
262 } else {
263 params
265 .iter()
266 .zip(m_hat.iter().zip(v_hat.iter()))
267 .map(|(&p, (&m, &v))| p - learning_rate * m / (v.sqrt() + epsilon))
268 .collect()
269 }
270 }
271
272 fn simd_weight_decay(
273 gradients: &ArrayView1<f32>,
274 params: &ArrayView1<f32>,
275 weight_decay: f32,
276 ) -> Array1<f32> {
277 if gradients.len() >= 16 {
278 let scaled_params = f32::simd_scalar_mul(params, weight_decay);
280 f32::simd_add(gradients, &scaled_params.view())
281 } else {
282 gradients
284 .iter()
285 .zip(params.iter())
286 .map(|(&g, &p)| g + weight_decay * p)
287 .collect()
288 }
289 }
290
291 fn simd_gradient_norm(gradients: &ArrayView1<f32>) -> f32 {
292 if gradients.len() >= 16 {
293 f32::simd_dot(gradients, gradients).sqrt()
295 } else {
296 gradients.iter().map(|&x| x * x).sum::<f32>().sqrt()
298 }
299 }
300}
301
302impl SimdOptimizer<f64> for f64 {
304 fn simd_sgd_update(
305 params: &ArrayView1<f64>,
306 gradients: &ArrayView1<f64>,
307 learning_rate: f64,
308 ) -> Array1<f64> {
309 if params.len() >= 8 {
310 let scaled_grads = f64::simd_scalar_mul(gradients, learning_rate);
312 f64::simd_sub(params, &scaled_grads.view())
313 } else {
314 params
316 .iter()
317 .zip(gradients.iter())
318 .map(|(&p, &g)| p - learning_rate * g)
319 .collect()
320 }
321 }
322
323 fn simd_momentum_update(
324 params: &ArrayView1<f64>,
325 gradients: &ArrayView1<f64>,
326 velocity: &ArrayView1<f64>,
327 learning_rate: f64,
328 momentum: f64,
329 ) -> (Array1<f64>, Array1<f64>) {
330 if params.len() >= 8 {
331 let scaled_velocity = f64::simd_scalar_mul(velocity, momentum);
333 let scaled_gradients = f64::simd_scalar_mul(gradients, learning_rate);
334 let new_velocity = f64::simd_add(&scaled_velocity.view(), &scaled_gradients.view());
335 let new_params = f64::simd_sub(params, &new_velocity.view());
336 (new_params, new_velocity)
337 } else {
338 let new_velocity: Array1<f64> = velocity
340 .iter()
341 .zip(gradients.iter())
342 .map(|(&v, &g)| momentum * v + learning_rate * g)
343 .collect();
344 let new_params: Array1<f64> = params
345 .iter()
346 .zip(new_velocity.iter())
347 .map(|(&p, &v)| p - v)
348 .collect();
349 (new_params, new_velocity)
350 }
351 }
352
353 fn simd_adam_first_moment(
354 m: &ArrayView1<f64>,
355 gradients: &ArrayView1<f64>,
356 beta1: f64,
357 ) -> Array1<f64> {
358 if m.len() >= 8 {
359 let scaled_m = f64::simd_scalar_mul(m, beta1);
361 let scaled_grads = f64::simd_scalar_mul(gradients, 1.0 - beta1);
362 f64::simd_add(&scaled_m.view(), &scaled_grads.view())
363 } else {
364 m.iter()
366 .zip(gradients.iter())
367 .map(|(&m_val, &g)| beta1 * m_val + (1.0 - beta1) * g)
368 .collect()
369 }
370 }
371
372 fn simd_adam_second_moment(
373 v: &ArrayView1<f64>,
374 gradients: &ArrayView1<f64>,
375 beta2: f64,
376 ) -> Array1<f64> {
377 if v.len() >= 8 {
378 let scaled_v = f64::simd_scalar_mul(v, beta2);
380 let grad_squared = f64::simd_mul(gradients, gradients);
381 let scaled_grad_squared = f64::simd_scalar_mul(&grad_squared.view(), 1.0 - beta2);
382 f64::simd_add(&scaled_v.view(), &scaled_grad_squared.view())
383 } else {
384 v.iter()
386 .zip(gradients.iter())
387 .map(|(&v_val, &g)| beta2 * v_val + (1.0 - beta2) * g * g)
388 .collect()
389 }
390 }
391
392 fn simd_adam_update(
393 params: &ArrayView1<f64>,
394 m_hat: &ArrayView1<f64>,
395 v_hat: &ArrayView1<f64>,
396 learning_rate: f64,
397 epsilon: f64,
398 ) -> Array1<f64> {
399 if params.len() >= 8 {
400 let v_hat_sqrt: Array1<f64> = v_hat.iter().map(|&v| v.sqrt() + epsilon).collect();
402 let step = f64::simd_div(m_hat, &v_hat_sqrt.view());
403 let scaled_step = f64::simd_scalar_mul(&step.view(), learning_rate);
404 f64::simd_sub(params, &scaled_step.view())
405 } else {
406 params
408 .iter()
409 .zip(m_hat.iter().zip(v_hat.iter()))
410 .map(|(&p, (&m, &v))| p - learning_rate * m / (v.sqrt() + epsilon))
411 .collect()
412 }
413 }
414
415 fn simd_weight_decay(
416 gradients: &ArrayView1<f64>,
417 params: &ArrayView1<f64>,
418 weight_decay: f64,
419 ) -> Array1<f64> {
420 if gradients.len() >= 8 {
421 let scaled_params = f64::simd_scalar_mul(params, weight_decay);
423 f64::simd_add(gradients, &scaled_params.view())
424 } else {
425 gradients
427 .iter()
428 .zip(params.iter())
429 .map(|(&g, &p)| g + weight_decay * p)
430 .collect()
431 }
432 }
433
434 fn simd_gradient_norm(gradients: &ArrayView1<f64>) -> f64 {
435 if gradients.len() >= 8 {
436 f64::simd_dot(gradients, gradients).sqrt()
438 } else {
439 gradients.iter().map(|&x| x * x).sum::<f64>().sqrt()
441 }
442 }
443}
444
445pub fn should_use_simd(size: usize, dtype_size: usize) -> bool {
456 let min_simd_size = match dtype_size {
458 4 => 16, 8 => 8, _ => usize::MAX, };
462
463 size >= min_simd_size
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use approx::assert_relative_eq;
470
471 #[test]
472 fn test_simd_sgd_update_f32() {
473 let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
474 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
475 let learning_rate = 0.1;
476
477 let result = f32::simd_sgd_update(¶ms.view(), &gradients.view(), learning_rate);
478
479 assert_relative_eq!(result[0], 0.99, epsilon = 1e-6);
480 assert_relative_eq!(result[1], 1.98, epsilon = 1e-6);
481 assert_relative_eq!(result[2], 2.97, epsilon = 1e-6);
482 assert_relative_eq!(result[3], 3.96, epsilon = 1e-6);
483 }
484
485 #[test]
486 fn test_simd_sgd_update_f64() {
487 let params = Array1::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]);
488 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
489 let learning_rate = 0.1;
490
491 let result = f64::simd_sgd_update(¶ms.view(), &gradients.view(), learning_rate);
492
493 assert_relative_eq!(result[0], 0.99, epsilon = 1e-10);
494 assert_relative_eq!(result[1], 1.98, epsilon = 1e-10);
495 assert_relative_eq!(result[2], 2.97, epsilon = 1e-10);
496 assert_relative_eq!(result[3], 3.96, epsilon = 1e-10);
497 }
498
499 #[test]
500 fn test_simd_momentum_update() {
501 let params = Array1::from_vec(vec![1.0f32, 2.0, 3.0, 4.0]);
502 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
503 let velocity = Array1::from_vec(vec![0.01, 0.02, 0.03, 0.04]);
504 let learning_rate = 0.1;
505 let momentum = 0.9;
506
507 let (new_params, new_velocity) = f32::simd_momentum_update(
508 ¶ms.view(),
509 &gradients.view(),
510 &velocity.view(),
511 learning_rate,
512 momentum,
513 );
514
515 assert_relative_eq!(new_velocity[0], 0.9 * 0.01 + 0.1 * 0.1, epsilon = 1e-6);
517
518 assert_relative_eq!(new_params[0], 1.0 - new_velocity[0], epsilon = 1e-6);
520 }
521
522 #[test]
523 fn test_simd_adam_first_moment() {
524 let m = Array1::from_vec(vec![0.01f32, 0.02, 0.03, 0.04]);
525 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
526 let beta1 = 0.9;
527
528 let result = f32::simd_adam_first_moment(&m.view(), &gradients.view(), beta1);
529
530 assert_relative_eq!(result[0], 0.9 * 0.01 + 0.1 * 0.1, epsilon = 1e-6);
531 assert_relative_eq!(result[1], 0.9 * 0.02 + 0.1 * 0.2, epsilon = 1e-6);
532 }
533
534 #[test]
535 fn test_simd_adam_second_moment() {
536 let v = Array1::from_vec(vec![0.001f32, 0.002, 0.003, 0.004]);
537 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
538 let beta2 = 0.999;
539
540 let result = f32::simd_adam_second_moment(&v.view(), &gradients.view(), beta2);
541
542 assert_relative_eq!(result[0], 0.999 * 0.001 + 0.001 * 0.1 * 0.1, epsilon = 1e-6);
543 }
544
545 #[test]
546 fn test_simd_weight_decay() {
547 let gradients = Array1::from_vec(vec![0.1f32, 0.2, 0.3, 0.4]);
548 let params = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
549 let weight_decay = 0.01;
550
551 let result = f32::simd_weight_decay(&gradients.view(), ¶ms.view(), weight_decay);
552
553 assert_relative_eq!(result[0], 0.1 + 0.01 * 1.0, epsilon = 1e-6);
554 assert_relative_eq!(result[1], 0.2 + 0.01 * 2.0, epsilon = 1e-6);
555 }
556
557 #[test]
558 fn test_simd_gradient_norm() {
559 let gradients = Array1::from_vec(vec![3.0f32, 4.0]);
560 let norm = f32::simd_gradient_norm(&gradients.view());
561 assert_relative_eq!(norm, 5.0, epsilon = 1e-6);
562
563 let gradients_f64 = Array1::from_vec(vec![3.0f64, 4.0]);
564 let norm_f64 = f64::simd_gradient_norm(&gradients_f64.view());
565 assert_relative_eq!(norm_f64, 5.0, epsilon = 1e-10);
566 }
567
568 #[test]
569 fn test_should_use_simd() {
570 assert!(!should_use_simd(8, 4)); assert!(should_use_simd(16, 4)); assert!(should_use_simd(100, 4)); assert!(!should_use_simd(4, 8)); assert!(should_use_simd(8, 8)); assert!(should_use_simd(100, 8)); }
580
581 #[test]
582 fn test_simd_large_array() {
583 let size = 1000;
585 let params: Array1<f32> = Array1::from_vec((0..size).map(|i| i as f32).collect());
586 let gradients: Array1<f32> = Array1::from_vec(vec![0.1; size]);
587 let learning_rate = 0.01;
588
589 let result = f32::simd_sgd_update(¶ms.view(), &gradients.view(), learning_rate);
590
591 for i in 0..size {
592 assert_relative_eq!(result[i], (i as f32) - learning_rate * 0.1, epsilon = 1e-6);
593 }
594 }
595}