chess_vector_engine/
nnue.rs

1#![allow(clippy::type_complexity)]
2use candle_core::{Device, Module, Result as CandleResult, Tensor};
3use candle_nn::{linear, AdamW, Linear, Optimizer, ParamsAdamW, VarBuilder, VarMap};
4use chess::{Board, Color, Piece, Square};
5use serde::{Deserialize, Serialize};
6
7/// NNUE (Efficiently Updatable Neural Network) for chess position evaluation
8/// This implementation is designed to work alongside vector-based position analysis
9///
10/// The key innovation is that NNUE provides fast, accurate position evaluation while
11/// the vector-based system provides strategic pattern recognition and similarity matching.
12/// Together they create a hybrid system that's both fast and strategically aware.
13pub struct NNUE {
14    feature_transformer: FeatureTransformer,
15    hidden_layers: Vec<Linear>,
16    output_layer: Linear,
17    device: Device,
18    #[allow(dead_code)]
19    var_map: VarMap,
20    optimizer: Option<AdamW>,
21
22    // Integration with vector-based system
23    vector_weight: f32, // How much to blend with vector evaluation
24    enable_vector_integration: bool,
25}
26
27/// Feature transformer that efficiently updates when pieces move
28/// Uses the standard NNUE approach with king-relative piece positions
29struct FeatureTransformer {
30    weights: Tensor,
31    biases: Tensor,
32    accumulated_features: Option<Tensor>,
33    king_squares: [Square; 2], // White and black king positions for incremental updates
34}
35
36/// NNUE configuration optimized for chess vector engine integration
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct NNUEConfig {
39    pub feature_size: usize,      // Input features (768 for king-relative pieces)
40    pub hidden_size: usize,       // Hidden layer size (256 typical)
41    pub num_hidden_layers: usize, // Number of hidden layers (2-4 typical)
42    pub activation: ActivationType,
43    pub learning_rate: f32,
44    pub vector_blend_weight: f32, // How much to blend with vector evaluation (0.0-1.0)
45    pub enable_incremental_updates: bool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum ActivationType {
50    ReLU,
51    ClippedReLU, // Clipped ReLU is standard for NNUE
52    Sigmoid,
53}
54
55impl Default for NNUEConfig {
56    fn default() -> Self {
57        Self {
58            feature_size: 768, // 12 pieces * 64 squares for king-relative
59            hidden_size: 256,
60            num_hidden_layers: 2,
61            activation: ActivationType::ClippedReLU,
62            learning_rate: 0.001,
63            vector_blend_weight: 0.3, // 30% vector, 70% NNUE by default
64            enable_incremental_updates: true,
65        }
66    }
67}
68
69impl NNUEConfig {
70    /// Configuration optimized for hybrid vector-NNUE evaluation
71    pub fn vector_integrated() -> Self {
72        Self {
73            vector_blend_weight: 0.4, // Higher vector influence for strategic awareness
74            ..Default::default()
75        }
76    }
77
78    /// Configuration for pure NNUE evaluation (less vector influence)
79    pub fn nnue_focused() -> Self {
80        Self {
81            vector_blend_weight: 0.1, // Minimal vector influence for speed
82            ..Default::default()
83        }
84    }
85
86    /// Configuration for research and experimentation
87    pub fn experimental() -> Self {
88        Self {
89            feature_size: 1024, // Match vector dimension for alignment
90            hidden_size: 512,
91            num_hidden_layers: 3,
92            vector_blend_weight: 0.5, // Equal blend
93            ..Default::default()
94        }
95    }
96}
97
98impl NNUE {
99    /// Create a new NNUE evaluator with vector integration
100    pub fn new(config: NNUEConfig) -> CandleResult<Self> {
101        let device = Device::Cpu; // Can be upgraded to GPU later
102        let var_map = VarMap::new();
103        let vs = VarBuilder::from_varmap(&var_map, candle_core::DType::F32, &device);
104
105        // Create feature transformer
106        let feature_transformer =
107            FeatureTransformer::new(vs.clone(), config.feature_size, config.hidden_size)?;
108
109        // Create hidden layers
110        let mut hidden_layers = Vec::new();
111        let mut prev_size = config.hidden_size;
112
113        for _i in 0..config.num_hidden_layers {
114            let layer = linear(prev_size, config.hidden_size, vs.pp("Processing..."))?;
115            hidden_layers.push(layer);
116            prev_size = config.hidden_size;
117        }
118
119        // Output layer (single neuron for evaluation)
120        let output_layer = linear(prev_size, 1, vs.pp("output"))?;
121
122        // Initialize optimizer
123        let adamw_params = ParamsAdamW {
124            lr: config.learning_rate as f64,
125            ..Default::default()
126        };
127        let optimizer = Some(AdamW::new(var_map.all_vars(), adamw_params)?);
128
129        Ok(Self {
130            feature_transformer,
131            hidden_layers,
132            output_layer,
133            device,
134            var_map,
135            optimizer,
136            vector_weight: config.vector_blend_weight,
137            enable_vector_integration: true,
138        })
139    }
140
141    /// Evaluate a position using NNUE
142    pub fn evaluate(&mut self, board: &Board) -> CandleResult<f32> {
143        let features = self.extract_features(board)?;
144        let output = self.forward(&features)?;
145
146        // Return evaluation in pawn units (consistent with rest of engine)
147        // Extract the single value from the [1, 1] tensor
148        let eval_pawn_units = output.to_vec2::<f32>()?[0][0];
149
150        Ok(eval_pawn_units)
151    }
152
153    /// Hybrid evaluation combining NNUE with vector-based analysis
154    pub fn evaluate_hybrid(
155        &mut self,
156        board: &Board,
157        vector_eval: Option<f32>,
158    ) -> CandleResult<f32> {
159        let nnue_eval = self.evaluate(board)?;
160
161        if !self.enable_vector_integration || vector_eval.is_none() {
162            return Ok(nnue_eval);
163        }
164
165        let vector_eval = vector_eval.unwrap();
166
167        // Blend evaluations: vector provides strategic insight, NNUE provides tactical precision
168        let blended = (1.0 - self.vector_weight) * nnue_eval + self.vector_weight * vector_eval;
169
170        Ok(blended)
171    }
172
173    /// Extract NNUE features from chess position
174    /// Uses king-relative piece encoding for efficient updates
175    fn extract_features(&self, board: &Board) -> CandleResult<Tensor> {
176        let mut features = vec![0.0f32; 768]; // 12 pieces * 64 squares
177
178        let white_king = board.king_square(Color::White);
179        let black_king = board.king_square(Color::Black);
180
181        // Encode pieces relative to king positions (standard NNUE approach)
182        for square in chess::ALL_SQUARES {
183            if let Some(piece) = board.piece_on(square) {
184                let color = board.color_on(square).unwrap();
185
186                // Get feature indices for this piece relative to both kings
187                let (white_idx, _black_idx) =
188                    self.get_feature_indices(piece, color, square, white_king, black_king);
189
190                // Activate features (only white perspective to fit in 768 features)
191                if let Some(idx) = white_idx {
192                    if idx < 768 {
193                        features[idx] = 1.0;
194                    }
195                }
196                // Skip black perspective for now to avoid index overflow
197                // Real NNUE would use a more sophisticated feature mapping
198            }
199        }
200
201        Tensor::from_vec(features, (1, 768), &self.device)
202    }
203
204    /// Get feature indices for a piece relative to king positions
205    fn get_feature_indices(
206        &self,
207        piece: Piece,
208        color: Color,
209        square: Square,
210        _white_king: Square,
211        _black_king: Square,
212    ) -> (Option<usize>, Option<usize>) {
213        let piece_type_idx = match piece {
214            Piece::Pawn => 0,
215            Piece::Knight => 1,
216            Piece::Bishop => 2,
217            Piece::Rook => 3,
218            Piece::Queen => 4,
219            Piece::King => return (None, None), // Kings not included in features
220        };
221
222        let color_offset = if color == Color::White { 0 } else { 5 };
223        let base_idx = (piece_type_idx + color_offset) * 64;
224
225        // Calculate square index (simplified - real NNUE uses king-relative mapping)
226        let feature_idx = base_idx + square.to_index();
227
228        // Ensure we don't exceed feature bounds
229        if feature_idx < 768 {
230            (Some(feature_idx), Some(feature_idx)) // Same index for both perspectives for simplicity
231        } else {
232            (None, None)
233        }
234    }
235
236    /// Forward pass through the network
237    fn forward(&self, features: &Tensor) -> CandleResult<Tensor> {
238        // Transform features
239        let mut x = self.feature_transformer.forward(features)?;
240
241        // Hidden layers with clipped ReLU activation
242        for layer in &self.hidden_layers {
243            x = layer.forward(&x)?;
244            x = self.clipped_relu(&x)?;
245        }
246
247        // Output layer
248        let output = self.output_layer.forward(&x)?;
249
250        Ok(output)
251    }
252
253    /// Clipped ReLU activation (standard for NNUE)
254    fn clipped_relu(&self, x: &Tensor) -> CandleResult<Tensor> {
255        // Clamp values between 0 and 1 (ReLU then clip at 1)
256        let relu = x.relu()?;
257        relu.clamp(0.0, 1.0)
258    }
259
260    /// Train the NNUE network on position data
261    pub fn train_batch(&mut self, positions: &[(Board, f32)]) -> CandleResult<f32> {
262        let batch_size = positions.len();
263        let mut total_loss = 0.0;
264
265        for (board, target_eval) in positions {
266            // Extract features
267            let features = self.extract_features(board)?;
268
269            // Forward pass
270            let prediction = self.forward(&features)?;
271
272            // Create target tensor (target_eval is already in pawn units)
273            let target = Tensor::from_vec(vec![*target_eval], (1, 1), &self.device)?;
274
275            // Compute loss (MSE)
276            let diff = (&prediction - &target)?;
277            let squared = diff.powf(2.0)?;
278            let loss = squared.sum_all()?;
279
280            // Backward pass and optimization
281            if let Some(ref mut optimizer) = self.optimizer {
282                // Compute gradients
283                let grads = loss.backward()?;
284
285                // Step the optimizer with computed gradients
286                optimizer.step(&grads)?;
287            }
288
289            total_loss += loss.to_scalar::<f32>()?;
290        }
291
292        Ok(total_loss / batch_size as f32)
293    }
294
295    /// Incremental update when a move is made (NNUE efficiency feature)
296    pub fn update_incrementally(
297        &mut self,
298        board: &Board,
299        _chess_move: chess::ChessMove,
300    ) -> CandleResult<()> {
301        // Update king positions for incremental feature tracking
302        let white_king = board.king_square(Color::White);
303        let black_king = board.king_square(Color::Black);
304        self.feature_transformer.king_squares = [white_king, black_king];
305
306        // For now, we'll re-extract features for simplicity
307        // Real NNUE would incrementally update the accumulator
308        let features = self.extract_features(board)?;
309        self.feature_transformer.accumulated_features = Some(features);
310
311        // In a production implementation, this would efficiently:
312        // 1. Remove features for moved piece from old square
313        // 2. Add features for moved piece on new square
314        // 3. Handle captures, castling, en passant, promotions
315        // 4. Update accumulator without full re-computation (10-100x faster)
316
317        Ok(())
318    }
319
320    /// Set the vector evaluation blend weight
321    pub fn set_vector_weight(&mut self, weight: f32) {
322        self.vector_weight = weight.clamp(0.0, 1.0);
323    }
324
325    /// Enable or disable vector integration
326    pub fn set_vector_integration(&mut self, enabled: bool) {
327        self.enable_vector_integration = enabled;
328    }
329
330    /// Get current configuration
331    pub fn get_config(&self) -> NNUEConfig {
332        NNUEConfig {
333            feature_size: 768,
334            hidden_size: 256,
335            num_hidden_layers: self.hidden_layers.len(),
336            activation: ActivationType::ClippedReLU,
337            learning_rate: 0.001,
338            vector_blend_weight: self.vector_weight,
339            enable_incremental_updates: true,
340        }
341    }
342
343    /// Save the trained model to a file
344    pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
345        use std::fs::File;
346        use std::io::Write;
347
348        // Save model weights as safetensors or custom format
349        // For now, save configuration and basic model info
350        let config = self.get_config();
351        let config_json = serde_json::to_string_pretty(&config)?;
352
353        let mut file = File::create(format!("{path}.config"))?;
354        file.write_all(config_json.as_bytes())?;
355
356        // In production, would save actual tensor weights using safetensors
357        println!("Model configuration saved to {path}.config");
358        println!("Note: Full weight serialization requires safetensors integration");
359
360        Ok(())
361    }
362
363    /// Load a trained model from a file  
364    pub fn load_model(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
365        use std::fs;
366
367        // Load model configuration
368        let config_path = format!("{path}.config");
369        if std::path::Path::new(&config_path).exists() {
370            let config_json = fs::read_to_string(config_path)?;
371            let config: NNUEConfig = serde_json::from_str(&config_json)?;
372
373            // Apply loaded configuration
374            self.vector_weight = config.vector_blend_weight;
375            self.enable_vector_integration = true;
376
377            println!("Operation complete");
378            println!("Note: Full weight loading requires safetensors integration");
379        } else {
380            return Err(format!("Model config file not found: {path}.config").into());
381        }
382
383        Ok(())
384    }
385
386    /// Get evaluation statistics for analysis
387    pub fn get_eval_stats(&mut self, positions: &[Board]) -> CandleResult<EvalStats> {
388        let mut stats = EvalStats::new();
389
390        for board in positions {
391            let eval = self.evaluate(board)?; // Simplified for demo
392            stats.add_evaluation(eval);
393        }
394
395        Ok(stats)
396    }
397}
398
399impl FeatureTransformer {
400    fn new(vs: VarBuilder, input_size: usize, output_size: usize) -> CandleResult<Self> {
401        let weights = vs.get((input_size, output_size), "ft_weights")?;
402        let biases = vs.get(output_size, "ft_biases")?;
403
404        Ok(Self {
405            weights,
406            biases,
407            accumulated_features: None,
408            king_squares: [Square::E1, Square::E8], // Default positions
409        })
410    }
411}
412
413impl Module for FeatureTransformer {
414    fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
415        // Simple linear transformation (real NNUE uses more efficient accumulator)
416        let output = x.matmul(&self.weights)?;
417        output.broadcast_add(&self.biases)
418    }
419}
420
421/// Statistics for NNUE evaluation analysis
422#[derive(Debug, Clone)]
423pub struct EvalStats {
424    pub count: usize,
425    pub mean: f32,
426    pub min: f32,
427    pub max: f32,
428    pub std_dev: f32,
429}
430
431impl EvalStats {
432    fn new() -> Self {
433        Self {
434            count: 0,
435            mean: 0.0,
436            min: f32::INFINITY,
437            max: f32::NEG_INFINITY,
438            std_dev: 0.0,
439        }
440    }
441
442    fn add_evaluation(&mut self, eval: f32) {
443        self.count += 1;
444        self.min = self.min.min(eval);
445        self.max = self.max.max(eval);
446
447        // Running mean calculation
448        let delta = eval - self.mean;
449        self.mean += delta / self.count as f32;
450
451        // Simplified std dev calculation (not numerically stable for large datasets)
452        if self.count > 1 {
453            let sum_sq =
454                (self.count - 1) as f32 * self.std_dev.powi(2) + delta * (eval - self.mean);
455            self.std_dev = (sum_sq / (self.count - 1) as f32).sqrt();
456        }
457    }
458}
459
460/// Integration helper for combining NNUE with vector-based evaluation
461pub struct HybridEvaluator {
462    nnue: NNUE,
463    vector_evaluator: Option<Box<dyn Fn(&Board) -> Option<f32>>>,
464    blend_strategy: BlendStrategy,
465}
466
467#[derive(Debug, Clone)]
468pub enum BlendStrategy {
469    Weighted(f32),   // Fixed weight blend
470    Adaptive,        // Adapt based on position type
471    Confidence(f32), // Use vector when NNUE confidence is low
472    GamePhase,       // Different blending for opening/middlegame/endgame
473}
474
475impl HybridEvaluator {
476    pub fn new(nnue: NNUE, blend_strategy: BlendStrategy) -> Self {
477        Self {
478            nnue,
479            vector_evaluator: None,
480            blend_strategy,
481        }
482    }
483
484    pub fn set_vector_evaluator<F>(&mut self, evaluator: F)
485    where
486        F: Fn(&Board) -> Option<f32> + 'static,
487    {
488        self.vector_evaluator = Some(Box::new(evaluator));
489    }
490
491    pub fn evaluate(&mut self, board: &Board) -> CandleResult<f32> {
492        let nnue_eval = self.nnue.evaluate(board)?;
493
494        let vector_eval = if let Some(ref evaluator) = self.vector_evaluator {
495            evaluator(board)
496        } else {
497            None
498        };
499
500        match self.blend_strategy {
501            BlendStrategy::Weighted(weight) => {
502                if let Some(vector_eval) = vector_eval {
503                    Ok((1.0 - weight) * nnue_eval + weight * vector_eval)
504                } else {
505                    Ok(nnue_eval)
506                }
507            }
508            BlendStrategy::Adaptive => {
509                // Adapt based on position characteristics
510                let is_tactical = self.is_tactical_position(board);
511                let weight = if is_tactical { 0.2 } else { 0.5 }; // Less vector in tactical positions
512
513                if let Some(vector_eval) = vector_eval {
514                    Ok((1.0 - weight) * nnue_eval + weight * vector_eval)
515                } else {
516                    Ok(nnue_eval)
517                }
518            }
519            _ => Ok(nnue_eval), // Other strategies can be implemented
520        }
521    }
522
523    fn is_tactical_position(&self, board: &Board) -> bool {
524        // Simple tactical detection (can be enhanced)
525        board.checkers().popcnt() > 0
526            || chess::MoveGen::new_legal(board).any(|m| board.piece_on(m.get_dest()).is_some())
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use chess::Board;
534
535    #[test]
536    fn test_nnue_creation() {
537        let config = NNUEConfig::default();
538        let nnue = NNUE::new(config);
539        assert!(nnue.is_ok());
540    }
541
542    #[test]
543    fn test_nnue_evaluation() {
544        let config = NNUEConfig::default();
545        let mut nnue = NNUE::new(config).unwrap();
546        let board = Board::default();
547
548        let eval = nnue.evaluate(&board);
549        if eval.is_err() {
550            println!("NNUE evaluation error: {:?}", eval.err());
551            panic!("NNUE evaluation failed");
552        }
553
554        // Starting position should be close to 0
555        let eval_value = eval.unwrap();
556        assert!(eval_value.abs() < 100.0); // Within 1 pawn
557    }
558
559    #[test]
560    fn test_hybrid_evaluation() {
561        let config = NNUEConfig::vector_integrated();
562        let mut nnue = NNUE::new(config).unwrap();
563        let board = Board::default();
564
565        let vector_eval = Some(25.0); // Small advantage
566        let hybrid_eval = nnue.evaluate_hybrid(&board, vector_eval);
567        assert!(hybrid_eval.is_ok());
568    }
569
570    #[test]
571    fn test_feature_extraction() {
572        let config = NNUEConfig::default();
573        let nnue = NNUE::new(config).unwrap();
574        let board = Board::default();
575
576        let features = nnue.extract_features(&board);
577        assert!(features.is_ok());
578
579        let feature_tensor = features.unwrap();
580        assert_eq!(feature_tensor.shape().dims(), &[1, 768]);
581    }
582
583    #[test]
584    fn test_blend_strategies() {
585        let config = NNUEConfig::default();
586        let nnue = NNUE::new(config).unwrap();
587
588        let mut evaluator = HybridEvaluator::new(nnue, BlendStrategy::Weighted(0.3));
589        evaluator.set_vector_evaluator(|_| Some(50.0));
590
591        let board = Board::default();
592        let eval = evaluator.evaluate(&board);
593        assert!(eval.is_ok());
594    }
595}