1use burn::tensor::Tensor;
8use burn::tensor::backend::Backend;
9use serde::{Deserialize, Serialize};
10use std::error::Error;
11use std::fmt::Debug;
12
13pub trait UpdateFunction<Input, Output> {
17 fn update(&self, current: &Output, input: &Input) -> Output;
19}
20
21pub trait Reward: Clone + std::ops::Add<Output = Self> + Into<f32> + Debug {
23 fn zero() -> Self;
25}
26
27pub trait Observation<const R: usize>:
31 Debug + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>
32{
33 const RANK: usize = R;
40
41 fn shape() -> [usize; R];
47}
48
49pub trait State<const R: usize>: Debug + Clone + Send + Sync {
51 const RANK: usize = R;
58
59 type Observation: Observation<R>;
60
61 fn shape() -> [usize; R];
67
68 fn observe(&self) -> Self::Observation;
70
71 fn is_valid(&self) -> bool;
80
81 fn numel(&self) -> usize {
98 Self::shape().iter().product()
99 }
100}
101
102pub trait Action<const R: usize>: Debug + Clone + Sized {
124 const RANK: usize = R;
131
132 fn shape() -> [usize; R];
138
139 fn is_valid(&self) -> bool;
149}
150
151pub trait TransitionDynamics<const SR: usize, const AR: usize, S: State<SR>, A: Action<AR>> {
158 fn transition(&self, state: &S, action: &A) -> S;
160}
161
162#[derive(Debug, Clone, PartialEq)]
164pub struct TensorConversionError {
165 pub message: String,
167}
168
169impl std::fmt::Display for TensorConversionError {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 write!(f, "Invalid tensor conversion: {}", self.message)
172 }
173}
174
175impl Error for TensorConversionError {}
176
177pub trait TensorConvertible<const R: usize, B: Backend>: Sized {
194 fn to_tensor(&self, device: &<B as burn::tensor::backend::BackendTypes>::Device) -> Tensor<B, R>;
196
197 fn from_tensor(tensor: Tensor<B, R>) -> Result<Self, TensorConversionError>;
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[derive(Clone, Debug, PartialEq)]
212 struct TestReward(f32);
213
214 impl Reward for TestReward {
215 fn zero() -> Self {
216 TestReward(0.0)
217 }
218 }
219
220 impl std::ops::Add for TestReward {
221 type Output = Self;
222
223 fn add(self, other: Self) -> Self {
224 TestReward(self.0 + other.0)
225 }
226 }
227
228 impl From<TestReward> for f32 {
229 fn from(reward: TestReward) -> f32 {
230 reward.0
231 }
232 }
233
234 #[test]
238 fn test_reward_zero_is_additive_identity() {
239 let zero = TestReward::zero();
240 let reward = TestReward(42.5);
241
242 let result = zero.clone() + reward.clone();
244 assert_eq!(result, reward);
245
246 let result = reward.clone() + zero.clone();
248 assert_eq!(result, reward);
249 }
250
251 #[test]
253 fn test_reward_addition() {
254 let reward1 = TestReward(10.0);
255 let reward2 = TestReward(25.5);
256 let result = reward1 + reward2;
257
258 assert_eq!(result, TestReward(35.5));
259 }
260
261 #[test]
263 fn test_reward_negative_addition() {
264 let positive = TestReward(100.0);
265 let negative = TestReward(-30.0);
266 let result = positive + negative;
267
268 assert_eq!(result, TestReward(70.0));
269 }
270
271 #[test]
273 fn test_reward_into_f32() {
274 let reward = TestReward(42.5);
275 let as_f32: f32 = reward.into();
276
277 assert_eq!(as_f32, 42.5);
278 }
279
280 #[test]
282 fn test_reward_zero_into_f32() {
283 let zero = TestReward::zero();
284 let as_f32: f32 = zero.into();
285
286 assert_eq!(as_f32, 0.0);
287 }
288
289 #[test]
291 fn test_reward_clone() {
292 let original = TestReward(123.456);
293 let cloned = original.clone();
294
295 assert_eq!(original, cloned);
296 }
297
298 #[test]
300 fn test_reward_debug() {
301 let reward = TestReward(42.0);
302 let debug_str = format!("{:?}", reward);
303
304 assert!(!debug_str.is_empty());
305 assert!(debug_str.contains("TestReward"));
306 }
307
308 #[test]
312 fn test_reward_accumulation() {
313 let mut accumulated = TestReward::zero();
314 let rewards = vec![TestReward(10.0), TestReward(20.0), TestReward(15.0)];
315
316 for reward in rewards {
317 accumulated = accumulated + reward;
318 }
319
320 assert_eq!(accumulated, TestReward(45.0));
321 }
322
323 #[test]
325 fn test_reward_floating_point_precision() {
326 let r1 = TestReward(0.1);
327 let r2 = TestReward(0.2);
328 let result = r1 + r2;
329
330 let expected = 0.3;
332 let as_f32: f32 = result.into();
333 assert!((as_f32 - expected).abs() < 1e-6);
334 }
335
336 #[test]
338 fn test_reward_addition_associativity() {
339 let r1 = TestReward(5.0);
340 let r2 = TestReward(10.0);
341 let r3 = TestReward(15.0);
342
343 let left = (r1.clone() + r2.clone()) + r3.clone();
344 let right = r1 + (r2 + r3);
345
346 assert_eq!(left, right);
347 }
348
349 #[test]
351 fn test_reward_addition_commutativity() {
352 let r1 = TestReward(7.5);
353 let r2 = TestReward(12.5);
354
355 let left = r1.clone() + r2.clone();
356 let right = r2 + r1;
357
358 assert_eq!(left, right);
359 }
360
361 #[test]
365 fn test_reward_large_values() {
366 let large1 = TestReward(1e6);
367 let large2 = TestReward(1e6);
368
369 let result = large1 + large2;
370 let result_f32: f32 = result.into();
371
372 assert_eq!(result_f32, 2e6);
373 }
374
375 #[test]
377 fn test_reward_small_values() {
378 let small1 = TestReward(1e-6);
379 let small2 = TestReward(1e-6);
380
381 let result = small1 + small2;
382 let result_f32: f32 = result.into();
383
384 assert!((result_f32 - 2e-6).abs() < 1e-7);
385 }
386
387 #[test]
389 fn test_reward_mixed_signs() {
390 let positive = TestReward(10.0);
391 let negative = TestReward(-5.0);
392
393 let pos_then_neg = positive.clone() + negative.clone();
394 let pos_then_neg_f32: f32 = pos_then_neg.into();
395
396 let neg_then_pos = negative.clone() + positive.clone();
397 let neg_then_pos_f32: f32 = neg_then_pos.into();
398
399 assert_eq!(pos_then_neg_f32, 5.0);
400 assert_eq!(neg_then_pos_f32, 5.0);
401 }
402
403 #[derive(Debug, Clone, Serialize, Deserialize)]
407 struct GameStateObservation {
408 state_id: u8,
409 level: u8,
410 score: u32,
411 }
412
413 impl Observation<1> for GameStateObservation {
414 fn shape() -> [usize; 1] {
415 [3] }
417 }
418
419 #[derive(Debug, Clone, PartialEq)]
420 enum GameState {
421 Menu,
422 Playing { level: u8 },
423 GameOver { score: u32 },
424 }
425
426 impl State<1> for GameState {
427 type Observation = GameStateObservation;
428
429 fn observe(&self) -> Self::Observation {
430 match self {
431 GameState::Menu => GameStateObservation {
432 state_id: 0,
433 level: 0,
434 score: 0,
435 },
436 GameState::Playing { level } => GameStateObservation {
437 state_id: 1,
438 level: *level,
439 score: 0,
440 },
441 GameState::GameOver { score } => GameStateObservation {
442 state_id: 2,
443 level: 0,
444 score: *score,
445 },
446 }
447 }
448
449 fn shape() -> [usize; 1] {
450 [3] }
452
453 fn is_valid(&self) -> bool {
454 match self {
455 GameState::Playing { level } => *level > 0 && *level <= 10,
456 _ => true,
457 }
458 }
459
460 fn numel(&self) -> usize {
461 3
463 }
464 }
465
466 #[test]
468 fn test_game_state_validation() {
469 let menu_state = GameState::Menu;
471 assert!(menu_state.is_valid(), "Menu state should always be valid");
472
473 let game_over_state = GameState::GameOver { score: 1000 };
475 assert!(
476 game_over_state.is_valid(),
477 "GameOver state should always be valid"
478 );
479
480 for level in 1..=10 {
482 let playing_state = GameState::Playing { level };
483 assert!(
484 playing_state.is_valid(),
485 "Playing state with level {} should be valid",
486 level
487 );
488 }
489
490 let invalid_levels = [0, 11, 255];
492 for level in invalid_levels {
493 let invalid_state = GameState::Playing { level };
494 assert!(
495 !invalid_state.is_valid(),
496 "Playing state with level {} should be invalid",
497 level
498 );
499 }
500 }
501
502 #[test]
504 fn test_game_state_numel() {
505 let test_states = [
506 GameState::Menu,
507 GameState::Playing { level: 5 },
508 GameState::GameOver { score: 1000 },
509 ];
510
511 for state in test_states {
512 assert_eq!(
513 state.numel(),
514 3,
515 "Number of elements should be 3 for all states"
516 );
517 }
518 }
519
520 #[test]
522 fn test_game_state_shape() {
523 let test_states = [
524 GameState::Menu,
525 GameState::Playing { level: 5 },
526 GameState::GameOver { score: 1000 },
527 ];
528
529 for _state in test_states {
530 assert_eq!(
531 GameState::shape(),
532 [3],
533 "Shape should be [3] for all states"
534 );
535 }
536 }
537
538 #[test]
540 fn test_game_state_consistency() {
541 let test_states = [
542 GameState::Menu,
543 GameState::Playing { level: 5 },
544 GameState::GameOver { score: 1000 },
545 ];
546
547 for state in test_states {
548 let numel = state.numel();
549 let shape_product: usize = GameState::shape().iter().product();
550 assert_eq!(
551 numel, shape_product,
552 "numel({}) should equal shape product({})",
553 numel, shape_product
554 );
555 }
556 }
557
558 #[test]
560 fn test_game_state_filtering() {
561 let states = vec![
562 GameState::Menu,
563 GameState::Playing { level: 5 },
564 GameState::Playing { level: 0 }, GameState::GameOver { score: 1000 },
566 ];
567
568 let valid_states: Vec<_> = states.into_iter().filter(|s| s.is_valid()).collect();
569
570 assert_eq!(
571 valid_states.len(),
572 3,
573 "Should have 3 valid states out of 4 total"
574 );
575 assert!(
576 valid_states.iter().all(|s| s.is_valid()),
577 "All filtered states should be valid"
578 );
579
580 assert!(
582 !valid_states.contains(&GameState::Playing { level: 0 }),
583 "Invalid playing state should be filtered out"
584 );
585 }
586
587 #[test]
589 fn test_playing_state_edge_cases() {
590 let min_valid_level = GameState::Playing { level: 1 };
592 assert!(
593 min_valid_level.is_valid(),
594 "Level 1 should be valid (minimum valid)"
595 );
596
597 let max_valid_level = GameState::Playing { level: 10 };
598 assert!(
599 max_valid_level.is_valid(),
600 "Level 10 should be valid (maximum valid)"
601 );
602
603 let below_min = GameState::Playing { level: 0 };
604 assert!(
605 !below_min.is_valid(),
606 "Level 0 should be invalid (below minimum)"
607 );
608
609 let above_max = GameState::Playing { level: 11 };
610 assert!(
611 !above_max.is_valid(),
612 "Level 11 should be invalid (above maximum)"
613 );
614 }
615
616 #[test]
618 fn test_game_state_observe() {
619 let menu_state = GameState::Menu;
621 let menu_obs = menu_state.observe();
622 assert_eq!(menu_obs.state_id, 0, "Menu state should have state_id 0");
623 assert_eq!(menu_obs.level, 0, "Menu state should have level 0");
624 assert_eq!(menu_obs.score, 0, "Menu state should have score 0");
625
626 let playing_state = GameState::Playing { level: 5 };
628 let playing_obs = playing_state.observe();
629 assert_eq!(
630 playing_obs.state_id, 1,
631 "Playing state should have state_id 1"
632 );
633 assert_eq!(playing_obs.level, 5, "Playing state should preserve level");
634 assert_eq!(playing_obs.score, 0, "Playing state should have score 0");
635
636 let game_over_state = GameState::GameOver { score: 1000 };
638 let game_over_obs = game_over_state.observe();
639 assert_eq!(
640 game_over_obs.state_id, 2,
641 "GameOver state should have state_id 2"
642 );
643 assert_eq!(game_over_obs.level, 0, "GameOver state should have level 0");
644 assert_eq!(
645 game_over_obs.score, 1000,
646 "GameOver state should preserve score"
647 );
648 }
649
650 #[test]
652 fn test_game_state_observation_shape() {
653 assert_eq!(
654 GameStateObservation::shape(),
655 [3],
656 "GameStateObservation should have shape [3]"
657 );
658 assert_eq!(
659 GameStateObservation::RANK,
660 1,
661 "GameStateObservation should have rank 1"
662 );
663 }
664
665 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
669 struct GridObservation {
670 x: i32,
671 y: i32,
672 }
673
674 impl Observation<1> for GridObservation {
675 fn shape() -> [usize; 1] {
676 [2] }
678 }
679
680 #[derive(Debug, Clone, Serialize, Deserialize)]
681 struct GridPosition {
682 x: i32,
683 y: i32,
684 max_x: i32,
685 max_y: i32,
686 }
687
688 impl State<1> for GridPosition {
689 type Observation = GridObservation;
690
691 fn observe(&self) -> Self::Observation {
692 GridObservation {
693 x: self.x,
694 y: self.y,
695 }
696 }
697
698 fn shape() -> [usize; 1] {
699 [2] }
701
702 fn is_valid(&self) -> bool {
703 self.x >= 0 && self.y >= 0 && self.x < self.max_x && self.y < self.max_y
704 }
705
706 fn numel(&self) -> usize {
707 2 }
709 }
710
711 impl GridPosition {
712 fn flatten(&self) -> Vec<f32> {
714 vec![
715 self.x as f32,
716 self.y as f32,
717 self.max_x as f32,
718 self.max_y as f32,
719 ]
720 }
721 }
722
723 #[test]
725 fn test_grid_position_validation() {
726 let valid = GridPosition {
727 x: 5,
728 y: 3,
729 max_x: 10,
730 max_y: 10,
731 };
732 assert!(valid.is_valid(), "x, y should be valid.");
733 let invalid = GridPosition {
735 x: 15,
736 y: 3,
737 max_x: 10,
738 max_y: 10,
739 };
740 assert!(
741 !invalid.is_valid(),
742 "x is larger than max_x and therefore invalid."
743 );
744 }
745
746 #[test]
748 fn test_grid_position_flattening() {
749 let pos1 = GridPosition {
750 x: 3,
751 y: 7,
752 max_x: 10,
753 max_y: 10,
754 };
755 let pos2 = GridPosition {
756 x: 0,
757 y: 0,
758 max_x: 10,
759 max_y: 10,
760 };
761 let pos3 = GridPosition {
762 x: 9,
763 y: 9,
764 max_x: 10,
765 max_y: 10,
766 };
767 let flat1 = pos1.flatten();
768 let flat2 = pos2.flatten();
769 let flat3 = pos3.flatten();
770
771 assert_eq!(flat1, vec![3.0, 7.0, 10.0, 10.0]);
772 assert_eq!(flat2, vec![0.0, 0.0, 10.0, 10.0]);
773 assert_eq!(flat3, vec![9.0, 9.0, 10.0, 10.0]);
774 }
775
776 #[test]
778 fn test_grid_position_observe() {
779 let pos = GridPosition {
780 x: 5,
781 y: 3,
782 max_x: 10,
783 max_y: 10,
784 };
785 let obs = pos.observe();
786 assert_eq!(obs.x, 5, "Observation should preserve x coordinate");
787 assert_eq!(obs.y, 3, "Observation should preserve y coordinate");
788
789 let origin = GridPosition {
791 x: 0,
792 y: 0,
793 max_x: 10,
794 max_y: 10,
795 };
796 let origin_obs = origin.observe();
797 assert_eq!(origin_obs.x, 0, "Origin observation should have x = 0");
798 assert_eq!(origin_obs.y, 0, "Origin observation should have y = 0");
799
800 let edge = GridPosition {
802 x: 9,
803 y: 9,
804 max_x: 10,
805 max_y: 10,
806 };
807 let edge_obs = edge.observe();
808 assert_eq!(edge_obs.x, 9, "Edge observation should have x = 9");
809 assert_eq!(edge_obs.y, 9, "Edge observation should have y = 9");
810 }
811
812 #[test]
814 fn test_grid_observation_shape() {
815 assert_eq!(
816 GridObservation::shape(),
817 [2],
818 "GridObservation should have shape [2]"
819 );
820 assert_eq!(
821 GridObservation::RANK,
822 1,
823 "GridObservation should have rank 1"
824 );
825 }
826
827 #[test]
829 fn test_grid_position_consistency() {
830 let pos = GridPosition {
831 x: 5,
832 y: 3,
833 max_x: 10,
834 max_y: 10,
835 };
836 let numel = pos.numel();
837 let shape_product: usize = GridPosition::shape().iter().product();
838 assert_eq!(
839 numel, shape_product,
840 "numel should equal shape product for GridPosition"
841 );
842 assert_eq!(numel, 2, "GridPosition should have numel of 2");
843 }
844
845 #[test]
847 fn test_state_rank_constant() {
848 assert_eq!(
849 <GameState as State<1>>::RANK,
850 1,
851 "GameState should have RANK = 1"
852 );
853 assert_eq!(
854 <GridPosition as State<1>>::RANK,
855 1,
856 "GridPosition should have RANK = 1"
857 );
858 }
859}