rlevo-environments 0.2.0

RL benchmark environments and landscapes for rlevo (internal crate — use `rlevo` for the full API)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
//! Chess move representation and action encoding modeled after AlphaZero Chess.
//!
//! This module defines the core data structures and traits for representing chess moves
//! in a way suitable for reinforcement learning agents. It provides a dense, efficient
//! encoding scheme that maps moves to discrete action indices for use with neural networks.
//!
//! AlphaZero represents the action space of chess as an $8 \times 8 \times 73$ tensor,
//! totaling 4,672 possible move slots.
//!
//! Because a neural network requires a fixed-size output, AlphaZero must provide a value
//! for every possible move a piece could technically make on an $8 \times 8$ board, even
//! if that move is illegal in the current position.
//!
//! # The Action Space ($8 \times 8 \times 73$)
//! The output is structured such that the first two dimensions ($8 \times 8$) represent the
//! square from which a piece is moved. The third dimension (73 planes) represents the type of
//! move being made.
//!
//! **The 73 "Move Type" Planes**:
//! - **56 Queen-like Moves**: These covers moves in 8 directions (N, NE, E, SE, S, SW, W, NW).
//!   For each direction, there are 7 possible distances (1 to 7 squares).
//!   - _Note_: If a pawn moves to the black rank and promotes to a Queen, it is represented in
//!     these planes.
//! - **8 Knight Moves**: These represent the 8 "L" shapes a knight can jump.
//! - **9 Underpromotion Planes**: When a pawn reaches the 8th rank, it can promote to a Knight,
//!   Bishop, or Rook. These 9 planes cover the 3 possible piece types $\times 3$ possible "exit"
//!   directions (capture left, move straight, capture right).
//!
//! **How illegal moves are handled**: The network outputs probabilites for all 4,672 moves.
//! However, before the move is actually selected, AlphaZero applies a mask. It sets the
//! probability of all illegal moves to zero and renormalizes the remaining legal moves so they
//! sum to 1.
//!
//! # How the Network Evaluates a Position
//! AlphaZero uses a "dual-head" architecture. After the input (the $8 \times 8 \times 119$
//! tensor) passes through the main body of the residual layers, the network splits into two
//! distinct output heads:
//!
//! **The Policy Head ($\pi$)**
//! The policy head outputs the $8 \times 8 \times 73$ tensor described above. This is a
//! probability distribution over all possible moves.
//! - **Purpose**: It acts as the "intuition". It tells the system which moves look the most
//!   promising to investigate further.
//! - **Training**: It is trained to match the move distributions found by the search tree during
//!   self-play.
//!
//! **The Value Head ($v$)**
//! The value head outputs a single scalar value between -1 and 1.
//! - **Interpretation**:
//!   - $+1$: High confidence of a win.
//!   - $0$: Prediction of a draw.
//!   - $-1$: High confidence of a loss.
//! - **Purpose**: It replaces the traditional "evaluation function" (which might count points for
//!   pieces). Instead of calculating a score like $+1.5$, it predicts the expected outcome of the
//!   game from the current position.
//!
//! # The Synergy: Evaluation + Search
//! The magic of AlphaZero is how these two outputs work together inside the Monte Carlo Tree
//! Search (MCTS):
//! 1. **Prioritization**: When MCTS explores a new branch, it uses the Policy ($\pi$) to decide
//!   which moves to try first. This prevents the engine from wasting time on "human-obvious"
//!   blunders.
//! 2. **Leaf Evaluation**: In old-school MCTS, you would play "random games" until the end to see
//!   who won. AlphaZero doesn't do that. As soon as it hits a new position in its search tree, it
//!   asks the Value Head ($v$) for an estimate.
//! 3. **Backpropagation**: This value estimate is "rippled" back up the tree, updating the
//!   quality score of every move that led to that position.
//!
//!
//! # Action Space Structure
//!
//! The `ChessMove` struct implements both the `Action`, `MultiDiscreteAction`, and
//! `ActionTensorConvertible`  traits, enabling integration with the rlevo-core framework:
//!
//! - **Validation**: Moves are validated by ensuring both source and destination squares
//!   are valid (0-63). Legal move validation is delegated to the environment.
//! - **Discretization**: Each move can be converted to/from the indicies in (8,8,73), enabling
//!   efficient neural network output layers and policy representations.
//!
//! # Usage Notes
//!
//! `ChessMove` stores the source square, destination square, and an optional
//! promotion piece. The `to_indices` / `from_indices` conversion methods that
//! map a move into the `(8, 8, 73)` action tensor are scaffolded but not yet
//! wired to the `Action` trait — see `compute_move_plane` and
//! `decode_move_plane` for the encoding logic that will back those methods once
//! the `Environment` impl lands.
//!
//! # Implementation Notes
//!
//! - All moves are immutable and copyable (`Copy` trait), suitable for rapid iteration
//!   during agent training.
//! - The fixed action space (8, 8, 73) simplifies policy networks, which output a single
//!   tensor rather than variable-length action lists.
//! - Illegal moves (e.g., moving into check) are encoded but marked as invalid by the
//!   environment. Agents learn to avoid them through reward signals.

