use crate::board::state::BoardState;
use crate::collections::piecemap::PieceMap;
use crate::error::ParseError;
use crate::game::GameStatus;
use crate::game::GameStatus::Ongoing;
use crate::pieces::Side;
use crate::play::{Play, PlayRecord};
use std::cmp::PartialEq;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub(crate) struct ShortPlayRecord {
side: Side,
play: Play,
captures: bool
}
impl<B: BoardState> From<&PlayRecord<B>> for ShortPlayRecord {
fn from(play_record: &PlayRecord<B>) -> Self {
Self {
side: play_record.side,
play: play_record.play,
captures: !play_record.effects.captures.is_empty()
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
struct PlayRecQueue {
queue: [Option<ShortPlayRecord>; 4],
first_i: usize
}
impl PlayRecQueue {
pub(crate) fn push(&mut self, value: Option<ShortPlayRecord>) {
self.queue[self.first_i] = value;
self.first_i = if self.first_i == 3 {
0
} else {
self.first_i + 1
}
}
pub(crate) fn first(&self) -> &Option<ShortPlayRecord> {
&self.queue[self.first_i]
}
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RepetitionTracker {
pub(crate) attacker_reps: usize,
pub(crate) defender_reps: usize,
attacker_mid_pair: bool,
defender_mid_pair: bool,
recent_plays: PlayRecQueue
}
impl RepetitionTracker {
fn is_mid_pair(&self, side: Side) -> bool {
match side {
Side::Attacker => self.attacker_mid_pair,
Side::Defender => self.defender_mid_pair
}
}
fn toggle_mid_pair(&mut self, side: Side) {
match side {
Side::Attacker => self.attacker_mid_pair = !self.attacker_mid_pair,
Side::Defender => self.defender_mid_pair = !self.defender_mid_pair,
}
}
fn check_repetition(&mut self, record: ShortPlayRecord) -> (bool, bool) {
if (!record.captures) && (Some(record) == *self.recent_plays.first()) {
let is_rep = !self.is_mid_pair(record.side);
self.toggle_mid_pair(record.side);
(is_rep, false)
} else {
(false, true)
}
}
pub fn get_repetitions(&self, side: Side) -> usize {
match side {
Side::Attacker => self.attacker_reps,
Side::Defender => self.defender_reps,
}
}
pub fn track_play(&mut self, side: Side, play: Play, captures: bool) {
let record = ShortPlayRecord { side, play, captures };
let (incr, reset) = self.check_repetition(record);
if incr {
match record.side {
Side::Attacker => self.attacker_reps += 1,
Side::Defender => self.defender_reps += 1,
}
} else if reset {
match record.side {
Side::Attacker => {
self.attacker_reps = 0;
self.attacker_mid_pair = false;
},
Side::Defender => {
self.defender_reps = 0;
self.defender_mid_pair = false;
},
}
}
self.recent_plays.push(Some(record));
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct GameState<B: BoardState> {
pub board: B,
pub side_to_play: Side,
pub repetitions: RepetitionTracker,
pub plays_since_capture: usize,
pub status: GameStatus,
pub turn: usize,
}
impl <B: BoardState> GameState<B> {
pub fn new(fen_str: &str, side_to_play: Side) -> Result<Self, ParseError> {
Ok(Self {
board: B::from_fen(fen_str)?,
side_to_play,
repetitions: RepetitionTracker::default(),
plays_since_capture: 0,
status: Ongoing,
turn: 0,
})
}
}
#[cfg(test)]
mod tests {
use crate::game::state::RepetitionTracker;
use crate::pieces::Side;
use crate::play::Play;
use std::str::FromStr;
#[test]
fn test_repetition_tracker() {
let mut tracker = RepetitionTracker::default();
for i in 0..5 {
tracker.track_play(Side::Attacker, Play::from_str("a1-b1").unwrap(), false);
assert_eq!(tracker.get_repetitions(Side::Attacker), i);
tracker.track_play(Side::Defender, Play::from_str("a2-b2").unwrap(), false);
assert_eq!(tracker.get_repetitions(Side::Defender), i);
tracker.track_play(Side::Attacker, Play::from_str("b1-a1").unwrap(), false);
assert_eq!(tracker.get_repetitions(Side::Attacker), i);
tracker.track_play(Side::Defender, Play::from_str("b2-a2").unwrap(), false);
assert_eq!(tracker.get_repetitions(Side::Defender), i);
}
for i in 0..5 {
tracker.track_play(Side::Attacker, Play::from_str("f1-g1").unwrap(), false);
assert_eq!(tracker.get_repetitions(Side::Attacker), i);
tracker.track_play(Side::Defender, Play::from_str("f2-g2").unwrap(), false);
assert_eq!(tracker.get_repetitions(Side::Defender), i);
tracker.track_play(Side::Attacker, Play::from_str("g1-f1").unwrap(), false);
assert_eq!(tracker.get_repetitions(Side::Attacker), i);
tracker.track_play(Side::Defender, Play::from_str("g2-f2").unwrap(), false);
assert_eq!(tracker.get_repetitions(Side::Defender), i);
}
}
}