Skip to main content

timecat/utils/
moves.rs

1use super::*;
2
3#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
4#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash)]
5pub struct Move {
6    source: Square,
7    dest: Square,
8    promotion: Option<PieceType>,
9}
10
11impl Move {
12    #[inline]
13    pub const unsafe fn new_unchecked(
14        source: Square,
15        dest: Square,
16        promotion: Option<PieceType>,
17    ) -> Self {
18        Self {
19            source,
20            dest,
21            promotion,
22        }
23    }
24
25    #[inline]
26    pub const fn new(source: Square, dest: Square, promotion: Option<PieceType>) -> Result<Self> {
27        let move_ = unsafe { Self::new_unchecked(source, dest, promotion) };
28        if source.to_int() == dest.to_int() {
29            return Err(TimecatError::SameSourceAndDestination { move_ });
30        }
31        if let Some(Pawn) | Some(King) = promotion {
32            return Err(TimecatError::InvalidPromotion { move_ });
33        }
34        Ok(move_)
35    }
36
37    #[inline]
38    pub const fn get_source(&self) -> Square {
39        self.source
40    }
41
42    #[inline]
43    pub const fn get_dest(&self) -> Square {
44        self.dest
45    }
46
47    #[inline]
48    pub const fn get_promotion(&self) -> Option<PieceType> {
49        self.promotion
50    }
51
52    #[inline]
53    pub fn from_uci(uci: &str) -> Result<Self> {
54        uci.parse()
55    }
56
57    pub fn from_san(position: &ChessPosition, san: &str) -> Result<Self> {
58        // TODO: Make the logic better
59        let san = san.trim().replace('0', "O");
60        for move_ in position.generate_legal_moves() {
61            if move_.san(position).unwrap() == san {
62                return Ok(move_);
63            }
64        }
65        Err(TimecatError::InvalidSanMoveString { s: san.into() })
66    }
67
68    pub fn from_lan(position: &ChessPosition, lan: &str) -> Result<Self> {
69        // TODO: Make the logic better
70        let lan = lan.trim().replace('0', "O");
71        for move_ in position.generate_legal_moves() {
72            if move_.lan(position).unwrap() == lan {
73                return Ok(move_);
74            }
75        }
76        Err(TimecatError::InvalidLanMoveString { s: lan.into() })
77    }
78
79    pub fn algebraic_without_suffix(
80        self,
81        position: &ChessPosition,
82        long: bool,
83    ) -> Result<Cow<'static, str>> {
84        let source = self.get_source();
85        let dest = self.get_dest();
86
87        // Castling.
88        if position.is_castling(self) {
89            return if dest.get_file() < source.get_file() {
90                Ok(Cow::Borrowed("O-O-O"))
91            } else {
92                Ok(Cow::Borrowed("O-O"))
93            };
94        }
95
96        let piece = position.get_piece_type_at(source).ok_or_else(|| {
97            TimecatError::InvalidSanOrLanMove {
98                valid_or_null_move: self.into(),
99                fen: position.get_fen().into(),
100            }
101        })?;
102        let capture = position.is_capture(self);
103        let mut san = if piece == Pawn {
104            String::new()
105        } else {
106            piece.to_colored_piece_str(White).into()
107        };
108
109        if long {
110            write_unchecked!(san, "{}", source);
111        } else if piece != Pawn {
112            // Get ambiguous move candidates.
113            // Relevant candidates: not exactly the current move,
114            // but to the same square.
115            let mut others = BitBoard::EMPTY;
116            let from_mask =
117                position.get_piece_mask(piece) & position.self_occupied() & !source.to_bitboard();
118            let to_mask = dest.to_bitboard();
119            for candidate in position.generate_masked_legal_moves(from_mask, to_mask) {
120                others |= candidate.get_source().to_bitboard();
121            }
122
123            // Disambiguate.
124            if !others.is_empty() {
125                let (mut row, mut column) = (false, false);
126                if !(others & source.get_rank_bb()).is_empty() {
127                    column = true;
128                }
129                if !(others & source.get_file_bb()).is_empty() {
130                    row = true;
131                } else {
132                    column = true;
133                }
134                if column {
135                    san.push(
136                        "abcdefgh"
137                            .chars()
138                            .nth(source.get_file().to_index())
139                            .unwrap(),
140                    );
141                }
142                if row {
143                    write_unchecked!(san, "{}", source.get_rank().to_index() + 1);
144                }
145            }
146        } else if capture {
147            san.push(
148                "abcdefgh"
149                    .chars()
150                    .nth(source.get_file().to_index())
151                    .unwrap(),
152            );
153        }
154
155        // Captures.
156        if capture {
157            san.push('x');
158        } else if long {
159            san.push('-');
160        }
161
162        // Destination square.
163        write_unchecked!(san, "{}", dest);
164
165        // Promotion.
166        if let Some(promotion) = self.get_promotion() {
167            write_unchecked!(san, "={}", promotion.to_colored_piece_str(White));
168        }
169
170        Ok(san.into())
171    }
172
173    pub fn algebraic_and_new_position(
174        self,
175        position: &ChessPosition,
176        long: bool,
177    ) -> Result<(Cow<'static, str>, ChessPosition)> {
178        let san = self.algebraic_without_suffix(position, long)?;
179
180        // Look ahead for check or checkmate.
181        let new_position = position.make_move_new(self);
182        let is_checkmate = new_position.is_checkmate();
183
184        // Add check or checkmate suffix.
185        let san = if is_checkmate {
186            san + "#"
187        } else if new_position.is_check() {
188            san + "+"
189        } else {
190            san
191        };
192        Ok((san, new_position))
193    }
194
195    #[cfg(feature = "pyo3")]
196    fn from_py_move(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
197        let source = ob.getattr("from_square")?.extract()?;
198        let dest = ob.getattr("to_square")?.extract()?;
199        let promotion = ob.getattr("promotion")?.extract()?;
200        Ok(Self::new(source, dest, promotion)?)
201    }
202}
203
204macro_rules! generate_move_error {
205    ($s: ident) => {
206        TimecatError::InvalidUciMoveString {
207            s: $s.to_string().into(),
208        }
209    };
210}
211
212impl FromStr for Move {
213    type Err = TimecatError;
214
215    fn from_str(mut s: &str) -> Result<Self> {
216        s = s.trim();
217        if s.len() > 5 {
218            return Err(generate_move_error!(s));
219        }
220        let source = s
221            .get(0..2)
222            .ok_or_else(|| generate_move_error!(s))?
223            .parse()
224            .map_err(|_| generate_move_error!(s))?;
225        let dest = s
226            .get(2..4)
227            .ok_or_else(|| generate_move_error!(s))?
228            .parse()
229            .map_err(|_| generate_move_error!(s))?;
230        let promotion = (s.len() > 4)
231            .then(|| {
232                s.get(4..)
233                    .ok_or_else(|| generate_move_error!(s))?
234                    .parse()
235                    .map_err(|_| generate_move_error!(s))
236            })
237            .transpose()?;
238        Self::new(source, dest, promotion)
239    }
240}
241
242impl fmt::Display for Move {
243    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
244        match self.promotion {
245            Some(piece) => write!(f, "{}{}{}", self.source, self.dest, piece),
246            None => write!(f, "{}{}", self.source, self.dest),
247        }
248    }
249}
250
251#[cfg(feature = "pyo3")]
252impl<'source> FromPyObject<'source> for Move {
253    fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
254        if let Ok(move_text) = ob.extract::<&str>()
255            && let Ok(move_) = move_text.parse()
256        {
257            return Ok(move_);
258        }
259        if let Ok(move_) = Self::from_py_move(ob) {
260            return Ok(move_);
261        }
262        Err(Pyo3Error::Pyo3TypeConversionError {
263            from: ob.to_string().into(),
264            to: std::any::type_name::<Self>().into(),
265        }
266        .into())
267    }
268}
269
270#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
271#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash)]
272#[repr(transparent)]
273pub struct ValidOrNullMove(Option<Move>);
274
275impl ValidOrNullMove {
276    #[expect(non_upper_case_globals)]
277    pub const NullMove: Self = Self(None);
278
279    #[inline]
280    pub const unsafe fn new_unchecked(
281        source: Square,
282        dest: Square,
283        promotion: Option<PieceType>,
284    ) -> Self {
285        Self(Some(Move::new_unchecked(source, dest, promotion)))
286    }
287
288    #[inline]
289    pub fn new(source: Square, dest: Square, promotion: Option<PieceType>) -> Result<Self> {
290        Ok(Self(Some(Move::new(source, dest, promotion)?)))
291    }
292
293    #[inline]
294    pub const fn into_inner(&self) -> Option<&Move> {
295        self.0.as_ref()
296    }
297
298    #[inline]
299    pub const fn into_inner_mut(&mut self) -> Option<&mut Move> {
300        self.0.as_mut()
301    }
302
303    #[inline]
304    pub const fn is_null(&self) -> bool {
305        self.into_inner().is_none()
306    }
307
308    #[inline]
309    pub fn get_source(&self) -> Option<Square> {
310        self.into_inner().map(|move_| move_.source)
311    }
312
313    #[inline]
314    pub fn get_dest(&self) -> Option<Square> {
315        self.into_inner().map(|move_| move_.dest)
316    }
317
318    #[inline]
319    pub fn get_promotion(&self) -> Option<PieceType> {
320        self.into_inner()?.promotion
321    }
322
323    pub fn from_san(position: &ChessPosition, san: &str) -> Result<Self> {
324        // TODO: Make the logic better
325        let san = san.trim();
326        if san == "--" || san == "0000" {
327            return Ok(Self::NullMove);
328        }
329        Ok(Move::from_san(position, san)?.into())
330    }
331
332    pub fn from_lan(position: &ChessPosition, lan: &str) -> Result<Self> {
333        // TODO: Make the logic better
334        let lan = lan.trim();
335        if lan == "--" || lan == "0000" {
336            return Ok(Self::NullMove);
337        }
338        Ok(Move::from_lan(position, lan)?.into())
339    }
340
341    #[inline]
342    pub fn algebraic_without_suffix(
343        self,
344        position: &ChessPosition,
345        long: bool,
346    ) -> Result<Cow<'static, str>> {
347        self.map_or(Ok(Cow::Borrowed("--")), |move_| {
348            move_.algebraic_without_suffix(position, long)
349        })
350    }
351
352    #[inline]
353    pub fn algebraic_and_new_position(
354        self,
355        position: &ChessPosition,
356        long: bool,
357    ) -> Result<(Cow<'static, str>, ChessPosition)> {
358        self.map_or_else(
359            || Ok((Cow::Borrowed("--"), position.null_move()?)),
360            |move_| move_.algebraic_and_new_position(position, long),
361        )
362    }
363}
364
365impl Default for ValidOrNullMove {
366    fn default() -> Self {
367        Self::NullMove
368    }
369}
370
371impl fmt::Display for ValidOrNullMove {
372    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
373        if let Some(move_) = self.into_inner() {
374            write!(f, "{}", move_)
375        } else {
376            write!(f, "--")
377        }
378    }
379}
380
381impl FromStr for ValidOrNullMove {
382    type Err = TimecatError;
383
384    #[inline]
385    fn from_str(s: &str) -> Result<Self> {
386        Ok(match s.trim() {
387            "--" | "0000" => Self::NullMove,
388            s_trimmed => Move::from_str(s_trimmed)?.into(),
389        })
390    }
391}
392
393impl From<Move> for ValidOrNullMove {
394    #[inline]
395    fn from(value: Move) -> Self {
396        Some(value).into()
397    }
398}
399
400impl From<&Move> for ValidOrNullMove {
401    #[inline]
402    fn from(value: &Move) -> Self {
403        Some(value).into()
404    }
405}
406
407impl From<Option<Move>> for ValidOrNullMove {
408    #[inline]
409    fn from(value: Option<Move>) -> Self {
410        Self(value)
411    }
412}
413
414impl From<Option<&Move>> for ValidOrNullMove {
415    #[inline]
416    fn from(value: Option<&Move>) -> Self {
417        value.copied().into()
418    }
419}
420
421impl Deref for ValidOrNullMove {
422    type Target = Option<Move>;
423
424    #[inline]
425    fn deref(&self) -> &Self::Target {
426        &self.0
427    }
428}
429
430impl DerefMut for ValidOrNullMove {
431    #[inline]
432    fn deref_mut(&mut self) -> &mut Self::Target {
433        &mut self.0
434    }
435}
436
437#[cfg(feature = "pyo3")]
438impl<'source> FromPyObject<'source> for ValidOrNullMove {
439    fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
440        if let Ok(move_) = ob.extract::<Move>() {
441            return Ok(move_.into());
442        }
443        if let Ok(move_text) = ob.extract::<&str>()
444            && let Ok(valid_or_null_move) = move_text.parse()
445        {
446            return Ok(valid_or_null_move);
447        }
448        Err(Pyo3Error::Pyo3TypeConversionError {
449            from: ob.to_string().into(),
450            to: std::any::type_name::<Self>().into(),
451        }
452        .into())
453    }
454}
455
456#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
457#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
458pub struct WeightedMove {
459    pub move_: Move,
460    pub weight: MoveWeight,
461}
462
463impl PartialOrd for WeightedMove {
464    #[inline]
465    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
466        Some(self.cmp(other))
467    }
468}
469
470impl Ord for WeightedMove {
471    #[inline]
472    fn cmp(&self, other: &Self) -> Ordering {
473        self.weight.cmp(&other.weight)
474    }
475}
476
477impl WeightedMove {
478    #[inline]
479    pub fn new(move_: Move, weight: MoveWeight) -> Self {
480        Self { move_, weight }
481    }
482}
483
484impl fmt::Display for WeightedMove {
485    #[inline]
486    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
487        write!(f, "({}, {})", self.move_, self.weight)
488    }
489}
490
491#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
492#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
493pub enum CastleMoveType {
494    KingSide,
495    QueenSide,
496}
497
498#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
499#[derive(Clone, Copy, PartialEq, Eq, Default, Debug, Hash)]
500pub enum MoveType {
501    Capture {
502        is_en_passant: bool,
503    },
504    Castle(CastleMoveType),
505    DoublePawnPush,
506    Promotion(PieceType),
507    #[default]
508    Other,
509}
510
511#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
512pub struct MoveWithInfo {
513    valid_or_null_move: ValidOrNullMove,
514    type_: MoveType,
515    is_check: bool,
516}