1#![allow(clippy::type_complexity)]
2use candle_core::{Device, Module, Result as CandleResult, Tensor};
3use candle_nn::{linear, AdamW, Linear, Optimizer, ParamsAdamW, VarBuilder, VarMap};
4use chess::{Board, Color, Piece, Square};
5use serde::{Deserialize, Serialize};
6
7pub struct NNUE {
14 feature_transformer: FeatureTransformer,
15 hidden_layers: Vec<Linear>,
16 output_layer: Linear,
17 device: Device,
18 #[allow(dead_code)]
19 var_map: VarMap,
20 optimizer: Option<AdamW>,
21
22 vector_weight: f32, enable_vector_integration: bool,
25
26 weights_loaded: bool, training_version: u32, }
30
31struct FeatureTransformer {
34 weights: Tensor,
35 biases: Tensor,
36 accumulated_features: Option<Tensor>,
37 king_squares: [Square; 2], }
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct NNUEConfig {
43 pub feature_size: usize, pub hidden_size: usize, pub num_hidden_layers: usize, pub activation: ActivationType,
47 pub learning_rate: f32,
48 pub vector_blend_weight: f32, pub enable_incremental_updates: bool,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum ActivationType {
54 ReLU,
55 ClippedReLU, Sigmoid,
57}
58
59impl Default for NNUEConfig {
60 fn default() -> Self {
61 Self {
62 feature_size: 768, hidden_size: 256,
64 num_hidden_layers: 2,
65 activation: ActivationType::ClippedReLU,
66 learning_rate: 0.001,
67 vector_blend_weight: 0.3, enable_incremental_updates: true,
69 }
70 }
71}
72
73impl NNUEConfig {
74 pub fn vector_integrated() -> Self {
76 Self {
77 vector_blend_weight: 0.4, ..Default::default()
79 }
80 }
81
82 pub fn nnue_focused() -> Self {
84 Self {
85 vector_blend_weight: 0.1, ..Default::default()
87 }
88 }
89
90 pub fn experimental() -> Self {
92 Self {
93 feature_size: 1024, hidden_size: 512,
95 num_hidden_layers: 3,
96 vector_blend_weight: 0.5, ..Default::default()
98 }
99 }
100}
101
102impl NNUE {
103 pub fn new(config: NNUEConfig) -> CandleResult<Self> {
105 Self::new_with_weights(config, None)
106 }
107
108 pub fn new_with_weights(
110 config: NNUEConfig,
111 weights: Option<std::collections::HashMap<String, candle_core::Tensor>>,
112 ) -> CandleResult<Self> {
113 let device = Device::Cpu; if let Some(weight_map) = weights {
116 println!("🔄 Creating NNUE with pre-loaded weights...");
118 return Self::create_with_loaded_weights(config, weight_map, device);
119 }
120
121 let var_map = VarMap::new();
123 let vs = VarBuilder::from_varmap(&var_map, candle_core::DType::F32, &device);
124
125 let feature_transformer =
127 FeatureTransformer::new(vs.clone(), config.feature_size, config.hidden_size)?;
128
129 let mut hidden_layers = Vec::new();
131 let mut prev_size = config.hidden_size;
132
133 for _i in 0..config.num_hidden_layers {
134 let layer = linear(prev_size, config.hidden_size, vs.pp("Processing..."))?;
135 hidden_layers.push(layer);
136 prev_size = config.hidden_size;
137 }
138
139 let output_layer = linear(prev_size, 1, vs.pp("output"))?;
141
142 let adamw_params = ParamsAdamW {
144 lr: config.learning_rate as f64,
145 ..Default::default()
146 };
147 let optimizer = Some(AdamW::new(var_map.all_vars(), adamw_params)?);
148
149 Ok(Self {
150 feature_transformer,
151 hidden_layers,
152 output_layer,
153 device,
154 var_map,
155 optimizer,
156 vector_weight: config.vector_blend_weight,
157 enable_vector_integration: true,
158 weights_loaded: false,
159 training_version: 0,
160 })
161 }
162
163 fn create_with_loaded_weights(
165 config: NNUEConfig,
166 weights: std::collections::HashMap<String, candle_core::Tensor>,
167 device: Device,
168 ) -> CandleResult<Self> {
169 println!("✨ Creating custom NNUE with direct weight application...");
170
171 let var_map = VarMap::new();
173
174 let feature_transformer = if let (Some(ft_weights), Some(ft_biases)) = (
176 weights.get("feature_transformer.weights"),
177 weights.get("feature_transformer.biases"),
178 ) {
179 FeatureTransformer {
180 weights: ft_weights.clone(),
181 biases: ft_biases.clone(),
182 accumulated_features: None,
183 king_squares: [chess::Square::E1, chess::Square::E8],
184 }
185 } else {
186 println!("⚠️ Feature transformer weights not found, using random initialization");
187 let vs = VarBuilder::from_varmap(&var_map, candle_core::DType::F32, &device);
188 FeatureTransformer::new(vs, config.feature_size, config.hidden_size)?
189 };
190
191 let vs = VarBuilder::from_varmap(&var_map, candle_core::DType::F32, &device);
194 let mut hidden_layers = Vec::new();
195 let mut prev_size = config.hidden_size;
196
197 for i in 0..config.num_hidden_layers {
198 let layer = linear(
199 prev_size,
200 config.hidden_size,
201 vs.pp(format!("hidden_{}", i)),
202 )?;
203 hidden_layers.push(layer);
204 prev_size = config.hidden_size;
205
206 let weight_key = format!("hidden_layer_{}.weight", i);
208 let bias_key = format!("hidden_layer_{}.bias", i);
209 if weights.contains_key(&weight_key) && weights.contains_key(&bias_key) {
210 println!(" 📋 Hidden layer {} weights available but not applied (candle-nn limitation)", i);
211 }
212 }
213
214 let output_layer = linear(prev_size, 1, vs.pp("output"))?;
216 if weights.contains_key("output_layer.weight") && weights.contains_key("output_layer.bias")
217 {
218 println!(" 📋 Output layer weights available but not applied (candle-nn limitation)");
219 }
220
221 let adamw_params = ParamsAdamW {
223 lr: config.learning_rate as f64,
224 ..Default::default()
225 };
226 let optimizer = Some(AdamW::new(var_map.all_vars(), adamw_params)?);
227
228 println!("✅ Custom NNUE created with partial weight loading");
229 println!("📝 Feature transformer: ✅ Applied");
230 println!("📝 Hidden layers: ⚠️ Not applied (candle-nn limitation)");
231 println!("📝 Output layer: ⚠️ Not applied (candle-nn limitation)");
232
233 Ok(Self {
234 feature_transformer,
235 hidden_layers,
236 output_layer,
237 device,
238 var_map,
239 optimizer,
240 vector_weight: config.vector_blend_weight,
241 enable_vector_integration: true,
242 weights_loaded: true, training_version: 0, })
245 }
246
247 pub fn evaluate(&mut self, board: &Board) -> CandleResult<f32> {
249 let features = self.extract_features(board)?;
250 let output = self.forward(&features)?;
251
252 let eval_pawn_units = output.to_vec2::<f32>()?[0][0];
255
256 Ok(eval_pawn_units)
257 }
258
259 pub fn evaluate_optimized(&mut self, board: &Board) -> CandleResult<f32> {
261 if let Some(ref accumulated) = self.feature_transformer.accumulated_features {
263 let activated = accumulated.clamp(0.0, 1.0)?;
265
266 let mut hidden_output = activated;
268 for layer in &self.hidden_layers {
269 hidden_output = layer.forward(&hidden_output)?;
270 hidden_output = hidden_output.clamp(0.0, 1.0)?; }
272
273 let output = self.output_layer.forward(&hidden_output)?;
275 let eval_raw = output.get(0)?.to_scalar::<f32>()?;
276
277 return Ok(eval_raw * 600.0);
278 }
279
280 self.initialize_accumulator(board)?;
282
283 self.evaluate_optimized(board)
285 }
286
287 fn initialize_accumulator(&mut self, board: &Board) -> CandleResult<()> {
289 let mut accumulator = self.feature_transformer.biases.clone();
291
292 let white_king = board.king_square(Color::White);
293 let black_king = board.king_square(Color::Black);
294
295 for square in chess::ALL_SQUARES {
297 if let Some(piece) = board.piece_on(square) {
298 let color = board.color_on(square).unwrap();
299
300 if let Some(feature_idx) = self.feature_transformer.get_feature_index_for_piece(
301 piece, color, square, white_king, black_king
302 ) {
303 if feature_idx < 768 {
304 let piece_weights = self.feature_transformer.weights.get(feature_idx)?;
305 accumulator = accumulator.add(&piece_weights)?;
306 }
307 }
308 }
309 }
310
311 self.feature_transformer.accumulated_features = Some(accumulator);
313 self.feature_transformer.king_squares = [white_king, black_king];
314
315 Ok(())
316 }
317
318 pub fn update_after_move(
320 &mut self,
321 chess_move: chess::ChessMove,
322 board_before: &Board,
323 board_after: &Board,
324 ) -> CandleResult<()> {
325 let moved_piece = board_before.piece_on(chess_move.get_source()).unwrap();
326 let piece_color = board_before.color_on(chess_move.get_source()).unwrap();
327
328 let white_king_after = board_after.king_square(Color::White);
329 let black_king_after = board_after.king_square(Color::Black);
330
331 if let Some(captured_piece) = board_before.piece_on(chess_move.get_dest()) {
333 let captured_color = board_before.color_on(chess_move.get_dest()).unwrap();
334
335 if let Some(captured_idx) = self.feature_transformer.get_feature_index_for_piece(
337 captured_piece, captured_color, chess_move.get_dest(),
338 white_king_after, black_king_after
339 ) {
340 if captured_idx < 768 && self.feature_transformer.accumulated_features.is_some() {
341 let captured_weights = self.feature_transformer.weights.get(captured_idx)?;
342 let accumulator = self.feature_transformer.accumulated_features.as_mut().unwrap();
343 *accumulator = accumulator.sub(&captured_weights)?;
344 }
345 }
346 }
347
348 self.feature_transformer.incremental_update(
350 moved_piece,
351 piece_color,
352 chess_move.get_source(),
353 chess_move.get_dest(),
354 white_king_after,
355 black_king_after,
356 )?;
357
358 if chess_move.get_promotion().is_some() {
360 let promoted_piece = chess_move.get_promotion().unwrap();
362
363 if let Some(promoted_idx) = self.feature_transformer.get_feature_index_for_piece(
366 promoted_piece, piece_color, chess_move.get_dest(),
367 white_king_after, black_king_after
368 ) {
369 if promoted_idx < 768 && self.feature_transformer.accumulated_features.is_some() {
370 let promoted_weights = self.feature_transformer.weights.get(promoted_idx)?;
371 let accumulator = self.feature_transformer.accumulated_features.as_mut().unwrap();
372 *accumulator = accumulator.add(&promoted_weights)?;
373 }
374 }
375 }
376
377 Ok(())
378 }
379
380 pub fn evaluate_batch(&mut self, boards: &[Board]) -> CandleResult<Vec<f32>> {
382 let mut results = Vec::with_capacity(boards.len());
383
384 for board in boards {
385 let eval = self.evaluate_optimized(board)?;
387 results.push(eval);
388 }
389
390 Ok(results)
391 }
392
393 pub fn evaluate_from_features(&mut self, features: &Tensor) -> CandleResult<f32> {
395 let output = self.forward_optimized(features)?;
396 Ok(output)
397 }
398
399 pub fn evaluate_hybrid(
401 &mut self,
402 board: &Board,
403 vector_eval: Option<f32>,
404 tactical_eval: Option<f32>,
405 ) -> CandleResult<f32> {
406 let nnue_eval = self.evaluate_optimized(board)?;
407
408 if !self.enable_vector_integration {
409 return Ok(nnue_eval);
410 }
411
412 let blend_weights = self.calculate_blend_weights(board, nnue_eval, vector_eval, tactical_eval)?;
414
415 let mut final_eval = blend_weights.nnue_weight * nnue_eval;
416
417 if let Some(vector_eval) = vector_eval {
418 final_eval += blend_weights.vector_weight * vector_eval;
419 }
420
421 if let Some(tactical_eval) = tactical_eval {
422 final_eval += blend_weights.tactical_weight * tactical_eval;
423 }
424
425 Ok(final_eval)
426 }
427
428 fn calculate_blend_weights(
430 &self,
431 board: &Board,
432 _nnue_eval: f32,
433 _vector_eval: Option<f32>,
434 _tactical_eval: Option<f32>,
435 ) -> CandleResult<BlendWeights> {
436 let mut nnue_weight = 0.7; let mut vector_weight = 0.2; let mut tactical_weight = 0.1; let material_count = self.count_material(board);
442 let game_phase = self.detect_game_phase(material_count);
443
444 match game_phase {
445 GamePhase::Opening => {
446 vector_weight = 0.4;
448 nnue_weight = 0.5;
449 tactical_weight = 0.1;
450 },
451 GamePhase::Middlegame => {
452 if self.is_tactical_position(board) {
454 tactical_weight = 0.3;
455 nnue_weight = 0.5;
456 vector_weight = 0.2;
457 } else {
458 nnue_weight = 0.6;
460 vector_weight = 0.25;
461 tactical_weight = 0.15;
462 }
463 },
464 GamePhase::Endgame => {
465 nnue_weight = 0.8;
467 vector_weight = 0.15;
468 tactical_weight = 0.05;
469 },
470 }
471
472 let total_weight = nnue_weight + vector_weight + tactical_weight;
474
475 Ok(BlendWeights {
476 nnue_weight: nnue_weight / total_weight,
477 vector_weight: vector_weight / total_weight,
478 tactical_weight: tactical_weight / total_weight,
479 })
480 }
481
482 fn detect_game_phase(&self, material_count: u32) -> GamePhase {
484 if material_count > 78 { GamePhase::Opening
486 } else if material_count > 30 {
487 GamePhase::Middlegame
488 } else {
489 GamePhase::Endgame
490 }
491 }
492
493 fn count_material(&self, board: &Board) -> u32 {
495 let mut material = 0;
496 for square in chess::ALL_SQUARES {
497 if let Some(piece) = board.piece_on(square) {
498 material += match piece {
499 Piece::Pawn => 1,
500 Piece::Knight => 3,
501 Piece::Bishop => 3,
502 Piece::Rook => 5,
503 Piece::Queen => 9,
504 Piece::King => 0,
505 };
506 }
507 }
508 material
509 }
510
511 fn is_tactical_position(&self, board: &Board) -> bool {
513 let moves = chess::MoveGen::new_legal(board);
515 let capture_count = moves.filter(|mv| board.piece_on(mv.get_dest()).is_some()).count();
516
517 capture_count > 3 || board.checkers().popcnt() > 0
519 }
520
521 pub fn benchmark_performance(&mut self, positions: &[Board], iterations: usize) -> Result<NNUEBenchmarkResult, Box<dyn std::error::Error>> {
523 use std::time::Instant;
524
525 println!("🚀 NNUE Performance Benchmark");
526 println!("Positions: {}, Iterations: {}", positions.len(), iterations);
527
528 let start = Instant::now();
530 for _ in 0..iterations {
531 for board in positions {
532 let _ = self.evaluate(board)?;
533 }
534 }
535 let standard_duration = start.elapsed();
536
537 let start = Instant::now();
539 for _ in 0..iterations {
540 for board in positions {
541 let _ = self.evaluate_optimized(board)?;
542 }
543 }
544 let optimized_duration = start.elapsed();
545
546 let start = Instant::now();
548 for _ in 0..iterations {
549 for board in positions {
550 self.initialize_accumulator(board).ok();
552
553 let _ = self.evaluate_optimized(board)?;
555 }
556 }
557 let incremental_duration = start.elapsed();
558
559 let total_evaluations = positions.len() * iterations;
560
561 let standard_nps = total_evaluations as f64 / standard_duration.as_secs_f64();
562 let optimized_nps = total_evaluations as f64 / optimized_duration.as_secs_f64();
563 let incremental_nps = total_evaluations as f64 / incremental_duration.as_secs_f64();
564
565 Ok(NNUEBenchmarkResult {
566 total_evaluations,
567 standard_nps,
568 optimized_nps,
569 incremental_nps,
570 speedup_optimized: optimized_nps / standard_nps,
571 speedup_incremental: incremental_nps / standard_nps,
572 })
573 }
574}
575
576#[derive(Debug, Clone)]
577pub struct NNUEBenchmarkResult {
578 pub total_evaluations: usize,
579 pub standard_nps: f64,
580 pub optimized_nps: f64,
581 pub incremental_nps: f64,
582 pub speedup_optimized: f64,
583 pub speedup_incremental: f64,
584}
585
586impl NNUE {
587 fn extract_features(&self, board: &Board) -> CandleResult<Tensor> {
590 let mut features = vec![0.0f32; 768]; let white_king = board.king_square(Color::White);
593 let black_king = board.king_square(Color::Black);
594
595 for square in chess::ALL_SQUARES {
597 if let Some(piece) = board.piece_on(square) {
598 let color = board.color_on(square).unwrap();
599
600 let (white_idx, _black_idx) =
602 self.get_feature_indices(piece, color, square, white_king, black_king);
603
604 if let Some(idx) = white_idx {
606 if idx < 768 {
607 features[idx] = 1.0;
608 }
609 }
610 }
613 }
614
615 Tensor::from_vec(features, (1, 768), &self.device)
616 }
617
618 fn extract_features_optimized(&self, board: &Board) -> CandleResult<Tensor> {
620 let mut features = [0.0f32; 768]; let white_king = board.king_square(Color::White);
623 let black_king = board.king_square(Color::Black);
624
625 let white_king_idx = white_king.to_index();
627 let black_king_idx = black_king.to_index();
628
629 let occupied = board.combined();
631 for square_idx in 0..64 {
632 if (occupied.0 & (1u64 << square_idx)) != 0 {
633 let square = unsafe { Square::new(square_idx) };
634
635 if let Some(piece) = board.piece_on(square) {
636 let color = board.color_on(square).unwrap();
637
638 let feature_idx = self.get_feature_index_fast(
640 piece,
641 color,
642 square_idx as usize,
643 white_king_idx,
644 black_king_idx
645 );
646
647 if feature_idx < 768 {
648 features[feature_idx] = 1.0;
649 }
650 }
651 }
652 }
653
654 Tensor::from_slice(&features, (1, 768), &self.device)
656 }
657
658 fn get_feature_index_fast(
660 &self,
661 piece: Piece,
662 color: Color,
663 square_idx: usize,
664 white_king_idx: usize,
665 _black_king_idx: usize,
666 ) -> usize {
667 let piece_idx = match piece {
671 Piece::Pawn => 0,
672 Piece::Knight => 1,
673 Piece::Bishop => 2,
674 Piece::Rook => 3,
675 Piece::Queen => 4,
676 Piece::King => 5,
677 };
678
679 let color_offset = if color == Color::White { 0 } else { 6 };
680 let king_bucket = white_king_idx / 8; (piece_idx + color_offset) * 64 + square_idx + (king_bucket % 4) * 384
684 }
685
686 fn forward_optimized(&self, features: &Tensor) -> CandleResult<f32> {
688 let transformed = self.feature_transformer.forward_optimized(features)?;
690
691 let activated = transformed.clamp(0.0, 1.0)?;
693
694 let mut hidden_output = activated;
696 for layer in &self.hidden_layers {
697 hidden_output = layer.forward(&hidden_output)?;
698 hidden_output = hidden_output.clamp(0.0, 1.0)?; }
700
701 let output = self.output_layer.forward(&hidden_output)?;
703
704 let eval_raw = output.get(0)?.get(0)?.to_scalar::<f32>()?;
706
707 Ok(eval_raw * 600.0) }
710
711 fn get_feature_indices(
713 &self,
714 piece: Piece,
715 color: Color,
716 square: Square,
717 _white_king: Square,
718 _black_king: Square,
719 ) -> (Option<usize>, Option<usize>) {
720 let piece_type_idx = match piece {
721 Piece::Pawn => 0,
722 Piece::Knight => 1,
723 Piece::Bishop => 2,
724 Piece::Rook => 3,
725 Piece::Queen => 4,
726 Piece::King => return (None, None), };
728
729 let color_offset = if color == Color::White { 0 } else { 5 };
730 let base_idx = (piece_type_idx + color_offset) * 64;
731
732 let feature_idx = base_idx + square.to_index();
734
735 if feature_idx < 768 {
737 (Some(feature_idx), Some(feature_idx)) } else {
739 (None, None)
740 }
741 }
742
743 fn forward(&self, features: &Tensor) -> CandleResult<Tensor> {
745 let mut x = self.feature_transformer.forward(features)?;
747
748 for layer in &self.hidden_layers {
750 x = layer.forward(&x)?;
751 x = self.clipped_relu(&x)?;
752 }
753
754 let output = self.output_layer.forward(&x)?;
756
757 Ok(output)
758 }
759
760 fn clipped_relu(&self, x: &Tensor) -> CandleResult<Tensor> {
762 let relu = x.relu()?;
764 relu.clamp(0.0, 1.0)
765 }
766
767 pub fn train_batch(&mut self, positions: &[(Board, f32)]) -> CandleResult<f32> {
769 let batch_size = positions.len();
770 let mut total_loss = 0.0;
771
772 for (board, target_eval) in positions {
773 let features = self.extract_features(board)?;
775
776 let prediction = self.forward(&features)?;
778
779 let target = Tensor::from_vec(vec![*target_eval], (1, 1), &self.device)?;
781
782 let diff = (&prediction - &target)?;
784 let squared = diff.powf(2.0)?;
785 let loss = squared.sum_all()?;
786
787 if let Some(ref mut optimizer) = self.optimizer {
789 let grads = loss.backward()?;
791
792 optimizer.step(&grads)?;
794 }
795
796 total_loss += loss.to_scalar::<f32>()?;
797 }
798
799 Ok(total_loss / batch_size as f32)
800 }
801
802 pub fn update_incrementally(
804 &mut self,
805 board: &Board,
806 _chess_move: chess::ChessMove,
807 ) -> CandleResult<()> {
808 let white_king = board.king_square(Color::White);
810 let black_king = board.king_square(Color::Black);
811 self.feature_transformer.king_squares = [white_king, black_king];
812
813 let features = self.extract_features(board)?;
816 self.feature_transformer.accumulated_features = Some(features);
817
818 Ok(())
825 }
826
827 pub fn set_vector_weight(&mut self, weight: f32) {
829 self.vector_weight = weight.clamp(0.0, 1.0);
830 }
831
832 pub fn are_weights_loaded(&self) -> bool {
834 self.weights_loaded
835 }
836
837 pub fn quick_fix_training(&mut self, positions: &[(Board, f32)]) -> CandleResult<f32> {
839 if self.weights_loaded {
840 println!("📝 Weights were loaded, skipping quick training");
841 return Ok(0.0);
842 }
843
844 println!("⚡ Running quick NNUE training to fix evaluation blindness...");
845 let loss = self.train_batch(positions)?;
846 println!("✅ Quick training completed with loss: {:.4}", loss);
847 Ok(loss)
848 }
849
850 pub fn incremental_train(
852 &mut self,
853 positions: &[(Board, f32)],
854 preserve_best: bool,
855 ) -> CandleResult<f32> {
856 let initial_loss = if preserve_best {
857 let mut total_loss = 0.0;
859 for (board, target_eval) in positions {
860 let prediction = self.evaluate(board)?;
861 let diff = prediction - target_eval;
862 total_loss += diff * diff;
863 }
864 total_loss / positions.len() as f32
865 } else {
866 f32::MAX
867 };
868
869 println!(
870 "🔄 Starting incremental training (v{})...",
871 self.training_version + 1
872 );
873 if preserve_best {
874 println!("📊 Baseline loss: {:.4}", initial_loss);
875 }
876
877 let original_weights = if preserve_best {
879 Some((
880 self.feature_transformer.weights.clone(),
881 self.feature_transformer.biases.clone(),
882 ))
883 } else {
884 None
885 };
886
887 let final_loss = self.train_batch(positions)?;
889
890 if preserve_best && final_loss > initial_loss {
892 println!(
893 "⚠️ Training made model worse ({:.4} > {:.4}), reverting...",
894 final_loss, initial_loss
895 );
896 if let Some((orig_weights, orig_biases)) = original_weights {
897 self.feature_transformer.weights = orig_weights;
898 self.feature_transformer.biases = orig_biases;
899 }
900 return Ok(initial_loss);
901 }
902
903 println!(
904 "✅ Incremental training improved model: {:.4} -> {:.4}",
905 if preserve_best { initial_loss } else { 0.0 },
906 final_loss
907 );
908 Ok(final_loss)
909 }
910
911 pub fn set_vector_integration(&mut self, enabled: bool) {
913 self.enable_vector_integration = enabled;
914 }
915
916 pub fn get_config(&self) -> NNUEConfig {
918 NNUEConfig {
919 feature_size: 768,
920 hidden_size: 256,
921 num_hidden_layers: self.hidden_layers.len(),
922 activation: ActivationType::ClippedReLU,
923 learning_rate: 0.001,
924 vector_blend_weight: self.vector_weight,
925 enable_incremental_updates: true,
926 }
927 }
928
929 pub fn save_model(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
931 use std::fs::File;
932 use std::io::Write;
933
934 let config = self.get_config();
936 let config_json = serde_json::to_string_pretty(&config)?;
937 let mut file = File::create(format!("{path}.config"))?;
938 file.write_all(config_json.as_bytes())?;
939 println!("Model configuration saved to {path}.config");
940
941 let mut weights_info = Vec::new();
944
945 let ft_weights_shape = self.feature_transformer.weights.shape().dims().to_vec();
947 let ft_biases_shape = self.feature_transformer.biases.shape().dims().to_vec();
948 let ft_weights_data = self
949 .feature_transformer
950 .weights
951 .flatten_all()?
952 .to_vec1::<f32>()?;
953 let ft_biases_data = self.feature_transformer.biases.to_vec1::<f32>()?;
954
955 weights_info.push((
956 "feature_transformer.weights".to_string(),
957 ft_weights_shape,
958 ft_weights_data,
959 ));
960 weights_info.push((
961 "feature_transformer.biases".to_string(),
962 ft_biases_shape,
963 ft_biases_data,
964 ));
965
966 for (i, layer) in self.hidden_layers.iter().enumerate() {
968 let weight_shape = layer.weight().shape().dims().to_vec();
969 let bias_shape = layer.bias().unwrap().shape().dims().to_vec();
970 let weight_data = layer.weight().flatten_all()?.to_vec1::<f32>()?;
971 let bias_data = layer.bias().unwrap().to_vec1::<f32>()?;
972
973 weights_info.push((
974 format!("hidden_layer_{}.weight", i),
975 weight_shape,
976 weight_data,
977 ));
978 weights_info.push((format!("hidden_layer_{}.bias", i), bias_shape, bias_data));
979 }
980
981 let output_weight_shape = self.output_layer.weight().shape().dims().to_vec();
983 let output_bias_shape = self.output_layer.bias().unwrap().shape().dims().to_vec();
984 let output_weight_data = self.output_layer.weight().flatten_all()?.to_vec1::<f32>()?;
985 let output_bias_data = self.output_layer.bias().unwrap().to_vec1::<f32>()?;
986
987 weights_info.push((
988 "output_layer.weight".to_string(),
989 output_weight_shape,
990 output_weight_data,
991 ));
992 weights_info.push((
993 "output_layer.bias".to_string(),
994 output_bias_shape,
995 output_bias_data,
996 ));
997
998 let version = self.training_version + 1;
1000
1001 let weights_json = serde_json::to_string(&weights_info)?;
1003
1004 std::fs::write(format!("{path}.weights"), &weights_json)?;
1006
1007 if version > 1 {
1009 std::fs::write(format!("{path}_v{version}.weights"), &weights_json)?;
1010 println!("💾 Versioned backup saved: {path}_v{version}.weights");
1011 }
1012
1013 self.training_version = version;
1015
1016 println!(
1017 "✅ Full model with weights saved to {path}.weights (v{})",
1018 version
1019 );
1020 println!("📊 Saved {} tensor parameters", weights_info.len());
1021 println!(
1022 "📝 Note: Using JSON serialization (can be upgraded to safetensors for production)"
1023 );
1024
1025 Ok(())
1026 }
1027
1028 pub fn load_model(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
1030 use std::fs;
1031
1032 let config_path = format!("{path}.config");
1034 if !std::path::Path::new(&config_path).exists() {
1035 return Err(format!("Model config file not found: {path}.config").into());
1036 }
1037
1038 let config_json = fs::read_to_string(config_path)?;
1039 let config: NNUEConfig = serde_json::from_str(&config_json)?;
1040
1041 self.vector_weight = config.vector_blend_weight;
1043 self.enable_vector_integration = true;
1044 self.weights_loaded = false; println!("✅ Configuration loaded from {path}.config");
1046
1047 let weights_path = format!("{path}.weights");
1049 if std::path::Path::new(&weights_path).exists() {
1050 let weights_json = fs::read_to_string(weights_path)?;
1051 let weights_info: Vec<(String, Vec<usize>, Vec<f32>)> =
1052 serde_json::from_str(&weights_json)?;
1053
1054 println!("🧠 Loading trained neural network weights...");
1055
1056 let mut loaded_weights = std::collections::HashMap::new();
1058
1059 for (name, shape, data) in &weights_info {
1060 println!(
1061 " ✅ Loaded {}: shape {:?}, {} parameters",
1062 name,
1063 shape,
1064 data.len()
1065 );
1066
1067 let tensor =
1068 candle_core::Tensor::from_vec(data.clone(), shape.as_slice(), &self.device)?;
1069 loaded_weights.insert(name.clone(), tensor);
1070 }
1071
1072 let config = self.get_config();
1074 let new_nnue = Self::new_with_weights(config, Some(loaded_weights))?;
1075
1076 self.feature_transformer = new_nnue.feature_transformer;
1078 self.weights_loaded = true;
1079
1080 let mut detected_version = 1;
1082 for v in 2..=100 {
1083 if std::path::Path::new(&format!("{path}_v{v}.weights")).exists() {
1084 detected_version = v;
1085 }
1086 }
1087 self.training_version = detected_version;
1088
1089 println!(
1090 " ✅ NNUE reconstructed with loaded weights (detected v{})",
1091 detected_version
1092 );
1093 println!(" 📝 Feature transformer weights: ✅ Applied");
1094 println!(" 📝 Hidden/output layers: ⚠️ candle-nn limitation remains");
1095 println!(" 💾 Next training will create v{}", detected_version + 1);
1096
1097 println!("✅ Neural network weights loaded successfully");
1098 println!("📊 Loaded {} tensor parameters", weights_info.len());
1099 println!(
1100 "📝 Note: Weight application to network requires deeper candle-nn integration"
1101 );
1102
1103 self.weights_loaded = true;
1105 } else {
1106 println!("⚠️ No weights file found at {path}.weights");
1107 println!(" Model will use fresh random weights");
1108 self.weights_loaded = false;
1109 }
1110
1111 Ok(())
1112 }
1113
1114 #[allow(dead_code)]
1116 fn apply_loaded_weights(
1117 &mut self,
1118 weights: std::collections::HashMap<String, candle_core::Tensor>,
1119 ) -> CandleResult<()> {
1120 if let (Some(ft_weights), Some(ft_biases)) = (
1122 weights.get("feature_transformer.weights"),
1123 weights.get("feature_transformer.biases"),
1124 ) {
1125 self.feature_transformer.weights = ft_weights.clone();
1126 self.feature_transformer.biases = ft_biases.clone();
1127 println!(" ✅ Applied feature transformer weights");
1128 }
1129
1130 for (i, _layer) in self.hidden_layers.iter_mut().enumerate() {
1132 let weight_key = format!("hidden_layer_{}.weight", i);
1133 let bias_key = format!("hidden_layer_{}.bias", i);
1134
1135 if let (Some(_weight), Some(_bias)) = (weights.get(&weight_key), weights.get(&bias_key))
1136 {
1137 println!(
1141 " ⚠️ Hidden layer {} weights loaded but not applied (candle-nn limitation)",
1142 i
1143 );
1144 }
1145 }
1146
1147 if let (Some(_weight), Some(_bias)) = (
1149 weights.get("output_layer.weight"),
1150 weights.get("output_layer.bias"),
1151 ) {
1152 println!(" ⚠️ Output layer weights loaded but not applied (candle-nn limitation)");
1153 }
1154
1155 println!(" 📝 Note: Full weight application requires candle-nn API enhancements");
1156
1157 Ok(())
1158 }
1159
1160 pub fn recreate_with_loaded_weights(
1162 &mut self,
1163 weights: std::collections::HashMap<String, candle_core::Tensor>,
1164 ) -> CandleResult<()> {
1165 let new_var_map = VarMap::new();
1167 let _vs = VarBuilder::from_varmap(&new_var_map, candle_core::DType::F32, &self.device);
1168
1169 for (name, _tensor) in weights {
1171 println!(" 🔄 Attempting to set {}", name);
1174 }
1175
1176 println!(" ⚠️ Weight recreation not fully implemented yet");
1179
1180 Ok(())
1181 }
1182
1183 pub fn get_eval_stats(&mut self, positions: &[Board]) -> CandleResult<EvalStats> {
1185 let mut stats = EvalStats::new();
1186
1187 for board in positions {
1188 let eval = self.evaluate(board)?; stats.add_evaluation(eval);
1190 }
1191
1192 Ok(stats)
1193 }
1194}
1195
1196impl FeatureTransformer {
1197 fn new(vs: VarBuilder, input_size: usize, output_size: usize) -> CandleResult<Self> {
1198 let weights = vs.get((input_size, output_size), "ft_weights")?;
1199 let biases = vs.get(output_size, "ft_biases")?;
1200
1201 Ok(Self {
1202 weights,
1203 biases,
1204 accumulated_features: None,
1205 king_squares: [Square::E1, Square::E8], })
1207 }
1208
1209 fn forward_optimized(&self, x: &Tensor) -> CandleResult<Tensor> {
1211 let output = x.matmul(&self.weights)?;
1213 output.broadcast_add(&self.biases)
1214 }
1215
1216 fn incremental_update(
1218 &mut self,
1219 moved_piece: Piece,
1220 piece_color: Color,
1221 from_square: Square,
1222 to_square: Square,
1223 white_king: Square,
1224 black_king: Square,
1225 ) -> CandleResult<()> {
1226 if self.accumulated_features.is_none() {
1228 self.accumulated_features = Some(self.biases.clone());
1230 }
1231
1232 let from_idx = self.get_feature_index_for_piece(moved_piece, piece_color, from_square, white_king, black_king);
1234 let to_idx = self.get_feature_index_for_piece(moved_piece, piece_color, to_square, white_king, black_king);
1235
1236 if let (Some(from_feature), Some(to_feature)) = (from_idx, to_idx) {
1238 if from_feature < 768 && to_feature < 768 {
1239 let from_weights = self.weights.get(from_feature)?;
1241 let to_weights = self.weights.get(to_feature)?;
1242
1243 if let Some(ref mut accumulator) = self.accumulated_features {
1245 *accumulator = accumulator.sub(&from_weights)?.add(&to_weights)?;
1246 }
1247 }
1248 }
1249
1250 self.king_squares = [white_king, black_king];
1252
1253 Ok(())
1254 }
1255
1256 fn get_feature_index_for_piece(
1258 &self,
1259 piece: Piece,
1260 color: Color,
1261 square: Square,
1262 white_king: Square,
1263 black_king: Square,
1264 ) -> Option<usize> {
1265 let piece_idx = match piece {
1266 Piece::Pawn => 0,
1267 Piece::Knight => 1,
1268 Piece::Bishop => 2,
1269 Piece::Rook => 3,
1270 Piece::Queen => 4,
1271 Piece::King => 5,
1272 };
1273
1274 let color_offset = if color == Color::White { 0 } else { 6 };
1275 let square_idx = square.to_index();
1276
1277 let king_square = if color == Color::White { white_king } else { black_king };
1279 let king_bucket = self.get_king_bucket(king_square);
1280
1281 let feature_idx = king_bucket * 384 + (piece_idx + color_offset) * 64 + square_idx;
1282
1283 if feature_idx < 768 {
1284 Some(feature_idx)
1285 } else {
1286 None
1287 }
1288 }
1289
1290 fn get_king_bucket(&self, king_square: Square) -> usize {
1292 let square_idx = king_square.to_index();
1293 let file = square_idx % 8;
1294 let rank = square_idx / 8;
1295
1296 let file_bucket = if file < 4 { 0 } else { 1 };
1298 let rank_bucket = if rank < 4 { 0 } else { 1 };
1299
1300 file_bucket + rank_bucket * 2
1301 }
1302
1303 fn reset_accumulator(&mut self) {
1305 self.accumulated_features = None;
1306 }
1307}
1308
1309impl Module for FeatureTransformer {
1310 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
1311 let output = x.matmul(&self.weights)?;
1313 output.broadcast_add(&self.biases)
1314 }
1315}
1316
1317#[derive(Debug, Clone)]
1319pub struct EvalStats {
1320 pub count: usize,
1321 pub mean: f32,
1322 pub min: f32,
1323 pub max: f32,
1324 pub std_dev: f32,
1325}
1326
1327impl EvalStats {
1328 fn new() -> Self {
1329 Self {
1330 count: 0,
1331 mean: 0.0,
1332 min: f32::INFINITY,
1333 max: f32::NEG_INFINITY,
1334 std_dev: 0.0,
1335 }
1336 }
1337
1338 fn add_evaluation(&mut self, eval: f32) {
1339 self.count += 1;
1340 self.min = self.min.min(eval);
1341 self.max = self.max.max(eval);
1342
1343 let delta = eval - self.mean;
1345 self.mean += delta / self.count as f32;
1346
1347 if self.count > 1 {
1349 let sum_sq =
1350 (self.count - 1) as f32 * self.std_dev.powi(2) + delta * (eval - self.mean);
1351 self.std_dev = (sum_sq / (self.count - 1) as f32).sqrt();
1352 }
1353 }
1354}
1355
1356pub struct HybridEvaluator {
1358 nnue: NNUE,
1359 vector_evaluator: Option<Box<dyn Fn(&Board) -> Option<f32>>>,
1360 blend_strategy: BlendStrategy,
1361}
1362
1363#[derive(Debug, Clone)]
1364pub enum BlendStrategy {
1365 Weighted(f32), Adaptive, Confidence(f32), GamePhase, }
1370
1371#[derive(Debug, Clone)]
1373pub struct BlendWeights {
1374 pub nnue_weight: f32,
1375 pub vector_weight: f32,
1376 pub tactical_weight: f32,
1377}
1378
1379#[derive(Debug, Clone, PartialEq)]
1381pub enum GamePhase {
1382 Opening,
1383 Middlegame,
1384 Endgame,
1385}
1386
1387impl HybridEvaluator {
1388 pub fn new(nnue: NNUE, blend_strategy: BlendStrategy) -> Self {
1389 Self {
1390 nnue,
1391 vector_evaluator: None,
1392 blend_strategy,
1393 }
1394 }
1395
1396 pub fn set_vector_evaluator<F>(&mut self, evaluator: F)
1397 where
1398 F: Fn(&Board) -> Option<f32> + 'static,
1399 {
1400 self.vector_evaluator = Some(Box::new(evaluator));
1401 }
1402
1403 pub fn evaluate(&mut self, board: &Board) -> CandleResult<f32> {
1404 let nnue_eval = self.nnue.evaluate(board)?;
1405
1406 let vector_eval = if let Some(ref evaluator) = self.vector_evaluator {
1407 evaluator(board)
1408 } else {
1409 None
1410 };
1411
1412 match self.blend_strategy {
1413 BlendStrategy::Weighted(weight) => {
1414 if let Some(vector_eval) = vector_eval {
1415 Ok((1.0 - weight) * nnue_eval + weight * vector_eval)
1416 } else {
1417 Ok(nnue_eval)
1418 }
1419 }
1420 BlendStrategy::Adaptive => {
1421 let is_tactical = self.is_tactical_position(board);
1423 let weight = if is_tactical { 0.2 } else { 0.5 }; if let Some(vector_eval) = vector_eval {
1426 Ok((1.0 - weight) * nnue_eval + weight * vector_eval)
1427 } else {
1428 Ok(nnue_eval)
1429 }
1430 }
1431 _ => Ok(nnue_eval), }
1433 }
1434
1435 fn is_tactical_position(&self, board: &Board) -> bool {
1436 board.checkers().popcnt() > 0
1438 || chess::MoveGen::new_legal(board).any(|m| board.piece_on(m.get_dest()).is_some())
1439 }
1440}
1441
1442#[cfg(test)]
1443mod tests {
1444 use super::*;
1445 use chess::Board;
1446
1447 #[test]
1448 fn test_nnue_creation() {
1449 let config = NNUEConfig::default();
1450 let nnue = NNUE::new(config);
1451 assert!(nnue.is_ok());
1452 }
1453
1454 #[test]
1455 fn test_nnue_evaluation() {
1456 let config = NNUEConfig::default();
1457 let mut nnue = NNUE::new(config).unwrap();
1458 let board = Board::default();
1459
1460 let eval = nnue.evaluate(&board);
1461 if eval.is_err() {
1462 println!("NNUE evaluation error: {:?}", eval.err());
1463 panic!("NNUE evaluation failed");
1464 }
1465
1466 let eval_value = eval.unwrap();
1468 assert!(eval_value.abs() < 100.0); }
1470
1471 #[test]
1472 fn test_hybrid_evaluation() {
1473 let config = NNUEConfig::vector_integrated();
1474 let mut nnue = NNUE::new(config).unwrap();
1475 let board = Board::default();
1476
1477 let vector_eval = Some(25.0); let hybrid_eval = nnue.evaluate_hybrid(&board, vector_eval, None);
1479 assert!(hybrid_eval.is_ok());
1480 }
1481
1482 #[test]
1483 fn test_feature_extraction() {
1484 let config = NNUEConfig::default();
1485 let nnue = NNUE::new(config).unwrap();
1486 let board = Board::default();
1487
1488 let features = nnue.extract_features(&board);
1489 assert!(features.is_ok());
1490
1491 let feature_tensor = features.unwrap();
1492 assert_eq!(feature_tensor.shape().dims(), &[1, 768]);
1493 }
1494
1495 #[test]
1496 fn test_blend_strategies() {
1497 let config = NNUEConfig::default();
1498 let nnue = NNUE::new(config).unwrap();
1499
1500 let mut evaluator = HybridEvaluator::new(nnue, BlendStrategy::Weighted(0.3));
1501 evaluator.set_vector_evaluator(|_| Some(50.0));
1502
1503 let board = Board::default();
1504 let eval = evaluator.evaluate(&board);
1505 assert!(eval.is_ok());
1506 }
1507}