use crate::error::{NumRs2Error, Result};
use scirs2_core::ndarray::Array1;
use scirs2_core::random::{Distribution, Rng, Uniform};
pub trait ExplorationStrategy {
fn select_action<A: RLAgent, R: Rng>(
&self,
agent: &A,
state: &Array1<f64>,
rng: &mut R,
) -> Result<usize>;
fn decay(&mut self);
fn exploration_param(&self) -> f64;
}
pub trait RLAgent {
fn select_greedy_action(&self, state: &Array1<f64>) -> Result<usize>;
fn action_dim(&self) -> usize;
}
pub struct EpsilonGreedy {
epsilon: f64,
epsilon_min: f64,
decay_rate: f64,
}
impl EpsilonGreedy {
pub fn new(epsilon: f64, epsilon_min: f64, decay_rate: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&epsilon) {
return Err(NumRs2Error::ValueError(
"epsilon must be in [0, 1]".to_string(),
));
}
if !(0.0..=1.0).contains(&epsilon_min) {
return Err(NumRs2Error::ValueError(
"epsilon_min must be in [0, 1]".to_string(),
));
}
if !(0.0..=1.0).contains(&decay_rate) {
return Err(NumRs2Error::ValueError(
"decay_rate must be in [0, 1]".to_string(),
));
}
Ok(Self {
epsilon,
epsilon_min,
decay_rate,
})
}
pub fn epsilon(&self) -> f64 {
self.epsilon
}
}
impl ExplorationStrategy for EpsilonGreedy {
fn select_action<A: RLAgent, R: Rng>(
&self,
agent: &A,
state: &Array1<f64>,
rng: &mut R,
) -> Result<usize> {
let dist = Uniform::new(0.0, 1.0)
.map_err(|e| NumRs2Error::ValueError(format!("Uniform distribution error: {}", e)))?;
if dist.sample(rng) < self.epsilon {
let action_dist = Uniform::new(0, agent.action_dim()).map_err(|e| {
NumRs2Error::ValueError(format!("Uniform distribution error: {}", e))
})?;
Ok(action_dist.sample(rng))
} else {
agent.select_greedy_action(state)
}
}
fn decay(&mut self) {
self.epsilon = (self.epsilon * self.decay_rate).max(self.epsilon_min);
}
fn exploration_param(&self) -> f64 {
self.epsilon
}
}
pub struct BoltzmannExploration {
temperature: f64,
temperature_min: f64,
decay_rate: f64,
}
impl BoltzmannExploration {
pub fn new(temperature: f64, temperature_min: f64, decay_rate: f64) -> Result<Self> {
if temperature <= 0.0 {
return Err(NumRs2Error::ValueError(
"temperature must be positive".to_string(),
));
}
if temperature_min <= 0.0 {
return Err(NumRs2Error::ValueError(
"temperature_min must be positive".to_string(),
));
}
if !(0.0..=1.0).contains(&decay_rate) {
return Err(NumRs2Error::ValueError(
"decay_rate must be in [0, 1]".to_string(),
));
}
Ok(Self {
temperature,
temperature_min,
decay_rate,
})
}
pub fn temperature(&self) -> f64 {
self.temperature
}
fn softmax(&self, values: &[f64]) -> Result<Vec<f64>> {
if values.is_empty() {
return Err(NumRs2Error::ValueError(
"Cannot compute softmax of empty array".to_string(),
));
}
let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let exp_values: Vec<f64> = values
.iter()
.map(|&v| ((v - max_val) / self.temperature).exp())
.collect();
let sum: f64 = exp_values.iter().sum();
if sum == 0.0 || !sum.is_finite() {
return Err(NumRs2Error::NumericalError(
"Softmax computation resulted in invalid sum".to_string(),
));
}
Ok(exp_values.iter().map(|&v| v / sum).collect())
}
}
impl ExplorationStrategy for BoltzmannExploration {
fn select_action<A: RLAgent, R: Rng>(
&self,
agent: &A,
state: &Array1<f64>,
_rng: &mut R,
) -> Result<usize> {
agent.select_greedy_action(state)
}
fn decay(&mut self) {
self.temperature = (self.temperature * self.decay_rate).max(self.temperature_min);
}
fn exploration_param(&self) -> f64 {
self.temperature
}
}
pub struct RewardNormalizer {
mean: f64,
var: f64,
count: usize,
epsilon: f64,
}
impl RewardNormalizer {
pub fn new(epsilon: f64) -> Self {
Self {
mean: 0.0,
var: 1.0,
count: 0,
epsilon,
}
}
pub fn update(&mut self, reward: f64) {
self.count += 1;
let delta = reward - self.mean;
self.mean += delta / self.count as f64;
let delta2 = reward - self.mean;
self.var += delta * delta2;
}
pub fn normalize(&self, reward: f64) -> f64 {
if self.count < 2 {
return reward;
}
let std = (self.var / (self.count - 1) as f64).sqrt() + self.epsilon;
(reward - self.mean) / std
}
pub fn mean(&self) -> f64 {
self.mean
}
pub fn std(&self) -> f64 {
if self.count < 2 {
return 1.0;
}
(self.var / (self.count - 1) as f64).sqrt()
}
pub fn count(&self) -> usize {
self.count
}
pub fn reset(&mut self) {
self.mean = 0.0;
self.var = 1.0;
self.count = 0;
}
}
impl Default for RewardNormalizer {
fn default() -> Self {
Self::new(1e-8)
}
}
#[derive(Debug, Clone)]
pub struct EpisodeTracker {
episode_rewards: Vec<f64>,
episode_lengths: Vec<usize>,
current_episode_reward: f64,
current_episode_length: usize,
window_size: usize,
}
impl EpisodeTracker {
pub fn new(window_size: usize) -> Self {
Self {
episode_rewards: Vec::new(),
episode_lengths: Vec::new(),
current_episode_reward: 0.0,
current_episode_length: 0,
window_size,
}
}
pub fn step(&mut self, reward: f64) {
self.current_episode_reward += reward;
self.current_episode_length += 1;
}
pub fn finish_episode(&mut self) {
self.episode_rewards.push(self.current_episode_reward);
self.episode_lengths.push(self.current_episode_length);
if self.episode_rewards.len() > self.window_size {
self.episode_rewards.remove(0);
self.episode_lengths.remove(0);
}
self.current_episode_reward = 0.0;
self.current_episode_length = 0;
}
pub fn num_episodes(&self) -> usize {
self.episode_rewards.len()
}
pub fn mean_reward(&self) -> Option<f64> {
if self.episode_rewards.is_empty() {
return None;
}
let sum: f64 = self.episode_rewards.iter().sum();
Some(sum / self.episode_rewards.len() as f64)
}
pub fn mean_length(&self) -> Option<f64> {
if self.episode_lengths.is_empty() {
return None;
}
let sum: usize = self.episode_lengths.iter().sum();
Some(sum as f64 / self.episode_lengths.len() as f64)
}
pub fn last_reward(&self) -> Option<f64> {
self.episode_rewards.last().copied()
}
pub fn last_length(&self) -> Option<usize> {
self.episode_lengths.last().copied()
}
pub fn episode_rewards(&self) -> &[f64] {
&self.episode_rewards
}
pub fn episode_lengths(&self) -> &[usize] {
&self.episode_lengths
}
pub fn reset(&mut self) {
self.episode_rewards.clear();
self.episode_lengths.clear();
self.current_episode_reward = 0.0;
self.current_episode_length = 0;
}
}
impl Default for EpisodeTracker {
fn default() -> Self {
Self::new(100)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::random::thread_rng;
struct DummyAgent {
action_dim: usize,
}
impl RLAgent for DummyAgent {
fn select_greedy_action(&self, _state: &Array1<f64>) -> Result<usize> {
Ok(0)
}
fn action_dim(&self) -> usize {
self.action_dim
}
}
#[test]
fn test_epsilon_greedy_creation() -> Result<()> {
let strategy = EpsilonGreedy::new(1.0, 0.01, 0.995)?;
assert_eq!(strategy.epsilon(), 1.0);
Ok(())
}
#[test]
fn test_epsilon_greedy_invalid_params() {
assert!(EpsilonGreedy::new(1.5, 0.01, 0.995).is_err());
assert!(EpsilonGreedy::new(-0.1, 0.01, 0.995).is_err());
assert!(EpsilonGreedy::new(1.0, 1.5, 0.995).is_err());
assert!(EpsilonGreedy::new(1.0, 0.01, 1.5).is_err());
}
#[test]
fn test_epsilon_greedy_decay() -> Result<()> {
let mut strategy = EpsilonGreedy::new(1.0, 0.01, 0.9)?;
strategy.decay();
assert!((strategy.epsilon() - 0.9).abs() < 1e-6);
for _ in 0..100 {
strategy.decay();
}
assert!(strategy.epsilon() >= 0.01);
Ok(())
}
#[test]
fn test_epsilon_greedy_action_selection() -> Result<()> {
let agent = DummyAgent { action_dim: 4 };
let mut rng = thread_rng();
let state = Array1::zeros(2);
let strategy = EpsilonGreedy::new(0.0, 0.0, 1.0)?;
let action = strategy.select_action(&agent, &state, &mut rng)?;
assert_eq!(action, 0);
let strategy = EpsilonGreedy::new(1.0, 1.0, 1.0)?;
let action = strategy.select_action(&agent, &state, &mut rng)?;
assert!(action < 4);
Ok(())
}
#[test]
fn test_boltzmann_creation() -> Result<()> {
let strategy = BoltzmannExploration::new(1.0, 0.1, 0.99)?;
assert_eq!(strategy.temperature(), 1.0);
Ok(())
}
#[test]
fn test_boltzmann_invalid_params() {
assert!(BoltzmannExploration::new(0.0, 0.1, 0.99).is_err());
assert!(BoltzmannExploration::new(-1.0, 0.1, 0.99).is_err());
assert!(BoltzmannExploration::new(1.0, 0.0, 0.99).is_err());
assert!(BoltzmannExploration::new(1.0, 0.1, 1.5).is_err());
}
#[test]
fn test_boltzmann_decay() -> Result<()> {
let mut strategy = BoltzmannExploration::new(10.0, 0.1, 0.9)?;
strategy.decay();
assert!((strategy.temperature() - 9.0).abs() < 1e-6);
for _ in 0..200 {
strategy.decay();
}
assert!(strategy.temperature() >= 0.1);
Ok(())
}
#[test]
fn test_boltzmann_softmax() -> Result<()> {
let strategy = BoltzmannExploration::new(1.0, 0.1, 0.99)?;
let values = vec![1.0, 2.0, 3.0];
let probs = strategy.softmax(&values)?;
assert_eq!(probs.len(), 3);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
assert!(probs[2] > probs[1]);
assert!(probs[1] > probs[0]);
Ok(())
}
#[test]
fn test_boltzmann_softmax_empty() -> Result<()> {
let strategy = BoltzmannExploration::new(1.0, 0.1, 0.99)?;
let values: Vec<f64> = vec![];
let result = strategy.softmax(&values);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_reward_normalizer_creation() {
let normalizer = RewardNormalizer::new(1e-8);
assert_eq!(normalizer.mean(), 0.0);
assert_eq!(normalizer.count(), 0);
}
#[test]
fn test_reward_normalizer_update() {
let mut normalizer = RewardNormalizer::new(1e-8);
normalizer.update(1.0);
normalizer.update(2.0);
normalizer.update(3.0);
assert_eq!(normalizer.count(), 3);
assert!((normalizer.mean() - 2.0).abs() < 1e-6);
}
#[test]
fn test_reward_normalizer_normalize() {
let mut normalizer = RewardNormalizer::new(1e-8);
for i in 1..=5 {
normalizer.update(i as f64);
}
let normalized = normalizer.normalize(3.0);
assert!((normalized - 0.0).abs() < 0.5); }
#[test]
fn test_reward_normalizer_reset() {
let mut normalizer = RewardNormalizer::new(1e-8);
normalizer.update(1.0);
normalizer.update(2.0);
normalizer.reset();
assert_eq!(normalizer.count(), 0);
assert_eq!(normalizer.mean(), 0.0);
}
#[test]
fn test_episode_tracker_creation() {
let tracker = EpisodeTracker::new(100);
assert_eq!(tracker.num_episodes(), 0);
assert!(tracker.mean_reward().is_none());
}
#[test]
fn test_episode_tracker_step() {
let mut tracker = EpisodeTracker::new(100);
tracker.step(1.0);
tracker.step(2.0);
tracker.step(3.0);
tracker.finish_episode();
assert_eq!(tracker.num_episodes(), 1);
assert_eq!(tracker.last_reward(), Some(6.0));
assert_eq!(tracker.last_length(), Some(3));
}
#[test]
fn test_episode_tracker_multiple_episodes() {
let mut tracker = EpisodeTracker::new(100);
for _ in 0..3 {
for i in 1..=10 {
tracker.step(i as f64);
}
tracker.finish_episode();
}
assert_eq!(tracker.num_episodes(), 3);
assert_eq!(tracker.mean_reward(), Some(55.0));
assert_eq!(tracker.mean_length(), Some(10.0));
}
#[test]
fn test_episode_tracker_window() {
let mut tracker = EpisodeTracker::new(2);
for episode in 1..=5 {
tracker.step(episode as f64);
tracker.finish_episode();
}
assert_eq!(tracker.num_episodes(), 2); assert_eq!(tracker.last_reward(), Some(5.0));
assert_eq!(tracker.mean_reward(), Some(4.5)); }
#[test]
fn test_episode_tracker_reset() {
let mut tracker = EpisodeTracker::new(100);
tracker.step(1.0);
tracker.finish_episode();
tracker.reset();
assert_eq!(tracker.num_episodes(), 0);
assert!(tracker.mean_reward().is_none());
}
#[test]
fn test_episode_tracker_accessors() {
let mut tracker = EpisodeTracker::new(100);
for i in 1..=3 {
tracker.step(i as f64);
tracker.step(i as f64);
tracker.finish_episode();
}
let rewards = tracker.episode_rewards();
assert_eq!(rewards, &[2.0, 4.0, 6.0]);
let lengths = tracker.episode_lengths();
assert_eq!(lengths, &[2, 2, 2]);
}
}