use crate::error::{RlError, RlResult};
use crate::handle::RlHandle;
#[derive(Debug, Clone)]
pub struct DeterministicPolicy {
action_dim: usize,
action_low: f32,
action_high: f32,
}
impl DeterministicPolicy {
#[must_use]
pub fn new(action_dim: usize) -> Self {
Self::with_bounds(action_dim, -1.0, 1.0)
}
#[must_use]
pub fn with_bounds(action_dim: usize, action_low: f32, action_high: f32) -> Self {
assert!(action_dim > 0, "action_dim must be > 0");
assert!(action_low < action_high, "action_low must be < action_high");
Self {
action_dim,
action_low,
action_high,
}
}
#[must_use]
#[inline]
pub fn action_dim(&self) -> usize {
self.action_dim
}
pub fn clip_action(&self, action: &[f32]) -> RlResult<Vec<f32>> {
if action.len() != self.action_dim {
return Err(RlError::DimensionMismatch {
expected: self.action_dim,
got: action.len(),
});
}
Ok(action
.iter()
.map(|&a| a.clamp(self.action_low, self.action_high))
.collect())
}
pub fn exploration_action(
&self,
action: &[f32],
sigma: f32,
handle: &mut RlHandle,
) -> RlResult<Vec<f32>> {
if action.len() != self.action_dim {
return Err(RlError::DimensionMismatch {
expected: self.action_dim,
got: action.len(),
});
}
let rng = handle.rng_mut();
let noisy: Vec<f32> = action
.iter()
.map(|&a| {
let u1 = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
let u2 = rng.next_f32();
let noise = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
(a + sigma * noise).clamp(self.action_low, self.action_high)
})
.collect();
Ok(noisy)
}
pub fn smooth_target_action(
&self,
action: &[f32],
sigma: f32,
clip_c: f32,
handle: &mut RlHandle,
) -> RlResult<Vec<f32>> {
if action.len() != self.action_dim {
return Err(RlError::DimensionMismatch {
expected: self.action_dim,
got: action.len(),
});
}
let rng = handle.rng_mut();
let smoothed: Vec<f32> = action
.iter()
.map(|&a| {
let u1 = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
let u2 = rng.next_f32();
let noise_raw = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos();
let noise = (sigma * noise_raw).clamp(-clip_c, clip_c);
(a + noise).clamp(self.action_low, self.action_high)
})
.collect();
Ok(smoothed)
}
}
#[derive(Debug, Clone)]
pub struct OrnsteinUhlenbeck {
action_dim: usize,
mu: Vec<f32>,
theta: f32,
sigma: f32,
state: Vec<f32>,
}
impl OrnsteinUhlenbeck {
#[must_use]
pub fn new(action_dim: usize, theta: f32, sigma: f32) -> Self {
Self {
action_dim,
mu: vec![0.0; action_dim],
theta,
sigma,
state: vec![0.0; action_dim],
}
}
pub fn reset(&mut self) {
self.state.iter_mut().for_each(|x| *x = 0.0);
}
pub fn sample(&mut self, handle: &mut RlHandle) -> Vec<f32> {
let rng = handle.rng_mut();
let mut out = Vec::with_capacity(self.action_dim);
let mut k = 0;
while k < self.action_dim {
let u1 = (rng.next_f32() + 1e-10).min(1.0 - 1e-10);
let u2 = rng.next_f32();
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * std::f32::consts::PI * u2;
out.push(r * theta.cos());
if k + 1 < self.action_dim {
out.push(r * theta.sin());
}
k += 2;
}
out.truncate(self.action_dim);
for (x, (&mu, &w)) in self.state.iter_mut().zip(self.mu.iter().zip(out.iter())) {
*x += self.theta * (mu - *x) + self.sigma * w;
}
self.state.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn clip_action_within_bounds() {
let p = DeterministicPolicy::new(3);
let clipped = p.clip_action(&[-2.0, 0.0, 2.0]).unwrap();
assert_eq!(clipped, vec![-1.0, 0.0, 1.0]);
}
#[test]
fn exploration_action_stays_in_bounds() {
let p = DeterministicPolicy::new(4);
let mut handle = RlHandle::default_handle();
for _ in 0..100 {
let a = p.exploration_action(&[0.0; 4], 0.3, &mut handle).unwrap();
for v in a {
assert!(
(-1.0..=1.0).contains(&v),
"exploration action out of bounds: {v}"
);
}
}
}
#[test]
fn smooth_target_action_within_bounds() {
let p = DeterministicPolicy::new(2);
let mut handle = RlHandle::default_handle();
for _ in 0..100 {
let a = p
.smooth_target_action(&[0.5, -0.5], 0.2, 0.5, &mut handle)
.unwrap();
for v in a {
assert!(
(-1.0..=1.0).contains(&v),
"smoothed action out of bounds: {v}"
);
}
}
}
#[test]
fn clip_action_dimension_error() {
let p = DeterministicPolicy::new(3);
assert!(p.clip_action(&[0.0; 2]).is_err());
}
#[test]
fn ou_sample_correct_dim() {
let mut ou = OrnsteinUhlenbeck::new(4, 0.15, 0.2);
let mut handle = RlHandle::default_handle();
let noise = ou.sample(&mut handle);
assert_eq!(noise.len(), 4);
}
#[test]
fn ou_reset_zeroes_state() {
let mut ou = OrnsteinUhlenbeck::new(2, 0.15, 0.2);
let mut handle = RlHandle::default_handle();
ou.sample(&mut handle);
ou.reset();
assert_eq!(ou.state, vec![0.0, 0.0]);
}
#[test]
fn ou_mean_reversion() {
let mut ou = OrnsteinUhlenbeck::new(1, 1.0, 0.01);
ou.state = vec![100.0];
let mut handle = RlHandle::default_handle();
for _ in 0..50 {
ou.sample(&mut handle);
}
assert!(
ou.state[0].abs() < 5.0,
"OU should mean-revert, state={}",
ou.state[0]
);
}
}