use crate::error::{RlError, RlResult};
#[derive(Debug, Clone, Copy)]
pub struct DqnConfig {
pub gamma: f32,
pub huber_kappa: Option<f32>,
pub use_is_weights: bool,
}
impl Default for DqnConfig {
fn default() -> Self {
Self {
gamma: 0.99,
huber_kappa: Some(1.0),
use_is_weights: false,
}
}
}
#[derive(Debug, Clone)]
pub struct DqnLoss {
pub loss: f32,
pub td_errors: Vec<f32>,
}
#[inline]
fn huber(delta: f32, kappa: f32) -> f32 {
if delta.abs() <= kappa {
0.5 * delta * delta
} else {
kappa * (delta.abs() - 0.5 * kappa)
}
}
pub fn dqn_loss(
q_sa: &[f32],
rewards: &[f32],
max_q_next: &[f32],
dones: &[f32],
is_weights: &[f32],
cfg: DqnConfig,
) -> RlResult<DqnLoss> {
let b = q_sa.len();
if rewards.len() != b || max_q_next.len() != b || dones.len() != b || is_weights.len() != b {
return Err(RlError::DimensionMismatch {
expected: b,
got: b.wrapping_sub(1),
});
}
let mut loss = 0.0_f32;
let mut td_errors = Vec::with_capacity(b);
for i in 0..b {
let target = rewards[i] + cfg.gamma * max_q_next[i] * (1.0 - dones[i]);
let delta = target - q_sa[i];
td_errors.push(delta.abs());
let elem_loss = match cfg.huber_kappa {
Some(k) => huber(delta, k),
None => 0.5 * delta * delta,
};
loss += is_weights[i] * elem_loss;
}
loss /= b as f32;
Ok(DqnLoss { loss, td_errors })
}
#[allow(clippy::too_many_arguments)]
pub fn double_dqn_loss(
q_sa: &[f32],
rewards: &[f32],
q_next_online: &[f32],
q_next_target: &[f32],
n_actions: usize,
dones: &[f32],
is_weights: &[f32],
cfg: DqnConfig,
) -> RlResult<DqnLoss> {
let b = q_sa.len();
if rewards.len() != b
|| q_next_online.len() != b * n_actions
|| q_next_target.len() != b * n_actions
|| dones.len() != b
|| is_weights.len() != b
{
return Err(RlError::DimensionMismatch {
expected: b * n_actions,
got: q_next_online.len(),
});
}
let max_q_next: Vec<f32> = (0..b)
.map(|i| {
let slice = &q_next_online[i * n_actions..(i + 1) * n_actions];
let a_star = slice
.iter()
.enumerate()
.max_by(|(_, x), (_, y)| x.partial_cmp(y).unwrap_or(std::cmp::Ordering::Equal))
.map(|(j, _)| j)
.unwrap_or(0);
q_next_target[i * n_actions + a_star]
})
.collect();
dqn_loss(q_sa, rewards, &max_q_next, dones, is_weights, cfg)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn huber_small_delta() {
let l = huber(0.5, 1.0);
assert!((l - 0.125).abs() < 1e-5, "huber(0.5)={l}");
}
#[test]
fn huber_large_delta() {
let l = huber(2.0, 1.0);
assert!((l - 1.5).abs() < 1e-5, "huber(2.0)={l}");
}
#[test]
fn dqn_loss_zero_when_perfect() {
let cfg = DqnConfig {
gamma: 1.0,
huber_kappa: None,
use_is_weights: false,
};
let q = vec![2.0_f32; 4];
let r = vec![1.0_f32; 4];
let max_q_next = vec![1.0_f32; 4];
let done = vec![0.0_f32; 4];
let w = vec![1.0_f32; 4];
let l = dqn_loss(&q, &r, &max_q_next, &done, &w, cfg).unwrap();
assert!(l.loss.abs() < 1e-5, "loss should be 0, got {}", l.loss);
}
#[test]
fn dqn_loss_positive_when_off() {
let cfg = DqnConfig::default();
let q = vec![0.0_f32; 4];
let r = vec![1.0_f32; 4];
let max_q_next = vec![1.0_f32; 4];
let done = vec![0.0_f32; 4];
let w = vec![1.0_f32; 4];
let l = dqn_loss(&q, &r, &max_q_next, &done, &w, cfg).unwrap();
assert!(l.loss > 0.0, "loss should be > 0");
}
#[test]
fn dqn_loss_done_stops_future() {
let cfg = DqnConfig {
gamma: 1.0,
huber_kappa: None,
use_is_weights: false,
};
let q = vec![0.0_f32];
let r = vec![2.0_f32];
let max_q_next = vec![100.0_f32]; let done = vec![1.0_f32];
let w = vec![1.0_f32];
let l = dqn_loss(&q, &r, &max_q_next, &done, &w, cfg).unwrap();
assert!((l.loss - 2.0).abs() < 1e-5, "loss={}", l.loss);
}
#[test]
fn dqn_td_errors_correct_length() {
let cfg = DqnConfig::default();
let n = 8;
let q = vec![0.0_f32; n];
let r = vec![1.0_f32; n];
let max_q = vec![0.5_f32; n];
let done = vec![0.0_f32; n];
let w = vec![1.0_f32; n];
let l = dqn_loss(&q, &r, &max_q, &done, &w, cfg).unwrap();
assert_eq!(l.td_errors.len(), n);
}
#[test]
fn double_dqn_selects_online_argmax() {
let cfg = DqnConfig {
gamma: 1.0,
huber_kappa: None,
use_is_weights: false,
};
let q_sa = vec![0.0_f32];
let r = vec![0.0_f32];
let q_next_on = vec![1.0, 5.0, 2.0]; let q_next_tgt = vec![10.0, 3.0, 10.0];
let done = vec![0.0_f32];
let w = vec![1.0_f32];
let l = double_dqn_loss(&q_sa, &r, &q_next_on, &q_next_tgt, 3, &done, &w, cfg).unwrap();
assert!((l.loss - 4.5).abs() < 1e-4, "loss={}", l.loss);
}
#[test]
fn dqn_dimension_mismatch() {
let cfg = DqnConfig::default();
let q = vec![0.0_f32; 4];
let r = vec![1.0_f32; 3];
let max_q = vec![0.5_f32; 4];
let done = vec![0.0_f32; 4];
let w = vec![1.0_f32; 4];
assert!(dqn_loss(&q, &r, &max_q, &done, &w, cfg).is_err());
}
}