use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use scirs2_core::random::{thread_rng, Random};
use std::collections::{HashMap, VecDeque};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub enum OnlineLearningStrategy {
AdaptiveSGD {
initial_lr: f64,
adaptation_method: LearningRateAdaptation,
},
OnlineNewton {
damping: f64,
hessian_window: usize,
},
FTRL {
l1_regularization: f64,
l2_regularization: f64,
learning_rate_power: f64,
},
MirrorDescent {
mirror_function: MirrorFunction,
regularization: f64,
},
AdaptiveMultiTask {
similarity_threshold: f64,
task_lr_adaptation: bool,
},
}
#[derive(Debug, Clone)]
pub enum LearningRateAdaptation {
AdaGrad {
epsilon: f64,
},
RMSprop {
decay: f64,
epsilon: f64,
},
Adam {
beta1: f64,
beta2: f64,
epsilon: f64,
},
ExponentialDecay {
decay_rate: f64,
},
InverseScaling {
power: f64,
},
}
#[derive(Debug, Clone)]
pub enum MirrorFunction {
Euclidean,
Entropy,
L1,
Nuclear,
}
#[derive(Debug, Clone)]
pub enum LifelongStrategy {
ElasticWeightConsolidation {
importance_weight: f64,
fisher_samples: usize,
},
ProgressiveNetworks {
lateral_strength: f64,
growth_strategy: ColumnGrowthStrategy,
},
MemoryAugmented {
memory_size: usize,
update_strategy: MemoryUpdateStrategy,
},
MetaLearning {
meta_lr: f64,
inner_steps: usize,
task_embedding_size: usize,
},
GradientEpisodicMemory {
memory_per_task: usize,
violation_tolerance: f64,
},
}
#[derive(Debug, Clone)]
pub enum ColumnGrowthStrategy {
PerTask,
PerformanceBased {
threshold: f64,
},
FixedInterval {
interval: usize,
},
}
#[derive(Debug, Clone)]
pub enum MemoryUpdateStrategy {
FIFO,
Random,
ImportanceBased,
GradientDiversity,
}
#[derive(Debug)]
pub struct OnlineOptimizer<A: Float, D: Dimension> {
strategy: OnlineLearningStrategy,
parameters: Array<A, D>,
gradient_accumulator: Array<A, D>,
second_moment_accumulator: Option<Array<A, D>>,
current_lr: A,
step_count: usize,
performance_history: VecDeque<A>,
regret_bound: A,
}
#[derive(Debug)]
pub struct LifelongOptimizer<A: Float, D: Dimension> {
strategy: LifelongStrategy,
task_optimizers: HashMap<String, OnlineOptimizer<A, D>>,
#[allow(dead_code)]
shared_knowledge: SharedKnowledge<A, D>,
task_graph: TaskGraph,
memory_buffer: MemoryBuffer<A, D>,
current_task: Option<String>,
task_performance: HashMap<String, Vec<A>>,
}
#[derive(Debug)]
pub struct SharedKnowledge<A: Float, D: Dimension> {
#[allow(dead_code)]
fisher_information: Option<Array<A, D>>,
#[allow(dead_code)]
important_parameters: Option<Array<A, D>>,
#[allow(dead_code)]
task_embeddings: HashMap<String, Array1<A>>,
#[allow(dead_code)]
transfer_weights: HashMap<(String, String), A>,
#[allow(dead_code)]
meta_parameters: Option<Array1<A>>,
}
#[derive(Debug)]
pub struct TaskGraph {
task_similarities: HashMap<(String, String), f64>,
#[allow(dead_code)]
task_dependencies: HashMap<String, Vec<String>>,
#[allow(dead_code)]
task_clusters: HashMap<String, String>,
}
#[derive(Debug)]
pub struct MemoryBuffer<A: Float, D: Dimension> {
examples: VecDeque<MemoryExample<A, D>>,
max_size: usize,
update_strategy: MemoryUpdateStrategy,
importance_scores: VecDeque<A>,
}
#[derive(Debug, Clone)]
pub struct MemoryExample<A: Float, D: Dimension> {
pub input: Array<A, D>,
pub target: Array<A, D>,
pub task_id: String,
pub importance: A,
pub gradient: Option<Array<A, D>>,
}
#[derive(Debug, Clone)]
pub struct OnlinePerformanceMetrics<A: Float> {
pub cumulative_regret: A,
pub average_loss: A,
pub lr_stability: A,
pub adaptation_speed: A,
pub memory_efficiency: A,
}
impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
OnlineOptimizer<A, D>
{
pub fn new(strategy: OnlineLearningStrategy, initial_parameters: Array<A, D>) -> Self {
let paramshape = initial_parameters.raw_dim();
let gradient_accumulator = Array::zeros(paramshape.clone());
let second_moment_accumulator = match &strategy {
OnlineLearningStrategy::AdaptiveSGD {
adaptation_method: LearningRateAdaptation::Adam { .. },
..
} => Some(Array::zeros(paramshape)),
_ => None,
};
let current_lr = match &strategy {
OnlineLearningStrategy::AdaptiveSGD { initial_lr, .. } => {
A::from(*initial_lr).expect("unwrap failed")
}
OnlineLearningStrategy::OnlineNewton { .. } => A::from(0.01).expect("unwrap failed"),
OnlineLearningStrategy::FTRL { .. } => A::from(0.1).expect("unwrap failed"),
OnlineLearningStrategy::MirrorDescent { .. } => A::from(0.01).expect("unwrap failed"),
OnlineLearningStrategy::AdaptiveMultiTask { .. } => {
A::from(0.001).expect("unwrap failed")
}
};
Self {
strategy,
parameters: initial_parameters,
gradient_accumulator,
second_moment_accumulator,
current_lr,
step_count: 0,
performance_history: VecDeque::new(),
regret_bound: A::zero(),
}
}
pub fn online_update(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
self.step_count += 1;
self.performance_history.push_back(loss);
if self.performance_history.len() > 1000 {
self.performance_history.pop_front();
}
match self.strategy.clone() {
OnlineLearningStrategy::AdaptiveSGD {
adaptation_method, ..
} => {
self.adaptive_sgd_update(gradient, &adaptation_method)?;
}
OnlineLearningStrategy::OnlineNewton { damping, .. } => {
self.online_newton_update(gradient, damping)?;
}
OnlineLearningStrategy::FTRL {
l1_regularization,
l2_regularization,
learning_rate_power,
} => {
self.ftrl_update(
gradient,
l1_regularization,
l2_regularization,
learning_rate_power,
)?;
}
OnlineLearningStrategy::MirrorDescent {
mirror_function,
regularization,
} => {
self.mirror_descent_update(gradient, &mirror_function, regularization)?;
}
OnlineLearningStrategy::AdaptiveMultiTask { .. } => {
self.adaptive_multitask_update(gradient)?;
}
}
self.update_regret_bound(loss);
Ok(())
}
fn adaptive_sgd_update(
&mut self,
gradient: &Array<A, D>,
adaptation: &LearningRateAdaptation,
) -> Result<()> {
match adaptation {
LearningRateAdaptation::AdaGrad { epsilon } => {
self.gradient_accumulator = &self.gradient_accumulator + &gradient.mapv(|g| g * g);
let adaptive_lr = self
.gradient_accumulator
.mapv(|acc| A::from(*epsilon).expect("unwrap failed") + A::sqrt(acc));
self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
}
LearningRateAdaptation::RMSprop { decay, epsilon } => {
let decay_factor = A::from(*decay).expect("unwrap failed");
let one_minus_decay = A::one() - decay_factor;
self.gradient_accumulator = &self.gradient_accumulator * decay_factor
+ &gradient.mapv(|g| g * g * one_minus_decay);
let adaptive_lr = self
.gradient_accumulator
.mapv(|acc| A::sqrt(acc + A::from(*epsilon).expect("unwrap failed")));
self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
}
LearningRateAdaptation::Adam {
beta1,
beta2,
epsilon,
} => {
let beta1_val = A::from(*beta1).expect("unwrap failed");
let beta2_val = A::from(*beta2).expect("unwrap failed");
let one_minus_beta1 = A::one() - beta1_val;
let one_minus_beta2 = A::one() - beta2_val;
self.gradient_accumulator =
&self.gradient_accumulator * beta1_val + gradient * one_minus_beta1;
if let Some(ref mut second_moment) = self.second_moment_accumulator {
*second_moment =
&*second_moment * beta2_val + &gradient.mapv(|g| g * g * one_minus_beta2);
let step_count_float = A::from(self.step_count).expect("unwrap failed");
let bias_correction1 = A::one() - A::powf(beta1_val, step_count_float);
let bias_correction2 = A::one() - A::powf(beta2_val, step_count_float);
let corrected_first = &self.gradient_accumulator / bias_correction1;
let corrected_second = &*second_moment / bias_correction2;
let adaptive_lr = corrected_second
.mapv(|v| A::sqrt(v) + A::from(*epsilon).expect("unwrap failed"));
self.parameters =
&self.parameters - &(corrected_first / adaptive_lr * self.current_lr);
}
}
LearningRateAdaptation::ExponentialDecay { decay_rate } => {
self.current_lr = self.current_lr * A::from(*decay_rate).expect("unwrap failed");
self.parameters = &self.parameters - gradient * self.current_lr;
}
LearningRateAdaptation::InverseScaling { power } => {
let step_power = A::powf(
A::from(self.step_count).expect("unwrap failed"),
A::from(*power).expect("unwrap failed"),
);
let decayed_lr = self.current_lr / step_power;
self.parameters = &self.parameters - gradient * decayed_lr;
}
}
Ok(())
}
fn online_newton_update(&mut self, gradient: &Array<A, D>, damping: f64) -> Result<()> {
let damping_val = A::from(damping).expect("unwrap failed");
let hessian_approx = gradient.mapv(|g| g * g + damping_val);
let newton_step = gradient / hessian_approx;
self.parameters = &self.parameters - &newton_step * self.current_lr;
Ok(())
}
fn ftrl_update(
&mut self,
gradient: &Array<A, D>,
l1_reg: f64,
l2_reg: f64,
lr_power: f64,
) -> Result<()> {
self.gradient_accumulator = &self.gradient_accumulator + gradient;
let step_factor = A::powf(
A::from(self.step_count).expect("unwrap failed"),
A::from(lr_power).expect("unwrap failed"),
);
let learning_rate = self.current_lr / step_factor;
let l1_weight = A::from(l1_reg).expect("unwrap failed");
let l2_weight = A::from(l2_reg).expect("unwrap failed");
self.parameters = self.gradient_accumulator.mapv(|g| {
let abs_g = A::abs(g);
if abs_g <= l1_weight {
A::zero()
} else {
let sign = if g > A::zero() { A::one() } else { -A::one() };
-sign * (abs_g - l1_weight) / (l2_weight + A::sqrt(abs_g))
}
}) * learning_rate;
Ok(())
}
fn mirror_descent_update(
&mut self,
gradient: &Array<A, D>,
mirror_fn: &MirrorFunction,
regularization: f64,
) -> Result<()> {
match mirror_fn {
MirrorFunction::Euclidean => {
self.parameters = &self.parameters - gradient * self.current_lr;
}
MirrorFunction::Entropy => {
let reg_val = A::from(regularization).expect("unwrap failed");
let updated = self
.parameters
.mapv(|p| A::exp(A::ln(p) - self.current_lr * reg_val));
let sum = updated.sum();
self.parameters = updated / sum; }
MirrorFunction::L1 => {
let threshold = self.current_lr * A::from(regularization).expect("unwrap failed");
self.parameters = (&self.parameters - gradient * self.current_lr).mapv(|p| {
if A::abs(p) <= threshold {
A::zero()
} else {
p - A::signum(p) * threshold
}
});
}
MirrorFunction::Nuclear => {
self.parameters = &self.parameters - gradient * self.current_lr;
}
}
Ok(())
}
fn adaptive_multitask_update(&mut self, gradient: &Array<A, D>) -> Result<()> {
self.parameters = &self.parameters - gradient * self.current_lr;
Ok(())
}
fn update_regret_bound(&mut self, loss: A) {
if let Some(&best_loss) = self
.performance_history
.iter()
.min_by(|a, b| a.partial_cmp(b).expect("unwrap failed"))
{
let regret = loss - best_loss;
self.regret_bound = self.regret_bound + regret.max(A::zero());
}
}
pub fn parameters(&self) -> &Array<A, D> {
&self.parameters
}
pub fn get_performance_metrics(&self) -> OnlinePerformanceMetrics<A> {
let average_loss = if self.performance_history.is_empty() {
A::zero()
} else {
self.performance_history.iter().copied().sum::<A>()
/ A::from(self.performance_history.len()).expect("unwrap failed")
};
let lr_stability = A::from(1.0).expect("unwrap failed"); let adaptation_speed = A::from(self.step_count as f64).expect("unwrap failed"); let memory_efficiency = A::from(0.8).expect("unwrap failed");
OnlinePerformanceMetrics {
cumulative_regret: self.regret_bound,
average_loss,
lr_stability,
adaptation_speed,
memory_efficiency,
}
}
}
impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
LifelongOptimizer<A, D>
{
pub fn new(strategy: LifelongStrategy) -> Self {
Self {
strategy,
task_optimizers: HashMap::new(),
shared_knowledge: SharedKnowledge {
fisher_information: None,
important_parameters: None,
task_embeddings: HashMap::new(),
transfer_weights: HashMap::new(),
meta_parameters: None,
},
task_graph: TaskGraph {
task_similarities: HashMap::new(),
task_dependencies: HashMap::new(),
task_clusters: HashMap::new(),
},
memory_buffer: MemoryBuffer {
examples: VecDeque::new(),
max_size: 1000,
update_strategy: MemoryUpdateStrategy::FIFO,
importance_scores: VecDeque::new(),
},
current_task: None,
task_performance: HashMap::new(),
}
}
pub fn start_task(&mut self, task_id: String, initial_parameters: Array<A, D>) -> Result<()> {
self.current_task = Some(task_id.clone());
let online_strategy = OnlineLearningStrategy::AdaptiveSGD {
initial_lr: 0.001,
adaptation_method: LearningRateAdaptation::Adam {
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
},
};
let task_optimizer = OnlineOptimizer::new(online_strategy, initial_parameters);
self.task_optimizers.insert(task_id.clone(), task_optimizer);
self.task_performance.insert(task_id, Vec::new());
Ok(())
}
pub fn update_current_task(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
let task_id = self
.current_task
.as_ref()
.ok_or_else(|| OptimError::InvalidConfig("No current task set".to_string()))?
.clone();
if let Some(optimizer) = self.task_optimizers.get_mut(&task_id) {
optimizer.online_update(gradient, loss)?;
}
if let Some(performance) = self.task_performance.get_mut(&task_id) {
performance.push(loss);
}
match &self.strategy {
LifelongStrategy::ElasticWeightConsolidation {
importance_weight, ..
} => {
self.apply_ewc_regularization(gradient, *importance_weight)?;
}
LifelongStrategy::ProgressiveNetworks { .. } => {
self.apply_progressive_networks(gradient)?;
}
LifelongStrategy::MemoryAugmented { .. } => {
self.update_memory_buffer(gradient, loss)?;
}
LifelongStrategy::MetaLearning { .. } => {
self.apply_meta_learning(gradient)?;
}
LifelongStrategy::GradientEpisodicMemory { .. } => {
self.apply_gem_constraints(gradient)?;
}
}
Ok(())
}
fn apply_ewc_regularization(
&mut self,
gradient: &Array<A, D>,
_importance_weight: f64,
) -> Result<()> {
Ok(())
}
fn apply_progressive_networks(&mut self, gradient: &Array<A, D>) -> Result<()> {
Ok(())
}
fn update_memory_buffer(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
if let Some(task_id) = &self.current_task {
let example = MemoryExample {
input: Array::zeros(gradient.raw_dim()), target: Array::zeros(gradient.raw_dim()), task_id: task_id.clone(),
importance: loss,
gradient: Some(gradient.clone()),
};
if self.memory_buffer.examples.len() >= self.memory_buffer.max_size {
match self.memory_buffer.update_strategy {
MemoryUpdateStrategy::FIFO => {
self.memory_buffer.examples.pop_front();
self.memory_buffer.importance_scores.pop_front();
}
MemoryUpdateStrategy::Random => {
let idx = thread_rng().gen_range(0..self.memory_buffer.examples.len());
self.memory_buffer.examples.remove(idx);
self.memory_buffer.importance_scores.remove(idx);
}
MemoryUpdateStrategy::ImportanceBased => {
if let Some(min_idx) = self
.memory_buffer
.importance_scores
.iter()
.enumerate()
.min_by(|a, b| a.1.partial_cmp(b.1).expect("unwrap failed"))
.map(|(idx, _)| idx)
{
self.memory_buffer.examples.remove(min_idx);
self.memory_buffer.importance_scores.remove(min_idx);
}
}
MemoryUpdateStrategy::GradientDiversity => {
self.memory_buffer.examples.pop_front();
self.memory_buffer.importance_scores.pop_front();
}
}
}
self.memory_buffer.examples.push_back(example);
self.memory_buffer.importance_scores.push_back(loss);
}
Ok(())
}
fn apply_meta_learning(&mut self, gradient: &Array<A, D>) -> Result<()> {
Ok(())
}
fn apply_gem_constraints(&mut self, gradient: &Array<A, D>) -> Result<()> {
Ok(())
}
pub fn compute_task_similarity(&self, task1: &str, task2: &str) -> f64 {
self.task_graph
.task_similarities
.get(&(task1.to_string(), task2.to_string()))
.or_else(|| {
self.task_graph
.task_similarities
.get(&(task2.to_string(), task1.to_string()))
})
.copied()
.unwrap_or(0.0)
}
pub fn get_lifelong_stats(&self) -> LifelongStats<A> {
let num_tasks = self.task_optimizers.len();
let avg_performance = if self.task_performance.is_empty() {
A::zero()
} else {
let total_performance: A = self.task_performance.values().flatten().copied().sum();
let total_samples = self
.task_performance
.values()
.map(|v| v.len())
.sum::<usize>();
if total_samples > 0 {
total_performance / A::from(total_samples).expect("unwrap failed")
} else {
A::zero()
}
};
LifelongStats {
num_tasks,
average_performance: avg_performance,
memory_usage: self.memory_buffer.examples.len(),
transfer_efficiency: A::from(0.8).expect("unwrap failed"), catastrophic_forgetting: A::from(0.1).expect("unwrap failed"), }
}
}
#[derive(Debug, Clone)]
pub struct LifelongStats<A: Float> {
pub num_tasks: usize,
pub average_performance: A,
pub memory_usage: usize,
pub transfer_efficiency: A,
pub catastrophic_forgetting: A,
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_online_optimizer_creation() {
let strategy = OnlineLearningStrategy::AdaptiveSGD {
initial_lr: 0.01,
adaptation_method: LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
};
let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let optimizer = OnlineOptimizer::new(strategy, initial_params);
assert_eq!(optimizer.step_count, 0);
assert_relative_eq!(optimizer.current_lr, 0.01, epsilon = 1e-6);
}
#[test]
fn test_online_update() {
let strategy = OnlineLearningStrategy::AdaptiveSGD {
initial_lr: 0.1,
adaptation_method: LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
};
let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let loss = 0.5;
optimizer
.online_update(&gradient, loss)
.expect("unwrap failed");
assert_eq!(optimizer.step_count, 1);
assert_eq!(optimizer.performance_history.len(), 1);
assert_relative_eq!(optimizer.performance_history[0], 0.5, epsilon = 1e-6);
}
#[test]
fn test_lifelong_optimizer_creation() {
let strategy = LifelongStrategy::ElasticWeightConsolidation {
importance_weight: 1000.0,
fisher_samples: 100,
};
let optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
assert_eq!(optimizer.task_optimizers.len(), 0);
assert!(optimizer.current_task.is_none());
}
#[test]
fn test_task_management() {
let strategy = LifelongStrategy::MemoryAugmented {
memory_size: 100,
update_strategy: MemoryUpdateStrategy::FIFO,
};
let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
optimizer
.start_task("task1".to_string(), initial_params)
.expect("unwrap failed");
assert_eq!(optimizer.current_task, Some("task1".to_string()));
assert!(optimizer.task_optimizers.contains_key("task1"));
assert!(optimizer.task_performance.contains_key("task1"));
}
#[test]
fn test_memory_buffer_update() {
let strategy = LifelongStrategy::MemoryAugmented {
memory_size: 2,
update_strategy: MemoryUpdateStrategy::FIFO,
};
let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
optimizer.memory_buffer.max_size = 2;
let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
optimizer
.start_task("task1".to_string(), initial_params)
.expect("unwrap failed");
let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
optimizer
.update_current_task(&gradient, 0.5)
.expect("unwrap failed");
assert_eq!(optimizer.memory_buffer.examples.len(), 1);
optimizer
.update_current_task(&gradient, 0.6)
.expect("unwrap failed");
assert_eq!(optimizer.memory_buffer.examples.len(), 2);
optimizer
.update_current_task(&gradient, 0.7)
.expect("unwrap failed");
assert_eq!(optimizer.memory_buffer.examples.len(), 2);
}
#[test]
fn test_performance_metrics() {
let strategy = OnlineLearningStrategy::AdaptiveSGD {
initial_lr: 0.01,
adaptation_method: LearningRateAdaptation::Adam {
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
},
};
let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
optimizer.performance_history.push_back(0.8);
optimizer.performance_history.push_back(0.6);
optimizer.performance_history.push_back(0.4);
optimizer.regret_bound = 0.5;
let metrics = optimizer.get_performance_metrics();
assert_relative_eq!(metrics.cumulative_regret, 0.5, epsilon = 1e-6);
assert_relative_eq!(metrics.average_loss, 0.6, epsilon = 1e-6);
}
#[test]
fn test_lifelong_stats() {
let strategy = LifelongStrategy::MetaLearning {
meta_lr: 0.001,
inner_steps: 5,
task_embedding_size: 64,
};
let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
optimizer
.start_task("task1".to_string(), initial_params.clone())
.expect("unwrap failed");
optimizer
.start_task("task2".to_string(), initial_params)
.expect("unwrap failed");
optimizer
.task_performance
.get_mut("task1")
.expect("unwrap failed")
.extend(vec![0.8, 0.7]);
optimizer
.task_performance
.get_mut("task2")
.expect("unwrap failed")
.extend(vec![0.9, 0.8]);
let stats = optimizer.get_lifelong_stats();
assert_eq!(stats.num_tasks, 2);
assert_relative_eq!(stats.average_performance, 0.8, epsilon = 1e-6);
}
#[test]
fn test_learning_rate_adaptations() {
let strategies = vec![
LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
LearningRateAdaptation::RMSprop {
decay: 0.9,
epsilon: 1e-8,
},
LearningRateAdaptation::Adam {
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
},
LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
LearningRateAdaptation::InverseScaling { power: 0.5 },
];
for adaptation in strategies {
let strategy = OnlineLearningStrategy::AdaptiveSGD {
initial_lr: 0.01,
adaptation_method: adaptation,
};
let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let result = optimizer.online_update(&gradient, 0.5);
assert!(result.is_ok());
assert_eq!(optimizer.step_count, 1);
}
}
}