chess-lib 0.1.3

A chess movement generator library.
Documentation
// Based on: https://analog-hors.github.io/site/magic-bitboards/

use rand::RngCore;
use rand_pcg::Pcg64Mcg;

use crate::engine::bitboard::{is_empty, is_occupied, popcnt, Bitboard, EMPTY};
use crate::engine::common::Square;
use crate::engine::magics_entries::{
    BISHOP_MAGICS, BISHOP_TABLE_SIZE, ROOK_MAGICS, ROOK_TABLE_SIZE,
};

pub const ROOK: Slider = Slider {
    deltas: [(1, 0), (0, -1), (-1, 0), (0, 1)],
};

pub const BISHOP: Slider = Slider {
    deltas: [(1, 1), (1, -1), (-1, -1), (-1, 1)],
};

pub type RookTable = Vec<Bitboard>;
pub type BishopTable = Vec<Bitboard>;

pub struct Tables {
    pub rook: RookTable,
    pub bishop: BishopTable,
}

pub fn load_magics_table() -> Tables {
    Tables {
        rook: make_table(ROOK_TABLE_SIZE, ROOK, ROOK_MAGICS),
        bishop: make_table(BISHOP_TABLE_SIZE, BISHOP, BISHOP_MAGICS),
    }
}

pub struct Slider {
    deltas: [(i8, i8); 4],
}

impl Slider {
    fn moves(&self, square: Square, blockers: Bitboard) -> Bitboard {
        let mut moves = EMPTY;

        for &(delta_left, delta_right) in &self.deltas {
            let mut ray = square;

            while !is_occupied(blockers, ray) {
                if let Some(shifted) = ray.try_offset(delta_left, delta_right) {
                    ray = shifted;
                    moves |= ray.bitboard();
                } else {
                    break;
                }
            }
        }

        moves
    }

    fn relevant_blockers(&self, square: Square) -> Bitboard {
        let mut blockers = EMPTY;

        for &(delta_left, delta_right) in &self.deltas {
            let mut ray = square;

            while let Some(shifted) = ray.try_offset(delta_left, delta_right) {
                blockers |= ray.bitboard();
                ray = shifted;
            }
        }

        blockers &= !square.bitboard();
        blockers
    }
}

#[derive(Debug, Copy, Clone)]
pub struct MagicEntry {
    pub mask: Bitboard,
    pub magic: u64,
    pub shift: u8,
    pub offset: u32,
}

impl MagicEntry {
    pub fn new(mask: Bitboard, magic: u64, shift: u8, offset: u32) -> Self {
        Self {
            mask,
            magic,
            shift,
            offset,
        }
    }
}

fn magic_index(entry: &MagicEntry, blockers: Bitboard) -> usize {
    let blockers = blockers & entry.mask;
    let hash = blockers.wrapping_mul(entry.magic);
    (hash >> entry.shift) as usize
}

pub fn magic_index_with_offset(entry: &MagicEntry, blockers: Bitboard) -> usize {
    let blockers = blockers & entry.mask;
    let hash = blockers.wrapping_mul(entry.magic);
    let index = (hash >> entry.shift) as usize;
    entry.offset as usize + index
}

pub fn find_and_print_all_magics(slider: &Slider, slider_name: &str, rng: &mut Pcg64Mcg) {
    println!("pub const {}_MAGICS: &[MagicEntry; 64] = &[", slider_name);

    let mut total_table_size = 0;
    for square in Square::iter() {
        let index_bits = popcnt(slider.relevant_blockers(square)) as u8;
        let (magic_entry, table) = find_magic(slider, square, index_bits, rng);
        println!(
            "    MagicEntry {{ mask: 0x{:016X}, magic: 0x{:016X}, shift: {}, offset: {} }},",
            magic_entry.mask, magic_entry.magic, magic_entry.shift, total_table_size
        );
        total_table_size += table.len();
    }

    println!("];");
    println!(
        "pub const {}_TABLE_SIZE: usize = {};",
        slider_name, total_table_size
    );
}

fn find_magic(
    slider: &Slider,
    square: Square,
    index_bits: u8,
    rng: &mut Pcg64Mcg,
) -> (MagicEntry, Vec<Bitboard>) {
    let mask = slider.relevant_blockers(square);
    let shift = 64 - index_bits;

    loop {
        let magic = rng.next_u64() & rng.next_u64() & rng.next_u64();
        let magic_entry = MagicEntry::new(mask, magic, shift, 0);
        if let Ok(table) = try_to_make_table(slider, square, &magic_entry) {
            return (magic_entry, table);
        }
    }
}

struct TableFillError;

