mod frozen_lake_map;
pub use self::frozen_lake_map::*;
mod frozen_lake_scene;
pub use self::frozen_lake_scene::*;
mod grid;
pub use self::grid::*;
mod reward_grid;
pub use self::reward_grid::*;
mod spawn_frozen_lake;
pub use self::spawn_frozen_lake::*;
mod translate_grid;
pub use self::translate_grid::*;
use crate::prelude::*;
use beet_flow::prelude::*;
use beet_core::prelude::*;
pub struct FrozenLakePlugin;
impl Plugin for FrozenLakePlugin {
fn build(&self, app: &mut App) {
app.add_plugins((
RlSessionPlugin::<FrozenLakeEpParams>::default(),
))
.add_systems(Update, (
translate_grid.in_set(TickSet),
reward_grid.in_set(PostTickSet)))
.add_systems(
Update,
spawn_frozen_lake_episode
.in_set(PostTickSet),
)
.init_asset::<QTable<GridPos, GridDirection>>()
.init_asset_loader::<QTableLoader<GridPos, GridDirection>>()
.register_type::<GridPos>()
.register_type::<RlSession<FrozenLakeEpParams>>()
.register_type::<QTable<GridPos, GridDirection>>()
.register_type::<GridPos>()
.register_type::<GridDirection>()
.register_type::<GridToWorld>()
;
let world = app.world_mut();
world.register_component::<GridPos>();
world.register_component::<GridDirection>();
}
}
#[derive(Debug, Clone, Reflect)]
pub struct FrozenLakeEpParams {
pub learn_params: QLearnParams,
pub map: FrozenLakeMap,
pub grid_to_world: GridToWorld,
}
impl Default for FrozenLakeEpParams {
fn default() -> Self {
let map = FrozenLakeMap::default_four_by_four();
Self {
learn_params: QLearnParams::default(),
grid_to_world: GridToWorld::from_frozen_lake_map(&map, 4.0),
map,
}
}
}
impl EpisodeParams for FrozenLakeEpParams {
fn num_episodes(&self) -> u32 { self.learn_params.n_training_episodes }
}
pub type FrozenLakeQTable = QTable<GridPos, GridDirection>;
#[derive(Debug, Reflect)]
pub struct FrozenLakeQTableSession;
impl RlSessionTypes for FrozenLakeQTableSession {
type State = GridPos;
type Action = GridDirection;
type QLearnPolicy = FrozenLakeQTable;
type Env = QTableEnv<Self::State, Self::Action>;
type EpisodeParams = FrozenLakeEpParams;
}