1use chess::{Board, Color, Piece, Square};
2use ndarray::Array1;
3use rayon::prelude::*;
4
5#[derive(Clone)]
7pub struct PositionEncoder {
8 vector_size: usize,
10}
11
12impl PositionEncoder {
13 pub fn new(vector_size: usize) -> Self {
14 Self { vector_size }
15 }
16
17 pub fn vector_size(&self) -> usize {
19 self.vector_size
20 }
21
22 pub fn encode(&self, board: &Board) -> Array1<f32> {
24 let mut features = vec![0.0f32; self.vector_size];
26 let mut offset = 0;
27
28 offset = self.encode_piece_positions_fast(board, &mut features, offset);
31
32 offset = self.encode_game_state_fast(board, &mut features, offset);
34
35 offset = self.encode_material_balance_fast(board, &mut features, offset);
37
38 offset = self.encode_positional_features_fast(board, &mut features, offset);
40
41 self.encode_tactical_patterns_fast(board, &mut features, offset);
43
44 Array1::from(features)
45 }
46
47 fn encode_piece_positions(&self, board: &Board, features: &mut Vec<f32>) {
49 for square in chess::ALL_SQUARES {
52 let mut square_features = vec![0.0; 12]; if let Some(piece) = board.piece_on(square) {
56 let color = board.color_on(square).unwrap();
57 let piece_idx = match piece {
58 chess::Piece::Pawn => 0,
59 chess::Piece::Knight => 1,
60 chess::Piece::Bishop => 2,
61 chess::Piece::Rook => 3,
62 chess::Piece::Queen => 4,
63 chess::Piece::King => 5,
64 };
65
66 let color_offset = if color == chess::Color::White { 0 } else { 6 };
67 square_features[piece_idx + color_offset] = 1.0;
68 }
69
70 features.extend(square_features);
71 }
72
73 self.encode_piece_interactions(board, features);
75 }
76
77 fn encode_piece_interactions(&self, board: &Board, features: &mut Vec<f32>) {
79 let mut white_attacks = vec![0.0; 6]; let mut black_attacks = vec![0.0; 6];
82
83 for square in chess::ALL_SQUARES {
85 if let Some(piece) = board.piece_on(square) {
86 let color = board.color_on(square).unwrap();
87 let piece_idx = match piece {
88 Piece::Pawn => 0,
89 Piece::Knight => 1,
90 Piece::Bishop => 2,
91 Piece::Rook => 3,
92 Piece::Queen => 4,
93 Piece::King => 5,
94 };
95
96 let attack_value = match piece {
98 Piece::Pawn => 1.0,
99 Piece::Knight => 3.0,
100 Piece::Bishop => 3.0,
101 Piece::Rook => 5.0,
102 Piece::Queen => 9.0,
103 Piece::King => 1.0,
104 };
105
106 if color == Color::White {
107 white_attacks[piece_idx] += attack_value;
108 } else {
109 black_attacks[piece_idx] += attack_value;
110 }
111 }
112 }
113
114 features.extend(white_attacks);
116 features.extend(black_attacks);
117 }
118
119 fn encode_game_state(&self, board: &Board, features: &mut Vec<f32>) {
121 features.push(if board.castle_rights(Color::White).has_kingside() {
123 1.0
124 } else {
125 0.0
126 });
127 features.push(if board.castle_rights(Color::White).has_queenside() {
128 1.0
129 } else {
130 0.0
131 });
132 features.push(if board.castle_rights(Color::Black).has_kingside() {
133 1.0
134 } else {
135 0.0
136 });
137 features.push(if board.castle_rights(Color::Black).has_queenside() {
138 1.0
139 } else {
140 0.0
141 });
142
143 features.push(if board.en_passant().is_some() {
145 1.0
146 } else {
147 0.0
148 });
149
150 features.push(if board.side_to_move() == Color::White {
152 1.0
153 } else {
154 0.0
155 });
156
157 features.push(0.0);
159 }
160
161 fn encode_material_balance(&self, board: &Board, features: &mut Vec<f32>) {
163 let piece_values = [
164 (Piece::Pawn, 1),
165 (Piece::Knight, 3),
166 (Piece::Bishop, 3),
167 (Piece::Rook, 5),
168 (Piece::Queen, 9),
169 (Piece::King, 0),
170 ];
171
172 for (piece, _value) in piece_values {
173 let white_count = board.pieces(piece) & board.color_combined(Color::White);
174 let black_count = board.pieces(piece) & board.color_combined(Color::Black);
175
176 features.push(white_count.popcnt() as f32);
177 features.push(black_count.popcnt() as f32);
178 features.push((white_count.popcnt() as i32 - black_count.popcnt() as i32) as f32);
179 }
180 }
181
182 fn encode_positional_features(&self, board: &Board, features: &mut Vec<f32>) {
184 for color in [Color::White, Color::Black] {
186 let king_square = board.king_square(color);
187 let center_distance = self.distance_to_center(king_square);
189 features.push(center_distance);
190
191 let surrounding_pieces = self.count_surrounding_pieces(board, king_square);
193 features.push(surrounding_pieces as f32);
194 }
195
196 for color in [Color::White, Color::Black] {
198 let mobility = self.calculate_mobility(board, color);
199 features.push(mobility as f32);
200 }
201
202 self.encode_pawn_structure(board, features);
204
205 self.encode_tactical_patterns(board, features);
207
208 self.encode_center_control(board, features);
210
211 self.encode_piece_coordination(board, features);
213 }
214
215 fn distance_to_center(&self, square: Square) -> f32 {
217 let file = square.get_file().to_index() as f32;
218 let rank = square.get_rank().to_index() as f32;
219 let center_file = 3.5;
220 let center_rank = 3.5;
221
222 ((file - center_file).powi(2) + (rank - center_rank).powi(2)).sqrt()
223 }
224
225 fn count_surrounding_pieces(&self, board: &Board, center: Square) -> u32 {
227 let mut count = 0;
228 let center_file = center.get_file().to_index() as i32;
229 let center_rank = center.get_rank().to_index() as i32;
230
231 for file_offset in -1..=1 {
232 for rank_offset in -1..=1 {
233 if file_offset == 0 && rank_offset == 0 {
234 continue;
235 }
236
237 let new_file = center_file + file_offset;
238 let new_rank = center_rank + rank_offset;
239
240 if (0..8).contains(&new_file) && (0..8).contains(&new_rank) {
241 let square = Square::make_square(
242 chess::Rank::from_index(new_rank as usize),
243 chess::File::from_index(new_file as usize),
244 );
245 if board.piece_on(square).is_some() {
246 count += 1;
247 }
248 }
249 }
250 }
251 count
252 }
253
254 fn calculate_mobility(&self, board: &Board, color: Color) -> u32 {
256 let pieces = board.color_combined(color);
258 let mut mobility = 0;
259
260 for _square in *pieces {
261 mobility += 1;
264 }
265
266 mobility
267 }
268
269 fn encode_pawn_structure(&self, board: &Board, features: &mut Vec<f32>) {
271 for color in [Color::White, Color::Black] {
272 let pawns = board.pieces(Piece::Pawn) & board.color_combined(color);
273
274 let mut doubled_pawns = 0;
276 for file in 0..8 {
277 let mut file_pawn_count = 0;
278 for rank in 0..8 {
279 let square = chess::Square::make_square(
280 chess::Rank::from_index(rank),
281 chess::File::from_index(file),
282 );
283 if (pawns & chess::BitBoard::from_square(square)).popcnt() > 0 {
284 file_pawn_count += 1;
285 }
286 }
287 if file_pawn_count > 1 {
288 doubled_pawns += file_pawn_count - 1;
289 }
290 }
291 features.push(doubled_pawns as f32);
292
293 let mut isolated_pawns = 0;
295 for file in 0..8 {
296 let mut file_has_pawn = false;
297 for rank in 0..8 {
298 let square = chess::Square::make_square(
299 chess::Rank::from_index(rank),
300 chess::File::from_index(file),
301 );
302 if (pawns & chess::BitBoard::from_square(square)).popcnt() > 0 {
303 file_has_pawn = true;
304 break;
305 }
306 }
307
308 if file_has_pawn {
309 let mut has_adjacent = false;
311 for adj_file in [file.saturating_sub(1), file + 1] {
312 if adj_file < 8 && adj_file != file {
313 for rank in 0..8 {
314 let adj_square = chess::Square::make_square(
315 chess::Rank::from_index(rank),
316 chess::File::from_index(adj_file),
317 );
318 if (pawns & chess::BitBoard::from_square(adj_square)).popcnt() > 0 {
319 has_adjacent = true;
320 break;
321 }
322 }
323 }
324 if has_adjacent {
325 break;
326 }
327 }
328
329 if !has_adjacent {
330 isolated_pawns += 1;
331 }
332 }
333 }
334 features.push(isolated_pawns as f32);
335 }
336 }
337
338 fn encode_tactical_patterns(&self, board: &Board, features: &mut Vec<f32>) {
340 for color in [Color::White, Color::Black] {
342 let opponent_color = if color == Color::White {
343 Color::Black
344 } else {
345 Color::White
346 };
347
348 let enemy_king_square = board.king_square(opponent_color);
350 let mut potential_pins = 0;
351
352 let rooks_queens = (board.pieces(Piece::Rook) | board.pieces(Piece::Queen))
354 & board.color_combined(color);
355 for square in chess::ALL_SQUARES {
356 if (rooks_queens & chess::BitBoard::from_square(square)).popcnt() > 0
357 && (square.get_rank() == enemy_king_square.get_rank()
358 || square.get_file() == enemy_king_square.get_file())
359 {
360 potential_pins += 1;
361 }
362 }
363
364 let bishops_queens = (board.pieces(Piece::Bishop) | board.pieces(Piece::Queen))
366 & board.color_combined(color);
367 for square in chess::ALL_SQUARES {
368 if (bishops_queens & chess::BitBoard::from_square(square)).popcnt() > 0 {
369 let rank_diff = (square.get_rank().to_index() as i32
370 - enemy_king_square.get_rank().to_index() as i32)
371 .abs();
372 let file_diff = (square.get_file().to_index() as i32
373 - enemy_king_square.get_file().to_index() as i32)
374 .abs();
375 if rank_diff == file_diff && rank_diff > 0 {
376 potential_pins += 1;
377 }
378 }
379 }
380
381 features.push(potential_pins as f32);
382 }
383
384 self.encode_center_control(board, features);
386 self.encode_piece_coordination(board, features);
387 }
388
389 fn encode_center_control(&self, board: &Board, features: &mut Vec<f32>) {
391 let center_squares = [
393 chess::Square::D4,
394 chess::Square::D5,
395 chess::Square::E4,
396 chess::Square::E5,
397 ];
398
399 for color in [Color::White, Color::Black] {
400 let mut center_control = 0.0;
401
402 for &square in ¢er_squares {
403 if let Some(_piece) = board.piece_on(square) {
405 if board.color_on(square) == Some(color) {
406 center_control += 2.0; }
408 }
409
410 let pieces = board.color_combined(color);
412 for piece_square in chess::ALL_SQUARES {
413 if (pieces & chess::BitBoard::from_square(piece_square)).popcnt() > 0 {
414 if let Some(piece) = board.piece_on(piece_square) {
415 let can_attack = match piece {
416 Piece::Pawn => {
417 let rank_diff = (square.get_rank().to_index() as i32
418 - piece_square.get_rank().to_index() as i32)
419 .abs();
420 let file_diff = (square.get_file().to_index() as i32
421 - piece_square.get_file().to_index() as i32)
422 .abs();
423 rank_diff == 1 && file_diff == 1
424 }
425 Piece::Knight => {
426 let rank_diff = (square.get_rank().to_index() as i32
427 - piece_square.get_rank().to_index() as i32)
428 .abs();
429 let file_diff = (square.get_file().to_index() as i32
430 - piece_square.get_file().to_index() as i32)
431 .abs();
432 (rank_diff == 2 && file_diff == 1)
433 || (rank_diff == 1 && file_diff == 2)
434 }
435 _ => false, };
437
438 if can_attack {
439 center_control += 0.5;
440 }
441 }
442 }
443 }
444 }
445
446 features.push(center_control);
447 }
448 }
449
450 fn encode_piece_coordination(&self, board: &Board, features: &mut Vec<f32>) {
452 for color in [Color::White, Color::Black] {
453 let mut coordination_score = 0.0;
454
455 let pieces = board.color_combined(color);
457 for square1 in chess::ALL_SQUARES {
458 if (pieces & chess::BitBoard::from_square(square1)).popcnt() > 0 {
459 for square2 in chess::ALL_SQUARES {
460 if (pieces & chess::BitBoard::from_square(square2)).popcnt() > 0
461 && square1 != square2
462 {
463 let rank_diff = (square1.get_rank().to_index() as i32
465 - square2.get_rank().to_index() as i32)
466 .abs();
467 let file_diff = (square1.get_file().to_index() as i32
468 - square2.get_file().to_index() as i32)
469 .abs();
470
471 if rank_diff <= 2 && file_diff <= 2 {
472 coordination_score += 0.1;
473 }
474 }
475 }
476 }
477 }
478
479 features.push(coordination_score);
480 }
481 }
482
483 pub fn similarity(&self, vec1: &Array1<f32>, vec2: &Array1<f32>) -> f32 {
485 let dot_product = vec1.dot(vec2);
487 let norm1 = vec1.dot(vec1).sqrt();
488 let norm2 = vec2.dot(vec2).sqrt();
489
490 if norm1 == 0.0 || norm2 == 0.0 {
491 0.0
492 } else {
493 dot_product / (norm1 * norm2)
494 }
495 }
496
497 pub fn distance(&self, vec1: &Array1<f32>, vec2: &Array1<f32>) -> f32 {
499 (vec1 - vec2).mapv(|x| x * x).sum().sqrt()
500 }
501
502 pub fn encode_batch(&self, boards: &[Board]) -> Vec<Array1<f32>> {
504 if boards.len() > 10 {
505 boards.par_iter().map(|board| self.encode(board)).collect()
507 } else {
508 boards.iter().map(|board| self.encode(board)).collect()
510 }
511 }
512
513 pub fn batch_similarity(&self, query: &Array1<f32>, vectors: &[Array1<f32>]) -> Vec<f32> {
515 if vectors.len() > 50 {
516 vectors
518 .par_iter()
519 .map(|vec| self.similarity(query, vec))
520 .collect()
521 } else {
522 vectors
524 .iter()
525 .map(|vec| self.similarity(query, vec))
526 .collect()
527 }
528 }
529
530 pub fn pairwise_similarity_matrix(&self, vectors: &[Array1<f32>]) -> Vec<Vec<f32>> {
532 if vectors.len() > 20 {
533 vectors
535 .par_iter()
536 .enumerate()
537 .map(|(i, vec1)| {
538 vectors
539 .iter()
540 .enumerate()
541 .map(|(j, vec2)| {
542 if i == j {
543 1.0 } else {
545 self.similarity(vec1, vec2)
546 }
547 })
548 .collect()
549 })
550 .collect()
551 } else {
552 vectors
554 .iter()
555 .enumerate()
556 .map(|(i, vec1)| {
557 vectors
558 .iter()
559 .enumerate()
560 .map(|(j, vec2)| {
561 if i == j {
562 1.0 } else {
564 self.similarity(vec1, vec2)
565 }
566 })
567 .collect()
568 })
569 .collect()
570 }
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577 use chess::Board;
578 use std::str::FromStr;
579
580 #[test]
581 fn test_encode_starting_position() {
582 let encoder = PositionEncoder::new(1024);
583 let board = Board::default();
584 let vector = encoder.encode(&board);
585
586 assert_eq!(vector.len(), 1024);
587
588 assert!(vector.iter().any(|&x| x > 0.0));
590 }
591
592 #[test]
593 fn test_similarity_identical_positions() {
594 let encoder = PositionEncoder::new(1024);
595 let board = Board::default();
596 let vec1 = encoder.encode(&board);
597 let vec2 = encoder.encode(&board);
598
599 let similarity = encoder.similarity(&vec1, &vec2);
600 assert!((similarity - 1.0).abs() < 1e-6);
601 }
602
603 #[test]
604 fn test_similarity_different_positions() {
605 let encoder = PositionEncoder::new(1024);
606 let board1 = Board::default();
607 let board2 =
608 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1").unwrap();
609
610 let vec1 = encoder.encode(&board1);
611 let vec2 = encoder.encode(&board2);
612
613 let similarity = encoder.similarity(&vec1, &vec2);
614 assert!(similarity < 1.0);
615 assert!(similarity > 0.8); }
617}
618
619impl PositionEncoder {
620 fn encode_piece_positions_fast(&self, board: &Board, features: &mut [f32], offset: usize) -> usize {
624 let mut idx = offset;
625
626 const PIECE_INDICES: [usize; 6] = [0, 1, 2, 3, 4, 5]; for square in chess::ALL_SQUARES {
630 if let Some(piece) = board.piece_on(square) {
631 let color = board.color_on(square).unwrap();
632 let piece_idx = match piece {
633 chess::Piece::Pawn => 0,
634 chess::Piece::Knight => 1,
635 chess::Piece::Bishop => 2,
636 chess::Piece::Rook => 3,
637 chess::Piece::Queen => 4,
638 chess::Piece::King => 5,
639 };
640
641 let color_offset = if color == chess::Color::White { 0 } else { 6 };
642 let feature_idx = idx + piece_idx + color_offset;
643
644 if feature_idx < features.len() {
645 features[feature_idx] = 1.0;
646 }
647 }
648 idx += 12; }
650
651 offset + 768 }
653
654 fn encode_game_state_fast(&self, board: &Board, features: &mut [f32], offset: usize) -> usize {
656 let mut idx = offset;
657
658 if idx + 7 < features.len() {
659 features[idx] = if board.castle_rights(chess::Color::White).has_kingside() { 1.0 } else { 0.0 };
661 features[idx + 1] = if board.castle_rights(chess::Color::White).has_queenside() { 1.0 } else { 0.0 };
662 features[idx + 2] = if board.castle_rights(chess::Color::Black).has_kingside() { 1.0 } else { 0.0 };
663 features[idx + 3] = if board.castle_rights(chess::Color::Black).has_queenside() { 1.0 } else { 0.0 };
664
665 features[idx + 4] = if board.en_passant().is_some() { 1.0 } else { 0.0 };
667
668 features[idx + 5] = if board.side_to_move() == chess::Color::White { 1.0 } else { 0.0 };
670
671 features[idx + 6] = 0.0; }
674
675 offset + 7
676 }
677
678 fn encode_material_balance_fast(&self, board: &Board, features: &mut [f32], offset: usize) -> usize {
680 let mut idx = offset;
681
682 if idx + 12 < features.len() {
683 let piece_values = [1.0, 3.0, 3.0, 5.0, 9.0, 0.0]; for (piece_type, &_value) in [chess::Piece::Pawn, chess::Piece::Knight, chess::Piece::Bishop,
687 chess::Piece::Rook, chess::Piece::Queen, chess::Piece::King].iter().zip(&piece_values) {
688 let white_count = (board.pieces(*piece_type) & board.color_combined(chess::Color::White)).popcnt() as f32;
689 let black_count = (board.pieces(*piece_type) & board.color_combined(chess::Color::Black)).popcnt() as f32;
690
691 features[idx] = white_count / 8.0; features[idx + 1] = black_count / 8.0;
693 idx += 2;
694 }
695 }
696
697 offset + 12
698 }
699
700 fn encode_positional_features_fast(&self, board: &Board, features: &mut [f32], offset: usize) -> usize {
702 let mut idx = offset;
703
704 if idx + 4 < features.len() {
705 let white_king_square = board.king_square(chess::Color::White);
707 let black_king_square = board.king_square(chess::Color::Black);
708
709 features[idx] = white_king_square.get_file().to_index() as f32 / 7.0;
710 features[idx + 1] = white_king_square.get_rank().to_index() as f32 / 7.0;
711 features[idx + 2] = black_king_square.get_file().to_index() as f32 / 7.0;
712 features[idx + 3] = black_king_square.get_rank().to_index() as f32 / 7.0;
713 }
714
715 offset + 4
716 }
717
718 fn encode_tactical_patterns_fast(&self, board: &Board, features: &mut [f32], offset: usize) -> usize {
720 let mut idx = offset;
721
722 if idx + 2 < features.len() {
723 let white_pieces = board.color_combined(chess::Color::White).popcnt() as f32;
725 let black_pieces = board.color_combined(chess::Color::Black).popcnt() as f32;
726
727 features[idx] = white_pieces / 16.0;
728 features[idx + 1] = black_pieces / 16.0;
729 }
730
731 offset + 2
732 }
733}