chess/
fen.rs

1use std::fmt;
2use std::str::FromStr;
3
4use crate::board::BoardState;
5use crate::errors::FenParseError;
6use crate::log_and_return_error;
7use crate::movegen::{MovegenFlags, Piece, PieceColour, PieceType, Square};
8use crate::position::{Pos64, Position, ABOVE_BELOW};
9
10#[derive(Debug, Clone, Copy)]
11pub struct FEN {
12    pos64: Pos64,
13    side: PieceColour,
14    movegen_flags: MovegenFlags,
15    halfmove_count: u32,
16    move_count: u32,
17}
18
19impl FromStr for FEN {
20    type Err = FenParseError;
21
22    fn from_str(s: &str) -> Result<Self, Self::Err> {
23        let fen_vec: Vec<&str> = s.split(' ').collect();
24        // check if the FEN string has the correct number of fields, accept the last two as optional with default values given in BoardState
25        if fen_vec.len() < 4 || fen_vec.len() > 6 {
26            return Err(FenParseError(format!(
27                "Invalid number of fields in FEN string: {}. Expected at least 4, max 6",
28                fen_vec.len()
29            )));
30        }
31        let mut fen = Self::new();
32        // first field of FEN defines the piece positions
33        fen.parse_pos_field(fen_vec[0])?;
34        // second filed of FEN defines which side it is to move, either 'w' or 'b'
35        fen.parse_side_field(fen_vec[1])?;
36        // third field of FEN defines castling flags
37        fen.parse_castling_flags(fen_vec[2])?;
38        // fourth field of FEN defines en passant flag, it gives notation of the square the pawn jumped over
39        fen.parse_en_passant_flag(fen_vec[3])?;
40        // set last two fields if they exist, otherwise default values are 0 and 1 already set in new()
41        fen.parse_halfmove_move_count(fen_vec.get(4).copied(), fen_vec.get(5).copied())?;
42
43        Ok(fen)
44    }
45}
46
47impl fmt::Display for FEN {
48    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
49        let mut fen_str = String::new();
50
51        let mut empty_count: i32 = 0;
52        for (idx, sq) in self.pos64.iter().enumerate() {
53            match sq {
54                Square::Piece(p) => {
55                    if empty_count > 0 {
56                        fen_str.push_str(empty_count.to_string().as_str());
57                        empty_count = 0;
58                    }
59
60                    match p.ptype {
61                        PieceType::Pawn => match p.pcolour {
62                            PieceColour::White => fen_str.push('P'),
63                            PieceColour::Black => fen_str.push('p'),
64                        },
65                        PieceType::Knight => match p.pcolour {
66                            PieceColour::White => fen_str.push('N'),
67                            PieceColour::Black => fen_str.push('n'),
68                        },
69                        PieceType::Bishop => match p.pcolour {
70                            PieceColour::White => fen_str.push('B'),
71                            PieceColour::Black => fen_str.push('b'),
72                        },
73                        PieceType::Rook => match p.pcolour {
74                            PieceColour::White => fen_str.push('R'),
75                            PieceColour::Black => fen_str.push('r'),
76                        },
77                        PieceType::Queen => match p.pcolour {
78                            PieceColour::White => fen_str.push('Q'),
79                            PieceColour::Black => fen_str.push('q'),
80                        },
81                        PieceType::King => match p.pcolour {
82                            PieceColour::White => fen_str.push('K'),
83                            PieceColour::Black => fen_str.push('k'),
84                        },
85                    }
86                }
87                Square::Empty => {
88                    empty_count += 1;
89                }
90            }
91
92            // new rank insert '/', except when at last index, then only insert empty count if it's > 0
93            if (idx + 1) % 8 == 0 {
94                if empty_count > 0 {
95                    fen_str.push_str(empty_count.to_string().as_str());
96                    empty_count = 0;
97                }
98                if idx != 63 {
99                    fen_str.push('/');
100                }
101            }
102        }
103        fen_str.push(' ');
104
105        match self.side {
106            PieceColour::White => fen_str.push('w'),
107            PieceColour::Black => fen_str.push('b'),
108        }
109        fen_str.push(' ');
110
111        if self.movegen_flags.white_castle_short {
112            fen_str.push('K');
113        }
114        if self.movegen_flags.white_castle_long {
115            fen_str.push('Q');
116        }
117        if self.movegen_flags.black_castle_short {
118            fen_str.push('k');
119        }
120        if self.movegen_flags.black_castle_long {
121            fen_str.push('q');
122        }
123        if !(self.movegen_flags.white_castle_short
124            || self.movegen_flags.white_castle_long
125            || self.movegen_flags.black_castle_short
126            || self.movegen_flags.black_castle_long)
127        {
128            fen_str.push('-');
129        }
130        fen_str.push(' ');
131
132        match self.movegen_flags.en_passant {
133            Some(idx) => {
134                if self.side == PieceColour::White {
135                    fen_str.push_str(index_to_notation(idx - ABOVE_BELOW).as_str());
136                } else {
137                    fen_str.push_str(index_to_notation(idx + ABOVE_BELOW).as_str());
138                }
139            }
140            None => {
141                fen_str.push('-');
142            }
143        }
144        fen_str.push(' ');
145        fen_str.push_str(&format!("{} {}", self.halfmove_count, self.move_count));
146
147        write!(f, "{}", fen_str)
148    }
149}
150
151impl From<&BoardState> for FEN {
152    fn from(board_state: &BoardState) -> Self {
153        let mut fen = Self::from(board_state.position());
154        fen.halfmove_count = board_state.halfmove_count();
155        fen.move_count = board_state.move_count();
156        fen
157    }
158}
159
160impl From<&Position> for FEN {
161    fn from(pos: &Position) -> Self {
162        // default halfmove and move count to 0 and 1 respectively as Position does not store this information
163        Self {
164            pos64: pos.pos64,
165            side: pos.side,
166            movegen_flags: pos.movegen_flags,
167            halfmove_count: 0,
168            move_count: 1,
169        }
170    }
171}
172
173impl FEN {
174    fn new() -> Self {
175        Self {
176            pos64: Pos64::default(),
177            side: PieceColour::White,
178            movegen_flags: MovegenFlags::default(),
179            halfmove_count: 0,
180            move_count: 1,
181        }
182    }
183
184    pub fn pos64(&self) -> Pos64 {
185        self.pos64
186    }
187
188    pub fn side(&self) -> PieceColour {
189        self.side
190    }
191
192    pub fn movegen_flags(&self) -> MovegenFlags {
193        self.movegen_flags
194    }
195
196    pub fn halfmove_count(&self) -> u32 {
197        self.halfmove_count
198    }
199
200    pub fn move_count(&self) -> u32 {
201        self.move_count
202    }
203
204    fn parse_pos_field(&mut self, field: &str) -> Result<(), FenParseError> {
205        let mut pos = Pos64::default();
206        let mut rank_start_idx = 0;
207        // check for multiple kings, should be the only issue in terms of pieces on the board
208        let mut wking_num = 0;
209        let mut bking_num = 0;
210        for rank in field.split('/') {
211            // check to see if there is 8 squares in a rank.
212            let mut square_count = 0;
213            for c in rank.chars() {
214                if c.is_ascii_digit() {
215                    let num = c.to_digit(10).unwrap();
216                    square_count += num;
217                } else {
218                    square_count += 1;
219                }
220            }
221            if square_count != 8 {
222                return Err(FenParseError(format!(
223                    "Invalid number of squares in rank: {}. Expected 8, got {}",
224                    rank, square_count
225                )));
226            }
227
228            let mut i = 0;
229            for c in rank.chars() {
230                let square = match c {
231                    'p' => Square::Piece(Piece {
232                        pcolour: PieceColour::Black,
233                        ptype: PieceType::Pawn,
234                    }),
235                    'P' => Square::Piece(Piece {
236                        pcolour: PieceColour::White,
237                        ptype: PieceType::Pawn,
238                    }),
239                    'r' => Square::Piece(Piece {
240                        pcolour: PieceColour::Black,
241                        ptype: PieceType::Rook,
242                    }),
243                    'R' => Square::Piece(Piece {
244                        pcolour: PieceColour::White,
245                        ptype: PieceType::Rook,
246                    }),
247                    'n' => Square::Piece(Piece {
248                        pcolour: PieceColour::Black,
249                        ptype: PieceType::Knight,
250                    }),
251                    'N' => Square::Piece(Piece {
252                        pcolour: PieceColour::White,
253                        ptype: PieceType::Knight,
254                    }),
255                    'b' => Square::Piece(Piece {
256                        pcolour: PieceColour::Black,
257                        ptype: PieceType::Bishop,
258                    }),
259                    'B' => Square::Piece(Piece {
260                        pcolour: PieceColour::White,
261                        ptype: PieceType::Bishop,
262                    }),
263                    'q' => Square::Piece(Piece {
264                        pcolour: PieceColour::Black,
265                        ptype: PieceType::Queen,
266                    }),
267                    'Q' => Square::Piece(Piece {
268                        pcolour: PieceColour::White,
269                        ptype: PieceType::Queen,
270                    }),
271                    'k' => {
272                        bking_num += 1;
273                        Square::Piece(Piece {
274                            pcolour: PieceColour::Black,
275                            ptype: PieceType::King,
276                        })
277                    }
278                    'K' => {
279                        wking_num += 1;
280                        Square::Piece(Piece {
281                            pcolour: PieceColour::White,
282                            ptype: PieceType::King,
283                        })
284                    }
285                    x if x.is_ascii_digit() => {
286                        for _ in 0..x.to_digit(10).unwrap() {
287                            pos[i + rank_start_idx] = Square::Empty;
288                            i += 1;
289                        }
290                        continue; // skip the below square assignment for pieces
291                    }
292                    other => {
293                        let err = FenParseError(format!("Invalid char in first field: {}", other));
294                        log_and_return_error!(err)
295                    }
296                };
297                pos[i + rank_start_idx] = square;
298                i += 1;
299            }
300            rank_start_idx += 8; // next rank
301        }
302
303        if wking_num > 1 || bking_num > 1 {
304            let err = FenParseError(format!(
305                "Multiple kings (white: {}, black: {}) in FEN field: {}",
306                wking_num, bking_num, field
307            ));
308            log_and_return_error!(err)
309        }
310
311        self.pos64 = pos;
312        Ok(())
313    }
314
315    fn parse_side_field(&mut self, field: &str) -> Result<(), FenParseError> {
316        match field {
317            "w" => {
318                self.side = PieceColour::White;
319            }
320            "b" => {
321                self.side = PieceColour::Black;
322            }
323            other => {
324                return Err(FenParseError(format!(
325                    "Invalid second field: {}. Expected 'w' or 'b'",
326                    other
327                )));
328            }
329        }
330        Ok(())
331    }
332
333    fn parse_castling_flags(&mut self, field: &str) -> Result<(), FenParseError> {
334        for c in field.chars() {
335            match c {
336                'q' => {
337                    self.movegen_flags.black_castle_long = true;
338                }
339                'Q' => {
340                    self.movegen_flags.white_castle_long = true;
341                }
342                'k' => {
343                    self.movegen_flags.black_castle_short = true;
344                }
345                'K' => {
346                    self.movegen_flags.white_castle_short = true;
347                }
348                '-' => {}
349                other => {
350                    return Err(FenParseError(format!(
351                        "Invalid char in third field: {}",
352                        other
353                    )));
354                }
355            }
356        }
357        Ok(())
358    }
359
360    fn parse_en_passant_flag(&mut self, field: &str) -> Result<(), FenParseError> {
361        if field != "-" {
362            let ep_mv_idx = notation_to_index(field)?;
363
364            // error if index is out of bounds. FEN defines the index behind the pawn that moved, so valid indexes are only 16->47 (excluded top and bottom two ranks)
365            if !(16..=47).contains(&ep_mv_idx) {
366                return Err(FenParseError(format!(
367                    "Invalid en passant square: {}. Index is out of bounds",
368                    field
369                )));
370            }
371
372            // in our struct however, we store the idx of the pawn to be captured
373            let ep_flag = if self.side == PieceColour::White {
374                ep_mv_idx + ABOVE_BELOW
375            } else {
376                ep_mv_idx - ABOVE_BELOW
377            };
378            self.movegen_flags.en_passant = Some(ep_flag);
379
380            // set polyglot en passant flag if the ep_flag is beside a pawn of side to move colour
381            if self.pos64.polyglot_is_pawn_beside(ep_flag, self.side) {
382                self.movegen_flags.polyglot_en_passant = Some(ep_flag);
383            }
384        }
385        Ok(())
386    }
387
388    fn parse_halfmove_move_count(
389        &mut self,
390        hm_field: Option<&str>,
391        m_field: Option<&str>,
392    ) -> Result<(), FenParseError> {
393        if let Some(hm) = hm_field {
394            self.halfmove_count = if let Ok(halfmove_count) = hm.parse::<u32>() {
395                halfmove_count
396            } else {
397                let err = FenParseError(format!("Error parsing halfmove count: {}", hm));
398                log_and_return_error!(err)
399            };
400        };
401
402        if let Some(m) = m_field {
403            self.move_count = if let Ok(move_count) = m.parse::<u32>() {
404                move_count
405            } else {
406                let err = FenParseError(format!("Error parsing move count: {}", m));
407                log_and_return_error!(err)
408            };
409        }
410        Ok(())
411    }
412}
413
414#[inline]
415fn notation_to_index(n: &str) -> Result<usize, FenParseError> {
416    if n.len() != 2
417        || n.chars().next().unwrap() < 'a'
418        || n.chars().next().unwrap() > 'h'
419        || n.chars().nth(1).unwrap() < '1'
420        || n.chars().nth(1).unwrap() > '8'
421    {
422        log_and_return_error!(FenParseError(format!(
423            "Invalid notation ({}) when converting to index:",
424            n
425        )))
426    }
427    let file: char = n.chars().next().unwrap();
428    let rank: char = n.chars().nth(1).unwrap();
429    let rank_starts = [56, 48, 40, 32, 24, 16, 8, 0]; // 1st to 8th rank starting indexes
430    let file_offset = match file {
431        'a' => 0,
432        'b' => 1,
433        'c' => 2,
434        'd' => 3,
435        'e' => 4,
436        'f' => 5,
437        'g' => 6,
438        'h' => 7,
439        _ => unreachable!(), // see error checking at start of function
440    };
441    let rank_digit = rank.to_digit(10).unwrap();
442    Ok(file_offset + rank_starts[(rank_digit - 1) as usize])
443}
444
445#[inline]
446fn index_to_notation(i: usize) -> String {
447    let file = match i % 8 {
448        0 => 'a',
449        1 => 'b',
450        2 => 'c',
451        3 => 'd',
452        4 => 'e',
453        5 => 'f',
454        6 => 'g',
455        7 => 'h',
456        _ => ' ',
457    };
458    let rank_num = 8 - i / 8;
459    let rank = char::from_digit(rank_num.try_into().unwrap(), 10).unwrap();
460    format!("{}{}", file, rank)
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn test_fen_from_str_valid() {
469        let fen_str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
470        let fen = FEN::from_str(fen_str).unwrap();
471        assert_eq!(fen.to_string(), fen_str);
472    }
473
474    #[test]
475    fn test_fen_from_str_invalid_fields() {
476        let fen_str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq";
477        assert!(FEN::from_str(fen_str).is_err());
478    }
479
480    #[test]
481    fn test_fen_from_str_invalid_piece_positions() {
482        let fen_str = "rnbqkbnr/pppppppp/0/8/8/8/PPPPPPPP/RNBQKBNKK w KQkq - 0 1";
483        assert!(FEN::from_str(fen_str).is_err());
484    }
485
486    #[test]
487    fn test_fen_from_str_invalid_side() {
488        let fen_str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR xw KQkq - 0 1";
489        assert!(FEN::from_str(fen_str).is_err());
490    }
491
492    #[test]
493    fn test_fen_from_str_invalid_castling_flags() {
494        let fen_str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KdQkq - 0 1";
495        assert!(FEN::from_str(fen_str).is_err());
496    }
497
498    #[test]
499    fn test_fen_from_str_invalid_en_passant() {
500        let fen_str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq x2 0 1";
501        assert!(FEN::from_str(fen_str).is_err());
502    }
503
504    #[test]
505    fn test_fen_from_str_invalid_halfmove_count() {
506        let fen_str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - x 1";
507        assert!(FEN::from_str(fen_str).is_err());
508    }
509
510    #[test]
511    fn test_fen_from_str_invalid_move_count() {
512        let fen_str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 x";
513        assert!(FEN::from_str(fen_str).is_err());
514    }
515
516    #[test]
517    fn test_fen_to_string() {
518        let fen = FEN::new();
519        let fen_str = "8/8/8/8/8/8/8/8 w - - 0 1";
520        assert_eq!(fen.to_string(), fen_str);
521    }
522
523    #[test]
524    fn test_fen_from_board_state() {
525        let board_state = BoardState::new_starting();
526        let fen = FEN::from(&board_state);
527        let fen_str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
528        assert_eq!(fen.to_string(), fen_str);
529    }
530
531    #[test]
532    fn test_fen_to_board_state() {
533        let fen_str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";
534        let fen = FEN::from_str(fen_str).unwrap();
535        let board_state: BoardState = fen.into();
536        let fen_from_board = FEN::from(&board_state);
537        assert_eq!(fen_from_board.to_string(), fen_str);
538    }
539
540    #[test]
541    fn test_notation_to_index() {
542        assert_eq!(notation_to_index("a1").unwrap(), 56);
543        assert_eq!(notation_to_index("h8").unwrap(), 7);
544        assert_eq!(notation_to_index("d4").unwrap(), 35);
545        assert!(notation_to_index("i9").is_err());
546        assert!(notation_to_index("a9").is_err());
547        assert!(notation_to_index("z1").is_err());
548    }
549
550    #[test]
551    fn test_index_to_notation() {
552        assert_eq!(index_to_notation(56), "a1");
553        assert_eq!(index_to_notation(7), "h8");
554        assert_eq!(index_to_notation(35), "d4");
555    }
556}