use crate::error::{RlError, RlResult};
use crate::handle::RlHandle;
#[derive(Debug, Clone)]
pub struct CategoricalPolicy {
n_actions: usize,
temperature: f32,
}
impl CategoricalPolicy {
#[must_use]
pub fn new(n_actions: usize) -> Self {
assert!(n_actions > 0, "n_actions must be > 0");
Self {
n_actions,
temperature: 1.0,
}
}
#[must_use]
pub fn with_temperature(mut self, t: f32) -> Self {
self.temperature = t;
self
}
#[must_use]
#[inline]
pub fn n_actions(&self) -> usize {
self.n_actions
}
pub fn softmax(&self, logits: &[f32]) -> RlResult<Vec<f32>> {
if logits.len() != self.n_actions {
return Err(RlError::DimensionMismatch {
expected: self.n_actions,
got: logits.len(),
});
}
let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = logits
.iter()
.map(|&l| ((l - max_logit) / self.temperature).exp())
.collect();
let sum: f32 = probs.iter().sum();
for p in probs.iter_mut() {
*p /= sum;
}
Ok(probs)
}
pub fn sample_action(&self, probs: &[f32], handle: &mut RlHandle) -> RlResult<usize> {
if probs.is_empty() {
return Err(RlError::EmptyDistribution);
}
let sum: f32 = probs.iter().sum();
if (sum - 1.0).abs() > 0.01 {
return Err(RlError::InvalidDistribution { sum, tol: 0.01 });
}
let rng = handle.rng_mut();
let mut best_idx = 0;
let mut best_val = f32::NEG_INFINITY;
for (i, &p) in probs.iter().enumerate() {
let u = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
let gumbel = -((-u.ln()).ln());
let val = p.ln() + gumbel;
if val > best_val {
best_val = val;
best_idx = i;
}
}
Ok(best_idx)
}
pub fn sample_from_logits(&self, logits: &[f32], handle: &mut RlHandle) -> RlResult<usize> {
let probs = self.softmax(logits)?;
self.sample_action(&probs, handle)
}
pub fn log_prob(&self, probs: &[f32], action: usize) -> RlResult<f32> {
if probs.len() != self.n_actions {
return Err(RlError::DimensionMismatch {
expected: self.n_actions,
got: probs.len(),
});
}
let p = probs[action].max(1e-10);
Ok(p.ln())
}
pub fn log_prob_batch(&self, probs_batch: &[f32], actions: &[usize]) -> RlResult<Vec<f32>> {
let batch_size = actions.len();
if probs_batch.len() != batch_size * self.n_actions {
return Err(RlError::DimensionMismatch {
expected: batch_size * self.n_actions,
got: probs_batch.len(),
});
}
let mut out = Vec::with_capacity(batch_size);
for (b, &a) in actions.iter().enumerate() {
let p = probs_batch[b * self.n_actions + a].max(1e-10);
out.push(p.ln());
}
Ok(out)
}
pub fn entropy(&self, probs: &[f32]) -> RlResult<f32> {
if probs.len() != self.n_actions {
return Err(RlError::DimensionMismatch {
expected: self.n_actions,
got: probs.len(),
});
}
let h = probs
.iter()
.filter(|&&p| p > 0.0)
.map(|&p| -p * p.ln())
.sum();
Ok(h)
}
pub fn greedy_action(&self, probs: &[f32]) -> RlResult<usize> {
if probs.is_empty() {
return Err(RlError::EmptyDistribution);
}
let (idx, _) = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.expect("probs is non-empty");
Ok(idx)
}
pub fn kl_divergence(&self, p: &[f32], q: &[f32]) -> RlResult<f32> {
if p.len() != q.len() {
return Err(RlError::DimensionMismatch {
expected: p.len(),
got: q.len(),
});
}
let kl = p
.iter()
.zip(q.iter())
.filter(|&(&pi, _)| pi > 0.0)
.map(|(&pi, &qi)| pi * (pi.ln() - qi.max(1e-10).ln()))
.sum();
Ok(kl)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn policy() -> CategoricalPolicy {
CategoricalPolicy::new(4)
}
#[test]
fn softmax_sums_to_one() {
let p = policy();
let probs = p.softmax(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let s: f32 = probs.iter().sum();
assert!((s - 1.0).abs() < 1e-5, "softmax sum={s}");
}
#[test]
fn softmax_max_logit_has_max_prob() {
let p = policy();
let probs = p.softmax(&[1.0, 5.0, 2.0, 3.0]).unwrap();
let max_idx = probs
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(i, _)| i)
.unwrap();
assert_eq!(max_idx, 1, "max logit at idx=1 should have max prob");
}
#[test]
fn softmax_dimension_mismatch() {
let p = policy();
assert!(p.softmax(&[1.0, 2.0]).is_err());
}
#[test]
fn sample_in_range() {
let p = policy();
let probs = p.softmax(&[1.0, 2.0, 3.0, 4.0]).unwrap();
let mut handle = RlHandle::default_handle();
for _ in 0..100 {
let a = p.sample_action(&probs, &mut handle).unwrap();
assert!(a < 4, "action out of range: {a}");
}
}
#[test]
fn sample_deterministic_peaked() {
let p = CategoricalPolicy::new(3).with_temperature(0.01);
let probs = p.softmax(&[0.0, 0.0, 10.0]).unwrap();
let mut handle = RlHandle::default_handle();
let mut counts = [0_usize; 3];
for _ in 0..50 {
let a = p.sample_action(&probs, &mut handle).unwrap();
counts[a] += 1;
}
assert!(
counts[2] > 45,
"peaked dist should mostly pick action 2, counts={counts:?}"
);
}
#[test]
fn sample_invalid_distribution() {
let p = policy();
let mut handle = RlHandle::default_handle();
assert!(p.sample_action(&[0.1, 0.1, 0.1, 0.1], &mut handle).is_err());
}
#[test]
fn log_prob_uniform_distribution() {
let p = CategoricalPolicy::new(4);
let probs = vec![0.25; 4];
let lp = p.log_prob(&probs, 0).unwrap();
assert!((lp - (-2.0_f32.ln() * 2.0)).abs() < 1e-5, "log_prob={lp}");
}
#[test]
fn log_prob_batch_correct() {
let p = CategoricalPolicy::new(2);
let probs = vec![0.7, 0.3, 0.4, 0.6]; let actions = vec![0_usize, 1];
let lps = p.log_prob_batch(&probs, &actions).unwrap();
assert!((lps[0] - 0.7_f32.ln()).abs() < 1e-5);
assert!((lps[1] - 0.6_f32.ln()).abs() < 1e-5);
}
#[test]
fn entropy_uniform_is_max() {
let p = CategoricalPolicy::new(4);
let uniform = vec![0.25; 4];
let peaked = vec![0.97, 0.01, 0.01, 0.01];
let h_u = p.entropy(&uniform).unwrap();
let h_p = p.entropy(&peaked).unwrap();
assert!(h_u > h_p, "uniform entropy {h_u} should be > peaked {h_p}");
}
#[test]
fn entropy_deterministic_is_zero() {
let p = CategoricalPolicy::new(2);
let probs = vec![1.0, 0.0];
let h = p.entropy(&probs).unwrap();
assert!(
h.abs() < 1e-6,
"deterministic entropy should be ≈0, got {h}"
);
}
#[test]
fn greedy_selects_max() {
let p = policy();
let probs = vec![0.1, 0.5, 0.3, 0.1];
let a = p.greedy_action(&probs).unwrap();
assert_eq!(a, 1, "greedy should pick idx=1");
}
#[test]
fn kl_div_identical_distributions_zero() {
let p = policy();
let probs = vec![0.25; 4];
let kl = p.kl_divergence(&probs, &probs).unwrap();
assert!(kl.abs() < 1e-5, "KL(p||p) should be 0, got {kl}");
}
#[test]
fn kl_div_non_negative() {
let p = policy();
let p_dist = vec![0.4, 0.3, 0.2, 0.1];
let q_dist = vec![0.1, 0.2, 0.3, 0.4];
let kl = p.kl_divergence(&p_dist, &q_dist).unwrap();
assert!(kl > 0.0, "KL divergence should be > 0");
}
}