optirs_learned/domain_optimizers/mod.rs
1//! Domain-Specific Optimizers
2//!
3//! This module provides specialized optimizers tailored for specific deep learning
4//! domains: computer vision, natural language processing, and attention mechanisms.
5//!
6//! Each optimizer incorporates domain knowledge to improve convergence and
7//! generalization compared to generic optimizers.
8
9pub mod attention_optimizer;
10pub mod cv_optimizer;
11pub mod nlp_optimizer;
12
13pub use attention_optimizer::AttentionOptimizer;
14pub use cv_optimizer::CVOptimizer;
15pub use nlp_optimizer::NLPOptimizer;
16
17use crate::error::Result;
18use scirs2_core::ndarray::Array1;
19use scirs2_core::numeric::Float;
20use std::fmt::Debug;
21
22/// Information about the current state of an optimizer.
23#[derive(Debug, Clone)]
24pub struct OptimizerStateInfo<T: Float + Debug + Send + Sync + 'static> {
25 /// Total number of optimization steps taken
26 pub step_count: usize,
27 /// Current effective learning rate
28 pub current_lr: T,
29 /// Exponential moving average of gradient norms
30 pub grad_norm_ema: T,
31}
32
33/// Trait for domain-specific advanced optimizers.
34///
35/// Provides a common interface for optimizers that incorporate domain knowledge
36/// (e.g., spatial awareness for CV, layer-wise decay for NLP, head-wise scaling
37/// for attention) on top of standard gradient-based updates.
38pub trait AdvancedOptimizer<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
39 /// Perform one optimization step, returning the updated parameters.
40 ///
41 /// # Arguments
42 /// * `params` - Current parameter values
43 /// * `gradients` - Gradient of the loss with respect to `params`
44 ///
45 /// # Returns
46 /// Updated parameter values after applying the optimizer step.
47 fn step(&mut self, params: &Array1<T>, gradients: &Array1<T>) -> Result<Array1<T>>;
48
49 /// Get the current effective learning rate.
50 fn get_learning_rate(&self) -> T;
51
52 /// Set the base learning rate.
53 fn set_learning_rate(&mut self, lr: T);
54
55 /// Get the optimizer name.
56 fn name(&self) -> &str;
57
58 /// Get a snapshot of the optimizer's internal state.
59 fn get_state(&self) -> OptimizerStateInfo<T>;
60}
61
62/// Compute the L2 norm of an array.
63pub(crate) fn l2_norm<T: Float + Debug + Send + Sync + 'static>(arr: &Array1<T>) -> T {
64 let sum_sq = arr.iter().fold(T::zero(), |acc, &x| acc + x * x);
65 sum_sq.sqrt()
66}
67
68/// Clip a gradient array so that its L2 norm does not exceed `max_norm`.
69/// Returns a new array (clipped copy) if the norm exceeds the threshold,
70/// otherwise returns a clone.
71pub(crate) fn clip_grad_norm<T: Float + Debug + Send + Sync + 'static>(
72 grad: &Array1<T>,
73 max_norm: T,
74) -> Array1<T> {
75 let norm = l2_norm(grad);
76 if norm > max_norm && norm > T::zero() {
77 let scale = max_norm / norm;
78 grad.mapv(|g| g * scale)
79 } else {
80 grad.clone()
81 }
82}