1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_beta1, validate_beta2, validate_epsilon, validate_lr};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct AdamWState {
12 first_moment: Tensor,
13 second_moment: Tensor,
14 step: u64,
15}
16
17impl AdamWState {
18 fn new(shape: &[usize]) -> Result<Self, OptimError> {
19 Ok(Self {
20 first_moment: Tensor::zeros(shape.to_vec())?,
21 second_moment: Tensor::zeros(shape.to_vec())?,
22 step: 0,
23 })
24 }
25
26 fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27 *self = Self::new(shape)?;
28 Ok(())
29 }
30}
31
32#[derive(Debug, Clone)]
34pub struct AdamW {
35 lr: f32,
36 beta1: f32,
37 beta2: f32,
38 epsilon: f32,
39 weight_decay: f32,
40 state: HashMap<u64, AdamWState>,
41}
42
43impl AdamW {
44 pub fn new(lr: f32) -> Result<Self, OptimError> {
46 validate_lr(lr)?;
47 Ok(Self {
48 lr,
49 beta1: 0.9,
50 beta2: 0.999,
51 epsilon: 1e-8,
52 weight_decay: 0.0,
53 state: HashMap::new(),
54 })
55 }
56
57 pub fn with_beta1(mut self, beta1: f32) -> Result<Self, OptimError> {
59 validate_beta1(beta1)?;
60 self.beta1 = beta1;
61 Ok(self)
62 }
63
64 pub fn with_beta2(mut self, beta2: f32) -> Result<Self, OptimError> {
66 validate_beta2(beta2)?;
67 self.beta2 = beta2;
68 Ok(self)
69 }
70
71 pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
73 validate_epsilon(epsilon)?;
74 self.epsilon = epsilon;
75 Ok(self)
76 }
77
78 pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
80 if !weight_decay.is_finite() || weight_decay < 0.0 {
81 return Err(OptimError::InvalidWeightDecay { weight_decay });
82 }
83 self.weight_decay = weight_decay;
84 Ok(self)
85 }
86
87 pub fn clear_state(&mut self) {
89 self.state.clear();
90 }
91
92 pub fn learning_rate(&self) -> f32 {
94 self.lr
95 }
96
97 pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
99 validate_lr(lr)?;
100 self.lr = lr;
101 Ok(())
102 }
103
104 pub fn step(
106 &mut self,
107 parameter_id: u64,
108 weights: &mut Tensor,
109 grad: &Tensor,
110 ) -> Result<(), OptimError> {
111 if weights.shape() != grad.shape() {
112 return Err(OptimError::ShapeMismatch {
113 weights: weights.shape().to_vec(),
114 grad: grad.shape().to_vec(),
115 });
116 }
117
118 let state = match self.state.entry(parameter_id) {
119 Entry::Occupied(entry) => entry.into_mut(),
120 Entry::Vacant(entry) => entry.insert(AdamWState::new(weights.shape())?),
121 };
122 if state.first_moment.shape() != weights.shape()
123 || state.second_moment.shape() != weights.shape()
124 {
125 state.reset(weights.shape())?;
126 }
127
128 state.step = state.step.saturating_add(1);
129 let step_f64 = state.step as f64;
130 let bias_correction1 =
131 (1.0 - (self.beta1 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
132 let bias_correction2 =
133 (1.0 - (self.beta2 as f64).powf(step_f64)).max(f64::MIN_POSITIVE) as f32;
134
135 let first_moment = state.first_moment.data_mut();
136 let second_moment = state.second_moment.data_mut();
137 let grad_values = grad.data();
138 let weights_data = weights.data_mut();
139
140 let beta1 = self.beta1;
141 let beta2 = self.beta2;
142 let one_minus_beta1 = 1.0 - beta1;
143 let one_minus_beta2 = 1.0 - beta2;
144 let bias_correction1_inv = 1.0 / bias_correction1;
145 let bias_correction2_inv = 1.0 / bias_correction2;
146 let lr = self.lr;
147 let epsilon = self.epsilon;
148 let decay_factor = 1.0 - lr * self.weight_decay;
149 let has_weight_decay = self.weight_decay != 0.0;
150
151 adamw_update_inner(
152 weights_data,
153 grad_values,
154 first_moment,
155 second_moment,
156 beta1,
157 beta2,
158 one_minus_beta1,
159 one_minus_beta2,
160 bias_correction1_inv,
161 bias_correction2_inv,
162 lr,
163 epsilon,
164 decay_factor,
165 has_weight_decay,
166 );
167
168 Ok(())
169 }
170
171 pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
173 if !graph.requires_grad(node)? {
174 return Ok(());
175 }
176
177 let grad = match graph.grad(node)? {
178 Some(grad) => grad.clone(),
179 None => return Err(OptimError::MissingGradient { node: node.0 }),
180 };
181 let weights = graph.value_mut(node)?;
182 self.step(node.0 as u64, weights, &grad)
183 }
184}
185
186impl LearningRate for AdamW {
187 fn learning_rate(&self) -> f32 {
188 AdamW::learning_rate(self)
189 }
190
191 fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
192 AdamW::set_learning_rate(self, lr)
193 }
194}
195
196#[allow(clippy::too_many_arguments, unsafe_code)]
198fn adamw_update_inner(
199 weights: &mut [f32],
200 grad: &[f32],
201 first_moment: &mut [f32],
202 second_moment: &mut [f32],
203 beta1: f32,
204 beta2: f32,
205 one_minus_beta1: f32,
206 one_minus_beta2: f32,
207 bc1_inv: f32,
208 bc2_inv: f32,
209 lr: f32,
210 epsilon: f32,
211 decay_factor: f32,
212 has_weight_decay: bool,
213) {
214 let len = weights.len();
215
216 #[cfg(target_arch = "aarch64")]
217 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
218 unsafe {
219 adamw_update_neon(
220 weights,
221 grad,
222 first_moment,
223 second_moment,
224 beta1,
225 beta2,
226 one_minus_beta1,
227 one_minus_beta2,
228 bc1_inv,
229 bc2_inv,
230 lr,
231 epsilon,
232 decay_factor,
233 has_weight_decay,
234 );
235 }
236 return;
237 }
238
239 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
240 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
241 unsafe {
242 adamw_update_avx(
243 weights,
244 grad,
245 first_moment,
246 second_moment,
247 beta1,
248 beta2,
249 one_minus_beta1,
250 one_minus_beta2,
251 bc1_inv,
252 bc2_inv,
253 lr,
254 epsilon,
255 decay_factor,
256 has_weight_decay,
257 );
258 }
259 return;
260 }
261
262 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
263 if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
264 unsafe {
265 adamw_update_sse(
266 weights,
267 grad,
268 first_moment,
269 second_moment,
270 beta1,
271 beta2,
272 one_minus_beta1,
273 one_minus_beta2,
274 bc1_inv,
275 bc2_inv,
276 lr,
277 epsilon,
278 decay_factor,
279 has_weight_decay,
280 );
281 }
282 return;
283 }
284
285 let wp = weights.as_mut_ptr();
286 let gp = grad.as_ptr();
287 let mp = first_moment.as_mut_ptr();
288 let vp = second_moment.as_mut_ptr();
289 for i in 0..len {
290 unsafe {
291 let g = *gp.add(i);
292 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
293 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
294 *mp.add(i) = m;
295 *vp.add(i) = v;
296 let m_hat = m * bc1_inv;
297 let v_hat = v * bc2_inv;
298 let w = *wp.add(i);
299 let w = if has_weight_decay {
300 w * decay_factor
301 } else {
302 w
303 };
304 *wp.add(i) = w - lr * m_hat / (v_hat.sqrt() + epsilon);
305 }
306 }
307}
308
309#[cfg(target_arch = "aarch64")]
312#[target_feature(enable = "neon")]
313#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
314unsafe fn adamw_update_neon(
315 weights: &mut [f32],
316 grad: &[f32],
317 first_moment: &mut [f32],
318 second_moment: &mut [f32],
319 beta1: f32,
320 beta2: f32,
321 one_minus_beta1: f32,
322 one_minus_beta2: f32,
323 bc1_inv: f32,
324 bc2_inv: f32,
325 lr: f32,
326 epsilon: f32,
327 decay_factor: f32,
328 has_weight_decay: bool,
329) {
330 use std::arch::aarch64::*;
331 let len = weights.len();
332 let wp = weights.as_mut_ptr();
333 let gp = grad.as_ptr();
334 let mp = first_moment.as_mut_ptr();
335 let vp = second_moment.as_mut_ptr();
336 let beta1_v = vdupq_n_f32(beta1);
337 let beta2_v = vdupq_n_f32(beta2);
338 let omb1_v = vdupq_n_f32(one_minus_beta1);
339 let omb2_v = vdupq_n_f32(one_minus_beta2);
340 let bc1_v = vdupq_n_f32(bc1_inv);
341 let bc2_v = vdupq_n_f32(bc2_inv);
342 let lr_v = vdupq_n_f32(lr);
343 let eps_v = vdupq_n_f32(epsilon);
344 let decay_v = vdupq_n_f32(decay_factor);
345 let mut i = 0usize;
346 while i + 4 <= len {
347 let w = vld1q_f32(wp.add(i));
348 let g = vld1q_f32(gp.add(i));
349 let m_old = vld1q_f32(mp.add(i));
350 let v_old = vld1q_f32(vp.add(i));
351 let m_new = vfmaq_f32(vmulq_f32(g, omb1_v), m_old, beta1_v);
352 let grad_sq = vmulq_f32(g, g);
353 let v_new = vfmaq_f32(vmulq_f32(grad_sq, omb2_v), v_old, beta2_v);
354 vst1q_f32(mp.add(i), m_new);
355 vst1q_f32(vp.add(i), v_new);
356 let m_hat = vmulq_f32(m_new, bc1_v);
357 let v_hat = vmulq_f32(v_new, bc2_v);
358 let update = vdivq_f32(vmulq_f32(m_hat, lr_v), vaddq_f32(vsqrtq_f32(v_hat), eps_v));
359 let w_decayed = if has_weight_decay {
360 vmulq_f32(w, decay_v)
361 } else {
362 w
363 };
364 vst1q_f32(wp.add(i), vsubq_f32(w_decayed, update));
365 i += 4;
366 }
367 while i < len {
368 let g = *gp.add(i);
369 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
370 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
371 *mp.add(i) = m;
372 *vp.add(i) = v;
373 let w = *wp.add(i);
374 let w = if has_weight_decay {
375 w * decay_factor
376 } else {
377 w
378 };
379 *wp.add(i) = w - lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
380 i += 1;
381 }
382}
383
384#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
387#[target_feature(enable = "avx")]
388#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
389unsafe fn adamw_update_avx(
390 weights: &mut [f32],
391 grad: &[f32],
392 first_moment: &mut [f32],
393 second_moment: &mut [f32],
394 beta1: f32,
395 beta2: f32,
396 one_minus_beta1: f32,
397 one_minus_beta2: f32,
398 bc1_inv: f32,
399 bc2_inv: f32,
400 lr: f32,
401 epsilon: f32,
402 decay_factor: f32,
403 has_weight_decay: bool,
404) {
405 #[cfg(target_arch = "x86")]
406 use std::arch::x86::*;
407 #[cfg(target_arch = "x86_64")]
408 use std::arch::x86_64::*;
409 let len = weights.len();
410 let wp = weights.as_mut_ptr();
411 let gp = grad.as_ptr();
412 let mp = first_moment.as_mut_ptr();
413 let vp = second_moment.as_mut_ptr();
414 let beta1_v = _mm256_set1_ps(beta1);
415 let beta2_v = _mm256_set1_ps(beta2);
416 let omb1_v = _mm256_set1_ps(one_minus_beta1);
417 let omb2_v = _mm256_set1_ps(one_minus_beta2);
418 let bc1_v = _mm256_set1_ps(bc1_inv);
419 let bc2_v = _mm256_set1_ps(bc2_inv);
420 let lr_v = _mm256_set1_ps(lr);
421 let eps_v = _mm256_set1_ps(epsilon);
422 let decay_v = _mm256_set1_ps(decay_factor);
423 let mut i = 0usize;
424 while i + 8 <= len {
425 let w = _mm256_loadu_ps(wp.add(i));
426 let g = _mm256_loadu_ps(gp.add(i));
427 let m_old = _mm256_loadu_ps(mp.add(i));
428 let v_old = _mm256_loadu_ps(vp.add(i));
429 let m_new = _mm256_add_ps(_mm256_mul_ps(beta1_v, m_old), _mm256_mul_ps(omb1_v, g));
430 let grad_sq = _mm256_mul_ps(g, g);
431 let v_new = _mm256_add_ps(
432 _mm256_mul_ps(beta2_v, v_old),
433 _mm256_mul_ps(omb2_v, grad_sq),
434 );
435 _mm256_storeu_ps(mp.add(i), m_new);
436 _mm256_storeu_ps(vp.add(i), v_new);
437 let m_hat = _mm256_mul_ps(m_new, bc1_v);
438 let v_hat = _mm256_mul_ps(v_new, bc2_v);
439 let update = _mm256_div_ps(
440 _mm256_mul_ps(m_hat, lr_v),
441 _mm256_add_ps(_mm256_sqrt_ps(v_hat), eps_v),
442 );
443 let w_decayed = if has_weight_decay {
444 _mm256_mul_ps(w, decay_v)
445 } else {
446 w
447 };
448 _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w_decayed, update));
449 i += 8;
450 }
451 while i < len {
452 let g = *gp.add(i);
453 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
454 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
455 *mp.add(i) = m;
456 *vp.add(i) = v;
457 let w = *wp.add(i);
458 let w = if has_weight_decay {
459 w * decay_factor
460 } else {
461 w
462 };
463 *wp.add(i) = w - lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
464 i += 1;
465 }
466}
467
468#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
471#[target_feature(enable = "sse")]
472#[allow(clippy::too_many_arguments, unsafe_code, unsafe_op_in_unsafe_fn)]
473unsafe fn adamw_update_sse(
474 weights: &mut [f32],
475 grad: &[f32],
476 first_moment: &mut [f32],
477 second_moment: &mut [f32],
478 beta1: f32,
479 beta2: f32,
480 one_minus_beta1: f32,
481 one_minus_beta2: f32,
482 bc1_inv: f32,
483 bc2_inv: f32,
484 lr: f32,
485 epsilon: f32,
486 decay_factor: f32,
487 has_weight_decay: bool,
488) {
489 #[cfg(target_arch = "x86")]
490 use std::arch::x86::*;
491 #[cfg(target_arch = "x86_64")]
492 use std::arch::x86_64::*;
493 let len = weights.len();
494 let wp = weights.as_mut_ptr();
495 let gp = grad.as_ptr();
496 let mp = first_moment.as_mut_ptr();
497 let vp = second_moment.as_mut_ptr();
498 let beta1_v = _mm_set1_ps(beta1);
499 let beta2_v = _mm_set1_ps(beta2);
500 let omb1_v = _mm_set1_ps(one_minus_beta1);
501 let omb2_v = _mm_set1_ps(one_minus_beta2);
502 let bc1_v = _mm_set1_ps(bc1_inv);
503 let bc2_v = _mm_set1_ps(bc2_inv);
504 let lr_v = _mm_set1_ps(lr);
505 let eps_v = _mm_set1_ps(epsilon);
506 let decay_v = _mm_set1_ps(decay_factor);
507 let mut i = 0usize;
508 while i + 4 <= len {
509 let w = _mm_loadu_ps(wp.add(i));
510 let g = _mm_loadu_ps(gp.add(i));
511 let m_old = _mm_loadu_ps(mp.add(i));
512 let v_old = _mm_loadu_ps(vp.add(i));
513 let m_new = _mm_add_ps(_mm_mul_ps(beta1_v, m_old), _mm_mul_ps(omb1_v, g));
514 let grad_sq = _mm_mul_ps(g, g);
515 let v_new = _mm_add_ps(_mm_mul_ps(beta2_v, v_old), _mm_mul_ps(omb2_v, grad_sq));
516 _mm_storeu_ps(mp.add(i), m_new);
517 _mm_storeu_ps(vp.add(i), v_new);
518 let m_hat = _mm_mul_ps(m_new, bc1_v);
519 let v_hat = _mm_mul_ps(v_new, bc2_v);
520 let update = _mm_div_ps(
521 _mm_mul_ps(m_hat, lr_v),
522 _mm_add_ps(_mm_sqrt_ps(v_hat), eps_v),
523 );
524 let w_decayed = if has_weight_decay {
525 _mm_mul_ps(w, decay_v)
526 } else {
527 w
528 };
529 _mm_storeu_ps(wp.add(i), _mm_sub_ps(w_decayed, update));
530 i += 4;
531 }
532 while i < len {
533 let g = *gp.add(i);
534 let m = beta1 * *mp.add(i) + one_minus_beta1 * g;
535 let v = beta2 * *vp.add(i) + one_minus_beta2 * g * g;
536 *mp.add(i) = m;
537 *vp.add(i) = v;
538 let w = *wp.add(i);
539 let w = if has_weight_decay {
540 w * decay_factor
541 } else {
542 w
543 };
544 *wp.add(i) = w - lr * (m * bc1_inv) / ((v * bc2_inv).sqrt() + epsilon);
545 i += 1;
546 }
547}