use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct PPOConfig {
pub clip_range: f64,
pub clip_range_vf: f64,
pub value_coeff: f64,
pub entropy_coeff: f64,
pub gamma: f64,
pub lam: f64,
pub max_grad_norm: f64,
pub n_epochs: usize,
pub mini_batch_size: usize,
pub normalise_advantages: bool,
}
impl Default for PPOConfig {
fn default() -> Self {
Self {
clip_range: 0.2,
clip_range_vf: 0.2,
value_coeff: 0.5,
entropy_coeff: 0.01,
gamma: 0.99,
lam: 0.95,
max_grad_norm: 0.5,
n_epochs: 4,
mini_batch_size: 64,
normalise_advantages: true,
}
}
}
#[derive(Debug, Clone)]
pub struct RolloutStep<F: Float + Debug> {
pub obs: Array1<F>,
pub action: usize,
pub reward: F,
pub value: F,
pub log_prob: F,
pub done: bool,
pub advantage: F,
pub returns: F,
}
#[derive(Debug, Clone)]
pub struct PPOBuffer<F: Float + Debug> {
pub steps: Vec<RolloutStep<F>>,
pub capacity: usize,
pub obs_dim: usize,
pub config: PPOConfig,
}
impl<F> PPOBuffer<F>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive + Clone,
{
pub fn new(config: PPOConfig, capacity: usize, obs_dim: usize) -> Self {
Self {
steps: Vec::with_capacity(capacity),
capacity,
obs_dim,
config,
}
}
pub fn is_full(&self) -> bool {
self.steps.len() >= self.capacity
}
pub fn len(&self) -> usize {
self.steps.len()
}
pub fn is_empty(&self) -> bool {
self.steps.is_empty()
}
pub fn push(
&mut self,
obs: Array1<F>,
action: usize,
reward: f64,
value: f64,
log_prob: f64,
done: bool,
) -> Result<()> {
if self.is_full() {
return Err(NeuralError::InvalidArgument(
"PPOBuffer: buffer is full".to_string(),
));
}
if obs.len() != self.obs_dim {
return Err(NeuralError::DimensionMismatch(format!(
"PPOBuffer: obs dim {} != expected {}",
obs.len(),
self.obs_dim
)));
}
let reward_f = F::from_f64(reward).ok_or_else(|| {
NeuralError::ComputationError("PPOBuffer: cannot convert reward".to_string())
})?;
let value_f = F::from_f64(value).ok_or_else(|| {
NeuralError::ComputationError("PPOBuffer: cannot convert value".to_string())
})?;
let log_prob_f = F::from_f64(log_prob).ok_or_else(|| {
NeuralError::ComputationError("PPOBuffer: cannot convert log_prob".to_string())
})?;
self.steps.push(RolloutStep {
obs,
action,
reward: reward_f,
value: value_f,
log_prob: log_prob_f,
done,
advantage: F::zero(),
returns: F::zero(),
});
Ok(())
}
pub fn reset(&mut self) {
self.steps.clear();
}
pub fn obs_matrix(&self) -> Result<Array2<F>> {
let t = self.steps.len();
if t == 0 {
return Err(NeuralError::InvalidArgument(
"obs_matrix: buffer is empty".to_string(),
));
}
let mut mat = Array2::zeros((t, self.obs_dim));
for (i, step) in self.steps.iter().enumerate() {
for j in 0..self.obs_dim {
mat[[i, j]] = step.obs[j];
}
}
Ok(mat)
}
pub fn advantages(&self) -> Array1<F> {
Array1::from_iter(self.steps.iter().map(|s| s.advantage))
}
pub fn returns_array(&self) -> Array1<F> {
Array1::from_iter(self.steps.iter().map(|s| s.returns))
}
pub fn old_log_probs(&self) -> Array1<F> {
Array1::from_iter(self.steps.iter().map(|s| s.log_prob))
}
pub fn old_values(&self) -> Array1<F> {
Array1::from_iter(self.steps.iter().map(|s| s.value))
}
pub fn normalise_advantages(&mut self) -> Result<()> {
let t = self.steps.len();
if t < 2 {
return Ok(());
}
let eps = F::from_f64(1e-8).ok_or_else(|| {
NeuralError::ComputationError("normalise_advantages: cannot convert eps".to_string())
})?;
let t_f = F::from_usize(t)
.ok_or_else(|| NeuralError::ComputationError("cannot convert t".to_string()))?;
let mut sum = F::zero();
for s in &self.steps {
sum += s.advantage;
}
let mean = sum / t_f;
let mut sq_sum = F::zero();
for s in &self.steps {
let diff = s.advantage - mean;
sq_sum += diff * diff;
}
let std_dev = (sq_sum / t_f + eps).sqrt();
for s in &mut self.steps {
s.advantage = (s.advantage - mean) / std_dev;
}
Ok(())
}
}
pub fn compute_gae<F>(
buf: &mut PPOBuffer<F>,
last_value: f64,
gamma: f64,
lam: f64,
) -> Result<()>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive + Clone,
{
let t = buf.steps.len();
if t == 0 {
return Err(NeuralError::InvalidArgument(
"compute_gae: buffer is empty".to_string(),
));
}
let gamma_f = F::from_f64(gamma).ok_or_else(|| {
NeuralError::ComputationError("compute_gae: cannot convert gamma".to_string())
})?;
let lam_f = F::from_f64(lam).ok_or_else(|| {
NeuralError::ComputationError("compute_gae: cannot convert lam".to_string())
})?;
let last_val_f = F::from_f64(last_value).ok_or_else(|| {
NeuralError::ComputationError("compute_gae: cannot convert last_value".to_string())
})?;
let mut gae = F::zero();
for i in (0..t).rev() {
let next_non_terminal = if buf.steps[i].done {
F::zero()
} else {
F::one()
};
let next_value = if i + 1 < t {
buf.steps[i + 1].value
} else {
last_val_f
};
let delta = buf.steps[i].reward
+ gamma_f * next_value * next_non_terminal
- buf.steps[i].value;
gae = delta + gamma_f * lam_f * next_non_terminal * gae;
buf.steps[i].advantage = gae;
buf.steps[i].returns = gae + buf.steps[i].value;
}
Ok(())
}
pub fn ppo_clip_loss<F>(
log_probs_new: &Array1<F>,
log_probs_old: &Array1<F>,
advantages: &Array1<F>,
clip_range: f64,
) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let t = log_probs_new.len();
if t == 0 {
return Err(NeuralError::InvalidArgument(
"ppo_clip_loss: empty arrays".to_string(),
));
}
if log_probs_old.len() != t || advantages.len() != t {
return Err(NeuralError::DimensionMismatch(format!(
"ppo_clip_loss: length mismatch {t} / {} / {}",
log_probs_old.len(),
advantages.len()
)));
}
let eps = F::from_f64(clip_range).ok_or_else(|| {
NeuralError::ComputationError("ppo_clip_loss: cannot convert clip_range".to_string())
})?;
let one = F::one();
let clip_lo = one - eps;
let clip_hi = one + eps;
let mut total = F::zero();
for i in 0..t {
let log_ratio = log_probs_new[i] - log_probs_old[i];
let ratio = log_ratio.exp();
let surr1 = ratio * advantages[i];
let surr2 = ratio.max(clip_lo).min(clip_hi) * advantages[i];
total += surr1.min(surr2);
}
let t_f = F::from_usize(t)
.ok_or_else(|| NeuralError::ComputationError("cannot convert t".to_string()))?;
Ok(-total / t_f)
}
pub fn value_loss<F>(
values_new: &Array1<F>,
values_old: &Array1<F>,
returns: &Array1<F>,
clip_range_vf: f64,
) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let t = values_new.len();
if t == 0 {
return Err(NeuralError::InvalidArgument(
"value_loss: empty arrays".to_string(),
));
}
if values_old.len() != t || returns.len() != t {
return Err(NeuralError::DimensionMismatch(format!(
"value_loss: length mismatch {t} / {} / {}",
values_old.len(),
returns.len()
)));
}
let half = F::from_f64(0.5).ok_or_else(|| {
NeuralError::ComputationError("value_loss: cannot convert 0.5".to_string())
})?;
let clip = clip_range_vf != 0.0;
let eps = if clip {
F::from_f64(clip_range_vf).ok_or_else(|| {
NeuralError::ComputationError(
"value_loss: cannot convert clip_range_vf".to_string(),
)
})?
} else {
F::zero()
};
let mut total = F::zero();
for i in 0..t {
let err_unclipped = values_new[i] - returns[i];
let loss_unclipped = err_unclipped * err_unclipped;
let loss = if clip {
let v_clipped = (values_new[i])
.max(values_old[i] - eps)
.min(values_old[i] + eps);
let err_clipped = v_clipped - returns[i];
let loss_clipped = err_clipped * err_clipped;
loss_unclipped.max(loss_clipped)
} else {
loss_unclipped
};
total += loss;
}
let t_f = F::from_usize(t)
.ok_or_else(|| NeuralError::ComputationError("cannot convert t".to_string()))?;
Ok(half * total / t_f)
}
#[derive(Debug, Clone)]
pub struct PPOLossOutput<F: Float + Debug> {
pub total_loss: F,
pub policy_loss: F,
pub vf_loss: F,
pub entropy: F,
pub clip_fraction: F,
pub mean_ratio: F,
}
pub fn ppo_loss<F>(
log_probs_new: &Array1<F>,
log_probs_old: &Array1<F>,
advantages: &Array1<F>,
values_new: &Array1<F>,
values_old: &Array1<F>,
returns: &Array1<F>,
entropy: F,
config: &PPOConfig,
) -> Result<PPOLossOutput<F>>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let policy_loss = ppo_clip_loss(log_probs_new, log_probs_old, advantages, config.clip_range)?;
let vf_loss = value_loss(values_new, values_old, returns, config.clip_range_vf)?;
let c_v = F::from_f64(config.value_coeff).ok_or_else(|| {
NeuralError::ComputationError("ppo_loss: cannot convert value_coeff".to_string())
})?;
let c_e = F::from_f64(config.entropy_coeff).ok_or_else(|| {
NeuralError::ComputationError("ppo_loss: cannot convert entropy_coeff".to_string())
})?;
let total = policy_loss + c_v * vf_loss - c_e * entropy;
let (clip_frac, mean_ratio) = clip_diagnostics(log_probs_new, log_probs_old, config.clip_range)?;
Ok(PPOLossOutput {
total_loss: total,
policy_loss,
vf_loss,
entropy,
clip_fraction: clip_frac,
mean_ratio,
})
}
fn clip_diagnostics<F>(
log_probs_new: &Array1<F>,
log_probs_old: &Array1<F>,
clip_range: f64,
) -> Result<(F, F)>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let t = log_probs_new.len();
let eps = F::from_f64(clip_range).ok_or_else(|| {
NeuralError::ComputationError("clip_diagnostics: cannot convert clip_range".to_string())
})?;
let one = F::one();
let lo = one - eps;
let hi = one + eps;
let mut clipped = 0usize;
let mut sum_ratio = F::zero();
for i in 0..t {
let ratio = (log_probs_new[i] - log_probs_old[i]).exp();
sum_ratio += ratio;
if ratio < lo || ratio > hi {
clipped += 1;
}
}
let t_f = F::from_usize(t)
.ok_or_else(|| NeuralError::ComputationError("cannot convert t".to_string()))?;
let clip_frac = F::from_usize(clipped)
.ok_or_else(|| NeuralError::ComputationError("cannot convert clipped".to_string()))?
/ t_f;
let mean_ratio = sum_ratio / t_f;
Ok((clip_frac, mean_ratio))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ppo_buffer_push_and_gae() {
let config = PPOConfig::default();
let mut buf = PPOBuffer::<f64>::new(config.clone(), 5, 4);
for i in 0..5 {
let obs = Array1::from(vec![i as f64; 4]);
buf.push(obs, 0, 1.0, 0.5, -0.3, i == 4).expect("push");
}
assert!(buf.is_full());
assert_eq!(buf.len(), 5);
compute_gae(&mut buf, 0.0, config.gamma, config.lam).expect("gae");
for step in &buf.steps {
assert!(step.advantage.is_finite());
assert!(step.returns.is_finite());
}
}
#[test]
fn test_gae_terminal_step() {
let config = PPOConfig::default();
let mut buf = PPOBuffer::<f64>::new(config.clone(), 3, 2);
buf.push(Array1::zeros(2), 0, 1.0, 1.0, 0.0, false).expect("push");
buf.push(Array1::zeros(2), 0, 1.0, 1.0, 0.0, true).expect("push");
buf.push(Array1::zeros(2), 0, 1.0, 1.0, 0.0, false).expect("push");
compute_gae(&mut buf, 0.0, 0.99, 0.95).expect("gae");
assert!(buf.steps[2].advantage.is_finite());
}
#[test]
fn test_ppo_clip_loss_unclipped() {
let log_probs = Array1::from(vec![0.0_f64, 0.0, 0.0]);
let advantages = Array1::from(vec![1.0_f64, 2.0, 3.0]);
let loss = ppo_clip_loss(&log_probs, &log_probs, &advantages, 0.2).expect("loss");
assert!((loss - (-2.0)).abs() < 1e-9, "loss={loss}");
}
#[test]
fn test_ppo_clip_loss_clipped() {
let log_probs_new = Array1::from(vec![5.0_f64]);
let log_probs_old = Array1::from(vec![0.0_f64]);
let advantages = Array1::from(vec![1.0_f64]);
let loss = ppo_clip_loss(&log_probs_new, &log_probs_old, &advantages, 0.2).expect("loss");
assert!((loss - (-1.2)).abs() < 1e-9, "loss={loss}");
}
#[test]
fn test_value_loss_no_clip() {
let v_new = Array1::from(vec![1.0_f64, 2.0]);
let v_old = Array1::from(vec![1.5_f64, 2.5]);
let returns = Array1::from(vec![2.0_f64, 3.0]);
let loss = value_loss(&v_new, &v_old, &returns, 0.0).expect("loss");
assert!((loss - 0.5).abs() < 1e-9, "loss={loss}");
}
#[test]
fn test_value_loss_clipped() {
let v_new = Array1::from(vec![5.0_f64]); let v_old = Array1::from(vec![1.0_f64]);
let returns = Array1::from(vec![2.0_f64]);
let loss = value_loss(&v_new, &v_old, &returns, 0.2).expect("loss");
assert!((loss - 4.5).abs() < 1e-9, "loss={loss}");
}
#[test]
fn test_normalise_advantages() {
let config = PPOConfig::default();
let mut buf = PPOBuffer::<f64>::new(config.clone(), 4, 2);
for _ in 0..4 {
buf.push(Array1::zeros(2), 0, 1.0, 0.5, 0.0, false).expect("push");
}
compute_gae(&mut buf, 0.0, 0.99, 0.95).expect("gae");
buf.steps[0].advantage = 1.0;
buf.steps[1].advantage = 2.0;
buf.steps[2].advantage = 3.0;
buf.steps[3].advantage = 4.0;
buf.normalise_advantages().expect("normalise");
let adv = buf.advantages();
let mean: f64 = adv.sum() / adv.len() as f64;
assert!(mean.abs() < 1e-6, "mean={mean}");
}
#[test]
fn test_ppo_loss_combined() {
let config = PPOConfig::default();
let t = 8;
let log_p_new = Array1::from(vec![-1.0_f64; t]);
let log_p_old = Array1::from(vec![-1.0_f64; t]);
let adv = Array1::from(vec![1.0_f64; t]);
let v_new = Array1::from(vec![1.0_f64; t]);
let v_old = Array1::from(vec![1.0_f64; t]);
let rets = Array1::from(vec![1.5_f64; t]);
let entropy = 0.5_f64;
let out = ppo_loss(&log_p_new, &log_p_old, &adv, &v_new, &v_old, &rets, entropy, &config)
.expect("ppo_loss");
assert!(out.total_loss.is_finite());
assert!(out.policy_loss.is_finite());
assert!(out.vf_loss.is_finite());
assert!(out.clip_fraction.is_finite());
}
#[test]
fn test_buffer_overflow_error() {
let config = PPOConfig::default();
let mut buf = PPOBuffer::<f64>::new(config, 1, 2);
buf.push(Array1::zeros(2), 0, 0.0, 0.0, 0.0, false).expect("push");
let result = buf.push(Array1::zeros(2), 0, 0.0, 0.0, 0.0, false);
assert!(result.is_err());
}
#[test]
fn test_obs_dim_mismatch_error() {
let config = PPOConfig::default();
let mut buf = PPOBuffer::<f64>::new(config, 4, 4);
let wrong_obs = Array1::zeros(3);
let result = buf.push(wrong_obs, 0, 0.0, 0.0, 0.0, false);
assert!(result.is_err());
}
}