mod cudad;
mod marlin;
pub use cudad::{CudADFormat, CudADFormatIter};
pub use marlin::{MarlinFormat, MarlinFormatIter};
use crate::BulletFormat;
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct ChessBoard {
pub occ: u64,
pub pcs: [u8; 16],
pub score: i16,
pub result: u8,
pub ksq: u8,
pub opp_ksq: u8,
pub extra: [u8; 3],
}
const _RIGHT_SIZE: () = assert!(std::mem::size_of::<ChessBoard>() == 32);
impl BulletFormat for ChessBoard {
type FeatureType = (u8, u8);
const HEADER_SIZE: usize = 0;
fn score(&self) -> i16 {
self.score
}
fn result(&self) -> f32 {
f32::from(self.result) / 2.
}
fn result_idx(&self) -> usize {
usize::from(self.result)
}
fn set_result(&mut self, result: f32) {
self.result = (2.0 * result) as u8;
}
}
impl IntoIterator for ChessBoard {
type Item = (u8, u8);
type IntoIter = BoardIter;
fn into_iter(self) -> Self::IntoIter {
BoardIter {
board: self,
idx: 0,
}
}
}
pub struct BoardIter {
board: ChessBoard,
idx: usize,
}
impl Iterator for BoardIter {
type Item = (u8, u8);
fn next(&mut self) -> Option<Self::Item> {
if self.board.occ == 0 {
return None;
}
let square = self.board.occ.trailing_zeros() as u8;
let piece = (self.board.pcs[self.idx / 2] >> (4 * (self.idx & 1))) & 0b1111;
self.board.occ &= self.board.occ - 1;
self.idx += 1;
Some((piece, square))
}
}
impl ChessBoard {
pub fn occ(&self) -> u64 {
self.occ
}
pub fn our_ksq(&self) -> u8 {
self.ksq
}
pub fn opp_ksq(&self) -> u8 {
self.opp_ksq
}
pub fn extra(&self) -> [u8; 3] {
self.extra
}
pub fn from_raw(
mut bbs: [u64; 8],
stm: usize,
mut score: i16,
mut result: f32,
) -> Result<Self, String> {
if stm == 1 {
for bb in bbs.iter_mut() {
*bb = bb.swap_bytes();
}
bbs.swap(0, 1);
score = -score;
result = 1.0 - result;
}
let occ = bbs[0] | bbs[1];
let mut pcs = [0; 16];
let mut idx = 0;
let mut occ2 = occ;
while occ2 > 0 {
let sq = occ2.trailing_zeros();
let bit = 1 << sq;
occ2 &= occ2 - 1;
let colour = u8::from((bit & bbs[1]) > 0) << 3;
let piece = bbs
.iter()
.skip(2)
.position(|bb| bit & bb > 0)
.ok_or("No Piece Found!".to_string())?;
let pc = colour | piece as u8;
pcs[idx / 2] |= pc << (4 * (idx & 1));
idx += 1;
}
let result = (2.0 * result) as u8;
let ksq = (bbs[0] & bbs[7]).trailing_zeros() as u8;
let opp_ksq = (bbs[1] & bbs[7]).trailing_zeros() as u8 ^ 56;
Ok(Self {
occ,
pcs,
score,
result,
ksq,
opp_ksq,
extra: [0; 3],
})
}
}
impl std::str::FromStr for ChessBoard {
type Err = String;
fn from_str(s: &str) -> Result<Self, String> {
let split: Vec<_> = s.split('|').collect();
let fen = split[0];
let score = split.get(1).ok_or("Malformed!")?.trim();
let wdl = split.get(2).ok_or("Malformed!")?.trim();
let parts: Vec<&str> = fen.split_whitespace().collect();
let board_str = *parts.first().ok_or("Malformed FEN!")?;
let stm_str = *parts.get(1).ok_or("Malformed FEN!")?;
let stm = u8::from(stm_str == "b");
let mut board = Self::default();
let mut idx = 0;
let mut parse_row = |i: usize, row: &str| {
let mut col = 0;
for ch in row.chars() {
if ('1'..='8').contains(&ch) {
col += ch.to_digit(10).expect("hard coded") as usize;
} else if let Some(mut piece) = "PNBRQKpnbrqk".chars().position(|el| el == ch) {
let mut square = 8 * i + col;
piece = (piece / 6) << 3 | (piece % 6);
if stm == 1 {
piece ^= 8;
square ^= 56;
}
if piece == 5 {
board.ksq = square as u8;
}
if piece == 13 {
board.opp_ksq = square as u8 ^ 56;
}
board.occ |= 1 << square;
if idx >= 32 {
return Err(s);
}
board.pcs[idx / 2] |= (piece as u8) << (4 * (idx & 1));
idx += 1;
col += 1;
}
}
Ok(())
};
if stm == 1 {
for (i, row) in board_str.split('/').enumerate() {
parse_row(7 - i, row)?;
}
} else {
for (i, row) in board_str.split('/').rev().enumerate() {
parse_row(i, row)?;
}
}
board.score = if let Ok(x) = score.parse::<i16>() {
x
} else {
println!("{s}");
return Err(String::from("Bad score!"));
};
board.result = match wdl {
"1.0" | "[1.0]" | "1" => 2,
"0.5" | "[0.5]" | "1/2" => 1,
"0.0" | "[0.0]" | "0" => 0,
_ => {
println!("{s}");
return Err(String::from("Bad game result!"));
}
};
if stm == 1 {
board.score = -board.score;
board.result = 2 - board.result;
}
Ok(board)
}
}