use rlevo_core::environment::{EpisodeStatus, Snapshot, SnapshotMetadata};
use rlevo_core::reward::ScalarReward;
use super::observation::LunarLanderObservation;
pub const METADATA_KEY_SHAPING: &str = "shaping";
#[derive(Debug, Clone)]
pub struct LunarLanderSnapshot {
pub observation: LunarLanderObservation,
pub reward: ScalarReward,
pub status: EpisodeStatus,
metadata: SnapshotMetadata,
}
impl LunarLanderSnapshot {
pub fn running(obs: LunarLanderObservation, reward: ScalarReward, shaping: f32) -> Self {
Self::make(obs, reward, EpisodeStatus::Running, shaping)
}
pub fn terminated(obs: LunarLanderObservation, reward: ScalarReward, shaping: f32) -> Self {
Self::make(obs, reward, EpisodeStatus::Terminated, shaping)
}
pub fn truncated(obs: LunarLanderObservation, reward: ScalarReward, shaping: f32) -> Self {
Self::make(obs, reward, EpisodeStatus::Truncated, shaping)
}
fn make(
obs: LunarLanderObservation,
reward: ScalarReward,
status: EpisodeStatus,
shaping: f32,
) -> Self {
Self {
observation: obs,
reward,
status,
metadata: SnapshotMetadata::new().with(METADATA_KEY_SHAPING, shaping),
}
}
}
impl Snapshot<1> for LunarLanderSnapshot {
type ObservationType = LunarLanderObservation;
type RewardType = ScalarReward;
fn observation(&self) -> &LunarLanderObservation {
&self.observation
}
fn reward(&self) -> &ScalarReward {
&self.reward
}
fn status(&self) -> EpisodeStatus {
self.status
}
fn metadata(&self) -> Option<&SnapshotMetadata> {
Some(&self.metadata)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metadata_shaping_key_present() {
let obs = LunarLanderObservation::default();
let snap = LunarLanderSnapshot::running(obs, ScalarReward(1.0), 42.5);
let meta = snap.metadata().expect("metadata must be Some");
assert!(
meta.components.contains_key(METADATA_KEY_SHAPING),
"metadata must contain the shaping key"
);
assert!((meta.components[METADATA_KEY_SHAPING] - 42.5).abs() < 1e-6);
}
#[test]
fn test_status_variants() {
let obs = LunarLanderObservation::default();
assert!(
!LunarLanderSnapshot::running(obs.clone(), ScalarReward(0.0), 0.0)
.status()
.is_done()
);
assert!(
LunarLanderSnapshot::terminated(obs.clone(), ScalarReward(0.0), 0.0)
.status()
.is_terminated()
);
assert!(
LunarLanderSnapshot::truncated(obs, ScalarReward(0.0), 0.0)
.status()
.is_truncated()
);
}
}