fn try_to_make_table(
    slider: &Slider,
    square: Square,
    magic_entry: &MagicEntry,
) -> Result<Vec<Bitboard>, TableFillError> {
    let index_bits = 64 - magic_entry.shift;
    let mut table = vec![EMPTY; 1 << index_bits];

    let mut blockers = EMPTY;
    loop {
        let moves = slider.moves(square, blockers);
        let table_entry = &mut table[magic_index(magic_entry, blockers)];
        if is_empty(*table_entry) {
            *table_entry = moves;
        } else if *table_entry != moves {
            return Err(TableFillError);
        }

        blockers = blockers.wrapping_sub(magic_entry.mask) & magic_entry.mask;
        if is_empty(blockers) {
            break;
        }
    }

    Ok(table)
}

fn make_table(table_size: usize, slider: Slider, magics: &[MagicEntry; 64]) -> Vec<Bitboard> {
    let mut table = vec![EMPTY; table_size];

    for square in Square::iter() {
        let magic_entry = &magics[square.index() as usize];
        let mask = magic_entry.mask;

        let mut blockers = EMPTY;
        loop {
            let moves = slider.moves(square, blockers);
            table[magic_index_with_offset(magic_entry, blockers)] = moves;

            blockers = blockers.wrapping_sub(mask) & mask;
            if is_empty(blockers) {
                break;
            }
        }
    }

    table
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::engine::bitboard::set_bit;
    use crate::engine::common::{File, Rank};

    #[test]
    fn test_load_magic_tables() {
        let tables = load_magics_table();

        assert_eq!(tables.bishop.len(), BISHOP_TABLE_SIZE);
        assert_eq!(tables.rook.len(), ROOK_TABLE_SIZE);
    }

    #[test]
    fn test_slider_moves_as_rook() {
        let square = Square::new(File::D, Rank::Four);
        let blockers = set_bit(set_bit(EMPTY, File::B, Rank::Four), File::D, Rank::Two);
        let expected = 0b00001000 << 8
            | 0b00001000 << (8 * 2)
            | 0b11110110 << (3 * 8)
            | 0b00001000 << (8 * 4)
            | 0b00001000 << (8 * 5)
            | 0b00001000 << (8 * 6)
            | 0b00001000 << (8 * 7);

        let moves = ROOK.moves(square, blockers);

        assert_eq!(moves, expected);
    }

    #[test]
    fn test_slider_moves_as_bishop() {
        let square = Square::new(File::D, Rank::Four);
        let blockers = set_bit(set_bit(EMPTY, File::C, Rank::Six), File::E, Rank::Five);
        let expected = 0b01000001
            | 0b00100010 << 8
            | 0b00010100 << (8 * 2)
            | 0b00010100 << (8 * 4)
            | 0b00000010 << (8 * 5)
            | 0b00000001 << (8 * 6);

        let moves = BISHOP.moves(square, blockers);

        assert_eq!(moves, expected);
    }

    #[test]
    fn test_blockers_as_rook() {
        let square = Square::new(File::D, Rank::Four);
        let expected = 0b00001000 << 8
            | 0b00001000 << (8 * 2)
            | 0b01110110 << (3 * 8)
            | 0b00001000 << (8 * 4)
            | 0b00001000 << (8 * 5)
            | 0b00001000 << (8 * 6);

        let blockers = ROOK.relevant_blockers(square);

        assert_eq!(blockers, expected)
    }

    #[test]
    fn test_blockers_as_bishop() {
        let square = Square::new(File::D, Rank::Four);
        let expected = 0b00100010 << 8
            | 0b00010100 << (8 * 2)
            | 0b00010100 << (8 * 4)
            | 0b00100010 << (8 * 5)
            | 0b01000000 << (8 * 6);

        let blockers = BISHOP.relevant_blockers(square);

        assert_eq!(blockers, expected)
    }

    #[test]
    fn test_find_magic_as_rook() {
        let (magic_entry, bitboards) = find_magic(
            &ROOK,
            Square::new(File::D, Rank::Four),
            12,
            &mut Pcg64Mcg::new(0),
        );

        assert_eq!(magic_entry.mask, 2260632246683648);
        assert_eq!(magic_entry.magic, 4616191825688924160);
        assert_eq!(magic_entry.shift, 52);

        assert_eq!(bitboards.len(), 4096);
    }

    #[test]
    fn test_find_magic_as_bishop() {
        let (magic_entry, bitboards) = find_magic(
            &BISHOP,
            Square::new(File::D, Rank::Four),
            12,
            &mut Pcg64Mcg::new(0),
        );

        assert_eq!(magic_entry.mask, 18051867805491712);
        assert_eq!(magic_entry.magic, 288309553940959272);
        assert_eq!(magic_entry.shift, 52);

        assert_eq!(bitboards.len(), 4096);
    }
}