use crate::error::{RlError, RlResult};
#[derive(Debug, Clone, Copy)]
pub struct SacConfig {
pub gamma: f32,
pub alpha: f32,
pub target_entropy: f32,
pub auto_alpha: bool,
}
impl Default for SacConfig {
fn default() -> Self {
Self {
gamma: 0.99,
alpha: 0.2,
target_entropy: -1.0,
auto_alpha: true,
}
}
}
#[derive(Debug, Clone)]
pub struct SacLoss {
pub critic_loss: f32,
pub actor_loss: f32,
pub alpha_loss: f32,
pub td_errors: Vec<f32>,
}
pub fn sac_critic_loss(
q_sa: &[f32],
rewards: &[f32],
dones: &[f32],
min_q_next: &[f32],
log_pi_next: &[f32],
is_weights: &[f32],
cfg: SacConfig,
) -> RlResult<(f32, Vec<f32>)> {
let b = q_sa.len();
if rewards.len() != b
|| dones.len() != b
|| min_q_next.len() != b
|| log_pi_next.len() != b
|| is_weights.len() != b
{
return Err(RlError::DimensionMismatch {
expected: b,
got: b.wrapping_sub(1),
});
}
let mut loss = 0.0_f32;
let mut td_errors = Vec::with_capacity(b);
for i in 0..b {
let mask = 1.0 - dones[i];
let soft_value = min_q_next[i] - cfg.alpha * log_pi_next[i];
let target = rewards[i] + cfg.gamma * mask * soft_value;
let delta = target - q_sa[i];
td_errors.push(delta.abs());
loss += is_weights[i] * 0.5 * delta * delta;
}
loss /= b as f32;
Ok((loss, td_errors))
}
pub fn sac_actor_loss(log_pi: &[f32], min_q: &[f32], cfg: SacConfig) -> RlResult<f32> {
let b = log_pi.len();
if min_q.len() != b {
return Err(RlError::DimensionMismatch {
expected: b,
got: min_q.len(),
});
}
let loss: f32 = log_pi
.iter()
.zip(min_q.iter())
.map(|(&lp, &q)| cfg.alpha * lp - q)
.sum::<f32>()
/ b as f32;
Ok(loss)
}
pub fn sac_temperature_loss(log_pi: &[f32], log_alpha: f32, cfg: SacConfig) -> RlResult<f32> {
if log_pi.is_empty() {
return Err(RlError::DimensionMismatch {
expected: 1,
got: 0,
});
}
let alpha = log_alpha.exp();
let mean_lp = log_pi.iter().sum::<f32>() / log_pi.len() as f32;
Ok(-alpha * (mean_lp + cfg.target_entropy))
}
#[allow(clippy::too_many_arguments)]
pub fn sac_loss(
q_sa: &[f32],
rewards: &[f32],
dones: &[f32],
min_q_next: &[f32],
log_pi_next: &[f32],
is_weights: &[f32],
log_pi: &[f32],
min_q: &[f32],
log_alpha: f32,
cfg: SacConfig,
) -> RlResult<SacLoss> {
let (critic_loss, td_errors) = sac_critic_loss(
q_sa,
rewards,
dones,
min_q_next,
log_pi_next,
is_weights,
cfg,
)?;
let actor_loss = sac_actor_loss(log_pi, min_q, cfg)?;
let alpha_loss = sac_temperature_loss(log_pi, log_alpha, cfg)?;
Ok(SacLoss {
critic_loss,
actor_loss,
alpha_loss,
td_errors,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn critic_loss_zero_when_perfect() {
let cfg = SacConfig {
gamma: 1.0,
alpha: 0.0,
..SacConfig::default()
};
let q = vec![2.0_f32; 4];
let r = vec![1.0_f32; 4];
let d = vec![0.0_f32; 4];
let min_q_next = vec![1.0_f32; 4];
let lp_next = vec![0.0_f32; 4];
let w = vec![1.0_f32; 4];
let (loss, _) = sac_critic_loss(&q, &r, &d, &min_q_next, &lp_next, &w, cfg).unwrap();
assert!(loss.abs() < 1e-5, "critic loss={loss}");
}
#[test]
fn actor_loss_negative_with_high_q() {
let cfg = SacConfig {
alpha: 0.1,
..SacConfig::default()
};
let log_pi = vec![-1.0_f32; 4];
let min_q = vec![10.0_f32; 4];
let l = sac_actor_loss(&log_pi, &min_q, cfg).unwrap();
assert!(l < 0.0, "actor loss should be negative with high Q");
}
#[test]
fn temperature_loss_zero_at_target_entropy() {
let cfg = SacConfig {
target_entropy: -1.0,
..SacConfig::default()
};
let log_pi = vec![-1.0_f32; 4]; let l = sac_temperature_loss(&log_pi, 0.0_f32.ln().max(-10.0), cfg).unwrap();
assert!(l.is_finite(), "temperature loss should be finite");
}
#[test]
fn temperature_loss_empty_error() {
assert!(sac_temperature_loss(&[], 0.0, SacConfig::default()).is_err());
}
#[test]
fn sac_loss_all_finite() {
let b = 8;
let q = vec![1.0_f32; b];
let r = vec![0.5_f32; b];
let d = vec![0.0_f32; b];
let min_qn = vec![1.0_f32; b];
let lp_next = vec![-0.5_f32; b];
let w = vec![1.0_f32; b];
let lp = vec![-0.5_f32; b];
let min_q = vec![1.0_f32; b];
let l = sac_loss(
&q,
&r,
&d,
&min_qn,
&lp_next,
&w,
&lp,
&min_q,
0.0_f32.ln().max(-10.0),
SacConfig::default(),
)
.unwrap();
assert!(l.critic_loss.is_finite());
assert!(l.actor_loss.is_finite());
assert!(l.alpha_loss.is_finite());
}
}