use crate::base::Action;
use std::error::Error;
use std::fmt::Debug;
pub trait DiscreteAction<const R: usize>: Action<R> {
const ACTION_COUNT: usize;
fn from_index(index: usize) -> Self;
fn to_index(&self) -> usize;
fn random() -> Self
where
Self: Sized,
{
use rand::RngExt;
let mut rng = rand::rng();
let index = rng.random_range(0..Self::ACTION_COUNT);
Self::from_index(index)
}
fn enumerate() -> Vec<Self>
where
Self: Sized,
{
(0..Self::ACTION_COUNT).map(Self::from_index).collect()
}
}
pub trait MultiDiscreteAction<const R: usize>: Action<R> {
fn from_indices(indices: [usize; R]) -> Self;
fn to_indices(&self) -> [usize; R];
fn random() -> Self
where
Self: Sized,
{
use rand::RngExt;
let mut rng = rand::rng();
let space = Self::shape();
let indices = space.map(|dim| rng.random_range(0..dim));
Self::from_indices(indices)
}
fn enumerate() -> Vec<Self>
where
Self: Sized,
{
let space = Self::shape();
let total: usize = space.iter().product();
let mut actions = Vec::with_capacity(total);
fn generate<const R: usize, T: MultiDiscreteAction<R>>(
space: &[usize; R],
current: &mut [usize; R],
axis: usize,
actions: &mut Vec<T>,
) {
if axis == R {
actions.push(T::from_indices(*current));
return;
}
for i in 0..space[axis] {
current[axis] = i;
generate(space, current, axis + 1, actions);
}
}
let mut current = [0; R];
generate(&space, &mut current, 0, &mut actions);
actions
}
}
pub trait ContinuousAction<const R: usize>: Action<R> {
fn as_slice(&self) -> &[f32];
fn clip(&self, min: f32, max: f32) -> Self;
fn random() -> Self
where
Self: Sized,
{
use rand::RngExt;
let mut rng = rand::rng();
let values: Vec<f32> = (0..Self::RANK)
.map(|_| rng.random_range(-1.0..1.0))
.collect();
Self::from_slice(&values)
}
fn from_slice(values: &[f32]) -> Self;
}
pub trait BoundedAction<const R: usize>: ContinuousAction<R> {
fn low() -> [f32; R];
fn high() -> [f32; R];
}
#[derive(Debug, Clone, PartialEq)]
pub struct InvalidActionError {
pub message: String,
}
impl std::fmt::Display for InvalidActionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid action: {}", self.message)
}
}
impl Error for InvalidActionError {}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SimpleDiscreteAction {
Left,
Right,
Up,
Down,
}
impl Action<1> for SimpleDiscreteAction {
fn shape() -> [usize; 1] {
[4]
}
fn is_valid(&self) -> bool {
true }
}
impl DiscreteAction<1> for SimpleDiscreteAction {
const ACTION_COUNT: usize = 4;
fn from_index(index: usize) -> Self {
match index {
0 => SimpleDiscreteAction::Left,
1 => SimpleDiscreteAction::Right,
2 => SimpleDiscreteAction::Up,
3 => SimpleDiscreteAction::Down,
_ => panic!("Index out of bounds: {}", index),
}
}
fn to_index(&self) -> usize {
match self {
SimpleDiscreteAction::Left => 0,
SimpleDiscreteAction::Right => 1,
SimpleDiscreteAction::Up => 2,
SimpleDiscreteAction::Down => 3,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MultiActionTest {
direction: usize, intensity: usize, }
impl Action<2> for MultiActionTest {
fn shape() -> [usize; 2] {
[4, 3]
}
fn is_valid(&self) -> bool {
self.direction < 4 && self.intensity < 3
}
}
impl MultiDiscreteAction<2> for MultiActionTest {
fn from_indices(indices: [usize; 2]) -> Self {
if indices[0] >= 4 {
panic!("Direction index out of bounds: {}", indices[0]);
}
if indices[1] >= 3 {
panic!("Intensity index out of bounds: {}", indices[1]);
}
MultiActionTest {
direction: indices[0],
intensity: indices[1],
}
}
fn to_indices(&self) -> [usize; 2] {
[self.direction, self.intensity]
}
}
#[derive(Debug, Clone)]
struct ContinuousActionTest {
values: [f32; 3],
}
impl Action<3> for ContinuousActionTest {
fn shape() -> [usize; 3] {
[1, 1, 1] }
fn is_valid(&self) -> bool {
self.values.iter().all(|v| v.is_finite())
}
}
impl ContinuousAction<3> for ContinuousActionTest {
fn as_slice(&self) -> &[f32] {
&self.values
}
fn clip(&self, min: f32, max: f32) -> Self {
let clipped = self
.values
.iter()
.map(|&v| v.max(min).min(max))
.collect::<Vec<_>>();
ContinuousActionTest {
values: [clipped[0], clipped[1], clipped[2]],
}
}
fn from_slice(values: &[f32]) -> Self {
assert_eq!(values.len(), 3, "Expected 3 values, got {}", values.len());
ContinuousActionTest {
values: [values[0], values[1], values[2]],
}
}
}
impl BoundedAction<3> for ContinuousActionTest {
fn low() -> [f32; 3] {
[-1.0, -1.0, -1.0]
}
fn high() -> [f32; 3] {
[1.0, 1.0, 1.0]
}
}
#[test]
fn test_discrete_action_shape() {
assert_eq!(SimpleDiscreteAction::shape(), [4]);
assert_eq!(SimpleDiscreteAction::RANK, 1);
}
#[test]
fn test_discrete_action_count() {
assert_eq!(SimpleDiscreteAction::ACTION_COUNT, 4);
}
#[test]
fn test_discrete_action_from_index() {
assert_eq!(
SimpleDiscreteAction::from_index(0),
SimpleDiscreteAction::Left
);
assert_eq!(
SimpleDiscreteAction::from_index(1),
SimpleDiscreteAction::Right
);
assert_eq!(
SimpleDiscreteAction::from_index(2),
SimpleDiscreteAction::Up
);
assert_eq!(
SimpleDiscreteAction::from_index(3),
SimpleDiscreteAction::Down
);
}
#[test]
#[should_panic(expected = "Index out of bounds")]
fn test_discrete_action_from_index_out_of_bounds() {
SimpleDiscreteAction::from_index(4);
}
#[test]
#[should_panic(expected = "Index out of bounds")]
fn test_discrete_action_from_index_negative_like() {
SimpleDiscreteAction::from_index(100);
}
#[test]
fn test_discrete_action_to_index() {
assert_eq!(SimpleDiscreteAction::Left.to_index(), 0);
assert_eq!(SimpleDiscreteAction::Right.to_index(), 1);
assert_eq!(SimpleDiscreteAction::Up.to_index(), 2);
assert_eq!(SimpleDiscreteAction::Down.to_index(), 3);
}
#[test]
fn test_discrete_action_roundtrip() {
for i in 0..SimpleDiscreteAction::ACTION_COUNT {
let action = SimpleDiscreteAction::from_index(i);
assert_eq!(action.to_index(), i);
}
}
#[test]
fn test_discrete_action_enumerate() {
let actions = SimpleDiscreteAction::enumerate();
assert_eq!(actions.len(), 4);
assert_eq!(
actions,
vec![
SimpleDiscreteAction::Left,
SimpleDiscreteAction::Right,
SimpleDiscreteAction::Up,
SimpleDiscreteAction::Down
]
);
}
#[test]
fn test_discrete_action_random() {
for _ in 0..100 {
let action = SimpleDiscreteAction::random();
let index = action.to_index();
assert!(index < SimpleDiscreteAction::ACTION_COUNT);
}
}
#[test]
fn test_discrete_action_is_valid() {
for i in 0..SimpleDiscreteAction::ACTION_COUNT {
let action = SimpleDiscreteAction::from_index(i);
assert!(action.is_valid());
}
}
#[test]
fn test_multidiscrete_action_shape() {
assert_eq!(MultiActionTest::shape(), [4, 3]);
assert_eq!(MultiActionTest::RANK, 2);
}
#[test]
fn test_multidiscrete_action_from_indices() {
let action = MultiActionTest::from_indices([0, 0]);
assert_eq!(action.direction, 0);
assert_eq!(action.intensity, 0);
let action = MultiActionTest::from_indices([3, 2]);
assert_eq!(action.direction, 3);
assert_eq!(action.intensity, 2);
}
#[test]
#[should_panic(expected = "Direction index out of bounds")]
fn test_multidiscrete_action_from_indices_direction_out_of_bounds() {
MultiActionTest::from_indices([4, 0]);
}
#[test]
#[should_panic(expected = "Intensity index out of bounds")]
fn test_multidiscrete_action_from_indices_intensity_out_of_bounds() {
MultiActionTest::from_indices([0, 3]);
}
#[test]
fn test_multidiscrete_action_to_indices() {
let action = MultiActionTest::from_indices([2, 1]);
assert_eq!(action.to_indices(), [2, 1]);
}
#[test]
fn test_multidiscrete_action_roundtrip() {
for d in 0..4 {
for i in 0..3 {
let action = MultiActionTest::from_indices([d, i]);
assert_eq!(action.to_indices(), [d, i]);
}
}
}
#[test]
fn test_multidiscrete_action_enumerate() {
let actions = MultiActionTest::enumerate();
assert_eq!(actions.len(), 12);
for (idx, action) in actions.iter().enumerate() {
let expected_d = idx / 3;
let expected_i = idx % 3;
assert_eq!(action.direction, expected_d);
assert_eq!(action.intensity, expected_i);
}
}
#[test]
fn test_multidiscrete_action_enumerate_large_space() {
#[derive(Debug, Clone)]
struct LargeMultiAction([usize; 3]);
impl Action<3> for LargeMultiAction {
fn shape() -> [usize; 3] {
[5, 5, 5]
}
fn is_valid(&self) -> bool {
self.0.iter().enumerate().all(|(i, &v)| v < [5, 5, 5][i])
}
}
impl MultiDiscreteAction<3> for LargeMultiAction {
fn from_indices(indices: [usize; 3]) -> Self {
for (i, &idx) in indices.iter().enumerate() {
assert!(idx < 5, "Index {} out of bounds", i);
}
LargeMultiAction(indices)
}
fn to_indices(&self) -> [usize; 3] {
self.0
}
}
let actions = LargeMultiAction::enumerate();
assert_eq!(actions.len(), 125);
}
#[test]
fn test_multidiscrete_action_random() {
for _ in 0..100 {
let action = MultiActionTest::random();
assert!(action.is_valid());
let indices = action.to_indices();
assert!(indices[0] < 4);
assert!(indices[1] < 3);
}
}
#[test]
fn test_multidiscrete_action_is_valid() {
assert!(MultiActionTest::from_indices([0, 0]).is_valid());
assert!(MultiActionTest::from_indices([3, 2]).is_valid());
let invalid = MultiActionTest {
direction: 5,
intensity: 0,
};
assert!(!invalid.is_valid());
let invalid = MultiActionTest {
direction: 0,
intensity: 5,
};
assert!(!invalid.is_valid());
}
#[test]
fn test_continuous_action_shape() {
assert_eq!(ContinuousActionTest::shape(), [1, 1, 1]);
assert_eq!(ContinuousActionTest::RANK, 3);
}
#[test]
fn test_continuous_action_as_slice() {
let action = ContinuousActionTest {
values: [0.5, -0.3, 1.0],
};
let slice = action.as_slice();
assert_eq!(slice.len(), 3);
assert_eq!(slice, &[0.5, -0.3, 1.0]);
}
#[test]
fn test_continuous_action_from_slice() {
let values = [0.1, 0.2, 0.3];
let action = ContinuousActionTest::from_slice(&values);
assert_eq!(action.values, values);
}
#[test]
#[should_panic(expected = "Expected 3 values")]
fn test_continuous_action_from_slice_wrong_size() {
let values = [0.1, 0.2];
ContinuousActionTest::from_slice(&values);
}
#[test]
fn test_continuous_action_roundtrip() {
let original = [0.5, -0.3, 0.9];
let action = ContinuousActionTest::from_slice(&original);
assert_eq!(action.as_slice(), &original);
}
#[test]
fn test_continuous_action_clip_within_bounds() {
let action = ContinuousActionTest {
values: [0.0, 0.5, -0.5],
};
let clipped = action.clip(-1.0, 1.0);
assert_eq!(clipped.values, [0.0, 0.5, -0.5]);
}
#[test]
fn test_continuous_action_clip_exceeds_max() {
let action = ContinuousActionTest {
values: [2.0, 1.5, 3.0],
};
let clipped = action.clip(-1.0, 1.0);
assert_eq!(clipped.values, [1.0, 1.0, 1.0]);
}
#[test]
fn test_continuous_action_clip_exceeds_min() {
let action = ContinuousActionTest {
values: [-2.0, -1.5, -3.0],
};
let clipped = action.clip(-1.0, 1.0);
assert_eq!(clipped.values, [-1.0, -1.0, -1.0]);
}
#[test]
fn test_continuous_action_clip_mixed() {
let action = ContinuousActionTest {
values: [2.0, 0.5, -2.0],
};
let clipped = action.clip(-1.0, 1.0);
assert_eq!(clipped.values, [1.0, 0.5, -1.0]);
}
#[test]
fn test_continuous_action_random() {
for _ in 0..100 {
let action = ContinuousActionTest::random();
assert!(action.is_valid());
for &value in action.as_slice() {
assert!((-1.0..=1.0).contains(&value));
assert!(value.is_finite());
}
}
}
#[test]
fn test_continuous_action_is_valid_finite() {
let action = ContinuousActionTest {
values: [0.5, -0.3, 1.0],
};
assert!(action.is_valid());
}
#[test]
fn test_continuous_action_is_invalid_nan() {
let action = ContinuousActionTest {
values: [f32::NAN, 0.5, 1.0],
};
assert!(!action.is_valid());
}
#[test]
fn test_continuous_action_is_invalid_inf() {
let action = ContinuousActionTest {
values: [f32::INFINITY, 0.5, 1.0],
};
assert!(!action.is_valid());
let action = ContinuousActionTest {
values: [f32::NEG_INFINITY, 0.5, 1.0],
};
assert!(!action.is_valid());
}
#[test]
fn test_invalid_action_error_creation() {
let error = InvalidActionError {
message: String::from("Index out of bounds"),
};
assert_eq!(error.message, "Index out of bounds");
}
#[test]
fn test_invalid_action_error_display() {
let error = InvalidActionError {
message: String::from("Invalid value"),
};
let displayed = format!("{}", error);
assert_eq!(displayed, "Invalid action: Invalid value");
}
#[test]
fn test_invalid_action_error_debug() {
let error = InvalidActionError {
message: String::from("Test error"),
};
let debug_str = format!("{:?}", error);
assert!(debug_str.contains("Test error"));
}
#[test]
fn test_invalid_action_error_clone() {
let error = InvalidActionError {
message: String::from("Original"),
};
let cloned = error.clone();
assert_eq!(error, cloned);
}
#[test]
fn test_invalid_action_error_equality() {
let error1 = InvalidActionError {
message: String::from("Same error"),
};
let error2 = InvalidActionError {
message: String::from("Same error"),
};
let error3 = InvalidActionError {
message: String::from("Different error"),
};
assert_eq!(error1, error2);
assert_ne!(error1, error3);
}
#[test]
fn test_invalid_action_error_is_error() {
let error: Box<dyn Error> = Box::new(InvalidActionError {
message: String::from("Test"),
});
let _msg = error.to_string();
}
#[test]
fn test_discrete_action_clone_and_debug() {
let action = SimpleDiscreteAction::Left;
let cloned = action;
assert_eq!(action, cloned);
let debug_str = format!("{:?}", action);
assert!(debug_str.contains("Left"));
}
#[test]
fn test_multidiscrete_action_clone_and_debug() {
let action = MultiActionTest::from_indices([1, 2]);
let cloned = action;
assert_eq!(action, cloned);
let debug_str = format!("{:?}", action);
assert!(debug_str.contains("direction"));
}
#[test]
fn test_continuous_action_clone_and_debug() {
let action = ContinuousActionTest {
values: [0.1, 0.2, 0.3],
};
let cloned = action.clone();
assert_eq!(action.as_slice(), cloned.as_slice());
let debug_str = format!("{:?}", action);
assert!(debug_str.contains("values"));
}
#[test]
fn test_continuous_action_clip_chaining() {
let action = ContinuousActionTest {
values: [2.0, -3.0, 0.5],
};
let clipped = action.clip(-2.0, 2.0).clip(-1.0, 1.0);
assert_eq!(clipped.values, [1.0, -1.0, 0.5]);
}
#[test]
fn test_large_discrete_action_space() {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct LargeDiscreteAction(u8);
impl Action<1> for LargeDiscreteAction {
fn shape() -> [usize; 1] {
[256]
}
fn is_valid(&self) -> bool {
true
}
}
impl DiscreteAction<1> for LargeDiscreteAction {
const ACTION_COUNT: usize = 256;
fn from_index(index: usize) -> Self {
assert!(index < 256);
LargeDiscreteAction(index as u8)
}
fn to_index(&self) -> usize {
self.0 as usize
}
}
let actions = LargeDiscreteAction::enumerate();
assert_eq!(actions.len(), 256);
for i in [0, 1, 127, 255] {
let action = LargeDiscreteAction::from_index(i);
assert_eq!(action.to_index(), i);
}
}
#[test]
fn test_continuous_action_with_zero_values() {
let action = ContinuousActionTest {
values: [0.0, 0.0, 0.0],
};
assert!(action.is_valid());
assert_eq!(action.as_slice(), &[0.0, 0.0, 0.0]);
let clipped = action.clip(-1.0, 1.0);
assert_eq!(clipped.values, [0.0, 0.0, 0.0]);
}
#[test]
fn test_continuous_action_extreme_clip_bounds() {
let action = ContinuousActionTest {
values: [100.0, -100.0, 0.0],
};
let clipped = action.clip(f32::NEG_INFINITY, f32::INFINITY);
assert_eq!(clipped.values, [100.0, -100.0, 0.0]);
let clipped = action.clip(0.0, 0.0);
assert_eq!(clipped.values, [0.0, 0.0, 0.0]);
}
#[test]
fn test_bounded_action_low_strictly_below_high() {
let low = ContinuousActionTest::low();
let high = ContinuousActionTest::high();
for i in 0..3 {
assert!(low[i] < high[i], "bound {i}: low >= high");
}
}
#[test]
fn test_bounded_action_clip_is_noop_inside_bounds() {
let low = ContinuousActionTest::low();
let high = ContinuousActionTest::high();
let at_low = ContinuousActionTest::from_slice(&low);
let at_high = ContinuousActionTest::from_slice(&high);
assert_eq!(at_low.clip(low[0], high[0]).as_slice(), &low);
assert_eq!(at_high.clip(low[0], high[0]).as_slice(), &high);
}
}