use super::core::{
GridSnapshot,
action::GridAction,
agent::AgentState,
build_snapshot,
direction::Direction,
dynamics::{StepOutcome, apply_action},
entity::Entity,
grid::Grid,
render::render_ascii,
reward::success_reward,
state::GridState,
};
use rand::SeedableRng;
use rand::rngs::StdRng;
use rlevo_core::environment::{Environment, EnvironmentError};
use rlevo_core::reward::ScalarReward;
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};
use std::str::FromStr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum DistShiftVariant {
#[default]
One,
Two,
}
impl DistShiftVariant {
#[must_use]
pub const fn lava_row(self) -> i32 {
match self {
Self::One => 3,
Self::Two => 5,
}
}
}
impl FromStr for DistShiftVariant {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim().to_ascii_lowercase().as_str() {
"one" | "1" | "distshift1" => Ok(Self::One),
"two" | "2" | "distshift2" => Ok(Self::Two),
other => Err(format!("unknown variant `{other}`")),
}
}
}
const WIDTH: usize = 9;
const HEIGHT: usize = 7;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct DistShiftConfig {
pub variant: DistShiftVariant,
pub max_steps: usize,
pub seed: u64,
}
impl DistShiftConfig {
#[must_use]
pub const fn new(variant: DistShiftVariant, max_steps: usize, seed: u64) -> Self {
Self {
variant,
max_steps,
seed,
}
}
}
impl Default for DistShiftConfig {
fn default() -> Self {
Self {
variant: DistShiftVariant::One,
max_steps: 100,
seed: 0,
}
}
}
impl FromStr for DistShiftConfig {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut cfg = Self::default();
for (idx, raw) in s.trim().split(',').map(str::trim).enumerate() {
if raw.is_empty() {
continue;
}
if let Some((key, value)) = raw.split_once('=') {
match key.trim() {
"variant" => cfg.variant = value.trim().parse()?,
"max_steps" => {
cfg.max_steps = value
.trim()
.parse()
.map_err(|e| format!("max_steps: {e}"))?;
}
"seed" => cfg.seed = value.trim().parse().map_err(|e| format!("seed: {e}"))?,
other => return Err(format!("unknown key `{other}`")),
}
} else {
match idx {
0 => cfg.variant = raw.parse()?,
1 => cfg.max_steps = raw.parse().map_err(|e| format!("max_steps: {e}"))?,
2 => cfg.seed = raw.parse().map_err(|e| format!("seed: {e}"))?,
_ => return Err(format!("unexpected positional value `{raw}`")),
}
}
}
Ok(cfg)
}
}
#[derive(Debug)]
pub struct DistShiftEnv {
state: GridState,
config: DistShiftConfig,
steps: usize,
render: bool,
_rng: StdRng,
}
impl DistShiftEnv {
#[must_use]
pub fn with_config(config: DistShiftConfig, render: bool) -> Self {
let rng = StdRng::seed_from_u64(config.seed);
let state = Self::build(&config);
Self {
state,
config,
steps: 0,
render,
_rng: rng,
}
}
#[must_use]
pub const fn config(&self) -> &DistShiftConfig {
&self.config
}
#[must_use]
pub const fn steps(&self) -> usize {
self.steps
}
#[must_use]
pub const fn state(&self) -> &GridState {
&self.state
}
#[must_use]
pub fn ascii(&self) -> String {
render_ascii(&self.state.grid, &self.state.agent)
}
fn build(config: &DistShiftConfig) -> GridState {
let mut grid = Grid::new(WIDTH, HEIGHT);
grid.draw_walls();
let lava_row = config.variant.lava_row();
for x in 2..=6 {
grid.set(x, lava_row, Entity::Lava);
}
grid.set(7, 1, Entity::Goal);
let agent = AgentState::new(1, 1, Direction::East);
GridState::new(grid, agent)
}
fn emit(&self, reward: f32, done: bool) -> GridSnapshot {
if self.render {
println!("{}", self.ascii());
}
build_snapshot(&self.state, reward, done)
}
}
impl Display for DistShiftEnv {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"DistShiftEnv(variant={:?}, step={}/{})",
self.config.variant, self.steps, self.config.max_steps
)
}
}
impl Environment<3, 3, 1> for DistShiftEnv {
type StateType = GridState;
type ObservationType = super::core::GridObservation;
type ActionType = GridAction;
type RewardType = ScalarReward;
type SnapshotType = GridSnapshot;
fn new(render: bool) -> Self {
Self::with_config(DistShiftConfig::default(), render)
}
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
self.state = Self::build(&self.config);
self.steps = 0;
self._rng = StdRng::seed_from_u64(self.config.seed);
Ok(self.emit(0.0, false))
}
fn step(&mut self, action: Self::ActionType) -> Result<Self::SnapshotType, EnvironmentError> {
self.steps += 1;
let outcome = apply_action(&mut self.state.grid, &mut self.state.agent, action);
let (reward, done) = match outcome {
StepOutcome::ReachedGoal => (success_reward(self.steps, self.config.max_steps), true),
StepOutcome::HitLava => (0.0, true),
_ => {
let done = self.steps >= self.config.max_steps;
(0.0, done)
}
};
Ok(self.emit(reward, done))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlevo_core::environment::Snapshot;
#[test]
fn default_config_is_variant_one() {
let cfg = DistShiftConfig::default();
assert_eq!(cfg.variant, DistShiftVariant::One);
assert_eq!(cfg.max_steps, 100);
}
#[test]
fn variant_lava_row() {
assert_eq!(DistShiftVariant::One.lava_row(), 3);
assert_eq!(DistShiftVariant::Two.lava_row(), 5);
}
#[test]
fn fromstr_variant_aliases() {
assert_eq!(
"one".parse::<DistShiftVariant>().unwrap(),
DistShiftVariant::One
);
assert_eq!(
"two".parse::<DistShiftVariant>().unwrap(),
DistShiftVariant::Two
);
assert_eq!(
"1".parse::<DistShiftVariant>().unwrap(),
DistShiftVariant::One
);
assert_eq!(
"2".parse::<DistShiftVariant>().unwrap(),
DistShiftVariant::Two
);
assert!("three".parse::<DistShiftVariant>().is_err());
}
#[test]
fn fromstr_config_keyvalue() {
let cfg: DistShiftConfig = "variant=two,max_steps=50,seed=9".parse().unwrap();
assert_eq!(cfg.variant, DistShiftVariant::Two);
assert_eq!(cfg.max_steps, 50);
assert_eq!(cfg.seed, 9);
}
#[test]
fn build_variant_one_places_lava_in_row_three() {
let env = DistShiftEnv::with_config(DistShiftConfig::default(), false);
for x in 2..=6 {
assert_eq!(env.state().grid.get(x, 3), Entity::Lava);
}
assert_eq!(env.state().grid.get(7, 1), Entity::Goal);
assert_eq!(env.state().agent.x, 1);
assert_eq!(env.state().agent.y, 1);
}
#[test]
fn build_variant_two_places_lava_in_row_five() {
let env =
DistShiftEnv::with_config(DistShiftConfig::new(DistShiftVariant::Two, 100, 0), false);
for x in 2..=6 {
assert_eq!(env.state().grid.get(x, 5), Entity::Lava);
assert_eq!(env.state().grid.get(x, 3), Entity::Empty);
}
}
#[test]
fn optimal_rollout_along_top_row_reaches_goal() {
let mut env = DistShiftEnv::with_config(DistShiftConfig::default(), false);
env.reset().unwrap();
let mut last = None;
for _ in 0..6 {
last = Some(env.step(GridAction::Forward).unwrap());
}
let snap = last.unwrap();
assert!(snap.is_done());
let reward: f32 = (*snap.reward()).into();
assert!(reward > 0.9);
}
#[test]
fn walking_into_lava_terminates_zero_reward() {
let mut env = DistShiftEnv::with_config(DistShiftConfig::default(), false);
env.reset().unwrap();
env.step(GridAction::Forward).unwrap(); env.step(GridAction::Forward).unwrap(); env.step(GridAction::TurnRight).unwrap();
env.step(GridAction::Forward).unwrap(); let snap = env.step(GridAction::Forward).unwrap(); assert!(snap.is_done());
let reward: f32 = (*snap.reward()).into();
assert_eq!(reward, 0.0);
}
#[test]
fn reset_is_deterministic_between_variants() {
let cfg_a = DistShiftConfig::new(DistShiftVariant::One, 100, 0);
let cfg_b = DistShiftConfig::new(DistShiftVariant::Two, 100, 0);
let mut a = DistShiftEnv::with_config(cfg_a, false);
let mut b = DistShiftEnv::with_config(cfg_b, false);
let sa = a.reset().unwrap();
let sb = b.reset().unwrap();
assert_ne!(
a.state().grid.get(2, 3),
b.state().grid.get(2, 3),
"variant One has lava at row 3 but variant Two does not"
);
let _ = (sa, sb);
}
#[test]
fn unknown_variant_errors() {
assert!("variant=wat".parse::<DistShiftConfig>().is_err());
}
}