use crate::error::OptimizeResult;
use crate::result::OptimizeResults;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::random::{rng, Rng, RngExt};
pub mod actor_critic;
pub mod bandit_optimization;
pub mod evolutionary_strategies;
pub mod meta_learning;
pub mod policy_gradient;
pub mod q_learning_optimization;
#[allow(ambiguous_glob_reexports)]
pub use actor_critic::*;
#[allow(ambiguous_glob_reexports)]
pub use bandit_optimization::*;
#[allow(ambiguous_glob_reexports)]
pub use evolutionary_strategies::*;
#[allow(ambiguous_glob_reexports)]
pub use meta_learning::*;
#[allow(ambiguous_glob_reexports)]
pub use policy_gradient::*;
#[allow(ambiguous_glob_reexports)]
pub use q_learning_optimization::*;
#[derive(Debug, Clone)]
pub struct RLOptimizationConfig {
pub num_episodes: usize,
pub max_steps_per_episode: usize,
pub learning_rate: f64,
pub discount_factor: f64,
pub exploration_rate: f64,
pub exploration_decay: f64,
pub min_exploration_rate: f64,
pub batch_size: usize,
pub memory_size: usize,
pub use_experience_replay: bool,
}
impl Default for RLOptimizationConfig {
fn default() -> Self {
Self {
num_episodes: 1000,
max_steps_per_episode: 100,
learning_rate: 0.001,
discount_factor: 0.99,
exploration_rate: 0.1,
exploration_decay: 0.995,
min_exploration_rate: 0.01,
batch_size: 32,
memory_size: 10000,
use_experience_replay: true,
}
}
}
#[derive(Debug, Clone)]
pub struct OptimizationState {
pub parameters: Array1<f64>,
pub objective_value: f64,
pub gradient: Option<Array1<f64>>,
pub step: usize,
pub objective_history: Vec<f64>,
pub convergence_metrics: ConvergenceMetrics,
}
#[derive(Debug, Clone)]
pub struct ConvergenceMetrics {
pub relative_objective_change: f64,
pub gradient_norm: Option<f64>,
pub parameter_change_norm: f64,
pub steps_since_improvement: usize,
}
#[derive(Debug, Clone)]
pub enum OptimizationAction {
GradientStep { learning_rate: f64 },
RandomPerturbation { magnitude: f64 },
MomentumUpdate { momentum: f64 },
AdaptiveLearningRate { factor: f64 },
ResetToBest,
Terminate,
}
#[derive(Debug, Clone)]
pub struct Experience {
pub state: OptimizationState,
pub action: OptimizationAction,
pub reward: f64,
pub next_state: OptimizationState,
pub done: bool,
}
pub trait RLOptimizer {
fn config(&self) -> &RLOptimizationConfig;
fn select_action(&mut self, state: &OptimizationState) -> OptimizationAction;
fn update(&mut self, experience: &Experience) -> OptimizeResult<()>;
fn run_episode<F>(
&mut self,
objective: &F,
initial_params: &ArrayView1<f64>,
) -> OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64;
fn train<F>(
&mut self,
objective: &F,
initial_params: &ArrayView1<f64>,
) -> OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64;
fn reset(&mut self);
}
pub trait RewardFunction {
fn compute_reward(
&self,
prev_state: &OptimizationState,
action: &OptimizationAction,
new_state: &OptimizationState,
) -> f64;
}
#[derive(Debug, Clone)]
pub struct ImprovementReward {
pub improvement_scale: f64,
pub step_penalty: f64,
pub convergence_bonus: f64,
}
impl Default for ImprovementReward {
fn default() -> Self {
Self {
improvement_scale: 10.0,
step_penalty: 0.01,
convergence_bonus: 1.0,
}
}
}
impl RewardFunction for ImprovementReward {
fn compute_reward(
&self,
prev_state: &OptimizationState,
_action: &OptimizationAction,
new_state: &OptimizationState,
) -> f64 {
let improvement = prev_state.objective_value - new_state.objective_value;
let improvement_reward = self.improvement_scale * improvement;
let step_penalty = -self.step_penalty;
let convergence_bonus = if new_state.convergence_metrics.relative_objective_change < 1e-6 {
self.convergence_bonus
} else {
0.0
};
improvement_reward + step_penalty + convergence_bonus
}
}
#[derive(Debug, Clone)]
pub struct ExperienceBuffer {
pub buffer: Vec<Experience>,
pub max_size: usize,
pub position: usize,
}
impl ExperienceBuffer {
pub fn new(max_size: usize) -> Self {
Self {
buffer: Vec::with_capacity(max_size),
max_size,
position: 0,
}
}
pub fn add(&mut self, experience: Experience) {
if self.buffer.len() < self.max_size {
self.buffer.push(experience);
} else {
self.buffer[self.position] = experience;
self.position = (self.position + 1) % self.max_size;
}
}
pub fn sample_batch(&self, batchsize: usize) -> Vec<Experience> {
let mut batch = Vec::with_capacity(batchsize);
for _ in 0..batchsize.min(self.buffer.len()) {
let idx = scirs2_core::random::rng().random_range(0..self.buffer.len());
batch.push(self.buffer[idx].clone());
}
batch
}
pub fn size(&self) -> usize {
self.buffer.len()
}
}
pub mod utils {
use super::*;
pub fn create_state<F>(
parameters: Array1<f64>,
objective: &F,
step: usize,
prev_state: Option<&OptimizationState>,
) -> OptimizationState
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let objective_value = objective(¶meters.view());
let convergence_metrics = if let Some(prev) = prev_state {
let relative_change = (prev.objective_value - objective_value).abs()
/ (prev.objective_value.abs() + 1e-12);
let param_change = if parameters.len() == prev.parameters.len() {
(¶meters - &prev.parameters)
.mapv(|x| x * x)
.sum()
.sqrt()
} else {
parameters.mapv(|x| x * x).sum().sqrt()
};
let steps_since_improvement = if objective_value < prev.objective_value {
0
} else {
prev.convergence_metrics.steps_since_improvement + 1
};
ConvergenceMetrics {
relative_objective_change: relative_change,
gradient_norm: None,
parameter_change_norm: param_change,
steps_since_improvement,
}
} else {
ConvergenceMetrics {
relative_objective_change: f64::INFINITY,
gradient_norm: None,
parameter_change_norm: 0.0,
steps_since_improvement: 0,
}
};
let mut objective_history = prev_state
.map(|s| s.objective_history.clone())
.unwrap_or_default();
objective_history.push(objective_value);
if objective_history.len() > 10 {
objective_history.remove(0);
}
OptimizationState {
parameters,
objective_value,
gradient: None, step,
objective_history,
convergence_metrics,
}
}
pub fn apply_action(
state: &OptimizationState,
action: &OptimizationAction,
best_params: &Array1<f64>,
momentum: &mut Array1<f64>,
) -> Array1<f64> {
match action {
OptimizationAction::GradientStep { learning_rate } => {
let mut new_params = state.parameters.clone();
for i in 0..new_params.len() {
let step = (scirs2_core::random::rng().random::<f64>() - 0.5) * learning_rate;
new_params[i] += step;
}
new_params
}
OptimizationAction::RandomPerturbation { magnitude } => {
let mut new_params = state.parameters.clone();
for i in 0..new_params.len() {
let perturbation =
(scirs2_core::random::rng().random::<f64>() - 0.5) * 2.0 * magnitude;
new_params[i] += perturbation;
}
new_params
}
OptimizationAction::MomentumUpdate {
momentum: momentum_coeff,
} => {
if momentum.len() != state.parameters.len() {
*momentum = Array1::zeros(state.parameters.len());
}
for i in 0..momentum.len().min(state.parameters.len()) {
let gradient_estimate =
(scirs2_core::random::rng().random::<f64>() - 0.5) * 0.1;
momentum[i] =
momentum_coeff * momentum[i] + (1.0 - momentum_coeff) * gradient_estimate;
}
&state.parameters - &*momentum
}
OptimizationAction::AdaptiveLearningRate { factor: _factor } => {
let step_size = 0.01 / (1.0 + state.step as f64 * 0.01);
let direction = Array1::from(vec![step_size; state.parameters.len()]);
&state.parameters - &direction
}
OptimizationAction::ResetToBest => best_params.clone(),
OptimizationAction::Terminate => state.parameters.clone(),
}
}
pub fn should_terminate(state: &OptimizationState, max_steps: usize) -> bool {
state.step >= max_steps
|| state.convergence_metrics.relative_objective_change < 1e-8
|| state.convergence_metrics.steps_since_improvement > 50
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimization_state_creation() {
let params = Array1::from(vec![1.0, 2.0]);
let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
let state = utils::create_state(params, &objective, 0, None);
assert_eq!(state.parameters.len(), 2);
assert_eq!(state.objective_value, 5.0);
assert_eq!(state.step, 0);
}
#[test]
fn test_experience_buffer() {
let mut buffer = ExperienceBuffer::new(5);
let params = Array1::from(vec![1.0]);
let objective = |x: &ArrayView1<f64>| x[0].powi(2);
let state = utils::create_state(params.clone(), &objective, 0, None);
let experience = Experience {
state: state.clone(),
action: OptimizationAction::GradientStep {
learning_rate: 0.01,
},
reward: 1.0,
next_state: state,
done: false,
};
buffer.add(experience);
assert_eq!(buffer.size(), 1);
let batch = buffer.sample_batch(1);
assert_eq!(batch.len(), 1);
}
#[test]
fn test_improvement_reward() {
let reward_fn = ImprovementReward::default();
let params1 = Array1::from(vec![2.0]);
let params2 = Array1::from(vec![1.0]);
let objective = |x: &ArrayView1<f64>| x[0].powi(2);
let state1 = utils::create_state(params1, &objective, 0, None);
let state2 = utils::create_state(params2, &objective, 1, Some(&state1));
let action = OptimizationAction::GradientStep { learning_rate: 0.1 };
let reward = reward_fn.compute_reward(&state1, &action, &state2);
assert!(reward > 0.0);
}
#[test]
fn test_action_application() {
let params = Array1::from(vec![1.0, 2.0]);
let objective = |x: &ArrayView1<f64>| x[0].powi(2) + x[1].powi(2);
let state = utils::create_state(params.clone(), &objective, 0, None);
let mut momentum = Array1::zeros(2);
let action = OptimizationAction::RandomPerturbation { magnitude: 0.1 };
let new_params = utils::apply_action(&state, &action, ¶ms, &mut momentum);
assert_eq!(new_params.len(), 2);
assert!(new_params != state.parameters);
}
#[test]
fn test_termination_condition() {
let params = Array1::from(vec![1.0]);
let objective = |x: &ArrayView1<f64>| x[0].powi(2);
let state = utils::create_state(params, &objective, 100, None);
assert!(utils::should_terminate(&state, 50));
}
#[test]
fn test_convergence_metrics() {
let params1 = Array1::from(vec![2.0]);
let params2 = Array1::from(vec![1.9]);
let objective = |x: &ArrayView1<f64>| x[0].powi(2);
let state1 = utils::create_state(params1, &objective, 0, None);
let state2 = utils::create_state(params2, &objective, 1, Some(&state1));
assert!(state2.convergence_metrics.relative_objective_change > 0.0);
assert!(state2.convergence_metrics.parameter_change_norm > 0.0);
assert_eq!(state2.convergence_metrics.steps_since_improvement, 0); }
}