1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_beta1, validate_beta2, validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct LambState {
12 first_moment: Tensor,
13 second_moment: Tensor,
14 step: u64,
15}
16
17impl LambState {
18 fn new(shape: &[usize]) -> Result<Self, OptimError> {
19 Ok(Self {
20 first_moment: Tensor::zeros(shape.to_vec())?,
21 second_moment: Tensor::zeros(shape.to_vec())?,
22 step: 0,
23 })
24 }
25
26 fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27 *self = Self::new(shape)?;
28 Ok(())
29 }
30}
31
32#[derive(Debug, Clone)]
37pub struct Lamb {
38 lr: f32,
39 beta1: f32,
40 beta2: f32,
41 epsilon: f32,
42 weight_decay: f32,
43 state: HashMap<u64, LambState>,
44}
45
46impl Lamb {
47 pub fn new(lr: f32) -> Result<Self, OptimError> {
49 validate_lr(lr)?;
50 Ok(Self {
51 lr,
52 beta1: 0.9,
53 beta2: 0.999,
54 epsilon: 1e-6,
55 weight_decay: 0.0,
56 state: HashMap::new(),
57 })
58 }
59
60 pub fn with_betas(mut self, beta1: f32, beta2: f32) -> Result<Self, OptimError> {
62 validate_beta1(beta1)?;
63 validate_beta2(beta2)?;
64 self.beta1 = beta1;
65 self.beta2 = beta2;
66 Ok(self)
67 }
68
69 pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
71 if !weight_decay.is_finite() || weight_decay < 0.0 {
72 return Err(OptimError::InvalidWeightDecay { weight_decay });
73 }
74 self.weight_decay = weight_decay;
75 Ok(self)
76 }
77
78 pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
80 validate_epsilon(epsilon)?;
81 self.epsilon = epsilon;
82 Ok(self)
83 }
84
85 pub fn clear_state(&mut self) {
87 self.state.clear();
88 }
89
90 pub fn learning_rate(&self) -> f32 {
92 self.lr
93 }
94
95 pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
97 validate_lr(lr)?;
98 self.lr = lr;
99 Ok(())
100 }
101
102 pub fn step(
104 &mut self,
105 parameter_id: u64,
106 weights: &mut Tensor,
107 grad: &Tensor,
108 ) -> Result<(), OptimError> {
109 if weights.shape() != grad.shape() {
110 return Err(OptimError::ShapeMismatch {
111 weights: weights.shape().to_vec(),
112 grad: grad.shape().to_vec(),
113 });
114 }
115
116 let state = match self.state.entry(parameter_id) {
117 Entry::Occupied(entry) => entry.into_mut(),
118 Entry::Vacant(entry) => entry.insert(LambState::new(weights.shape())?),
119 };
120 if state.first_moment.shape() != weights.shape()
121 || state.second_moment.shape() != weights.shape()
122 {
123 state.reset(weights.shape())?;
124 }
125
126 state.step = state.step.saturating_add(1);
127 let step_f64 = state.step as f64;
128 let bias_correction1 =
129 (1.0 - (self.beta1 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
130 let bias_correction2 =
131 (1.0 - (self.beta2 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
132
133 let first_moment = state.first_moment.data_mut();
134 let second_moment = state.second_moment.data_mut();
135 let grad_data = grad.data();
136 let weights_data = weights.data_mut();
137
138 let beta1 = self.beta1;
139 let beta2 = self.beta2;
140 let one_minus_beta1 = 1.0 - beta1;
141 let one_minus_beta2 = 1.0 - beta2;
142 let bias_correction1_inv = 1.0 / bias_correction1;
143 let bias_correction2_inv = 1.0 / bias_correction2;
144 let epsilon = self.epsilon;
145 let weight_decay = self.weight_decay;
146
147 let (w_norm_sq, step_norm_sq) = lamb_pass1_inner(
148 weights_data,
149 grad_data,
150 first_moment,
151 second_moment,
152 beta1,
153 beta2,
154 one_minus_beta1,
155 one_minus_beta2,
156 bias_correction1_inv,
157 bias_correction2_inv,
158 epsilon,
159 weight_decay,
160 );
161
162 let w_norm = w_norm_sq.sqrt();
163 let step_norm = step_norm_sq.sqrt();
164 let trust_ratio = if w_norm > 0.0 && step_norm > 0.0 {
165 w_norm / step_norm
166 } else {
167 1.0
168 };
169 let scaled_lr = self.lr * trust_ratio;
170
171 lamb_pass2_inner(
172 weights_data,
173 first_moment,
174 second_moment,
175 bias_correction1_inv,
176 bias_correction2_inv,
177 scaled_lr,
178 epsilon,
179 weight_decay,
180 );
181
182 Ok(())
183 }
184
185 pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
187 if !graph.requires_grad(node)? {
188 return Ok(());
189 }
190
191 let grad = match graph.grad(node)? {
192 Some(grad) => grad.clone(),
193 None => return Err(OptimError::MissingGradient { node: node.0 }),
194 };
195 let weights = graph.value_mut(node)?;
196 self.step(node.0 as u64, weights, &grad)
197 }
198}
199
200impl LearningRate for Lamb {
201 fn learning_rate(&self) -> f32 {
202 Lamb::learning_rate(self)
203 }
204
205 fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
206 Lamb::set_learning_rate(self, lr)
207 }
208}
209
210#[allow(clippy::too_many_arguments, unsafe_code)]
214fn lamb_pass1_inner(
215 weights: &mut [f32],
216 grad: &[f32],
217 first_moment: &mut [f32],
218 second_moment: &mut [f32],
219 beta1: f32,
220 beta2: f32,
221 one_minus_beta1: f32,
222 one_minus_beta2: f32,
223 bc1_inv: f32,
224 bc2_inv: f32,
225 epsilon: f32,
226 weight_decay: f32,
227) -> (f32, f32) {
228 #[cfg(target_arch = "aarch64")]
229 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
230 return unsafe {
231 lamb_pass1_neon(
232 weights,
233 grad,
234 first_moment,
235 second_moment,
236 beta1,
237 beta2,
238 one_minus_beta1,
239 one_minus_beta2,
240 bc1_inv,
241 bc2_inv,
242 epsilon,
243 weight_decay,
244 )
245 };
246 }
247
248 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
250 return unsafe {
251 lamb_pass1_avx(
252 weights,
253 grad,
254 first_moment,
255 second_moment,
256 beta1,
257 beta2,
258 one_minus_beta1,
259 one_minus_beta2,
260 bc1_inv,
261 bc2_inv,
262 epsilon,
263 weight_decay,
264 )
265 };
266 }
267
268 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
269 if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
270 return unsafe {
271 lamb_pass1_sse(
272 weights,
273 grad,
274 first_moment,
275 second_moment,
276 beta1,
277 beta2,
278 one_minus_beta1,
279 one_minus_beta2,
280 bc1_inv,
281 bc2_inv,
282 epsilon,
283 weight_decay,
284 )
285 };
286 }
287
288 let len = weights.len();
289 let wp = weights.as_mut_ptr();
290 let gp = grad.as_ptr();
291 let mp = first_moment.as_mut_ptr();
292 let vp = second_moment.as_mut_ptr();
293 let mut w_norm_sq: f32 = 0.0;
294 let mut step_norm_sq: f32 = 0.0;
295 for i in 0..len {
296 unsafe {
297 let w = *wp.add(i);
298 let g = *gp.add(i);
299 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
300 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
301 *mp.add(i) = m;
302 *vp.add(i) = v;
303 let m_hat = m * bc1_inv;
304 let v_hat = v * bc2_inv;
305 let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * w;
306 w_norm_sq += w * w;
307 step_norm_sq += s * s;
308 }
309 }
310 (w_norm_sq, step_norm_sq)
311}
312
313#[allow(clippy::too_many_arguments, unsafe_code)]
315fn lamb_pass2_inner(
316 weights: &mut [f32],
317 first_moment: &[f32],
318 second_moment: &[f32],
319 bc1_inv: f32,
320 bc2_inv: f32,
321 scaled_lr: f32,
322 epsilon: f32,
323 weight_decay: f32,
324) {
325 #[cfg(target_arch = "aarch64")]
326 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
327 unsafe {
328 lamb_pass2_neon(
329 weights,
330 first_moment,
331 second_moment,
332 bc1_inv,
333 bc2_inv,
334 scaled_lr,
335 epsilon,
336 weight_decay,
337 );
338 }
339 return;
340 }
341
342 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
343 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
344 unsafe {
345 lamb_pass2_avx(
346 weights,
347 first_moment,
348 second_moment,
349 bc1_inv,
350 bc2_inv,
351 scaled_lr,
352 epsilon,
353 weight_decay,
354 );
355 }
356 return;
357 }
358
359 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
360 if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
361 unsafe {
362 lamb_pass2_sse(
363 weights,
364 first_moment,
365 second_moment,
366 bc1_inv,
367 bc2_inv,
368 scaled_lr,
369 epsilon,
370 weight_decay,
371 );
372 }
373 return;
374 }
375
376 let len = weights.len();
377 let wp = weights.as_mut_ptr();
378 let mp = first_moment.as_ptr();
379 let vp = second_moment.as_ptr();
380 for i in 0..len {
381 unsafe {
382 let m_hat = *mp.add(i) * bc1_inv;
383 let v_hat = *vp.add(i) * bc2_inv;
384 let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * *wp.add(i);
385 *wp.add(i) -= scaled_lr * s;
386 }
387 }
388}
389
390#[cfg(target_arch = "aarch64")]
393#[target_feature(enable = "neon")]
394#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
395unsafe fn lamb_pass1_neon(
396 weights: &mut [f32],
397 grad: &[f32],
398 first_moment: &mut [f32],
399 second_moment: &mut [f32],
400 beta1: f32,
401 beta2: f32,
402 one_minus_beta1: f32,
403 one_minus_beta2: f32,
404 bc1_inv: f32,
405 bc2_inv: f32,
406 epsilon: f32,
407 weight_decay: f32,
408) -> (f32, f32) {
409 use std::arch::aarch64::*;
410 let len = weights.len();
411 let wp = weights.as_mut_ptr();
412 let gp = grad.as_ptr();
413 let mp = first_moment.as_mut_ptr();
414 let vp = second_moment.as_mut_ptr();
415 let beta1_v = vdupq_n_f32(beta1);
416 let beta2_v = vdupq_n_f32(beta2);
417 let omb1_v = vdupq_n_f32(one_minus_beta1);
418 let omb2_v = vdupq_n_f32(one_minus_beta2);
419 let bc1_v = vdupq_n_f32(bc1_inv);
420 let bc2_v = vdupq_n_f32(bc2_inv);
421 let eps_v = vdupq_n_f32(epsilon);
422 let wd_v = vdupq_n_f32(weight_decay);
423 let mut w_norm_acc = vdupq_n_f32(0.0);
424 let mut s_norm_acc = vdupq_n_f32(0.0);
425 let mut i = 0usize;
426 while i + 4 <= len {
427 let w = vld1q_f32(wp.add(i));
428 let g = vld1q_f32(gp.add(i));
429 let m_old = vld1q_f32(mp.add(i));
430 let v_old = vld1q_f32(vp.add(i));
431 let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
432 let grad_sq = vmulq_f32(g, g);
433 let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
434 vst1q_f32(mp.add(i), m_new);
435 vst1q_f32(vp.add(i), v_new);
436 let m_hat = vmulq_f32(m_new, bc1_v);
437 let v_hat = vmulq_f32(v_new, bc2_v);
438 let s = vfmaq_f32(
439 vdivq_f32(m_hat, vaddq_f32(vsqrtq_f32(v_hat), eps_v)),
440 wd_v,
441 w,
442 );
443 w_norm_acc = vfmaq_f32(w_norm_acc, w, w);
444 s_norm_acc = vfmaq_f32(s_norm_acc, s, s);
445 i += 4;
446 }
447 let mut w_norm_sq = vaddvq_f32(w_norm_acc);
448 let mut step_norm_sq = vaddvq_f32(s_norm_acc);
449 while i < len {
450 let w = *wp.add(i);
451 let g = *gp.add(i);
452 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
453 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
454 *mp.add(i) = m;
455 *vp.add(i) = v;
456 let m_hat = m * bc1_inv;
457 let v_hat = v * bc2_inv;
458 let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * w;
459 w_norm_sq += w * w;
460 step_norm_sq += s * s;
461 i += 1;
462 }
463 (w_norm_sq, step_norm_sq)
464}
465
466#[cfg(target_arch = "aarch64")]
467#[target_feature(enable = "neon")]
468#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
469unsafe fn lamb_pass2_neon(
470 weights: &mut [f32],
471 first_moment: &[f32],
472 second_moment: &[f32],
473 bc1_inv: f32,
474 bc2_inv: f32,
475 scaled_lr: f32,
476 epsilon: f32,
477 weight_decay: f32,
478) {
479 use std::arch::aarch64::*;
480 let len = weights.len();
481 let wp = weights.as_mut_ptr();
482 let mp = first_moment.as_ptr();
483 let vp = second_moment.as_ptr();
484 let bc1_v = vdupq_n_f32(bc1_inv);
485 let bc2_v = vdupq_n_f32(bc2_inv);
486 let lr_v = vdupq_n_f32(scaled_lr);
487 let eps_v = vdupq_n_f32(epsilon);
488 let wd_v = vdupq_n_f32(weight_decay);
489 let mut i = 0usize;
490 while i + 4 <= len {
491 let w = vld1q_f32(wp.add(i));
492 let m_hat = vmulq_f32(vld1q_f32(mp.add(i)), bc1_v);
493 let v_hat = vmulq_f32(vld1q_f32(vp.add(i)), bc2_v);
494 let s = vfmaq_f32(
495 vdivq_f32(m_hat, vaddq_f32(vsqrtq_f32(v_hat), eps_v)),
496 wd_v,
497 w,
498 );
499 vst1q_f32(wp.add(i), vsubq_f32(w, vmulq_f32(lr_v, s)));
500 i += 4;
501 }
502 while i < len {
503 let m_hat = *mp.add(i) * bc1_inv;
504 let v_hat = *vp.add(i) * bc2_inv;
505 let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * *wp.add(i);
506 *wp.add(i) -= scaled_lr * s;
507 i += 1;
508 }
509}
510
511#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
514#[target_feature(enable = "avx")]
515#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
516unsafe fn lamb_pass1_avx(
517 weights: &mut [f32],
518 grad: &[f32],
519 first_moment: &mut [f32],
520 second_moment: &mut [f32],
521 beta1: f32,
522 beta2: f32,
523 one_minus_beta1: f32,
524 one_minus_beta2: f32,
525 bc1_inv: f32,
526 bc2_inv: f32,
527 epsilon: f32,
528 weight_decay: f32,
529) -> (f32, f32) {
530 #[cfg(target_arch = "x86")]
531 use std::arch::x86::*;
532 #[cfg(target_arch = "x86_64")]
533 use std::arch::x86_64::*;
534 let len = weights.len();
535 let wp = weights.as_mut_ptr();
536 let gp = grad.as_ptr();
537 let mp = first_moment.as_mut_ptr();
538 let vp = second_moment.as_mut_ptr();
539 let beta1_v = _mm256_set1_ps(beta1);
540 let beta2_v = _mm256_set1_ps(beta2);
541 let omb1_v = _mm256_set1_ps(one_minus_beta1);
542 let omb2_v = _mm256_set1_ps(one_minus_beta2);
543 let bc1_v = _mm256_set1_ps(bc1_inv);
544 let bc2_v = _mm256_set1_ps(bc2_inv);
545 let eps_v = _mm256_set1_ps(epsilon);
546 let wd_v = _mm256_set1_ps(weight_decay);
547 let mut w_norm_acc = _mm256_setzero_ps();
548 let mut s_norm_acc = _mm256_setzero_ps();
549 let mut i = 0usize;
550 while i + 8 <= len {
551 let w = _mm256_loadu_ps(wp.add(i));
552 let g = _mm256_loadu_ps(gp.add(i));
553 let m_old = _mm256_loadu_ps(mp.add(i));
554 let v_old = _mm256_loadu_ps(vp.add(i));
555 let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
556 let grad_sq = _mm256_mul_ps(g, g);
557 let v_new = _mm256_add_ps(
558 _mm256_mul_ps(beta2_v, v_old),
559 _mm256_mul_ps(omb2_v, grad_sq),
560 );
561 _mm256_storeu_ps(mp.add(i), m_new);
562 _mm256_storeu_ps(vp.add(i), v_new);
563 let m_hat = _mm256_mul_ps(m_new, bc1_v);
564 let v_hat = _mm256_mul_ps(v_new, bc2_v);
565 let s = _mm256_add_ps(
566 _mm256_div_ps(m_hat, _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v)),
567 _mm256_mul_ps(wd_v, w),
568 );
569 w_norm_acc = _mm256_add_ps(w_norm_acc, _mm256_mul_ps(w, w));
570 s_norm_acc = _mm256_add_ps(s_norm_acc, _mm256_mul_ps(s, s));
571 i += 8;
572 }
573 let w_lo = _mm256_castps256_ps128(w_norm_acc);
575 let w_hi = _mm256_extractf128_ps(w_norm_acc, 1);
576 let w_sum4 = _mm_add_ps(w_lo, w_hi);
577 let w_shuf = _mm_movehdup_ps(w_sum4);
578 let w_sum2 = _mm_add_ps(w_sum4, w_shuf);
579 let w_shuf2 = _mm_movehl_ps(w_sum2, w_sum2);
580 let mut w_norm_sq = _mm_cvtss_f32(_mm_add_ss(w_sum2, w_shuf2));
581
582 let s_lo = _mm256_castps256_ps128(s_norm_acc);
583 let s_hi = _mm256_extractf128_ps(s_norm_acc, 1);
584 let s_sum4 = _mm_add_ps(s_lo, s_hi);
585 let s_shuf = _mm_movehdup_ps(s_sum4);
586 let s_sum2 = _mm_add_ps(s_sum4, s_shuf);
587 let s_shuf2 = _mm_movehl_ps(s_sum2, s_sum2);
588 let mut step_norm_sq = _mm_cvtss_f32(_mm_add_ss(s_sum2, s_shuf2));
589
590 while i < len {
591 let w = *wp.add(i);
592 let g = *gp.add(i);
593 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
594 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
595 *mp.add(i) = m;
596 *vp.add(i) = v;
597 let m_hat = m * bc1_inv;
598 let v_hat = v * bc2_inv;
599 let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * w;
600 w_norm_sq += w * w;
601 step_norm_sq += s * s;
602 i += 1;
603 }
604 (w_norm_sq, step_norm_sq)
605}
606
607#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
608#[target_feature(enable = "avx")]
609#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
610unsafe fn lamb_pass2_avx(
611 weights: &mut [f32],
612 first_moment: &[f32],
613 second_moment: &[f32],
614 bc1_inv: f32,
615 bc2_inv: f32,
616 scaled_lr: f32,
617 epsilon: f32,
618 weight_decay: f32,
619) {
620 #[cfg(target_arch = "x86")]
621 use std::arch::x86::*;
622 #[cfg(target_arch = "x86_64")]
623 use std::arch::x86_64::*;
624 let len = weights.len();
625 let wp = weights.as_mut_ptr();
626 let mp = first_moment.as_ptr();
627 let vp = second_moment.as_ptr();
628 let bc1_v = _mm256_set1_ps(bc1_inv);
629 let bc2_v = _mm256_set1_ps(bc2_inv);
630 let lr_v = _mm256_set1_ps(scaled_lr);
631 let eps_v = _mm256_set1_ps(epsilon);
632 let wd_v = _mm256_set1_ps(weight_decay);
633 let mut i = 0usize;
634 while i + 8 <= len {
635 let w = _mm256_loadu_ps(wp.add(i));
636 let m_hat = _mm256_mul_ps(_mm256_loadu_ps(mp.add(i)), bc1_v);
637 let v_hat = _mm256_mul_ps(_mm256_loadu_ps(vp.add(i)), bc2_v);
638 let s = _mm256_add_ps(
639 _mm256_div_ps(m_hat, _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v)),
640 _mm256_mul_ps(wd_v, w),
641 );
642 _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, _mm256_mul_ps(lr_v, s)));
643 i += 8;
644 }
645 while i < len {
646 let m_hat = *mp.add(i) * bc1_inv;
647 let v_hat = *vp.add(i) * bc2_inv;
648 let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * *wp.add(i);
649 *wp.add(i) -= scaled_lr * s;
650 i += 1;
651 }
652}
653
654#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
657#[target_feature(enable = "sse")]
658#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
659unsafe fn lamb_pass1_sse(
660 weights: &mut [f32],
661 grad: &[f32],
662 first_moment: &mut [f32],
663 second_moment: &mut [f32],
664 beta1: f32,
665 beta2: f32,
666 one_minus_beta1: f32,
667 one_minus_beta2: f32,
668 bc1_inv: f32,
669 bc2_inv: f32,
670 epsilon: f32,
671 weight_decay: f32,
672) -> (f32, f32) {
673 #[cfg(target_arch = "x86")]
674 use std::arch::x86::*;
675 #[cfg(target_arch = "x86_64")]
676 use std::arch::x86_64::*;
677 let len = weights.len();
678 let wp = weights.as_mut_ptr();
679 let gp = grad.as_ptr();
680 let mp = first_moment.as_mut_ptr();
681 let vp = second_moment.as_mut_ptr();
682 let beta1_v = _mm_set1_ps(beta1);
683 let beta2_v = _mm_set1_ps(beta2);
684 let omb1_v = _mm_set1_ps(one_minus_beta1);
685 let omb2_v = _mm_set1_ps(one_minus_beta2);
686 let bc1_v = _mm_set1_ps(bc1_inv);
687 let bc2_v = _mm_set1_ps(bc2_inv);
688 let eps_v = _mm_set1_ps(epsilon);
689 let wd_v = _mm_set1_ps(weight_decay);
690 let mut w_norm_acc = _mm_setzero_ps();
691 let mut s_norm_acc = _mm_setzero_ps();
692 let mut i = 0usize;
693 while i + 4 <= len {
694 let w = _mm_loadu_ps(wp.add(i));
695 let g = _mm_loadu_ps(gp.add(i));
696 let m_old = _mm_loadu_ps(mp.add(i));
697 let v_old = _mm_loadu_ps(vp.add(i));
698 let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
699 let grad_sq = _mm_mul_ps(g, g);
700 let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
701 _mm_storeu_ps(mp.add(i), m_new);
702 _mm_storeu_ps(vp.add(i), v_new);
703 let m_hat = _mm_mul_ps(m_new, bc1_v);
704 let v_hat = _mm_mul_ps(v_new, bc2_v);
705 let s = _mm_add_ps(
706 _mm_div_ps(m_hat, _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v)),
707 _mm_mul_ps(wd_v, w),
708 );
709 w_norm_acc = _mm_add_ps(w_norm_acc, _mm_mul_ps(w, w));
710 s_norm_acc = _mm_add_ps(s_norm_acc, _mm_mul_ps(s, s));
711 i += 4;
712 }
713 let w_shuf = _mm_movehdup_ps(w_norm_acc);
715 let w_sum2 = _mm_add_ps(w_norm_acc, w_shuf);
716 let w_shuf2 = _mm_movehl_ps(w_sum2, w_sum2);
717 let mut w_norm_sq = _mm_cvtss_f32(_mm_add_ss(w_sum2, w_shuf2));
718
719 let s_shuf = _mm_movehdup_ps(s_norm_acc);
720 let s_sum2 = _mm_add_ps(s_norm_acc, s_shuf);
721 let s_shuf2 = _mm_movehl_ps(s_sum2, s_sum2);
722 let mut step_norm_sq = _mm_cvtss_f32(_mm_add_ss(s_sum2, s_shuf2));
723
724 while i < len {
725 let w = *wp.add(i);
726 let g = *gp.add(i);
727 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
728 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
729 *mp.add(i) = m;
730 *vp.add(i) = v;
731 let m_hat = m * bc1_inv;
732 let v_hat = v * bc2_inv;
733 let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * w;
734 w_norm_sq += w * w;
735 step_norm_sq += s * s;
736 i += 1;
737 }
738 (w_norm_sq, step_norm_sq)
739}
740
741#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
742#[target_feature(enable = "sse")]
743#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
744unsafe fn lamb_pass2_sse(
745 weights: &mut [f32],
746 first_moment: &[f32],
747 second_moment: &[f32],
748 bc1_inv: f32,
749 bc2_inv: f32,
750 scaled_lr: f32,
751 epsilon: f32,
752 weight_decay: f32,
753) {
754 #[cfg(target_arch = "x86")]
755 use std::arch::x86::*;
756 #[cfg(target_arch = "x86_64")]
757 use std::arch::x86_64::*;
758 let len = weights.len();
759 let wp = weights.as_mut_ptr();
760 let mp = first_moment.as_ptr();
761 let vp = second_moment.as_ptr();
762 let bc1_v = _mm_set1_ps(bc1_inv);
763 let bc2_v = _mm_set1_ps(bc2_inv);
764 let lr_v = _mm_set1_ps(scaled_lr);
765 let eps_v = _mm_set1_ps(epsilon);
766 let wd_v = _mm_set1_ps(weight_decay);
767 let mut i = 0usize;
768 while i + 4 <= len {
769 let w = _mm_loadu_ps(wp.add(i));
770 let m_hat = _mm_mul_ps(_mm_loadu_ps(mp.add(i)), bc1_v);
771 let v_hat = _mm_mul_ps(_mm_loadu_ps(vp.add(i)), bc2_v);
772 let s = _mm_add_ps(
773 _mm_div_ps(m_hat, _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v)),
774 _mm_mul_ps(wd_v, w),
775 );
776 _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, _mm_mul_ps(lr_v, s)));
777 i += 4;
778 }
779 while i < len {
780 let m_hat = *mp.add(i) * bc1_inv;
781 let v_hat = *vp.add(i) * bc2_inv;
782 let s = m_hat / (v_hat.sqrt() + epsilon) + weight_decay * *wp.add(i);
783 *wp.add(i) -= scaled_lr * s;
784 i += 1;
785 }
786}