#![allow(clippy::use_self)] mod bandits;
mod builders;
mod cartpole;
mod chain;
mod mdps;
mod memory;
pub mod meta;
mod multiagent;
mod partition;
#[cfg(test)]
pub mod testing;
mod wrappers;
pub use bandits::{
Bandit, BernoulliBandit, DeterministicBandit, OneHotBandits, UniformBernoulliBandits,
};
pub use builders::{BuildEnv, BuildEnvDist, BuildEnvError, CloneBuild};
pub use cartpole::{CartPole, CartPoleConfig};
pub use chain::Chain;
pub use mdps::DirichletRandomMdps;
pub use memory::MemoryGame;
pub use meta::MetaEnv;
pub use multiagent::fruit::{self, FruitGame};
pub use multiagent::views::{FirstPlayerView, SecondPlayerView};
pub use partition::PartitionGame;
pub use wrappers::{
LatentStepLimit, StructurePreservingWrapper, VisibleStepLimit, WithLatentStepLimit,
WithVisibleStepLimit, Wrap, Wrapped,
};
use crate::agents::Actor;
use crate::feedback::Reward;
use crate::logging::StatsLogger;
use crate::simulation::{SimSeed, Steps};
use crate::spaces::{IntervalSpace, Space};
use crate::Prng;
use serde::{Deserialize, Serialize};
use std::borrow::Borrow;
use std::f64;
pub trait Environment {
type State;
type Observation;
type Action;
type Feedback;
fn initial_state(&self, rng: &mut Prng) -> Self::State;
fn observe(&self, state: &Self::State, rng: &mut Prng) -> Self::Observation;
fn step(
&self,
state: Self::State,
action: &Self::Action,
rng: &mut Prng,
logger: &mut dyn StatsLogger,
) -> (Successor<Self::State>, Self::Feedback);
fn run<T, L>(self, actor: T, seed: SimSeed, logger: L) -> Steps<Self, T, Prng, L>
where
T: Actor<Self::Observation, Self::Action>,
L: StatsLogger,
Self: Sized,
{
Steps::new_seeded(self, actor, seed, logger)
}
}
macro_rules! impl_wrapped_environment {
($wrapper:ty) => {
impl<T: Environment + ?Sized> Environment for $wrapper {
type State = T::State;
type Observation = T::Observation;
type Action = T::Action;
type Feedback = T::Feedback;
fn initial_state(&self, rng: &mut Prng) -> Self::State {
T::initial_state(self, rng)
}
fn observe(&self, state: &Self::State, rng: &mut Prng) -> Self::Observation {
T::observe(self, state, rng)
}
fn step(
&self,
state: Self::State,
action: &Self::Action,
rng: &mut Prng,
logger: &mut dyn StatsLogger,
) -> (Successor<Self::State>, Self::Feedback) {
T::step(self, state, action, rng, logger)
}
}
};
}
impl_wrapped_environment!(&'_ T);
impl_wrapped_environment!(Box<T>);
pub trait Pomdp: Environment<Feedback = Reward> {}
impl<T: Environment<Feedback = Reward>> Pomdp for T {}
pub trait EnvStructure {
type ObservationSpace: Space;
type ActionSpace: Space;
type FeedbackSpace: Space;
fn observation_space(&self) -> Self::ObservationSpace;
fn action_space(&self) -> Self::ActionSpace;
fn feedback_space(&self) -> Self::FeedbackSpace;
fn discount_factor(&self) -> f64;
}
macro_rules! impl_wrapped_env_structure {
($wrapper:ty) => {
impl<T: EnvStructure + ?Sized> EnvStructure for $wrapper {
type ObservationSpace = T::ObservationSpace;
type ActionSpace = T::ActionSpace;
type FeedbackSpace = T::FeedbackSpace;
fn observation_space(&self) -> Self::ObservationSpace {
T::observation_space(self)
}
fn action_space(&self) -> Self::ActionSpace {
T::action_space(self)
}
fn feedback_space(&self) -> Self::FeedbackSpace {
T::feedback_space(self)
}
fn discount_factor(&self) -> f64 {
T::discount_factor(self)
}
}
};
}
impl_wrapped_env_structure!(&'_ T);
impl_wrapped_env_structure!(Box<T>);
pub trait PomdpStructure: EnvStructure<FeedbackSpace = IntervalSpace<Reward>> {}
impl<T: EnvStructure<FeedbackSpace = IntervalSpace<Reward>>> PomdpStructure for T {}
pub trait StructuredEnvironment:
EnvStructure
+ Environment<
Observation = <Self::ObservationSpace as Space>::Element,
Action = <Self::ActionSpace as Space>::Element,
Feedback = <Self::FeedbackSpace as Space>::Element,
>
{
}
impl<T> StructuredEnvironment for T where
T: EnvStructure
+ Environment<
Observation = <Self::ObservationSpace as Space>::Element,
Action = <Self::ActionSpace as Space>::Element,
Feedback = <Self::FeedbackSpace as Space>::Element,
> + ?Sized
{
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Successor<T, U = T> {
Continue(U),
Terminate,
Interrupt(T),
}
impl<T, U> Successor<T, U> {
#[allow(clippy::missing_const_for_fn)] #[inline]
pub fn into_continue(self) -> Option<U> {
match self {
Self::Continue(s) => Some(s),
_ => None,
}
}
#[allow(clippy::missing_const_for_fn)] #[inline]
pub fn into_interrupt(self) -> Option<T> {
match self {
Self::Interrupt(s) => Some(s),
_ => None,
}
}
#[inline]
pub const fn episode_done(&self) -> bool {
!matches!(self, Successor::Continue(_))
}
#[allow(clippy::missing_const_for_fn)] #[inline]
pub fn into_partial(self) -> PartialSuccessor<T> {
match self {
Self::Continue(_) => Successor::Continue(()),
Self::Terminate => Successor::Terminate,
Self::Interrupt(s) => Successor::Interrupt(s),
}
}
#[inline]
pub fn map_continue<F, V>(self, f: F) -> Successor<T, V>
where
F: FnOnce(U) -> V,
{
match self {
Self::Continue(s) => Successor::Continue(f(s)),
Self::Terminate => Successor::Terminate,
Self::Interrupt(s) => Successor::Interrupt(s),
}
}
#[allow(clippy::missing_const_for_fn)] #[inline]
pub fn into_partial_continue(self) -> (PartialSuccessor<T>, Option<U>) {
match self {
Self::Continue(o) => (Successor::Continue(()), Some(o)),
Self::Terminate => (Successor::Terminate, None),
Self::Interrupt(o) => (Successor::Interrupt(o), None),
}
}
}
impl<T> Successor<T> {
#[inline]
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> Successor<U> {
match self {
Self::Continue(state) => Successor::Continue(f(state)),
Self::Terminate => Successor::Terminate,
Self::Interrupt(state) => Successor::Interrupt(f(state)),
}
}
#[must_use]
#[inline]
pub fn then_interrupt_if<F>(self, f: F) -> Self
where
F: FnOnce(&T) -> bool,
{
if let Self::Continue(state) = self {
if f(&state) {
Self::Interrupt(state)
} else {
Self::Continue(state)
}
} else {
self
}
}
#[allow(clippy::missing_const_for_fn)] #[inline]
pub fn into_inner(self) -> Option<T> {
match self {
Self::Continue(s) | Self::Interrupt(s) => Some(s),
Self::Terminate => None,
}
}
}
impl<T, U: Borrow<T>> Successor<T, U> {
#[inline]
pub fn as_ref(&self) -> Successor<&T> {
match self {
Self::Continue(s) => Successor::Continue(s.borrow()),
Self::Terminate => Successor::Terminate,
Self::Interrupt(s) => Successor::Interrupt(s),
}
}
}
impl<T: Clone, U: Clone> Successor<&'_ T, &'_ U> {
#[must_use]
#[inline]
pub fn cloned(self) -> Successor<T, U> {
match self {
Self::Continue(s) => Successor::Continue(s.clone()),
Self::Terminate => Successor::Terminate,
Self::Interrupt(s) => Successor::Interrupt(s.clone()),
}
}
}
impl<T: Clone> Successor<T, &'_ T> {
#[inline]
pub fn into_owned(self) -> Successor<T> {
match self {
Self::Continue(s) => Successor::Continue(s.clone()),
Self::Terminate => Successor::Terminate,
Self::Interrupt(s) => Successor::Interrupt(s),
}
}
}
pub type RefSuccessor<'a, T> = Successor<T, &'a T>;
pub type PartialSuccessor<T> = Successor<T, ()>;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct StoredEnvStructure<OS, AS, FS = IntervalSpace<Reward>> {
pub observation_space: OS,
pub action_space: AS,
pub feedback_space: FS,
pub discount_factor: f64,
}
impl<OS, AS, FS> StoredEnvStructure<OS, AS, FS> {
pub const fn new(
observation_space: OS,
action_space: AS,
feedback_space: FS,
discount_factor: f64,
) -> Self {
Self {
observation_space,
action_space,
feedback_space,
discount_factor,
}
}
}
impl<OS, AS, FS> EnvStructure for StoredEnvStructure<OS, AS, FS>
where
OS: Space + Clone,
AS: Space + Clone,
FS: Space + Clone,
{
type ObservationSpace = OS;
type ActionSpace = AS;
type FeedbackSpace = FS;
fn observation_space(&self) -> Self::ObservationSpace {
self.observation_space.clone()
}
fn action_space(&self) -> Self::ActionSpace {
self.action_space.clone()
}
fn feedback_space(&self) -> Self::FeedbackSpace {
self.feedback_space.clone()
}
fn discount_factor(&self) -> f64 {
self.discount_factor
}
}
impl<E> From<&E> for StoredEnvStructure<E::ObservationSpace, E::ActionSpace, E::FeedbackSpace>
where
E: EnvStructure + ?Sized,
{
fn from(env: &E) -> Self {
Self {
observation_space: env.observation_space(),
action_space: env.action_space(),
feedback_space: env.feedback_space(),
discount_factor: env.discount_factor(),
}
}
}
pub trait EnvDistribution {
type State;
type Observation;
type Action;
type Feedback;
type Environment: Environment<
State = Self::State,
Observation = Self::Observation,
Action = Self::Action,
Feedback = Self::Feedback,
>;
fn sample_environment(&self, rng: &mut Prng) -> Self::Environment;
}
pub trait StructuredEnvDist:
EnvStructure
+ EnvDistribution<
Observation = <Self::ObservationSpace as Space>::Element,
Action = <Self::ActionSpace as Space>::Element,
Feedback = <Self::FeedbackSpace as Space>::Element,
>
{
}
impl<T> StructuredEnvDist for T where
T: EnvStructure
+ EnvDistribution<
Observation = <Self::ObservationSpace as Space>::Element,
Action = <Self::ActionSpace as Space>::Element,
Feedback = <Self::FeedbackSpace as Space>::Element,
> + ?Sized
{
}