use std::marker::PhantomData;
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use rapier3d::prelude::*;
use rlevo_core::environment::{ConstructableEnv, Environment, EnvironmentError, EpisodeStatus, SnapshotMetadata};
use rlevo_core::reward::ScalarReward;
use crate::locomotion::backend::{LocomotionBackend, Rapier3DBackend, Rapier3DWorld};
use crate::locomotion::common::{LocomotionSnapshot, TerminationMode, wrap_to_pi};
use super::action::InvertedPendulumAction;
use super::config::InvertedPendulumConfig;
use super::observation::InvertedPendulumObservation;
use super::state::InvertedPendulumState;
pub const METADATA_KEY_ALIVE: &str = "alive";
#[derive(Debug)]
pub struct InvertedPendulum<B: LocomotionBackend = Rapier3DBackend> {
world: B::World,
state: InvertedPendulumState,
config: InvertedPendulumConfig,
rng: StdRng,
steps: usize,
_marker: PhantomData<B>,
}
pub type InvertedPendulumRapier = InvertedPendulum<Rapier3DBackend>;
impl InvertedPendulum<Rapier3DBackend> {
#[must_use]
pub fn with_config(config: InvertedPendulumConfig) -> Self {
let mut rng = StdRng::seed_from_u64(config.seed);
let (world, state) = Self::build_world(&config, &mut rng);
Self {
world,
state,
config,
rng,
steps: 0,
_marker: PhantomData,
}
}
fn build_world(
config: &InvertedPendulumConfig,
rng: &mut StdRng,
) -> (Rapier3DWorld, InvertedPendulumState) {
let mut world = Rapier3DWorld::new(
Vector::new(0.0, 0.0, config.gravity),
config.dt,
config.frame_skip,
);
let n = config.reset_noise_scale;
let init_cart_x: f32 = rng.random_range(-n..=n);
let init_angle: f32 = rng.random_range(-n..=n);
let init_cart_vx: f32 = rng.random_range(-n..=n);
let init_pole_angvel: f32 = rng.random_range(-n..=n);
let cart_z = config.cart_half_extents[2]; let pole_half = config.pole_length * 0.5;
let cart_volume = config.cart_half_extents[0]
* config.cart_half_extents[1]
* config.cart_half_extents[2]
* 8.0;
let cart_density = config.cart_mass / cart_volume.max(f32::EPSILON);
let cart_builder = RigidBodyBuilder::dynamic()
.translation(Vector::new(init_cart_x, 0.0, cart_z))
.linvel(Vector::new(init_cart_vx, 0.0, 0.0))
.enabled_translations(true, false, false)
.enabled_rotations(false, false, false);
let cart = world.add_body(cart_builder);
world.add_collider(
ColliderBuilder::cuboid(
config.cart_half_extents[0],
config.cart_half_extents[1],
config.cart_half_extents[2],
)
.density(cart_density),
cart,
);
let pole_initial_z = cart_z + cart_half_z(config) + pole_half;
let pole_volume = std::f32::consts::PI
* config.pole_radius.powi(2)
* (2.0 * pole_half + (4.0 / 3.0) * config.pole_radius);
let pole_density = config.pole_mass / pole_volume.max(f32::EPSILON);
let pole_builder = RigidBodyBuilder::dynamic()
.translation(Vector::new(init_cart_x, 0.0, pole_initial_z))
.rotation(Vector::new(0.0, init_angle, 0.0))
.angvel(Vector::new(0.0, init_pole_angvel, 0.0))
.enabled_translations(true, true, true)
.enabled_rotations(false, true, false);
let pole = world.add_body(pole_builder);
world.add_collider(
ColliderBuilder::capsule_z(pole_half, config.pole_radius).density(pole_density),
pole,
);
let y_axis: Vector = Vector::new(0.0, 1.0, 0.0);
let joint = RevoluteJointBuilder::new(y_axis)
.local_anchor1(Vector::new(0.0, 0.0, config.cart_half_extents[2]))
.local_anchor2(Vector::new(0.0, 0.0, -pole_half))
.build();
let joint_handle = world.add_impulse_joint(cart, pole, joint);
let state = InvertedPendulumState {
cart,
pole,
joint: joint_handle,
last_obs: InvertedPendulumObservation::default(),
};
(world, state)
}
fn extract_observation(&self) -> InvertedPendulumObservation {
let cart_pose = Rapier3DBackend::get_pose(&self.world, self.state.cart);
let cart_vel = Rapier3DBackend::get_vel(&self.world, self.state.cart);
let pole_pose = Rapier3DBackend::get_pose(&self.world, self.state.pole);
let pole_vel = Rapier3DBackend::get_vel(&self.world, self.state.pole);
let [w, _, y, _] = pole_pose.orientation;
let pole_angle = 2.0 * y.atan2(w);
let pole_angle = wrap_to_pi(pole_angle);
InvertedPendulumObservation([
cart_pose.position[0],
pole_angle,
cart_vel.linear[0],
pole_vel.angular[1],
])
}
fn apply_action(&mut self, action: &InvertedPendulumAction) {
let (lo, hi) = self.config.action_clip;
let clipped = [action.0[0].clamp(lo, hi)];
let torques = self.config.gear.apply(&clipped);
let force = torques[0];
if let Some(cart) = self.world.bodies_mut().get_mut(self.state.cart) {
cart.add_force(Vector::new(force, 0.0, 0.0), true);
}
}
}
impl ConstructableEnv for InvertedPendulum<Rapier3DBackend> {
fn new(_render: bool) -> Self {
Self::with_config(InvertedPendulumConfig::default())
}
}
impl Environment<1, 1, 1> for InvertedPendulum<Rapier3DBackend> {
type StateType = InvertedPendulumState;
type ObservationType = InvertedPendulumObservation;
type ActionType = InvertedPendulumAction;
type RewardType = ScalarReward;
type SnapshotType = LocomotionSnapshot<InvertedPendulumObservation>;
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
self.rng = StdRng::seed_from_u64(self.config.seed);
let (world, mut state) = Self::build_world(&self.config, &mut self.rng);
self.world = world;
state.last_obs = InvertedPendulumObservation::default();
self.state = state;
self.steps = 0;
let obs = self.extract_observation();
self.state.last_obs = obs;
let meta = SnapshotMetadata::new().with(METADATA_KEY_ALIVE, 0.0);
Ok(LocomotionSnapshot::running(obs, ScalarReward(0.0), meta))
}
fn step(
&mut self,
action: InvertedPendulumAction,
) -> Result<Self::SnapshotType, EnvironmentError> {
if !action.0[0].is_finite() {
return Err(EnvironmentError::InvalidAction(format!(
"InvertedPendulum action must be finite, got {}",
action.0[0]
)));
}
self.apply_action(&action);
Rapier3DBackend::step(&mut self.world);
self.steps += 1;
let obs = self.extract_observation();
self.state.last_obs = obs;
let healthy = self.config.healthy.is_healthy(
0.0,
obs.pole_angle(),
&obs.0,
);
let alive_bonus = if healthy { 1.0 } else { 0.0 };
let reward = ScalarReward(alive_bonus);
let status = if !healthy && matches!(self.config.termination, TerminationMode::OnUnhealthy)
{
EpisodeStatus::Terminated
} else if self.steps >= self.config.max_steps {
EpisodeStatus::Truncated
} else {
EpisodeStatus::Running
};
let meta = SnapshotMetadata::new()
.with(METADATA_KEY_ALIVE, alive_bonus)
.with_position(
"cart",
[obs.cart_position(), 0.0, self.config.cart_half_extents[2]],
);
Ok(LocomotionSnapshot::new(obs, reward, status, meta))
}
}
fn cart_half_z(config: &InvertedPendulumConfig) -> f32 {
config.cart_half_extents[2]
}
impl rlevo_core::render::Locomotion2DPayloadSource for InvertedPendulum<Rapier3DBackend> {
fn locomotion2d_snapshot(&self) -> rlevo_core::render::Locomotion2DSnapshot {
use rlevo_core::render::{Locomotion2DSnapshot, Point2};
let cart_pose = Rapier3DBackend::get_pose(&self.world, self.state.cart);
let pole_pose = Rapier3DBackend::get_pose(&self.world, self.state.pole);
let cx = cart_pose.position[0];
let cz = cart_pose.position[2];
let px = pole_pose.position[0];
let pz = pole_pose.position[2];
let tip_x = cx + 2.0 * (px - cx);
let tip_z = cz + 2.0 * (pz - cz);
Locomotion2DSnapshot {
joints: vec![Point2::new(cx, cz), Point2::new(tip_x, tip_z)],
bones: vec![(0, 1)],
ground_y: 0.0,
com: Some(Point2::new(px, pz)),
contacts: vec![],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlevo_core::action::ContinuousAction;
use rlevo_core::base::Action;
use rlevo_core::base::Observation;
use rlevo_core::environment::Snapshot;
#[test]
fn action_shape_and_validity() {
assert_eq!(InvertedPendulumAction::shape(), [1]);
assert!(InvertedPendulumAction::new(0.0).is_valid());
assert!(InvertedPendulumAction::new(3.0).is_valid());
assert!(!InvertedPendulumAction::new(3.5).is_valid());
assert!(!InvertedPendulumAction::new(f32::NAN).is_valid());
}
#[test]
fn observation_shape() {
assert_eq!(InvertedPendulumObservation::shape(), [4]);
}
#[test]
fn reset_returns_running_with_near_zero_obs() {
let mut env = InvertedPendulumRapier::with_config(InvertedPendulumConfig {
seed: 7,
reset_noise_scale: 0.0,
..Default::default()
});
let snap = env.reset().unwrap();
assert!(!snap.is_done());
for v in snap.observation().0 {
assert!(v.abs() < 1e-5, "zero reset noise should give ~zero obs");
}
}
#[test]
fn ctrl_cost_not_paid() {
let mut env = InvertedPendulumRapier::with_config(InvertedPendulumConfig::default());
env.reset().unwrap();
let snap = env.step(InvertedPendulumAction::new(3.0)).unwrap();
let total: f32 = snap.metadata().unwrap().components.values().sum();
assert!((total - snap.reward().0).abs() < 1e-5);
}
#[test]
fn reward_roundtrip_matches_components() {
let mut env = InvertedPendulumRapier::with_config(InvertedPendulumConfig::default());
env.reset().unwrap();
for _ in 0..5 {
let snap = env.step(InvertedPendulumAction::new(0.0)).unwrap();
let meta = snap.metadata().unwrap();
let total: f32 = meta.components.values().sum();
assert!(
(total - snap.reward().0).abs() < 1e-5,
"components sum ({total}) must equal reward ({})",
snap.reward().0
);
}
}
#[test]
fn terminates_when_pole_angle_leaves_band() {
let mut env = InvertedPendulumRapier::with_config(InvertedPendulumConfig {
reset_noise_scale: 0.0,
max_steps: 2000,
..Default::default()
});
env.reset().unwrap();
let mut terminated = false;
let mut max_abs_angle: f32 = 0.0;
let mut cart_x_max: f32 = 0.0;
for i in 0..2000 {
let action = if i < 20 { 3.0 } else { 0.0 };
let snap = env.step(InvertedPendulumAction::new(action)).unwrap();
max_abs_angle = max_abs_angle.max(snap.observation().pole_angle().abs());
cart_x_max = cart_x_max.max(snap.observation().cart_position().abs());
if snap.is_terminated() {
terminated = true;
break;
}
}
assert!(
terminated,
"pushing the cart must eventually drop the pole outside (-0.2, 0.2); \
max |angle| observed = {max_abs_angle}, max |cart_x| = {cart_x_max}"
);
}
#[test]
fn truncates_at_max_steps() {
let mut env = InvertedPendulumRapier::with_config(InvertedPendulumConfig {
max_steps: 5,
termination: TerminationMode::Never,
reset_noise_scale: 0.0,
..Default::default()
});
env.reset().unwrap();
let mut status = EpisodeStatus::Running;
for _ in 0..5 {
let snap = env.step(InvertedPendulumAction::new(0.0)).unwrap();
status = snap.status();
}
assert_eq!(status, EpisodeStatus::Truncated);
}
#[test]
fn determinism_across_reset() {
let cfg = InvertedPendulumConfig {
seed: 123,
..Default::default()
};
let rollout = |actions: &[f32]| {
let mut env = InvertedPendulumRapier::with_config(cfg.clone());
env.reset().unwrap();
let mut last = InvertedPendulumObservation::default();
for &a in actions {
if let Ok(snap) = env.step(InvertedPendulumAction::new(a)) {
last = *snap.observation();
}
}
last
};
let actions = [0.0, 1.0, -1.0, 0.5, 0.0];
assert_eq!(rollout(&actions), rollout(&actions));
}
#[test]
fn invalid_action_is_error() {
let mut env = InvertedPendulumRapier::with_config(InvertedPendulumConfig::default());
env.reset().unwrap();
let bad = InvertedPendulumAction::new(f32::NAN);
assert!(env.step(bad).is_err());
}
#[test]
fn action_clip_at_boundaries() {
let a = InvertedPendulumAction::new(10.0).clip(-3.0, 3.0);
assert_eq!(a.0[0], 3.0);
let a = InvertedPendulumAction::new(-10.0).clip(-3.0, 3.0);
assert_eq!(a.0[0], -3.0);
}
#[test]
fn obs_is_finite_after_rollout() {
let mut env = InvertedPendulumRapier::with_config(InvertedPendulumConfig::default());
env.reset().unwrap();
for _ in 0..50 {
let snap = env.step(InvertedPendulumAction::new(0.1)).unwrap();
assert!(snap.observation().is_finite());
if snap.is_done() {
break;
}
}
}
}