use crate::games::chess::board::Square;

/// Promotion piece types for pawn promotion moves.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PromotionPiece {
    /// Promote to Queen (covered in queen-like moves for queen promotion)
    Queen,
    /// Promote to Rook (underpromotion)
    Rook,
    /// Promote to Bishop (underpromotion)
    Bishop,
    /// Promote to Knight (underpromotion)
    Knight,
}

/// Chess move representation compatible with AlphaZero's action space.
///
/// Encodes moves in an 8×8×73 tensor format where:
/// - First dimension: source rank (0-7)
/// - Second dimension: source file (0-7)
/// - Third dimension: move type plane (0-72)
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ChessMove {
    /// Source square (0-63, where 0 is a1 and 63 is h8)
    pub from: Square,
    /// Destination square (0-63)
    pub to: Square,
    /// Optional promotion piece (None for non-promotion moves, Some for promotions)
    pub promotion: Option<PromotionPiece>,
}

impl ChessMove {
    /// Creates a new chess move from source to destination.
    pub fn new(from: Square, to: Square) -> Self {
        Self {
            from,
            to,
            promotion: None,
        }
    }

    /// Creates a new chess move with a promotion.
    pub fn new_with_promotion(from: Square, to: Square, promotion: PromotionPiece) -> Self {
        Self {
            from,
            to,
            promotion: Some(promotion),
        }
    }

    /// Returns the source rank (0-7).
    #[inline]
    pub fn from_rank(&self) -> u8 {
        self.from.rank()
    }

    /// Returns the source file (0-7).
    #[inline]
    pub fn from_file(&self) -> u8 {
        self.from.file()
    }

    /// Returns the destination rank (0-7).
    #[inline]
    pub fn to_rank(&self) -> u8 {
        self.to.rank()
    }

    /// Returns the destination file (0-7).
    #[inline]
    pub fn to_file(&self) -> u8 {
        self.to.file()
    }

