chess_vector_engine/
position_encoder.rs

1use chess::{Board, Color, Piece, Square};
2use ndarray::Array1;
3use rayon::prelude::*;
4
5/// Basic position encoder that converts chess positions to vectors
6#[derive(Clone)]
7pub struct PositionEncoder {
8    /// Dimension of the output vector
9    vector_size: usize,
10}
11
12impl PositionEncoder {
13    pub fn new(vector_size: usize) -> Self {
14        Self { vector_size }
15    }
16
17    /// Get the vector size
18    pub fn vector_size(&self) -> usize {
19        self.vector_size
20    }
21
22    /// Encode a chess position into a vector
23    pub fn encode(&self, board: &Board) -> Array1<f32> {
24        let mut features = Vec::with_capacity(self.vector_size);
25
26        // Basic encoding strategy:
27        // 1. Piece positions (64 squares * 12 piece types = 768 features)
28        // 2. Game state features (castling, en passant, etc.)
29        // 3. Material balance
30        // 4. Positional features
31
32        // 1. Piece position encoding
33        self.encode_piece_positions(board, &mut features);
34
35        // 2. Game state
36        self.encode_game_state(board, &mut features);
37
38        // 3. Material balance
39        self.encode_material_balance(board, &mut features);
40
41        // 4. Basic positional features
42        self.encode_positional_features(board, &mut features);
43
44        // 5. Tactical pattern features
45        self.encode_tactical_patterns(board, &mut features);
46
47        // Pad or truncate to desired size
48        features.resize(self.vector_size, 0.0);
49
50        Array1::from(features)
51    }
52
53    /// Encode piece positions on the board using dense representation
54    fn encode_piece_positions(&self, board: &Board, features: &mut Vec<f32>) {
55        // Enhanced encoding: 64 squares * 12 piece types (6 pieces * 2 colors) = 768 features
56        // This creates more distinctive representations
57        for square in chess::ALL_SQUARES {
58            // One-hot encoding for each piece type and color
59            let mut square_features = vec![0.0; 12]; // 6 pieces * 2 colors
60
61            if let Some(piece) = board.piece_on(square) {
62                let color = board.color_on(square).unwrap();
63                let piece_idx = match piece {
64                    chess::Piece::Pawn => 0,
65                    chess::Piece::Knight => 1,
66                    chess::Piece::Bishop => 2,
67                    chess::Piece::Rook => 3,
68                    chess::Piece::Queen => 4,
69                    chess::Piece::King => 5,
70                };
71
72                let color_offset = if color == chess::Color::White { 0 } else { 6 };
73                square_features[piece_idx + color_offset] = 1.0;
74            }
75
76            features.extend(square_features);
77        }
78
79        // Add piece interaction features - attacks/defends relationships
80        self.encode_piece_interactions(board, features);
81    }
82
83    /// Encode piece interactions (attacks, defends)
84    fn encode_piece_interactions(&self, board: &Board, features: &mut Vec<f32>) {
85        // Count attacks by piece type for each color
86        let mut white_attacks = vec![0.0; 6]; // pawn, knight, bishop, rook, queen, king
87        let mut black_attacks = vec![0.0; 6];
88
89        // Simplified attack counting - in practice would use chess engine's attack detection
90        for square in chess::ALL_SQUARES {
91            if let Some(piece) = board.piece_on(square) {
92                let color = board.color_on(square).unwrap();
93                let piece_idx = match piece {
94                    Piece::Pawn => 0,
95                    Piece::Knight => 1,
96                    Piece::Bishop => 2,
97                    Piece::Rook => 3,
98                    Piece::Queen => 4,
99                    Piece::King => 5,
100                };
101
102                // Simple attack count based on piece mobility
103                let attack_value = match piece {
104                    Piece::Pawn => 1.0,
105                    Piece::Knight => 3.0,
106                    Piece::Bishop => 3.0,
107                    Piece::Rook => 5.0,
108                    Piece::Queen => 9.0,
109                    Piece::King => 1.0,
110                };
111
112                if color == Color::White {
113                    white_attacks[piece_idx] += attack_value;
114                } else {
115                    black_attacks[piece_idx] += attack_value;
116                }
117            }
118        }
119
120        // Add attack features (12 more features)
121        features.extend(white_attacks);
122        features.extend(black_attacks);
123    }
124
125    /// Encode game state (castling rights, en passant, etc.)
126    fn encode_game_state(&self, board: &Board, features: &mut Vec<f32>) {
127        // Castling rights (4 features)
128        features.push(if board.castle_rights(Color::White).has_kingside() {
129            1.0
130        } else {
131            0.0
132        });
133        features.push(if board.castle_rights(Color::White).has_queenside() {
134            1.0
135        } else {
136            0.0
137        });
138        features.push(if board.castle_rights(Color::Black).has_kingside() {
139            1.0
140        } else {
141            0.0
142        });
143        features.push(if board.castle_rights(Color::Black).has_queenside() {
144            1.0
145        } else {
146            0.0
147        });
148
149        // En passant
150        features.push(if board.en_passant().is_some() {
151            1.0
152        } else {
153            0.0
154        });
155
156        // Side to move
157        features.push(if board.side_to_move() == Color::White {
158            1.0
159        } else {
160            0.0
161        });
162
163        // Halfmove clock (simplified - just use 0 for now)
164        features.push(0.0);
165    }
166
167    /// Encode material balance
168    fn encode_material_balance(&self, board: &Board, features: &mut Vec<f32>) {
169        let piece_values = [
170            (Piece::Pawn, 1),
171            (Piece::Knight, 3),
172            (Piece::Bishop, 3),
173            (Piece::Rook, 5),
174            (Piece::Queen, 9),
175            (Piece::King, 0),
176        ];
177
178        for (piece, _value) in piece_values {
179            let white_count = board.pieces(piece) & board.color_combined(Color::White);
180            let black_count = board.pieces(piece) & board.color_combined(Color::Black);
181
182            features.push(white_count.popcnt() as f32);
183            features.push(black_count.popcnt() as f32);
184            features.push((white_count.popcnt() as i32 - black_count.popcnt() as i32) as f32);
185        }
186    }
187
188    /// Encode basic positional features
189    fn encode_positional_features(&self, board: &Board, features: &mut Vec<f32>) {
190        // King safety (distance to center, surrounded pieces)
191        for color in [Color::White, Color::Black] {
192            let king_square = board.king_square(color);
193            // Distance from center
194            let center_distance = self.distance_to_center(king_square);
195            features.push(center_distance);
196
197            // Number of pieces around king (3x3 area)
198            let surrounding_pieces = self.count_surrounding_pieces(board, king_square);
199            features.push(surrounding_pieces as f32);
200        }
201
202        // Piece mobility (simplified)
203        for color in [Color::White, Color::Black] {
204            let mobility = self.calculate_mobility(board, color);
205            features.push(mobility as f32);
206        }
207
208        // Add pawn structure features
209        self.encode_pawn_structure(board, features);
210
211        // Add tactical patterns
212        self.encode_tactical_patterns(board, features);
213
214        // Add center control
215        self.encode_center_control(board, features);
216
217        // Add piece coordination patterns
218        self.encode_piece_coordination(board, features);
219    }
220
221    /// Calculate distance from square to center of board
222    fn distance_to_center(&self, square: Square) -> f32 {
223        let file = square.get_file().to_index() as f32;
224        let rank = square.get_rank().to_index() as f32;
225        let center_file = 3.5;
226        let center_rank = 3.5;
227
228        ((file - center_file).powi(2) + (rank - center_rank).powi(2)).sqrt()
229    }
230
231    /// Count pieces in 3x3 area around a square
232    fn count_surrounding_pieces(&self, board: &Board, center: Square) -> u32 {
233        let mut count = 0;
234        let center_file = center.get_file().to_index() as i32;
235        let center_rank = center.get_rank().to_index() as i32;
236
237        for file_offset in -1..=1 {
238            for rank_offset in -1..=1 {
239                if file_offset == 0 && rank_offset == 0 {
240                    continue;
241                }
242
243                let new_file = center_file + file_offset;
244                let new_rank = center_rank + rank_offset;
245
246                if (0..8).contains(&new_file) && (0..8).contains(&new_rank) {
247                    let square = Square::make_square(
248                        chess::Rank::from_index(new_rank as usize),
249                        chess::File::from_index(new_file as usize),
250                    );
251                    if board.piece_on(square).is_some() {
252                        count += 1;
253                    }
254                }
255            }
256        }
257        count
258    }
259
260    /// Calculate basic mobility for a color
261    fn calculate_mobility(&self, board: &Board, color: Color) -> u32 {
262        // Simplified: count number of pieces that can move
263        let pieces = board.color_combined(color);
264        let mut mobility = 0;
265
266        for _square in *pieces {
267            // This is a simplified mobility calculation
268            // In a real implementation, you'd generate all legal moves
269            mobility += 1;
270        }
271
272        mobility
273    }
274
275    /// Encode pawn structure features
276    fn encode_pawn_structure(&self, board: &Board, features: &mut Vec<f32>) {
277        for color in [Color::White, Color::Black] {
278            let pawns = board.pieces(Piece::Pawn) & board.color_combined(color);
279
280            // Count doubled pawns (simplified)
281            let mut doubled_pawns = 0;
282            for file in 0..8 {
283                let mut file_pawn_count = 0;
284                for rank in 0..8 {
285                    let square = chess::Square::make_square(
286                        chess::Rank::from_index(rank),
287                        chess::File::from_index(file),
288                    );
289                    if (pawns & chess::BitBoard::from_square(square)).popcnt() > 0 {
290                        file_pawn_count += 1;
291                    }
292                }
293                if file_pawn_count > 1 {
294                    doubled_pawns += file_pawn_count - 1;
295                }
296            }
297            features.push(doubled_pawns as f32);
298
299            // Count isolated pawns (simplified)
300            let mut isolated_pawns = 0;
301            for file in 0..8 {
302                let mut file_has_pawn = false;
303                for rank in 0..8 {
304                    let square = chess::Square::make_square(
305                        chess::Rank::from_index(rank),
306                        chess::File::from_index(file),
307                    );
308                    if (pawns & chess::BitBoard::from_square(square)).popcnt() > 0 {
309                        file_has_pawn = true;
310                        break;
311                    }
312                }
313
314                if file_has_pawn {
315                    // Check adjacent files
316                    let mut has_adjacent = false;
317                    for adj_file in [file.saturating_sub(1), file + 1] {
318                        if adj_file < 8 && adj_file != file {
319                            for rank in 0..8 {
320                                let adj_square = chess::Square::make_square(
321                                    chess::Rank::from_index(rank),
322                                    chess::File::from_index(adj_file),
323                                );
324                                if (pawns & chess::BitBoard::from_square(adj_square)).popcnt() > 0 {
325                                    has_adjacent = true;
326                                    break;
327                                }
328                            }
329                        }
330                        if has_adjacent {
331                            break;
332                        }
333                    }
334
335                    if !has_adjacent {
336                        isolated_pawns += 1;
337                    }
338                }
339            }
340            features.push(isolated_pawns as f32);
341        }
342    }
343
344    /// Encode tactical patterns
345    fn encode_tactical_patterns(&self, board: &Board, features: &mut Vec<f32>) {
346        // Count pins, forks, and other tactical motifs (simplified)
347        for color in [Color::White, Color::Black] {
348            let opponent_color = if color == Color::White {
349                Color::Black
350            } else {
351                Color::White
352            };
353
354            // Count potential pins by looking at pieces on same lines as enemy king
355            let enemy_king_square = board.king_square(opponent_color);
356            let mut potential_pins = 0;
357
358            // Check for pieces that could pin along ranks/files
359            let rooks_queens = (board.pieces(Piece::Rook) | board.pieces(Piece::Queen))
360                & board.color_combined(color);
361            for square in chess::ALL_SQUARES {
362                if (rooks_queens & chess::BitBoard::from_square(square)).popcnt() > 0
363                    && (square.get_rank() == enemy_king_square.get_rank()
364                        || square.get_file() == enemy_king_square.get_file())
365                {
366                    potential_pins += 1;
367                }
368            }
369
370            // Check for pieces that could pin along diagonals
371            let bishops_queens = (board.pieces(Piece::Bishop) | board.pieces(Piece::Queen))
372                & board.color_combined(color);
373            for square in chess::ALL_SQUARES {
374                if (bishops_queens & chess::BitBoard::from_square(square)).popcnt() > 0 {
375                    let rank_diff = (square.get_rank().to_index() as i32
376                        - enemy_king_square.get_rank().to_index() as i32)
377                        .abs();
378                    let file_diff = (square.get_file().to_index() as i32
379                        - enemy_king_square.get_file().to_index() as i32)
380                        .abs();
381                    if rank_diff == file_diff && rank_diff > 0 {
382                        potential_pins += 1;
383                    }
384                }
385            }
386
387            features.push(potential_pins as f32);
388        }
389
390        // Add center control and piece coordination features
391        self.encode_center_control(board, features);
392        self.encode_piece_coordination(board, features);
393    }
394
395    /// Encode center control
396    fn encode_center_control(&self, board: &Board, features: &mut Vec<f32>) {
397        // Check control of central squares (d4, d5, e4, e5)
398        let center_squares = [
399            chess::Square::D4,
400            chess::Square::D5,
401            chess::Square::E4,
402            chess::Square::E5,
403        ];
404
405        for color in [Color::White, Color::Black] {
406            let mut center_control = 0.0;
407
408            for &square in &center_squares {
409                // Check if we have a piece on this square
410                if let Some(_piece) = board.piece_on(square) {
411                    if board.color_on(square) == Some(color) {
412                        center_control += 2.0; // Extra weight for occupying center
413                    }
414                }
415
416                // Count pieces that could attack this square (simplified)
417                let pieces = board.color_combined(color);
418                for piece_square in chess::ALL_SQUARES {
419                    if (pieces & chess::BitBoard::from_square(piece_square)).popcnt() > 0 {
420                        if let Some(piece) = board.piece_on(piece_square) {
421                            let can_attack = match piece {
422                                Piece::Pawn => {
423                                    let rank_diff = (square.get_rank().to_index() as i32
424                                        - piece_square.get_rank().to_index() as i32)
425                                        .abs();
426                                    let file_diff = (square.get_file().to_index() as i32
427                                        - piece_square.get_file().to_index() as i32)
428                                        .abs();
429                                    rank_diff == 1 && file_diff == 1
430                                }
431                                Piece::Knight => {
432                                    let rank_diff = (square.get_rank().to_index() as i32
433                                        - piece_square.get_rank().to_index() as i32)
434                                        .abs();
435                                    let file_diff = (square.get_file().to_index() as i32
436                                        - piece_square.get_file().to_index() as i32)
437                                        .abs();
438                                    (rank_diff == 2 && file_diff == 1)
439                                        || (rank_diff == 1 && file_diff == 2)
440                                }
441                                _ => false, // Simplified - would need more complex logic for sliding pieces
442                            };
443
444                            if can_attack {
445                                center_control += 0.5;
446                            }
447                        }
448                    }
449                }
450            }
451
452            features.push(center_control);
453        }
454    }
455
456    /// Encode piece coordination patterns
457    fn encode_piece_coordination(&self, board: &Board, features: &mut Vec<f32>) {
458        for color in [Color::White, Color::Black] {
459            let mut coordination_score = 0.0;
460
461            // Count pieces defending each other
462            let pieces = board.color_combined(color);
463            for square1 in chess::ALL_SQUARES {
464                if (pieces & chess::BitBoard::from_square(square1)).popcnt() > 0 {
465                    for square2 in chess::ALL_SQUARES {
466                        if (pieces & chess::BitBoard::from_square(square2)).popcnt() > 0
467                            && square1 != square2
468                        {
469                            // Simplified check for mutual protection
470                            let rank_diff = (square1.get_rank().to_index() as i32
471                                - square2.get_rank().to_index() as i32)
472                                .abs();
473                            let file_diff = (square1.get_file().to_index() as i32
474                                - square2.get_file().to_index() as i32)
475                                .abs();
476
477                            if rank_diff <= 2 && file_diff <= 2 {
478                                coordination_score += 0.1;
479                            }
480                        }
481                    }
482                }
483            }
484
485            features.push(coordination_score);
486        }
487    }
488
489    /// Calculate similarity between two position vectors
490    pub fn similarity(&self, vec1: &Array1<f32>, vec2: &Array1<f32>) -> f32 {
491        // Cosine similarity
492        let dot_product = vec1.dot(vec2);
493        let norm1 = vec1.dot(vec1).sqrt();
494        let norm2 = vec2.dot(vec2).sqrt();
495
496        if norm1 == 0.0 || norm2 == 0.0 {
497            0.0
498        } else {
499            dot_product / (norm1 * norm2)
500        }
501    }
502
503    /// Calculate Euclidean distance between two vectors
504    pub fn distance(&self, vec1: &Array1<f32>, vec2: &Array1<f32>) -> f32 {
505        (vec1 - vec2).mapv(|x| x * x).sum().sqrt()
506    }
507
508    /// Encode multiple positions in parallel
509    pub fn encode_batch(&self, boards: &[Board]) -> Vec<Array1<f32>> {
510        if boards.len() > 10 {
511            // Use parallel processing for larger batches
512            boards.par_iter().map(|board| self.encode(board)).collect()
513        } else {
514            // Use sequential processing for smaller batches
515            boards.iter().map(|board| self.encode(board)).collect()
516        }
517    }
518
519    /// Calculate similarities between a query vector and multiple position vectors in parallel
520    pub fn batch_similarity(&self, query: &Array1<f32>, vectors: &[Array1<f32>]) -> Vec<f32> {
521        if vectors.len() > 50 {
522            // Use parallel processing for larger batches
523            vectors
524                .par_iter()
525                .map(|vec| self.similarity(query, vec))
526                .collect()
527        } else {
528            // Use sequential processing for smaller batches
529            vectors
530                .iter()
531                .map(|vec| self.similarity(query, vec))
532                .collect()
533        }
534    }
535
536    /// Calculate pairwise similarities between all vectors in parallel
537    pub fn pairwise_similarity_matrix(&self, vectors: &[Array1<f32>]) -> Vec<Vec<f32>> {
538        if vectors.len() > 20 {
539            // Use parallel processing for larger matrices
540            vectors
541                .par_iter()
542                .enumerate()
543                .map(|(i, vec1)| {
544                    vectors
545                        .iter()
546                        .enumerate()
547                        .map(|(j, vec2)| {
548                            if i == j {
549                                1.0 // Self-similarity
550                            } else {
551                                self.similarity(vec1, vec2)
552                            }
553                        })
554                        .collect()
555                })
556                .collect()
557        } else {
558            // Use sequential processing for smaller matrices
559            vectors
560                .iter()
561                .enumerate()
562                .map(|(i, vec1)| {
563                    vectors
564                        .iter()
565                        .enumerate()
566                        .map(|(j, vec2)| {
567                            if i == j {
568                                1.0 // Self-similarity
569                            } else {
570                                self.similarity(vec1, vec2)
571                            }
572                        })
573                        .collect()
574                })
575                .collect()
576        }
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583    use chess::Board;
584    use std::str::FromStr;
585
586    #[test]
587    fn test_encode_starting_position() {
588        let encoder = PositionEncoder::new(1024);
589        let board = Board::default();
590        let vector = encoder.encode(&board);
591
592        assert_eq!(vector.len(), 1024);
593
594        // Starting position should have all pieces
595        assert!(vector.iter().any(|&x| x > 0.0));
596    }
597
598    #[test]
599    fn test_similarity_identical_positions() {
600        let encoder = PositionEncoder::new(1024);
601        let board = Board::default();
602        let vec1 = encoder.encode(&board);
603        let vec2 = encoder.encode(&board);
604
605        let similarity = encoder.similarity(&vec1, &vec2);
606        assert!((similarity - 1.0).abs() < 1e-6);
607    }
608
609    #[test]
610    fn test_similarity_different_positions() {
611        let encoder = PositionEncoder::new(1024);
612        let board1 = Board::default();
613        let board2 =
614            Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1").unwrap();
615
616        let vec1 = encoder.encode(&board1);
617        let vec2 = encoder.encode(&board2);
618
619        let similarity = encoder.similarity(&vec1, &vec2);
620        assert!(similarity < 1.0);
621        assert!(similarity > 0.8); // Should still be quite similar (only one move difference)
622    }
623}