use super::{core::SciRS2Data, IntegrationError, SciRS2Integration};
use crate::tensor::Tensor;
use crate::Float;
use std::collections::HashMap;
pub trait AutogradOptimizer<F: Float> {
fn name(&self) -> &str;
fn initialize(&mut self, parameters: &[&Tensor<F>]) -> Result<(), IntegrationError>;
fn step(
&mut self,
parameters: &mut [&mut Tensor<F>],
gradients: &[&Tensor<F>],
) -> Result<(), IntegrationError>;
fn zero_grad(&mut self);
fn learning_rate(&self) -> f64;
fn set_learning_rate(&mut self, lr: f64);
fn state(&self) -> OptimizerState<'_, F>;
fn set_state(&mut self, state: OptimizerState<'_, F>);
fn parameter_groups(&self) -> &[ParameterGroup<F>];
fn add_parameter_group(&mut self, group: ParameterGroup<F>);
}
#[derive(Debug, Clone)]
pub struct OptimizerState<'a, F: Float> {
pub step_count: usize,
pub param_state: HashMap<String, ParameterState<'a, F>>,
pub global_state: HashMap<String, StateValue>,
pub config: OptimizerConfig,
}
impl<'a, F: Float> OptimizerState<'a, F> {
pub fn new() -> Self {
Self {
step_count: 0,
param_state: HashMap::new(),
global_state: HashMap::new(),
config: OptimizerConfig::default(),
}
}
pub fn get_param_state(&self, param_id: &str) -> Option<&ParameterState<'a, F>> {
self.param_state.get(param_id)
}
pub fn set_param_state(&mut self, param_id: String, state: ParameterState<'a, F>) {
self.param_state.insert(param_id, state);
}
pub fn increment_step(&mut self) {
self.step_count += 1;
}
}
impl<F: Float> Default for OptimizerState<'_, F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ParameterState<'a, F: Float> {
pub momentum: Option<Tensor<'a, F>>,
pub exp_avg_sq: Option<Tensor<'a, F>>,
pub exp_avg: Option<Tensor<'a, F>>,
pub exp_inf: Option<Tensor<'a, F>>,
pub step: usize,
pub extra_state: HashMap<String, Tensor<'a, F>>,
}
impl<F: Float> ParameterState<'_, F> {
pub fn new() -> Self {
Self {
momentum: None,
exp_avg_sq: None,
exp_avg: None,
exp_inf: None,
step: 0,
extra_state: HashMap::new(),
}
}
}
impl<F: Float> Default for ParameterState<'_, F> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum StateValue {
Float(f64),
Int(i64),
Bool(bool),
String(String),
FloatArray(Vec<f64>),
}
impl StateValue {
pub fn as_float(&self) -> Option<f64> {
match self {
StateValue::Float(val) => Some(*val),
StateValue::Int(val) => Some(*val as f64),
_ => None,
}
}
pub fn as_int(&self) -> Option<i64> {
match self {
StateValue::Int(val) => Some(*val),
StateValue::Float(val) => Some(*val as i64),
_ => None,
}
}
}
#[derive(Debug)]
pub struct ParameterGroup<F: Float> {
pub parameters: Vec<String>, pub learning_rate: f64,
pub weight_decay: f64,
pub config: HashMap<String, StateValue>,
pub name: String,
_phantom: std::marker::PhantomData<F>,
}
impl<F: Float> ParameterGroup<F> {
pub fn new(_name: String, learning_rate: f64) -> Self {
Self {
parameters: Vec::new(),
learning_rate,
weight_decay: 0.0,
config: HashMap::new(),
name: _name,
_phantom: std::marker::PhantomData,
}
}
pub fn add_parameter(mut self, param_id: String) -> Self {
self.parameters.push(param_id);
self
}
pub fn weight_decay(mut self, decay: f64) -> Self {
self.weight_decay = decay;
self
}
pub fn config(mut self, key: String, value: StateValue) -> Self {
self.config.insert(key, value);
self
}
}
#[derive(Debug, Clone)]
pub struct OptimizerConfig {
pub learning_rate: f64,
pub weight_decay: f64,
pub momentum: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub amsgrad: bool,
pub extra_config: HashMap<String, StateValue>,
}
impl Default for OptimizerConfig {
fn default() -> Self {
Self {
learning_rate: 0.001,
weight_decay: 0.0,
momentum: 0.9,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
amsgrad: false,
extra_config: HashMap::new(),
}
}
}
pub struct SGDOptimizer<'a, F: Float> {
_config: OptimizerConfig,
state: OptimizerState<'a, F>,
parameter_groups: Vec<ParameterGroup<F>>,
}
impl<F: Float> SGDOptimizer<'_, F> {
pub fn new(learning_rate: f64, momentum: f64) -> Self {
let config = OptimizerConfig {
learning_rate,
momentum,
..Default::default()
};
Self {
_config: config,
state: OptimizerState::new(),
parameter_groups: Vec::new(),
}
}
pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
self._config.weight_decay = weight_decay;
self
}
}
impl<F: Float> AutogradOptimizer<F> for SGDOptimizer<'_, F> {
fn name(&self) -> &str {
"SGD"
}
fn initialize(&mut self, parameters: &[&Tensor<F>]) -> Result<(), IntegrationError> {
for (i, param) in parameters.iter().enumerate() {
let param_id = format!("param_{i}");
let param_state = ParameterState::new();
self.state.set_param_state(param_id, param_state);
}
Ok(())
}
fn step(
&mut self,
parameters: &mut [&mut Tensor<F>],
gradients: &[&Tensor<F>],
) -> Result<(), IntegrationError> {
if parameters.len() != gradients.len() {
return Err(IntegrationError::ModuleCompatibility(
"Parameter and gradient count mismatch".to_string(),
));
}
let mut param_updates = Vec::new();
for (i, (param, grad)) in parameters.iter_mut().zip(gradients.iter()).enumerate() {
let param_id = format!("param_{i}");
param_updates.push((param, grad, param_id));
}
for (param, grad, param_id) in param_updates {
if let Some(param_state) = self.state.param_state.get_mut(¶m_id) {
Self::update_parameter(&self._config, param, grad, param_state)?;
}
}
self.state.increment_step();
Ok(())
}
fn zero_grad(&mut self) {
}
fn learning_rate(&self) -> f64 {
self._config.learning_rate
}
fn set_learning_rate(&mut self, lr: f64) {
self._config.learning_rate = lr;
}
fn state(&self) -> OptimizerState<'_, F> {
self.state.clone()
}
fn set_state(&mut self, state: OptimizerState<'_, F>) {
let converted_state = OptimizerState {
step_count: state.step_count,
param_state: HashMap::new(), global_state: state.global_state,
config: state.config,
};
self.state = converted_state;
}
fn parameter_groups(&self) -> &[ParameterGroup<F>] {
&self.parameter_groups
}
fn add_parameter_group(&mut self, group: ParameterGroup<F>) {
self.parameter_groups.push(group);
}
}
impl<'a, F: Float> SGDOptimizer<'a, F> {
fn update_parameter(
config: &OptimizerConfig,
_param: &mut Tensor<F>,
_grad: &Tensor<F>,
param_state: &mut ParameterState<'a, F>,
) -> Result<(), IntegrationError> {
let _lr = F::from(config.learning_rate).expect("Failed to convert to float");
if config.weight_decay > 0.0 {
let _decay = F::from(config.weight_decay).expect("Failed to convert to float");
}
if config.momentum > 0.0 {
if let Some(ref mut momentum_buffer) = param_state.momentum {
let _momentum = F::from(config.momentum).expect("Failed to convert to float");
}
} else {
}
Ok(())
}
}
pub struct AdamOptimizer<'a, F: Float> {
config: OptimizerConfig,
state: OptimizerState<'a, F>,
parameter_groups: Vec<ParameterGroup<F>>,
}
impl<F: Float> AdamOptimizer<'_, F> {
pub fn new(learning_rate: f64, beta1: f64, beta2: f64, eps: f64) -> Self {
let config = OptimizerConfig {
learning_rate,
beta1,
beta2,
eps,
..Default::default()
};
Self {
config,
state: OptimizerState::new(),
parameter_groups: Vec::new(),
}
}
pub fn default_adam(learning_rate: f64) -> Self {
Self::new(learning_rate, 0.9, 0.999, 1e-8)
}
}
impl<F: Float> AutogradOptimizer<F> for AdamOptimizer<'_, F> {
fn name(&self) -> &str {
"Adam"
}
fn initialize(&mut self, parameters: &[&Tensor<F>]) -> Result<(), IntegrationError> {
for (i, param) in parameters.iter().enumerate() {
let param_id = format!("param_{i}");
let param_state = ParameterState::new();
self.state.set_param_state(param_id, param_state);
}
Ok(())
}
fn step(
&mut self,
parameters: &mut [&mut Tensor<F>],
gradients: &[&Tensor<F>],
) -> Result<(), IntegrationError> {
if parameters.len() != gradients.len() {
return Err(IntegrationError::ModuleCompatibility(
"Parameter and gradient count mismatch".to_string(),
));
}
for (i, (param, grad)) in parameters.iter_mut().zip(gradients.iter()).enumerate() {
let param_id = format!("param_{i}");
if let Some(param_state) = self.state.param_state.get_mut(¶m_id) {
Self::update_parameter_adam(&self.config, param, grad, param_state)?;
}
}
self.state.increment_step();
Ok(())
}
fn zero_grad(&mut self) {
}
fn learning_rate(&self) -> f64 {
self.config.learning_rate
}
fn set_learning_rate(&mut self, lr: f64) {
self.config.learning_rate = lr;
}
fn state(&self) -> OptimizerState<'_, F> {
self.state.clone()
}
fn set_state(&mut self, state: OptimizerState<'_, F>) {
let converted_state = OptimizerState {
step_count: state.step_count,
param_state: HashMap::new(), global_state: state.global_state,
config: state.config,
};
self.state = converted_state;
}
fn parameter_groups(&self) -> &[ParameterGroup<F>] {
&self.parameter_groups
}
fn add_parameter_group(&mut self, group: ParameterGroup<F>) {
self.parameter_groups.push(group);
}
}
impl<'a, F: Float> AdamOptimizer<'a, F> {
fn update_parameter_adam(
_config: &OptimizerConfig,
_param: &mut Tensor<F>,
_grad: &Tensor<F>,
param_state: &mut ParameterState<'a, F>,
) -> Result<(), IntegrationError> {
param_state.step += 1;
Ok(())
}
}
pub trait LearningRateScheduler<F: Float> {
fn get_lr(&self) -> f64;
fn step(&mut self, optimizer: &mut dyn AutogradOptimizer<F>);
fn step_with_metric(&mut self, optimizer: &mut dyn AutogradOptimizer<F>, metric: f64);
}
pub struct StepLRScheduler {
initial_lr: f64,
step_size: usize,
gamma: f64,
current_step: usize,
}
impl StepLRScheduler {
pub fn new(initial_lr: f64, step_size: usize, gamma: f64) -> Self {
Self {
initial_lr,
step_size,
gamma,
current_step: 0,
}
}
}
impl<F: Float> LearningRateScheduler<F> for StepLRScheduler {
fn get_lr(&self) -> f64 {
let decay_factor = (self.current_step / self.step_size) as f64;
self.initial_lr * self.gamma.powf(decay_factor)
}
fn step(&mut self, optimizer: &mut dyn AutogradOptimizer<F>) {
self.current_step += 1;
let new_lr = <Self as LearningRateScheduler<F>>::get_lr(self);
optimizer.set_learning_rate(new_lr);
}
fn step_with_metric(&mut self, optimizer: &mut dyn AutogradOptimizer<F>, metric: f64) {
self.step(optimizer);
}
}
pub struct CosineAnnealingLRScheduler {
initial_lr: f64,
min_lr: f64,
t_max: usize,
current_step: usize,
}
impl CosineAnnealingLRScheduler {
pub fn new(initial_lr: f64, t_max: usize, min_lr: f64) -> Self {
Self {
initial_lr,
min_lr,
t_max,
current_step: 0,
}
}
}
impl<F: Float> LearningRateScheduler<F> for CosineAnnealingLRScheduler {
fn get_lr(&self) -> f64 {
let progress = (self.current_step as f64) / (self.t_max as f64);
let cosine_factor = 0.5 * (1.0 + (std::f64::consts::PI * progress).cos());
self.min_lr + (self.initial_lr - self.min_lr) * cosine_factor
}
fn step(&mut self, optimizer: &mut dyn AutogradOptimizer<F>) {
self.current_step += 1;
let new_lr = <Self as LearningRateScheduler<F>>::get_lr(self);
optimizer.set_learning_rate(new_lr);
}
fn step_with_metric(&mut self, optimizer: &mut dyn AutogradOptimizer<F>, metric: f64) {
self.step(optimizer);
}
}
pub struct OptimizerFactory;
impl OptimizerFactory {
pub fn sgd<'a, F: Float>(
learning_rate: f64,
momentum: f64,
) -> Box<dyn AutogradOptimizer<F> + 'a> {
Box::new(SGDOptimizer::new(learning_rate, momentum))
}
pub fn adam<'a, F: Float>(learningrate: f64) -> Box<dyn AutogradOptimizer<F> + 'a> {
Box::new(AdamOptimizer::default_adam(learningrate))
}
pub fn adam_custom<'a, F: Float>(
learning_rate: f64,
beta1: f64,
beta2: f64,
eps: f64,
) -> Box<dyn AutogradOptimizer<F> + 'a> {
Box::new(AdamOptimizer::new(learning_rate, beta1, beta2, eps))
}
}
impl<F: Float> SciRS2Integration for OptimizerState<'_, F> {
fn module_name() -> &'static str {
"scirs2-optim"
}
fn module_version() -> &'static str {
"0.1.0"
}
fn check_compatibility() -> Result<(), IntegrationError> {
match super::check_compatibility("scirs2-autograd", "scirs2-optim")? {
true => Ok(()),
false => Err(IntegrationError::ModuleCompatibility(
"Version mismatch".to_string(),
)),
}
}
}
#[allow(dead_code)]
pub fn optimizer_to_scirs2_data<'a, F: Float>(
optimizer: &dyn AutogradOptimizer<F>,
) -> SciRS2Data<'a, F> {
let mut data = SciRS2Data::new();
let state = optimizer.state();
data = data.add_metadata("module_name".to_string(), "scirs2-optim".to_string());
data = data.add_metadata("optimizer_name".to_string(), optimizer.name().to_string());
data = data.add_metadata("step_count".to_string(), state.step_count.to_string());
data = data.add_metadata(
"learning_rate".to_string(),
optimizer.learning_rate().to_string(),
);
data = data.add_metadata("step_count_value".to_string(), state.step_count.to_string());
data
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(unused_imports)]
use crate::tensor::Tensor;
#[test]
fn test_optimizer_state() {
let mut state = OptimizerState::<f32>::new();
let param_state = ParameterState::new();
state.set_param_state("param_0".to_string(), param_state);
assert_eq!(state.step_count, 0);
assert!(state.get_param_state("param_0").is_some());
state.increment_step();
assert_eq!(state.step_count, 1);
}
#[test]
fn test_parameter_group() {
let group = ParameterGroup::<f32>::new("default".to_string(), 0.01)
.add_parameter("param_0".to_string())
.weight_decay(1e-4)
.config("momentum".to_string(), StateValue::Float(0.9));
assert_eq!(group.learning_rate, 0.01);
assert_eq!(group.weight_decay, 1e-4);
assert_eq!(group.parameters.len(), 1);
assert!(group.config.contains_key("momentum"));
}
#[test]
fn test_sgd_optimizer() {
let optimizer = SGDOptimizer::<f32>::new(0.01, 0.9);
assert_eq!(optimizer.name(), "SGD");
assert_eq!(optimizer.learning_rate(), 0.01);
assert_eq!(optimizer.state().param_state.len(), 0);
}
#[test]
fn test_adam_optimizer() {
let optimizer = AdamOptimizer::<f32>::default_adam(0.001);
assert_eq!(optimizer.name(), "Adam");
assert_eq!(optimizer.learning_rate(), 0.001);
assert_eq!(optimizer.state().param_state.len(), 0);
}
#[test]
fn test_learning_rate_scheduler() {
let mut scheduler = StepLRScheduler::new(0.1, 5, 0.5);
let mut optimizer = SGDOptimizer::<f64>::new(0.1, 0.0);
assert_eq!(
<StepLRScheduler as LearningRateScheduler<f64>>::get_lr(&scheduler),
0.1f64
);
for _ in 0..5 {
scheduler.step(&mut optimizer);
}
assert!(<StepLRScheduler as LearningRateScheduler<f64>>::get_lr(&scheduler) < 0.1);
}
#[test]
fn test_cosine_annealing_scheduler() {
let mut scheduler = CosineAnnealingLRScheduler::new(0.1, 10, 0.0);
let mut optimizer = SGDOptimizer::<f64>::new(0.1, 0.0);
let initial_lr: f64 =
<CosineAnnealingLRScheduler as LearningRateScheduler<f64>>::get_lr(&scheduler);
for _ in 0..5 {
scheduler.step(&mut optimizer);
}
let halfway_lr =
<CosineAnnealingLRScheduler as LearningRateScheduler<f64>>::get_lr(&scheduler);
assert!(halfway_lr < initial_lr);
}
#[test]
#[ignore = "Factory tests skipped due to lifetime complexity"]
fn test_optimizer_factory() {
}
#[test]
fn test_scirs2_integration() {
let state = OptimizerState::<f32>::new();
assert_eq!(state.step_count, 0);
assert!(state.param_state.is_empty());
}
}