#[cfg(feature = "ai")]
use crate::prelude::*;
#[cfg(feature = "ai")]
pub trait CanonicalState {
fn to_state_vector(&self) -> Vec<f32>;
}
#[cfg(feature = "ai")]
impl CanonicalState for Match {
fn to_state_vector(&self) -> Vec<f32> {
let mut v = Vec::with_capacity(151);
let match_rules_points = self.get_points().unwrap() as f32;
v.push(match_rules_points / 21.0);
v.push(if self.get_cube_play().unwrap() {
1.0
} else {
0.0
});
v.push(if self.get_beaver().unwrap() { 1.0 } else { 0.0 });
v.push(if self.get_raccoon().unwrap() {
1.0
} else {
0.0
});
let murphy = self.get_murphy().unwrap();
v.push(if murphy.0 { 1.0 } else { 0.0 });
v.push(murphy.1 as f32 / 5.0);
v.push(if self.get_jacobi().unwrap() { 1.0 } else { 0.0 });
v.push(if self.get_crawford().unwrap() {
1.0
} else {
0.0
});
v.push(if self.get_holland().unwrap() {
1.0
} else {
0.0
});
v.push(self.get_score().unwrap().0 as f32 / match_rules_points);
v.push(self.get_score().unwrap().1 as f32 / match_rules_points);
v.push((self.get_cube_value().unwrap() as f32).log2() / 6.0); let cube_owner = self.get_cube_owner().unwrap();
v.push(if cube_owner == Player::Nobody {
1.0
} else {
0.0
});
v.push(if cube_owner == Player::Player0 {
1.0
} else {
0.0
});
v.push(if self.get_match_state().unwrap() == &MatchState::Start {
1.0
} else {
0.0
});
v.push(
if self.get_match_state().unwrap() == &MatchState::OfferCube(Player::Player0) {
1.0
} else {
0.0
},
);
v.push(
if self.get_match_state().unwrap() == &MatchState::OfferCube(Player::Player1) {
1.0
} else {
0.0
},
);
v.push(
if self.get_match_state().unwrap() == &MatchState::AwaitOfferAcceptance(Player::Player0)
{
1.0
} else {
0.0
},
);
v.push(
if self.get_match_state().unwrap() == &MatchState::AwaitOfferAcceptance(Player::Player1)
{
1.0
} else {
0.0
},
);
v.push(
if self.get_match_state().unwrap()
== &MatchState::AwaitBeaverAcceptance(Player::Player0)
{
1.0
} else {
0.0
},
);
v.push(
if self.get_match_state().unwrap()
== &MatchState::AwaitBeaverAcceptance(Player::Player1)
{
1.0
} else {
0.0
},
);
v.push(
if self.get_match_state().unwrap() == &MatchState::PlayGame(Player::Player0) {
1.0
} else {
0.0
},
);
v.push(
if self.get_match_state().unwrap() == &MatchState::PlayGame(Player::Player1) {
1.0
} else {
0.0
},
);
v.push(
if self.get_match_state().unwrap() == &MatchState::End(Player::Player0) {
1.0
} else {
0.0
},
);
v.push(
if self.get_match_state().unwrap() == &MatchState::End(Player::Player1) {
1.0
} else {
0.0
},
);
let board_state = self.get_board();
for i in 0..24 {
if board_state.board[i].is_positive() {
v.push(0.0);
} else {
v.push(1.0);
}
v.extend(encode_checkers(board_state.board[i].unsigned_abs()));
}
v.push(board_state.bar.0 as f32 / 2.0); v.push(board_state.bar.1 as f32 / 2.0); v.push(board_state.off.0 as f32 / 2.0); v.push(board_state.off.1 as f32 / 2.0);
v.push(board_state.dice.0 as f32 / 6.0);
v.push(board_state.dice.1 as f32 / 6.0);
v
}
}
#[cfg(feature = "ai")]
fn encode_checkers(count: u8) -> [f32; 4] {
let mut res = [0.0; 4];
if count >= 1 {
res[0] = 1.0;
}
if count >= 2 {
res[1] = 1.0;
}
if count >= 3 {
res[2] = 1.0;
}
if count > 3 {
res[3] = (count - 3) as f32 / 2.0;
}
res
}
#[cfg(all(feature = "ai", test))]
mod tests {
use super::*;
#[test]
fn test_encode_checkers_zero() {
let result = encode_checkers(0);
assert_eq!(result, [0.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_encode_checkers_one() {
let result = encode_checkers(1);
assert_eq!(result, [1.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_encode_checkers_two() {
let result = encode_checkers(2);
assert_eq!(result, [1.0, 1.0, 0.0, 0.0]);
}
#[test]
fn test_encode_checkers_three() {
let result = encode_checkers(3);
assert_eq!(result, [1.0, 1.0, 1.0, 0.0]);
}
#[test]
fn test_encode_checkers_five() {
let result = encode_checkers(5);
assert_eq!(result, [1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn test_encode_checkers_ten() {
let result = encode_checkers(10);
assert_eq!(result, [1.0, 1.0, 1.0, 3.5]);
}
#[test]
fn test_to_state_vector_length() {
let m = Match::new();
let vec = m.to_state_vector();
assert_eq!(
vec.len(),
151,
"State vector should have exactly 151 elements"
);
}
#[test]
fn test_to_state_vector_match_rules() {
let m = Match::new();
let vec = m.to_state_vector();
assert_eq!(vec[0], 7.0 / 21.0, "Points encoding");
assert_eq!(vec[1], 1.0, "Cube play encoding");
assert_eq!(vec[2], 0.0, "Beaver encoding");
assert_eq!(vec[3], 0.0, "Raccoon encoding");
assert_eq!(vec[4], 0.0, "Murphy enabled encoding");
assert_eq!(vec[5], 0.0, "Murphy value encoding");
assert_eq!(vec[6], 0.0, "Jacobi encoding");
assert_eq!(vec[7], 1.0, "Crawford encoding");
assert_eq!(vec[8], 0.0, "Holland encoding");
}
#[test]
fn test_to_state_vector_match_points() {
let mut m = Match::new();
m.set_points(7).unwrap();
m.roll(Player::Nobody).unwrap();
let vec = m.to_state_vector();
let score = m.get_score().unwrap();
assert_eq!(vec[9], score.0 as f32 / 7.0, "Player0 score encoding");
assert_eq!(vec[10], score.1 as f32 / 7.0, "Player1 score encoding");
}
#[test]
fn test_to_state_vector_cube_state() {
let mut m = Match::new();
m.set_cube_play(true).unwrap();
m.roll(Player::Nobody).unwrap();
let vec = m.to_state_vector();
let cube_value = m.get_cube_value().unwrap();
assert_eq!(
vec[11],
(cube_value as f32).log2() / 6.0,
"Cube value encoding"
);
let cube_owner = m.get_cube_owner().unwrap();
if cube_owner == Player::Nobody {
assert_eq!(vec[12], 1.0, "Cube owner Nobody encoding");
assert_eq!(vec[13], 0.0, "Cube owner Player0 encoding");
} else if cube_owner == Player::Player0 {
assert_eq!(vec[12], 0.0, "Cube owner Nobody encoding");
assert_eq!(vec[13], 1.0, "Cube owner Player0 encoding");
} else {
assert_eq!(vec[12], 0.0, "Cube owner Nobody encoding");
assert_eq!(vec[13], 0.0, "Cube owner Player0 encoding");
}
}
#[test]
fn test_to_state_vector_match_state() {
let m = Match::new();
let vec = m.to_state_vector();
let state = m.get_match_state().unwrap();
match state {
MatchState::Start => assert_eq!(vec[14], 1.0, "Start state encoding"),
MatchState::OfferCube(Player::Player0) => {
assert_eq!(vec[15], 1.0, "OfferCube Player0 encoding")
}
MatchState::OfferCube(Player::Player1) => {
assert_eq!(vec[16], 1.0, "OfferCube Player1 encoding")
}
MatchState::AwaitOfferAcceptance(Player::Player0) => {
assert_eq!(vec[17], 1.0, "AwaitOfferAcceptance Player0 encoding")
}
MatchState::AwaitOfferAcceptance(Player::Player1) => {
assert_eq!(vec[18], 1.0, "AwaitOfferAcceptance Player1 encoding")
}
MatchState::AwaitBeaverAcceptance(Player::Player0) => {
assert_eq!(vec[19], 1.0, "AwaitBeaverAcceptance Player0 encoding")
}
MatchState::AwaitBeaverAcceptance(Player::Player1) => {
assert_eq!(vec[20], 1.0, "AwaitBeaverAcceptance Player1 encoding")
}
MatchState::PlayGame(Player::Player0) => {
assert_eq!(vec[21], 1.0, "PlayGame Player0 encoding")
}
MatchState::PlayGame(Player::Player1) => {
assert_eq!(vec[22], 1.0, "PlayGame Player1 encoding")
}
MatchState::End(Player::Player0) => assert_eq!(vec[23], 1.0, "End Player0 encoding"),
MatchState::End(Player::Player1) => assert_eq!(vec[24], 1.0, "End Player1 encoding"),
_ => {}
}
}
#[test]
fn test_to_state_vector_board_points() {
let m = Match::new();
let vec = m.to_state_vector();
let board = m.get_board();
for i in 0..24 {
let base_idx = 25 + (i * 5);
if board.board[i].is_positive() {
assert_eq!(
vec[base_idx], 0.0,
"Point {} player indicator (positive)",
i
);
} else {
assert_eq!(
vec[base_idx], 1.0,
"Point {} player indicator (negative)",
i
);
}
let checker_count = board.board[i].unsigned_abs();
let expected_encoding = encode_checkers(checker_count);
for j in 0..4 {
assert_eq!(
vec[base_idx + 1 + j],
expected_encoding[j],
"Point {} checker encoding slot {}",
i,
j
);
}
}
}
#[test]
fn test_to_state_vector_bar_and_off() {
let m = Match::new();
let vec = m.to_state_vector();
let board = m.get_board();
assert_eq!(vec[145], board.bar.0 as f32 / 2.0, "Bar Player0 encoding");
assert_eq!(vec[146], board.bar.1 as f32 / 2.0, "Bar Player1 encoding");
assert_eq!(vec[147], board.off.0 as f32 / 2.0, "Off Player0 encoding");
assert_eq!(vec[148], board.off.1 as f32 / 2.0, "Off Player1 encoding");
}
#[test]
fn test_to_state_vector_dice() {
let mut m = Match::new();
m.roll(Player::Nobody).unwrap();
let vec = m.to_state_vector();
let board = m.get_board();
assert_eq!(vec[149], board.dice.0 as f32 / 6.0, "Die 1 encoding");
assert_eq!(vec[150], board.dice.1 as f32 / 6.0, "Die 2 encoding");
}
}