1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_beta1, validate_beta2, validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct RAdamState {
12 first_moment: Tensor,
13 second_moment: Tensor,
14 step: u64,
15}
16
17impl RAdamState {
18 fn new(shape: &[usize]) -> Result<Self, OptimError> {
19 Ok(Self {
20 first_moment: Tensor::zeros(shape.to_vec())?,
21 second_moment: Tensor::zeros(shape.to_vec())?,
22 step: 0,
23 })
24 }
25
26 fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27 *self = Self::new(shape)?;
28 Ok(())
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct RAdam {
35 lr: f32,
36 beta1: f32,
37 beta2: f32,
38 epsilon: f32,
39 weight_decay: f32,
40 state: HashMap<u64, RAdamState>,
41}
42
43impl RAdam {
44 pub fn new(lr: f32) -> Result<Self, OptimError> {
46 validate_lr(lr)?;
47 Ok(Self {
48 lr,
49 beta1: 0.9,
50 beta2: 0.999,
51 epsilon: 1e-8,
52 weight_decay: 0.0,
53 state: HashMap::new(),
54 })
55 }
56
57 pub fn with_beta1(mut self, beta1: f32) -> Result<Self, OptimError> {
59 validate_beta1(beta1)?;
60 self.beta1 = beta1;
61 Ok(self)
62 }
63
64 pub fn with_beta2(mut self, beta2: f32) -> Result<Self, OptimError> {
66 validate_beta2(beta2)?;
67 self.beta2 = beta2;
68 Ok(self)
69 }
70
71 pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
73 validate_epsilon(epsilon)?;
74 self.epsilon = epsilon;
75 Ok(self)
76 }
77
78 pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
80 if !weight_decay.is_finite() || weight_decay < 0.0 {
81 return Err(OptimError::InvalidWeightDecay { weight_decay });
82 }
83 self.weight_decay = weight_decay;
84 Ok(self)
85 }
86
87 pub fn clear_state(&mut self) {
89 self.state.clear();
90 }
91
92 pub fn learning_rate(&self) -> f32 {
94 self.lr
95 }
96
97 pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
99 validate_lr(lr)?;
100 self.lr = lr;
101 Ok(())
102 }
103
104 pub fn step(
106 &mut self,
107 parameter_id: u64,
108 weights: &mut Tensor,
109 grad: &Tensor,
110 ) -> Result<(), OptimError> {
111 if weights.shape() != grad.shape() {
112 return Err(OptimError::ShapeMismatch {
113 weights: weights.shape().to_vec(),
114 grad: grad.shape().to_vec(),
115 });
116 }
117
118 let state = match self.state.entry(parameter_id) {
119 Entry::Occupied(entry) => entry.into_mut(),
120 Entry::Vacant(entry) => entry.insert(RAdamState::new(weights.shape())?),
121 };
122 if state.first_moment.shape() != weights.shape()
123 || state.second_moment.shape() != weights.shape()
124 {
125 state.reset(weights.shape())?;
126 }
127
128 state.step = state.step.saturating_add(1);
129 let step_f64 = state.step as f64;
130 let beta1_f64 = self.beta1 as f64;
131 let beta2_f64 = self.beta2 as f64;
132
133 let bias_correction1 = (1.0 - beta1_f64.powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
134
135 let rho_inf = 2.0 / (1.0 - beta2_f64) - 1.0;
137 let beta2_pow_t = beta2_f64.powf(step_f64);
139 let rho_t = rho_inf - 2.0 * step_f64 * beta2_pow_t / (1.0 - beta2_pow_t);
140
141 let use_adaptive = rho_t > 5.0;
142
143 let (r_t, bias_correction2) = if use_adaptive {
144 let bc2 = (1.0 - beta2_pow_t).max(f64::MIN_POSITIVE) as f32;
145 let r = ((rho_t - 4.0) * (rho_t - 2.0) * rho_inf
146 / ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t))
147 .sqrt() as f32;
148 (r, bc2)
149 } else {
150 (1.0_f32, 1.0_f32)
151 };
152
153 let first_moment = state.first_moment.data_mut();
154 let second_moment = state.second_moment.data_mut();
155 let grad_values = grad.data();
156 let weights_data = weights.data_mut();
157
158 let beta1 = self.beta1;
159 let beta2 = self.beta2;
160 let one_minus_beta1 = 1.0 - beta1;
161 let one_minus_beta2 = 1.0 - beta2;
162 let bias_correction1_inv = 1.0 / bias_correction1;
163 let bias_correction2_inv = 1.0 / bias_correction2;
164 let lr = self.lr;
165 let epsilon = self.epsilon;
166 let weight_decay = self.weight_decay;
167
168 radam_update_inner(
169 weights_data,
170 grad_values,
171 first_moment,
172 second_moment,
173 beta1,
174 beta2,
175 one_minus_beta1,
176 one_minus_beta2,
177 bias_correction1_inv,
178 bias_correction2_inv,
179 lr,
180 epsilon,
181 weight_decay,
182 r_t,
183 use_adaptive,
184 );
185
186 Ok(())
187 }
188
189 pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
191 if !graph.requires_grad(node)? {
192 return Ok(());
193 }
194
195 let grad = match graph.grad(node)? {
196 Some(grad) => grad.clone(),
197 None => return Err(OptimError::MissingGradient { node: node.0 }),
198 };
199 let weights = graph.value_mut(node)?;
200 self.step(node.0 as u64, weights, &grad)
201 }
202}
203
204impl LearningRate for RAdam {
205 fn learning_rate(&self) -> f32 {
206 RAdam::learning_rate(self)
207 }
208
209 fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
210 RAdam::set_learning_rate(self, lr)
211 }
212}
213
214#[allow(clippy::too_many_arguments, unsafe_code)]
216fn radam_update_inner(
217 weights: &mut [f32],
218 grad: &[f32],
219 first_moment: &mut [f32],
220 second_moment: &mut [f32],
221 beta1: f32,
222 beta2: f32,
223 one_minus_beta1: f32,
224 one_minus_beta2: f32,
225 bc1_inv: f32,
226 bc2_inv: f32,
227 lr: f32,
228 epsilon: f32,
229 weight_decay: f32,
230 r_t: f32,
231 use_adaptive: bool,
232) {
233 #[cfg(target_arch = "aarch64")]
234 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
235 unsafe {
236 radam_update_neon(
237 weights,
238 grad,
239 first_moment,
240 second_moment,
241 beta1,
242 beta2,
243 one_minus_beta1,
244 one_minus_beta2,
245 bc1_inv,
246 bc2_inv,
247 lr,
248 epsilon,
249 weight_decay,
250 r_t,
251 use_adaptive,
252 );
253 }
254 return;
255 }
256
257 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
258 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
259 unsafe {
260 radam_update_avx(
261 weights,
262 grad,
263 first_moment,
264 second_moment,
265 beta1,
266 beta2,
267 one_minus_beta1,
268 one_minus_beta2,
269 bc1_inv,
270 bc2_inv,
271 lr,
272 epsilon,
273 weight_decay,
274 r_t,
275 use_adaptive,
276 );
277 }
278 return;
279 }
280
281 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
282 if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
283 unsafe {
284 radam_update_sse(
285 weights,
286 grad,
287 first_moment,
288 second_moment,
289 beta1,
290 beta2,
291 one_minus_beta1,
292 one_minus_beta2,
293 bc1_inv,
294 bc2_inv,
295 lr,
296 epsilon,
297 weight_decay,
298 r_t,
299 use_adaptive,
300 );
301 }
302 return;
303 }
304
305 let len = weights.len();
306 let wp = weights.as_mut_ptr();
307 let gp = grad.as_ptr();
308 let mp = first_moment.as_mut_ptr();
309 let vp = second_moment.as_mut_ptr();
310 for i in 0..len {
311 unsafe {
312 let g = *gp.add(i) + weight_decay * *wp.add(i);
313 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
314 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
315 *mp.add(i) = m;
316 *vp.add(i) = v;
317 let m_hat = m * bc1_inv;
318 if use_adaptive {
319 let v_hat = v * bc2_inv;
320 *wp.add(i) -= lr * r_t * m_hat / (v_hat.sqrt() + epsilon);
321 } else {
322 *wp.add(i) -= lr * m_hat;
323 }
324 }
325 }
326}
327
328#[cfg(target_arch = "aarch64")]
331#[target_feature(enable = "neon")]
332#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
333unsafe fn radam_update_neon(
334 weights: &mut [f32],
335 grad: &[f32],
336 first_moment: &mut [f32],
337 second_moment: &mut [f32],
338 beta1: f32,
339 beta2: f32,
340 one_minus_beta1: f32,
341 one_minus_beta2: f32,
342 bc1_inv: f32,
343 bc2_inv: f32,
344 lr: f32,
345 epsilon: f32,
346 weight_decay: f32,
347 r_t: f32,
348 use_adaptive: bool,
349) {
350 use std::arch::aarch64::*;
351 let len = weights.len();
352 let wp = weights.as_mut_ptr();
353 let gp = grad.as_ptr();
354 let mp = first_moment.as_mut_ptr();
355 let vp = second_moment.as_mut_ptr();
356 let beta1_v = vdupq_n_f32(beta1);
357 let beta2_v = vdupq_n_f32(beta2);
358 let omb1_v = vdupq_n_f32(one_minus_beta1);
359 let omb2_v = vdupq_n_f32(one_minus_beta2);
360 let bc1_v = vdupq_n_f32(bc1_inv);
361 let wd_v = vdupq_n_f32(weight_decay);
362 let mut i = 0usize;
363
364 if use_adaptive {
365 let bc2_v = vdupq_n_f32(bc2_inv);
366 let lr_rt_v = vdupq_n_f32(lr * r_t);
367 let eps_v = vdupq_n_f32(epsilon);
368 while i + 4 <= len {
369 let w = vld1q_f32(wp.add(i));
370 let raw_g = vld1q_f32(gp.add(i));
371 let g = vfmaq_f32(raw_g, wd_v, w);
372 let m_old = vld1q_f32(mp.add(i));
373 let v_old = vld1q_f32(vp.add(i));
374 let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
375 let grad_sq = vmulq_f32(g, g);
376 let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
377 vst1q_f32(mp.add(i), m_new);
378 vst1q_f32(vp.add(i), v_new);
379 let m_hat = vmulq_f32(m_new, bc1_v);
380 let v_hat = vmulq_f32(v_new, bc2_v);
381 let update = vdivq_f32(
382 vmulq_f32(lr_rt_v, m_hat),
383 vaddq_f32(vsqrtq_f32(v_hat), eps_v),
384 );
385 vst1q_f32(wp.add(i), vsubq_f32(w, update));
386 i += 4;
387 }
388 } else {
389 let lr_v = vdupq_n_f32(lr);
390 while i + 4 <= len {
391 let w = vld1q_f32(wp.add(i));
392 let raw_g = vld1q_f32(gp.add(i));
393 let g = vfmaq_f32(raw_g, wd_v, w);
394 let m_old = vld1q_f32(mp.add(i));
395 let v_old = vld1q_f32(vp.add(i));
396 let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
397 let grad_sq = vmulq_f32(g, g);
398 let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
399 vst1q_f32(mp.add(i), m_new);
400 vst1q_f32(vp.add(i), v_new);
401 let m_hat = vmulq_f32(m_new, bc1_v);
402 vst1q_f32(wp.add(i), vsubq_f32(w, vmulq_f32(lr_v, m_hat)));
403 i += 4;
404 }
405 }
406
407 while i < len {
408 let g = *gp.add(i) + weight_decay * *wp.add(i);
409 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
410 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
411 *mp.add(i) = m;
412 *vp.add(i) = v;
413 let m_hat = m * bc1_inv;
414 if use_adaptive {
415 let v_hat = v * bc2_inv;
416 *wp.add(i) -= lr * r_t * m_hat / (v_hat.sqrt() + epsilon);
417 } else {
418 *wp.add(i) -= lr * m_hat;
419 }
420 i += 1;
421 }
422}
423
424#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
427#[target_feature(enable = "avx")]
428#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
429unsafe fn radam_update_avx(
430 weights: &mut [f32],
431 grad: &[f32],
432 first_moment: &mut [f32],
433 second_moment: &mut [f32],
434 beta1: f32,
435 beta2: f32,
436 one_minus_beta1: f32,
437 one_minus_beta2: f32,
438 bc1_inv: f32,
439 bc2_inv: f32,
440 lr: f32,
441 epsilon: f32,
442 weight_decay: f32,
443 r_t: f32,
444 use_adaptive: bool,
445) {
446 #[cfg(target_arch = "x86")]
447 use std::arch::x86::*;
448 #[cfg(target_arch = "x86_64")]
449 use std::arch::x86_64::*;
450 let len = weights.len();
451 let wp = weights.as_mut_ptr();
452 let gp = grad.as_ptr();
453 let mp = first_moment.as_mut_ptr();
454 let vp = second_moment.as_mut_ptr();
455 let beta1_v = _mm256_set1_ps(beta1);
456 let beta2_v = _mm256_set1_ps(beta2);
457 let omb1_v = _mm256_set1_ps(one_minus_beta1);
458 let omb2_v = _mm256_set1_ps(one_minus_beta2);
459 let bc1_v = _mm256_set1_ps(bc1_inv);
460 let wd_v = _mm256_set1_ps(weight_decay);
461 let mut i = 0usize;
462
463 if use_adaptive {
464 let bc2_v = _mm256_set1_ps(bc2_inv);
465 let lr_rt_v = _mm256_set1_ps(lr * r_t);
466 let eps_v = _mm256_set1_ps(epsilon);
467 while i + 8 <= len {
468 let w = _mm256_loadu_ps(wp.add(i));
469 let raw_g = _mm256_loadu_ps(gp.add(i));
470 let g = _mm256_add_ps(raw_g, _mm256_mul_ps(wd_v, w));
471 let m_old = _mm256_loadu_ps(mp.add(i));
472 let v_old = _mm256_loadu_ps(vp.add(i));
473 let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
474 let grad_sq = _mm256_mul_ps(g, g);
475 let v_new = _mm256_add_ps(
476 _mm256_mul_ps(beta2_v, v_old),
477 _mm256_mul_ps(omb2_v, grad_sq),
478 );
479 _mm256_storeu_ps(mp.add(i), m_new);
480 _mm256_storeu_ps(vp.add(i), v_new);
481 let m_hat = _mm256_mul_ps(m_new, bc1_v);
482 let v_hat = _mm256_mul_ps(v_new, bc2_v);
483 let update = _mm256_div_ps(
484 _mm256_mul_ps(lr_rt_v, m_hat),
485 _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v),
486 );
487 _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, update));
488 i += 8;
489 }
490 } else {
491 let lr_v = _mm256_set1_ps(lr);
492 while i + 8 <= len {
493 let w = _mm256_loadu_ps(wp.add(i));
494 let raw_g = _mm256_loadu_ps(gp.add(i));
495 let g = _mm256_add_ps(raw_g, _mm256_mul_ps(wd_v, w));
496 let m_old = _mm256_loadu_ps(mp.add(i));
497 let v_old = _mm256_loadu_ps(vp.add(i));
498 let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
499 let grad_sq = _mm256_mul_ps(g, g);
500 let v_new = _mm256_add_ps(
501 _mm256_mul_ps(beta2_v, v_old),
502 _mm256_mul_ps(omb2_v, grad_sq),
503 );
504 _mm256_storeu_ps(mp.add(i), m_new);
505 _mm256_storeu_ps(vp.add(i), v_new);
506 let m_hat = _mm256_mul_ps(m_new, bc1_v);
507 _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, _mm256_mul_ps(lr_v, m_hat)));
508 i += 8;
509 }
510 }
511
512 while i < len {
513 let g = *gp.add(i) + weight_decay * *wp.add(i);
514 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
515 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
516 *mp.add(i) = m;
517 *vp.add(i) = v;
518 let m_hat = m * bc1_inv;
519 if use_adaptive {
520 let v_hat = v * bc2_inv;
521 *wp.add(i) -= lr * r_t * m_hat / (v_hat.sqrt() + epsilon);
522 } else {
523 *wp.add(i) -= lr * m_hat;
524 }
525 i += 1;
526 }
527}
528
529#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
532#[target_feature(enable = "sse")]
533#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
534unsafe fn radam_update_sse(
535 weights: &mut [f32],
536 grad: &[f32],
537 first_moment: &mut [f32],
538 second_moment: &mut [f32],
539 beta1: f32,
540 beta2: f32,
541 one_minus_beta1: f32,
542 one_minus_beta2: f32,
543 bc1_inv: f32,
544 bc2_inv: f32,
545 lr: f32,
546 epsilon: f32,
547 weight_decay: f32,
548 r_t: f32,
549 use_adaptive: bool,
550) {
551 #[cfg(target_arch = "x86")]
552 use std::arch::x86::*;
553 #[cfg(target_arch = "x86_64")]
554 use std::arch::x86_64::*;
555 let len = weights.len();
556 let wp = weights.as_mut_ptr();
557 let gp = grad.as_ptr();
558 let mp = first_moment.as_mut_ptr();
559 let vp = second_moment.as_mut_ptr();
560 let beta1_v = _mm_set1_ps(beta1);
561 let beta2_v = _mm_set1_ps(beta2);
562 let omb1_v = _mm_set1_ps(one_minus_beta1);
563 let omb2_v = _mm_set1_ps(one_minus_beta2);
564 let bc1_v = _mm_set1_ps(bc1_inv);
565 let wd_v = _mm_set1_ps(weight_decay);
566 let mut i = 0usize;
567
568 if use_adaptive {
569 let bc2_v = _mm_set1_ps(bc2_inv);
570 let lr_rt_v = _mm_set1_ps(lr * r_t);
571 let eps_v = _mm_set1_ps(epsilon);
572 while i + 4 <= len {
573 let w = _mm_loadu_ps(wp.add(i));
574 let raw_g = _mm_loadu_ps(gp.add(i));
575 let g = _mm_add_ps(raw_g, _mm_mul_ps(wd_v, w));
576 let m_old = _mm_loadu_ps(mp.add(i));
577 let v_old = _mm_loadu_ps(vp.add(i));
578 let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
579 let grad_sq = _mm_mul_ps(g, g);
580 let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
581 _mm_storeu_ps(mp.add(i), m_new);
582 _mm_storeu_ps(vp.add(i), v_new);
583 let m_hat = _mm_mul_ps(m_new, bc1_v);
584 let v_hat = _mm_mul_ps(v_new, bc2_v);
585 let update = _mm_div_ps(
586 _mm_mul_ps(lr_rt_v, m_hat),
587 _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v),
588 );
589 _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, update));
590 i += 4;
591 }
592 } else {
593 let lr_v = _mm_set1_ps(lr);
594 while i + 4 <= len {
595 let w = _mm_loadu_ps(wp.add(i));
596 let raw_g = _mm_loadu_ps(gp.add(i));
597 let g = _mm_add_ps(raw_g, _mm_mul_ps(wd_v, w));
598 let m_old = _mm_loadu_ps(mp.add(i));
599 let v_old = _mm_loadu_ps(vp.add(i));
600 let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
601 let grad_sq = _mm_mul_ps(g, g);
602 let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
603 _mm_storeu_ps(mp.add(i), m_new);
604 _mm_storeu_ps(vp.add(i), v_new);
605 let m_hat = _mm_mul_ps(m_new, bc1_v);
606 _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, _mm_mul_ps(lr_v, m_hat)));
607 i += 4;
608 }
609 }
610
611 while i < len {
612 let g = *gp.add(i) + weight_decay * *wp.add(i);
613 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
614 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
615 *mp.add(i) = m;
616 *vp.add(i) = v;
617 let m_hat = m * bc1_inv;
618 if use_adaptive {
619 let v_hat = v * bc2_inv;
620 *wp.add(i) -= lr * r_t * m_hat / (v_hat.sqrt() + epsilon);
621 } else {
622 *wp.add(i) -= lr * m_hat;
623 }
624 i += 1;
625 }
626}