#![allow(dead_code)]
#![allow(unused_variables)]
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct Move10(u16);
impl Move10 {
pub fn new(piece_id_within: u8, dest: u8) -> Move10 {
assert!(piece_id_within < 16, "piece_id must fit in 4 bits");
assert!(dest < 64, "dest must fit in 6 bits");
let bits: u16 = ((piece_id_within as u16) << 6) | (dest as u16);
Move10(bits & 0x03FF) }
pub fn piece_id(&self) -> u8 {
((self.0 >> 6) & 0x0F) as u8
}
pub fn dest(&self) -> u8 {
(self.0 & 0x3F) as u8
}
}
#[derive(Clone, Debug)]
pub struct MovePlanes {
pub planes: [u128; 10],
}
impl MovePlanes {
pub fn new() -> MovePlanes {
MovePlanes { planes: [0u128; 10] }
}
pub fn clear_all(&mut self) {
for i in 0..10 {
self.planes[i] = 0;
}
}
pub fn write_ply(&mut self, k: usize, mv: Option<Move10>) {
debug_assert!(k < 128, "ply index must be < 128");
if let Some(mv10) = mv {
let raw = mv10.0 as u32; for bit_i in 0..10 {
let bit_val = ((raw >> bit_i) & 0x01) as u128;
if bit_val == 1 {
self.planes[bit_i] |= 1u128 << k;
} else {
self.planes[bit_i] &= !(1u128 << k);
}
}
} else {
for i in 0..10 {
self.planes[i] &= !(1u128 << k);
}
}
}
pub fn read_ply(&self, k: usize) -> Move10 {
debug_assert!(k < 128, "ply index must be < 128");
let mut raw: u16 = 0;
for bit_i in 0..10 {
let bit_val = ((self.planes[bit_i] >> k) & 0x01) as u16;
raw |= bit_val << bit_i;
}
Move10(raw)
}
}
pub type CapturedBits = u32;
pub type Occupied = u64;
#[derive(Clone, Debug)]
pub struct PieceMapping {
pub piece_square: [Option<u8>; 32],
pub square_piece: [Option<u8>; 64],
}
impl PieceMapping {
pub fn new_empty() -> PieceMapping {
PieceMapping {
piece_square: [None; 32],
square_piece: [None; 64],
}
}
pub fn place_piece(&mut self, pid: u8, sq: u8) {
let p = pid as usize;
let s = sq as usize;
assert!(
self.piece_square[p].is_none(),
"Piece {} was already on the board!", pid
);
assert!(
self.square_piece[s].is_none(),
"Square {} was already occupied!", sq
);
self.piece_square[p] = Some(sq);
self.square_piece[s] = Some(pid);
}
pub fn remove_piece(&mut self, pid: u8) {
let p = pid as usize;
if let Some(old_sq) = self.piece_square[p] {
let s = old_sq as usize;
self.square_piece[s] = None;
}
self.piece_square[p] = None;
}
pub fn move_piece(&mut self, pid: u8, dest: u8) {
let p = pid as usize;
let d = dest as usize;
let old_sq = self.piece_square[p]
.expect(&format!("Piece {} is not on board!", pid)) as usize;
self.square_piece[old_sq] = None;
assert!(
self.square_piece[d].is_none(),
"Destination square {} is occupied!", dest
);
self.square_piece[d] = Some(pid);
self.piece_square[p] = Some(dest);
}
pub fn who_on_square(&self, sq: u8) -> Option<u8> {
self.square_piece[sq as usize]
}
}
#[derive(Clone, Debug)]
pub struct ChessGame {
pub planes: MovePlanes,
pub ply: usize,
pub captured_bits: CapturedBits,
pub occupied: Occupied,
pub mapping: PieceMapping,
pub en_passant_target: Option<u8>,
}
impl ChessGame {
pub fn new() -> ChessGame {
let mut g = ChessGame {
planes: MovePlanes::new(),
ply: 0,
captured_bits: 0,
occupied: 0,
mapping: PieceMapping::new_empty(),
en_passant_target: None,
};
g.set_starting_position();
g
}
pub fn set_starting_position(&mut self) {
self.captured_bits = 0;
self.occupied = 0;
self.mapping = PieceMapping::new_empty();
self.en_passant_target = None;
for i in 0..8 {
let pid = i as u8; let sq = (8 + i) as u8; self.mapping.place_piece(pid, sq);
self.occupied |= 1u64 << sq;
}
self.mapping.place_piece( 8, 1); self.occupied |= 1u64 << 1;
self.mapping.place_piece( 9, 6); self.occupied |= 1u64 << 6;
self.mapping.place_piece(10, 2); self.occupied |= 1u64 << 2;
self.mapping.place_piece(11, 5); self.occupied |= 1u64 << 5;
self.mapping.place_piece(12, 0); self.occupied |= 1u64 << 0;
self.mapping.place_piece(13, 7); self.occupied |= 1u64 << 7;
self.mapping.place_piece(14, 3); self.occupied |= 1u64 << 3;
self.mapping.place_piece(15, 4); self.occupied |= 1u64 << 4;
for i in 0..8 {
let pid = (16 + i) as u8; let sq = (48 + i) as u8; self.mapping.place_piece(pid, sq);
self.occupied |= 1u64 << sq;
}
self.mapping.place_piece(24, 57); self.occupied |= 1u64 << 57;
self.mapping.place_piece(25, 62); self.occupied |= 1u64 << 62;
self.mapping.place_piece(26, 58); self.occupied |= 1u64 << 58;
self.mapping.place_piece(27, 61); self.occupied |= 1u64 << 61;
self.mapping.place_piece(28, 56); self.occupied |= 1u64 << 56;
self.mapping.place_piece(29, 63); self.occupied |= 1u64 << 63;
self.mapping.place_piece(30, 59); self.occupied |= 1u64 << 59;
self.mapping.place_piece(31, 60); self.occupied |= 1u64 << 60;
}
pub fn push_move(&mut self, piece_id_within: u8, dest: u8) {
assert!(piece_id_within < 16, "piece_id must be 0..15");
assert!(dest < 64, "dest must be 0..63");
let k = self.ply;
assert!(k < 128, "Cannot exceed 128 plies");
let color = (k % 2) as u8; let global_pid = 16 * color + piece_id_within;
let src = self.mapping.piece_square[global_pid as usize]
.expect(&format!("Piece {} not on board!", global_pid)) as usize;
let src_file = (src % 8) as i8;
let src_rank = (src / 8) as i8;
let dst_file = (dest % 8) as i8;
let dst_rank = (dest / 8) as i8;
if piece_id_within == 15 {
let king_start = if color == 0 { 4 } else { 60 };
if src == king_start as usize && (dst_file - src_file).abs() == 2 {
if dst_file == src_file + 2 {
let rook_pid = if color == 0 { 13 } else { 29 };
let rook_src = if color == 0 { 7 } else { 63 };
let rook_dst = if color == 0 { 5 } else { 61 };
self.mapping.piece_square[global_pid as usize] = Some(dest);
self.mapping.square_piece[dest as usize] = Some(global_pid);
self.mapping.square_piece[src as usize] = None;
self.occupied &= !(1u64 << (src as u64));
self.occupied |= 1u64 << (dest as u64);
self.mapping.piece_square[rook_pid as usize] = Some(rook_dst);
self.mapping.square_piece[rook_dst as usize] = Some(rook_pid);
self.mapping.square_piece[rook_src as usize] = None;
self.occupied &= !(1u64 << (rook_src as u64));
self.occupied |= 1u64 << (rook_dst as u64);
self.en_passant_target = None;
let mv10 = Move10::new(piece_id_within, dest);
self.planes.write_ply(k, Some(mv10));
self.ply += 1;
return;
} else if dst_file == src_file - 2 {
let rook_pid = if color == 0 { 12 } else { 28 };
let rook_src = if color == 0 { 0 } else { 56 };
let rook_dst = if color == 0 { 3 } else { 59 };
self.mapping.piece_square[global_pid as usize] = Some(dest);
self.mapping.square_piece[dest as usize] = Some(global_pid);
self.mapping.square_piece[src as usize] = None;
self.occupied &= !(1u64 << (src as u64));
self.occupied |= 1u64 << (dest as u64);
self.mapping.piece_square[rook_pid as usize] = Some(rook_dst);
self.mapping.square_piece[rook_dst as usize] = Some(rook_pid);
self.mapping.square_piece[rook_src as usize] = None;
self.occupied &= !(1u64 << (rook_src as u64));
self.occupied |= 1u64 << (rook_dst as u64);
self.en_passant_target = None;
let mv10 = Move10::new(piece_id_within, dest);
self.planes.write_ply(k, Some(mv10));
self.ply += 1;
return;
}
}
}
let is_pawn = piece_id_within < 8;
if is_pawn && (dst_file - src_file).abs() == 1 && (dst_rank - src_rank).abs() == 1 {
if self.mapping.who_on_square(dest).is_none() {
if let Some(ep_sq) = self.en_passant_target {
if ep_sq == dest {
let captured_sq = if color == 0 {
(dest as i8 - 8) as u8
} else {
(dest as i8 + 8) as u8
};
if let Some(opp_pid) = self.mapping.who_on_square(captured_sq) {
self.captured_bits |= 1u32 << (opp_pid as u32);
self.mapping.remove_piece(opp_pid);
self.occupied &= !(1u64 << (captured_sq as u64));
}
}
}
}
}
if let Some(opp_pid) = self.mapping.who_on_square(dest) {
self.captured_bits |= 1u32 << (opp_pid as u32);
self.mapping.remove_piece(opp_pid);
self.occupied &= !(1u64 << (dest as u64));
}
self.mapping.move_piece(global_pid as u8, dest);
self.occupied &= !(1u64 << (src as u64));
self.occupied |= 1u64 << (dest as u64);
if is_pawn && (dst_rank - src_rank).abs() == 2 {
let ep_square = if color == 0 {
(dest as i8 - 8) as u8
} else {
(dest as i8 + 8) as u8
};
self.en_passant_target = Some(ep_square);
} else {
self.en_passant_target = None;
}
let mv10 = Move10::new(piece_id_within, dest);
self.planes.write_ply(k, Some(mv10));
self.ply += 1;
}
pub fn pop_move(&mut self) {
if self.ply == 0 {
return;
}
self.ply -= 1;
let new_k = self.ply;
self.planes.write_ply(new_k, None);
self.captured_bits = 0;
self.occupied = 0;
self.mapping = PieceMapping::new_empty();
self.en_passant_target = None;
self.set_starting_position();
for i in 0..new_k {
let mv10 = self.planes.read_ply(i);
let pid_within = mv10.piece_id(); let dst = mv10.dest();
let color = (i % 2) as u8; let global_pid = 16 * color + pid_within;
let src = self.mapping.piece_square[global_pid as usize]
.expect(&format!("(Replay) piece {} missing at ply {}", global_pid, i)) as usize;
let src_file = (src % 8) as i8;
let src_rank = (src / 8) as i8;
let dst_file = (dst % 8) as i8;
let dst_rank = (dst / 8) as i8;
if pid_within == 15 {
let king_start = if color == 0 { 4 } else { 60 };
if src == king_start as usize && (dst_file - src_file).abs() == 2 {
if dst_file == src_file + 2 {
let rook_pid = if color == 0 { 13 } else { 29 };
let rook_src = if color == 0 { 7 } else { 63 };
let rook_dst = if color == 0 { 5 } else { 61 };
self.mapping.piece_square[global_pid as usize] = Some(dst);
self.mapping.square_piece[dst as usize] = Some(global_pid);
self.mapping.square_piece[src as usize] = None;
self.occupied &= !(1u64 << (src as u64));
self.occupied |= 1u64 << (dst as u64);
self.mapping.piece_square[rook_pid as usize] = Some(rook_dst);
self.mapping.square_piece[rook_dst as usize] = Some(rook_pid);
self.mapping.square_piece[rook_src as usize] = None;
self.occupied &= !(1u64 << (rook_src as u64));
self.occupied |= 1u64 << (rook_dst as u64);
self.en_passant_target = None;
continue;
} else if dst_file == src_file - 2 {
let rook_pid = if color == 0 { 12 } else { 28 };
let rook_src = if color == 0 { 0 } else { 56 };
let rook_dst = if color == 0 { 3 } else { 59 };
self.mapping.piece_square[global_pid as usize] = Some(dst);
self.mapping.square_piece[dst as usize] = Some(global_pid);
self.mapping.square_piece[src as usize] = None;
self.occupied &= !(1u64 << (src as u64));
self.occupied |= 1u64 << (dst as u64);
self.mapping.piece_square[rook_pid as usize] = Some(rook_dst);
self.mapping.square_piece[rook_dst as usize] = Some(rook_pid);
self.mapping.square_piece[rook_src as usize] = None;
self.occupied &= !(1u64 << (rook_src as u64));
self.occupied |= 1u64 << (rook_dst as u64);
self.en_passant_target = None;
continue;
}
}
}
let is_pawn = pid_within < 8;
if is_pawn && (dst_file - src_file).abs() == 1 && (dst_rank - src_rank).abs() == 1 {
if self.mapping.who_on_square(dst).is_none() {
if let Some(ep_sq) = self.en_passant_target {
if ep_sq == dst {
let captured_sq = if color == 0 {
(dst as i8 - 8) as u8
} else {
(dst as i8 + 8) as u8
};
if let Some(opp_pid) = self.mapping.who_on_square(captured_sq) {
self.captured_bits |= 1u32 << (opp_pid as u32);
self.mapping.remove_piece(opp_pid);
self.occupied &= !(1u64 << (captured_sq as u64));
}
}
}
}
}
if let Some(opp_pid) = self.mapping.who_on_square(dst) {
self.captured_bits |= 1u32 << (opp_pid as u32);
self.mapping.remove_piece(opp_pid);
self.occupied &= !(1u64 << (dst as u64));
}
self.mapping.move_piece(global_pid as u8, dst);
self.occupied &= !(1u64 << (src as u64));
self.occupied |= 1u64 << (dst as u64);
if is_pawn && (dst_rank - src_rank).abs() == 2 {
let ep_square = if color == 0 {
(dst as i8 - 8) as u8
} else {
(dst as i8 + 8) as u8
};
self.en_passant_target = Some(ep_square);
} else {
self.en_passant_target = None;
}
}
}
}
pub fn encode_square(rank: u8, file: u8) -> u8 {
assert!(rank < 8 && file < 8, "rank/file must be 0..7");
(rank << 3) | file
}
pub fn encode_piece(color: u8, is_pawn: bool, spawn_side_or_rank: u8, kind_or_unused: u8) -> u8 {
let bit_color = (color & 1) << 4;
let bit_pawn = (is_pawn as u8) << 3;
let lower3 = if is_pawn {
spawn_side_or_rank & 0b111
} else {
let side_bit = (spawn_side_or_rank & 1) << 2;
let kind_bits = kind_or_unused & 0b11;
side_bit | kind_bits
};
bit_color | bit_pawn | lower3
}
pub fn init_chess_positions() -> Vec<(u8, u8)> {
let mut v = Vec::new();
let white_back: [(u8, u8); 8] = [
(encode_piece(0, false, 1, 3), encode_square(0, 0)), (encode_piece(0, false, 1, 2), encode_square(0, 1)), (encode_piece(0, false, 1, 1), encode_square(0, 2)), (encode_piece(0, false, 1, 0), encode_square(0, 3)), (encode_piece(0, false, 0, 0), encode_square(0, 4)), (encode_piece(0, false, 0, 1), encode_square(0, 5)), (encode_piece(0, false, 0, 2), encode_square(0, 6)), (encode_piece(0, false, 0, 3), encode_square(0, 7)), ];
v.extend_from_slice(&white_back);
let white_pawn_id = encode_piece(0, true, 1, 0);
for file in 0..8 {
v.push((white_pawn_id, encode_square(1, file)));
}
let black_pawn_id = encode_piece(1, true, 6, 0);
for file in 0..8 {
v.push((black_pawn_id, encode_square(6, file)));
}
let black_back: [(u8, u8); 8] = [
(encode_piece(1, false, 1, 3), encode_square(7, 0)), (encode_piece(1, false, 1, 2), encode_square(7, 1)), (encode_piece(1, false, 1, 1), encode_square(7, 2)), (encode_piece(1, false, 1, 0), encode_square(7, 3)), (encode_piece(1, false, 0, 0), encode_square(7, 4)), (encode_piece(1, false, 0, 1), encode_square(7, 5)), (encode_piece(1, false, 0, 2), encode_square(7, 6)), (encode_piece(1, false, 0, 3), encode_square(7, 7)), ];
v.extend_from_slice(&black_back);
v
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_move10_roundtrip() {
for pid in 0..16 {
for dst in 0..64 {
let mv = Move10::new(pid, dst);
assert_eq!(mv.piece_id(), pid);
assert_eq!(mv.dest(), dst);
}
}
}
#[test]
fn test_move_planes_write_read() {
let mut planes = MovePlanes::new();
let codes = [
Move10::new(3, 10),
Move10::new(15, 63),
Move10::new(0, 0),
Move10::new(7, 31),
];
for (i, &mv) in codes.iter().enumerate() {
planes.write_ply(i, Some(mv));
}
for (i, &mv) in codes.iter().enumerate() {
let got = planes.read_ply(i);
assert_eq!(got, mv);
}
planes.write_ply(2, None);
let cleared = planes.read_ply(2);
assert_eq!(cleared.0, 0);
}
#[test]
fn test_piece_mapping_basic() {
let mut pm = PieceMapping::new_empty();
pm.place_piece(5, 20);
assert_eq!(pm.piece_square[5], Some(20));
assert_eq!(pm.who_on_square(20), Some(5));
pm.move_piece(5, 30);
assert_eq!(pm.piece_square[5], Some(30));
assert_eq!(pm.who_on_square(20), None);
assert_eq!(pm.who_on_square(30), Some(5));
pm.remove_piece(5);
assert_eq!(pm.piece_square[5], None);
assert_eq!(pm.who_on_square(30), None);
}
#[test]
fn test_white_kingside_castling() {
let mut game = ChessGame::new();
game.push_move(9, 22);
game.push_move(0, 40);
game.push_move(11, 12);
game.push_move(0, 32);
assert_eq!(game.mapping.piece_square[15], Some(4)); assert_eq!(game.mapping.piece_square[13], Some(7)); game.push_move(15, 6);
assert_eq!(game.mapping.piece_square[15], Some(6));
assert_eq!(game.mapping.piece_square[13], Some(5));
for _ in 0..5 {
game.pop_move();
}
assert_eq!(game.mapping.piece_square[15], Some(4)); assert_eq!(game.mapping.piece_square[13], Some(7)); assert_eq!(game.ply, 0);
}
#[test]
fn test_black_kingside_castling() {
let mut game = ChessGame::new();
game.push_move(0, 16);
game.push_move(9, 55);
game.push_move(0, 24);
game.push_move(11, 54);
game.push_move(1, 17);
assert_eq!(game.mapping.piece_square[31], Some(60)); assert_eq!(game.mapping.piece_square[29], Some(63)); game.push_move(15, 62);
assert_eq!(game.mapping.piece_square[31], Some(62));
assert_eq!(game.mapping.piece_square[29], Some(61));
for _ in 0..6 {
game.pop_move();
}
assert_eq!(game.mapping.piece_square[31], Some(60)); assert_eq!(game.mapping.piece_square[29], Some(63)); assert_eq!(game.ply, 0);
}
#[test]
fn test_en_passant_capture() {
let mut ep_game = ChessGame::new();
ep_game.push_move(4, 28);
assert_eq!(ep_game.en_passant_target, Some(20));
ep_game.push_move(0, 40);
ep_game.push_move(4, 36);
ep_game.push_move(3, 35);
assert_eq!(ep_game.en_passant_target, Some(43));
ep_game.push_move(4, 43);
assert_eq!(ep_game.captured_bits & (1u32 << (16 + 3)), 1u32 << (16 + 3));
assert_eq!(ep_game.mapping.who_on_square(43), Some(4));
for _ in 0..5 {
ep_game.pop_move();
}
assert_eq!(ep_game.mapping.who_on_square(43), None);
assert_eq!(ep_game.en_passant_target, None);
assert_eq!(ep_game.ply, 0);
}
#[test]
fn test_encode_square_basic() {
assert_eq!(encode_square(0, 0), 0b000_000); assert_eq!(encode_square(0, 7), 0b000_111); assert_eq!(encode_square(7, 0), 0b111_000); assert_eq!(encode_square(7, 7), 0b111_111); assert_eq!(encode_square(3, 4), (3 << 3) | 4); }
#[test]
fn test_encode_piece_non_pawns() {
assert_eq!(encode_piece(0, false, 0, 0), 0x00);
assert_eq!(encode_piece(0, false, 1, 0), 0x04);
assert_eq!(encode_piece(0, false, 0, 1), 0x01);
assert_eq!(encode_piece(0, false, 1, 1), 0x05);
assert_eq!(encode_piece(0, false, 0, 2), 0x02);
assert_eq!(encode_piece(0, false, 1, 2), 0x06);
assert_eq!(encode_piece(0, false, 0, 3), 0x03);
assert_eq!(encode_piece(0, false, 1, 3), 0x07);
assert_eq!(encode_piece(1, false, 0, 0), 0x10);
assert_eq!(encode_piece(1, false, 1, 0), 0x14);
assert_eq!(encode_piece(1, false, 0, 1), 0x11);
assert_eq!(encode_piece(1, false, 1, 1), 0x15);
assert_eq!(encode_piece(1, false, 0, 2), 0x12);
assert_eq!(encode_piece(1, false, 1, 2), 0x16);
assert_eq!(encode_piece(1, false, 0, 3), 0x13);
assert_eq!(encode_piece(1, false, 1, 3), 0x17);
}
#[test]
fn test_encode_piece_pawns() {
assert_eq!(encode_piece(0, true, 1, 0), 0x09);
for _ in 0..8 {
assert_eq!(encode_piece(0, true, 1, 0), 0x09);
}
assert_eq!(encode_piece(1, true, 6, 0), 0x1E);
for _ in 0..8 {
assert_eq!(encode_piece(1, true, 6, 0), 0x1E);
}
}
#[test]
fn test_init_chess_positions() {
let pos = init_chess_positions();
assert_eq!(pos.len(), 32);
let mut map: std::collections::BTreeMap<u8, Vec<u8>> = Default::default();
for &(pid, sq) in &pos {
map.entry(pid).or_default().push(sq);
}
assert_eq!(map.get(&0x00), Some(&vec![encode_square(0, 4)]));
assert_eq!(map.get(&0x04), Some(&vec![encode_square(0, 3)]));
let mut w_rooks = map.get(&0x03).unwrap().clone();
w_rooks.sort();
assert_eq!(w_rooks, vec![encode_square(0, 7)]);
assert_eq!(map.get(&0x07), Some(&vec![encode_square(0, 0)]));
assert_eq!(map.get(&0x02), Some(&vec![encode_square(0, 6)]));
assert_eq!(map.get(&0x06), Some(&vec![encode_square(0, 1)]));
assert_eq!(map.get(&0x01), Some(&vec![encode_square(0, 5)]));
assert_eq!(map.get(&0x05), Some(&vec![encode_square(0, 2)]));
let wpawn_sqs: Vec<u8> = (0..8).map(|f| encode_square(1, f)).collect();
let mut got_wpawn_sqs = map.get(&0x09).unwrap().clone();
got_wpawn_sqs.sort();
assert_eq!(got_wpawn_sqs, wpawn_sqs);
let bpawn_sqs: Vec<u8> = (0..8).map(|f| encode_square(6, f)).collect();
let mut got_bpawn_sqs = map.get(&0x1E).unwrap().clone();
got_bpawn_sqs.sort();
assert_eq!(got_bpawn_sqs, bpawn_sqs);
assert_eq!(map.get(&0x10), Some(&vec![encode_square(7, 4)]));
assert_eq!(map.get(&0x14), Some(&vec![encode_square(7, 3)]));
assert_eq!(map.get(&0x13), Some(&vec![encode_square(7, 7)]));
assert_eq!(map.get(&0x17), Some(&vec![encode_square(7, 0)]));
assert_eq!(map.get(&0x12), Some(&vec![encode_square(7, 6)]));
assert_eq!(map.get(&0x16), Some(&vec![encode_square(7, 1)]));
assert_eq!(map.get(&0x11), Some(&vec![encode_square(7, 5)]));
assert_eq!(map.get(&0x15), Some(&vec![encode_square(7, 2)]));
}
}