use super::{utils, OptimizationAction, OptimizationState, RLOptimizationConfig, RLOptimizer};
use crate::error::{OptimizeError, OptimizeResult};
use crate::result::OptimizeResults;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::random::{rng, Rng, RngExt};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct QLearningOptimizer {
config: RLOptimizationConfig,
q_table: HashMap<String, f64>,
exploration_rate: f64,
best_params: Array1<f64>,
best_objective: f64,
}
impl QLearningOptimizer {
pub fn new(config: RLOptimizationConfig, numparams: usize) -> Self {
let exploration_rate = config.exploration_rate;
Self {
config,
q_table: HashMap::new(),
exploration_rate,
best_params: Array1::zeros(numparams),
best_objective: f64::INFINITY,
}
}
fn state_action_key(&self, state: &OptimizationState, action: &OptimizationAction) -> String {
let obj_bucket = (state.objective_value * 10.0) as i32;
let step_bucket = state.step / 10;
let action_id = match action {
OptimizationAction::GradientStep { .. } => 0,
OptimizationAction::RandomPerturbation { .. } => 1,
OptimizationAction::MomentumUpdate { .. } => 2,
OptimizationAction::AdaptiveLearningRate { .. } => 3,
OptimizationAction::ResetToBest => 4,
OptimizationAction::Terminate => 5,
};
format!("{}_{}__{}", obj_bucket, step_bucket, action_id)
}
fn get_q_value(&self, state: &OptimizationState, action: &OptimizationAction) -> f64 {
let key = self.state_action_key(state, action);
*self.q_table.get(&key).unwrap_or(&0.0)
}
fn update_q_value(
&mut self,
state: &OptimizationState,
action: &OptimizationAction,
new_value: f64,
) {
let key = self.state_action_key(state, action);
self.q_table.insert(key, new_value);
}
fn get_possible_actions(&self) -> Vec<OptimizationAction> {
vec![
OptimizationAction::GradientStep {
learning_rate: 0.01,
},
OptimizationAction::RandomPerturbation { magnitude: 0.1 },
OptimizationAction::MomentumUpdate { momentum: 0.9 },
OptimizationAction::AdaptiveLearningRate { factor: 0.5 },
OptimizationAction::ResetToBest,
OptimizationAction::Terminate,
]
}
}
impl RLOptimizer for QLearningOptimizer {
fn config(&self) -> &RLOptimizationConfig {
&self.config
}
fn select_action(&mut self, state: &OptimizationState) -> OptimizationAction {
if scirs2_core::random::rng().random_range(0.0..1.0) < self.exploration_rate {
let actions = self.get_possible_actions();
let idx = scirs2_core::random::rng().random_range(0..actions.len());
actions[idx].clone()
} else {
let actions = self.get_possible_actions();
let mut best_action = actions[0].clone();
let mut best_q = self.get_q_value(state, &best_action);
for action in &actions[1..] {
let q_value = self.get_q_value(state, action);
if q_value > best_q {
best_q = q_value;
best_action = action.clone();
}
}
best_action
}
}
fn update(&mut self, experience: &super::Experience) -> Result<(), OptimizeError> {
let current_q = self.get_q_value(&experience.state, &experience.action);
let max_next_q = if experience.done {
0.0
} else {
let actions = self.get_possible_actions();
actions
.iter()
.map(|a| self.get_q_value(&experience.next_state, a))
.fold(f64::NEG_INFINITY, f64::max)
};
let target = experience.reward + self.config.discount_factor * max_next_q;
let new_q = current_q + self.config.learning_rate * (target - current_q);
self.update_q_value(&experience.state, &experience.action, new_q);
Ok(())
}
fn run_episode<F>(
&mut self,
objective: &F,
initial_params: &ArrayView1<f64>,
) -> OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let mut current_params = initial_params.to_owned();
let mut current_state = utils::create_state(current_params.clone(), objective, 0, None);
let mut momentum = Array1::zeros(initial_params.len());
for step in 0..self.config.max_steps_per_episode {
let action = self.select_action(¤t_state);
let new_params =
utils::apply_action(¤t_state, &action, &self.best_params, &mut momentum);
let new_state =
utils::create_state(new_params, objective, step + 1, Some(¤t_state));
let reward = current_state.objective_value - new_state.objective_value;
let experience = super::Experience {
state: current_state.clone(),
action: action.clone(),
reward,
next_state: new_state.clone(),
done: utils::should_terminate(&new_state, self.config.max_steps_per_episode),
};
self.update(&experience)?;
if new_state.objective_value < self.best_objective {
self.best_objective = new_state.objective_value;
self.best_params = new_state.parameters.clone();
}
current_state = new_state;
current_params = current_state.parameters.clone();
if utils::should_terminate(¤t_state, self.config.max_steps_per_episode)
|| matches!(action, OptimizationAction::Terminate)
{
break;
}
}
self.exploration_rate = (self.exploration_rate * self.config.exploration_decay)
.max(self.config.min_exploration_rate);
Ok(OptimizeResults::<f64> {
x: current_params,
fun: current_state.objective_value,
success: current_state.convergence_metrics.relative_objective_change < 1e-6,
nit: current_state.step,
message: "Q-learning episode completed".to_string(),
jac: None,
hess: None,
constr: None,
nfev: current_state.step,
njev: 0,
nhev: 0,
maxcv: 0,
status: if current_state.convergence_metrics.relative_objective_change < 1e-6 {
0
} else {
1
},
})
}
fn train<F>(
&mut self,
objective: &F,
initial_params: &ArrayView1<f64>,
) -> OptimizeResult<OptimizeResults<f64>>
where
F: Fn(&ArrayView1<f64>) -> f64,
{
let mut best_result = OptimizeResults::<f64> {
x: initial_params.to_owned(),
fun: f64::INFINITY,
success: false,
nit: 0,
nfev: 0,
njev: 0,
nhev: 0,
maxcv: 0,
status: 0,
message: "Training not completed".to_string(),
jac: None,
hess: None,
constr: None,
};
for _episode in 0..self.config.num_episodes {
let result = self.run_episode(objective, initial_params)?;
if result.fun < best_result.fun {
best_result = result;
}
}
best_result.x = self.best_params.clone();
best_result.fun = self.best_objective;
best_result.message = "Q-learning training completed".to_string();
Ok(best_result)
}
fn reset(&mut self) {
self.q_table.clear();
self.exploration_rate = self.config.exploration_rate;
self.best_objective = f64::INFINITY;
self.best_params.fill(0.0);
}
}
#[allow(dead_code)]
pub fn placeholder() {}