Skip to main content

rlevo_core/
base.rs

1//! Core traits for reinforcement learning abstractions.
2//!
3//! This module defines the foundational vocabulary used throughout `rlevo-core`:
4//! rewards, observations, states, actions, transition dynamics, and tensor
5//! conversion. All other modules depend on these primitives.
6
7use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use serde::{Deserialize, Serialize};
10use std::error::Error;
11use std::fmt::Debug;
12
13/// Generic update function: how something evolves over time.
14///
15/// Parameterized over the input stimulus and the output type it transforms.
16pub trait UpdateFunction<Input, Output> {
17    /// Computes the next value given the current value and an input.
18    fn update(&self, current: &Output, input: &Input) -> Output;
19}
20
21/// A scalar reward signal emitted by an environment each step.
22pub trait Reward: Clone + std::ops::Add<Output = Self> + Into<f32> + Debug {
23    /// Returns the additive identity for this reward type (typically `0.0`).
24    fn zero() -> Self;
25}
26
27/// The `Observation` trait defines how an agent perceives the world. It
28/// represents something that can be observed from the environment.
29/// Implements `Serialize` and `Deserialize` for storage in a replay buffer.
30pub trait Observation<const R: usize>:
31    Debug + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>
32{
33    /// The rank of this observation space — i.e. the number of axes (tensor
34    /// order), *not* the size of any axis.
35    ///
36    /// "Rank" here means the count of indices needed to address an element
37    /// (NumPy `ndim`, Burn's `Tensor<B, R>`), not matrix rank or CP-decomposition
38    /// rank. This is automatically set to match the const generic parameter `R`.
39    const RANK: usize = R;
40
41    /// Returns the size of each axis in this observation space.
42    ///
43    /// The returned array has length `R` (the rank), where each element is the
44    /// cardinality of that axis — the number of possible values along it. All
45    /// values must be greater than zero.
46    fn shape() -> [usize; R];
47}
48
49/// The complete state of an environment (Markov property)
50pub trait State<const R: usize>: Debug + Clone + Send + Sync {
51    /// The rank of this state space — i.e. the number of axes (tensor order),
52    /// *not* the size of any axis.
53    ///
54    /// "Rank" here means the count of indices needed to address an element
55    /// (NumPy `ndim`, Burn's `Tensor<B, R>`), not matrix rank or CP-decomposition
56    /// rank. This is automatically set to match the const generic parameter `R`.
57    const RANK: usize = R;
58
59    type Observation: Observation<R>;
60
61    /// Returns the size of each axis in this state space.
62    ///
63    /// The returned array has length `R` (the rank), where each element is the
64    /// cardinality of that axis — the number of possible values along it. All
65    /// values must be greater than zero.
66    fn shape() -> [usize; R];
67
68    /// Generate an observation from this state (may be partial)
69    fn observe(&self) -> Self::Observation;
70
71    /// Validates whether this state satisfies all constraints.
72    ///
73    /// This method checks if the state is legal according to its type's invariants.
74    /// It does **not** check environment-specific legality - that's the environment's responsibility.
75    ///
76    /// # Returns
77    ///
78    /// Returns `true` if the state satisfies all structural constraints, `false` otherwise.
79    fn is_valid(&self) -> bool;
80
81    /// Returns the total number of scalar elements in this state's representation.
82    ///
83    /// This value is critical for:
84    /// - Allocating buffers for state serialization
85    /// - Determining neural network input layer dimensions
86    /// - Validating state transformations (e.g., flattening/unflattening)
87    ///
88    /// # Relationship to Shape
89    ///
90    /// For consistency, `numel()` must equal the product of all dimensions returned by
91    /// [`shape()`](State::shape). The default implementation enforces this by computing
92    /// the product directly. Override only if the state uses a non-product layout.
93    ///
94    /// # Returns
95    ///
96    /// The total number of scalar elements needed to represent this state.
97    fn numel(&self) -> usize {
98        Self::shape().iter().product()
99    }
100}
101
102/// Base trait for all action types in reinforcement learning environments.
103///
104/// This trait defines the minimal interface that all actions must implement, regardless
105/// of their underlying representation (discrete, continuous, or hybrid). It ensures actions
106/// are debuggable, clonable, and can validate themselves.
107///
108/// # Design Rationale
109///
110/// The `Action` trait is intentionally minimal and framework-agnostic:
111/// - `Debug`: Required for logging and debugging agents
112/// - `Clone`: Actions may be stored in replay buffers or used multiple times
113/// - `Sized`: Enables efficient stack allocation and compile-time optimization
114/// - `is_valid()`: Allows runtime validation of action constraints
115///
116/// # Implementing Action
117///
118/// When implementing this trait, ensure `is_valid()` checks all constraints:
119/// - Range bounds for numeric values
120/// - Finiteness for floating-point values
121/// - Structural invariants (e.g., array dimensions)
122/// - Environment-specific rules (e.g., available moves in a game state)
123pub trait Action<const R: usize>: Debug + Clone + Sized {
124    /// The rank of this action space — i.e. the number of axes (tensor order),
125    /// *not* the size of any axis.
126    ///
127    /// "Rank" here means the count of indices needed to address an element
128    /// (NumPy `ndim`, Burn's `Tensor<B, R>`), not matrix rank or CP-decomposition
129    /// rank. This is automatically set to match the const generic parameter `R`.
130    const RANK: usize = R;
131
132    /// Returns the size of each axis in this action space.
133    ///
134    /// The returned array has length `R` (the rank), where each element is the
135    /// cardinality of that axis — the number of possible values along it. All
136    /// values must be greater than zero.
137    fn shape() -> [usize; R];
138
139    /// Validates whether this action satisfies all constraints.
140    ///
141    /// This method checks if the action is legal according to its type's invariants.
142    /// It does **not** check environment-specific legality (e.g., whether a move
143    /// is valid in the current game state)—that's the environment's responsibility.
144    ///
145    /// # Returns
146    ///
147    /// Returns `true` if the action satisfies all structural constraints, `false` otherwise.
148    fn is_valid(&self) -> bool;
149}
150
151/// Deterministic environment transition dynamics: s_{t+1} = f(s_t, a_t).
152///
153/// This trait covers only **deterministic** transitions. Stochastic dynamics
154/// (where the successor state is drawn from a distribution) are not modeled
155/// here; environments with stochastic transitions implement that logic internally
156/// inside [`crate::environment::Environment::step`].
157pub trait TransitionDynamics<const SR: usize, const AR: usize, S: State<SR>, A: Action<AR>> {
158    /// Returns the successor state after applying `action` to `state`.
159    fn transition(&self, state: &S, action: &A) -> S;
160}
161
162/// Error returned when a tensor cannot be converted to or from a domain type.
163#[derive(Debug, Clone, PartialEq)]
164pub struct TensorConversionError {
165    /// Human-readable description of why the conversion failed.
166    pub message: String,
167}
168
169impl std::fmt::Display for TensorConversionError {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        write!(f, "Invalid tensor conversion: {}", self.message)
172    }
173}
174
175impl Error for TensorConversionError {}
176
177/// Bidirectional conversion between a domain type and a Burn tensor.
178///
179/// Implementors must round-trip: `from_tensor(x.to_tensor(device))` equals
180/// `Ok(x)` for any valid `x`. Strategies and replay buffers rely on this
181/// invariant.
182///
183/// # Type Parameters
184///
185/// - `R`: Rank of the tensor produced.
186/// - `B`: Burn backend.
187///
188/// # Errors
189///
190/// `from_tensor` returns [`TensorConversionError`] when the tensor's shape,
191/// dtype, or contents violate the domain type's invariants (see
192/// [`State::is_valid`] / [`Action::is_valid`]).
193pub trait TensorConvertible<const R: usize, B: Backend>: Sized {
194    /// Converts `self` into a tensor on `device`.
195    fn to_tensor(&self, device: &<B as burn::tensor::backend::BackendTypes>::Device) -> Tensor<B, R>;
196
197    /// Reconstructs a value from a tensor.
198    ///
199    /// # Errors
200    ///
201    /// Returns [`TensorConversionError`] if the tensor's shape or contents
202    /// do not describe a valid instance of `Self`.
203    fn from_tensor(tensor: Tensor<B, R>) -> Result<Self, TensorConversionError>;
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    /// Simple scalar reward implementation for testing
211    #[derive(Clone, Debug, PartialEq)]
212    struct TestReward(f32);
213
214    impl Reward for TestReward {
215        fn zero() -> Self {
216            TestReward(0.0)
217        }
218    }
219
220    impl std::ops::Add for TestReward {
221        type Output = Self;
222
223        fn add(self, other: Self) -> Self {
224            TestReward(self.0 + other.0)
225        }
226    }
227
228    impl From<TestReward> for f32 {
229        fn from(reward: TestReward) -> f32 {
230            reward.0
231        }
232    }
233
234    // ===== Basic Reward Trait Tests =====
235
236    /// Test that zero() creates a neutral element for addition
237    #[test]
238    fn test_reward_zero_is_additive_identity() {
239        let zero = TestReward::zero();
240        let reward = TestReward(42.5);
241
242        // zero + reward should equal reward
243        let result = zero.clone() + reward.clone();
244        assert_eq!(result, reward);
245
246        // reward + zero should equal reward
247        let result = reward.clone() + zero.clone();
248        assert_eq!(result, reward);
249    }
250
251    /// Test that rewards can be added together
252    #[test]
253    fn test_reward_addition() {
254        let reward1 = TestReward(10.0);
255        let reward2 = TestReward(25.5);
256        let result = reward1 + reward2;
257
258        assert_eq!(result, TestReward(35.5));
259    }
260
261    /// Test that negative rewards can be added
262    #[test]
263    fn test_reward_negative_addition() {
264        let positive = TestReward(100.0);
265        let negative = TestReward(-30.0);
266        let result = positive + negative;
267
268        assert_eq!(result, TestReward(70.0));
269    }
270
271    /// Test that rewards can be converted to f32
272    #[test]
273    fn test_reward_into_f32() {
274        let reward = TestReward(42.5);
275        let as_f32: f32 = reward.into();
276
277        assert_eq!(as_f32, 42.5);
278    }
279
280    /// Test that zero reward converts to 0.0
281    #[test]
282    fn test_reward_zero_into_f32() {
283        let zero = TestReward::zero();
284        let as_f32: f32 = zero.into();
285
286        assert_eq!(as_f32, 0.0);
287    }
288
289    /// Test that rewards are cloneable
290    #[test]
291    fn test_reward_clone() {
292        let original = TestReward(123.456);
293        let cloned = original.clone();
294
295        assert_eq!(original, cloned);
296    }
297
298    /// Test that rewards implement Debug
299    #[test]
300    fn test_reward_debug() {
301        let reward = TestReward(42.0);
302        let debug_str = format!("{:?}", reward);
303
304        assert!(!debug_str.is_empty());
305        assert!(debug_str.contains("TestReward"));
306    }
307
308    // ===== Arithmetic Properties Tests =====
309
310    /// Test accumulated reward through chained additions
311    #[test]
312    fn test_reward_accumulation() {
313        let mut accumulated = TestReward::zero();
314        let rewards = vec![TestReward(10.0), TestReward(20.0), TestReward(15.0)];
315
316        for reward in rewards {
317            accumulated = accumulated + reward;
318        }
319
320        assert_eq!(accumulated, TestReward(45.0));
321    }
322
323    /// Test reward trait with floating point precision
324    #[test]
325    fn test_reward_floating_point_precision() {
326        let r1 = TestReward(0.1);
327        let r2 = TestReward(0.2);
328        let result = r1 + r2;
329
330        // Account for floating point imprecision
331        let expected = 0.3;
332        let as_f32: f32 = result.into();
333        assert!((as_f32 - expected).abs() < 1e-6);
334    }
335
336    /// Test addition associativity: (a + b) + c == a + (b + c)
337    #[test]
338    fn test_reward_addition_associativity() {
339        let r1 = TestReward(5.0);
340        let r2 = TestReward(10.0);
341        let r3 = TestReward(15.0);
342
343        let left = (r1.clone() + r2.clone()) + r3.clone();
344        let right = r1 + (r2 + r3);
345
346        assert_eq!(left, right);
347    }
348
349    /// Test addition commutativity: a + b == b + a
350    #[test]
351    fn test_reward_addition_commutativity() {
352        let r1 = TestReward(7.5);
353        let r2 = TestReward(12.5);
354
355        let left = r1.clone() + r2.clone();
356        let right = r2 + r1;
357
358        assert_eq!(left, right);
359    }
360
361    // ===== Special Values Tests =====
362
363    /// Test reward arithmetic with large values
364    #[test]
365    fn test_reward_large_values() {
366        let large1 = TestReward(1e6);
367        let large2 = TestReward(1e6);
368
369        let result = large1 + large2;
370        let result_f32: f32 = result.into();
371
372        assert_eq!(result_f32, 2e6);
373    }
374
375    /// Test reward arithmetic with small values
376    #[test]
377    fn test_reward_small_values() {
378        let small1 = TestReward(1e-6);
379        let small2 = TestReward(1e-6);
380
381        let result = small1 + small2;
382        let result_f32: f32 = result.into();
383
384        assert!((result_f32 - 2e-6).abs() < 1e-7);
385    }
386
387    /// Test mixed positive and negative rewards
388    #[test]
389    fn test_reward_mixed_signs() {
390        let positive = TestReward(10.0);
391        let negative = TestReward(-5.0);
392
393        let pos_then_neg = positive.clone() + negative.clone();
394        let pos_then_neg_f32: f32 = pos_then_neg.into();
395
396        let neg_then_pos = negative.clone() + positive.clone();
397        let neg_then_pos_f32: f32 = neg_then_pos.into();
398
399        assert_eq!(pos_then_neg_f32, 5.0);
400        assert_eq!(neg_then_pos_f32, 5.0);
401    }
402
403    /// ========================================================================
404    /// GameState example to test the State trait implementation
405    /// ========================================================================
406    #[derive(Debug, Clone, Serialize, Deserialize)]
407    struct GameStateObservation {
408        state_id: u8,
409        level: u8,
410        score: u32,
411    }
412
413    impl Observation<1> for GameStateObservation {
414        fn shape() -> [usize; 1] {
415            [3] // 3 features: state_id, level, score
416        }
417    }
418
419    #[derive(Debug, Clone, PartialEq)]
420    enum GameState {
421        Menu,
422        Playing { level: u8 },
423        GameOver { score: u32 },
424    }
425
426    impl State<1> for GameState {
427        type Observation = GameStateObservation;
428
429        fn observe(&self) -> Self::Observation {
430            match self {
431                GameState::Menu => GameStateObservation {
432                    state_id: 0,
433                    level: 0,
434                    score: 0,
435                },
436                GameState::Playing { level } => GameStateObservation {
437                    state_id: 1,
438                    level: *level,
439                    score: 0,
440                },
441                GameState::GameOver { score } => GameStateObservation {
442                    state_id: 2,
443                    level: 0,
444                    score: *score,
445                },
446            }
447        }
448
449        fn shape() -> [usize; 1] {
450            [3] // 3 features: state_id, level, score
451        }
452
453        fn is_valid(&self) -> bool {
454            match self {
455                GameState::Playing { level } => *level > 0 && *level <= 10,
456                _ => true,
457            }
458        }
459
460        fn numel(&self) -> usize {
461            // Encode as 3 features: [state_id, level, score]
462            3
463        }
464    }
465
466    /// Test state validation for each state variant
467    #[test]
468    fn test_game_state_validation() {
469        // Menu state should always be valid
470        let menu_state = GameState::Menu;
471        assert!(menu_state.is_valid(), "Menu state should always be valid");
472
473        // GameOver state should always be valid
474        let game_over_state = GameState::GameOver { score: 1000 };
475        assert!(
476            game_over_state.is_valid(),
477            "GameOver state should always be valid"
478        );
479
480        // Playing state with valid levels should be valid
481        for level in 1..=10 {
482            let playing_state = GameState::Playing { level };
483            assert!(
484                playing_state.is_valid(),
485                "Playing state with level {} should be valid",
486                level
487            );
488        }
489
490        // Playing state with invalid levels should be invalid
491        let invalid_levels = [0, 11, 255];
492        for level in invalid_levels {
493            let invalid_state = GameState::Playing { level };
494            assert!(
495                !invalid_state.is_valid(),
496                "Playing state with level {} should be invalid",
497                level
498            );
499        }
500    }
501
502    /// Test that numel returns 3 for all state variants
503    #[test]
504    fn test_game_state_numel() {
505        let test_states = [
506            GameState::Menu,
507            GameState::Playing { level: 5 },
508            GameState::GameOver { score: 1000 },
509        ];
510
511        for state in test_states {
512            assert_eq!(
513                state.numel(),
514                3,
515                "Number of elements should be 3 for all states"
516            );
517        }
518    }
519
520    /// Test that shape returns [3] for all state variants
521    #[test]
522    fn test_game_state_shape() {
523        let test_states = [
524            GameState::Menu,
525            GameState::Playing { level: 5 },
526            GameState::GameOver { score: 1000 },
527        ];
528
529        for _state in test_states {
530            assert_eq!(
531                GameState::shape(),
532                [3],
533                "Shape should be [3] for all states"
534            );
535        }
536    }
537
538    /// Test the invariant: numel() should equal product of shape()
539    #[test]
540    fn test_game_state_consistency() {
541        let test_states = [
542            GameState::Menu,
543            GameState::Playing { level: 5 },
544            GameState::GameOver { score: 1000 },
545        ];
546
547        for state in test_states {
548            let numel = state.numel();
549            let shape_product: usize = GameState::shape().iter().product();
550            assert_eq!(
551                numel, shape_product,
552                "numel({}) should equal shape product({})",
553                numel, shape_product
554            );
555        }
556    }
557
558    /// Test that filtering states by validity works correctly
559    #[test]
560    fn test_game_state_filtering() {
561        let states = vec![
562            GameState::Menu,
563            GameState::Playing { level: 5 },
564            GameState::Playing { level: 0 }, // Invalid
565            GameState::GameOver { score: 1000 },
566        ];
567
568        let valid_states: Vec<_> = states.into_iter().filter(|s| s.is_valid()).collect();
569
570        assert_eq!(
571            valid_states.len(),
572            3,
573            "Should have 3 valid states out of 4 total"
574        );
575        assert!(
576            valid_states.iter().all(|s| s.is_valid()),
577            "All filtered states should be valid"
578        );
579
580        // Verify the invalid state was filtered out
581        assert!(
582            !valid_states.contains(&GameState::Playing { level: 0 }),
583            "Invalid playing state should be filtered out"
584        );
585    }
586
587    /// Test edge cases for Playing state level bounds
588    #[test]
589    fn test_playing_state_edge_cases() {
590        // Test boundary values
591        let min_valid_level = GameState::Playing { level: 1 };
592        assert!(
593            min_valid_level.is_valid(),
594            "Level 1 should be valid (minimum valid)"
595        );
596
597        let max_valid_level = GameState::Playing { level: 10 };
598        assert!(
599            max_valid_level.is_valid(),
600            "Level 10 should be valid (maximum valid)"
601        );
602
603        let below_min = GameState::Playing { level: 0 };
604        assert!(
605            !below_min.is_valid(),
606            "Level 0 should be invalid (below minimum)"
607        );
608
609        let above_max = GameState::Playing { level: 11 };
610        assert!(
611            !above_max.is_valid(),
612            "Level 11 should be invalid (above maximum)"
613        );
614    }
615
616    /// Test that observe() generates correct observations for each state variant
617    #[test]
618    fn test_game_state_observe() {
619        // Test Menu state observation
620        let menu_state = GameState::Menu;
621        let menu_obs = menu_state.observe();
622        assert_eq!(menu_obs.state_id, 0, "Menu state should have state_id 0");
623        assert_eq!(menu_obs.level, 0, "Menu state should have level 0");
624        assert_eq!(menu_obs.score, 0, "Menu state should have score 0");
625
626        // Test Playing state observation
627        let playing_state = GameState::Playing { level: 5 };
628        let playing_obs = playing_state.observe();
629        assert_eq!(
630            playing_obs.state_id, 1,
631            "Playing state should have state_id 1"
632        );
633        assert_eq!(playing_obs.level, 5, "Playing state should preserve level");
634        assert_eq!(playing_obs.score, 0, "Playing state should have score 0");
635
636        // Test GameOver state observation
637        let game_over_state = GameState::GameOver { score: 1000 };
638        let game_over_obs = game_over_state.observe();
639        assert_eq!(
640            game_over_obs.state_id, 2,
641            "GameOver state should have state_id 2"
642        );
643        assert_eq!(game_over_obs.level, 0, "GameOver state should have level 0");
644        assert_eq!(
645            game_over_obs.score, 1000,
646            "GameOver state should preserve score"
647        );
648    }
649
650    /// Test GameStateObservation shape
651    #[test]
652    fn test_game_state_observation_shape() {
653        assert_eq!(
654            GameStateObservation::shape(),
655            [3],
656            "GameStateObservation should have shape [3]"
657        );
658        assert_eq!(
659            GameStateObservation::RANK,
660            1,
661            "GameStateObservation should have rank 1"
662        );
663    }
664
665    /// ========================================================================
666    /// GridPosition example to test the State trait implementation
667    /// ========================================================================
668    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
669    struct GridObservation {
670        x: i32,
671        y: i32,
672    }
673
674    impl Observation<1> for GridObservation {
675        fn shape() -> [usize; 1] {
676            [2] // 2 coordinates: x, y
677        }
678    }
679
680    #[derive(Debug, Clone, Serialize, Deserialize)]
681    struct GridPosition {
682        x: i32,
683        y: i32,
684        max_x: i32,
685        max_y: i32,
686    }
687
688    impl State<1> for GridPosition {
689        type Observation = GridObservation;
690
691        fn observe(&self) -> Self::Observation {
692            GridObservation {
693                x: self.x,
694                y: self.y,
695            }
696        }
697
698        fn shape() -> [usize; 1] {
699            [2] // 2 coordinates: x, y
700        }
701
702        fn is_valid(&self) -> bool {
703            self.x >= 0 && self.y >= 0 && self.x < self.max_x && self.y < self.max_y
704        }
705
706        fn numel(&self) -> usize {
707            2 // x and y coordinates
708        }
709    }
710
711    impl GridPosition {
712        /// Flatten the grid position to a vector of f32 values
713        fn flatten(&self) -> Vec<f32> {
714            vec![
715                self.x as f32,
716                self.y as f32,
717                self.max_x as f32,
718                self.max_y as f32,
719            ]
720        }
721    }
722
723    /// Test GridPosition validation
724    #[test]
725    fn test_grid_position_validation() {
726        let valid = GridPosition {
727            x: 5,
728            y: 3,
729            max_x: 10,
730            max_y: 10,
731        };
732        assert!(valid.is_valid(), "x, y should be valid.");
733        //
734        let invalid = GridPosition {
735            x: 15,
736            y: 3,
737            max_x: 10,
738            max_y: 10,
739        };
740        assert!(
741            !invalid.is_valid(),
742            "x is larger than max_x and therefore invalid."
743        );
744    }
745
746    /// Test GridPosition flatten
747    #[test]
748    fn test_grid_position_flattening() {
749        let pos1 = GridPosition {
750            x: 3,
751            y: 7,
752            max_x: 10,
753            max_y: 10,
754        };
755        let pos2 = GridPosition {
756            x: 0,
757            y: 0,
758            max_x: 10,
759            max_y: 10,
760        };
761        let pos3 = GridPosition {
762            x: 9,
763            y: 9,
764            max_x: 10,
765            max_y: 10,
766        };
767        let flat1 = pos1.flatten();
768        let flat2 = pos2.flatten();
769        let flat3 = pos3.flatten();
770
771        assert_eq!(flat1, vec![3.0, 7.0, 10.0, 10.0]);
772        assert_eq!(flat2, vec![0.0, 0.0, 10.0, 10.0]);
773        assert_eq!(flat3, vec![9.0, 9.0, 10.0, 10.0]);
774    }
775
776    /// Test that observe() generates correct observations for GridPosition
777    #[test]
778    fn test_grid_position_observe() {
779        let pos = GridPosition {
780            x: 5,
781            y: 3,
782            max_x: 10,
783            max_y: 10,
784        };
785        let obs = pos.observe();
786        assert_eq!(obs.x, 5, "Observation should preserve x coordinate");
787        assert_eq!(obs.y, 3, "Observation should preserve y coordinate");
788
789        // Test with different positions
790        let origin = GridPosition {
791            x: 0,
792            y: 0,
793            max_x: 10,
794            max_y: 10,
795        };
796        let origin_obs = origin.observe();
797        assert_eq!(origin_obs.x, 0, "Origin observation should have x = 0");
798        assert_eq!(origin_obs.y, 0, "Origin observation should have y = 0");
799
800        // Test with edge position
801        let edge = GridPosition {
802            x: 9,
803            y: 9,
804            max_x: 10,
805            max_y: 10,
806        };
807        let edge_obs = edge.observe();
808        assert_eq!(edge_obs.x, 9, "Edge observation should have x = 9");
809        assert_eq!(edge_obs.y, 9, "Edge observation should have y = 9");
810    }
811
812    /// Test GridObservation shape
813    #[test]
814    fn test_grid_observation_shape() {
815        assert_eq!(
816            GridObservation::shape(),
817            [2],
818            "GridObservation should have shape [2]"
819        );
820        assert_eq!(
821            GridObservation::RANK,
822            1,
823            "GridObservation should have rank 1"
824        );
825    }
826
827    /// Test that GridPosition numel matches shape product
828    #[test]
829    fn test_grid_position_consistency() {
830        let pos = GridPosition {
831            x: 5,
832            y: 3,
833            max_x: 10,
834            max_y: 10,
835        };
836        let numel = pos.numel();
837        let shape_product: usize = GridPosition::shape().iter().product();
838        assert_eq!(
839            numel, shape_product,
840            "numel should equal shape product for GridPosition"
841        );
842        assert_eq!(numel, 2, "GridPosition should have numel of 2");
843    }
844
845    /// Test State trait const RANK value
846    #[test]
847    fn test_state_rank_constant() {
848        assert_eq!(
849            <GameState as State<1>>::RANK,
850            1,
851            "GameState should have RANK = 1"
852        );
853        assert_eq!(
854            <GridPosition as State<1>>::RANK,
855            1,
856            "GridPosition should have RANK = 1"
857        );
858    }
859}