use rlevo_core::action::ContinuousAction;
use rlevo_core::base::Action;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct CarRacingAction {
pub steer: f32,
pub gas: f32,
pub brake: f32,
}
impl CarRacingAction {
fn components_valid(steer: f32, gas: f32, brake: f32) -> bool {
steer.is_finite()
&& steer.abs() <= 1.0
&& gas.is_finite()
&& (0.0..=1.0).contains(&gas)
&& brake.is_finite()
&& (0.0..=1.0).contains(&brake)
}
}
impl Action<1> for CarRacingAction {
fn shape() -> [usize; 1] {
[3]
}
fn is_valid(&self) -> bool {
Self::components_valid(self.steer, self.gas, self.brake)
}
}
impl ContinuousAction<1> for CarRacingAction {
fn as_slice(&self) -> &[f32] {
std::slice::from_ref(&self.steer)
}
fn clip(&self, min: f32, max: f32) -> Self {
Self {
steer: self.steer.clamp(min, max),
gas: self.gas.clamp(min, max),
brake: self.brake.clamp(min, max),
}
}
fn from_slice(values: &[f32]) -> Self {
assert!(values.len() >= 3, "CarRacingAction needs 3 values");
Self {
steer: values[0],
gas: values[1],
brake: values[2],
}
}
}
impl CarRacingAction {
pub fn as_array(&self) -> [f32; 3] {
[self.steer, self.gas, self.brake]
}
pub fn random_valid(rng: &mut rand::rngs::StdRng) -> Self {
use rand::RngExt;
Self {
steer: rng.random_range(-1.0..=1.0),
gas: rng.random_range(0.0..=1.0),
brake: rng.random_range(0.0..=1.0),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shape() {
assert_eq!(CarRacingAction::shape(), [3]);
}
#[test]
fn test_valid_action() {
assert!(
CarRacingAction {
steer: 0.5,
gas: 0.3,
brake: 0.0
}
.is_valid()
);
}
#[test]
fn test_d5_negative_gas() {
assert!(
!CarRacingAction {
steer: 0.0,
gas: -0.1,
brake: 0.0
}
.is_valid()
);
}
#[test]
fn test_d5_steer_out_of_range() {
assert!(
!CarRacingAction {
steer: 1.5,
gas: 0.0,
brake: 0.0
}
.is_valid()
);
}
#[test]
fn test_d5_brake_negative() {
assert!(
!CarRacingAction {
steer: 0.0,
gas: 0.0,
brake: -0.1
}
.is_valid()
);
}
#[test]
fn test_from_slice() {
let a = CarRacingAction::from_slice(&[0.1, 0.5, 0.2]);
assert!((a.steer - 0.1).abs() < 1e-6);
assert!((a.gas - 0.5).abs() < 1e-6);
assert!((a.brake - 0.2).abs() < 1e-6);
}
}