pub mod filter;
use filter::BoardFilter;
pub(super) mod piece;
pub(super) use piece::*;
pub(super) mod piecedefn;
pub(super) use piecedefn::*;
pub(super) mod square;
pub(super) use square::*;
use std::collections::HashMap;
use std::fmt;
#[derive(Clone, Debug)]
pub struct Board {
board: Vec<Vec<Piece>>,
pub(super) white_to_move: bool,
pub(super) white_king_side_castling: bool,
pub(super) white_queen_side_castling: bool,
pub(super) black_king_side_castling: bool,
pub(super) black_queen_side_castling: bool,
pub(super) enpassant_target: Option<Square>,
pub(super) halfmove_clock: usize,
pub(super) fullmove_number: usize,
}
impl PartialEq for Board {
fn eq(&self, other: &Self) -> bool {
if self.white_to_move != other.white_to_move
|| self.white_king_side_castling != other.white_king_side_castling
|| self.white_queen_side_castling != other.white_queen_side_castling
|| self.black_king_side_castling != other.black_king_side_castling
|| self.black_queen_side_castling != other.black_queen_side_castling
|| self.enpassant_target != other.enpassant_target
|| self.halfmove_clock != other.halfmove_clock
|| self.fullmove_number != other.fullmove_number
{
return false;
} else {
for row in 0..7 {
for col in 0..7 {
if self.board[row][col] != other.board[row][col] {
return false;
}
}
}
}
true
}
}
impl Eq for Board {}
impl fmt::Display for Board {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut result = String::from("");
for row in (0..=7).rev() {
for col in 0..=7 {
result.push(match &self.get_coord(col, row).name {
Piece::WK => 'K',
Piece::WQ => 'Q',
Piece::WR => 'R',
Piece::WB => 'B',
Piece::WN => 'N',
Piece::WP => 'P',
Piece::BK => 'k',
Piece::BQ => 'q',
Piece::BR => 'r',
Piece::BB => 'b',
Piece::BN => 'n',
Piece::BP => 'p',
Piece::NO => '.',
});
}
result.push_str("\n");
}
write!(f, "{}", result)
}
}
impl Board {
pub fn new() -> Board {
Board {
board: vec![vec![Piece::NO; 8]; 8],
white_to_move: true,
white_king_side_castling: false,
white_queen_side_castling: false,
black_king_side_castling: false,
black_queen_side_castling: false,
enpassant_target: None,
halfmove_clock: 0,
fullmove_number: 1,
}
}
pub fn from_fen(fen: &str) -> Result<Board, &'static str> {
let fields: Vec<&str> = fen.split(' ').collect();
if fields.len() == 1 || fields.len() == 6 {
let mut board = Board::new();
board.setup_board_from_fen(fields[0])?;
if fields.len() == 6 {
board.white_to_move = fields[1].to_lowercase() == "w";
board.white_king_side_castling = fields[2].contains("K");
board.white_queen_side_castling = fields[2].contains("Q");
board.black_king_side_castling = fields[2].contains("k");
board.black_queen_side_castling = fields[2].contains("q");
board.enpassant_target = if fields[3] == "-" {
None
} else {
match Square::from(fields[3]) {
Ok(square) => Some(square),
Err(e) => panic!("Invalid e.p. target {}", e),
}
};
if let Ok(count) = fields[4].parse() {
board.halfmove_clock = count;
}
if let Ok(count) = fields[5].parse() {
board.fullmove_number = count;
}
}
Ok(board)
} else {
Err("Invalid FEN description")
}
}
pub fn to_fen(&self) -> String {
let mut result = String::from("");
for row in (0..=7).rev() {
let mut count = 0;
for col in 0..=7 {
let piece = &self.board[col][row];
if *piece == Piece::NO {
count += 1;
} else {
if count > 0 {
result.push_str(format!("{}", count).as_str());
count = 0;
}
result.push_str(&piece_to_string(&piece));
}
}
if count > 0 {
result.push_str(format!("{}", count).as_str());
}
if row != 0 {
result.push_str("/")
}
}
if self.white_to_move {
result.push_str(" w ");
} else {
result.push_str(" b ");
}
if self.white_king_side_castling {
result.push_str("K");
}
if self.white_queen_side_castling {
result.push_str("Q");
}
if self.black_king_side_castling {
result.push_str("k");
}
if self.black_queen_side_castling {
result.push_str("q");
}
if let Some(square) = &self.enpassant_target {
result.push_str(format!(" {} ", square).as_str());
} else {
result.push_str(" - ");
}
result.push_str(format!("{} {}", &self.halfmove_clock, &self.fullmove_number).as_str());
result
}
pub fn start_position() -> Board {
Board::from_fen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1").unwrap()
}
pub fn must_have() -> BoardFilter {
BoardFilter::new()
}
pub fn player_to_move(&self) -> String {
if self.white_to_move {
String::from("White")
} else {
String::from("Black")
}
}
pub fn can_white_castle_king_side(&self) -> bool {
self.white_king_side_castling
}
pub fn can_white_castle_queen_side(&self) -> bool {
self.white_queen_side_castling
}
pub fn can_black_castle_king_side(&self) -> bool {
self.black_king_side_castling
}
pub fn can_black_castle_queen_side(&self) -> bool {
self.black_queen_side_castling
}
pub fn enpassant_target(&self) -> String {
match &self.enpassant_target {
Some(square) => square.to_string(),
None => String::from(""),
}
}
pub fn halfmove_clock(&self) -> usize {
self.halfmove_clock
}
pub fn fullmove_number(&self) -> usize {
self.fullmove_number
}
fn piece_counts(&self) -> HashMap<Piece, usize> {
let mut counts = HashMap::new();
counts.insert(Piece::WK, 0);
counts.insert(Piece::WQ, 0);
counts.insert(Piece::WB, 0);
counts.insert(Piece::WN, 0);
counts.insert(Piece::WR, 0);
counts.insert(Piece::WP, 0);
counts.insert(Piece::BK, 0);
counts.insert(Piece::BQ, 0);
counts.insert(Piece::BB, 0);
counts.insert(Piece::BN, 0);
counts.insert(Piece::BR, 0);
counts.insert(Piece::BP, 0);
for col in 0..=7 {
for row in 0..=7 {
let piece = &self.board[col][row];
if *piece != Piece::NO {
let count = counts.entry(*piece).or_insert(0);
*count += 1;
}
}
}
counts
}
pub(super) fn can_reach(&self, piece: &Piece, start: &Square, finish: &Square) -> bool {
match piece {
Piece::WK | Piece::BK => self.king_can_reach(&start, &finish),
Piece::WQ | Piece::BQ => self.queen_can_reach(&start, &finish),
Piece::WB | Piece::BB => self.bishop_can_reach(&start, &finish),
Piece::WN | Piece::BN => self.knight_can_reach(&start, &finish),
Piece::WR | Piece::BR => self.rook_can_reach(&start, &finish),
_ => false, }
}
fn king_can_reach(&self, start: &Square, finish: &Square) -> bool {
step_h(&start, &finish) <= 1 && step_v(&start, &finish) <= 1
}
fn queen_can_reach(&self, start: &Square, finish: &Square) -> bool {
self.rook_can_reach(&start, &finish) || self.bishop_can_reach(&start, &finish)
}
fn bishop_can_reach(&self, start: &Square, finish: &Square) -> bool {
let h = step_h(&start, &finish);
let v = step_v(&start, &finish);
if h == v {
let dirn_h = if finish.col > start.col { 1 } else { -1 };
let dirn_v = if finish.row > start.row { 1 } else { -1 };
(1..=(h - 1)).all(|i| {
let i = i as i32;
let col = if dirn_h > 0 {
start.col.checked_add((i * dirn_h) as usize).unwrap()
} else {
start.col.checked_sub((i * -dirn_h) as usize).unwrap()
};
let row = if dirn_v > 0 {
start.row.checked_add((i * dirn_v) as usize).unwrap()
} else {
start.row.checked_sub((i * -dirn_v) as usize).unwrap()
};
let square_to_check = Square { col, row };
self.get_square(&square_to_check).name == Piece::NO
})
} else {
false
}
}
fn knight_can_reach(&self, start: &Square, finish: &Square) -> bool {
let h = step_h(&start, &finish);
let v = step_v(&start, &finish);
(h == 2 && v == 1) || (h == 1 && v == 2)
}
fn rook_can_reach(&self, start: &Square, finish: &Square) -> bool {
if start.col == finish.col {
let row_1 = std::cmp::min(start.row, finish.row) + 1;
let row_2 = std::cmp::max(start.row, finish.row) - 1;
(row_1..=row_2).all(|row| self.get_coord(start.col, row).name == Piece::NO)
} else if start.row == finish.row {
let col_1 = std::cmp::min(start.col, finish.col) + 1;
let col_2 = std::cmp::max(start.col, finish.col) - 1;
(col_1..=col_2).all(|col| self.get_coord(col, start.row).name == Piece::NO)
} else {
false
}
}
fn find_pieces(&self, pieces: &Vec<Piece>) -> Vec<PieceDefn> {
let mut result = vec![];
for col in 0..=7 {
for row in 0..=7 {
if pieces.contains(&self.board[col][row]) {
result.push(PieceDefn {
name: self.board[col][row].clone(),
square: Square { col, row },
});
}
}
}
result
}
fn white_pieces(&self) -> Vec<PieceDefn> {
self.find_pieces(&vec![
Piece::WK,
Piece::WQ,
Piece::WR,
Piece::WB,
Piece::WN,
Piece::WP,
])
}
fn black_pieces(&self) -> Vec<PieceDefn> {
self.find_pieces(&vec![
Piece::BK,
Piece::BQ,
Piece::BR,
Piece::BB,
Piece::BN,
Piece::BP,
])
}
pub(super) fn set_coord(&mut self, col: usize, row: usize, piece: &Piece) {
self.board[col][row] = piece.clone();
}
pub(super) fn set_square(&mut self, square: &Square, piece: &Piece) {
self.set_coord(square.col, square.row, &piece)
}
pub(super) fn king_left_in_check(
&self,
piece: &Piece,
start: &Square,
finish: &Square,
) -> bool {
let mut test_board = self.clone();
test_board.set_square(&start, &Piece::NO);
test_board.set_square(&finish, &piece);
if self.white_to_move {
test_board.is_white_king_in_check()
} else {
test_board.is_black_king_in_check()
}
}
fn is_black_king_in_check(&self) -> bool {
let black_king = &self.locations_of(&Piece::BK)[0];
self.white_pieces()
.iter()
.any(|defn| self.can_reach(&defn.name, &defn.square, &black_king))
}
fn is_white_king_in_check(&self) -> bool {
let white_king = &self.locations_of(&Piece::WK)[0];
self.black_pieces()
.iter()
.any(|defn| self.can_reach(&defn.name, &defn.square, &white_king))
}
fn locations_of(&self, piece: &Piece) -> Vec<Square> {
self.locations_within(piece, "")
}
pub(super) fn locations_within(&self, piece: &Piece, hint: &str) -> Vec<Square> {
let hint = hint.to_lowercase();
let mut result = vec![];
for col in 0..=7 {
for row in 0..=7 {
let square = Square { col, row };
if square.to_string().contains(&hint) && self.board[col][row] == *piece {
result.push(square);
}
}
}
result
}
pub(super) fn get_coord(&self, col: usize, row: usize) -> PieceDefn {
PieceDefn {
name: self.board[col][row].clone(),
square: Square { col, row },
}
}
pub(super) fn get_square(&self, square: &Square) -> PieceDefn {
self.get_coord(square.col, square.row)
}
fn setup_board_from_fen(&mut self, fen: &str) -> Result<(), &'static str> {
let rows: Vec<&str> = fen.split('/').collect();
let mut result = Ok(());
if rows.len() == 8 {
for row in 0..=7 {
let mut col = 0;
for ch in rows[row].chars() {
match ch {
'K' | 'k' | 'Q' | 'q' | 'R' | 'r' | 'B' | 'b' | 'N' | 'n' | 'P' | 'p' => {
self.board[col][7 - row] = piece_from(&format!("{}", ch));
col += 1;
()
}
'1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' => {
if let Some(blanks) = ch.to_digit(10) {
col += blanks as usize;
}
}
_ => {
result = Err("Invalid character in FEN description");
}
}
}
}
} else {
result = Err("Invalid FEN description");
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_must_initialise_empty_board() {
let board = Board::new();
for col in 0..=7 {
for row in 0..=7 {
assert_eq!(
PieceDefn {
name: Piece::NO,
square: Square { col, row }
},
board.get_coord(col, row)
);
}
}
}
#[test]
fn test_invalid_fen() {
assert_eq!(
Err("Invalid FEN description"),
Board::from_fen("1 2 3 4 5 6")
);
assert_eq!(Err("Invalid FEN description"), Board::from_fen("1"));
assert_eq!(Err("Invalid FEN description"), Board::from_fen("1 2 3"));
}
fn test_piece(board: &Board, col: usize, row: usize, piece: &str) {
assert_eq!(
PieceDefn {
name: piece_from(&piece),
square: Square { col, row }
},
board.get_coord(col, row)
);
}
#[test]
fn test_setup_fen() {
if let Ok(board) = Board::from_fen("3k4/7p/8/8/8/8/2P5/R7 b Kq - 5 46") {
test_piece(&board, 0, 0, "R");
test_piece(&board, 2, 1, "P");
test_piece(&board, 7, 6, "p");
test_piece(&board, 3, 7, "k");
assert!(!board.white_to_move);
assert!(board.white_king_side_castling);
assert!(!board.white_queen_side_castling);
assert!(!board.black_king_side_castling);
assert!(board.black_queen_side_castling);
assert_eq!(None, board.enpassant_target);
assert_eq!(5, board.halfmove_clock);
assert_eq!(46, board.fullmove_number);
} else {
assert!(false);
}
}
#[test]
fn test_to_fen() {
let fens = [
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1",
"3k4/7p/8/8/8/8/2P5/R7 b Kq - 5 46",
];
for fen in fens.iter() {
let position = Board::from_fen(fen).unwrap();
assert_eq!(fen, &position.to_fen());
}
}
#[test]
fn test_makes_start() {
let board = Board::start_position();
assert!(board.white_to_move);
assert!(board.white_king_side_castling);
assert!(board.white_queen_side_castling);
assert!(board.black_king_side_castling);
assert!(board.black_queen_side_castling);
assert_eq!(None, board.enpassant_target);
assert_eq!(0, board.halfmove_clock);
assert_eq!(1, board.fullmove_number);
test_piece(&board, 0, 0, "R");
test_piece(&board, 1, 0, "N");
test_piece(&board, 2, 0, "B");
test_piece(&board, 3, 0, "Q");
test_piece(&board, 4, 0, "K");
test_piece(&board, 5, 0, "B");
test_piece(&board, 6, 0, "N");
test_piece(&board, 7, 0, "R");
test_piece(&board, 0, 7, "r");
test_piece(&board, 1, 7, "n");
test_piece(&board, 2, 7, "b");
test_piece(&board, 3, 7, "q");
test_piece(&board, 4, 7, "k");
test_piece(&board, 5, 7, "b");
test_piece(&board, 6, 7, "n");
test_piece(&board, 7, 7, "r");
for col in 0..=7 {
test_piece(&board, col, 1, "P");
test_piece(&board, col, 6, "p");
for row in 2..=5 {
assert_eq!(
PieceDefn {
name: Piece::NO,
square: Square { col, row }
},
board.get_coord(col, row)
);
}
}
}
#[test]
fn test_piece_counts() {
let board = Board::start_position();
let mut counts = board.piece_counts();
assert_eq!(12, counts.keys().len());
let count = counts.entry(Piece::WR).or_insert(0);
assert_eq!(2, *count);
let count = counts.entry(Piece::WB).or_insert(0);
assert_eq!(2, *count);
let count = counts.entry(Piece::WN).or_insert(0);
assert_eq!(2, *count);
let count = counts.entry(Piece::WK).or_insert(0);
assert_eq!(1, *count);
let count = counts.entry(Piece::WQ).or_insert(0);
assert_eq!(1, *count);
let count = counts.entry(Piece::WP).or_insert(0);
assert_eq!(8, *count);
let count = counts.entry(Piece::BR).or_insert(0);
assert_eq!(2, *count);
let count = counts.entry(Piece::BB).or_insert(0);
assert_eq!(2, *count);
let count = counts.entry(Piece::BN).or_insert(0);
assert_eq!(2, *count);
let count = counts.entry(Piece::BK).or_insert(0);
assert_eq!(1, *count);
let count = counts.entry(Piece::BQ).or_insert(0);
assert_eq!(1, *count);
let count = counts.entry(Piece::BP).or_insert(0);
assert_eq!(8, *count);
let board = Board::from_fen("3k4/7p/8/8/8/8/2P5/R7").unwrap();
let mut counts = board.piece_counts();
assert_eq!(12, counts.keys().len());
let count = counts.entry(Piece::WR).or_insert(0);
assert_eq!(1, *count);
let count = counts.entry(Piece::WB).or_insert(0);
assert_eq!(0, *count);
let count = counts.entry(Piece::WN).or_insert(0);
assert_eq!(0, *count);
let count = counts.entry(Piece::WK).or_insert(0);
assert_eq!(0, *count);
let count = counts.entry(Piece::WQ).or_insert(0);
assert_eq!(0, *count);
let count = counts.entry(Piece::WP).or_insert(0);
assert_eq!(1, *count);
let count = counts.entry(Piece::BR).or_insert(0);
assert_eq!(0, *count);
let count = counts.entry(Piece::BB).or_insert(0);
assert_eq!(0, *count);
let count = counts.entry(Piece::BN).or_insert(0);
assert_eq!(0, *count);
let count = counts.entry(Piece::BK).or_insert(0);
assert_eq!(1, *count);
let count = counts.entry(Piece::BQ).or_insert(0);
assert_eq!(0, *count);
let count = counts.entry(Piece::BP).or_insert(0);
assert_eq!(1, *count);
}
#[test]
fn test_locations() {
let c1 = Square { col: 2, row: 0 };
let f1 = Square { col: 5, row: 0 };
assert_eq!(
vec![c1.clone(), f1.clone()],
Board::start_position().locations_of(&Piece::WB)
);
assert!(Board::start_position()
.locations_within(&Piece::WB, "Q")
.is_empty());
assert_eq!(
vec![f1.clone()],
Board::start_position().locations_within(&Piece::WB, "f")
);
assert_eq!(
vec![c1.clone(), f1.clone()],
Board::start_position().locations_within(&Piece::WB, "1")
);
}
#[test]
fn test_find_pieces() {
let start = Board::start_position();
let knights = start.find_pieces(&vec![Piece::WN, Piece::BN]);
assert_eq!(4, knights.len());
let squares: Vec<String> = knights
.iter()
.map(|piece| piece.square.to_string())
.collect();
assert_eq!(vec!["b1", "b8", "g1", "g8"], squares);
}
const DEFNS: &[(&str, &str, &str, bool, &str)] = &[
("R", "e4", "e8", true, "8/8/8/8/4R3/8/8/8"),
("R", "e4", "d8", false, "8/8/8/8/4R3/8/8/8"),
("r", "e4", "c4", true, "8/8/8/8/4r3/8/8/8"),
("B", "d4", "h8", true, "8/8/8/8/3R4/8/8/8"),
("b", "g8", "a2", true, "6b1/8/8/8/8/8/8/8"),
("B", "d4", "d8", false, "8/8/8/8/3B4/8/8/8"),
("b", "b4", "e1", false, "8/8/8/8/1b6/2N5/8/4K3"),
("Q", "e4", "e8", true, "8/8/8/8/4Q3/8/8/8"),
("Q", "e4", "d8", false, "8/8/8/8/4Q3/8/8/8"),
("Q", "e4", "h1", true, "8/8/8/8/4Q3/8/8/8"),
("Q", "e4", "a1", false, "8/8/8/8/4Q3/8/8/8"),
("K", "e1", "f1", true, "8/8/8/8/8/8/8/4K3"),
("K", "e1", "f2", true, "8/8/8/8/8/8/8/4K3"),
("K", "e1", "f3", false, "8/8/8/8/8/8/8/4K3"),
("K", "e1", "g1", false, "8/8/8/8/8/8/8/4K3"),
("K", "e4", "d3", true, "8/8/8/8/4R3/8/8/8"),
("N", "g1", "f3", true, "8/8/8/8/8/8/8/6N1"),
("N", "g1", "g3", false, "8/8/8/8/8/8/8/6N1"),
("n", "e5", "d7", true, "8/8/8/4n3/8/8/8/8"),
];
#[test]
fn test_reach_squares() {
for defn in DEFNS {
let board = Board::from_fen(&defn.4).expect("Internal error in tests");
let start = Square::from(&defn.1).unwrap();
let finish = Square::from(&defn.2).unwrap();
if defn.3 {
assert!(board.can_reach(&piece_from(defn.0), &start, &finish));
} else {
assert!(!board.can_reach(&piece_from(defn.0), &start, &finish));
}
}
}
}