use crate::error::{RlError, RlResult};
#[derive(Debug, Clone, Copy)]
pub struct PpoConfig {
pub clip_eps: f32,
pub value_coeff: f32,
pub entropy_coeff: f32,
pub value_clip: Option<f32>,
}
impl Default for PpoConfig {
fn default() -> Self {
Self {
clip_eps: 0.2,
value_coeff: 0.5,
entropy_coeff: 0.01,
value_clip: None,
}
}
}
#[derive(Debug, Clone)]
pub struct PpoLoss {
pub total: f32,
pub policy_loss: f32,
pub value_loss: f32,
pub entropy: f32,
pub clip_fraction: f32,
pub approx_kl: f32,
}
#[allow(clippy::too_many_arguments)]
pub fn ppo_loss(
log_probs_new: &[f32],
log_probs_old: &[f32],
advantages: &[f32],
value_preds: &[f32],
value_targets: &[f32],
entropies: &[f32],
old_vpreds: &[f32],
cfg: PpoConfig,
) -> RlResult<PpoLoss> {
let b = log_probs_new.len();
if log_probs_old.len() != b
|| advantages.len() != b
|| value_preds.len() != b
|| value_targets.len() != b
|| entropies.len() != b
|| old_vpreds.len() != b
{
return Err(RlError::DimensionMismatch {
expected: b,
got: b.wrapping_sub(1),
});
}
let b_f = b as f32;
let mut policy_loss = 0.0_f32;
let mut value_loss = 0.0_f32;
let mut entropy_sum = 0.0_f32;
let mut clip_count = 0_usize;
let mut kl_sum = 0.0_f32;
for i in 0..b {
let ratio = (log_probs_new[i] - log_probs_old[i]).exp();
let adv = advantages[i];
let surr1 = ratio * adv;
let surr2 = ratio.clamp(1.0 - cfg.clip_eps, 1.0 + cfg.clip_eps) * adv;
let pol = surr1.min(surr2);
policy_loss += pol;
if (ratio - 1.0).abs() > cfg.clip_eps {
clip_count += 1;
}
let vf = match cfg.value_clip {
Some(vclip) => {
let vf_unclipped = (value_preds[i] - value_targets[i]).powi(2);
let vf_clipped = (old_vpreds[i]
+ (value_preds[i] - old_vpreds[i]).clamp(-vclip, vclip)
- value_targets[i])
.powi(2);
vf_unclipped.max(vf_clipped)
}
None => (value_preds[i] - value_targets[i]).powi(2),
};
value_loss += vf;
entropy_sum += entropies[i];
kl_sum += log_probs_old[i] - log_probs_new[i];
}
let policy_loss_mean = policy_loss / b_f;
let value_loss_mean = value_loss / b_f;
let entropy_mean = entropy_sum / b_f;
let total =
-policy_loss_mean + cfg.value_coeff * value_loss_mean - cfg.entropy_coeff * entropy_mean;
Ok(PpoLoss {
total,
policy_loss: -policy_loss_mean,
value_loss: value_loss_mean,
entropy: entropy_mean,
clip_fraction: clip_count as f32 / b_f,
approx_kl: kl_sum / b_f,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn make_ppo_batch(b: usize, ratio: f32, adv: f32) -> PpoLoss {
let lp_new = vec![ratio.ln() + 0.0_f32; b];
let lp_old = vec![0.0_f32; b];
let adv_v = vec![adv; b];
let vp = vec![1.0_f32; b];
let vt = vec![1.0_f32; b];
let ent = vec![0.5_f32; b];
let ovp = vec![1.0_f32; b];
ppo_loss(
&lp_new,
&lp_old,
&adv_v,
&vp,
&vt,
&ent,
&ovp,
PpoConfig::default(),
)
.unwrap()
}
#[test]
fn ppo_loss_zero_value_loss_when_match() {
let l = make_ppo_batch(4, 1.0, 1.0);
assert!(l.value_loss.abs() < 1e-5, "v_loss={}", l.value_loss);
}
#[test]
fn ppo_loss_positive_when_bad_policy() {
let l = make_ppo_batch(8, 10.0, 1.0); assert!(l.total.is_finite(), "loss should be finite");
}
#[test]
fn ppo_clip_fraction_all_clipped() {
let l = make_ppo_batch(8, 10.0, 1.0);
assert!(
(l.clip_fraction - 1.0).abs() < 1e-5,
"clip_fraction={}",
l.clip_fraction
);
}
#[test]
fn ppo_clip_fraction_none_clipped() {
let l = make_ppo_batch(8, 1.0, 1.0);
assert!(l.clip_fraction < 1e-5, "clip_fraction={}", l.clip_fraction);
}
#[test]
fn ppo_dimension_mismatch() {
let b = vec![0.0_f32; 4];
let b3 = vec![0.0_f32; 3];
assert!(ppo_loss(&b, &b, &b, &b, &b, &b, &b3, PpoConfig::default()).is_err());
}
#[test]
fn ppo_entropy_reduces_loss() {
let b = 8;
let lp_new = vec![0.0_f32; b];
let lp_old = vec![0.0_f32; b];
let adv = vec![0.0_f32; b];
let vp = vec![0.0_f32; b];
let vt = vec![0.0_f32; b];
let ovp = vec![0.0_f32; b];
let ent_low = vec![0.1_f32; b];
let ent_high = vec![2.0_f32; b];
let l_low = ppo_loss(
&lp_new,
&lp_old,
&adv,
&vp,
&vt,
&ent_low,
&ovp,
PpoConfig::default(),
)
.unwrap();
let l_high = ppo_loss(
&lp_new,
&lp_old,
&adv,
&vp,
&vt,
&ent_high,
&ovp,
PpoConfig::default(),
)
.unwrap();
assert!(
l_high.total < l_low.total,
"higher entropy should reduce total loss"
);
}
}