optirs_core/
unified_api.rs

1// Unified API consistent with popular deep learning frameworks
2//
3// This module provides a unified interface that closely follows the design patterns
4// of popular deep learning frameworks like PyTorch, TensorFlow, and JAX/Optax.
5//
6// # Design Principles
7//
8// - **Parameter Groups**: Support for different optimization parameters for different layers
9// - **State Management**: Automatic handling of optimizer state
10// - **Framework Consistency**: APIs that feel familiar to PyTorch/TensorFlow users
11// - **Flexible Configuration**: Easy-to-use builder patterns
12// - **Scheduler Integration**: Seamless integration with learning rate schedulers
13
14use crate::error::{OptimError, Result};
15use crate::schedulers::LearningRateScheduler;
16use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
17use scirs2_core::numeric::Float;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22/// Unified optimizer configuration
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct OptimizerConfig<A: Float> {
25    /// Learning rate
26    pub lr: A,
27    /// Weight decay (L2 regularization)
28    pub weight_decay: A,
29    /// Gradient clipping value (optional)
30    pub grad_clip: Option<A>,
31    /// Additional optimizer-specific parameters
32    pub params: HashMap<String, A>,
33}
34
35impl<A: Float + Send + Sync> Default for OptimizerConfig<A> {
36    fn default() -> Self {
37        Self {
38            lr: A::from(0.001).unwrap(),
39            weight_decay: A::zero(),
40            grad_clip: None,
41            params: HashMap::new(),
42        }
43    }
44}
45
46impl<A: Float + Send + Sync> OptimizerConfig<A> {
47    /// Create a new optimizer configuration with the given learning rate
48    pub fn new(lr: A) -> Self {
49        Self {
50            lr,
51            ..Default::default()
52        }
53    }
54
55    /// Set weight decay
56    pub fn weight_decay(mut self, weightdecay: A) -> Self {
57        self.weight_decay = weightdecay;
58        self
59    }
60
61    /// Set gradient clipping
62    pub fn grad_clip(mut self, gradclip: A) -> Self {
63        self.grad_clip = Some(gradclip);
64        self
65    }
66
67    /// Add a custom parameter
68    pub fn param<S: Into<String>>(mut self, key: S, value: A) -> Self {
69        self.params.insert(key.into(), value);
70        self
71    }
72
73    /// Set multiple parameters at once
74    pub fn params(mut self, params: HashMap<String, A>) -> Self {
75        self.params.extend(params);
76        self
77    }
78}
79
80/// Parameter tensor wrapper for unified API
81#[derive(Debug, Clone)]
82pub struct Parameter<A: Float, D: Dimension> {
83    /// Parameter data
84    pub data: Array<A, D>,
85    /// Gradient data (optional)
86    pub grad: Option<Array<A, D>>,
87    /// Whether this parameter requires gradients
88    pub requires_grad: bool,
89    /// Parameter name/identifier
90    pub name: String,
91}
92
93impl<A: Float + ScalarOperand, D: Dimension + Send + Sync> Parameter<A, D> {
94    /// Create a new parameter
95    pub fn new<S: Into<String>>(data: Array<A, D>, name: S) -> Self {
96        Self {
97            data,
98            grad: None,
99            requires_grad: true,
100            name: name.into(),
101        }
102    }
103
104    /// Create a parameter that doesn't require gradients
105    pub fn no_grad<S: Into<String>>(data: Array<A, D>, name: S) -> Self {
106        Self {
107            data,
108            grad: None,
109            requires_grad: false,
110            name: name.into(),
111        }
112    }
113
114    /// Set gradient for this parameter
115    pub fn set_grad(&mut self, grad: Array<A, D>) {
116        if self.requires_grad {
117            self.grad = Some(grad);
118        }
119    }
120
121    /// Clear gradients
122    pub fn zero_grad(&mut self) {
123        self.grad = None;
124    }
125
126    /// Get gradient reference
127    pub fn grad(&self) -> Option<&Array<A, D>> {
128        self.grad.as_ref()
129    }
130
131    /// Apply gradient clipping if specified
132    pub fn clip_grad(&mut self, maxnorm: A) -> Result<()> {
133        if let Some(ref mut grad) = self.grad {
134            let _norm = grad
135                .iter()
136                .map(|x| (*x) * (*x))
137                .fold(A::zero(), |acc, x| acc + x)
138                .sqrt();
139            if _norm > maxnorm {
140                let scale = maxnorm / _norm;
141                grad.mapv_inplace(|x| x * scale);
142            }
143        }
144        Ok(())
145    }
146}
147
148/// Unified optimizer interface
149pub trait UnifiedOptimizer<A: Float> {
150    /// Get optimizer configuration
151    fn config(&self) -> &OptimizerConfig<A>;
152
153    /// Update a single parameter
154    fn step_param<D: Dimension>(&mut self, param: &mut Parameter<A, D>) -> Result<()>
155    where
156        A: ScalarOperand + Debug;
157
158    /// Update multiple parameters
159    fn step_params<D: Dimension>(&mut self, params: &mut [Parameter<A, D>]) -> Result<()>
160    where
161        A: ScalarOperand + Debug,
162    {
163        for param in params.iter_mut() {
164            self.step_param(param)?;
165        }
166        Ok(())
167    }
168
169    /// Zero gradients for all parameters
170    fn zero_grad<D: Dimension>(&self, params: &mut [Parameter<A, D>]) {
171        for param in params.iter_mut() {
172            param.grad = None;
173        }
174    }
175
176    /// Update learning rate
177    fn set_lr(&mut self, lr: A);
178
179    /// Get current learning rate
180    fn get_lr(&self) -> A;
181
182    /// State dictionary for serialization
183    fn state_dict(&self) -> HashMap<String, Vec<u8>>;
184
185    /// Load state from dictionary
186    fn load_state_dict(&mut self, statedict: HashMap<String, Vec<u8>>) -> Result<()>;
187}
188
189/// SGD optimizer with unified API
190#[derive(Debug)]
191pub struct UnifiedSGD<A: Float> {
192    config: OptimizerConfig<A>,
193    momentum_buffers: HashMap<String, Array1<A>>,
194}
195
196impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedSGD<A> {
197    /// Create a new SGD optimizer
198    pub fn new(config: OptimizerConfig<A>) -> Self {
199        Self {
200            config,
201            momentum_buffers: HashMap::new(),
202        }
203    }
204
205    /// Create SGD with momentum
206    pub fn with_momentum(mut config: OptimizerConfig<A>, momentum: A) -> Self {
207        config.params.insert("momentum".to_string(), momentum);
208        Self::new(config)
209    }
210}
211
212impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedOptimizer<A> for UnifiedSGD<A> {
213    fn config(&self) -> &OptimizerConfig<A> {
214        &self.config
215    }
216
217    fn step_param<D: Dimension>(&mut self, param: &mut Parameter<A, D>) -> Result<()> {
218        if !param.requires_grad {
219            return Ok(());
220        }
221
222        // Check gradient exists first
223        if param.grad.is_none() {
224            return Err(OptimError::InvalidConfig(
225                "Parameter has no gradient".to_string(),
226            ));
227        }
228
229        // Apply gradient clipping if configured
230        if let Some(max_norm) = self.config.grad_clip {
231            param.clip_grad(max_norm)?;
232        }
233
234        // Apply weight decay
235        if self.config.weight_decay > A::zero() {
236            param
237                .data
238                .mapv_inplace(|x| x * (A::one() - self.config.weight_decay * self.config.lr));
239        }
240
241        // Get gradient safely
242        let grad = param.grad.as_ref().unwrap();
243
244        // Get momentum factor
245        let momentum = self
246            .config
247            .params
248            .get("momentum")
249            .copied()
250            .unwrap_or(A::zero());
251
252        if momentum > A::zero() {
253            // SGD with momentum
254            if let Some(momentum_buffer) = self.momentum_buffers.get_mut(&param.name) {
255                // Update momentum buffer
256                for (m, g) in momentum_buffer.iter_mut().zip(grad.iter()) {
257                    *m = momentum * (*m) + *g;
258                }
259                // Update parameters
260                for (p, m) in param.data.iter_mut().zip(momentum_buffer.iter()) {
261                    *p = *p - self.config.lr * (*m);
262                }
263            } else {
264                // Initialize momentum buffer
265                let mut momentum_buffer = Array1::zeros(grad.len());
266                for (m, g) in momentum_buffer.iter_mut().zip(grad.iter()) {
267                    *m = *g;
268                }
269                // Update parameters
270                for (p, m) in param.data.iter_mut().zip(momentum_buffer.iter()) {
271                    *p = *p - self.config.lr * (*m);
272                }
273                self.momentum_buffers
274                    .insert(param.name.clone(), momentum_buffer);
275            }
276        } else {
277            // Standard SGD
278            for (p, g) in param.data.iter_mut().zip(grad.iter()) {
279                *p = *p - self.config.lr * (*g);
280            }
281        }
282
283        Ok(())
284    }
285
286    fn set_lr(&mut self, lr: A) {
287        self.config.lr = lr;
288    }
289
290    fn get_lr(&self) -> A {
291        self.config.lr
292    }
293
294    fn state_dict(&self) -> HashMap<String, Vec<u8>> {
295        // Simplified state serialization
296        HashMap::new()
297    }
298
299    fn load_state_dict(&mut self, _statedict: HashMap<String, Vec<u8>>) -> Result<()> {
300        // Simplified state deserialization
301        Ok(())
302    }
303}
304
305/// Adam optimizer with unified API
306#[derive(Debug)]
307pub struct UnifiedAdam<A: Float> {
308    config: OptimizerConfig<A>,
309    step_count: usize,
310    exp_avg: HashMap<String, Array1<A>>,
311    exp_avg_sq: HashMap<String, Array1<A>>,
312}
313
314impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedAdam<A> {
315    /// Create a new Adam optimizer
316    pub fn new(config: OptimizerConfig<A>) -> Self {
317        let mut params = config.params.clone();
318        params
319            .entry("beta1".to_string())
320            .or_insert_with(|| A::from(0.9).unwrap());
321        params
322            .entry("beta2".to_string())
323            .or_insert_with(|| A::from(0.999).unwrap());
324        params
325            .entry("eps".to_string())
326            .or_insert_with(|| A::from(1e-8).unwrap());
327
328        Self {
329            config: OptimizerConfig { params, ..config },
330            step_count: 0,
331            exp_avg: HashMap::new(),
332            exp_avg_sq: HashMap::new(),
333        }
334    }
335
336    /// Create Adam with custom betas
337    pub fn with_betas(mut config: OptimizerConfig<A>, beta1: A, beta2: A) -> Self {
338        config.params.insert("beta1".to_string(), beta1);
339        config.params.insert("beta2".to_string(), beta2);
340        Self::new(config)
341    }
342}
343
344impl<A: Float + ScalarOperand + Debug + Send + Sync> UnifiedOptimizer<A> for UnifiedAdam<A> {
345    fn config(&self) -> &OptimizerConfig<A> {
346        &self.config
347    }
348
349    fn step_param<D: Dimension>(&mut self, param: &mut Parameter<A, D>) -> Result<()> {
350        if !param.requires_grad {
351            return Ok(());
352        }
353
354        // Check gradient exists first
355        if param.grad.is_none() {
356            return Err(OptimError::InvalidConfig(
357                "Parameter has no gradient".to_string(),
358            ));
359        }
360
361        // Apply gradient clipping if configured
362        if let Some(max_norm) = self.config.grad_clip {
363            param.clip_grad(max_norm)?;
364        }
365
366        self.step_count += 1;
367
368        let beta1 = self.config.params["beta1"];
369        let beta2 = self.config.params["beta2"];
370        let eps = self.config.params["eps"];
371
372        // Get gradient safely
373        let grad = param.grad.as_ref().unwrap();
374
375        // Initialize or get existing moment estimates
376        let exp_avg = self
377            .exp_avg
378            .entry(param.name.clone())
379            .or_insert_with(|| Array1::zeros(grad.len()));
380        let exp_avg_sq = self
381            .exp_avg_sq
382            .entry(param.name.clone())
383            .or_insert_with(|| Array1::zeros(grad.len()));
384
385        // Update biased first and second moment estimates
386        for ((exp_avg_val, exp_avg_sq_val), grad_val) in exp_avg
387            .iter_mut()
388            .zip(exp_avg_sq.iter_mut())
389            .zip(grad.iter())
390        {
391            *exp_avg_val = beta1 * (*exp_avg_val) + (A::one() - beta1) * (*grad_val);
392            *exp_avg_sq_val =
393                beta2 * (*exp_avg_sq_val) + (A::one() - beta2) * (*grad_val) * (*grad_val);
394        }
395
396        // Bias correction
397        let bias_correction1 = A::one() - beta1.powi(self.step_count as i32);
398        let bias_correction2 = A::one() - beta2.powi(self.step_count as i32);
399
400        let step_size = self.config.lr * (bias_correction2.sqrt() / bias_correction1);
401
402        // Update parameters
403        for ((p, exp_avg_val), exp_avg_sq_val) in param
404            .data
405            .iter_mut()
406            .zip(exp_avg.iter())
407            .zip(exp_avg_sq.iter())
408        {
409            let denom = exp_avg_sq_val.sqrt() + eps;
410            *p = *p - step_size * (*exp_avg_val) / denom;
411        }
412
413        // Apply weight decay after the main update
414        if self.config.weight_decay > A::zero() {
415            param
416                .data
417                .mapv_inplace(|x| x * (A::one() - self.config.weight_decay * self.config.lr));
418        }
419
420        Ok(())
421    }
422
423    fn set_lr(&mut self, lr: A) {
424        self.config.lr = lr;
425    }
426
427    fn get_lr(&self) -> A {
428        self.config.lr
429    }
430
431    fn state_dict(&self) -> HashMap<String, Vec<u8>> {
432        // Simplified state serialization
433        HashMap::new()
434    }
435
436    fn load_state_dict(&mut self, _statedict: HashMap<String, Vec<u8>>) -> Result<()> {
437        // Simplified state deserialization
438        Ok(())
439    }
440}
441
442/// Optimizer factory for creating optimizers with unified API
443pub struct OptimizerFactory;
444
445impl OptimizerFactory {
446    /// Create SGD optimizer
447    pub fn sgd<A: Float + ScalarOperand + Debug + Send + Sync>(
448        config: OptimizerConfig<A>,
449    ) -> UnifiedSGD<A> {
450        UnifiedSGD::new(config)
451    }
452
453    /// Create Adam optimizer
454    pub fn adam<A: Float + ScalarOperand + Debug + Send + Sync>(
455        config: OptimizerConfig<A>,
456    ) -> UnifiedAdam<A> {
457        UnifiedAdam::new(config)
458    }
459
460    /// Create SGD with momentum
461    pub fn sgd_momentum<A: Float + ScalarOperand + Debug + Send + Sync>(
462        config: OptimizerConfig<A>,
463        momentum: A,
464    ) -> UnifiedSGD<A> {
465        UnifiedSGD::with_momentum(config, momentum)
466    }
467
468    /// Create Adam with custom parameters
469    pub fn adam_custom<A: Float + ScalarOperand + Debug + Send + Sync>(
470        config: OptimizerConfig<A>,
471        beta1: A,
472        beta2: A,
473    ) -> UnifiedAdam<A> {
474        UnifiedAdam::with_betas(config, beta1, beta2)
475    }
476}
477
478/// Training loop helper with unified API
479pub struct TrainingLoop<A: Float, O: UnifiedOptimizer<A>> {
480    optimizer: O,
481    scheduler: Option<Box<dyn LearningRateScheduler<A>>>,
482    _phantom: std::marker::PhantomData<A>,
483}
484
485impl<A: Float + ScalarOperand + Debug, O: UnifiedOptimizer<A> + Send + Sync> TrainingLoop<A, O> {
486    /// Create a new training loop
487    pub fn new(optimizer: O) -> Self {
488        Self {
489            optimizer,
490            scheduler: None,
491            _phantom: std::marker::PhantomData,
492        }
493    }
494
495    /// Add a learning rate scheduler
496    pub fn with_scheduler(mut self, scheduler: Box<dyn LearningRateScheduler<A>>) -> Self {
497        self.scheduler = Some(scheduler);
498        self
499    }
500
501    /// Perform one training step
502    pub fn step<D: Dimension>(&mut self, params: &mut [Parameter<A, D>]) -> Result<()> {
503        // Update parameters
504        self.optimizer.step_params(params)?;
505
506        // Update learning rate if scheduler is present
507        if let Some(ref mut scheduler) = self.scheduler {
508            let new_lr = scheduler.step();
509            self.optimizer.set_lr(new_lr);
510        }
511
512        Ok(())
513    }
514
515    /// Zero gradients
516    pub fn zero_grad<D: Dimension>(&self, params: &mut [Parameter<A, D>]) {
517        for param in params.iter_mut() {
518            param.grad = None;
519        }
520    }
521
522    /// Get current learning rate
523    pub fn get_lr(&self) -> A {
524        self.optimizer.get_lr()
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use scirs2_core::ndarray::Array1;
532
533    #[test]
534    fn test_unified_sgd() {
535        let config = OptimizerConfig::new(0.1f64);
536        let mut optimizer = UnifiedSGD::new(config);
537
538        let mut param = Parameter::new(Array1::from_vec(vec![1.0, 2.0, 3.0]), "test_param");
539        param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
540
541        optimizer.step_param(&mut param).unwrap();
542
543        // Check that parameters were updated correctly
544        assert!((param.data[0] - 0.99).abs() < 1e-10);
545        assert!((param.data[1] - 1.98).abs() < 1e-10);
546        assert!((param.data[2] - 2.97).abs() < 1e-10);
547    }
548
549    #[test]
550    fn test_unified_adam() {
551        let config = OptimizerConfig::new(0.001f64);
552        let mut optimizer = UnifiedAdam::new(config);
553
554        let mut param = Parameter::new(Array1::from_vec(vec![1.0, 2.0, 3.0]), "test_param");
555        param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
556
557        optimizer.step_param(&mut param).unwrap();
558
559        // Parameters should have been updated (exact values depend on Adam's internal state)
560        assert!(param.data[0] < 1.0);
561        assert!(param.data[1] < 2.0);
562        assert!(param.data[2] < 3.0);
563    }
564
565    #[test]
566    fn test_optimizer_factory() {
567        let config = OptimizerConfig::new(0.01f64).weight_decay(0.0001);
568        let _sgd = OptimizerFactory::sgd(config.clone());
569        let _adam = OptimizerFactory::adam(config);
570    }
571
572    #[test]
573    fn test_parameter_operations() {
574        let mut param = Parameter::new(Array1::from_vec(vec![1.0, 2.0, 3.0]), "test");
575
576        // Test gradient setting
577        param.set_grad(Array1::from_vec(vec![0.1, 0.2, 0.3]));
578        assert!(param.grad().is_some());
579
580        // Test gradient clipping
581        param.clip_grad(0.1).unwrap();
582        let grad = param.grad().unwrap();
583        let norm: f64 = grad.iter().map(|x| x * x).sum::<f64>().sqrt();
584        assert!((norm - 0.1).abs() < 1e-10);
585
586        // Test zero grad
587        param.zero_grad();
588        assert!(param.grad().is_none());
589    }
590}