Skip to main content

rlevo_core/
action.rs

1//! Action space abstractions for reinforcement learning environments.
2//!
3//! This module provides a flexible type system for representing agent actions in RL environments.
4//! Actions can be discrete (finite choices), multi-discrete (multiple independent discrete choices),
5//! or continuous (real-valued vectors).
6//!
7//! # Design Philosophy
8//!
9//! The action traits follow a layered design:
10//! - [`Action`](crate::base::Action): Base trait providing validation and cloning semantics
11//! - [`DiscreteAction`], [`MultiDiscreteAction`], [`ContinuousAction`]: Type-specific extensions
12//!
13//! # Action Types
14//!
15//! ## Discrete Actions
16//!
17//! Discrete actions represent a finite set of mutually exclusive choices (e.g., "move left",
18//! "move right", "jump"). They are indexed from `0` to `ACTION_COUNT - 1`.
19//!
20//! ## Multi-Discrete Actions
21//!
22//! Multi-discrete actions consist of multiple independent discrete choices, such as selecting
23//! both a direction and an attack type simultaneously.
24//!
25//! ## Continuous Actions
26//!
27//! Continuous actions are real-valued vectors, typically used for motor control or
28//! parametrized actions (e.g., steering angle, throttle).
29//!
30//! # Test Suite
31//!
32//! This module includes a comprehensive test suite covering:
33//!
34//! ## DiscreteAction Tests (10 tests)
35//! - `test_discrete_action_shape`: Verifies shape and dimension constants
36//! - `test_discrete_action_count`: Checks action count constant
37//! - `test_discrete_action_from_index`: Tests index-to-action conversion
38//! - `test_discrete_action_from_index_out_of_bounds`: Validates panic on invalid indices
39//! - `test_discrete_action_to_index`: Tests action-to-index conversion
40//! - `test_discrete_action_roundtrip`: Ensures bidirectional conversion consistency
41//! - `test_discrete_action_enumerate`: Verifies all actions are enumerated correctly
42//! - `test_discrete_action_random`: Tests random action generation
43//! - `test_discrete_action_is_valid`: Validates the `is_valid()` predicate
44//! - `test_discrete_action_clone_and_debug`: Tests Debug and Clone trait implementations
45//!
46//! ## MultiDiscreteAction Tests (11 tests)
47//! - `test_multidiscrete_action_shape`: Verifies multi-dimensional shape
48//! - `test_multidiscrete_action_from_indices`: Tests multi-index conversion
49//! - `test_multidiscrete_action_from_indices_*_out_of_bounds`: Validates panic on invalid indices
50//! - `test_multidiscrete_action_to_indices`: Tests reverse conversion
51//! - `test_multidiscrete_action_roundtrip`: Ensures bidirectional consistency
52//! - `test_multidiscrete_action_enumerate`: Verifies all action combinations are enumerated
53//! - `test_multidiscrete_action_enumerate_large_space`: Tests scalability with large action spaces
54//! - `test_multidiscrete_action_random`: Tests random sampling
55//! - `test_multidiscrete_action_is_valid`: Validates constraints
56//! - `test_multidiscrete_action_clone_and_debug`: Tests trait implementations
57//!
58//! ## ContinuousAction Tests (15 tests)
59//! - `test_continuous_action_shape`: Verifies shape specification
60//! - `test_continuous_action_as_slice`: Tests slice view access
61//! - `test_continuous_action_from_slice`: Tests construction from slice
62//! - `test_continuous_action_from_slice_wrong_size`: Validates dimension checking
63//! - `test_continuous_action_roundtrip`: Ensures slice conversion consistency
64//! - `test_continuous_action_clip_*`: Tests clipping behavior (within, exceeds max/min, mixed, extreme)
65//! - `test_continuous_action_clip_chaining`: Verifies method chaining
66//! - `test_continuous_action_random`: Tests random action generation
67//! - `test_continuous_action_is_valid_*`: Tests validity checking (finite, NaN, Inf)
68//! - `test_continuous_action_with_zero_values`: Tests edge case with zero values
69//! - `test_continuous_action_clone_and_debug`: Tests trait implementations
70//!
71//! ## InvalidActionError Tests (6 tests)
72//! - `test_invalid_action_error_creation`: Tests error instantiation
73//! - `test_invalid_action_error_display`: Tests Display trait formatting
74//! - `test_invalid_action_error_debug`: Tests Debug trait formatting
75//! - `test_invalid_action_error_clone`: Tests Clone implementation
76//! - `test_invalid_action_error_equality`: Tests PartialEq implementation
77//! - `test_invalid_action_error_is_error`: Tests std::error::Error trait compatibility
78//!
79//! ## Integration Tests (4 tests)
80//! - `test_large_discrete_action_space`: Tests with 256 actions
81//! - `test_continuous_action_extreme_clip_bounds`: Tests edge cases in clipping
82//! - Various clone/debug/trait tests across different action types
83
84use crate::base::Action;
85use std::error::Error;
86use std::fmt::Debug;
87
88/// Trait for discrete actions with a finite, enumerable set of choices.
89///
90/// Discrete actions represent mutually exclusive options that can be indexed by
91/// integers from `0` to `ACTION_COUNT - 1`. Common examples include:
92/// - Game controls (move left/right/jump)
93/// - Categorical decisions (buy/hold/sell)
94/// - Navigation directions (north/south/east/west)
95///
96/// # Type Safety
97///
98/// Implementations should ensure bidirectional conversion between indices and actions:
99/// ```text
100/// ∀ i ∈ [0, ACTION_COUNT): i == from_index(i).to_index()
101/// ∀ a: Action: a == from_index(a.to_index())
102/// ```
103///
104/// # Performance
105///
106/// For performance-critical code, prefer `from_index()` over `random()` when you
107/// already have an index (e.g., from a neural network's argmax). The `random()`
108/// method allocates a thread-local RNG on each call.
109pub trait DiscreteAction<const R: usize>: Action<R> {
110    /// The total number of distinct actions in this action space.
111    ///
112    /// This constant defines the cardinality of the action space. It must be
113    /// greater than zero and remain constant for the lifetime of the program.
114    const ACTION_COUNT: usize;
115
116    /// Constructs an action from its zero-based index.
117    ///
118    /// This method must be the inverse of [`to_index()`](DiscreteAction::to_index).
119    ///
120    /// # Panics
121    ///
122    /// Implementations should panic if `index >= ACTION_COUNT`, as this indicates
123    /// a programming error (out-of-bounds access).
124    fn from_index(index: usize) -> Self;
125
126    /// Converts this action to its zero-based index.
127    ///
128    /// The returned index must be in the range `[0, ACTION_COUNT)` and must be
129    /// the inverse of [`from_index()`](DiscreteAction::from_index).
130    fn to_index(&self) -> usize;
131
132    /// Samples a uniformly random action from this action space.
133    ///
134    /// This is a convenience method for exploration in reinforcement learning.
135    /// It uses thread-local RNG state, so it's safe to call from multiple threads
136    /// but will produce different sequences per thread.
137    ///
138    /// # Performance
139    ///
140    /// If you already have an index from another source (e.g., a neural network
141    /// output), use `from_index()` directly instead of this method.
142    fn random() -> Self
143    where
144        Self: Sized,
145    {
146        use rand::RngExt;
147        let mut rng = rand::rng();
148        let index = rng.random_range(0..Self::ACTION_COUNT);
149        Self::from_index(index)
150    }
151
152    /// Returns a vector containing all possible actions in index order.
153    ///
154    /// This is useful for tabular RL methods (e.g., Q-learning) that need to
155    /// iterate over the entire action space. The returned vector has length
156    /// `ACTION_COUNT` with actions ordered by their index.
157    ///
158    /// # Performance
159    ///
160    /// This allocates a vector of size `ACTION_COUNT`. For large action spaces,
161    /// consider using an iterator pattern instead (not currently provided).
162    fn enumerate() -> Vec<Self>
163    where
164        Self: Sized,
165    {
166        (0..Self::ACTION_COUNT).map(Self::from_index).collect()
167    }
168}
169
170/// Trait for actions with multiple independent discrete dimensions.
171///
172/// Multi-discrete actions represent scenarios where an agent must make several
173/// independent categorical choices simultaneously. Each dimension can have a
174/// different number of options. This is common in:
175/// - Strategy games (select unit + select action + select target)
176/// - Multi-agent coordination (each agent picks a discrete action)
177/// - Parameterized actions (choose action type + intensity level)
178///
179/// # Dimensionality
180///
181/// The const generic `R` (the rank) specifies the number of axes. Each axis
182/// can have a different cardinality, defined by [`shape()`](Action::shape).
183///
184/// The total number of action combinations is the product of all dimension sizes:
185/// ```text
186/// total_actions = ∏ shape()[i]
187/// ```
188///
189/// # Caution: Combinatorial Explosion
190///
191/// Be careful with [`enumerate()`](MultiDiscreteAction::enumerate) on large action spaces.
192/// A 3D action space with dimensions [10, 10, 10] produces 1000 actions, but
193/// [100, 100, 100] produces 1,000,000!
194pub trait MultiDiscreteAction<const R: usize>: Action<R> {
195    /// Constructs an action from multi-dimensional indices.
196    ///
197    /// Each index must be in the range `[0, shape()[i])` for axis `i`.
198    ///
199    /// # Panics
200    ///
201    /// Implementations should panic if any index is out of bounds for its axis.
202    fn from_indices(indices: [usize; R]) -> Self;
203
204    /// Converts this action to its multi-dimensional index representation.
205    ///
206    /// The returned array must satisfy: each element `i` is in `[0, shape()[i])`.
207    /// This method must be the inverse of [`from_indices()`](MultiDiscreteAction::from_indices).
208    fn to_indices(&self) -> [usize; R];
209
210    /// Samples a uniformly random action from this multi-discrete action space.
211    ///
212    /// Each dimension is sampled independently and uniformly from its valid range.
213    /// This is useful for exploration in reinforcement learning.
214    ///
215    /// # Examples
216    ///
217    /// ```rust,ignore
218    /// let random_action = StrategyAction::random();
219    /// assert!(random_action.is_valid());
220    /// ```
221    fn random() -> Self
222    where
223        Self: Sized,
224    {
225        use rand::RngExt;
226        let mut rng = rand::rng();
227        let space = Self::shape();
228        let indices = space.map(|dim| rng.random_range(0..dim));
229        Self::from_indices(indices)
230    }
231
232    /// Returns all possible action combinations in this space.
233    ///
234    /// # Warning
235    ///
236    /// This method generates **all** combinations across all dimensions. The number
237    /// of actions grows multiplicatively with the product of dimension sizes:
238    ///
239    /// - `[10, 10, 10]` → 1,000 actions (manageable)
240    /// - `[50, 50, 50]` → 125,000 actions (marginal)
241    /// - `[100, 100, 100]` → 1,000,000 actions (likely too large)
242    ///
243    /// Use this method only when you need to iterate over the entire action space
244    /// (e.g., for exact policy evaluation in tabular methods).
245    ///
246    /// # Panics
247    ///
248    /// May panic or run out of memory if the action space is too large.
249    fn enumerate() -> Vec<Self>
250    where
251        Self: Sized,
252    {
253        let space = Self::shape();
254        let total: usize = space.iter().product();
255        let mut actions = Vec::with_capacity(total);
256
257        fn generate<const R: usize, T: MultiDiscreteAction<R>>(
258            space: &[usize; R],
259            current: &mut [usize; R],
260            axis: usize,
261            actions: &mut Vec<T>,
262        ) {
263            if axis == R {
264                actions.push(T::from_indices(*current));
265                return;
266            }
267            for i in 0..space[axis] {
268                current[axis] = i;
269                generate(space, current, axis + 1, actions);
270            }
271        }
272
273        let mut current = [0; R];
274        generate(&space, &mut current, 0, &mut actions);
275        actions
276    }
277}
278
279/// Trait for continuous-valued actions represented as real-valued vectors.
280///
281/// Continuous actions are used when the agent's output is a vector of real numbers
282/// rather than discrete choices. Common applications include:
283/// - Robot motor control (joint angles, torques)
284/// - Vehicle control (steering, throttle, brake)
285/// - Continuous parameter tuning (learning rates, temperatures)
286///
287/// # Value Range
288///
289/// Continuous actions typically have bounded ranges (e.g., `[-1, 1]` or `[0, 1]`).
290/// The [`clip()`](ContinuousAction::clip) method enforces these bounds.
291///
292/// # Neural Network Integration
293///
294/// Continuous actions are typically produced by neural networks with `tanh` or
295/// `sigmoid` activation functions. Use [`clip()`](ContinuousAction::clip) to
296/// ensure outputs stay within valid ranges.
297pub trait ContinuousAction<const R: usize>: Action<R> {
298    /// Returns a slice view of this action's component values.
299    ///
300    /// The returned slice must have exactly `RANK` elements. This is used for
301    /// efficient serialization and tensor conversion.
302    fn as_slice(&self) -> &[f32];
303
304    /// Returns a new action with all components clipped to `[min, max]`.
305    ///
306    /// This is essential for ensuring neural network outputs (which may exceed
307    /// valid ranges due to numerical errors or exploration noise) stay within
308    /// acceptable bounds.
309    ///
310    /// # Common Use Cases
311    ///
312    /// - Enforcing action space bounds after neural network output
313    /// - Adding exploration noise while maintaining validity
314    /// - Recovering from numerical instability
315    fn clip(&self, min: f32, max: f32) -> Self;
316
317    /// Samples a random action with components uniformly distributed in `[-1.0, 1.0)`.
318    ///
319    /// The default implementation generates uniform random values in the half-open
320    /// range `[-1.0, 1.0)` using `rand::random_range`. Override this method if you
321    /// need different sampling behavior (e.g., Gaussian noise, domain-specific
322    /// distributions, or bounds derived from [`BoundedAction`]).
323    fn random() -> Self
324    where
325        Self: Sized,
326    {
327        use rand::RngExt;
328        let mut rng = rand::rng();
329        // Default implementation - override for custom behavior
330        let values: Vec<f32> = (0..Self::RANK)
331            .map(|_| rng.random_range(-1.0..1.0))
332            .collect();
333        Self::from_slice(&values)
334    }
335
336    /// Constructs an action from a slice of component values.
337    ///
338    /// The input slice must have exactly `RANK` elements. This is the inverse
339    /// operation of [`as_slice()`](ContinuousAction::as_slice).
340    ///
341    /// # Panics
342    ///
343    /// Implementations should panic if `values.len() != RANK`.
344    fn from_slice(values: &[f32]) -> Self;
345}
346
347/// A [`ContinuousAction`] with statically-known `[low, high]` component bounds.
348///
349/// DDPG and other continuous-control algorithms need the per-component action
350/// bounds to scale/shift neural-network outputs and to sample uniform warm-up
351/// actions. Expose them via associated static methods rather than associated
352/// constants so implementors can still derive bounds from a runtime env config
353/// (e.g. a `max_torque` field) while presenting a uniform API.
354///
355/// # Invariants
356///
357/// - `low()[i] < high()[i]` for every component `i`.
358/// - [`ContinuousAction::clip`] must be a no-op on an action whose components
359///   already lie in `[low, high]`.
360pub trait BoundedAction<const R: usize>: ContinuousAction<R> {
361    /// Returns the per-component lower bounds of this action space.
362    ///
363    /// The returned array must satisfy `low()[i] < high()[i]` for every component
364    /// `i`. See the trait-level invariants for the full contract.
365    fn low() -> [f32; R];
366
367    /// Returns the per-component upper bounds of this action space.
368    ///
369    /// The returned array must satisfy `low()[i] < high()[i]` for every component
370    /// `i`. See the trait-level invariants for the full contract.
371    fn high() -> [f32; R];
372}
373
374/// Error indicating an action violated its type's constraints.
375///
376/// Returned when an action fails validation or when invalid conversions are
377/// attempted (e.g., out-of-bounds indices, non-finite float values).
378///
379/// # Examples
380///
381/// ```
382/// use rlevo_core::action::InvalidActionError;
383///
384/// let err = InvalidActionError { message: "index 5 out of bounds for ACTION_COUNT=4".into() };
385/// assert!(err.to_string().contains("index 5"));
386/// ```
387#[derive(Debug, Clone, PartialEq)]
388pub struct InvalidActionError {
389    /// Human-readable description of the constraint that was violated.
390    pub message: String,
391}
392
393impl std::fmt::Display for InvalidActionError {
394    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395        write!(f, "Invalid action: {}", self.message)
396    }
397}
398
399impl Error for InvalidActionError {}
400
401// ============================================================================
402// Tests
403// ============================================================================
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    // ========================================================================
410    // Test Implementations
411    // ========================================================================
412
413    /// Simple discrete action with 4 possible choices.
414    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
415    enum SimpleDiscreteAction {
416        Left,
417        Right,
418        Up,
419        Down,
420    }
421
422    impl Action<1> for SimpleDiscreteAction {
423        fn shape() -> [usize; 1] {
424            [4]
425        }
426
427        fn is_valid(&self) -> bool {
428            true // All variants are always valid
429        }
430    }
431
432    impl DiscreteAction<1> for SimpleDiscreteAction {
433        const ACTION_COUNT: usize = 4;
434
435        fn from_index(index: usize) -> Self {
436            match index {
437                0 => SimpleDiscreteAction::Left,
438                1 => SimpleDiscreteAction::Right,
439                2 => SimpleDiscreteAction::Up,
440                3 => SimpleDiscreteAction::Down,
441                _ => panic!("Index out of bounds: {}", index),
442            }
443        }
444
445        fn to_index(&self) -> usize {
446            match self {
447                SimpleDiscreteAction::Left => 0,
448                SimpleDiscreteAction::Right => 1,
449                SimpleDiscreteAction::Up => 2,
450                SimpleDiscreteAction::Down => 3,
451            }
452        }
453    }
454
455    /// Multi-discrete action with 2 dimensions.
456    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
457    struct MultiActionTest {
458        direction: usize, // 0-3
459        intensity: usize, // 0-2
460    }
461
462    impl Action<2> for MultiActionTest {
463        fn shape() -> [usize; 2] {
464            [4, 3]
465        }
466
467        fn is_valid(&self) -> bool {
468            self.direction < 4 && self.intensity < 3
469        }
470    }
471
472    impl MultiDiscreteAction<2> for MultiActionTest {
473        fn from_indices(indices: [usize; 2]) -> Self {
474            if indices[0] >= 4 {
475                panic!("Direction index out of bounds: {}", indices[0]);
476            }
477            if indices[1] >= 3 {
478                panic!("Intensity index out of bounds: {}", indices[1]);
479            }
480            MultiActionTest {
481                direction: indices[0],
482                intensity: indices[1],
483            }
484        }
485
486        fn to_indices(&self) -> [usize; 2] {
487            [self.direction, self.intensity]
488        }
489    }
490
491    /// Continuous action with 3 dimensions (e.g., 3D velocity).
492    #[derive(Debug, Clone)]
493    struct ContinuousActionTest {
494        values: [f32; 3],
495    }
496
497    impl Action<3> for ContinuousActionTest {
498        fn shape() -> [usize; 3] {
499            [1, 1, 1] // Continuous dimensions typically have size 1
500        }
501
502        fn is_valid(&self) -> bool {
503            self.values.iter().all(|v| v.is_finite())
504        }
505    }
506
507    impl ContinuousAction<3> for ContinuousActionTest {
508        fn as_slice(&self) -> &[f32] {
509            &self.values
510        }
511
512        fn clip(&self, min: f32, max: f32) -> Self {
513            let clipped = self
514                .values
515                .iter()
516                .map(|&v| v.max(min).min(max))
517                .collect::<Vec<_>>();
518            ContinuousActionTest {
519                values: [clipped[0], clipped[1], clipped[2]],
520            }
521        }
522
523        fn from_slice(values: &[f32]) -> Self {
524            assert_eq!(values.len(), 3, "Expected 3 values, got {}", values.len());
525            ContinuousActionTest {
526                values: [values[0], values[1], values[2]],
527            }
528        }
529    }
530
531    impl BoundedAction<3> for ContinuousActionTest {
532        fn low() -> [f32; 3] {
533            [-1.0, -1.0, -1.0]
534        }
535
536        fn high() -> [f32; 3] {
537            [1.0, 1.0, 1.0]
538        }
539    }
540
541    // ========================================================================
542    // DiscreteAction Tests
543    // ========================================================================
544
545    #[test]
546    fn test_discrete_action_shape() {
547        assert_eq!(SimpleDiscreteAction::shape(), [4]);
548        assert_eq!(SimpleDiscreteAction::RANK, 1);
549    }
550
551    #[test]
552    fn test_discrete_action_count() {
553        assert_eq!(SimpleDiscreteAction::ACTION_COUNT, 4);
554    }
555
556    #[test]
557    fn test_discrete_action_from_index() {
558        assert_eq!(
559            SimpleDiscreteAction::from_index(0),
560            SimpleDiscreteAction::Left
561        );
562        assert_eq!(
563            SimpleDiscreteAction::from_index(1),
564            SimpleDiscreteAction::Right
565        );
566        assert_eq!(
567            SimpleDiscreteAction::from_index(2),
568            SimpleDiscreteAction::Up
569        );
570        assert_eq!(
571            SimpleDiscreteAction::from_index(3),
572            SimpleDiscreteAction::Down
573        );
574    }
575
576    #[test]
577    #[should_panic(expected = "Index out of bounds")]
578    fn test_discrete_action_from_index_out_of_bounds() {
579        SimpleDiscreteAction::from_index(4);
580    }
581
582    #[test]
583    #[should_panic(expected = "Index out of bounds")]
584    fn test_discrete_action_from_index_negative_like() {
585        // Note: usize can't be negative, but we test the boundary
586        SimpleDiscreteAction::from_index(100);
587    }
588
589    #[test]
590    fn test_discrete_action_to_index() {
591        assert_eq!(SimpleDiscreteAction::Left.to_index(), 0);
592        assert_eq!(SimpleDiscreteAction::Right.to_index(), 1);
593        assert_eq!(SimpleDiscreteAction::Up.to_index(), 2);
594        assert_eq!(SimpleDiscreteAction::Down.to_index(), 3);
595    }
596
597    #[test]
598    fn test_discrete_action_roundtrip() {
599        // Test bidirectional conversion
600        for i in 0..SimpleDiscreteAction::ACTION_COUNT {
601            let action = SimpleDiscreteAction::from_index(i);
602            assert_eq!(action.to_index(), i);
603        }
604    }
605
606    #[test]
607    fn test_discrete_action_enumerate() {
608        let actions = SimpleDiscreteAction::enumerate();
609        assert_eq!(actions.len(), 4);
610        assert_eq!(
611            actions,
612            vec![
613                SimpleDiscreteAction::Left,
614                SimpleDiscreteAction::Right,
615                SimpleDiscreteAction::Up,
616                SimpleDiscreteAction::Down
617            ]
618        );
619    }
620
621    #[test]
622    fn test_discrete_action_random() {
623        for _ in 0..100 {
624            let action = SimpleDiscreteAction::random();
625            let index = action.to_index();
626            assert!(index < SimpleDiscreteAction::ACTION_COUNT);
627        }
628    }
629
630    #[test]
631    fn test_discrete_action_is_valid() {
632        for i in 0..SimpleDiscreteAction::ACTION_COUNT {
633            let action = SimpleDiscreteAction::from_index(i);
634            assert!(action.is_valid());
635        }
636    }
637
638    // ========================================================================
639    // MultiDiscreteAction Tests
640    // ========================================================================
641
642    #[test]
643    fn test_multidiscrete_action_shape() {
644        assert_eq!(MultiActionTest::shape(), [4, 3]);
645        assert_eq!(MultiActionTest::RANK, 2);
646    }
647
648    #[test]
649    fn test_multidiscrete_action_from_indices() {
650        let action = MultiActionTest::from_indices([0, 0]);
651        assert_eq!(action.direction, 0);
652        assert_eq!(action.intensity, 0);
653
654        let action = MultiActionTest::from_indices([3, 2]);
655        assert_eq!(action.direction, 3);
656        assert_eq!(action.intensity, 2);
657    }
658
659    #[test]
660    #[should_panic(expected = "Direction index out of bounds")]
661    fn test_multidiscrete_action_from_indices_direction_out_of_bounds() {
662        MultiActionTest::from_indices([4, 0]);
663    }
664
665    #[test]
666    #[should_panic(expected = "Intensity index out of bounds")]
667    fn test_multidiscrete_action_from_indices_intensity_out_of_bounds() {
668        MultiActionTest::from_indices([0, 3]);
669    }
670
671    #[test]
672    fn test_multidiscrete_action_to_indices() {
673        let action = MultiActionTest::from_indices([2, 1]);
674        assert_eq!(action.to_indices(), [2, 1]);
675    }
676
677    #[test]
678    fn test_multidiscrete_action_roundtrip() {
679        for d in 0..4 {
680            for i in 0..3 {
681                let action = MultiActionTest::from_indices([d, i]);
682                assert_eq!(action.to_indices(), [d, i]);
683            }
684        }
685    }
686
687    #[test]
688    fn test_multidiscrete_action_enumerate() {
689        let actions = MultiActionTest::enumerate();
690        // 4 directions × 3 intensities = 12 total actions
691        assert_eq!(actions.len(), 12);
692
693        // Verify all combinations are present
694        for (idx, action) in actions.iter().enumerate() {
695            let expected_d = idx / 3;
696            let expected_i = idx % 3;
697            assert_eq!(action.direction, expected_d);
698            assert_eq!(action.intensity, expected_i);
699        }
700    }
701
702    #[test]
703    fn test_multidiscrete_action_enumerate_large_space() {
704        // Test with 3D space: [5, 5, 5] = 125 total actions
705        #[derive(Debug, Clone)]
706        struct LargeMultiAction([usize; 3]);
707
708        impl Action<3> for LargeMultiAction {
709            fn shape() -> [usize; 3] {
710                [5, 5, 5]
711            }
712
713            fn is_valid(&self) -> bool {
714                self.0.iter().enumerate().all(|(i, &v)| v < [5, 5, 5][i])
715            }
716        }
717
718        impl MultiDiscreteAction<3> for LargeMultiAction {
719            fn from_indices(indices: [usize; 3]) -> Self {
720                for (i, &idx) in indices.iter().enumerate() {
721                    assert!(idx < 5, "Index {} out of bounds", i);
722                }
723                LargeMultiAction(indices)
724            }
725
726            fn to_indices(&self) -> [usize; 3] {
727                self.0
728            }
729        }
730
731        let actions = LargeMultiAction::enumerate();
732        assert_eq!(actions.len(), 125);
733    }
734
735    #[test]
736    fn test_multidiscrete_action_random() {
737        for _ in 0..100 {
738            let action = MultiActionTest::random();
739            assert!(action.is_valid());
740            let indices = action.to_indices();
741            assert!(indices[0] < 4);
742            assert!(indices[1] < 3);
743        }
744    }
745
746    #[test]
747    fn test_multidiscrete_action_is_valid() {
748        // Valid actions
749        assert!(MultiActionTest::from_indices([0, 0]).is_valid());
750        assert!(MultiActionTest::from_indices([3, 2]).is_valid());
751
752        // Invalid actions created directly
753        let invalid = MultiActionTest {
754            direction: 5,
755            intensity: 0,
756        };
757        assert!(!invalid.is_valid());
758
759        let invalid = MultiActionTest {
760            direction: 0,
761            intensity: 5,
762        };
763        assert!(!invalid.is_valid());
764    }
765
766    // ========================================================================
767    // ContinuousAction Tests
768    // ========================================================================
769
770    #[test]
771    fn test_continuous_action_shape() {
772        assert_eq!(ContinuousActionTest::shape(), [1, 1, 1]);
773        assert_eq!(ContinuousActionTest::RANK, 3);
774    }
775
776    #[test]
777    fn test_continuous_action_as_slice() {
778        let action = ContinuousActionTest {
779            values: [0.5, -0.3, 1.0],
780        };
781        let slice = action.as_slice();
782        assert_eq!(slice.len(), 3);
783        assert_eq!(slice, &[0.5, -0.3, 1.0]);
784    }
785
786    #[test]
787    fn test_continuous_action_from_slice() {
788        let values = [0.1, 0.2, 0.3];
789        let action = ContinuousActionTest::from_slice(&values);
790        assert_eq!(action.values, values);
791    }
792
793    #[test]
794    #[should_panic(expected = "Expected 3 values")]
795    fn test_continuous_action_from_slice_wrong_size() {
796        let values = [0.1, 0.2];
797        ContinuousActionTest::from_slice(&values);
798    }
799
800    #[test]
801    fn test_continuous_action_roundtrip() {
802        let original = [0.5, -0.3, 0.9];
803        let action = ContinuousActionTest::from_slice(&original);
804        assert_eq!(action.as_slice(), &original);
805    }
806
807    #[test]
808    fn test_continuous_action_clip_within_bounds() {
809        let action = ContinuousActionTest {
810            values: [0.0, 0.5, -0.5],
811        };
812        let clipped = action.clip(-1.0, 1.0);
813        assert_eq!(clipped.values, [0.0, 0.5, -0.5]);
814    }
815
816    #[test]
817    fn test_continuous_action_clip_exceeds_max() {
818        let action = ContinuousActionTest {
819            values: [2.0, 1.5, 3.0],
820        };
821        let clipped = action.clip(-1.0, 1.0);
822        assert_eq!(clipped.values, [1.0, 1.0, 1.0]);
823    }
824
825    #[test]
826    fn test_continuous_action_clip_exceeds_min() {
827        let action = ContinuousActionTest {
828            values: [-2.0, -1.5, -3.0],
829        };
830        let clipped = action.clip(-1.0, 1.0);
831        assert_eq!(clipped.values, [-1.0, -1.0, -1.0]);
832    }
833
834    #[test]
835    fn test_continuous_action_clip_mixed() {
836        let action = ContinuousActionTest {
837            values: [2.0, 0.5, -2.0],
838        };
839        let clipped = action.clip(-1.0, 1.0);
840        assert_eq!(clipped.values, [1.0, 0.5, -1.0]);
841    }
842
843    #[test]
844    fn test_continuous_action_random() {
845        for _ in 0..100 {
846            let action = ContinuousActionTest::random();
847            assert!(action.is_valid());
848            for &value in action.as_slice() {
849                assert!((-1.0..=1.0).contains(&value));
850                assert!(value.is_finite());
851            }
852        }
853    }
854
855    #[test]
856    fn test_continuous_action_is_valid_finite() {
857        let action = ContinuousActionTest {
858            values: [0.5, -0.3, 1.0],
859        };
860        assert!(action.is_valid());
861    }
862
863    #[test]
864    fn test_continuous_action_is_invalid_nan() {
865        let action = ContinuousActionTest {
866            values: [f32::NAN, 0.5, 1.0],
867        };
868        assert!(!action.is_valid());
869    }
870
871    #[test]
872    fn test_continuous_action_is_invalid_inf() {
873        let action = ContinuousActionTest {
874            values: [f32::INFINITY, 0.5, 1.0],
875        };
876        assert!(!action.is_valid());
877
878        let action = ContinuousActionTest {
879            values: [f32::NEG_INFINITY, 0.5, 1.0],
880        };
881        assert!(!action.is_valid());
882    }
883
884    // ========================================================================
885    // InvalidActionError Tests
886    // ========================================================================
887
888    #[test]
889    fn test_invalid_action_error_creation() {
890        let error = InvalidActionError {
891            message: String::from("Index out of bounds"),
892        };
893        assert_eq!(error.message, "Index out of bounds");
894    }
895
896    #[test]
897    fn test_invalid_action_error_display() {
898        let error = InvalidActionError {
899            message: String::from("Invalid value"),
900        };
901        let displayed = format!("{}", error);
902        assert_eq!(displayed, "Invalid action: Invalid value");
903    }
904
905    #[test]
906    fn test_invalid_action_error_debug() {
907        let error = InvalidActionError {
908            message: String::from("Test error"),
909        };
910        let debug_str = format!("{:?}", error);
911        assert!(debug_str.contains("Test error"));
912    }
913
914    #[test]
915    fn test_invalid_action_error_clone() {
916        let error = InvalidActionError {
917            message: String::from("Original"),
918        };
919        let cloned = error.clone();
920        assert_eq!(error, cloned);
921    }
922
923    #[test]
924    fn test_invalid_action_error_equality() {
925        let error1 = InvalidActionError {
926            message: String::from("Same error"),
927        };
928        let error2 = InvalidActionError {
929            message: String::from("Same error"),
930        };
931        let error3 = InvalidActionError {
932            message: String::from("Different error"),
933        };
934
935        assert_eq!(error1, error2);
936        assert_ne!(error1, error3);
937    }
938
939    #[test]
940    fn test_invalid_action_error_is_error() {
941        let error: Box<dyn Error> = Box::new(InvalidActionError {
942            message: String::from("Test"),
943        });
944        // Should be able to use as std::error::Error trait object
945        let _msg = error.to_string();
946    }
947
948    // ========================================================================
949    // Integration Tests
950    // ========================================================================
951
952    #[test]
953    fn test_discrete_action_clone_and_debug() {
954        let action = SimpleDiscreteAction::Left;
955        let cloned = action;
956        assert_eq!(action, cloned);
957
958        let debug_str = format!("{:?}", action);
959        assert!(debug_str.contains("Left"));
960    }
961
962    #[test]
963    fn test_multidiscrete_action_clone_and_debug() {
964        let action = MultiActionTest::from_indices([1, 2]);
965        let cloned = action;
966        assert_eq!(action, cloned);
967
968        let debug_str = format!("{:?}", action);
969        assert!(debug_str.contains("direction"));
970    }
971
972    #[test]
973    fn test_continuous_action_clone_and_debug() {
974        let action = ContinuousActionTest {
975            values: [0.1, 0.2, 0.3],
976        };
977        let cloned = action.clone();
978        assert_eq!(action.as_slice(), cloned.as_slice());
979
980        let debug_str = format!("{:?}", action);
981        assert!(debug_str.contains("values"));
982    }
983
984    #[test]
985    fn test_continuous_action_clip_chaining() {
986        let action = ContinuousActionTest {
987            values: [2.0, -3.0, 0.5],
988        };
989        let clipped = action.clip(-2.0, 2.0).clip(-1.0, 1.0);
990        assert_eq!(clipped.values, [1.0, -1.0, 0.5]);
991    }
992
993    #[test]
994    fn test_large_discrete_action_space() {
995        // Test with 256 actions
996        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
997        struct LargeDiscreteAction(u8);
998
999        impl Action<1> for LargeDiscreteAction {
1000            fn shape() -> [usize; 1] {
1001                [256]
1002            }
1003
1004            fn is_valid(&self) -> bool {
1005                true
1006            }
1007        }
1008
1009        impl DiscreteAction<1> for LargeDiscreteAction {
1010            const ACTION_COUNT: usize = 256;
1011
1012            fn from_index(index: usize) -> Self {
1013                assert!(index < 256);
1014                LargeDiscreteAction(index as u8)
1015            }
1016
1017            fn to_index(&self) -> usize {
1018                self.0 as usize
1019            }
1020        }
1021
1022        // Enumerate should produce all 256 actions
1023        let actions = LargeDiscreteAction::enumerate();
1024        assert_eq!(actions.len(), 256);
1025
1026        // Verify roundtrip for a few samples
1027        for i in [0, 1, 127, 255] {
1028            let action = LargeDiscreteAction::from_index(i);
1029            assert_eq!(action.to_index(), i);
1030        }
1031    }
1032
1033    #[test]
1034    fn test_continuous_action_with_zero_values() {
1035        let action = ContinuousActionTest {
1036            values: [0.0, 0.0, 0.0],
1037        };
1038        assert!(action.is_valid());
1039        assert_eq!(action.as_slice(), &[0.0, 0.0, 0.0]);
1040
1041        let clipped = action.clip(-1.0, 1.0);
1042        assert_eq!(clipped.values, [0.0, 0.0, 0.0]);
1043    }
1044
1045    #[test]
1046    fn test_continuous_action_extreme_clip_bounds() {
1047        let action = ContinuousActionTest {
1048            values: [100.0, -100.0, 0.0],
1049        };
1050
1051        let clipped = action.clip(f32::NEG_INFINITY, f32::INFINITY);
1052        assert_eq!(clipped.values, [100.0, -100.0, 0.0]);
1053
1054        let clipped = action.clip(0.0, 0.0);
1055        assert_eq!(clipped.values, [0.0, 0.0, 0.0]);
1056    }
1057
1058    // ========================================================================
1059    // BoundedAction Tests
1060    // ========================================================================
1061
1062    #[test]
1063    fn test_bounded_action_low_strictly_below_high() {
1064        let low = ContinuousActionTest::low();
1065        let high = ContinuousActionTest::high();
1066        for i in 0..3 {
1067            assert!(low[i] < high[i], "bound {i}: low >= high");
1068        }
1069    }
1070
1071    #[test]
1072    fn test_bounded_action_clip_is_noop_inside_bounds() {
1073        // Construct an action at the low/high bounds: clip(low, high) must
1074        // return the same components.
1075        let low = ContinuousActionTest::low();
1076        let high = ContinuousActionTest::high();
1077        let at_low = ContinuousActionTest::from_slice(&low);
1078        let at_high = ContinuousActionTest::from_slice(&high);
1079        assert_eq!(at_low.clip(low[0], high[0]).as_slice(), &low);
1080        assert_eq!(at_high.clip(low[0], high[0]).as_slice(), &high);
1081    }
1082}