use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt::Debug;
pub trait UpdateFunction<Input, Output> {
fn update(&self, current: &Output, input: &Input) -> Output;
}
pub trait Reward: Clone + std::ops::Add<Output = Self> + Into<f32> + Debug {
fn zero() -> Self;
}
pub trait Observation<const R: usize>:
Debug + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>
{
const RANK: usize = R;
fn shape() -> [usize; R];
}
pub trait State<const R: usize>: Debug + Clone + Send + Sync {
const RANK: usize = R;
type Observation: Observation<R>;
fn shape() -> [usize; R];
fn observe(&self) -> Self::Observation;
fn is_valid(&self) -> bool;
fn numel(&self) -> usize {
Self::shape().iter().product()
}
}
pub trait Action<const R: usize>: Debug + Clone + Sized {
const RANK: usize = R;
fn shape() -> [usize; R];
fn is_valid(&self) -> bool;
}
pub trait TransitionDynamics<const SR: usize, const AR: usize, S: State<SR>, A: Action<AR>> {
fn transition(&self, state: &S, action: &A) -> S;
}
#[derive(Debug, Clone, PartialEq)]
pub struct TensorConversionError {
pub message: String,
}
impl std::fmt::Display for TensorConversionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid tensor conversion: {}", self.message)
}
}
impl Error for TensorConversionError {}
pub trait TensorConvertible<const R: usize, B: Backend>: Sized {
fn to_tensor(&self, device: &<B as burn::tensor::backend::BackendTypes>::Device) -> Tensor<B, R>;
fn from_tensor(tensor: Tensor<B, R>) -> Result<Self, TensorConversionError>;
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Debug, PartialEq)]
struct TestReward(f32);
impl Reward for TestReward {
fn zero() -> Self {
TestReward(0.0)
}
}
impl std::ops::Add for TestReward {
type Output = Self;
fn add(self, other: Self) -> Self {
TestReward(self.0 + other.0)
}
}
impl From<TestReward> for f32 {
fn from(reward: TestReward) -> f32 {
reward.0
}
}
#[test]
fn test_reward_zero_is_additive_identity() {
let zero = TestReward::zero();
let reward = TestReward(42.5);
let result = zero.clone() + reward.clone();
assert_eq!(result, reward);
let result = reward.clone() + zero.clone();
assert_eq!(result, reward);
}
#[test]
fn test_reward_addition() {
let reward1 = TestReward(10.0);
let reward2 = TestReward(25.5);
let result = reward1 + reward2;
assert_eq!(result, TestReward(35.5));
}
#[test]
fn test_reward_negative_addition() {
let positive = TestReward(100.0);
let negative = TestReward(-30.0);
let result = positive + negative;
assert_eq!(result, TestReward(70.0));
}
#[test]
fn test_reward_into_f32() {
let reward = TestReward(42.5);
let as_f32: f32 = reward.into();
assert_eq!(as_f32, 42.5);
}
#[test]
fn test_reward_zero_into_f32() {
let zero = TestReward::zero();
let as_f32: f32 = zero.into();
assert_eq!(as_f32, 0.0);
}
#[test]
fn test_reward_clone() {
let original = TestReward(123.456);
let cloned = original.clone();
assert_eq!(original, cloned);
}
#[test]
fn test_reward_debug() {
let reward = TestReward(42.0);
let debug_str = format!("{:?}", reward);
assert!(!debug_str.is_empty());
assert!(debug_str.contains("TestReward"));
}
#[test]
fn test_reward_accumulation() {
let mut accumulated = TestReward::zero();
let rewards = vec![TestReward(10.0), TestReward(20.0), TestReward(15.0)];
for reward in rewards {
accumulated = accumulated + reward;
}
assert_eq!(accumulated, TestReward(45.0));
}
#[test]
fn test_reward_floating_point_precision() {
let r1 = TestReward(0.1);
let r2 = TestReward(0.2);
let result = r1 + r2;
let expected = 0.3;
let as_f32: f32 = result.into();
assert!((as_f32 - expected).abs() < 1e-6);
}
#[test]
fn test_reward_addition_associativity() {
let r1 = TestReward(5.0);
let r2 = TestReward(10.0);
let r3 = TestReward(15.0);
let left = (r1.clone() + r2.clone()) + r3.clone();
let right = r1 + (r2 + r3);
assert_eq!(left, right);
}
#[test]
fn test_reward_addition_commutativity() {
let r1 = TestReward(7.5);
let r2 = TestReward(12.5);
let left = r1.clone() + r2.clone();
let right = r2 + r1;
assert_eq!(left, right);
}
#[test]
fn test_reward_large_values() {
let large1 = TestReward(1e6);
let large2 = TestReward(1e6);
let result = large1 + large2;
let result_f32: f32 = result.into();
assert_eq!(result_f32, 2e6);
}
#[test]
fn test_reward_small_values() {
let small1 = TestReward(1e-6);
let small2 = TestReward(1e-6);
let result = small1 + small2;
let result_f32: f32 = result.into();
assert!((result_f32 - 2e-6).abs() < 1e-7);
}
#[test]
fn test_reward_mixed_signs() {
let positive = TestReward(10.0);
let negative = TestReward(-5.0);
let pos_then_neg = positive.clone() + negative.clone();
let pos_then_neg_f32: f32 = pos_then_neg.into();
let neg_then_pos = negative.clone() + positive.clone();
let neg_then_pos_f32: f32 = neg_then_pos.into();
assert_eq!(pos_then_neg_f32, 5.0);
assert_eq!(neg_then_pos_f32, 5.0);
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GameStateObservation {
state_id: u8,
level: u8,
score: u32,
}
impl Observation<1> for GameStateObservation {
fn shape() -> [usize; 1] {
[3] }
}
#[derive(Debug, Clone, PartialEq)]
enum GameState {
Menu,
Playing { level: u8 },
GameOver { score: u32 },
}
impl State<1> for GameState {
type Observation = GameStateObservation;
fn observe(&self) -> Self::Observation {
match self {
GameState::Menu => GameStateObservation {
state_id: 0,
level: 0,
score: 0,
},
GameState::Playing { level } => GameStateObservation {
state_id: 1,
level: *level,
score: 0,
},
GameState::GameOver { score } => GameStateObservation {
state_id: 2,
level: 0,
score: *score,
},
}
}
fn shape() -> [usize; 1] {
[3] }
fn is_valid(&self) -> bool {
match self {
GameState::Playing { level } => *level > 0 && *level <= 10,
_ => true,
}
}
fn numel(&self) -> usize {
3
}
}
#[test]
fn test_game_state_validation() {
let menu_state = GameState::Menu;
assert!(menu_state.is_valid(), "Menu state should always be valid");
let game_over_state = GameState::GameOver { score: 1000 };
assert!(
game_over_state.is_valid(),
"GameOver state should always be valid"
);
for level in 1..=10 {
let playing_state = GameState::Playing { level };
assert!(
playing_state.is_valid(),
"Playing state with level {} should be valid",
level
);
}
let invalid_levels = [0, 11, 255];
for level in invalid_levels {
let invalid_state = GameState::Playing { level };
assert!(
!invalid_state.is_valid(),
"Playing state with level {} should be invalid",
level
);
}
}
#[test]
fn test_game_state_numel() {
let test_states = [
GameState::Menu,
GameState::Playing { level: 5 },
GameState::GameOver { score: 1000 },
];
for state in test_states {
assert_eq!(
state.numel(),
3,
"Number of elements should be 3 for all states"
);
}
}
#[test]
fn test_game_state_shape() {
let test_states = [
GameState::Menu,
GameState::Playing { level: 5 },
GameState::GameOver { score: 1000 },
];
for _state in test_states {
assert_eq!(
GameState::shape(),
[3],
"Shape should be [3] for all states"
);
}
}
#[test]
fn test_game_state_consistency() {
let test_states = [
GameState::Menu,
GameState::Playing { level: 5 },
GameState::GameOver { score: 1000 },
];
for state in test_states {
let numel = state.numel();
let shape_product: usize = GameState::shape().iter().product();
assert_eq!(
numel, shape_product,
"numel({}) should equal shape product({})",
numel, shape_product
);
}
}
#[test]
fn test_game_state_filtering() {
let states = vec![
GameState::Menu,
GameState::Playing { level: 5 },
GameState::Playing { level: 0 }, GameState::GameOver { score: 1000 },
];
let valid_states: Vec<_> = states.into_iter().filter(|s| s.is_valid()).collect();
assert_eq!(
valid_states.len(),
3,
"Should have 3 valid states out of 4 total"
);
assert!(
valid_states.iter().all(|s| s.is_valid()),
"All filtered states should be valid"
);
assert!(
!valid_states.contains(&GameState::Playing { level: 0 }),
"Invalid playing state should be filtered out"
);
}
#[test]
fn test_playing_state_edge_cases() {
let min_valid_level = GameState::Playing { level: 1 };
assert!(
min_valid_level.is_valid(),
"Level 1 should be valid (minimum valid)"
);
let max_valid_level = GameState::Playing { level: 10 };
assert!(
max_valid_level.is_valid(),
"Level 10 should be valid (maximum valid)"
);
let below_min = GameState::Playing { level: 0 };
assert!(
!below_min.is_valid(),
"Level 0 should be invalid (below minimum)"
);
let above_max = GameState::Playing { level: 11 };
assert!(
!above_max.is_valid(),
"Level 11 should be invalid (above maximum)"
);
}
#[test]
fn test_game_state_observe() {
let menu_state = GameState::Menu;
let menu_obs = menu_state.observe();
assert_eq!(menu_obs.state_id, 0, "Menu state should have state_id 0");
assert_eq!(menu_obs.level, 0, "Menu state should have level 0");
assert_eq!(menu_obs.score, 0, "Menu state should have score 0");
let playing_state = GameState::Playing { level: 5 };
let playing_obs = playing_state.observe();
assert_eq!(
playing_obs.state_id, 1,
"Playing state should have state_id 1"
);
assert_eq!(playing_obs.level, 5, "Playing state should preserve level");
assert_eq!(playing_obs.score, 0, "Playing state should have score 0");
let game_over_state = GameState::GameOver { score: 1000 };
let game_over_obs = game_over_state.observe();
assert_eq!(
game_over_obs.state_id, 2,
"GameOver state should have state_id 2"
);
assert_eq!(game_over_obs.level, 0, "GameOver state should have level 0");
assert_eq!(
game_over_obs.score, 1000,
"GameOver state should preserve score"
);
}
#[test]
fn test_game_state_observation_shape() {
assert_eq!(
GameStateObservation::shape(),
[3],
"GameStateObservation should have shape [3]"
);
assert_eq!(
GameStateObservation::RANK,
1,
"GameStateObservation should have rank 1"
);
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct GridObservation {
x: i32,
y: i32,
}
impl Observation<1> for GridObservation {
fn shape() -> [usize; 1] {
[2] }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct GridPosition {
x: i32,
y: i32,
max_x: i32,
max_y: i32,
}
impl State<1> for GridPosition {
type Observation = GridObservation;
fn observe(&self) -> Self::Observation {
GridObservation {
x: self.x,
y: self.y,
}
}
fn shape() -> [usize; 1] {
[2] }
fn is_valid(&self) -> bool {
self.x >= 0 && self.y >= 0 && self.x < self.max_x && self.y < self.max_y
}
fn numel(&self) -> usize {
2 }
}
impl GridPosition {
fn flatten(&self) -> Vec<f32> {
vec![
self.x as f32,
self.y as f32,
self.max_x as f32,
self.max_y as f32,
]
}
}
#[test]
fn test_grid_position_validation() {
let valid = GridPosition {
x: 5,
y: 3,
max_x: 10,
max_y: 10,
};
assert!(valid.is_valid(), "x, y should be valid.");
let invalid = GridPosition {
x: 15,
y: 3,
max_x: 10,
max_y: 10,
};
assert!(
!invalid.is_valid(),
"x is larger than max_x and therefore invalid."
);
}
#[test]
fn test_grid_position_flattening() {
let pos1 = GridPosition {
x: 3,
y: 7,
max_x: 10,
max_y: 10,
};
let pos2 = GridPosition {
x: 0,
y: 0,
max_x: 10,
max_y: 10,
};
let pos3 = GridPosition {
x: 9,
y: 9,
max_x: 10,
max_y: 10,
};
let flat1 = pos1.flatten();
let flat2 = pos2.flatten();
let flat3 = pos3.flatten();
assert_eq!(flat1, vec![3.0, 7.0, 10.0, 10.0]);
assert_eq!(flat2, vec![0.0, 0.0, 10.0, 10.0]);
assert_eq!(flat3, vec![9.0, 9.0, 10.0, 10.0]);
}
#[test]
fn test_grid_position_observe() {
let pos = GridPosition {
x: 5,
y: 3,
max_x: 10,
max_y: 10,
};
let obs = pos.observe();
assert_eq!(obs.x, 5, "Observation should preserve x coordinate");
assert_eq!(obs.y, 3, "Observation should preserve y coordinate");
let origin = GridPosition {
x: 0,
y: 0,
max_x: 10,
max_y: 10,
};
let origin_obs = origin.observe();
assert_eq!(origin_obs.x, 0, "Origin observation should have x = 0");
assert_eq!(origin_obs.y, 0, "Origin observation should have y = 0");
let edge = GridPosition {
x: 9,
y: 9,
max_x: 10,
max_y: 10,
};
let edge_obs = edge.observe();
assert_eq!(edge_obs.x, 9, "Edge observation should have x = 9");
assert_eq!(edge_obs.y, 9, "Edge observation should have y = 9");
}
#[test]
fn test_grid_observation_shape() {
assert_eq!(
GridObservation::shape(),
[2],
"GridObservation should have shape [2]"
);
assert_eq!(
GridObservation::RANK,
1,
"GridObservation should have rank 1"
);
}
#[test]
fn test_grid_position_consistency() {
let pos = GridPosition {
x: 5,
y: 3,
max_x: 10,
max_y: 10,
};
let numel = pos.numel();
let shape_product: usize = GridPosition::shape().iter().product();
assert_eq!(
numel, shape_product,
"numel should equal shape product for GridPosition"
);
assert_eq!(numel, 2, "GridPosition should have numel of 2");
}
#[test]
fn test_state_rank_constant() {
assert_eq!(
<GameState as State<1>>::RANK,
1,
"GameState should have RANK = 1"
);
assert_eq!(
<GridPosition as State<1>>::RANK,
1,
"GridPosition should have RANK = 1"
);
}
}