    /// Computes the move plane index (0–72) for this move in the AlphaZero action space.
    ///
    /// The plane index encodes the *type* of move made from the source square:
    ///
    /// | Range   | Category                                    |
    /// |---------|---------------------------------------------|
    /// | 0–55    | Queen-like: 8 directions × 7 distances      |
    /// | 56–63   | Knight: 8 L-shaped jumps                    |
    /// | 64–72   | Underpromotion: 3 pieces × 3 directions     |
    ///
    /// Queen promotions are encoded as a queen-like north move of distance 1
    /// from rank 6 (no separate plane needed). Underpromotion planes cover
    /// promotion to Knight, Bishop, or Rook only, each with three pawn-exit
    /// directions (capture-left, push-straight, capture-right).
    ///
    /// # Panics
    ///
    /// Panics if the move delta does not match any of the 73 encoded patterns
    /// (i.e., the move is geometrically impossible on an 8×8 board).
    fn compute_move_plane(&self) -> usize {
        let from_rank = self.from_rank() as i8;
        let from_file = self.from_file() as i8;
        let to_rank = self.to_rank() as i8;
        let to_file = self.to_file() as i8;

        let delta_rank = to_rank - from_rank;
        let delta_file = to_file - from_file;

        // Check for knight moves (planes 56-63)
        let knight_moves = [
            (2, 1),   // Plane 56
            (1, 2),   // Plane 57
            (-1, 2),  // Plane 58
            (-2, 1),  // Plane 59
            (-2, -1), // Plane 60
            (-1, -2), // Plane 61
            (1, -2),  // Plane 62
            (2, -1),  // Plane 63
        ];

        for (i, &(dr, df)) in knight_moves.iter().enumerate() {
            if delta_rank == dr && delta_file == df {
                return 56 + i;
            }
        }

        // Check for underpromotions (planes 64-72)
        if let Some(promo) = self.promotion {
            // Underpromotions only apply to pawns reaching rank 7 (for white) or rank 0 (for black)
            // We only encode underpromotions (not queen promotions, which use queen-like moves)
            if matches!(
                promo,
                PromotionPiece::Knight | PromotionPiece::Bishop | PromotionPiece::Rook
            ) {
                let piece_offset = match promo {
                    PromotionPiece::Knight => 0,
                    PromotionPiece::Bishop => 3,
                    PromotionPiece::Rook => 6,
                    _ => unreachable!(),
                };

                let direction = match delta_file {
                    -1 => 0, // capture left
                    0 => 1,  // move straight
                    1 => 2,  // capture right
                    _ => panic!("Invalid promotion move"),
                };

                return 64 + piece_offset + direction;
            }
        }

        // Queen-like moves (planes 0-55): 8 directions × 7 distances
        // Direction encoding: N, NE, E, SE, S, SW, W, NW
        let direction_index = if delta_rank > 0 && delta_file == 0 {
            0 // North
        } else if delta_rank > 0 && delta_file > 0 && delta_rank == delta_file {
            1 // Northeast
        } else if delta_rank == 0 && delta_file > 0 {
            2 // East
        } else if delta_rank < 0 && delta_file > 0 && -delta_rank == delta_file {
            3 // Southeast
        } else if delta_rank < 0 && delta_file == 0 {
            4 // South
        } else if delta_rank < 0 && delta_file < 0 && delta_rank == delta_file {
            5 // Southwest
        } else if delta_rank == 0 && delta_file < 0 {
            6 // West
        } else if delta_rank > 0 && delta_file < 0 && delta_rank == -delta_file {
            7 // Northwest
        } else {
            panic!(
                "Invalid queen-like move: delta_rank={}, delta_file={}",
                delta_rank, delta_file
            );
        };

        let distance = delta_rank.abs().max(delta_file.abs()) as usize;
        assert!(
            (1..=7).contains(&distance),
            "Invalid distance: {}",
            distance
        );

        direction_index * 7 + (distance - 1)
    }

