1use crate::error::{OptimError, Result};
15use crate::schedulers::LearningRateScheduler;
16use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
17use scirs2_core::numeric::Float;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct OptimizerConfig<A: Float> {
25 pub lr: A,
27 pub weight_decay: A,
29 pub grad_clip: Option<A>,
31 pub params: HashMap<String, A>,
33}
34
35impl<A: Float + Send + Sync> Default for OptimizerConfig<A> {
36 fn default() -> Self {
37 Self {
38 lr: A::from(0.001).unwrap(),
39 weight_decay: A::zero(),
40 grad_clip: None,
41 params: HashMap::new(),
42 }
43 }
44}
45
46impl<A: Float + Send + Sync> OptimizerConfig<A> {
47 pub fn new(lr: A) -> Self {
49 Self {
50 lr,
51 ..Default::default()
52 }
53 }
54
55 pub fn weight_decay(mut self, weightdecay: A) -> Self {
57 self.weight_decay = weightdecay;
58 self
59 }
60
61 pub fn grad_clip(mut self, gradclip: A) -> Self {
63 self.grad_clip = Some(gradclip);
64 self
65 }
66
67 pub fn param<S: Into<String>>(mut self, key: S, value: A) -> Self {
69 self.params.insert(key.into(), value);
70 self
71 }
72
73 pub fn params(mut self, params: HashMap<String, A>) -> Self {
75 self.params.extend(params);
76 self
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct Parameter<A: Float, D: Dimension> {
83 pub data: Array<A, D>,
85 pub grad: Option<Array<A, D>>,
87 pub requires_grad: bool,
89 pub name: String,
91}
92
93impl<A: Float + ScalarOperand, D: Dimension + Send + Sync> Parameter<A, D> {
94 pub fn new<S: Into<String>>(data: Array<A, D>, name: S) -> Self {
96 Self {
97 data,
98 grad: None,
99 requires_grad: true,
100 name: name.into(),
101 }
102 }
103
104 pub fn no_grad<S: Into<String>>(data: Array<A, D>, name: S) -> Self {
106 Self {
107 data,
108 grad: None,
109 requires_grad: false,
110 name: name.into(),
111 }
112 }
113
114 pub fn set_grad(&mut self, grad: Array<A, D>) {
116 if self.requires_grad {
117 self.grad = Some(grad);
118 }
119 }
120
121 pub fn zero_grad(&mut self) {
123 self.grad = None;
124 }
125
126 pub fn grad(&self) -> Option<&Array<A, D>> {
128 self.grad.as_ref()
129 }
130
131 pub fn clip_grad(&mut self, maxnorm: A) -> Result<()> {
133 if let Some(ref mut grad) = self.grad {
134 let _norm = grad
135 .iter()
136 .map(|x| (*x) * (*x))
137 .fold(A::zero(), |acc, x| acc + x)
138 .sqrt();
139 if _norm > maxnorm {
140 let scale = maxnorm / _norm;
141 grad.mapv_inplace(|x| x * scale);
142 }
143 }
144 Ok(())
145 }
146}
147
148pub trait UnifiedOptimizer<A: Float> {
150 fn config(&self) -> &OptimizerConfig<A>;
152
153 fn step_param<D: Dimension>(&mut self, param: &mut Parameter<A, D>) -> Result<()>
155 where
156 A: ScalarOperand + Debug;
157
158 fn step_params<D: Dimension>(&mut self, params: &mut [Parameter<A, D>]) -> Result<()>
160 where
161 A: ScalarOperand + Debug,
162 {
163 for param in params.iter_mut() {
164 self.step_param(param)?;
165 }
166 Ok(())
167 }
168
169 fn zero_grad<D: Dimension>(&self, params: &mut [Parameter<A, D>]) {
171 for param in params.iter_mut() {
172 param.grad = None;
173 }
174 }
175
176 fn set_lr(&mut self, lr: A);
178
179 fn get_lr(&self) -> A;
181
182 fn state_dict(&self) -> HashMap<String, Vec<u8>>;
184
185 fn load_state_dict(&mut self, statedict: HashMap<String, Vec<u8>>) -> Result<()>;
187}
188
189#[derive(Debug)]
191pub struct UnifiedSGD<A: Float> {
192 config: OptimizerConfig<A>,
193 momentum_buffers: HashMap<String, Array1<A>>,
194}
195
196impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedSGD<A> {
197 pub fn new(config: OptimizerConfig<A>) -> Self {
199 Self {
200 config,
201 momentum_buffers: HashMap::new(),
202 }
203 }
204
205 pub fn with_momentum(mut config: OptimizerConfig<A>, momentum: A) -> Self {
207 config.params.insert("momentum".to_string(), momentum);
208 Self::new(config)
209 }
210}
211
212impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedOptimizer<A> for UnifiedSGD<A> {
213 fn config(&self) -> &OptimizerConfig<A> {
214 &self.config
215 }
216
217 fn step_param<D: Dimension>(&mut self, param: &mut Parameter<A, D>) -> Result<()> {
218 if !param.requires_grad {
219 return Ok(());
220 }
221
222 if param.grad.is_none() {
224 return Err(OptimError::InvalidConfig(
225 "Parameter has no gradient".to_string(),
226 ));
227 }
228
229 if let Some(max_norm) = self.config.grad_clip {
231 param.clip_grad(max_norm)?;
232 }
233
234 if self.config.weight_decay > A::zero() {
236 param
237 .data
238 .mapv_inplace(|x| x * (A::one() - self.config.weight_decay * self.config.lr));
239 }
240
241 let grad = param.grad.as_ref().unwrap();
243
244 let momentum = self
246 .config
247 .params
248 .get("momentum")
249 .copied()
250 .unwrap_or(A::zero());
251
252 if momentum > A::zero() {
253 if let Some(momentum_buffer) = self.momentum_buffers.get_mut(¶m.name) {
255 for (m, g) in momentum_buffer.iter_mut().zip(grad.iter()) {
257 *m = momentum * (*m) + *g;
258 }
259 for (p, m) in param.data.iter_mut().zip(momentum_buffer.iter()) {
261 *p = *p - self.config.lr * (*m);
262 }
263 } else {
264 let mut momentum_buffer = Array1::zeros(grad.len());
266 for (m, g) in momentum_buffer.iter_mut().zip(grad.iter()) {
267 *m = *g;
268 }
269 for (p, m) in param.data.iter_mut().zip(momentum_buffer.iter()) {
271 *p = *p - self.config.lr * (*m);
272 }
273 self.momentum_buffers
274 .insert(param.name.clone(), momentum_buffer);
275 }
276 } else {
277 for (p, g) in param.data.iter_mut().zip(grad.iter()) {
279 *p = *p - self.config.lr * (*g);
280 }
281 }
282
283 Ok(())
284 }
285
286 fn set_lr(&mut self, lr: A) {
287 self.config.lr = lr;
288 }
289
290 fn get_lr(&self) -> A {
291 self.config.lr
292 }
293
294 fn state_dict(&self) -> HashMap<String, Vec<u8>> {
295 HashMap::new()
297 }
298
299 fn load_state_dict(&mut self, _statedict: HashMap<String, Vec<u8>>) -> Result<()> {
300 Ok(())
302 }
303}
304
305#[derive(Debug)]
307pub struct UnifiedAdam<A: Float> {
308 config: OptimizerConfig<A>,
309 step_count: usize,
310 exp_avg: HashMap<String, Array1<A>>,
311 exp_avg_sq: HashMap<String, Array1<A>>,
312}
313
314impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedAdam<A> {
315 pub fn new(config: OptimizerConfig<A>) -> Self {
317 let mut params = config.params.clone();
318 params
319 .entry("beta1".to_string())
320 .or_insert_with(|| A::from(0.9).unwrap());
321 params
322 .entry("beta2".to_string())
323 .or_insert_with(|| A::from(0.999).unwrap());
324 params
325 .entry("eps".to_string())
326 .or_insert_with(|| A::from(1e-8).unwrap());
327
328 Self {
329 config: OptimizerConfig { params, ..config },
330 step_count: 0,
331 exp_avg: HashMap::new(),
332 exp_avg_sq: HashMap::new(),
333 }
334 }
335
336 pub fn with_betas(mut config: OptimizerConfig<A>, beta1: A, beta2: A) -> Self {
338 config.params.insert("beta1".to_string(), beta1);
339 config.params.insert("beta2".to_string(), beta2);
340 Self::new(config)
341 }
342}
343
344impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedOptimizer<A> for UnifiedAdam<A> {
345 fn config(&self) -> &OptimizerConfig<A> {
346 &self.config
347 }
348
349 fn step_param<D: Dimension>(&mut self, param: &mut Parameter<A, D>) -> Result<()> {
350 if !param.requires_grad {
351 return Ok(());
352 }
353
354 if param.grad.is_none() {
356 return Err(OptimError::InvalidConfig(
357 "Parameter has no gradient".to_string(),
358 ));
359 }
360
361 if let Some(max_norm) = self.config.grad_clip {
363 param.clip_grad(max_norm)?;
364 }
365
366 self.step_count += 1;
367
368 let beta1 = self.config.params["beta1"];
369 let beta2 = self.config.params["beta2"];
370 let eps = self.config.params["eps"];
371
372 let grad = param.grad.as_ref().unwrap();
374
375 let exp_avg = self
377 .exp_avg
378 .entry(param.name.clone())
379 .or_insert_with(|| Array1::zeros(grad.len()));
380 let exp_avg_sq = self
381 .exp_avg_sq
382 .entry(param.name.clone())
383 .or_insert_with(|| Array1::zeros(grad.len()));
384
385 for ((exp_avg_val, exp_avg_sq_val), grad_val) in exp_avg
387 .iter_mut()
388 .zip(exp_avg_sq.iter_mut())
389 .zip(grad.iter())
390 {
391 *exp_avg_val = beta1 * (*exp_avg_val) + (A::one() - beta1) * (*grad_val);
392 *exp_avg_sq_val =
393 beta2 * (*exp_avg_sq_val) + (A::one() - beta2) * (*grad_val) * (*grad_val);
394 }
395
396 let bias_correction1 = A::one() - beta1.powi(self.step_count as i32);
398 let bias_correction2 = A::one() - beta2.powi(self.step_count as i32);
399
400 let step_size = self.config.lr * (bias_correction2.sqrt() / bias_correction1);
401
402 for ((p, exp_avg_val), exp_avg_sq_val) in param
404 .data
405 .iter_mut()
406 .zip(exp_avg.iter())
407 .zip(exp_avg_sq.iter())
408 {
409 let denom = exp_avg_sq_val.sqrt() + eps;
410 *p = *p - step_size * (*exp_avg_val) / denom;
411 }
412
413 if self.config.weight_decay > A::zero() {
415 param
416 .data
417 .mapv_inplace(|x| x * (A::one() - self.config.weight_decay * self.config.lr));
418 }
419
420 Ok(())
421 }
422
423 fn set_lr(&mut self, lr: A) {
424 self.config.lr = lr;
425 }
426
427 fn get_lr(&self) -> A {
428 self.config.lr
429 }
430
431 fn state_dict(&self) -> HashMap<String, Vec<u8>> {
432 HashMap::new()
434 }
435
436 fn load_state_dict(&mut self, _statedict: HashMap<String, Vec<u8>>) -> Result<()> {
437 Ok(())
439 }
440}
441
442pub struct OptimizerFactory;
444
445impl OptimizerFactory {
446 pub fn sgd<A: Float + ScalarOperand + Debug + Send + Sync>(
448 config: OptimizerConfig<A>,
449 ) -> UnifiedSGD<A> {
450 UnifiedSGD::new(config)
451 }
452
453 pub fn adam<A: Float + ScalarOperand + Debug + Send + Sync>(
455 config: OptimizerConfig<A>,
456 ) -> UnifiedAdam<A> {
457 UnifiedAdam::new(config)
458 }
459
460 pub fn sgd_momentum<A: Float + ScalarOperand + Debug + Send + Sync>(
462 config: OptimizerConfig<A>,
463 momentum: A,
464 ) -> UnifiedSGD<A> {
465 UnifiedSGD::with_momentum(config, momentum)
466 }
467
468 pub fn adam_custom<A: Float + ScalarOperand + Debug + Send + Sync>(
470 config: OptimizerConfig<A>,
471 beta1: A,
472 beta2: A,
473 ) -> UnifiedAdam<A> {
474 UnifiedAdam::with_betas(config, beta1, beta2)
475 }
476}
477
478pub struct TrainingLoop<A: Float, O: UnifiedOptimizer<A>> {
480 optimizer: O,
481 scheduler: Option<Box<dyn LearningRateScheduler<A>>>,
482 _phantom: std::marker::PhantomData<A>,
483}
484
485impl<A: Float + ScalarOperand + Debug, O: UnifiedOptimizer<A> + Send + Sync> TrainingLoop<A, O> {
486 pub fn new(optimizer: O) -> Self {
488 Self {
489 optimizer,
490 scheduler: None,
491 _phantom: std::marker::PhantomData,
492 }
493 }
494
495 pub fn with_scheduler(mut self, scheduler: Box<dyn LearningRateScheduler<A>>) -> Self {
497 self.scheduler = Some(scheduler);
498 self
499 }
500
501 pub fn step<D: Dimension>(&mut self, params: &mut [Parameter<A, D>]) -> Result<()> {
503 self.optimizer.step_params(params)?;
505
506 if let Some(ref mut scheduler) = self.scheduler {
508 let new_lr = scheduler.step();
509 self.optimizer.set_lr(new_lr);
510 }
511
512 Ok(())
513 }
514
515 pub fn zero_grad<D: Dimension>(&self, params: &mut [Parameter<A, D>]) {
517 for param in params.iter_mut() {
518 param.grad = None;
519 }
520 }
521
522 pub fn get_lr(&self) -> A {
524 self.optimizer.get_lr()
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use scirs2_core::ndarray::Array1;
532
533 #[test]
534 fn test_unified_sgd() {
535 let config = OptimizerConfig::new(0.1f64);
536 let mut optimizer = UnifiedSGD::new(config);
537
538 let mut param = Parameter::new(Array1::from_vec(vec![1.0, 2.0, 3.0]), "test_param");
539 param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
540
541 optimizer.step_param(&mut param).unwrap();
542
543 assert!((param.data[0] - 0.99).abs() < 1e-10);
545 assert!((param.data[1] - 1.98).abs() < 1e-10);
546 assert!((param.data[2] - 2.97).abs() < 1e-10);
547 }
548
549 #[test]
550 fn test_unified_adam() {
551 let config = OptimizerConfig::new(0.001f64);
552 let mut optimizer = UnifiedAdam::new(config);
553
554 let mut param = Parameter::new(Array1::from_vec(vec![1.0, 2.0, 3.0]), "test_param");
555 param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
556
557 optimizer.step_param(&mut param).unwrap();
558
559 assert!(param.data[0] < 1.0);
561 assert!(param.data[1] < 2.0);
562 assert!(param.data[2] < 3.0);
563 }
564
565 #[test]
566 fn test_optimizer_factory() {
567 let config = OptimizerConfig::new(0.01f64).weight_decay(0.0001);
568 let _sgd = OptimizerFactory::sgd(config.clone());
569 let _adam = OptimizerFactory::adam(config);
570 }
571
572 #[test]
573 fn test_parameter_operations() {
574 let mut param = Parameter::new(Array1::from_vec(vec![1.0, 2.0, 3.0]), "test");
575
576 param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
578 assert!(param.grad().is_some());
579
580 param.clip_grad(0.1).unwrap();
582 let grad = param.grad().unwrap();
583 let norm: f64 = grad.iter().map(|x| x * x).sum::<f64>().sqrt();
584 assert!((norm - 0.1).abs() < 1e-10);
585
586 param.zero_grad();
588 assert!(param.grad().is_none());
589 }
590}