Skip to main content

optirs_learned/domain_optimizers/
mod.rs

1//! Domain-Specific Optimizers
2//!
3//! This module provides specialized optimizers tailored for specific deep learning
4//! domains: computer vision, natural language processing, and attention mechanisms.
5//!
6//! Each optimizer incorporates domain knowledge to improve convergence and
7//! generalization compared to generic optimizers.
8
9pub mod attention_optimizer;
10pub mod cv_optimizer;
11pub mod nlp_optimizer;
12
13pub use attention_optimizer::AttentionOptimizer;
14pub use cv_optimizer::CVOptimizer;
15pub use nlp_optimizer::NLPOptimizer;
16
17use crate::error::Result;
18use scirs2_core::ndarray::Array1;
19use scirs2_core::numeric::Float;
20use std::fmt::Debug;
21
22/// Information about the current state of an optimizer.
23#[derive(Debug, Clone)]
24pub struct OptimizerStateInfo<T: Float + Debug + Send + Sync + 'static> {
25    /// Total number of optimization steps taken
26    pub step_count: usize,
27    /// Current effective learning rate
28    pub current_lr: T,
29    /// Exponential moving average of gradient norms
30    pub grad_norm_ema: T,
31}
32
33/// Trait for domain-specific advanced optimizers.
34///
35/// Provides a common interface for optimizers that incorporate domain knowledge
36/// (e.g., spatial awareness for CV, layer-wise decay for NLP, head-wise scaling
37/// for attention) on top of standard gradient-based updates.
38pub trait AdvancedOptimizer<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
39    /// Perform one optimization step, returning the updated parameters.
40    ///
41    /// # Arguments
42    /// * `params` - Current parameter values
43    /// * `gradients` - Gradient of the loss with respect to `params`
44    ///
45    /// # Returns
46    /// Updated parameter values after applying the optimizer step.
47    fn step(&mut self, params: &Array1<T>, gradients: &Array1<T>) -> Result<Array1<T>>;
48
49    /// Get the current effective learning rate.
50    fn get_learning_rate(&self) -> T;
51
52    /// Set the base learning rate.
53    fn set_learning_rate(&mut self, lr: T);
54
55    /// Get the optimizer name.
56    fn name(&self) -> &str;
57
58    /// Get a snapshot of the optimizer's internal state.
59    fn get_state(&self) -> OptimizerStateInfo<T>;
60}
61
62/// Compute the L2 norm of an array.
63pub(crate) fn l2_norm<T: Float + Debug + Send + Sync + 'static>(arr: &Array1<T>) -> T {
64    let sum_sq = arr.iter().fold(T::zero(), |acc, &x| acc + x * x);
65    sum_sq.sqrt()
66}
67
68/// Clip a gradient array so that its L2 norm does not exceed `max_norm`.
69/// Returns a new array (clipped copy) if the norm exceeds the threshold,
70/// otherwise returns a clone.
71pub(crate) fn clip_grad_norm<T: Float + Debug + Send + Sync + 'static>(
72    grad: &Array1<T>,
73    max_norm: T,
74) -> Array1<T> {
75    let norm = l2_norm(grad);
76    if norm > max_norm && norm > T::zero() {
77        let scale = max_norm / norm;
78        grad.mapv(|g| g * scale)
79    } else {
80        grad.clone()
81    }
82}