Skip to main content

yscv_optim/
adamw.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_beta1, validate_beta2, validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct AdamWState {
12    first_moment: Tensor,
13    second_moment: Tensor,
14    step: u64,
15}
16
17impl AdamWState {
18    fn new(shape: &[usize]) -> Result<Self, OptimError> {
19        Ok(Self {
20            first_moment: Tensor::zeros(shape.to_vec())?,
21            second_moment: Tensor::zeros(shape.to_vec())?,
22            step: 0,
23        })
24    }
25
26    fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27        *self = Self::new(shape)?;
28        Ok(())
29    }
30}
31
32/// AdamW optimizer with decoupled weight decay.
33#[derive(Debug, Clone)]
34pub struct AdamW {
35    lr: f32,
36    beta1: f32,
37    beta2: f32,
38    epsilon: f32,
39    weight_decay: f32,
40    state: HashMap<u64, AdamWState>,
41}
42
43impl AdamW {
44    /// Creates AdamW with required learning rate.
45    pub fn new(lr: f32) -> Result<Self, OptimError> {
46        validate_lr(lr)?;
47        Ok(Self {
48            lr,
49            beta1: 0.9,
50            beta2: 0.999,
51            epsilon: 1e-8,
52            weight_decay: 0.0,
53            state: HashMap::new(),
54        })
55    }
56
57    /// Sets beta1 factor in `[0, 1)`.
58    pub fn with_beta1(mut self, beta1: f32) -> Result<Self, OptimError> {
59        validate_beta1(beta1)?;
60        self.beta1 = beta1;
61        Ok(self)
62    }
63
64    /// Sets beta2 factor in `[0, 1)`.
65    pub fn with_beta2(mut self, beta2: f32) -> Result<Self, OptimError> {
66        validate_beta2(beta2)?;
67        self.beta2 = beta2;
68        Ok(self)
69    }
70
71    /// Sets epsilon value, must be finite and `> 0`.
72    pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
73        validate_epsilon(epsilon)?;
74        self.epsilon = epsilon;
75        Ok(self)
76    }
77
78    /// Sets decoupled weight decay factor in `[0, +inf)`.
79    pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
80        if !weight_decay.is_finite() || weight_decay < 0.0 {
81            return Err(OptimError::InvalidWeightDecay { weight_decay });
82        }
83        self.weight_decay = weight_decay;
84        Ok(self)
85    }
86
87    /// Drops optimizer state (for example when restarting training).
88    pub fn clear_state(&mut self) {
89        self.state.clear();
90    }
91
92    /// Returns current learning rate.
93    pub fn learning_rate(&self) -> f32 {
94        self.lr
95    }
96
97    /// Overrides current learning rate.
98    pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
99        validate_lr(lr)?;
100        self.lr = lr;
101        Ok(())
102    }
103
104    /// Applies one update to raw tensor weights.
105    pub fn step(
106        &mut self,
107        parameter_id: u64,
108        weights: &mut Tensor,
109        grad: &Tensor,
110    ) -> Result<(), OptimError> {
111        if weights.shape() != grad.shape() {
112            return Err(OptimError::ShapeMismatch {
113                weights: weights.shape().to_vec(),
114                grad: grad.shape().to_vec(),
115            });
116        }
117
118        let state = match self.state.entry(parameter_id) {
119            Entry::Occupied(entry) => entry.into_mut(),
120            Entry::Vacant(entry) => entry.insert(AdamWState::new(weights.shape())?),
121        };
122        if state.first_moment.shape() != weights.shape()
123            || state.second_moment.shape() != weights.shape()
124        {
125            state.reset(weights.shape())?;
126        }
127
128        state.step = state.step.saturating_add(1);
129        let step_f64 = state.step as f64;
130        let bias_correction1 =
131            (1.0 - (self.beta1 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
132        let bias_correction2 =
133            (1.0 - (self.beta2 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
134
135        let first_moment = state.first_moment.data_mut();
136        let second_moment = state.second_moment.data_mut();
137        let grad_values = grad.data();
138        let weights_data = weights.data_mut();
139
140        let beta1 = self.beta1;
141        let beta2 = self.beta2;
142        let one_minus_beta1 = 1.0 - beta1;
143        let one_minus_beta2 = 1.0 - beta2;
144        let bias_correction1_inv = 1.0 / bias_correction1;
145        let bias_correction2_inv = 1.0 / bias_correction2;
146        let lr = self.lr;
147        let epsilon = self.epsilon;
148        let decay_factor = 1.0 - lr * self.weight_decay;
149        let has_weight_decay = self.weight_decay != 0.0;
150
151        adamw_update_inner(
152            weights_data,
153            grad_values,
154            first_moment,
155            second_moment,
156            beta1,
157            beta2,
158            one_minus_beta1,
159            one_minus_beta2,
160            bias_correction1_inv,
161            bias_correction2_inv,
162            lr,
163            epsilon,
164            decay_factor,
165            has_weight_decay,
166        );
167
168        Ok(())
169    }
170
171    /// Applies one update to a trainable graph node by its `NodeId`.
172    pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
173        if !graph.requires_grad(node)? {
174            return Ok(());
175        }
176
177        let grad = match graph.grad(node)? {
178            Some(grad) => grad.clone(),
179            None => return Err(OptimError::MissingGradient { node: node.0 }),
180        };
181        let weights = graph.value_mut(node)?;
182        self.step(node.0 as u64, weights, &grad)
183    }
184}
185
186impl LearningRate for AdamW {
187    fn learning_rate(&self) -> f32 {
188        AdamW::learning_rate(self)
189    }
190
191    fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
192        AdamW::set_learning_rate(self, lr)
193    }
194}
195
196/// SIMD-accelerated AdamW parameter update with decoupled weight decay.
197#[allow(clippy::too_many_arguments, unsafe_code)]
198fn adamw_update_inner(
199    weights: &mut [f32],
200    grad: &[f32],
201    first_moment: &mut [f32],
202    second_moment: &mut [f32],
203    beta1: f32,
204    beta2: f32,
205    one_minus_beta1: f32,
206    one_minus_beta2: f32,
207    bc1_inv: f32,
208    bc2_inv: f32,
209    lr: f32,
210    epsilon: f32,
211    decay_factor: f32,
212    has_weight_decay: bool,
213) {
214    let len = weights.len();
215
216    #[cfg(target_arch = "aarch64")]
217    if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
218        unsafe {
219            adamw_update_neon(
220                weights,
221                grad,
222                first_moment,
223                second_moment,
224                beta1,
225                beta2,
226                one_minus_beta1,
227                one_minus_beta2,
228                bc1_inv,
229                bc2_inv,
230                lr,
231                epsilon,
232                decay_factor,
233                has_weight_decay,
234            );
235        }
236        return;
237    }
238
239    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
240    if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
241        unsafe {
242            adamw_update_avx(
243                weights,
244                grad,
245                first_moment,
246                second_moment,
247                beta1,
248                beta2,
249                one_minus_beta1,
250                one_minus_beta2,
251                bc1_inv,
252                bc2_inv,
253                lr,
254                epsilon,
255                decay_factor,
256                has_weight_decay,
257            );
258        }
259        return;
260    }
261
262    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
263    if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
264        unsafe {
265            adamw_update_sse(
266                weights,
267                grad,
268                first_moment,
269                second_moment,
270                beta1,
271                beta2,
272                one_minus_beta1,
273                one_minus_beta2,
274                bc1_inv,
275                bc2_inv,
276                lr,
277                epsilon,
278                decay_factor,
279                has_weight_decay,
280            );
281        }
282        return;
283    }
284
285    let wp = weights.as_mut_ptr();
286    let gp = grad.as_ptr();
287    let mp = first_moment.as_mut_ptr();
288    let vp = second_moment.as_mut_ptr();
289    for i in 0..len {
290        unsafe {
291            let g = *gp.add(i);
292            let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
293            let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
294            *mp.add(i) = m;
295            *vp.add(i) = v;
296            let m_hat = m * bc1_inv;
297            let v_hat = v * bc2_inv;
298            let w = *wp.add(i);
299            let w = if has_weight_decay {
300                w * decay_factor
301            } else {
302                w
303            };
304            *wp.add(i) = w - lr * m_hat / (v_hat.sqrt() + epsilon);
305        }
306    }
307}
308
309// ── NEON implementation ─────────────────────────────────────────────────
310
311#[cfg(target_arch = "aarch64")]
312#[target_feature(enable = "neon")]
313#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
314unsafe fn adamw_update_neon(
315    weights: &mut [f32],
316    grad: &[f32],
317    first_moment: &mut [f32],
318    second_moment: &mut [f32],
319    beta1: f32,
320    beta2: f32,
321    one_minus_beta1: f32,
322    one_minus_beta2: f32,
323    bc1_inv: f32,
324    bc2_inv: f32,
325    lr: f32,
326    epsilon: f32,
327    decay_factor: f32,
328    has_weight_decay: bool,
329) {
330    use std::arch::aarch64::*;
331    let len = weights.len();
332    let wp = weights.as_mut_ptr();
333    let gp = grad.as_ptr();
334    let mp = first_moment.as_mut_ptr();
335    let vp = second_moment.as_mut_ptr();
336    let beta1_v = vdupq_n_f32(beta1);
337    let beta2_v = vdupq_n_f32(beta2);
338    let omb1_v = vdupq_n_f32(one_minus_beta1);
339    let omb2_v = vdupq_n_f32(one_minus_beta2);
340    let bc1_v = vdupq_n_f32(bc1_inv);
341    let bc2_v = vdupq_n_f32(bc2_inv);
342    let lr_v = vdupq_n_f32(lr);
343    let eps_v = vdupq_n_f32(epsilon);
344    let decay_v = vdupq_n_f32(decay_factor);
345    let mut i = 0usize;
346    while i + 4 <= len {
347        let w = vld1q_f32(wp.add(i));
348        let g = vld1q_f32(gp.add(i));
349        let m_old = vld1q_f32(mp.add(i));
350        let v_old = vld1q_f32(vp.add(i));
351        let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
352        let grad_sq = vmulq_f32(g, g);
353        let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
354        vst1q_f32(mp.add(i), m_new);
355        vst1q_f32(vp.add(i), v_new);
356        let m_hat = vmulq_f32(m_new, bc1_v);
357        let v_hat = vmulq_f32(v_new, bc2_v);
358        let update = vdivq_f32(vmulq_f32(m_hat, lr_v), vaddq_f32(vsqrtq_f32(v_hat), eps_v));
359        let w_decayed = if has_weight_decay {
360            vmulq_f32(w, decay_v)
361        } else {
362            w
363        };
364        vst1q_f32(wp.add(i), vsubq_f32(w_decayed, update));
365        i += 4;
366    }
367    while i < len {
368        let g = *gp.add(i);
369        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
370        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
371        *mp.add(i) = m;
372        *vp.add(i) = v;
373        let w = *wp.add(i);
374        let w = if has_weight_decay {
375            w * decay_factor
376        } else {
377            w
378        };
379        *wp.add(i) = w - lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
380        i += 1;
381    }
382}
383
384// ── AVX implementation ──────────────────────────────────────────────────
385
386#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
387#[target_feature(enable = "avx")]
388#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
389unsafe fn adamw_update_avx(
390    weights: &mut [f32],
391    grad: &[f32],
392    first_moment: &mut [f32],
393    second_moment: &mut [f32],
394    beta1: f32,
395    beta2: f32,
396    one_minus_beta1: f32,
397    one_minus_beta2: f32,
398    bc1_inv: f32,
399    bc2_inv: f32,
400    lr: f32,
401    epsilon: f32,
402    decay_factor: f32,
403    has_weight_decay: bool,
404) {
405    #[cfg(target_arch = "x86")]
406    use std::arch::x86::*;
407    #[cfg(target_arch = "x86_64")]
408    use std::arch::x86_64::*;
409    let len = weights.len();
410    let wp = weights.as_mut_ptr();
411    let gp = grad.as_ptr();
412    let mp = first_moment.as_mut_ptr();
413    let vp = second_moment.as_mut_ptr();
414    let beta1_v = _mm256_set1_ps(beta1);
415    let beta2_v = _mm256_set1_ps(beta2);
416    let omb1_v = _mm256_set1_ps(one_minus_beta1);
417    let omb2_v = _mm256_set1_ps(one_minus_beta2);
418    let bc1_v = _mm256_set1_ps(bc1_inv);
419    let bc2_v = _mm256_set1_ps(bc2_inv);
420    let lr_v = _mm256_set1_ps(lr);
421    let eps_v = _mm256_set1_ps(epsilon);
422    let decay_v = _mm256_set1_ps(decay_factor);
423    let mut i = 0usize;
424    while i + 8 <= len {
425        let w = _mm256_loadu_ps(wp.add(i));
426        let g = _mm256_loadu_ps(gp.add(i));
427        let m_old = _mm256_loadu_ps(mp.add(i));
428        let v_old = _mm256_loadu_ps(vp.add(i));
429        let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
430        let grad_sq = _mm256_mul_ps(g, g);
431        let v_new = _mm256_add_ps(
432            _mm256_mul_ps(beta2_v, v_old),
433            _mm256_mul_ps(omb2_v, grad_sq),
434        );
435        _mm256_storeu_ps(mp.add(i), m_new);
436        _mm256_storeu_ps(vp.add(i), v_new);
437        let m_hat = _mm256_mul_ps(m_new, bc1_v);
438        let v_hat = _mm256_mul_ps(v_new, bc2_v);
439        let update = _mm256_div_ps(
440            _mm256_mul_ps(m_hat, lr_v),
441            _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v),
442        );
443        let w_decayed = if has_weight_decay {
444            _mm256_mul_ps(w, decay_v)
445        } else {
446            w
447        };
448        _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w_decayed, update));
449        i += 8;
450    }
451    while i < len {
452        let g = *gp.add(i);
453        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
454        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
455        *mp.add(i) = m;
456        *vp.add(i) = v;
457        let w = *wp.add(i);
458        let w = if has_weight_decay {
459            w * decay_factor
460        } else {
461            w
462        };
463        *wp.add(i) = w - lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
464        i += 1;
465    }
466}
467
468// ── SSE implementation ──────────────────────────────────────────────────
469
470#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
471#[target_feature(enable = "sse")]
472#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
473unsafe fn adamw_update_sse(
474    weights: &mut [f32],
475    grad: &[f32],
476    first_moment: &mut [f32],
477    second_moment: &mut [f32],
478    beta1: f32,
479    beta2: f32,
480    one_minus_beta1: f32,
481    one_minus_beta2: f32,
482    bc1_inv: f32,
483    bc2_inv: f32,
484    lr: f32,
485    epsilon: f32,
486    decay_factor: f32,
487    has_weight_decay: bool,
488) {
489    #[cfg(target_arch = "x86")]
490    use std::arch::x86::*;
491    #[cfg(target_arch = "x86_64")]
492    use std::arch::x86_64::*;
493    let len = weights.len();
494    let wp = weights.as_mut_ptr();
495    let gp = grad.as_ptr();
496    let mp = first_moment.as_mut_ptr();
497    let vp = second_moment.as_mut_ptr();
498    let beta1_v = _mm_set1_ps(beta1);
499    let beta2_v = _mm_set1_ps(beta2);
500    let omb1_v = _mm_set1_ps(one_minus_beta1);
501    let omb2_v = _mm_set1_ps(one_minus_beta2);
502    let bc1_v = _mm_set1_ps(bc1_inv);
503    let bc2_v = _mm_set1_ps(bc2_inv);
504    let lr_v = _mm_set1_ps(lr);
505    let eps_v = _mm_set1_ps(epsilon);
506    let decay_v = _mm_set1_ps(decay_factor);
507    let mut i = 0usize;
508    while i + 4 <= len {
509        let w = _mm_loadu_ps(wp.add(i));
510        let g = _mm_loadu_ps(gp.add(i));
511        let m_old = _mm_loadu_ps(mp.add(i));
512        let v_old = _mm_loadu_ps(vp.add(i));
513        let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
514        let grad_sq = _mm_mul_ps(g, g);
515        let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
516        _mm_storeu_ps(mp.add(i), m_new);
517        _mm_storeu_ps(vp.add(i), v_new);
518        let m_hat = _mm_mul_ps(m_new, bc1_v);
519        let v_hat = _mm_mul_ps(v_new, bc2_v);
520        let update = _mm_div_ps(
521            _mm_mul_ps(m_hat, lr_v),
522            _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v),
523        );
524        let w_decayed = if has_weight_decay {
525            _mm_mul_ps(w, decay_v)
526        } else {
527            w
528        };
529        _mm_storeu_ps(wp.add(i), _mm_sub_ps(w_decayed, update));
530        i += 4;
531    }
532    while i < len {
533        let g = *gp.add(i);
534        let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
535        let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
536        *mp.add(i) = m;
537        *vp.add(i) = v;
538        let w = *wp.add(i);
539        let w = if has_weight_decay {
540            w * decay_factor
541        } else {
542            w
543        };
544        *wp.add(i) = w - lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
545        i += 1;
546    }
547}