    /// Decodes a move plane index (0–72) and source square into a destination square and optional
    /// promotion piece.
    ///
    /// This is the inverse of [`Self::compute_move_plane`]. When the plane falls in the
    /// underpromotion range (64–72) the returned `Option<PromotionPiece>` is `Some`; for all
    /// other planes it is `None`. Queen promotions are recovered as plain queen-like north moves
    /// with no promotion marker.
    ///
    /// Destination coordinates are clamped to the board boundary, so callers should validate the
    /// result against the current legal-move list before use.
    fn decode_move_plane(from: Square, plane: usize) -> (Square, Option<PromotionPiece>) {
        let from_rank = (from.rank()) as i8;
        let from_file = (from.file()) as i8;

        let (delta_rank, delta_file, promotion) = if plane < 56 {
            // Queen-like moves (planes 0-55)
            let direction = plane / 7;
            let distance = (plane % 7 + 1) as i8;

            let (dr, df) = match direction {
                0 => (1, 0),   // North
                1 => (1, 1),   // Northeast
                2 => (0, 1),   // East
                3 => (-1, 1),  // Southeast
                4 => (-1, 0),  // South
                5 => (-1, -1), // Southwest
                6 => (0, -1),  // West
                7 => (1, -1),  // Northwest
                _ => unreachable!(),
            };

            (dr * distance, df * distance, None)
        } else if plane < 64 {
            // Knight moves (planes 56-63)
            let knight_index = plane - 56;
            let (dr, df) = match knight_index {
                0 => (2, 1),
                1 => (1, 2),
                2 => (-1, 2),
                3 => (-2, 1),
                4 => (-2, -1),
                5 => (-1, -2),
                6 => (1, -2),
                7 => (2, -1),
                _ => unreachable!(),
            };
            (dr, df, None)
        } else {
            // Underpromotions (planes 64-72)
            let promo_index = plane - 64;
            let piece = match promo_index / 3 {
                0 => PromotionPiece::Knight,
                1 => PromotionPiece::Bishop,
                2 => PromotionPiece::Rook,
                _ => unreachable!(),
            };
            let direction = promo_index % 3;
            let df = match direction {
                0 => -1, // capture left
                1 => 0,  // move straight
                2 => 1,  // capture right
                _ => unreachable!(),
            };
            // Assume white pawn promotion (moving north)
            (1, df, Some(piece))
        };

        let to_rank = from_rank + delta_rank;
        let to_file = from_file + delta_file;

        // Clamp to board boundaries
        let to_rank = to_rank.clamp(0, 7) as u8;
        let to_file = to_file.clamp(0, 7) as u8;
        let to = to_rank * 8 + to_file;

        (Square(to), promotion)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_square_indices() {
        // Test a1 (square 0)
        let mv = ChessMove::new(Square(0), Square(16)); // a1 to a3
        assert_eq!(mv.from_rank(), 0);
        assert_eq!(mv.from_file(), 0);

        // Test h8 (square 63)
        let mv = ChessMove::new(Square(63), Square(47)); // h8 to h6
        assert_eq!(mv.from_rank(), 7);
        assert_eq!(mv.from_file(), 7);

        // Test e2 (square 12)
        let mv = ChessMove::new(Square(12), Square(28)); // e2 to e4
        assert_eq!(mv.from_rank(), 1);
        assert_eq!(mv.from_file(), 4);
    }

    #[test]
    fn test_knight_move_encoding() {
        // Knight move from e4 (28) to f6 (45): +2 rank, +1 file
        let _mv = ChessMove::new(Square(28), Square(45));
        // let indices = mv.to_indices();
        // assert_eq!(indices[0], 3); // rank 3 (e4)
        // assert_eq!(indices[1], 4); // file 4 (e-file)
        // assert_eq!(indices[2], 56); // First knight move plane

        // Reconstruct
        // let reconstructed = ChessMove::from_indices(indices);
        // assert_eq!(mv.from, reconstructed.from);
        // assert_eq!(mv.to, reconstructed.to);
    }

    #[test]
    fn test_queen_move_encoding() {
        // North move: e2 to e4 (2 squares north)
        let _mv = ChessMove::new(Square(12), Square(28)); // e2 to e4
        // let indices = mv.to_indices();
        // assert_eq!(indices[0], 1); // rank 1
        // assert_eq!(indices[1], 4); // file 4
        // assert_eq!(indices[2], 1); // North direction, distance 2 (plane 0 + 1)

        // Reconstruct
        // let reconstructed = ChessMove::from_indices(indices);
        // assert_eq!(mv.from, reconstructed.from);
        // assert_eq!(mv.to, reconstructed.to);
    }

    #[test]
    fn test_promotion_encoding() {
        // Pawn promotion: e7 to e8 with knight promotion
        let _mv = ChessMove::new_with_promotion(Square(52), Square(60), PromotionPiece::Knight);
        // let indices = mv.to_indices();
        // assert_eq!(indices[0], 6); // rank 6
        // assert_eq!(indices[1], 4); // file 4
        // assert_eq!(indices[2], 65); // Knight underpromotion straight (64 + 0*3 + 1)

        // Reconstruct
        // let reconstructed = ChessMove::from_indices(indices);
        // assert_eq!(mv.from, reconstructed.from);
        // assert_eq!(mv.promotion, reconstructed.promotion);
    }

    // #[test]
    // fn test_action_space() {
    //     let space = ChessMove.shape();
    //     assert_eq!(space, [8, 8, 73]);
    // }

    // #[test]
    // fn test_is_valid() {
    //     let valid_move = ChessMove::new(Square(12), Square(28));
    //     assert!(valid_move.is_valid());

    //     let invalid_same_square = ChessMove::new(Square(12), Square(12));
    //     assert!(!invalid_same_square.is_valid());
    // }
}