use rlevo_core::base::Observation;
use rlevo_core::environment::{EpisodeStatus, Snapshot, SnapshotMetadata};
use rlevo_core::reward::ScalarReward;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ObservationComponents {
pub include_xy_position: bool,
pub include_cinert: bool,
pub include_cvel: bool,
pub include_qfrc_actuator: bool,
pub include_cfrc_ext: bool,
}
impl ObservationComponents {
#[must_use]
pub const fn minimal() -> Self {
Self {
include_xy_position: false,
include_cinert: false,
include_cvel: false,
include_qfrc_actuator: false,
include_cfrc_ext: false,
}
}
#[must_use]
pub const fn ant_default() -> Self {
Self {
include_xy_position: false,
include_cinert: false,
include_cvel: false,
include_qfrc_actuator: false,
include_cfrc_ext: true,
}
}
#[must_use]
pub const fn humanoid_default() -> Self {
Self {
include_xy_position: false,
include_cinert: true,
include_cvel: true,
include_qfrc_actuator: true,
include_cfrc_ext: true,
}
}
}
impl Default for ObservationComponents {
fn default() -> Self {
Self::minimal()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct HealthyCheck {
pub z_range: Option<(f32, f32)>,
pub angle_range: Option<(f32, f32)>,
pub state_range: Option<(f32, f32)>,
}
impl HealthyCheck {
#[must_use]
pub const fn none() -> Self {
Self {
z_range: None,
angle_range: None,
state_range: None,
}
}
#[must_use]
pub fn is_healthy(self, torso_z: f32, torso_angle: f32, state: &[f32]) -> bool {
if !torso_z.is_finite() || !torso_angle.is_finite() {
return false;
}
if let Some((lo, hi)) = self.z_range
&& (torso_z < lo || torso_z > hi)
{
return false;
}
if let Some((lo, hi)) = self.angle_range
&& (torso_angle < lo || torso_angle > hi)
{
return false;
}
if let Some((lo, hi)) = self.state_range
&& !state.iter().all(|v| v.is_finite() && *v >= lo && *v <= hi)
{
return false;
}
true
}
}
impl Default for HealthyCheck {
fn default() -> Self {
Self::none()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TerminationMode {
#[default]
OnUnhealthy,
Never,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Gear<const N: usize>([f32; N]);
impl<const N: usize> Gear<N> {
#[must_use]
pub const fn new(values: [f32; N]) -> Self {
Self(values)
}
#[must_use]
pub const fn values(&self) -> &[f32; N] {
&self.0
}
#[must_use]
pub fn apply(&self, action: &[f32; N]) -> [f32; N] {
let mut torque = [0.0f32; N];
for i in 0..N {
torque[i] = action[i] * self.0[i];
}
torque
}
}
#[must_use]
pub fn ctrl_cost<const N: usize>(weight: f32, action: &[f32; N]) -> f32 {
let mut sum = 0.0f32;
for a in action {
sum += a * a;
}
weight * sum
}
#[must_use]
pub fn is_finite_state(state: &[f32]) -> bool {
state.iter().all(|v| v.is_finite())
}
#[must_use]
pub fn clip_contact_cost(contact_cost: f32, range: (f32, f32)) -> f32 {
contact_cost.clamp(range.0, range.1)
}
#[must_use]
pub fn wrap_to_pi(angle: f32) -> f32 {
let two_pi = std::f32::consts::TAU;
let mut a = angle % two_pi;
if a > std::f32::consts::PI {
a -= two_pi;
} else if a <= -std::f32::consts::PI {
a += two_pi;
}
a
}
#[derive(Debug, Clone)]
pub struct LocomotionSnapshot<O>
where
O: Observation<1> + Clone,
{
observation: O,
reward: ScalarReward,
status: EpisodeStatus,
metadata: SnapshotMetadata,
}
impl<O> LocomotionSnapshot<O>
where
O: Observation<1> + Clone,
{
#[must_use]
pub fn new(
observation: O,
reward: ScalarReward,
status: EpisodeStatus,
metadata: SnapshotMetadata,
) -> Self {
Self {
observation,
reward,
status,
metadata,
}
}
#[must_use]
pub fn running(observation: O, reward: ScalarReward, metadata: SnapshotMetadata) -> Self {
Self::new(observation, reward, EpisodeStatus::Running, metadata)
}
#[must_use]
pub fn terminated(observation: O, reward: ScalarReward, metadata: SnapshotMetadata) -> Self {
Self::new(observation, reward, EpisodeStatus::Terminated, metadata)
}
#[must_use]
pub fn truncated(observation: O, reward: ScalarReward, metadata: SnapshotMetadata) -> Self {
Self::new(observation, reward, EpisodeStatus::Truncated, metadata)
}
}
impl<O> Snapshot<1> for LocomotionSnapshot<O>
where
O: Observation<1> + Clone,
{
type ObservationType = O;
type RewardType = ScalarReward;
fn observation(&self) -> &O {
&self.observation
}
fn reward(&self) -> &ScalarReward {
&self.reward
}
fn status(&self) -> EpisodeStatus {
self.status
}
fn metadata(&self) -> Option<&SnapshotMetadata> {
Some(&self.metadata)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn observation_components_ant_default_matches_spec() {
let comps = ObservationComponents::ant_default();
assert!(!comps.include_xy_position);
assert!(comps.include_cfrc_ext);
assert!(!comps.include_cinert);
}
#[test]
fn observation_components_humanoid_default_includes_everything_but_xy() {
let comps = ObservationComponents::humanoid_default();
assert!(!comps.include_xy_position);
assert!(comps.include_cinert);
assert!(comps.include_cvel);
assert!(comps.include_qfrc_actuator);
assert!(comps.include_cfrc_ext);
}
#[test]
fn healthy_check_none_is_always_healthy() {
let check = HealthyCheck::none();
assert!(check.is_healthy(1e9, 1e9, &[1e9]));
}
#[test]
fn healthy_check_z_range_gates_height() {
let check = HealthyCheck {
z_range: Some((0.2, 1.0)),
..HealthyCheck::none()
};
assert!(check.is_healthy(0.5, 0.0, &[]));
assert!(!check.is_healthy(0.1, 0.0, &[]));
assert!(!check.is_healthy(1.5, 0.0, &[]));
}
#[test]
fn healthy_check_hopper_style_all_three_ranges() {
let check = HealthyCheck {
z_range: Some((0.7, f32::INFINITY)),
angle_range: Some((-0.2, 0.2)),
state_range: Some((-100.0, 100.0)),
};
assert!(check.is_healthy(1.25, 0.0, &[1.0, -2.0, 3.0]));
assert!(!check.is_healthy(0.5, 0.0, &[1.0])); assert!(!check.is_healthy(1.25, 0.5, &[1.0])); assert!(!check.is_healthy(1.25, 0.0, &[150.0])); assert!(!check.is_healthy(f32::NAN, 0.0, &[]));
}
#[test]
fn gear_apply_scales_quadratically_compatible() {
let gear = Gear::<3>::new([2.0, 3.0, 4.0]);
let torque = gear.apply(&[1.0, 2.0, -0.5]);
assert_eq!(torque, [2.0, 6.0, -2.0]);
}
#[test]
fn ctrl_cost_is_quadratic() {
let a = [1.0, 2.0, 3.0];
let c1 = ctrl_cost(0.5, &a);
let a2 = [2.0, 4.0, 6.0];
let c2 = ctrl_cost(0.5, &a2);
assert!(
(c2 - 4.0 * c1).abs() < 1e-5,
"ctrl_cost(2a) must equal 4·ctrl_cost(a); got {c1} vs {c2}"
);
}
#[test]
fn is_finite_state_trips_on_nan_inf() {
assert!(is_finite_state(&[0.0, 1.0, -1.0]));
assert!(!is_finite_state(&[0.0, f32::NAN]));
assert!(!is_finite_state(&[0.0, f32::INFINITY]));
}
#[test]
fn clip_contact_cost_respects_range() {
assert_eq!(clip_contact_cost(5.0, (0.0, 10.0)), 5.0);
assert_eq!(clip_contact_cost(15.0, (0.0, 10.0)), 10.0);
assert_eq!(clip_contact_cost(-1.0, (0.0, 10.0)), 0.0);
}
#[test]
fn termination_mode_default_is_on_unhealthy() {
assert_eq!(TerminationMode::default(), TerminationMode::OnUnhealthy);
}
#[test]
fn wrap_to_pi_canonical_values() {
use std::f32::consts::PI;
assert!((wrap_to_pi(0.0) - 0.0).abs() < 1e-6);
assert!((wrap_to_pi(PI) - PI).abs() < 1e-6); assert!((wrap_to_pi(-PI) - PI).abs() < 1e-6); assert!((wrap_to_pi(3.0 * PI / 2.0) - (-PI / 2.0)).abs() < 1e-5);
assert!((wrap_to_pi(-3.0 * PI / 2.0) - (PI / 2.0)).abs() < 1e-5);
assert!((wrap_to_pi(4.0 * PI) - 0.0).abs() < 1e-5);
}
}