rlevo_core/action.rs
1//! Action space abstractions for reinforcement learning environments.
2//!
3//! This module provides a flexible type system for representing agent actions in RL environments.
4//! Actions can be discrete (finite choices), multi-discrete (multiple independent discrete choices),
5//! or continuous (real-valued vectors).
6//!
7//! # Design Philosophy
8//!
9//! The action traits follow a layered design:
10//! - [`Action`](crate::base::Action): Base trait providing validation and cloning semantics
11//! - [`DiscreteAction`], [`MultiDiscreteAction`], [`ContinuousAction`]: Type-specific extensions
12//!
13//! # Action Types
14//!
15//! ## Discrete Actions
16//!
17//! Discrete actions represent a finite set of mutually exclusive choices (e.g., "move left",
18//! "move right", "jump"). They are indexed from `0` to `ACTION_COUNT - 1`.
19//!
20//! ## Multi-Discrete Actions
21//!
22//! Multi-discrete actions consist of multiple independent discrete choices, such as selecting
23//! both a direction and an attack type simultaneously.
24//!
25//! ## Continuous Actions
26//!
27//! Continuous actions are real-valued vectors, typically used for motor control or
28//! parametrized actions (e.g., steering angle, throttle).
29//!
30//! # Test Suite
31//!
32//! This module includes a comprehensive test suite covering:
33//!
34//! ## DiscreteAction Tests (10 tests)
35//! - `test_discrete_action_shape`: Verifies shape and dimension constants
36//! - `test_discrete_action_count`: Checks action count constant
37//! - `test_discrete_action_from_index`: Tests index-to-action conversion
38//! - `test_discrete_action_from_index_out_of_bounds`: Validates panic on invalid indices
39//! - `test_discrete_action_to_index`: Tests action-to-index conversion
40//! - `test_discrete_action_roundtrip`: Ensures bidirectional conversion consistency
41//! - `test_discrete_action_enumerate`: Verifies all actions are enumerated correctly
42//! - `test_discrete_action_random`: Tests random action generation
43//! - `test_discrete_action_is_valid`: Validates the `is_valid()` predicate
44//! - `test_discrete_action_clone_and_debug`: Tests Debug and Clone trait implementations
45//!
46//! ## MultiDiscreteAction Tests (11 tests)
47//! - `test_multidiscrete_action_shape`: Verifies multi-dimensional shape
48//! - `test_multidiscrete_action_from_indices`: Tests multi-index conversion
49//! - `test_multidiscrete_action_from_indices_*_out_of_bounds`: Validates panic on invalid indices
50//! - `test_multidiscrete_action_to_indices`: Tests reverse conversion
51//! - `test_multidiscrete_action_roundtrip`: Ensures bidirectional consistency
52//! - `test_multidiscrete_action_enumerate`: Verifies all action combinations are enumerated
53//! - `test_multidiscrete_action_enumerate_large_space`: Tests scalability with large action spaces
54//! - `test_multidiscrete_action_random`: Tests random sampling
55//! - `test_multidiscrete_action_is_valid`: Validates constraints
56//! - `test_multidiscrete_action_clone_and_debug`: Tests trait implementations
57//!
58//! ## ContinuousAction Tests (15 tests)
59//! - `test_continuous_action_shape`: Verifies shape specification
60//! - `test_continuous_action_as_slice`: Tests slice view access
61//! - `test_continuous_action_from_slice`: Tests construction from slice
62//! - `test_continuous_action_from_slice_wrong_size`: Validates dimension checking
63//! - `test_continuous_action_roundtrip`: Ensures slice conversion consistency
64//! - `test_continuous_action_clip_*`: Tests clipping behavior (within, exceeds max/min, mixed, extreme)
65//! - `test_continuous_action_clip_chaining`: Verifies method chaining
66//! - `test_continuous_action_random`: Tests random action generation
67//! - `test_continuous_action_is_valid_*`: Tests validity checking (finite, NaN, Inf)
68//! - `test_continuous_action_with_zero_values`: Tests edge case with zero values
69//! - `test_continuous_action_clone_and_debug`: Tests trait implementations
70//!
71//! ## InvalidActionError Tests (6 tests)
72//! - `test_invalid_action_error_creation`: Tests error instantiation
73//! - `test_invalid_action_error_display`: Tests Display trait formatting
74//! - `test_invalid_action_error_debug`: Tests Debug trait formatting
75//! - `test_invalid_action_error_clone`: Tests Clone implementation
76//! - `test_invalid_action_error_equality`: Tests PartialEq implementation
77//! - `test_invalid_action_error_is_error`: Tests std::error::Error trait compatibility
78//!
79//! ## Integration Tests (4 tests)
80//! - `test_large_discrete_action_space`: Tests with 256 actions
81//! - `test_continuous_action_extreme_clip_bounds`: Tests edge cases in clipping
82//! - Various clone/debug/trait tests across different action types
83
84use crate::base::Action;
85use std::error::Error;
86use std::fmt::Debug;
87
88/// Trait for discrete actions with a finite, enumerable set of choices.
89///
90/// Discrete actions represent mutually exclusive options that can be indexed by
91/// integers from `0` to `ACTION_COUNT - 1`. Common examples include:
92/// - Game controls (move left/right/jump)
93/// - Categorical decisions (buy/hold/sell)
94/// - Navigation directions (north/south/east/west)
95///
96/// # Type Safety
97///
98/// Implementations should ensure bidirectional conversion between indices and actions:
99/// ```text
100/// ∀ i ∈ [0, ACTION_COUNT): i == from_index(i).to_index()
101/// ∀ a: Action: a == from_index(a.to_index())
102/// ```
103///
104/// # Performance
105///
106/// For performance-critical code, prefer `from_index()` over `random()` when you
107/// already have an index (e.g., from a neural network's argmax). The `random()`
108/// method allocates a thread-local RNG on each call.
109pub trait DiscreteAction<const R: usize>: Action<R> {
110 /// The total number of distinct actions in this action space.
111 ///
112 /// This constant defines the cardinality of the action space. It must be
113 /// greater than zero and remain constant for the lifetime of the program.
114 const ACTION_COUNT: usize;
115
116 /// Constructs an action from its zero-based index.
117 ///
118 /// This method must be the inverse of [`to_index()`](DiscreteAction::to_index).
119 ///
120 /// # Panics
121 ///
122 /// Implementations should panic if `index >= ACTION_COUNT`, as this indicates
123 /// a programming error (out-of-bounds access).
124 fn from_index(index: usize) -> Self;
125
126 /// Converts this action to its zero-based index.
127 ///
128 /// The returned index must be in the range `[0, ACTION_COUNT)` and must be
129 /// the inverse of [`from_index()`](DiscreteAction::from_index).
130 fn to_index(&self) -> usize;
131
132 /// Samples a uniformly random action from this action space.
133 ///
134 /// This is a convenience method for exploration in reinforcement learning.
135 /// It uses thread-local RNG state, so it's safe to call from multiple threads
136 /// but will produce different sequences per thread.
137 ///
138 /// # Performance
139 ///
140 /// If you already have an index from another source (e.g., a neural network
141 /// output), use `from_index()` directly instead of this method.
142 fn random() -> Self
143 where
144 Self: Sized,
145 {
146 use rand::RngExt;
147 let mut rng = rand::rng();
148 let index = rng.random_range(0..Self::ACTION_COUNT);
149 Self::from_index(index)
150 }
151
152 /// Returns a vector containing all possible actions in index order.
153 ///
154 /// This is useful for tabular RL methods (e.g., Q-learning) that need to
155 /// iterate over the entire action space. The returned vector has length
156 /// `ACTION_COUNT` with actions ordered by their index.
157 ///
158 /// # Performance
159 ///
160 /// This allocates a vector of size `ACTION_COUNT`. For large action spaces,
161 /// consider using an iterator pattern instead (not currently provided).
162 fn enumerate() -> Vec<Self>
163 where
164 Self: Sized,
165 {
166 (0..Self::ACTION_COUNT).map(Self::from_index).collect()
167 }
168}
169
170/// Trait for actions with multiple independent discrete dimensions.
171///
172/// Multi-discrete actions represent scenarios where an agent must make several
173/// independent categorical choices simultaneously. Each dimension can have a
174/// different number of options. This is common in:
175/// - Strategy games (select unit + select action + select target)
176/// - Multi-agent coordination (each agent picks a discrete action)
177/// - Parameterized actions (choose action type + intensity level)
178///
179/// # Dimensionality
180///
181/// The const generic `R` (the rank) specifies the number of axes. Each axis
182/// can have a different cardinality, defined by [`shape()`](Action::shape).
183///
184/// The total number of action combinations is the product of all dimension sizes:
185/// ```text
186/// total_actions = ∏ shape()[i]
187/// ```
188///
189/// # Caution: Combinatorial Explosion
190///
191/// Be careful with [`enumerate()`](MultiDiscreteAction::enumerate) on large action spaces.
192/// A 3D action space with dimensions [10, 10, 10] produces 1000 actions, but
193/// [100, 100, 100] produces 1,000,000!
194pub trait MultiDiscreteAction<const R: usize>: Action<R> {
195 /// Constructs an action from multi-dimensional indices.
196 ///
197 /// Each index must be in the range `[0, shape()[i])` for axis `i`.
198 ///
199 /// # Panics
200 ///
201 /// Implementations should panic if any index is out of bounds for its axis.
202 fn from_indices(indices: [usize; R]) -> Self;
203
204 /// Converts this action to its multi-dimensional index representation.
205 ///
206 /// The returned array must satisfy: each element `i` is in `[0, shape()[i])`.
207 /// This method must be the inverse of [`from_indices()`](MultiDiscreteAction::from_indices).
208 fn to_indices(&self) -> [usize; R];
209
210 /// Samples a uniformly random action from this multi-discrete action space.
211 ///
212 /// Each dimension is sampled independently and uniformly from its valid range.
213 /// This is useful for exploration in reinforcement learning.
214 ///
215 /// # Examples
216 ///
217 /// ```rust,ignore
218 /// let random_action = StrategyAction::random();
219 /// assert!(random_action.is_valid());
220 /// ```
221 fn random() -> Self
222 where
223 Self: Sized,
224 {
225 use rand::RngExt;
226 let mut rng = rand::rng();
227 let space = Self::shape();
228 let indices = space.map(|dim| rng.random_range(0..dim));
229 Self::from_indices(indices)
230 }
231
232 /// Returns all possible action combinations in this space.
233 ///
234 /// # Warning
235 ///
236 /// This method generates **all** combinations across all dimensions. The number
237 /// of actions grows multiplicatively with the product of dimension sizes:
238 ///
239 /// - `[10, 10, 10]` → 1,000 actions (manageable)
240 /// - `[50, 50, 50]` → 125,000 actions (marginal)
241 /// - `[100, 100, 100]` → 1,000,000 actions (likely too large)
242 ///
243 /// Use this method only when you need to iterate over the entire action space
244 /// (e.g., for exact policy evaluation in tabular methods).
245 ///
246 /// # Panics
247 ///
248 /// May panic or run out of memory if the action space is too large.
249 fn enumerate() -> Vec<Self>
250 where
251 Self: Sized,
252 {
253 let space = Self::shape();
254 let total: usize = space.iter().product();
255 let mut actions = Vec::with_capacity(total);
256
257 fn generate<const R: usize, T: MultiDiscreteAction<R>>(
258 space: &[usize; R],
259 current: &mut [usize; R],
260 axis: usize,
261 actions: &mut Vec<T>,
262 ) {
263 if axis == R {
264 actions.push(T::from_indices(*current));
265 return;
266 }
267 for i in 0..space[axis] {
268 current[axis] = i;
269 generate(space, current, axis + 1, actions);
270 }
271 }
272
273 let mut current = [0; R];
274 generate(&space, &mut current, 0, &mut actions);
275 actions
276 }
277}
278
279/// Trait for continuous-valued actions represented as real-valued vectors.
280///
281/// Continuous actions are used when the agent's output is a vector of real numbers
282/// rather than discrete choices. Common applications include:
283/// - Robot motor control (joint angles, torques)
284/// - Vehicle control (steering, throttle, brake)
285/// - Continuous parameter tuning (learning rates, temperatures)
286///
287/// # Value Range
288///
289/// Continuous actions typically have bounded ranges (e.g., `[-1, 1]` or `[0, 1]`).
290/// The [`clip()`](ContinuousAction::clip) method enforces these bounds.
291///
292/// # Neural Network Integration
293///
294/// Continuous actions are typically produced by neural networks with `tanh` or
295/// `sigmoid` activation functions. Use [`clip()`](ContinuousAction::clip) to
296/// ensure outputs stay within valid ranges.
297pub trait ContinuousAction<const R: usize>: Action<R> {
298 /// Returns a slice view of this action's component values.
299 ///
300 /// The returned slice must have exactly `RANK` elements. This is used for
301 /// efficient serialization and tensor conversion.
302 fn as_slice(&self) -> &[f32];
303
304 /// Returns a new action with all components clipped to `[min, max]`.
305 ///
306 /// This is essential for ensuring neural network outputs (which may exceed
307 /// valid ranges due to numerical errors or exploration noise) stay within
308 /// acceptable bounds.
309 ///
310 /// # Common Use Cases
311 ///
312 /// - Enforcing action space bounds after neural network output
313 /// - Adding exploration noise while maintaining validity
314 /// - Recovering from numerical instability
315 fn clip(&self, min: f32, max: f32) -> Self;
316
317 /// Samples a random action with components uniformly distributed in `[-1.0, 1.0)`.
318 ///
319 /// The default implementation generates uniform random values in the half-open
320 /// range `[-1.0, 1.0)` using `rand::random_range`. Override this method if you
321 /// need different sampling behavior (e.g., Gaussian noise, domain-specific
322 /// distributions, or bounds derived from [`BoundedAction`]).
323 fn random() -> Self
324 where
325 Self: Sized,
326 {
327 use rand::RngExt;
328 let mut rng = rand::rng();
329 // Default implementation - override for custom behavior
330 let values: Vec<f32> = (0..Self::RANK)
331 .map(|_| rng.random_range(-1.0..1.0))
332 .collect();
333 Self::from_slice(&values)
334 }
335
336 /// Constructs an action from a slice of component values.
337 ///
338 /// The input slice must have exactly `RANK` elements. This is the inverse
339 /// operation of [`as_slice()`](ContinuousAction::as_slice).
340 ///
341 /// # Panics
342 ///
343 /// Implementations should panic if `values.len() != RANK`.
344 fn from_slice(values: &[f32]) -> Self;
345}
346
347/// A [`ContinuousAction`] with statically-known `[low, high]` component bounds.
348///
349/// DDPG and other continuous-control algorithms need the per-component action
350/// bounds to scale/shift neural-network outputs and to sample uniform warm-up
351/// actions. Expose them via associated static methods rather than associated
352/// constants so implementors can still derive bounds from a runtime env config
353/// (e.g. a `max_torque` field) while presenting a uniform API.
354///
355/// # Invariants
356///
357/// - `low()[i] < high()[i]` for every component `i`.
358/// - [`ContinuousAction::clip`] must be a no-op on an action whose components
359/// already lie in `[low, high]`.
360pub trait BoundedAction<const R: usize>: ContinuousAction<R> {
361 /// Returns the per-component lower bounds of this action space.
362 ///
363 /// The returned array must satisfy `low()[i] < high()[i]` for every component
364 /// `i`. See the trait-level invariants for the full contract.
365 fn low() -> [f32; R];
366
367 /// Returns the per-component upper bounds of this action space.
368 ///
369 /// The returned array must satisfy `low()[i] < high()[i]` for every component
370 /// `i`. See the trait-level invariants for the full contract.
371 fn high() -> [f32; R];
372}
373
374/// Error indicating an action violated its type's constraints.
375///
376/// Returned when an action fails validation or when invalid conversions are
377/// attempted (e.g., out-of-bounds indices, non-finite float values).
378///
379/// # Examples
380///
381/// ```
382/// use rlevo_core::action::InvalidActionError;
383///
384/// let err = InvalidActionError { message: "index 5 out of bounds for ACTION_COUNT=4".into() };
385/// assert!(err.to_string().contains("index 5"));
386/// ```
387#[derive(Debug, Clone, PartialEq)]
388pub struct InvalidActionError {
389 /// Human-readable description of the constraint that was violated.
390 pub message: String,
391}
392
393impl std::fmt::Display for InvalidActionError {
394 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
395 write!(f, "Invalid action: {}", self.message)
396 }
397}
398
399impl Error for InvalidActionError {}
400
401// ============================================================================
402// Tests
403// ============================================================================
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 // ========================================================================
410 // Test Implementations
411 // ========================================================================
412
413 /// Simple discrete action with 4 possible choices.
414 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
415 enum SimpleDiscreteAction {
416 Left,
417 Right,
418 Up,
419 Down,
420 }
421
422 impl Action<1> for SimpleDiscreteAction {
423 fn shape() -> [usize; 1] {
424 [4]
425 }
426
427 fn is_valid(&self) -> bool {
428 true // All variants are always valid
429 }
430 }
431
432 impl DiscreteAction<1> for SimpleDiscreteAction {
433 const ACTION_COUNT: usize = 4;
434
435 fn from_index(index: usize) -> Self {
436 match index {
437 0 => SimpleDiscreteAction::Left,
438 1 => SimpleDiscreteAction::Right,
439 2 => SimpleDiscreteAction::Up,
440 3 => SimpleDiscreteAction::Down,
441 _ => panic!("Index out of bounds: {}", index),
442 }
443 }
444
445 fn to_index(&self) -> usize {
446 match self {
447 SimpleDiscreteAction::Left => 0,
448 SimpleDiscreteAction::Right => 1,
449 SimpleDiscreteAction::Up => 2,
450 SimpleDiscreteAction::Down => 3,
451 }
452 }
453 }
454
455 /// Multi-discrete action with 2 dimensions.
456 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
457 struct MultiActionTest {
458 direction: usize, // 0-3
459 intensity: usize, // 0-2
460 }
461
462 impl Action<2> for MultiActionTest {
463 fn shape() -> [usize; 2] {
464 [4, 3]
465 }
466
467 fn is_valid(&self) -> bool {
468 self.direction < 4 && self.intensity < 3
469 }
470 }
471
472 impl MultiDiscreteAction<2> for MultiActionTest {
473 fn from_indices(indices: [usize; 2]) -> Self {
474 if indices[0] >= 4 {
475 panic!("Direction index out of bounds: {}", indices[0]);
476 }
477 if indices[1] >= 3 {
478 panic!("Intensity index out of bounds: {}", indices[1]);
479 }
480 MultiActionTest {
481 direction: indices[0],
482 intensity: indices[1],
483 }
484 }
485
486 fn to_indices(&self) -> [usize; 2] {
487 [self.direction, self.intensity]
488 }
489 }
490
491 /// Continuous action with 3 dimensions (e.g., 3D velocity).
492 #[derive(Debug, Clone)]
493 struct ContinuousActionTest {
494 values: [f32; 3],
495 }
496
497 impl Action<3> for ContinuousActionTest {
498 fn shape() -> [usize; 3] {
499 [1, 1, 1] // Continuous dimensions typically have size 1
500 }
501
502 fn is_valid(&self) -> bool {
503 self.values.iter().all(|v| v.is_finite())
504 }
505 }
506
507 impl ContinuousAction<3> for ContinuousActionTest {
508 fn as_slice(&self) -> &[f32] {
509 &self.values
510 }
511
512 fn clip(&self, min: f32, max: f32) -> Self {
513 let clipped = self
514 .values
515 .iter()
516 .map(|&v| v.max(min).min(max))
517 .collect::<Vec<_>>();
518 ContinuousActionTest {
519 values: [clipped[0], clipped[1], clipped[2]],
520 }
521 }
522
523 fn from_slice(values: &[f32]) -> Self {
524 assert_eq!(values.len(), 3, "Expected 3 values, got {}", values.len());
525 ContinuousActionTest {
526 values: [values[0], values[1], values[2]],
527 }
528 }
529 }
530
531 impl BoundedAction<3> for ContinuousActionTest {
532 fn low() -> [f32; 3] {
533 [-1.0, -1.0, -1.0]
534 }
535
536 fn high() -> [f32; 3] {
537 [1.0, 1.0, 1.0]
538 }
539 }
540
541 // ========================================================================
542 // DiscreteAction Tests
543 // ========================================================================
544
545 #[test]
546 fn test_discrete_action_shape() {
547 assert_eq!(SimpleDiscreteAction::shape(), [4]);
548 assert_eq!(SimpleDiscreteAction::RANK, 1);
549 }
550
551 #[test]
552 fn test_discrete_action_count() {
553 assert_eq!(SimpleDiscreteAction::ACTION_COUNT, 4);
554 }
555
556 #[test]
557 fn test_discrete_action_from_index() {
558 assert_eq!(
559 SimpleDiscreteAction::from_index(0),
560 SimpleDiscreteAction::Left
561 );
562 assert_eq!(
563 SimpleDiscreteAction::from_index(1),
564 SimpleDiscreteAction::Right
565 );
566 assert_eq!(
567 SimpleDiscreteAction::from_index(2),
568 SimpleDiscreteAction::Up
569 );
570 assert_eq!(
571 SimpleDiscreteAction::from_index(3),
572 SimpleDiscreteAction::Down
573 );
574 }
575
576 #[test]
577 #[should_panic(expected = "Index out of bounds")]
578 fn test_discrete_action_from_index_out_of_bounds() {
579 SimpleDiscreteAction::from_index(4);
580 }
581
582 #[test]
583 #[should_panic(expected = "Index out of bounds")]
584 fn test_discrete_action_from_index_negative_like() {
585 // Note: usize can't be negative, but we test the boundary
586 SimpleDiscreteAction::from_index(100);
587 }
588
589 #[test]
590 fn test_discrete_action_to_index() {
591 assert_eq!(SimpleDiscreteAction::Left.to_index(), 0);
592 assert_eq!(SimpleDiscreteAction::Right.to_index(), 1);
593 assert_eq!(SimpleDiscreteAction::Up.to_index(), 2);
594 assert_eq!(SimpleDiscreteAction::Down.to_index(), 3);
595 }
596
597 #[test]
598 fn test_discrete_action_roundtrip() {
599 // Test bidirectional conversion
600 for i in 0..SimpleDiscreteAction::ACTION_COUNT {
601 let action = SimpleDiscreteAction::from_index(i);
602 assert_eq!(action.to_index(), i);
603 }
604 }
605
606 #[test]
607 fn test_discrete_action_enumerate() {
608 let actions = SimpleDiscreteAction::enumerate();
609 assert_eq!(actions.len(), 4);
610 assert_eq!(
611 actions,
612 vec![
613 SimpleDiscreteAction::Left,
614 SimpleDiscreteAction::Right,
615 SimpleDiscreteAction::Up,
616 SimpleDiscreteAction::Down
617 ]
618 );
619 }
620
621 #[test]
622 fn test_discrete_action_random() {
623 for _ in 0..100 {
624 let action = SimpleDiscreteAction::random();
625 let index = action.to_index();
626 assert!(index < SimpleDiscreteAction::ACTION_COUNT);
627 }
628 }
629
630 #[test]
631 fn test_discrete_action_is_valid() {
632 for i in 0..SimpleDiscreteAction::ACTION_COUNT {
633 let action = SimpleDiscreteAction::from_index(i);
634 assert!(action.is_valid());
635 }
636 }
637
638 // ========================================================================
639 // MultiDiscreteAction Tests
640 // ========================================================================
641
642 #[test]
643 fn test_multidiscrete_action_shape() {
644 assert_eq!(MultiActionTest::shape(), [4, 3]);
645 assert_eq!(MultiActionTest::RANK, 2);
646 }
647
648 #[test]
649 fn test_multidiscrete_action_from_indices() {
650 let action = MultiActionTest::from_indices([0, 0]);
651 assert_eq!(action.direction, 0);
652 assert_eq!(action.intensity, 0);
653
654 let action = MultiActionTest::from_indices([3, 2]);
655 assert_eq!(action.direction, 3);
656 assert_eq!(action.intensity, 2);
657 }
658
659 #[test]
660 #[should_panic(expected = "Direction index out of bounds")]
661 fn test_multidiscrete_action_from_indices_direction_out_of_bounds() {
662 MultiActionTest::from_indices([4, 0]);
663 }
664
665 #[test]
666 #[should_panic(expected = "Intensity index out of bounds")]
667 fn test_multidiscrete_action_from_indices_intensity_out_of_bounds() {
668 MultiActionTest::from_indices([0, 3]);
669 }
670
671 #[test]
672 fn test_multidiscrete_action_to_indices() {
673 let action = MultiActionTest::from_indices([2, 1]);
674 assert_eq!(action.to_indices(), [2, 1]);
675 }
676
677 #[test]
678 fn test_multidiscrete_action_roundtrip() {
679 for d in 0..4 {
680 for i in 0..3 {
681 let action = MultiActionTest::from_indices([d, i]);
682 assert_eq!(action.to_indices(), [d, i]);
683 }
684 }
685 }
686
687 #[test]
688 fn test_multidiscrete_action_enumerate() {
689 let actions = MultiActionTest::enumerate();
690 // 4 directions × 3 intensities = 12 total actions
691 assert_eq!(actions.len(), 12);
692
693 // Verify all combinations are present
694 for (idx, action) in actions.iter().enumerate() {
695 let expected_d = idx / 3;
696 let expected_i = idx % 3;
697 assert_eq!(action.direction, expected_d);
698 assert_eq!(action.intensity, expected_i);
699 }
700 }
701
702 #[test]
703 fn test_multidiscrete_action_enumerate_large_space() {
704 // Test with 3D space: [5, 5, 5] = 125 total actions
705 #[derive(Debug, Clone)]
706 struct LargeMultiAction([usize; 3]);
707
708 impl Action<3> for LargeMultiAction {
709 fn shape() -> [usize; 3] {
710 [5, 5, 5]
711 }
712
713 fn is_valid(&self) -> bool {
714 self.0.iter().enumerate().all(|(i, &v)| v < [5, 5, 5][i])
715 }
716 }
717
718 impl MultiDiscreteAction<3> for LargeMultiAction {
719 fn from_indices(indices: [usize; 3]) -> Self {
720 for (i, &idx) in indices.iter().enumerate() {
721 assert!(idx < 5, "Index {} out of bounds", i);
722 }
723 LargeMultiAction(indices)
724 }
725
726 fn to_indices(&self) -> [usize; 3] {
727 self.0
728 }
729 }
730
731 let actions = LargeMultiAction::enumerate();
732 assert_eq!(actions.len(), 125);
733 }
734
735 #[test]
736 fn test_multidiscrete_action_random() {
737 for _ in 0..100 {
738 let action = MultiActionTest::random();
739 assert!(action.is_valid());
740 let indices = action.to_indices();
741 assert!(indices[0] < 4);
742 assert!(indices[1] < 3);
743 }
744 }
745
746 #[test]
747 fn test_multidiscrete_action_is_valid() {
748 // Valid actions
749 assert!(MultiActionTest::from_indices([0, 0]).is_valid());
750 assert!(MultiActionTest::from_indices([3, 2]).is_valid());
751
752 // Invalid actions created directly
753 let invalid = MultiActionTest {
754 direction: 5,
755 intensity: 0,
756 };
757 assert!(!invalid.is_valid());
758
759 let invalid = MultiActionTest {
760 direction: 0,
761 intensity: 5,
762 };
763 assert!(!invalid.is_valid());
764 }
765
766 // ========================================================================
767 // ContinuousAction Tests
768 // ========================================================================
769
770 #[test]
771 fn test_continuous_action_shape() {
772 assert_eq!(ContinuousActionTest::shape(), [1, 1, 1]);
773 assert_eq!(ContinuousActionTest::RANK, 3);
774 }
775
776 #[test]
777 fn test_continuous_action_as_slice() {
778 let action = ContinuousActionTest {
779 values: [0.5, -0.3, 1.0],
780 };
781 let slice = action.as_slice();
782 assert_eq!(slice.len(), 3);
783 assert_eq!(slice, &[0.5, -0.3, 1.0]);
784 }
785
786 #[test]
787 fn test_continuous_action_from_slice() {
788 let values = [0.1, 0.2, 0.3];
789 let action = ContinuousActionTest::from_slice(&values);
790 assert_eq!(action.values, values);
791 }
792
793 #[test]
794 #[should_panic(expected = "Expected 3 values")]
795 fn test_continuous_action_from_slice_wrong_size() {
796 let values = [0.1, 0.2];
797 ContinuousActionTest::from_slice(&values);
798 }
799
800 #[test]
801 fn test_continuous_action_roundtrip() {
802 let original = [0.5, -0.3, 0.9];
803 let action = ContinuousActionTest::from_slice(&original);
804 assert_eq!(action.as_slice(), &original);
805 }
806
807 #[test]
808 fn test_continuous_action_clip_within_bounds() {
809 let action = ContinuousActionTest {
810 values: [0.0, 0.5, -0.5],
811 };
812 let clipped = action.clip(-1.0, 1.0);
813 assert_eq!(clipped.values, [0.0, 0.5, -0.5]);
814 }
815
816 #[test]
817 fn test_continuous_action_clip_exceeds_max() {
818 let action = ContinuousActionTest {
819 values: [2.0, 1.5, 3.0],
820 };
821 let clipped = action.clip(-1.0, 1.0);
822 assert_eq!(clipped.values, [1.0, 1.0, 1.0]);
823 }
824
825 #[test]
826 fn test_continuous_action_clip_exceeds_min() {
827 let action = ContinuousActionTest {
828 values: [-2.0, -1.5, -3.0],
829 };
830 let clipped = action.clip(-1.0, 1.0);
831 assert_eq!(clipped.values, [-1.0, -1.0, -1.0]);
832 }
833
834 #[test]
835 fn test_continuous_action_clip_mixed() {
836 let action = ContinuousActionTest {
837 values: [2.0, 0.5, -2.0],
838 };
839 let clipped = action.clip(-1.0, 1.0);
840 assert_eq!(clipped.values, [1.0, 0.5, -1.0]);
841 }
842
843 #[test]
844 fn test_continuous_action_random() {
845 for _ in 0..100 {
846 let action = ContinuousActionTest::random();
847 assert!(action.is_valid());
848 for &value in action.as_slice() {
849 assert!((-1.0..=1.0).contains(&value));
850 assert!(value.is_finite());
851 }
852 }
853 }
854
855 #[test]
856 fn test_continuous_action_is_valid_finite() {
857 let action = ContinuousActionTest {
858 values: [0.5, -0.3, 1.0],
859 };
860 assert!(action.is_valid());
861 }
862
863 #[test]
864 fn test_continuous_action_is_invalid_nan() {
865 let action = ContinuousActionTest {
866 values: [f32::NAN, 0.5, 1.0],
867 };
868 assert!(!action.is_valid());
869 }
870
871 #[test]
872 fn test_continuous_action_is_invalid_inf() {
873 let action = ContinuousActionTest {
874 values: [f32::INFINITY, 0.5, 1.0],
875 };
876 assert!(!action.is_valid());
877
878 let action = ContinuousActionTest {
879 values: [f32::NEG_INFINITY, 0.5, 1.0],
880 };
881 assert!(!action.is_valid());
882 }
883
884 // ========================================================================
885 // InvalidActionError Tests
886 // ========================================================================
887
888 #[test]
889 fn test_invalid_action_error_creation() {
890 let error = InvalidActionError {
891 message: String::from("Index out of bounds"),
892 };
893 assert_eq!(error.message, "Index out of bounds");
894 }
895
896 #[test]
897 fn test_invalid_action_error_display() {
898 let error = InvalidActionError {
899 message: String::from("Invalid value"),
900 };
901 let displayed = format!("{}", error);
902 assert_eq!(displayed, "Invalid action: Invalid value");
903 }
904
905 #[test]
906 fn test_invalid_action_error_debug() {
907 let error = InvalidActionError {
908 message: String::from("Test error"),
909 };
910 let debug_str = format!("{:?}", error);
911 assert!(debug_str.contains("Test error"));
912 }
913
914 #[test]
915 fn test_invalid_action_error_clone() {
916 let error = InvalidActionError {
917 message: String::from("Original"),
918 };
919 let cloned = error.clone();
920 assert_eq!(error, cloned);
921 }
922
923 #[test]
924 fn test_invalid_action_error_equality() {
925 let error1 = InvalidActionError {
926 message: String::from("Same error"),
927 };
928 let error2 = InvalidActionError {
929 message: String::from("Same error"),
930 };
931 let error3 = InvalidActionError {
932 message: String::from("Different error"),
933 };
934
935 assert_eq!(error1, error2);
936 assert_ne!(error1, error3);
937 }
938
939 #[test]
940 fn test_invalid_action_error_is_error() {
941 let error: Box<dyn Error> = Box::new(InvalidActionError {
942 message: String::from("Test"),
943 });
944 // Should be able to use as std::error::Error trait object
945 let _msg = error.to_string();
946 }
947
948 // ========================================================================
949 // Integration Tests
950 // ========================================================================
951
952 #[test]
953 fn test_discrete_action_clone_and_debug() {
954 let action = SimpleDiscreteAction::Left;
955 let cloned = action;
956 assert_eq!(action, cloned);
957
958 let debug_str = format!("{:?}", action);
959 assert!(debug_str.contains("Left"));
960 }
961
962 #[test]
963 fn test_multidiscrete_action_clone_and_debug() {
964 let action = MultiActionTest::from_indices([1, 2]);
965 let cloned = action;
966 assert_eq!(action, cloned);
967
968 let debug_str = format!("{:?}", action);
969 assert!(debug_str.contains("direction"));
970 }
971
972 #[test]
973 fn test_continuous_action_clone_and_debug() {
974 let action = ContinuousActionTest {
975 values: [0.1, 0.2, 0.3],
976 };
977 let cloned = action.clone();
978 assert_eq!(action.as_slice(), cloned.as_slice());
979
980 let debug_str = format!("{:?}", action);
981 assert!(debug_str.contains("values"));
982 }
983
984 #[test]
985 fn test_continuous_action_clip_chaining() {
986 let action = ContinuousActionTest {
987 values: [2.0, -3.0, 0.5],
988 };
989 let clipped = action.clip(-2.0, 2.0).clip(-1.0, 1.0);
990 assert_eq!(clipped.values, [1.0, -1.0, 0.5]);
991 }
992
993 #[test]
994 fn test_large_discrete_action_space() {
995 // Test with 256 actions
996 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
997 struct LargeDiscreteAction(u8);
998
999 impl Action<1> for LargeDiscreteAction {
1000 fn shape() -> [usize; 1] {
1001 [256]
1002 }
1003
1004 fn is_valid(&self) -> bool {
1005 true
1006 }
1007 }
1008
1009 impl DiscreteAction<1> for LargeDiscreteAction {
1010 const ACTION_COUNT: usize = 256;
1011
1012 fn from_index(index: usize) -> Self {
1013 assert!(index < 256);
1014 LargeDiscreteAction(index as u8)
1015 }
1016
1017 fn to_index(&self) -> usize {
1018 self.0 as usize
1019 }
1020 }
1021
1022 // Enumerate should produce all 256 actions
1023 let actions = LargeDiscreteAction::enumerate();
1024 assert_eq!(actions.len(), 256);
1025
1026 // Verify roundtrip for a few samples
1027 for i in [0, 1, 127, 255] {
1028 let action = LargeDiscreteAction::from_index(i);
1029 assert_eq!(action.to_index(), i);
1030 }
1031 }
1032
1033 #[test]
1034 fn test_continuous_action_with_zero_values() {
1035 let action = ContinuousActionTest {
1036 values: [0.0, 0.0, 0.0],
1037 };
1038 assert!(action.is_valid());
1039 assert_eq!(action.as_slice(), &[0.0, 0.0, 0.0]);
1040
1041 let clipped = action.clip(-1.0, 1.0);
1042 assert_eq!(clipped.values, [0.0, 0.0, 0.0]);
1043 }
1044
1045 #[test]
1046 fn test_continuous_action_extreme_clip_bounds() {
1047 let action = ContinuousActionTest {
1048 values: [100.0, -100.0, 0.0],
1049 };
1050
1051 let clipped = action.clip(f32::NEG_INFINITY, f32::INFINITY);
1052 assert_eq!(clipped.values, [100.0, -100.0, 0.0]);
1053
1054 let clipped = action.clip(0.0, 0.0);
1055 assert_eq!(clipped.values, [0.0, 0.0, 0.0]);
1056 }
1057
1058 // ========================================================================
1059 // BoundedAction Tests
1060 // ========================================================================
1061
1062 #[test]
1063 fn test_bounded_action_low_strictly_below_high() {
1064 let low = ContinuousActionTest::low();
1065 let high = ContinuousActionTest::high();
1066 for i in 0..3 {
1067 assert!(low[i] < high[i], "bound {i}: low >= high");
1068 }
1069 }
1070
1071 #[test]
1072 fn test_bounded_action_clip_is_noop_inside_bounds() {
1073 // Construct an action at the low/high bounds: clip(low, high) must
1074 // return the same components.
1075 let low = ContinuousActionTest::low();
1076 let high = ContinuousActionTest::high();
1077 let at_low = ContinuousActionTest::from_slice(&low);
1078 let at_high = ContinuousActionTest::from_slice(&high);
1079 assert_eq!(at_low.clip(low[0], high[0]).as_slice(), &low);
1080 assert_eq!(at_high.clip(low[0], high[0]).as_slice(), &high);
1081 }
1082}