1use chess::{Board, Color, Piece, Square};
2use ndarray::Array1;
3use rayon::prelude::*;
4
5#[derive(Clone)]
7pub struct PositionEncoder {
8 vector_size: usize,
10}
11
12impl PositionEncoder {
13 pub fn new(vector_size: usize) -> Self {
14 Self { vector_size }
15 }
16
17 pub fn vector_size(&self) -> usize {
19 self.vector_size
20 }
21
22 pub fn encode(&self, board: &Board) -> Array1<f32> {
24 let mut features = Vec::with_capacity(self.vector_size);
25
26 self.encode_piece_positions(board, &mut features);
34
35 self.encode_game_state(board, &mut features);
37
38 self.encode_material_balance(board, &mut features);
40
41 self.encode_positional_features(board, &mut features);
43
44 self.encode_tactical_patterns(board, &mut features);
46
47 features.resize(self.vector_size, 0.0);
49
50 Array1::from(features)
51 }
52
53 fn encode_piece_positions(&self, board: &Board, features: &mut Vec<f32>) {
55 for square in chess::ALL_SQUARES {
58 let mut square_features = vec![0.0; 12]; if let Some(piece) = board.piece_on(square) {
62 let color = board.color_on(square).unwrap();
63 let piece_idx = match piece {
64 chess::Piece::Pawn => 0,
65 chess::Piece::Knight => 1,
66 chess::Piece::Bishop => 2,
67 chess::Piece::Rook => 3,
68 chess::Piece::Queen => 4,
69 chess::Piece::King => 5,
70 };
71
72 let color_offset = if color == chess::Color::White { 0 } else { 6 };
73 square_features[piece_idx + color_offset] = 1.0;
74 }
75
76 features.extend(square_features);
77 }
78
79 self.encode_piece_interactions(board, features);
81 }
82
83 fn encode_piece_interactions(&self, board: &Board, features: &mut Vec<f32>) {
85 let mut white_attacks = vec![0.0; 6]; let mut black_attacks = vec![0.0; 6];
88
89 for square in chess::ALL_SQUARES {
91 if let Some(piece) = board.piece_on(square) {
92 let color = board.color_on(square).unwrap();
93 let piece_idx = match piece {
94 Piece::Pawn => 0,
95 Piece::Knight => 1,
96 Piece::Bishop => 2,
97 Piece::Rook => 3,
98 Piece::Queen => 4,
99 Piece::King => 5,
100 };
101
102 let attack_value = match piece {
104 Piece::Pawn => 1.0,
105 Piece::Knight => 3.0,
106 Piece::Bishop => 3.0,
107 Piece::Rook => 5.0,
108 Piece::Queen => 9.0,
109 Piece::King => 1.0,
110 };
111
112 if color == Color::White {
113 white_attacks[piece_idx] += attack_value;
114 } else {
115 black_attacks[piece_idx] += attack_value;
116 }
117 }
118 }
119
120 features.extend(white_attacks);
122 features.extend(black_attacks);
123 }
124
125 fn encode_game_state(&self, board: &Board, features: &mut Vec<f32>) {
127 features.push(if board.castle_rights(Color::White).has_kingside() {
129 1.0
130 } else {
131 0.0
132 });
133 features.push(if board.castle_rights(Color::White).has_queenside() {
134 1.0
135 } else {
136 0.0
137 });
138 features.push(if board.castle_rights(Color::Black).has_kingside() {
139 1.0
140 } else {
141 0.0
142 });
143 features.push(if board.castle_rights(Color::Black).has_queenside() {
144 1.0
145 } else {
146 0.0
147 });
148
149 features.push(if board.en_passant().is_some() {
151 1.0
152 } else {
153 0.0
154 });
155
156 features.push(if board.side_to_move() == Color::White {
158 1.0
159 } else {
160 0.0
161 });
162
163 features.push(0.0);
165 }
166
167 fn encode_material_balance(&self, board: &Board, features: &mut Vec<f32>) {
169 let piece_values = [
170 (Piece::Pawn, 1),
171 (Piece::Knight, 3),
172 (Piece::Bishop, 3),
173 (Piece::Rook, 5),
174 (Piece::Queen, 9),
175 (Piece::King, 0),
176 ];
177
178 for (piece, _value) in piece_values {
179 let white_count = board.pieces(piece) & board.color_combined(Color::White);
180 let black_count = board.pieces(piece) & board.color_combined(Color::Black);
181
182 features.push(white_count.popcnt() as f32);
183 features.push(black_count.popcnt() as f32);
184 features.push((white_count.popcnt() as i32 - black_count.popcnt() as i32) as f32);
185 }
186 }
187
188 fn encode_positional_features(&self, board: &Board, features: &mut Vec<f32>) {
190 for color in [Color::White, Color::Black] {
192 let king_square = board.king_square(color);
193 let center_distance = self.distance_to_center(king_square);
195 features.push(center_distance);
196
197 let surrounding_pieces = self.count_surrounding_pieces(board, king_square);
199 features.push(surrounding_pieces as f32);
200 }
201
202 for color in [Color::White, Color::Black] {
204 let mobility = self.calculate_mobility(board, color);
205 features.push(mobility as f32);
206 }
207
208 self.encode_pawn_structure(board, features);
210
211 self.encode_tactical_patterns(board, features);
213
214 self.encode_center_control(board, features);
216
217 self.encode_piece_coordination(board, features);
219 }
220
221 fn distance_to_center(&self, square: Square) -> f32 {
223 let file = square.get_file().to_index() as f32;
224 let rank = square.get_rank().to_index() as f32;
225 let center_file = 3.5;
226 let center_rank = 3.5;
227
228 ((file - center_file).powi(2) + (rank - center_rank).powi(2)).sqrt()
229 }
230
231 fn count_surrounding_pieces(&self, board: &Board, center: Square) -> u32 {
233 let mut count = 0;
234 let center_file = center.get_file().to_index() as i32;
235 let center_rank = center.get_rank().to_index() as i32;
236
237 for file_offset in -1..=1 {
238 for rank_offset in -1..=1 {
239 if file_offset == 0 && rank_offset == 0 {
240 continue;
241 }
242
243 let new_file = center_file + file_offset;
244 let new_rank = center_rank + rank_offset;
245
246 if (0..8).contains(&new_file) && (0..8).contains(&new_rank) {
247 let square = Square::make_square(
248 chess::Rank::from_index(new_rank as usize),
249 chess::File::from_index(new_file as usize),
250 );
251 if board.piece_on(square).is_some() {
252 count += 1;
253 }
254 }
255 }
256 }
257 count
258 }
259
260 fn calculate_mobility(&self, board: &Board, color: Color) -> u32 {
262 let pieces = board.color_combined(color);
264 let mut mobility = 0;
265
266 for _square in *pieces {
267 mobility += 1;
270 }
271
272 mobility
273 }
274
275 fn encode_pawn_structure(&self, board: &Board, features: &mut Vec<f32>) {
277 for color in [Color::White, Color::Black] {
278 let pawns = board.pieces(Piece::Pawn) & board.color_combined(color);
279
280 let mut doubled_pawns = 0;
282 for file in 0..8 {
283 let mut file_pawn_count = 0;
284 for rank in 0..8 {
285 let square = chess::Square::make_square(
286 chess::Rank::from_index(rank),
287 chess::File::from_index(file),
288 );
289 if (pawns & chess::BitBoard::from_square(square)).popcnt() > 0 {
290 file_pawn_count += 1;
291 }
292 }
293 if file_pawn_count > 1 {
294 doubled_pawns += file_pawn_count - 1;
295 }
296 }
297 features.push(doubled_pawns as f32);
298
299 let mut isolated_pawns = 0;
301 for file in 0..8 {
302 let mut file_has_pawn = false;
303 for rank in 0..8 {
304 let square = chess::Square::make_square(
305 chess::Rank::from_index(rank),
306 chess::File::from_index(file),
307 );
308 if (pawns & chess::BitBoard::from_square(square)).popcnt() > 0 {
309 file_has_pawn = true;
310 break;
311 }
312 }
313
314 if file_has_pawn {
315 let mut has_adjacent = false;
317 for adj_file in [file.saturating_sub(1), file + 1] {
318 if adj_file < 8 && adj_file != file {
319 for rank in 0..8 {
320 let adj_square = chess::Square::make_square(
321 chess::Rank::from_index(rank),
322 chess::File::from_index(adj_file),
323 );
324 if (pawns & chess::BitBoard::from_square(adj_square)).popcnt() > 0 {
325 has_adjacent = true;
326 break;
327 }
328 }
329 }
330 if has_adjacent {
331 break;
332 }
333 }
334
335 if !has_adjacent {
336 isolated_pawns += 1;
337 }
338 }
339 }
340 features.push(isolated_pawns as f32);
341 }
342 }
343
344 fn encode_tactical_patterns(&self, board: &Board, features: &mut Vec<f32>) {
346 for color in [Color::White, Color::Black] {
348 let opponent_color = if color == Color::White {
349 Color::Black
350 } else {
351 Color::White
352 };
353
354 let enemy_king_square = board.king_square(opponent_color);
356 let mut potential_pins = 0;
357
358 let rooks_queens = (board.pieces(Piece::Rook) | board.pieces(Piece::Queen))
360 & board.color_combined(color);
361 for square in chess::ALL_SQUARES {
362 if (rooks_queens & chess::BitBoard::from_square(square)).popcnt() > 0
363 && (square.get_rank() == enemy_king_square.get_rank()
364 || square.get_file() == enemy_king_square.get_file())
365 {
366 potential_pins += 1;
367 }
368 }
369
370 let bishops_queens = (board.pieces(Piece::Bishop) | board.pieces(Piece::Queen))
372 & board.color_combined(color);
373 for square in chess::ALL_SQUARES {
374 if (bishops_queens & chess::BitBoard::from_square(square)).popcnt() > 0 {
375 let rank_diff = (square.get_rank().to_index() as i32
376 - enemy_king_square.get_rank().to_index() as i32)
377 .abs();
378 let file_diff = (square.get_file().to_index() as i32
379 - enemy_king_square.get_file().to_index() as i32)
380 .abs();
381 if rank_diff == file_diff && rank_diff > 0 {
382 potential_pins += 1;
383 }
384 }
385 }
386
387 features.push(potential_pins as f32);
388 }
389
390 self.encode_center_control(board, features);
392 self.encode_piece_coordination(board, features);
393 }
394
395 fn encode_center_control(&self, board: &Board, features: &mut Vec<f32>) {
397 let center_squares = [
399 chess::Square::D4,
400 chess::Square::D5,
401 chess::Square::E4,
402 chess::Square::E5,
403 ];
404
405 for color in [Color::White, Color::Black] {
406 let mut center_control = 0.0;
407
408 for &square in ¢er_squares {
409 if let Some(_piece) = board.piece_on(square) {
411 if board.color_on(square) == Some(color) {
412 center_control += 2.0; }
414 }
415
416 let pieces = board.color_combined(color);
418 for piece_square in chess::ALL_SQUARES {
419 if (pieces & chess::BitBoard::from_square(piece_square)).popcnt() > 0 {
420 if let Some(piece) = board.piece_on(piece_square) {
421 let can_attack = match piece {
422 Piece::Pawn => {
423 let rank_diff = (square.get_rank().to_index() as i32
424 - piece_square.get_rank().to_index() as i32)
425 .abs();
426 let file_diff = (square.get_file().to_index() as i32
427 - piece_square.get_file().to_index() as i32)
428 .abs();
429 rank_diff == 1 && file_diff == 1
430 }
431 Piece::Knight => {
432 let rank_diff = (square.get_rank().to_index() as i32
433 - piece_square.get_rank().to_index() as i32)
434 .abs();
435 let file_diff = (square.get_file().to_index() as i32
436 - piece_square.get_file().to_index() as i32)
437 .abs();
438 (rank_diff == 2 && file_diff == 1)
439 || (rank_diff == 1 && file_diff == 2)
440 }
441 _ => false, };
443
444 if can_attack {
445 center_control += 0.5;
446 }
447 }
448 }
449 }
450 }
451
452 features.push(center_control);
453 }
454 }
455
456 fn encode_piece_coordination(&self, board: &Board, features: &mut Vec<f32>) {
458 for color in [Color::White, Color::Black] {
459 let mut coordination_score = 0.0;
460
461 let pieces = board.color_combined(color);
463 for square1 in chess::ALL_SQUARES {
464 if (pieces & chess::BitBoard::from_square(square1)).popcnt() > 0 {
465 for square2 in chess::ALL_SQUARES {
466 if (pieces & chess::BitBoard::from_square(square2)).popcnt() > 0
467 && square1 != square2
468 {
469 let rank_diff = (square1.get_rank().to_index() as i32
471 - square2.get_rank().to_index() as i32)
472 .abs();
473 let file_diff = (square1.get_file().to_index() as i32
474 - square2.get_file().to_index() as i32)
475 .abs();
476
477 if rank_diff <= 2 && file_diff <= 2 {
478 coordination_score += 0.1;
479 }
480 }
481 }
482 }
483 }
484
485 features.push(coordination_score);
486 }
487 }
488
489 pub fn similarity(&self, vec1: &Array1<f32>, vec2: &Array1<f32>) -> f32 {
491 let dot_product = vec1.dot(vec2);
493 let norm1 = vec1.dot(vec1).sqrt();
494 let norm2 = vec2.dot(vec2).sqrt();
495
496 if norm1 == 0.0 || norm2 == 0.0 {
497 0.0
498 } else {
499 dot_product / (norm1 * norm2)
500 }
501 }
502
503 pub fn distance(&self, vec1: &Array1<f32>, vec2: &Array1<f32>) -> f32 {
505 (vec1 - vec2).mapv(|x| x * x).sum().sqrt()
506 }
507
508 pub fn encode_batch(&self, boards: &[Board]) -> Vec<Array1<f32>> {
510 if boards.len() > 10 {
511 boards.par_iter().map(|board| self.encode(board)).collect()
513 } else {
514 boards.iter().map(|board| self.encode(board)).collect()
516 }
517 }
518
519 pub fn batch_similarity(&self, query: &Array1<f32>, vectors: &[Array1<f32>]) -> Vec<f32> {
521 if vectors.len() > 50 {
522 vectors
524 .par_iter()
525 .map(|vec| self.similarity(query, vec))
526 .collect()
527 } else {
528 vectors
530 .iter()
531 .map(|vec| self.similarity(query, vec))
532 .collect()
533 }
534 }
535
536 pub fn pairwise_similarity_matrix(&self, vectors: &[Array1<f32>]) -> Vec<Vec<f32>> {
538 if vectors.len() > 20 {
539 vectors
541 .par_iter()
542 .enumerate()
543 .map(|(i, vec1)| {
544 vectors
545 .iter()
546 .enumerate()
547 .map(|(j, vec2)| {
548 if i == j {
549 1.0 } else {
551 self.similarity(vec1, vec2)
552 }
553 })
554 .collect()
555 })
556 .collect()
557 } else {
558 vectors
560 .iter()
561 .enumerate()
562 .map(|(i, vec1)| {
563 vectors
564 .iter()
565 .enumerate()
566 .map(|(j, vec2)| {
567 if i == j {
568 1.0 } else {
570 self.similarity(vec1, vec2)
571 }
572 })
573 .collect()
574 })
575 .collect()
576 }
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583 use chess::Board;
584 use std::str::FromStr;
585
586 #[test]
587 fn test_encode_starting_position() {
588 let encoder = PositionEncoder::new(1024);
589 let board = Board::default();
590 let vector = encoder.encode(&board);
591
592 assert_eq!(vector.len(), 1024);
593
594 assert!(vector.iter().any(|&x| x > 0.0));
596 }
597
598 #[test]
599 fn test_similarity_identical_positions() {
600 let encoder = PositionEncoder::new(1024);
601 let board = Board::default();
602 let vec1 = encoder.encode(&board);
603 let vec2 = encoder.encode(&board);
604
605 let similarity = encoder.similarity(&vec1, &vec2);
606 assert!((similarity - 1.0).abs() < 1e-6);
607 }
608
609 #[test]
610 fn test_similarity_different_positions() {
611 let encoder = PositionEncoder::new(1024);
612 let board1 = Board::default();
613 let board2 =
614 Board::from_str("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1").unwrap();
615
616 let vec1 = encoder.encode(&board1);
617 let vec2 = encoder.encode(&board2);
618
619 let similarity = encoder.similarity(&vec1, &vec2);
620 assert!(similarity < 1.0);
621 assert!(similarity > 0.8); }
623}