1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_dampening, validate_lr, validate_momentum};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
12pub struct Sgd {
13 lr: f32,
14 momentum: f32,
15 dampening: f32,
16 weight_decay: f32,
17 nesterov: bool,
18 velocity: HashMap<u64, Tensor>,
19}
20
21impl Sgd {
22 pub fn new(lr: f32) -> Result<Self, OptimError> {
24 validate_lr(lr)?;
25 Ok(Self {
26 lr,
27 momentum: 0.0,
28 dampening: 0.0,
29 weight_decay: 0.0,
30 nesterov: false,
31 velocity: HashMap::new(),
32 })
33 }
34
35 pub fn with_momentum(mut self, momentum: f32) -> Result<Self, OptimError> {
37 validate_momentum(momentum)?;
38 self.momentum = momentum;
39 self.validate_nesterov_constraints()?;
40 Ok(self)
41 }
42
43 pub fn with_dampening(mut self, dampening: f32) -> Result<Self, OptimError> {
45 validate_dampening(dampening)?;
46 self.dampening = dampening;
47 Ok(self)
48 }
49
50 pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
52 if !weight_decay.is_finite() || weight_decay < 0.0 {
53 return Err(OptimError::InvalidWeightDecay { weight_decay });
54 }
55 self.weight_decay = weight_decay;
56 Ok(self)
57 }
58
59 pub fn with_nesterov(mut self, nesterov: bool) -> Result<Self, OptimError> {
61 self.nesterov = nesterov;
62 self.validate_nesterov_constraints()?;
63 Ok(self)
64 }
65
66 pub fn clear_state(&mut self) {
68 self.velocity.clear();
69 }
70
71 pub fn learning_rate(&self) -> f32 {
73 self.lr
74 }
75
76 pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
78 validate_lr(lr)?;
79 self.lr = lr;
80 Ok(())
81 }
82
83 pub fn step(
85 &mut self,
86 parameter_id: u64,
87 weights: &mut Tensor,
88 grad: &Tensor,
89 ) -> Result<(), OptimError> {
90 if weights.shape() != grad.shape() {
91 return Err(OptimError::ShapeMismatch {
92 weights: weights.shape().to_vec(),
93 grad: grad.shape().to_vec(),
94 });
95 }
96
97 if self.weight_decay == 0.0 && self.momentum == 0.0 {
99 axpy_neg(weights.data_mut(), grad.data(), self.lr);
100 return Ok(());
101 }
102
103 let has_wd = self.weight_decay != 0.0;
105 let adjusted_grad_buf: Vec<f32>;
108 let grad_slice: &[f32] = if has_wd {
109 let mut buf = grad.data().to_vec();
110 let wd = self.weight_decay;
111 fma_inplace(&mut buf, weights.data(), wd);
112 adjusted_grad_buf = buf;
113 &adjusted_grad_buf
114 } else {
115 grad.data()
116 };
117
118 if self.momentum != 0.0 {
119 let velocity = match self.velocity.entry(parameter_id) {
120 Entry::Occupied(entry) => entry.into_mut(),
121 Entry::Vacant(entry) => {
122 let initial = Tensor::zeros(weights.shape().to_vec())?;
123 entry.insert(initial)
124 }
125 };
126
127 if velocity.shape() != weights.shape() {
128 *velocity = Tensor::zeros(weights.shape().to_vec())?;
129 }
130
131 let mom = self.momentum;
134 let grad_scale = 1.0 - self.dampening;
135 momentum_update(velocity.data_mut(), grad_slice, mom, grad_scale);
136
137 if self.nesterov {
138 axpy_neg(weights.data_mut(), grad_slice, self.lr);
142 axpy_neg(weights.data_mut(), velocity.data(), self.lr * mom);
143 } else {
144 axpy_neg(weights.data_mut(), velocity.data(), self.lr);
146 }
147 } else {
148 axpy_neg(weights.data_mut(), grad_slice, self.lr);
149 }
150 Ok(())
151 }
152
153 pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
155 if !graph.requires_grad(node)? {
156 return Ok(());
157 }
158
159 let grad = match graph.grad(node)? {
160 Some(grad) => grad.clone(),
161 None => return Err(OptimError::MissingGradient { node: node.0 }),
162 };
163 let weights = graph.value_mut(node)?;
164 self.step(node.0 as u64, weights, &grad)
165 }
166
167 fn validate_nesterov_constraints(&self) -> Result<(), OptimError> {
168 if self.nesterov && self.momentum == 0.0 {
169 return Err(OptimError::NesterovRequiresMomentum);
170 }
171 Ok(())
172 }
173}
174
175impl LearningRate for Sgd {
176 fn learning_rate(&self) -> f32 {
177 Sgd::learning_rate(self)
178 }
179
180 fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
181 Sgd::set_learning_rate(self, lr)
182 }
183}
184
185#[allow(unsafe_code)]
187fn axpy_neg(weights: &mut [f32], grads: &[f32], lr: f32) {
188 debug_assert_eq!(weights.len(), grads.len());
189 let len = weights.len();
190
191 #[cfg(target_arch = "aarch64")]
192 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
193 unsafe { axpy_neg_neon(weights, grads, lr) };
194 return;
195 }
196
197 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
198 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
199 unsafe { axpy_neg_avx(weights, grads, lr) };
200 return;
201 }
202
203 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
204 if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
205 unsafe { axpy_neg_sse(weights, grads, lr) };
206 return;
207 }
208
209 let w_ptr = weights.as_mut_ptr();
210 let g_ptr = grads.as_ptr();
211 unsafe {
212 let mut i = 0usize;
213 while i + 4 <= len {
214 *w_ptr.add(i) -= lr * *g_ptr.add(i);
215 *w_ptr.add(i + 1) -= lr * *g_ptr.add(i + 1);
216 *w_ptr.add(i + 2) -= lr * *g_ptr.add(i + 2);
217 *w_ptr.add(i + 3) -= lr * *g_ptr.add(i + 3);
218 i += 4;
219 }
220 while i < len {
221 *w_ptr.add(i) -= lr * *g_ptr.add(i);
222 i += 1;
223 }
224 }
225}
226
227#[allow(unsafe_code)]
229fn fma_inplace(dst: &mut [f32], src: &[f32], scale: f32) {
230 debug_assert_eq!(dst.len(), src.len());
231 let len = dst.len();
232
233 #[cfg(target_arch = "aarch64")]
234 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
235 unsafe { fma_inplace_neon(dst, src, scale) };
236 return;
237 }
238
239 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
240 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
241 unsafe { fma_inplace_avx(dst, src, scale) };
242 return;
243 }
244
245 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
246 if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
247 unsafe { fma_inplace_sse(dst, src, scale) };
248 return;
249 }
250
251 for i in 0..len {
252 dst[i] += src[i] * scale;
253 }
254}
255
256#[allow(unsafe_code)]
258fn momentum_update(velocity: &mut [f32], grad: &[f32], momentum: f32, grad_scale: f32) {
259 debug_assert_eq!(velocity.len(), grad.len());
260 let len = velocity.len();
261
262 #[cfg(target_arch = "aarch64")]
263 if !cfg!(miri) && std::arch::is_aarch64_feature_detected!("neon") {
264 unsafe { momentum_update_neon(velocity, grad, momentum, grad_scale) };
265 return;
266 }
267
268 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
269 if !cfg!(miri) && std::is_x86_feature_detected!("avx") {
270 unsafe { momentum_update_avx(velocity, grad, momentum, grad_scale) };
271 return;
272 }
273
274 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
275 if !cfg!(miri) && std::is_x86_feature_detected!("sse") {
276 unsafe { momentum_update_sse(velocity, grad, momentum, grad_scale) };
277 return;
278 }
279
280 for i in 0..len {
281 velocity[i] = momentum * velocity[i] + grad_scale * grad[i];
282 }
283}
284
285#[cfg(target_arch = "aarch64")]
288#[target_feature(enable = "neon")]
289#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
290unsafe fn axpy_neg_neon(weights: &mut [f32], grads: &[f32], lr: f32) {
291 use std::arch::aarch64::*;
292 let len = weights.len();
293 let wp = weights.as_mut_ptr();
294 let gp = grads.as_ptr();
295 let vlr = vdupq_n_f32(lr);
296 let mut i = 0usize;
297 while i + 4 <= len {
298 let w = vld1q_f32(wp.add(i));
299 let g = vld1q_f32(gp.add(i));
300 vst1q_f32(wp.add(i), vfmsq_f32(w, g, vlr));
301 i += 4;
302 }
303 while i < len {
304 *wp.add(i) -= lr * *gp.add(i);
305 i += 1;
306 }
307}
308
309#[cfg(target_arch = "aarch64")]
310#[target_feature(enable = "neon")]
311#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
312unsafe fn fma_inplace_neon(dst: &mut [f32], src: &[f32], scale: f32) {
313 use std::arch::aarch64::*;
314 let len = dst.len();
315 let dp = dst.as_mut_ptr();
316 let sp = src.as_ptr();
317 let vs = vdupq_n_f32(scale);
318 let mut i = 0usize;
319 while i + 4 <= len {
320 let d = vld1q_f32(dp.add(i));
321 let s = vld1q_f32(sp.add(i));
322 vst1q_f32(dp.add(i), vfmaq_f32(d, s, vs));
323 i += 4;
324 }
325 while i < len {
326 *dp.add(i) += *sp.add(i) * scale;
327 i += 1;
328 }
329}
330
331#[cfg(target_arch = "aarch64")]
332#[target_feature(enable = "neon")]
333#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
334unsafe fn momentum_update_neon(velocity: &mut [f32], grad: &[f32], momentum: f32, grad_scale: f32) {
335 use std::arch::aarch64::*;
336 let len = velocity.len();
337 let vp = velocity.as_mut_ptr();
338 let gp = grad.as_ptr();
339 let vmom = vdupq_n_f32(momentum);
340 let vgs = vdupq_n_f32(grad_scale);
341 let mut i = 0usize;
342 while i + 4 <= len {
343 let v = vld1q_f32(vp.add(i));
344 let g = vld1q_f32(gp.add(i));
345 let result = vfmaq_f32(vmulq_f32(vmom, v), g, vgs);
347 vst1q_f32(vp.add(i), result);
348 i += 4;
349 }
350 while i < len {
351 *vp.add(i) = momentum * *vp.add(i) + grad_scale * *gp.add(i);
352 i += 1;
353 }
354}
355
356#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
359#[target_feature(enable = "avx")]
360#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
361unsafe fn axpy_neg_avx(weights: &mut [f32], grads: &[f32], lr: f32) {
362 #[cfg(target_arch = "x86")]
363 use std::arch::x86::*;
364 #[cfg(target_arch = "x86_64")]
365 use std::arch::x86_64::*;
366 let len = weights.len();
367 let wp = weights.as_mut_ptr();
368 let gp = grads.as_ptr();
369 let vlr = _mm256_set1_ps(lr);
370 let mut i = 0usize;
371 while i + 8 <= len {
372 let w = _mm256_loadu_ps(wp.add(i));
373 let g = _mm256_loadu_ps(gp.add(i));
374 _mm256_storeu_ps(wp.add(i), _mm256_sub_ps(w, _mm256_mul_ps(g, vlr)));
375 i += 8;
376 }
377 while i < len {
378 *wp.add(i) -= lr * *gp.add(i);
379 i += 1;
380 }
381}
382
383#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
384#[target_feature(enable = "avx")]
385#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
386unsafe fn fma_inplace_avx(dst: &mut [f32], src: &[f32], scale: f32) {
387 #[cfg(target_arch = "x86")]
388 use std::arch::x86::*;
389 #[cfg(target_arch = "x86_64")]
390 use std::arch::x86_64::*;
391 let len = dst.len();
392 let dp = dst.as_mut_ptr();
393 let sp = src.as_ptr();
394 let vs = _mm256_set1_ps(scale);
395 let mut i = 0usize;
396 while i + 8 <= len {
397 let d = _mm256_loadu_ps(dp.add(i));
398 let s = _mm256_loadu_ps(sp.add(i));
399 _mm256_storeu_ps(dp.add(i), _mm256_add_ps(d, _mm256_mul_ps(s, vs)));
400 i += 8;
401 }
402 while i < len {
403 *dp.add(i) += *sp.add(i) * scale;
404 i += 1;
405 }
406}
407
408#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
409#[target_feature(enable = "avx")]
410#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
411unsafe fn momentum_update_avx(velocity: &mut [f32], grad: &[f32], momentum: f32, grad_scale: f32) {
412 #[cfg(target_arch = "x86")]
413 use std::arch::x86::*;
414 #[cfg(target_arch = "x86_64")]
415 use std::arch::x86_64::*;
416 let len = velocity.len();
417 let vp = velocity.as_mut_ptr();
418 let gp = grad.as_ptr();
419 let vmom = _mm256_set1_ps(momentum);
420 let vgs = _mm256_set1_ps(grad_scale);
421 let mut i = 0usize;
422 while i + 8 <= len {
423 let v = _mm256_loadu_ps(vp.add(i));
424 let g = _mm256_loadu_ps(gp.add(i));
425 let result = _mm256_add_ps(_mm256_mul_ps(vmom, v), _mm256_mul_ps(g, vgs));
426 _mm256_storeu_ps(vp.add(i), result);
427 i += 8;
428 }
429 while i < len {
430 *vp.add(i) = momentum * *vp.add(i) + grad_scale * *gp.add(i);
431 i += 1;
432 }
433}
434
435#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
438#[target_feature(enable = "sse")]
439#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
440unsafe fn axpy_neg_sse(weights: &mut [f32], grads: &[f32], lr: f32) {
441 #[cfg(target_arch = "x86")]
442 use std::arch::x86::*;
443 #[cfg(target_arch = "x86_64")]
444 use std::arch::x86_64::*;
445 let len = weights.len();
446 let wp = weights.as_mut_ptr();
447 let gp = grads.as_ptr();
448 let vlr = _mm_set1_ps(lr);
449 let mut i = 0usize;
450 while i + 4 <= len {
451 let w = _mm_loadu_ps(wp.add(i));
452 let g = _mm_loadu_ps(gp.add(i));
453 _mm_storeu_ps(wp.add(i), _mm_sub_ps(w, _mm_mul_ps(g, vlr)));
454 i += 4;
455 }
456 while i < len {
457 *wp.add(i) -= lr * *gp.add(i);
458 i += 1;
459 }
460}
461
462#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
463#[target_feature(enable = "sse")]
464#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
465unsafe fn fma_inplace_sse(dst: &mut [f32], src: &[f32], scale: f32) {
466 #[cfg(target_arch = "x86")]
467 use std::arch::x86::*;
468 #[cfg(target_arch = "x86_64")]
469 use std::arch::x86_64::*;
470 let len = dst.len();
471 let dp = dst.as_mut_ptr();
472 let sp = src.as_ptr();
473 let vs = _mm_set1_ps(scale);
474 let mut i = 0usize;
475 while i + 4 <= len {
476 let d = _mm_loadu_ps(dp.add(i));
477 let s = _mm_loadu_ps(sp.add(i));
478 _mm_storeu_ps(dp.add(i), _mm_add_ps(d, _mm_mul_ps(s, vs)));
479 i += 4;
480 }
481 while i < len {
482 *dp.add(i) += *sp.add(i) * scale;
483 i += 1;
484 }
485}
486
487#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
488#[target_feature(enable = "sse")]
489#[allow(unsafe_code, unsafe_op_in_unsafe_fn)]
490unsafe fn momentum_update_sse(velocity: &mut [f32], grad: &[f32], momentum: f32, grad_scale: f32) {
491 #[cfg(target_arch = "x86")]
492 use std::arch::x86::*;
493 #[cfg(target_arch = "x86_64")]
494 use std::arch::x86_64::*;
495 let len = velocity.len();
496 let vp = velocity.as_mut_ptr();
497 let gp = grad.as_ptr();
498 let vmom = _mm_set1_ps(momentum);
499 let vgs = _mm_set1_ps(grad_scale);
500 let mut i = 0usize;
501 while i + 4 <= len {
502 let v = _mm_loadu_ps(vp.add(i));
503 let g = _mm_loadu_ps(gp.add(i));
504 let result = _mm_add_ps(_mm_mul_ps(vmom, v), _mm_mul_ps(g, vgs));
505 _mm_storeu_ps(vp.add(i), result);
506 i += 4;
507 }
508 while i < len {
509 *vp.add(i) = momentum * *vp.add(i) + grad_scale * *gp.add(i);
510 i += 1;
511 }
512}