use crate::base::{Action, Observation, Reward, State};
use std::collections::BTreeMap;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EpisodeStatus {
Running,
Terminated,
Truncated,
}
impl EpisodeStatus {
pub const fn is_done(self) -> bool {
matches!(self, Self::Terminated | Self::Truncated)
}
pub const fn is_terminated(self) -> bool {
matches!(self, Self::Terminated)
}
pub const fn is_truncated(self) -> bool {
matches!(self, Self::Truncated)
}
}
#[derive(Debug, Clone, Default)]
pub struct SnapshotMetadata {
pub components: BTreeMap<&'static str, f32>,
pub positions: BTreeMap<&'static str, [f32; 3]>,
}
impl SnapshotMetadata {
pub fn new() -> Self {
Self::default()
}
pub fn with(mut self, key: &'static str, value: f32) -> Self {
self.components.insert(key, value);
self
}
pub fn with_position(mut self, key: &'static str, xyz: [f32; 3]) -> Self {
self.positions.insert(key, xyz);
self
}
}
#[derive(Debug)]
pub enum EnvironmentError {
InvalidAction(String),
RenderFailed(String),
IoError(std::io::Error),
}
impl std::error::Error for EnvironmentError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
EnvironmentError::IoError(io_err) => Some(io_err),
_ => None,
}
}
}
impl std::fmt::Display for EnvironmentError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EnvironmentError::InvalidAction(action_error) => {
write!(f, "Invalid action: {}", action_error)
}
EnvironmentError::RenderFailed(render_error) => {
write!(f, "Render failed: {}", render_error)
}
EnvironmentError::IoError(io_err) => {
write!(f, "IO operation failed: {}", io_err)
}
}
}
}
impl From<std::io::Error> for EnvironmentError {
fn from(error: std::io::Error) -> Self {
EnvironmentError::IoError(error)
}
}
pub trait Snapshot<const R: usize>: Debug {
type ObservationType: Observation<R>;
type RewardType: Reward;
fn observation(&self) -> &Self::ObservationType;
fn reward(&self) -> &Self::RewardType;
fn status(&self) -> EpisodeStatus;
fn is_done(&self) -> bool {
self.status().is_done()
}
fn is_terminated(&self) -> bool {
self.status().is_terminated()
}
fn is_truncated(&self) -> bool {
self.status().is_truncated()
}
fn metadata(&self) -> Option<&SnapshotMetadata> {
None
}
}
#[derive(Debug, Clone)]
pub struct SnapshotBase<const R: usize, ObservationType: Observation<R>, RewardType: Reward> {
pub observation: ObservationType,
pub reward: RewardType,
pub status: EpisodeStatus,
}
impl<const R: usize, ObservationType: Observation<R>, RewardType: Reward>
SnapshotBase<R, ObservationType, RewardType>
{
pub fn running(observation: ObservationType, reward: RewardType) -> Self {
Self {
observation,
reward,
status: EpisodeStatus::Running,
}
}
pub fn terminated(observation: ObservationType, reward: RewardType) -> Self {
Self {
observation,
reward,
status: EpisodeStatus::Terminated,
}
}
pub fn truncated(observation: ObservationType, reward: RewardType) -> Self {
Self {
observation,
reward,
status: EpisodeStatus::Truncated,
}
}
}
impl<const R: usize, ObservationType: Observation<R>, RewardType: Reward> Snapshot<R>
for SnapshotBase<R, ObservationType, RewardType>
{
type ObservationType = ObservationType;
type RewardType = RewardType;
fn observation(&self) -> &Self::ObservationType {
&self.observation
}
fn reward(&self) -> &Self::RewardType {
&self.reward
}
fn status(&self) -> EpisodeStatus {
self.status
}
}
pub trait Environment<const R: usize, const SR: usize, const AR: usize> {
type StateType: State<SR>;
type ObservationType: Observation<R>;
type ActionType: Action<AR>;
type RewardType: Reward;
type SnapshotType: Snapshot<R, ObservationType = Self::ObservationType, RewardType = Self::RewardType>;
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError>;
fn step(&mut self, action: Self::ActionType) -> Result<Self::SnapshotType, EnvironmentError>;
}
pub trait ConstructableEnv {
fn new(render: bool) -> Self;
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use super::*;
use crate::action::DiscreteAction;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct MockObservation {
position: i32,
}
impl Observation<1> for MockObservation {
fn shape() -> [usize; 1] {
[1]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MockState {
position: i32,
}
impl MockState {
fn new(position: i32) -> Self {
Self { position }
}
fn is_in_bounds(position: i32) -> bool {
(0..=6).contains(&position)
}
}
impl State<1> for MockState {
type Observation = MockObservation;
fn numel(&self) -> usize {
7
}
fn shape() -> [usize; 1] {
[7]
}
fn is_valid(&self) -> bool {
Self::is_in_bounds(self.position)
}
fn observe(&self) -> Self::Observation {
MockObservation {
position: self.position,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MockAction {
MoveLeft, MoveRight, }
impl Action<1> for MockAction {
fn is_valid(&self) -> bool {
true }
fn shape() -> [usize; 1] {
[1]
}
}
impl DiscreteAction<1> for MockAction {
const ACTION_COUNT: usize = 2;
fn from_index(index: usize) -> Self {
match index {
0 => MockAction::MoveLeft,
1 => MockAction::MoveRight,
_ => panic!("Unknown action index: {}", index),
}
}
fn to_index(&self) -> usize {
match self {
MockAction::MoveLeft => 0,
MockAction::MoveRight => 1,
}
}
}
use crate::reward::ScalarReward;
struct MockEnvironment {
current_state: MockState,
step_count: usize,
max_steps: usize,
}
impl MockEnvironment {
const START_STATE: i32 = 3;
const MAX_STEPS: usize = 20;
const GOAL_STATE: i32 = 6;
fn with_defaults(_render: bool) -> Self {
Self {
current_state: MockState::new(Self::START_STATE),
step_count: 0,
max_steps: Self::MAX_STEPS,
}
}
}
impl ConstructableEnv for MockEnvironment {
fn new(render: bool) -> Self {
Self::with_defaults(render)
}
}
impl Environment<1, 1, 1> for MockEnvironment {
type StateType = MockState;
type ObservationType = MockObservation;
type ActionType = MockAction;
type RewardType = ScalarReward;
type SnapshotType = SnapshotBase<1, MockObservation, ScalarReward>;
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
self.current_state = MockState::new(Self::START_STATE);
self.step_count = 0;
Ok(SnapshotBase::running(
self.current_state.observe(),
ScalarReward(0.0),
))
}
fn step(
&mut self,
action: Self::ActionType,
) -> Result<Self::SnapshotType, EnvironmentError> {
if !action.is_valid() {
return Err(EnvironmentError::InvalidAction(format!(
"Invalid action: {:?}.",
action
)));
}
let next_position = if action == MockAction::MoveLeft {
self.current_state.position - 1 } else {
self.current_state.position + 1 };
let (new_state, reward, terminated) = if next_position < 0 {
(MockState::new(0), -1.0, true)
} else if next_position > 6 {
(MockState::new(6), -1.0, true)
} else {
let new_state = MockState::new(next_position);
let reward = if next_position == Self::GOAL_STATE {
1.0
} else {
0.0
};
let done = next_position == Self::GOAL_STATE;
(new_state, reward, done)
};
self.current_state = new_state;
self.step_count += 1;
let status = if terminated {
EpisodeStatus::Terminated
} else if self.step_count >= self.max_steps {
EpisodeStatus::Truncated
} else {
EpisodeStatus::Running
};
Ok(SnapshotBase {
observation: new_state.observe(),
reward: ScalarReward(reward),
status,
})
}
}
#[derive(Debug, Clone)]
pub struct CustomSnapshot {
observation: MockObservation,
reward: ScalarReward,
status: EpisodeStatus,
step_count: usize,
cumulative_reward: f32,
}
impl Snapshot<1> for CustomSnapshot {
type ObservationType = MockObservation;
type RewardType = ScalarReward;
fn observation(&self) -> &MockObservation {
&self.observation
}
fn reward(&self) -> &ScalarReward {
&self.reward
}
fn status(&self) -> EpisodeStatus {
self.status
}
}
#[test]
fn test_snapshot_base_creation() {
let obs = MockObservation { position: 42 };
let snapshot = SnapshotBase::running(obs, ScalarReward(1.5));
assert_eq!(snapshot.observation(), &obs);
assert_eq!(snapshot.reward(), &ScalarReward(1.5));
assert!(!snapshot.is_done());
assert_eq!(snapshot.status(), EpisodeStatus::Running);
}
#[test]
fn test_snapshot_base_terminal() {
let obs = MockObservation { position: 0 };
let snapshot = SnapshotBase::terminated(obs, ScalarReward(-1.0));
assert!(snapshot.is_done());
assert!(snapshot.is_terminated());
assert!(!snapshot.is_truncated());
assert_eq!(snapshot.reward(), &ScalarReward(-1.0));
}
#[test]
fn test_snapshot_base_clone() {
let obs = MockObservation { position: 10 };
let snapshot1 = SnapshotBase::running(obs, ScalarReward(0.5));
let snapshot2 = snapshot1.clone();
assert_eq!(snapshot1.observation(), snapshot2.observation());
assert_eq!(snapshot1.reward(), snapshot2.reward());
assert_eq!(snapshot1.is_done(), snapshot2.is_done());
}
#[test]
fn test_snapshot_debug() {
let obs = MockObservation { position: 5 };
let snapshot = SnapshotBase::terminated(obs, ScalarReward(2.0));
let debug_str = format!("{:?}", snapshot);
assert!(debug_str.contains("SnapshotBase"));
assert!(debug_str.contains("position: 5"));
assert!(debug_str.contains("reward: ScalarReward(2.0)"));
assert!(debug_str.contains("Terminated"));
}
#[test]
fn test_custom_snapshot_trait_impl() {
let snapshot = CustomSnapshot {
observation: MockObservation { position: 1 },
reward: ScalarReward(10.0),
status: EpisodeStatus::Running,
step_count: 5,
cumulative_reward: 25.0,
};
assert_eq!(snapshot.observation().position, 1);
assert_eq!(snapshot.reward(), &ScalarReward(10.0));
assert!(!snapshot.is_done());
assert_eq!(snapshot.step_count, 5);
assert_eq!(snapshot.cumulative_reward, 25.0);
}
#[test]
fn test_environment_creation() {
let env = MockEnvironment::new(false);
assert_eq!(env.step_count, 0);
}
#[test]
fn test_environment_reset() {
let mut env = MockEnvironment::new(false);
let snapshot = env.reset().expect("Reset should succeed");
assert_eq!(snapshot.observation().position, 3);
assert_eq!(snapshot.reward(), &ScalarReward(0.0));
assert!(!snapshot.is_done());
}
#[test]
fn test_environment_step_valid_action() {
let mut env = MockEnvironment::new(false);
env.reset().expect("Reset should succeed");
let action = MockAction::MoveRight;
let snapshot = env
.step(action)
.expect("Step with valid action should succeed");
assert_eq!(snapshot.observation().position, 4);
assert_eq!(snapshot.reward(), &ScalarReward(0.0));
}
#[test]
fn test_environment_episode_termination() {
let mut env = MockEnvironment::new(false);
env.reset().expect("Reset should succeed");
env.current_state.position = 0;
for i in 0..6 {
let action = MockAction::MoveRight;
let snapshot = env.step(action).expect("Step should succeed");
if i < 5 {
assert!(
!snapshot.is_done(),
"Episode should not be done before reaching goal"
);
} else {
assert!(
snapshot.is_done(),
"Episode should be done upon reaching goal"
);
}
}
}
#[test]
fn test_environment_reset_clears_state() {
let mut env = MockEnvironment::new(false);
env.reset().expect("Reset should succeed");
for _ in 0..5 {
let action = MockAction::MoveRight;
let _ = env.step(action);
}
let snapshot = env.reset().expect("Second reset should succeed");
assert_eq!(snapshot.observation().position, 3);
assert!(!snapshot.is_done());
}
#[test]
fn test_environment_error_display() {
let error = EnvironmentError::InvalidAction("test action".to_string());
let display_str = format!("{}", error);
assert!(display_str.contains("Invalid action"));
assert!(display_str.contains("test action"));
}
#[test]
fn test_environment_error_io_conversion() {
let io_error = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let env_error = EnvironmentError::from(io_error);
match env_error {
EnvironmentError::IoError(_) => {
}
_ => panic!("Expected IoError variant"),
}
}
#[test]
fn test_environment_error_source() {
let io_error = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "access denied");
let env_error = EnvironmentError::IoError(io_error);
use std::error::Error;
assert!(env_error.source().is_some());
}
#[test]
fn test_environment_multiple_episodes() {
let mut env = MockEnvironment::new(false);
for _episode in 0..3 {
let mut snapshot = env.reset().expect("Reset should succeed");
let mut step = 0;
while !snapshot.is_done() && step < 5 {
let action = MockAction::MoveRight;
snapshot = env.step(action).expect("Step should succeed");
step += 1;
}
}
}
#[test]
fn test_snapshot_reward_conversion() {
let observation = MockObservation { position: 1 };
let snapshot = SnapshotBase::running(observation, ScalarReward(42.5));
let reward_as_f32: f32 = (*snapshot.reward()).into();
assert_eq!(reward_as_f32, 42.5);
}
#[test]
fn test_metadata_default_is_empty() {
let meta = SnapshotMetadata::default();
assert!(meta.components.is_empty());
assert!(meta.positions.is_empty());
}
#[test]
fn test_metadata_builder_components_and_positions() {
let meta = SnapshotMetadata::new()
.with("forward", 1.25)
.with("ctrl", -0.1)
.with_position("torso", [0.5, 0.0, 1.1])
.with_position("com", [0.4, 0.0, 0.9]);
assert_eq!(meta.components.len(), 2);
assert_eq!(meta.components.get("forward"), Some(&1.25));
assert_eq!(meta.positions.len(), 2);
assert_eq!(meta.positions.get("torso"), Some(&[0.5, 0.0, 1.1]));
}
}