Skip to main content

yscv_optim/
radam.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_beta1, validate_beta2, validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct RAdamState {
12    first_moment: Tensor,
13    second_moment: Tensor,
14    step: u64,
15}
16
17impl RAdamState {
18    fn new(shape: &[usize]) -> Result<Self, OptimError> {
19        Ok(Self {
20            first_moment: Tensor::zeros(shape.to_vec())?,
21            second_moment: Tensor::zeros(shape.to_vec())?,
22            step: 0,
23        })
24    }
25
26    fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27        *self = Self::new(shape)?;
28        Ok(())
29    }
30}
31
32/// RAdam (Rectified Adam) optimizer with variance rectification.
33#[derive(Debug, Clone)]
34pub struct RAdam {
35    lr: f32,
36    beta1: f32,
37    beta2: f32,
38    epsilon: f32,
39    weight_decay: f32,
40    state: HashMap<u64, RAdamState>,
41}
42
43impl RAdam {
44    /// Creates RAdam with required learning rate.
45    pub fn new(lr: f32) -> Result<Self, OptimError> {
46        validate_lr(lr)?;
47        Ok(Self {
48            lr,
49            beta1: 0.9,
50            beta2: 0.999,
51            epsilon: 1e-8,
52            weight_decay: 0.0,
53            state: HashMap::new(),
54        })
55    }
56
57    /// Sets beta1 factor in `[0, 1)`.
58    pub fn with_beta1(mut self, beta1: f32) -> Result<Self, OptimError> {
59        validate_beta1(beta1)?;
60        self.beta1 = beta1;
61        Ok(self)
62    }
63
64    /// Sets beta2 factor in `[0, 1)`.
65    pub fn with_beta2(mut self, beta2: f32) -> Result<Self, OptimError> {
66        validate_beta2(beta2)?;
67        self.beta2 = beta2;
68        Ok(self)
69    }
70
71    /// Sets epsilon value, must be finite and `> 0`.
72    pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
73        validate_epsilon(epsilon)?;
74        self.epsilon = epsilon;
75        Ok(self)
76    }
77
78    /// Sets L2 weight decay factor in `[0, +inf)`.
79    pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
80        if !weight_decay.is_finite() || weight_decay < 0.0 {
81            return Err(OptimError::InvalidWeightDecay { weight_decay });
82        }
83        self.weight_decay = weight_decay;
84        Ok(self)
85    }
86
87    /// Drops optimizer state (for example when restarting training).
88    pub fn clear_state(&mut self) {
89        self.state.clear();
90    }
91
92    /// Returns current learning rate.
93    pub fn learning_rate(&self) -> f32 {
94        self.lr
95    }
96
97    /// Overrides current learning rate.
98    pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
99        validate_lr(lr)?;
100        self.lr = lr;
101        Ok(())
102    }
103
104    /// Applies one update to raw tensor weights.
105    pub fn step(
106        &mut self,
107        parameter_id: u64,
108        weights: &mut Tensor,
109        grad: &Tensor,
110    ) -> Result<(), OptimError> {
111        if weights.shape() != grad.shape() {
112            return Err(OptimError::ShapeMismatch {
113                weights: weights.shape().to_vec(),
114                grad: grad.shape().to_vec(),
115            });
116        }
117
118        let state = match self.state.entry(parameter_id) {
119            Entry::Occupied(entry) => entry.into_mut(),
120            Entry::Vacant(entry) => entry.insert(RAdamState::new(weights.shape())?),
121        };
122        if state.first_moment.shape() != weights.shape()
123            || state.second_moment.shape() != weights.shape()
124        {
125            state.reset(weights.shape())?;
126        }
127
128        state.step = state.step.saturating_add(1);
129        let step_f64 = state.step as f64;
130        let beta1_f64 = self.beta1 as f64;
131        let beta2_f64 = self.beta2 as f64;
132
133        let bias_correction1 = (1.0 - beta1_f64.powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
134
135        // rho_inf = 2 / (1 - beta2) - 1
136        let rho_inf = 2.0 / (1.0 - beta2_f64) - 1.0;
137        // rho_t = rho_inf - 2 * t * beta2^t / (1 - beta2^t)
138        let beta2_pow_t = beta2_f64.powf(step_f64);
139        let rho_t = rho_inf - 2.0 * step_f64 * beta2_pow_t / (1.0 - beta2_pow_t);
140
141        let use_adaptive = rho_t > 5.0;
142
143        let (r_t, bias_correction2) = if use_adaptive {
144            let bc2 = (1.0 - beta2_pow_t).max(f64::MIN_POSITIVE) as f32;
145            let r = ((rho_t - 4.0) * (rho_t - 2.0) * rho_inf
146                / ((rho_inf - 4.0) * (rho_inf - 2.0) * rho_t))
147                .sqrt() as f32;
148            (r, bc2)
149        } else {
150            (1.0_f32, 1.0_f32)
151        };
152
153        let first_moment = state.first_moment.data_mut();
154        let second_moment = state.second_moment.data_mut();
155        let grad_values = grad.data();
156        let weights_data = weights.data_mut();
157
158        let beta1 = self.beta1;
159        let beta2 = self.beta2;
160        let one_minus_beta1 = 1.0 - beta1;
161        let one_minus_beta2 = 1.0 - beta2;
162        let bias_correction1_inv = 1.0 / bias_correction1;
163        let bias_correction2_inv = 1.0 / bias_correction2;
164        let lr = self.lr;
165        let epsilon = self.epsilon;
166        let weight_decay = self.weight_decay;
167
168        radam_update_inner(
169            weights_data,
170            grad_values,
171            first_moment,
172            second_moment,
173            beta1,
174            beta2,
175            one_minus_beta1,
176            one_minus_beta2,
177            bias_correction1_inv,
178            bias_correction2_inv,
179            lr,
180            epsilon,
181            weight_decay,
182            r_t,
183            use_adaptive,
184        );
185
186        Ok(())
187    }
188
189    /// Applies one update to a trainable graph node by its `NodeId`.
190    pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
191        if !graph.requires_grad(node)? {
192            return Ok(());
193        }
194
195        let grad = match graph.grad(node)? {
196            Some(grad) => grad.clone(),
197            None => return Err(OptimError::MissingGradient { node: node.0 }),
198        };
199        let weights = graph.value_mut(node)?;
200        self.step(node.0 as u64, weights, &grad)
201    }
202}
203
204impl LearningRate for RAdam {
205    fn learning_rate(&self) -> f32 {
206        RAdam::learning_rate(self)
207    }
208
209    fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
210        RAdam::set_learning_rate(self, lr)
211    }
212}
213
214/// SIMD-accelerated RAdam parameter update.
215#[allow(clippy::too_many_arguments, unsafe_code)]
216fn radam_update_inner(
217    weights: &mut [f32],
218    grad: &[f32],
219    first_moment: &mut [f32],
220    second_moment: &mut [f32],
221    beta1: f32,
222    beta2: f32,
223    one_minus_beta1: f32,
224    one_minus_beta2: f32,
225    bc1_inv: f32,
226    bc2_inv: f32,
227    lr: f32,
228    epsilon: f32,
229    weight_decay: f32,
230    r_t: f32,
231    use_adaptive: bool,
232) {
233    #[cfg(target_arch = "aarch64")]
234    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
235        unsafe {
236            radam_update_neon(
237                weights,
238                grad,
239                first_moment,
240                second_moment,
241                beta1,
242                beta2,
243                one_minus_beta1,
244                one_minus_beta2,
245                bc1_inv,
246                bc2_inv,
247                lr,
248                epsilon,
249                weight_decay,
250                r_t,
251                use_adaptive,
252            );
253        }
254        return;
255    }
256
257    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
258    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
259        unsafe {
260            radam_update_avx(
261                weights,
262                grad,
263                first_moment,
264                second_moment,
265                beta1,
266                beta2,
267                one_minus_beta1,
268                one_minus_beta2,
269                bc1_inv,
270                bc2_inv,
271                lr,
272                epsilon,
273                weight_decay,
274                r_t,
275                use_adaptive,
276            );
277        }
278        return;
279    }
280
281    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
282    if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
283        unsafe {
284            radam_update_sse(
285                weights,
286                grad,
287                first_moment,
288                second_moment,
289                beta1,
290                beta2,
291                one_minus_beta1,
292                one_minus_beta2,
293                bc1_inv,
294                bc2_inv,
295                lr,
296                epsilon,
297                weight_decay,
298                r_t,
299                use_adaptive,
300            );
301        }
302        return;
303    }
304
305    let len = weights.len();
306    let wp = weights.as_mut_ptr();
307    let gp = grad.as_ptr();
308    let mp = first_moment.as_mut_ptr();
309    let vp = second_moment.as_mut_ptr();
310    for i in 0..len {
311        unsafe {
312            let g = *gp.add(i) + weight_decay * *wp.add(i);
313            let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
314            let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
315            *mp.add(i) = m;
316            *vp.add(i) = v;
317            let m_hat = m * bc1_inv;
318            if use_adaptive {
319                let v_hat = v * bc2_inv;
320                *wp.add(i) -= lr * r_t * m_hat / (v_hat.sqrt() + epsilon);
321            } else {
322                *wp.add(i) -= lr * m_hat;
323            }
324        }
325    }
326}
327
328// ── NEON implementation ─────────────────────────────────────────────────
329
330#[cfg(target_arch = "aarch64")]
331#[target_feature(enable = "neon")]
332#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
333unsafe fn radam_update_neon(
334    weights: &mut [f32],
335    grad: &[f32],
336    first_moment: &mut [f32],
337    second_moment: &mut [f32],
338    beta1: f32,
339    beta2: f32,
340    one_minus_beta1: f32,
341    one_minus_beta2: f32,
342    bc1_inv: f32,
343    bc2_inv: f32,
344    lr: f32,
345    epsilon: f32,
346    weight_decay: f32,
347    r_t: f32,
348    use_adaptive: bool,
349) {
350    use std::arch::aarch64::*;
351    let len = weights.len();
352    let wp = weights.as_mut_ptr();
353    let gp = grad.as_ptr();
354    let mp = first_moment.as_mut_ptr();
355    let vp = second_moment.as_mut_ptr();
356    let beta1_v = vdupq_n_f32(beta1);
357    let beta2_v = vdupq_n_f32(beta2);
358    let omb1_v = vdupq_n_f32(one_minus_beta1);
359    let omb2_v = vdupq_n_f32(one_minus_beta2);
360    let bc1_v = vdupq_n_f32(bc1_inv);
361    let wd_v = vdupq_n_f32(weight_decay);
362    let mut i = 0usize;
363
364    if use_adaptive {
365        let bc2_v = vdupq_n_f32(bc2_inv);
366        let lr_rt_v = vdupq_n_f32(lr * r_t);
367        let eps_v = vdupq_n_f32(epsilon);
368        while i + 4 <= len {
369            let w = vld1q_f32(wp.add(i));
370            let raw_g = vld1q_f32(gp.add(i));
371            let g = vfmaq_f32(raw_g, wd_v, w);
372            let m_old = vld1q_f32(mp.add(i));
373            let v_old = vld1q_f32(vp.add(i));
374            let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
375            let grad_sq = vmulq_f32(g, g);
376            let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
377            vst1q_f32(mp.add(i), m_new);
378            vst1q_f32(vp.add(i), v_new);
379            let m_hat = vmulq_f32(m_new, bc1_v);
380            let v_hat = vmulq_f32(v_new, bc2_v);
381            let update = vdivq_f32(
382                vmulq_f32(lr_rt_v, m_hat),
383                vaddq_f32(vsqrtq_f32(v_hat), eps_v),
384            );
385            vst1q_f32(wp.add(i), vsubq_f32(w, update));
386            i += 4;
387        }
388    } else {
389        let lr_v = vdupq_n_f32(lr);
390        while i + 4 <= len {
391            let w = vld1q_f32(wp.add(i));
392            let raw_g = vld1q_f32(gp.add(i));
393            let g = vfmaq_f32(raw_g, wd_v, w);
394            let m_old = vld1q_f32(mp.add(i));
395            let v_old = vld1q_f32(vp.add(i));
396            let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
397            let grad_sq = vmulq_f32(g, g);
398            let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
399            vst1q_f32(mp.add(i), m_new);
400            vst1q_f32(vp.add(i), v_new);
401            let m_hat = vmulq_f32(m_new, bc1_v);
402            vst1q_f32(wp.add(i), vsubq_f32(w, vmulq_f32(lr_v, m_hat)));
403            i += 4;
404        }
405    }
406
407    while i < len {
408        let g = *gp.add(i) + weight_decay * *wp.add(i);
409        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
410        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
411        *mp.add(i) = m;
412        *vp.add(i) = v;
413        let m_hat = m * bc1_inv;
414        if use_adaptive {
415            let v_hat = v * bc2_inv;
416            *wp.add(i) -= lr * r_t * m_hat / (v_hat.sqrt() + epsilon);
417        } else {
418            *wp.add(i) -= lr * m_hat;
419        }
420        i += 1;
421    }
422}
423
424// ── AVX implementation ──────────────────────────────────────────────────
425
426#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
427#[target_feature(enable = "avx")]
428#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
429unsafe fn radam_update_avx(
430    weights: &mut [f32],
431    grad: &[f32],
432    first_moment: &mut [f32],
433    second_moment: &mut [f32],
434    beta1: f32,
435    beta2: f32,
436    one_minus_beta1: f32,
437    one_minus_beta2: f32,
438    bc1_inv: f32,
439    bc2_inv: f32,
440    lr: f32,
441    epsilon: f32,
442    weight_decay: f32,
443    r_t: f32,
444    use_adaptive: bool,
445) {
446    #[cfg(target_arch = "x86")]
447    use std::arch::x86::*;
448    #[cfg(target_arch = "x86_64")]
449    use std::arch::x86_64::*;
450    let len = weights.len();
451    let wp = weights.as_mut_ptr();
452    let gp = grad.as_ptr();
453    let mp = first_moment.as_mut_ptr();
454    let vp = second_moment.as_mut_ptr();
455    let beta1_v = _mm256_set1_ps(beta1);
456    let beta2_v = _mm256_set1_ps(beta2);
457    let omb1_v = _mm256_set1_ps(one_minus_beta1);
458    let omb2_v = _mm256_set1_ps(one_minus_beta2);
459    let bc1_v = _mm256_set1_ps(bc1_inv);
460    let wd_v = _mm256_set1_ps(weight_decay);
461    let mut i = 0usize;
462
463    if use_adaptive {
464        let bc2_v = _mm256_set1_ps(bc2_inv);
465        let lr_rt_v = _mm256_set1_ps(lr * r_t);
466        let eps_v = _mm256_set1_ps(epsilon);
467        while i + 8 <= len {
468            let w = _mm256_loadu_ps(wp.add(i));
469            let raw_g = _mm256_loadu_ps(gp.add(i));
470            let g = _mm256_add_ps(raw_g, _mm256_mul_ps(wd_v, w));
471            let m_old = _mm256_loadu_ps(mp.add(i));
472            let v_old = _mm256_loadu_ps(vp.add(i));
473            let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
474            let grad_sq = _mm256_mul_ps(g, g);
475            let v_new = _mm256_add_ps(
476                _mm256_mul_ps(beta2_v, v_old),
477                _mm256_mul_ps(omb2_v, grad_sq),
478            );
479            _mm256_storeu_ps(mp.add(i), m_new);
480            _mm256_storeu_ps(vp.add(i), v_new);
481            let m_hat = _mm256_mul_ps(m_new, bc1_v);
482            let v_hat = _mm256_mul_ps(v_new, bc2_v);
483            let update = _mm256_div_ps(
484                _mm256_mul_ps(lr_rt_v, m_hat),
485                _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v),
486            );
487            _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, update));
488            i += 8;
489        }
490    } else {
491        let lr_v = _mm256_set1_ps(lr);
492        while i + 8 <= len {
493            let w = _mm256_loadu_ps(wp.add(i));
494            let raw_g = _mm256_loadu_ps(gp.add(i));
495            let g = _mm256_add_ps(raw_g, _mm256_mul_ps(wd_v, w));
496            let m_old = _mm256_loadu_ps(mp.add(i));
497            let v_old = _mm256_loadu_ps(vp.add(i));
498            let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
499            let grad_sq = _mm256_mul_ps(g, g);
500            let v_new = _mm256_add_ps(
501                _mm256_mul_ps(beta2_v, v_old),
502                _mm256_mul_ps(omb2_v, grad_sq),
503            );
504            _mm256_storeu_ps(mp.add(i), m_new);
505            _mm256_storeu_ps(vp.add(i), v_new);
506            let m_hat = _mm256_mul_ps(m_new, bc1_v);
507            _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, _mm256_mul_ps(lr_v, m_hat)));
508            i += 8;
509        }
510    }
511
512    while i < len {
513        let g = *gp.add(i) + weight_decay * *wp.add(i);
514        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
515        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
516        *mp.add(i) = m;
517        *vp.add(i) = v;
518        let m_hat = m * bc1_inv;
519        if use_adaptive {
520            let v_hat = v * bc2_inv;
521            *wp.add(i) -= lr * r_t * m_hat / (v_hat.sqrt() + epsilon);
522        } else {
523            *wp.add(i) -= lr * m_hat;
524        }
525        i += 1;
526    }
527}
528
529// ── SSE implementation ──────────────────────────────────────────────────
530
531#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
532#[target_feature(enable = "sse")]
533#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
534unsafe fn radam_update_sse(
535    weights: &mut [f32],
536    grad: &[f32],
537    first_moment: &mut [f32],
538    second_moment: &mut [f32],
539    beta1: f32,
540    beta2: f32,
541    one_minus_beta1: f32,
542    one_minus_beta2: f32,
543    bc1_inv: f32,
544    bc2_inv: f32,
545    lr: f32,
546    epsilon: f32,
547    weight_decay: f32,
548    r_t: f32,
549    use_adaptive: bool,
550) {
551    #[cfg(target_arch = "x86")]
552    use std::arch::x86::*;
553    #[cfg(target_arch = "x86_64")]
554    use std::arch::x86_64::*;
555    let len = weights.len();
556    let wp = weights.as_mut_ptr();
557    let gp = grad.as_ptr();
558    let mp = first_moment.as_mut_ptr();
559    let vp = second_moment.as_mut_ptr();
560    let beta1_v = _mm_set1_ps(beta1);
561    let beta2_v = _mm_set1_ps(beta2);
562    let omb1_v = _mm_set1_ps(one_minus_beta1);
563    let omb2_v = _mm_set1_ps(one_minus_beta2);
564    let bc1_v = _mm_set1_ps(bc1_inv);
565    let wd_v = _mm_set1_ps(weight_decay);
566    let mut i = 0usize;
567
568    if use_adaptive {
569        let bc2_v = _mm_set1_ps(bc2_inv);
570        let lr_rt_v = _mm_set1_ps(lr * r_t);
571        let eps_v = _mm_set1_ps(epsilon);
572        while i + 4 <= len {
573            let w = _mm_loadu_ps(wp.add(i));
574            let raw_g = _mm_loadu_ps(gp.add(i));
575            let g = _mm_add_ps(raw_g, _mm_mul_ps(wd_v, w));
576            let m_old = _mm_loadu_ps(mp.add(i));
577            let v_old = _mm_loadu_ps(vp.add(i));
578            let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
579            let grad_sq = _mm_mul_ps(g, g);
580            let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
581            _mm_storeu_ps(mp.add(i), m_new);
582            _mm_storeu_ps(vp.add(i), v_new);
583            let m_hat = _mm_mul_ps(m_new, bc1_v);
584            let v_hat = _mm_mul_ps(v_new, bc2_v);
585            let update = _mm_div_ps(
586                _mm_mul_ps(lr_rt_v, m_hat),
587                _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v),
588            );
589            _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, update));
590            i += 4;
591        }
592    } else {
593        let lr_v = _mm_set1_ps(lr);
594        while i + 4 <= len {
595            let w = _mm_loadu_ps(wp.add(i));
596            let raw_g = _mm_loadu_ps(gp.add(i));
597            let g = _mm_add_ps(raw_g, _mm_mul_ps(wd_v, w));
598            let m_old = _mm_loadu_ps(mp.add(i));
599            let v_old = _mm_loadu_ps(vp.add(i));
600            let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
601            let grad_sq = _mm_mul_ps(g, g);
602            let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
603            _mm_storeu_ps(mp.add(i), m_new);
604            _mm_storeu_ps(vp.add(i), v_new);
605            let m_hat = _mm_mul_ps(m_new, bc1_v);
606            _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, _mm_mul_ps(lr_v, m_hat)));
607            i += 4;
608        }
609    }
610
611    while i < len {
612        let g = *gp.add(i) + weight_decay * *wp.add(i);
613        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
614        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
615        *mp.add(i) = m;
616        *vp.add(i) = v;
617        let m_hat = m * bc1_inv;
618        if use_adaptive {
619            let v_hat = v * bc2_inv;
620            *wp.add(i) -= lr * r_t * m_hat / (v_hat.sqrt() + epsilon);
621        } else {
622            *wp.add(i) -= lr * m_hat;
623        }
624        i += 1;
625    }
626}