use burn::backend::NdArray;
use burn::tensor::Tensor;
use rlevo_core::action::DiscreteAction;
use rlevo_core::base::{
Action, Observation, Reward, State, TensorConversionError, TensorConvertible,
};
use rlevo_core::environment::{
Environment, EnvironmentError, EpisodeStatus, Snapshot, SnapshotBase,
};
use rlevo_reinforcement_learning::memory::PrioritizedExperienceReplayBuilder;
use rlevo_reinforcement_learning::metrics::{AgentStats, PerformanceRecord};
use serde::{Deserialize, Serialize};
use std::ops::Add;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
struct WalkObservation {
position: i32,
}
impl Observation<1> for WalkObservation {
fn shape() -> [usize; 1] {
[1]
}
}
impl<B: burn::tensor::backend::Backend> TensorConvertible<1, B> for WalkObservation {
#[allow(clippy::cast_precision_loss)]
fn to_tensor(&self, device: &B::Device) -> Tensor<B, 1> {
Tensor::from_floats([self.position as f32], device)
}
fn from_tensor(_t: Tensor<B, 1>) -> Result<Self, TensorConversionError> {
Err(TensorConversionError {
message: "from_tensor not implemented for WalkObservation".into(),
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct WalkState {
position: i32,
}
impl State<1> for WalkState {
type Observation = WalkObservation;
fn shape() -> [usize; 1] {
[1]
}
fn is_valid(&self) -> bool {
(0..=6).contains(&self.position)
}
fn observe(&self) -> Self::Observation {
WalkObservation {
position: self.position,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum WalkAction {
Left,
Right,
}
impl Action<1> for WalkAction {
fn shape() -> [usize; 1] {
[1]
}
fn is_valid(&self) -> bool {
true
}
}
impl DiscreteAction<1> for WalkAction {
const ACTION_COUNT: usize = 2;
fn from_index(index: usize) -> Self {
match index {
0 => WalkAction::Left,
1 => WalkAction::Right,
_ => panic!("invalid WalkAction index: {index}"),
}
}
fn to_index(&self) -> usize {
match self {
WalkAction::Left => 0,
WalkAction::Right => 1,
}
}
}
impl<B: burn::tensor::backend::Backend> TensorConvertible<1, B> for WalkAction {
#[allow(clippy::cast_precision_loss)]
fn to_tensor(&self, device: &B::Device) -> Tensor<B, 1> {
Tensor::from_floats([self.to_index() as f32], device)
}
fn from_tensor(_t: Tensor<B, 1>) -> Result<Self, TensorConversionError> {
Err(TensorConversionError {
message: "from_tensor not implemented for WalkAction".into(),
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
struct WalkReward(f32);
impl Reward for WalkReward {
fn zero() -> Self {
WalkReward(0.0)
}
}
impl Add for WalkReward {
type Output = Self;
fn add(self, rhs: Self) -> Self {
WalkReward(self.0 + rhs.0)
}
}
impl From<WalkReward> for f32 {
fn from(r: WalkReward) -> f32 {
r.0
}
}
impl<B: burn::tensor::backend::Backend> TensorConvertible<1, B> for WalkReward {
fn to_tensor(&self, device: &B::Device) -> Tensor<B, 1> {
Tensor::from_floats([self.0], device)
}
fn from_tensor(_t: Tensor<B, 1>) -> Result<Self, TensorConversionError> {
Err(TensorConversionError {
message: "from_tensor not implemented for WalkReward".into(),
})
}
}
struct RandomWalkEnv {
state: WalkState,
steps: usize,
}
impl RandomWalkEnv {
const START: i32 = 3;
const GOAL: i32 = 6;
const MAX_STEPS: usize = 20;
}
impl Environment<1, 1, 1> for RandomWalkEnv {
type StateType = WalkState;
type ObservationType = WalkObservation;
type ActionType = WalkAction;
type RewardType = WalkReward;
type SnapshotType = SnapshotBase<1, WalkObservation, WalkReward>;
fn new(_render: bool) -> Self {
Self {
state: WalkState {
position: Self::START,
},
steps: 0,
}
}
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
self.state = WalkState {
position: Self::START,
};
self.steps = 0;
Ok(SnapshotBase::running(
self.state.observe(),
WalkReward::zero(),
))
}
fn step(&mut self, action: WalkAction) -> Result<Self::SnapshotType, EnvironmentError> {
let delta = match action {
WalkAction::Left => -1,
WalkAction::Right => 1,
};
let next = self.state.position + delta;
self.steps += 1;
if next < 0 {
self.state = WalkState { position: 0 };
return Ok(SnapshotBase::terminated(
self.state.observe(),
WalkReward(-1.0),
));
}
if next >= Self::GOAL {
self.state = WalkState {
position: Self::GOAL,
};
return Ok(SnapshotBase::terminated(
self.state.observe(),
WalkReward(1.0),
));
}
self.state = WalkState { position: next };
if self.steps >= Self::MAX_STEPS {
Ok(SnapshotBase::truncated(
self.state.observe(),
WalkReward::zero(),
))
} else {
Ok(SnapshotBase::running(
self.state.observe(),
WalkReward::zero(),
))
}
}
}
#[test]
#[allow(clippy::float_cmp)]
fn full_episode_loop_reaches_goal() {
let mut env = RandomWalkEnv::new(false);
let initial = env.reset().expect("reset");
assert_eq!(initial.status(), EpisodeStatus::Running);
assert_eq!(f32::from(*initial.reward()), 0.0);
assert_eq!(initial.observation().position, RandomWalkEnv::START);
let mut last = env.step(WalkAction::Right).expect("step 1");
assert!(!last.is_done());
last = env.step(WalkAction::Right).expect("step 2");
assert!(!last.is_done());
last = env.step(WalkAction::Right).expect("step 3 -> goal");
assert!(last.is_done());
assert!(last.is_terminated());
assert!(!last.is_truncated());
assert_eq!(f32::from(*last.reward()), 1.0);
assert_eq!(last.observation().position, RandomWalkEnv::GOAL);
}
#[test]
fn replay_buffer_sample_batch_tensor_shapes() {
type B = NdArray;
let device = Default::default();
let mut buffer =
PrioritizedExperienceReplayBuilder::<1, 1, WalkObservation, WalkAction, WalkReward>::new()
.with_capacity(32)
.with_alpha(0.6)
.build();
let mut env = RandomWalkEnv::new(false);
let mut snapshot = env.reset().expect("reset");
for i in 0..20 {
let action = if i % 2 == 0 {
WalkAction::Right
} else {
WalkAction::Left
};
let obs_before = *snapshot.observation();
let next = env.step(action).expect("step");
buffer.add(
obs_before,
action,
*next.reward(),
*next.observation(),
next.is_done(),
Some(1.0),
);
snapshot = if next.is_done() {
env.reset().expect("reset after done")
} else {
next
};
}
let batch = buffer
.sample_batch::<2, 2, B>(8, &device)
.expect("sample_batch");
assert_eq!(batch.observations.shape().dims, [8, 1]);
assert_eq!(batch.actions.shape().dims, [8, 1]);
assert_eq!(batch.rewards.shape().dims, [8]);
assert_eq!(batch.next_observations.shape().dims, [8, 1]);
assert_eq!(batch.dones.shape().dims, [8]);
}
#[derive(Debug, Clone, Copy)]
struct EpisodeResult {
score: f32,
duration: usize,
}
impl PerformanceRecord for EpisodeResult {
fn score(&self) -> f32 {
self.score
}
fn duration(&self) -> usize {
self.duration
}
}
#[test]
fn agent_stats_tracks_episodes_and_sliding_window() {
let mut stats = AgentStats::<EpisodeResult>::new(2);
stats.record(EpisodeResult {
score: 5.0,
duration: 10,
});
stats.record(EpisodeResult {
score: 8.0,
duration: 15,
});
stats.record(EpisodeResult {
score: 3.0,
duration: 7,
});
assert_eq!(stats.total_episodes, 3);
assert_eq!(stats.total_steps, 32);
assert_eq!(stats.best_score, Some(8.0));
let avg = stats.avg_score().expect("avg_score present after records");
assert!((avg - 5.5).abs() < 1e-6, "expected 5.5, got {avg}");
assert_eq!(stats.recent_history.len(), 2);
}