#[allow(dead_code)]
use super::{
PolicyNetwork, RLOptimizationMetrics, RLOptimizerConfig, RLScheduler, ScheduleType,
TrajectoryBatch, ValueNetwork,
};
use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub enum PolicyGradientMethod {
Reinforce,
ActorCritic,
PPOClip,
PPOAdaptiveKL,
TRPO,
IMPALA,
A3C,
}
#[derive(Debug, Clone)]
pub struct PolicyGradientConfig<T: Float + Debug + Send + Sync + 'static> {
pub base_config: RLOptimizerConfig<T>,
pub method: PolicyGradientMethod,
pub ppo_config: PPOConfig<T>,
pub trpo_config: TRPOConfig<T>,
pub policy_scheduler: Option<RLScheduler<T>>,
pub value_scheduler: Option<RLScheduler<T>>,
pub use_baseline: bool,
pub importance_sampling: bool,
pub max_is_ratio: T,
}
#[derive(Debug, Clone)]
pub struct PPOConfig<T: Float + Debug + Send + Sync + 'static> {
pub clip_epsilon: T,
pub dual_clip: bool,
pub value_clip: bool,
pub value_clip_range: T,
pub target_kl: T,
pub kl_coeff: T,
pub kl_coeff_adapt_factor: T,
pub early_stop_on_kl: bool,
}
#[derive(Debug, Clone)]
pub struct TRPOConfig<T: Float + Debug + Send + Sync + 'static> {
pub max_kl: T,
pub backtrack_factor: T,
pub max_backtracks: usize,
pub cg_iters: usize,
pub cg_damping: T,
pub cg_tolerance: T,
pub use_natural_gradients: bool,
}
impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::numeric::FromPrimitive> Default for PolicyGradientConfig<T> {
fn default() -> Self {
Self {
base_config: RLOptimizerConfig::default(),
method: PolicyGradientMethod::PPOClip,
ppo_config: PPOConfig::default(),
trpo_config: TRPOConfig::default(),
policy_scheduler: Some(RLScheduler::new(
T::from(3e-4).unwrap_or_else(|| T::zero()),
ScheduleType::Constant,
)),
value_scheduler: Some(RLScheduler::new(
T::from(1e-3).unwrap_or_else(|| T::zero()),
ScheduleType::Constant,
)),
use_baseline: true,
importance_sampling: false,
max_is_ratio: T::from(2.0).unwrap_or_else(|| T::zero()),
}
}
}
impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::numeric::FromPrimitive> Default for PPOConfig<T> {
fn default() -> Self {
Self {
clip_epsilon: T::from(0.2).unwrap_or_else(|| T::zero()),
dual_clip: false,
value_clip: true,
value_clip_range: T::from(0.2).unwrap_or_else(|| T::zero()),
target_kl: T::from(0.01).unwrap_or_else(|| T::zero()),
kl_coeff: T::from(0.2).unwrap_or_else(|| T::zero()),
kl_coeff_adapt_factor: T::from(1.5).unwrap_or_else(|| T::zero()),
early_stop_on_kl: true,
}
}
}
impl<T: Float + Debug + Send + Sync + 'static + scirs2_core::numeric::FromPrimitive> Default for TRPOConfig<T> {
fn default() -> Self {
Self {
max_kl: T::from(0.01).unwrap_or_else(|| T::zero()),
backtrack_factor: T::from(0.5).unwrap_or_else(|| T::zero()),
max_backtracks: 10,
cg_iters: 10,
cg_damping: T::from(0.1).unwrap_or_else(|| T::zero()),
cg_tolerance: T::from(1e-8).unwrap_or_else(|| T::zero()),
use_natural_gradients: true,
}
}
}
pub struct PolicyGradientOptimizer<T: Float + Debug, P: PolicyNetwork<T>, V: ValueNetwork<T>> {
config: PolicyGradientConfig<T>,
policy_network: P,
value_network: Option<V>,
policy_scheduler: Option<RLScheduler<T>>,
value_scheduler: Option<RLScheduler<T>>,
metrics: RLOptimizationMetrics<T>,
update_count: usize,
kl_coeff: T,
trajectory_buffer: Vec<TrajectoryBatch<T>>,
max_buffer_size: usize,
}
impl<
T: Float
+ Send
+ Sync
+ ScalarOperand
+ std::ops::AddAssign
+ std::iter::Sum
+ scirs2_core::numeric::FromPrimitive,
P: PolicyNetwork<T>,
V: ValueNetwork<T>,
> PolicyGradientOptimizer<T, P, V>
{
pub fn new(
config: PolicyGradientConfig<T>,
policy_network: P,
value_network: Option<V>,
) -> Self {
let kl_coeff = config.ppo_config.kl_coeff;
let policy_scheduler = config.policy_scheduler.clone();
let value_scheduler = config.value_scheduler.clone();
Self {
config,
policy_network,
value_network,
policy_scheduler,
value_scheduler,
metrics: RLOptimizationMetrics::default(),
update_count: 0,
kl_coeff,
trajectory_buffer: Vec::new(),
max_buffer_size: 1000,
}
}
pub fn update(&mut self, trajectory: TrajectoryBatch<T>) -> Result<RLOptimizationMetrics<T>> {
match self.config.method {
PolicyGradientMethod::PPOClip => self.update_ppo_clip(trajectory),
PolicyGradientMethod::PPOAdaptiveKL => self.update_ppo_adaptive_kl(trajectory),
PolicyGradientMethod::TRPO => self.update_trpo(trajectory),
PolicyGradientMethod::Reinforce => self.update_reinforce(trajectory),
PolicyGradientMethod::ActorCritic => Err(OptimError::InvalidConfig(
"Method not implemented".to_string(),
)),
_ => Err(OptimError::InvalidConfig(
"Method not implemented".to_string(),
)),
}
}
fn update_ppo_clip(
&mut self,
mut trajectory: TrajectoryBatch<T>,
) -> Result<RLOptimizationMetrics<T>> {
let mut total_policy_loss = T::zero();
let mut total_value_loss = T::zero();
let mut total_entropy_loss = T::zero();
let mut clip_fraction = T::zero();
let mut approx_kl = T::zero();
let next_value = if let Some(ref value_net) = self.value_network {
let last_obs = trajectory.observations.slice(s![-1.., ..]).to_owned();
let mut last_obs_batch = Array2::zeros((1, last_obs.ncols()));
last_obs_batch.row_mut(0).assign(&last_obs.row(0));
value_net.evaluate_value(&last_obs_batch)?[0]
} else {
T::zero()
};
trajectory.compute_advantages(
self.config.base_config.discount_factor,
self.config.base_config.gae_lambda,
next_value,
)?;
let _old_policy_eval = self
.policy_network
.evaluate_actions(&trajectory.observations, &trajectory.actions)?;
let n_epochs = self.config.base_config.n_epochs;
let mini_batch_size = self.config.base_config.mini_batchsize;
for _epoch in 0..n_epochs {
let mini_batches = trajectory.get_mini_batches(mini_batch_size);
for mini_batch in mini_batches {
let policy_eval = self
.policy_network
.evaluate_actions(&mini_batch.observations, &mini_batch.actions)?;
let log_ratio = &policy_eval.log_probs - &mini_batch.log_probs;
let ratio = log_ratio.mapv(|x| x.exp());
let surr1 = &ratio * &mini_batch.advantages;
let clipped_ratio = ratio.mapv(|r| {
let clip_eps = self.config.ppo_config.clip_epsilon;
r.max(T::one() - clip_eps).min(T::one() + clip_eps)
});
let surr2 = &clipped_ratio * &mini_batch.advantages;
let policy_loss = -surr1
.iter()
.zip(surr2.iter())
.map(|(&s1, &s2)| s1.min(s2))
.sum::<T>()
/ T::from(mini_batch.observations.nrows()).expect("unwrap failed");
let entropy_loss = -policy_eval.entropy.iter().copied().sum::<T>()
/ T::from(policy_eval.entropy.len()).unwrap_or(T::zero());
let value_loss = if let Some(ref value_net) = self.value_network {
let predicted_values = value_net.evaluate_value(&mini_batch.observations)?;
if self.config.ppo_config.value_clip {
let value_pred_clipped = &mini_batch.values
+ (&predicted_values - &mini_batch.values).mapv(|diff| {
let clip_range = self.config.ppo_config.value_clip_range;
diff.max(-clip_range).min(clip_range)
});
let value_loss_1 =
(&predicted_values - &mini_batch.returns).mapv(|x| x * x);
let value_loss_2 =
(&value_pred_clipped - &mini_batch.returns).mapv(|x| x * x);
value_loss_1
.iter()
.zip(value_loss_2.iter())
.map(|(&v1, &v2)| v1.max(v2))
.sum::<T>()
/ T::from(mini_batch.observations.nrows()).expect("unwrap failed")
} else {
(&predicted_values - &mini_batch.returns)
.mapv(|x| x * x)
.mean()
.unwrap_or(T::zero())
}
} else {
T::zero()
};
let total_loss = policy_loss
+ self.config.base_config.value_loss_coeff * value_loss
+ self.config.base_config.entropy_coeff * entropy_loss;
self.update_networks_with_loss(total_loss, policy_loss, value_loss)?;
total_policy_loss = total_policy_loss + policy_loss;
total_value_loss = total_value_loss + value_loss;
total_entropy_loss = total_entropy_loss + entropy_loss;
let n_clipped = ratio
.iter()
.filter(|&&r| {
let clip_eps = self.config.ppo_config.clip_epsilon;
r < T::one() - clip_eps || r > T::one() + clip_eps
})
.count();
clip_fraction =
clip_fraction + T::from(n_clipped).unwrap_or_else(|| T::zero()) / T::from(ratio.len()).expect("unwrap failed");
approx_kl = approx_kl + log_ratio.mapv(|x| x * x).mean().unwrap_or(T::zero());
if self.config.ppo_config.early_stop_on_kl
&& approx_kl > self.config.ppo_config.target_kl * T::from(2.0).unwrap_or_else(|| T::zero())
{
break;
}
}
}
if let Some(ref mut scheduler) = self.policy_scheduler {
self.metrics.policy_lr = scheduler.step();
}
if let Some(ref mut scheduler) = self.value_scheduler {
self.metrics.value_lr = scheduler.step();
}
self.update_count += 1;
let n_updates = T::from(
n_epochs * ((trajectory.observations.nrows() + mini_batch_size - 1) / mini_batch_size),
)
.expect("unwrap failed");
self.metrics.policy_loss = total_policy_loss / n_updates;
self.metrics.value_loss = total_value_loss / n_updates;
self.metrics.entropy_loss = total_entropy_loss / n_updates;
self.metrics.total_loss = self.metrics.policy_loss
+ self.config.base_config.value_loss_coeff * self.metrics.value_loss
+ self.config.base_config.entropy_coeff * self.metrics.entropy_loss;
self.metrics.clip_fraction = Some(clip_fraction / n_updates);
self.metrics.kl_divergence = Some(approx_kl / n_updates);
Ok(self.metrics.clone())
}
fn update_ppo_adaptive_kl(
&mut self,
trajectory: TrajectoryBatch<T>,
) -> Result<RLOptimizationMetrics<T>> {
self.update_ppo_clip(trajectory) }
fn update_trpo(&mut self, trajectory: TrajectoryBatch<T>) -> Result<RLOptimizationMetrics<T>> {
self.update_ppo_clip(trajectory) }
fn update_reinforce(
&mut self,
trajectory: TrajectoryBatch<T>,
) -> Result<RLOptimizationMetrics<T>> {
let policy_eval = self
.policy_network
.evaluate_actions(&trajectory.observations, &trajectory.actions)?;
let policy_loss = if self.config.use_baseline && self.value_network.is_some() {
-(policy_eval.log_probs * trajectory.advantages)
.mean()
.unwrap_or(T::zero())
} else {
-(policy_eval.log_probs * trajectory.returns)
.mean()
.unwrap_or(T::zero())
};
let entropy_loss = -policy_eval.entropy.iter().copied().sum::<T>()
/ T::from(policy_eval.entropy.len()).unwrap_or(T::zero());
let total_loss = policy_loss + self.config.base_config.entropy_coeff * entropy_loss;
self.update_networks_with_loss(total_loss, policy_loss, T::zero())?;
self.metrics.policy_loss = policy_loss;
self.metrics.entropy_loss = entropy_loss;
self.metrics.total_loss = total_loss;
Ok(self.metrics.clone())
}
fn update_actor_critic(
&mut self,
trajectory: TrajectoryBatch<T>,
) -> Result<RLOptimizationMetrics<T>> {
self.update_reinforce(trajectory)
}
fn update_networks_with_loss(
&mut self,
_total_loss: T,
policy_loss: T,
value_loss: T,
) -> Result<()> {
let policy_gradients = self.compute_policy_gradients(policy_loss)?;
let value_gradients = if let Some(_) = self.value_network {
Some(self.compute_value_gradients(value_loss)?)
} else {
None
};
let clipped_policy_grads =
self.clip_gradients(&policy_gradients, self.config.base_config.max_grad_norm)?;
let clipped_value_grads = if let Some(val_grads) = value_gradients {
Some(self.clip_gradients(&val_grads, self.config.base_config.max_grad_norm)?)
} else {
None
};
self.update_policy_parameters(&clipped_policy_grads)?;
if let Some(ref val_grads) = clipped_value_grads {
self.update_value_parameters(val_grads)?;
}
self.metrics.policy_grad_norm = self.compute_gradient_norm(&clipped_policy_grads);
if let Some(ref val_grads) = clipped_value_grads {
self.metrics.value_grad_norm = self.compute_gradient_norm(val_grads);
}
Ok(())
}
fn compute_policy_gradients(&self, loss: T) -> Result<HashMap<String, Array1<T>>> {
let mut gradients = HashMap::new();
let policy_params = self.policy_network.get_parameters();
for (param_name, param_values) in policy_params {
let grad =
Array1::ones(param_values.len()) * loss / T::from(param_values.len()).expect("unwrap failed");
gradients.insert(param_name, grad);
}
Ok(gradients)
}
fn compute_value_gradients(&self, loss: T) -> Result<HashMap<String, Array1<T>>> {
let mut gradients = HashMap::new();
if let Some(ref value_net) = self.value_network {
let value_params = value_net.get_parameters();
for (param_name, param_values) in value_params {
let grad =
Array1::ones(param_values.len()) * loss / T::from(param_values.len()).expect("unwrap failed");
gradients.insert(param_name, grad);
}
}
Ok(gradients)
}
fn clip_gradients(
&self,
gradients: &HashMap<String, Array1<T>>,
max_norm: T,
) -> Result<HashMap<String, Array1<T>>> {
let mut clipped_gradients = HashMap::new();
let mut total_norm = T::zero();
for (_, grad) in gradients {
total_norm = total_norm + grad.iter().map(|&g| g * g).sum::<T>();
}
total_norm = total_norm.sqrt();
let clip_factor = if total_norm > max_norm {
max_norm / total_norm
} else {
T::one()
};
for (param_name, grad) in gradients {
let clipped_grad = grad * clip_factor;
clipped_gradients.insert(param_name.clone(), clipped_grad);
}
Ok(clipped_gradients)
}
fn update_policy_parameters(&mut self, gradients: &HashMap<String, Array1<T>>) -> Result<()> {
self.policy_network.update_parameters(gradients)?;
Ok(())
}
fn update_value_parameters(&mut self, gradients: &HashMap<String, Array1<T>>) -> Result<()> {
if let Some(ref mut value_net) = self.value_network {
value_net.update_parameters(gradients)?;
}
Ok(())
}
fn compute_gradient_norm(&self, gradients: &HashMap<String, Array1<T>>) -> T {
let mut total_norm = T::zero();
for (_, grad) in gradients {
total_norm = total_norm + grad.iter().map(|&g| g * g).sum::<T>();
}
total_norm.sqrt()
}
pub fn get_metrics(&self) -> &RLOptimizationMetrics<T> {
&self.metrics
}
pub fn add_trajectory(&mut self, trajectory: TrajectoryBatch<T>) {
self.trajectory_buffer.push(trajectory);
if self.trajectory_buffer.len() > self.max_buffer_size {
self.trajectory_buffer.remove(0);
}
}
pub fn update_from_buffer(&mut self) -> Result<RLOptimizationMetrics<T>> {
if self.trajectory_buffer.is_empty() {
return Err(OptimError::InvalidConfig(
"No trajectories in buffer".to_string(),
));
}
let combined = self.combine_trajectories()?;
self.update(combined)
}
fn combine_trajectories(&self) -> Result<TrajectoryBatch<T>> {
if self.trajectory_buffer.is_empty() {
return Err(OptimError::InvalidConfig(
"No trajectories to combine".to_string(),
));
}
let total_size: usize = self
.trajectory_buffer
.iter()
.map(|t| t.observations.nrows())
.sum();
let obs_dim = self.trajectory_buffer[0].observations.ncols();
let action_dim = self.trajectory_buffer[0].actions.ncols();
let mut combined_obs = Array2::zeros((total_size, obs_dim));
let mut combined_actions = Array2::zeros((total_size, action_dim));
let mut combined_log_probs = Array1::zeros(total_size);
let mut combined_rewards = Array1::zeros(total_size);
let mut combined_values = Array1::zeros(total_size);
let mut combined_dones = Vec::with_capacity(total_size);
let mut offset = 0;
for trajectory in &self.trajectory_buffer {
let size = trajectory.observations.nrows();
combined_obs
.slice_mut(s![offset..offset + size, ..])
.assign(&trajectory.observations);
combined_actions
.slice_mut(s![offset..offset + size, ..])
.assign(&trajectory.actions);
combined_log_probs
.slice_mut(s![offset..offset + size])
.assign(&trajectory.log_probs);
combined_rewards
.slice_mut(s![offset..offset + size])
.assign(&trajectory.rewards);
combined_values
.slice_mut(s![offset..offset + size])
.assign(&trajectory.values);
combined_dones.extend_from_slice(trajectory.dones.as_slice().expect("unwrap failed"));
offset += size;
}
let combined_dones_array = Array1::from_vec(combined_dones);
TrajectoryBatch::new(
combined_obs,
combined_actions,
combined_log_probs,
combined_rewards,
combined_values,
combined_dones_array,
)
}
pub fn clear_buffer(&mut self) {
self.trajectory_buffer.clear();
}
}
use scirs2_core::ndarray::s;