use super::{BuildEnv, BuildEnvError, EnvStructure, Environment, Successor};
use crate::feedback::Reward;
use crate::logging::StatsLogger;
use crate::spaces::{Indexed, IndexedTypeSpace, IntervalSpace};
use crate::Prng;
use rand::distributions::{Distribution, Uniform};
use serde::{Deserialize, Serialize};
#[derive(Debug, Default, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct CartPoleConfig {
pub physics_config: PhysicalConstants,
pub env_config: EnvironmentParams,
}
impl BuildEnv for CartPoleConfig {
type Observation = CartPolePhysicalState;
type Action = Push;
type Feedback = Reward;
type ObservationSpace = CartPolePhysicalStateSpace;
type ActionSpace = IndexedTypeSpace<Push>;
type FeedbackSpace = IntervalSpace<Reward>;
type Environment = CartPole;
fn build_env(&self, _: &mut Prng) -> Result<Self::Environment, BuildEnvError> {
Ok(CartPole::new(self.physics_config, self.env_config))
}
}
#[derive(Debug, Default, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct CartPole {
phys: InternalPhysicalConstants,
env: EnvironmentParams,
}
impl CartPole {
#[must_use]
pub fn new(phys: PhysicalConstants, env: EnvironmentParams) -> Self {
Self {
phys: phys.into(),
env,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Indexed, Serialize, Deserialize)]
pub enum Push {
Left,
Right,
}
impl EnvStructure for CartPole {
type ObservationSpace = CartPolePhysicalStateSpace;
type ActionSpace = IndexedTypeSpace<Push>;
type FeedbackSpace = IntervalSpace<Reward>;
fn observation_space(&self) -> Self::ObservationSpace {
let max_pos = self.env.max_pos;
let max_angle = self.env.max_angle;
CartPolePhysicalStateSpace {
cart_position: IntervalSpace::new(-max_pos, max_pos),
cart_velocity: IntervalSpace::default(),
pole_angle: IntervalSpace::new(-max_angle, max_angle),
pole_angular_velocity: IntervalSpace::default(),
}
}
fn action_space(&self) -> Self::ActionSpace {
IndexedTypeSpace::new()
}
fn feedback_space(&self) -> Self::FeedbackSpace {
IntervalSpace::new(Reward(0.0), Reward(1.0))
}
fn discount_factor(&self) -> f64 {
self.env.discount_factor
}
}
impl Environment for CartPole {
type State = CartPoleInternalState;
type Observation = CartPolePhysicalState;
type Action = Push;
type Feedback = Reward;
fn initial_state(&self, rng: &mut Prng) -> Self::State {
let dist = Uniform::new_inclusive(-0.05, 0.05);
CartPoleInternalState {
physical: CartPolePhysicalState {
cart_position: dist.sample(rng),
cart_velocity: dist.sample(rng),
pole_angle: dist.sample(rng),
pole_angular_velocity: dist.sample(rng),
},
cached_normal_velocity_is_positive: true,
}
}
fn observe(&self, state: &Self::State, _: &mut Prng) -> Self::Observation {
assert!(
state.physical.cart_position >= -self.env.max_pos
&& state.physical.cart_position <= self.env.max_pos
&& state.physical.pole_angle >= -self.env.max_angle
&& state.physical.pole_angle <= self.env.max_angle,
"out-of-bounds state should not have been produced"
);
state.physical
}
fn step(
&self,
state: Self::State,
action: &Self::Action,
_: &mut Prng,
_: &mut dyn StatsLogger,
) -> (Successor<Self::State>, Self::Feedback) {
let applied_force = match action {
Push::Left => -self.env.action_force,
Push::Right => self.env.action_force,
};
let next_state = self.phys.next_state(&state, applied_force);
let reward = 1.0;
let terminal = next_state.physical.cart_position.abs() > self.env.max_pos
|| next_state.physical.pole_angle.abs() > self.env.max_angle;
let successor = if terminal {
Successor::Terminate
} else {
Successor::Continue(next_state)
};
(successor, reward.into())
}
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct PhysicalConstants {
pub gravity: f64,
pub mass_cart: f64,
pub mass_pole: f64,
pub length_half_pole: f64,
pub friction_cart: f64,
pub friction_pole: f64,
pub time_step: f64,
}
impl Default for PhysicalConstants {
fn default() -> Self {
Self {
gravity: 9.8,
mass_cart: 1.0,
mass_pole: 0.1,
length_half_pole: 0.5,
friction_cart: 0.01,
friction_pole: 0.01,
time_step: 0.02,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct EnvironmentParams {
pub action_force: f64,
pub max_pos: f64,
pub max_angle: f64,
pub discount_factor: f64,
}
impl Default for EnvironmentParams {
fn default() -> Self {
Self {
action_force: 10.0,
max_pos: 2.4,
max_angle: 12.0f64.to_radians(), discount_factor: 0.99,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
#[serde(into = "PhysicalConstants", from = "PhysicalConstants")]
struct InternalPhysicalConstants {
c: PhysicalConstants,
total_weight: f64,
inv_total_mass: f64,
mass_length_pole: f64,
}
impl Default for InternalPhysicalConstants {
fn default() -> Self {
PhysicalConstants::default().into()
}
}
impl From<PhysicalConstants> for InternalPhysicalConstants {
fn from(c: PhysicalConstants) -> Self {
let total_mass = c.mass_cart + c.mass_pole;
let total_weight = c.gravity * total_mass;
let inv_total_mass = total_mass.recip();
let mass_length_pole = c.mass_pole * c.length_half_pole;
Self {
c,
total_weight,
inv_total_mass,
mass_length_pole,
}
}
}
impl From<InternalPhysicalConstants> for PhysicalConstants {
fn from(ic: InternalPhysicalConstants) -> Self {
ic.c
}
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct CartPolePhysicalState {
pub cart_position: f64,
pub cart_velocity: f64,
pub pole_angle: f64,
pub pole_angular_velocity: f64,
}
#[derive(Debug, Copy, Clone, PartialEq, ProductSpace, Serialize, Deserialize)]
#[element(CartPolePhysicalState)]
pub struct CartPolePhysicalStateSpace {
pub cart_position: IntervalSpace<f64>,
pub cart_velocity: IntervalSpace<f64>,
pub pole_angle: IntervalSpace<f64>,
pub pole_angular_velocity: IntervalSpace<f64>,
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub struct CartPoleInternalState {
physical: CartPolePhysicalState,
cached_normal_velocity_is_positive: bool,
}
impl InternalPhysicalConstants {
pub fn next_state(
&self,
state: &CartPoleInternalState,
applied_force: f64,
) -> CartPoleInternalState {
let phys = &state.physical;
let mut signed_cart_friction = if state.cached_normal_velocity_is_positive {
self.c.friction_cart
} else {
-self.c.friction_cart
};
let (sin_angle, cos_angle) = phys.pole_angle.sin_cos();
let angular_velocity_squared = phys.pole_angular_velocity * phys.pole_angular_velocity;
let mut angular_acceleration = self.angular_acceleration(
phys,
applied_force,
signed_cart_friction,
angular_velocity_squared,
sin_angle,
cos_angle,
);
let mut normal_force = self.normal_force(
angular_acceleration,
angular_velocity_squared,
sin_angle,
cos_angle,
);
let normal_velocity_is_positive = (normal_force * phys.cart_velocity).is_sign_positive();
if normal_velocity_is_positive != state.cached_normal_velocity_is_positive {
signed_cart_friction = -signed_cart_friction;
angular_acceleration = self.angular_acceleration(
phys,
applied_force,
signed_cart_friction,
angular_velocity_squared,
sin_angle,
cos_angle,
);
normal_force = self.normal_force(
angular_acceleration,
angular_velocity_squared,
sin_angle,
cos_angle,
);
}
let force_pole = self.mass_length_pole
* (angular_velocity_squared * sin_angle + angular_acceleration * cos_angle);
let force_friction = -signed_cart_friction * normal_force;
let net_force = applied_force + force_pole + force_friction;
let cart_acceleration = net_force * self.inv_total_mass;
let cart_velocity = phys.cart_velocity + self.c.time_step * cart_acceleration;
let cart_position = phys.cart_position + self.c.time_step * cart_velocity;
let pole_angular_velocity =
phys.pole_angular_velocity + self.c.time_step * angular_acceleration;
let pole_angle = phys.pole_angle + self.c.time_step * phys.pole_angular_velocity;
CartPoleInternalState {
physical: CartPolePhysicalState {
cart_position,
cart_velocity,
pole_angle,
pole_angular_velocity,
},
cached_normal_velocity_is_positive: normal_velocity_is_positive,
}
}
fn angular_acceleration(
&self,
state: &CartPolePhysicalState,
applied_force: f64,
signed_cart_friction: f64,
angular_velocity_squared: f64,
sin_angle: f64,
cos_angle: f64,
) -> f64 {
let alpha = (-applied_force
- self.mass_length_pole
* angular_velocity_squared
* (sin_angle + signed_cart_friction * cos_angle))
* self.inv_total_mass;
let beta = self.c.friction_pole * state.pole_angular_velocity / self.mass_length_pole;
let numerator = self.c.gravity * sin_angle
+ cos_angle * (alpha + self.c.gravity * signed_cart_friction)
- beta;
let denominator = self.c.length_half_pole
* (4.0 / 3.0
- self.c.mass_pole
* cos_angle
* self.inv_total_mass
* (cos_angle - signed_cart_friction));
numerator / denominator
}
fn normal_force(
&self,
angular_acceleration: f64,
angular_velocity_squared: f64,
sin_angle: f64,
cos_angle: f64,
) -> f64 {
self.total_weight
- self.mass_length_pole
* (angular_acceleration * sin_angle + angular_velocity_squared * cos_angle)
}
}
#[cfg(test)]
mod tests {
use super::super::testing;
use super::*;
#[test]
fn run_default() {
testing::check_structured_env(&CartPole::default(), 1000, 0);
}
}