1#![allow(clippy::type_complexity)]
2use candle_core::{Device, Module, Result as CandleResult, Tensor};
3use candle_nn::{linear, AdamW, Linear, Optimizer, ParamsAdamW, VarBuilder, VarMap};
4use chess::{Board, Color, Piece, Square};
5use serde::{Deserialize, Serialize};
6
7pub struct NNUE {
14 feature_transformer: FeatureTransformer,
15 hidden_layers: Vec<Linear>,
16 output_layer: Linear,
17 device: Device,
18 #[allow(dead_code)]
19 var_map: VarMap,
20 optimizer: Option<AdamW>,
21
22 vector_weight: f32, enable_vector_integration: bool,
25}
26
27struct FeatureTransformer {
30 weights: Tensor,
31 biases: Tensor,
32 accumulated_features: Option<Tensor>,
33 king_squares: [Square; 2], }
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct NNUEConfig {
39 pub feature_size: usize, pub hidden_size: usize, pub num_hidden_layers: usize, pub activation: ActivationType,
43 pub learning_rate: f32,
44 pub vector_blend_weight: f32, pub enable_incremental_updates: bool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub enum ActivationType {
50 ReLU,
51 ClippedReLU, Sigmoid,
53}
54
55impl Default for NNUEConfig {
56 fn default() -> Self {
57 Self {
58 feature_size: 768, hidden_size: 256,
60 num_hidden_layers: 2,
61 activation: ActivationType::ClippedReLU,
62 learning_rate: 0.001,
63 vector_blend_weight: 0.3, enable_incremental_updates: true,
65 }
66 }
67}
68
69impl NNUEConfig {
70 pub fn vector_integrated() -> Self {
72 Self {
73 vector_blend_weight: 0.4, ..Default::default()
75 }
76 }
77
78 pub fn nnue_focused() -> Self {
80 Self {
81 vector_blend_weight: 0.1, ..Default::default()
83 }
84 }
85
86 pub fn experimental() -> Self {
88 Self {
89 feature_size: 1024, hidden_size: 512,
91 num_hidden_layers: 3,
92 vector_blend_weight: 0.5, ..Default::default()
94 }
95 }
96}
97
98impl NNUE {
99 pub fn new(config: NNUEConfig) -> CandleResult<Self> {
101 let device = Device::Cpu; let var_map = VarMap::new();
103 let vs = VarBuilder::from_varmap(&var_map, candle_core::DType::F32, &device);
104
105 let feature_transformer =
107 FeatureTransformer::new(vs.clone(), config.feature_size, config.hidden_size)?;
108
109 let mut hidden_layers = Vec::new();
111 let mut prev_size = config.hidden_size;
112
113 for _i in 0..config.num_hidden_layers {
114 let layer = linear(prev_size, config.hidden_size, vs.pp("Processing..."))?;
115 hidden_layers.push(layer);
116 prev_size = config.hidden_size;
117 }
118
119 let output_layer = linear(prev_size, 1, vs.pp("output"))?;
121
122 let adamw_params = ParamsAdamW {
124 lr: config.learning_rate as f64,
125 ..Default::default()
126 };
127 let optimizer = Some(AdamW::new(var_map.all_vars(), adamw_params)?);
128
129 Ok(Self {
130 feature_transformer,
131 hidden_layers,
132 output_layer,
133 device,
134 var_map,
135 optimizer,
136 vector_weight: config.vector_blend_weight,
137 enable_vector_integration: true,
138 })
139 }
140
141 pub fn evaluate(&mut self, board: &Board) -> CandleResult<f32> {
143 let features = self.extract_features(board)?;
144 let output = self.forward(&features)?;
145
146 let eval_pawn_units = output.to_vec2::<f32>()?[0][0];
149
150 Ok(eval_pawn_units)
151 }
152
153 pub fn evaluate_hybrid(
155 &mut self,
156 board: &Board,
157 vector_eval: Option<f32>,
158 ) -> CandleResult<f32> {
159 let nnue_eval = self.evaluate(board)?;
160
161 if !self.enable_vector_integration || vector_eval.is_none() {
162 return Ok(nnue_eval);
163 }
164
165 let vector_eval = vector_eval.unwrap();
166
167 let blended = (1.0 - self.vector_weight) * nnue_eval + self.vector_weight * vector_eval;
169
170 Ok(blended)
171 }
172
173 fn extract_features(&self, board: &Board) -> CandleResult<Tensor> {
176 let mut features = vec![0.0f32; 768]; let white_king = board.king_square(Color::White);
179 let black_king = board.king_square(Color::Black);
180
181 for square in chess::ALL_SQUARES {
183 if let Some(piece) = board.piece_on(square) {
184 let color = board.color_on(square).unwrap();
185
186 let (white_idx, _black_idx) =
188 self.get_feature_indices(piece, color, square, white_king, black_king);
189
190 if let Some(idx) = white_idx {
192 if idx < 768 {
193 features[idx] = 1.0;
194 }
195 }
196 }
199 }
200
201 Tensor::from_vec(features, (1, 768), &self.device)
202 }
203
204 fn get_feature_indices(
206 &self,
207 piece: Piece,
208 color: Color,
209 square: Square,
210 _white_king: Square,
211 _black_king: Square,
212 ) -> (Option<usize>, Option<usize>) {
213 let piece_type_idx = match piece {
214 Piece::Pawn => 0,
215 Piece::Knight => 1,
216 Piece::Bishop => 2,
217 Piece::Rook => 3,
218 Piece::Queen => 4,
219 Piece::King => return (None, None), };
221
222 let color_offset = if color == Color::White { 0 } else { 5 };
223 let base_idx = (piece_type_idx + color_offset) * 64;
224
225 let feature_idx = base_idx + square.to_index();
227
228 if feature_idx < 768 {
230 (Some(feature_idx), Some(feature_idx)) } else {
232 (None, None)
233 }
234 }
235
236 fn forward(&self, features: &Tensor) -> CandleResult<Tensor> {
238 let mut x = self.feature_transformer.forward(features)?;
240
241 for layer in &self.hidden_layers {
243 x = layer.forward(&x)?;
244 x = self.clipped_relu(&x)?;
245 }
246
247 let output = self.output_layer.forward(&x)?;
249
250 Ok(output)
251 }
252
253 fn clipped_relu(&self, x: &Tensor) -> CandleResult<Tensor> {
255 let relu = x.relu()?;
257 relu.clamp(0.0, 1.0)
258 }
259
260 pub fn train_batch(&mut self, positions: &[(Board, f32)]) -> CandleResult<f32> {
262 let batch_size = positions.len();
263 let mut total_loss = 0.0;
264
265 for (board, target_eval) in positions {
266 let features = self.extract_features(board)?;
268
269 let prediction = self.forward(&features)?;
271
272 let target = Tensor::from_vec(vec![*target_eval], (1, 1), &self.device)?;
274
275 let diff = (&prediction - &target)?;
277 let squared = diff.powf(2.0)?;
278 let loss = squared.sum_all()?;
279
280 if let Some(ref mut optimizer) = self.optimizer {
282 let grads = loss.backward()?;
284
285 optimizer.step(&grads)?;
287 }
288
289 total_loss += loss.to_scalar::<f32>()?;
290 }
291
292 Ok(total_loss / batch_size as f32)
293 }
294
295 pub fn update_incrementally(
297 &mut self,
298 board: &Board,
299 _chess_move: chess::ChessMove,
300 ) -> CandleResult<()> {
301 let white_king = board.king_square(Color::White);
303 let black_king = board.king_square(Color::Black);
304 self.feature_transformer.king_squares = [white_king, black_king];
305
306 let features = self.extract_features(board)?;
309 self.feature_transformer.accumulated_features = Some(features);
310
311 Ok(())
318 }
319
320 pub fn set_vector_weight(&mut self, weight: f32) {
322 self.vector_weight = weight.clamp(0.0, 1.0);
323 }
324
325 pub fn set_vector_integration(&mut self, enabled: bool) {
327 self.enable_vector_integration = enabled;
328 }
329
330 pub fn get_config(&self) -> NNUEConfig {
332 NNUEConfig {
333 feature_size: 768,
334 hidden_size: 256,
335 num_hidden_layers: self.hidden_layers.len(),
336 activation: ActivationType::ClippedReLU,
337 learning_rate: 0.001,
338 vector_blend_weight: self.vector_weight,
339 enable_incremental_updates: true,
340 }
341 }
342
343 pub fn save_model(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
345 use std::fs::File;
346 use std::io::Write;
347
348 let config = self.get_config();
351 let config_json = serde_json::to_string_pretty(&config)?;
352
353 let mut file = File::create(format!("{path}.config"))?;
354 file.write_all(config_json.as_bytes())?;
355
356 println!("Model configuration saved to {path}.config");
358 println!("Note: Full weight serialization requires safetensors integration");
359
360 Ok(())
361 }
362
363 pub fn load_model(&mut self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
365 use std::fs;
366
367 let config_path = format!("{path}.config");
369 if std::path::Path::new(&config_path).exists() {
370 let config_json = fs::read_to_string(config_path)?;
371 let config: NNUEConfig = serde_json::from_str(&config_json)?;
372
373 self.vector_weight = config.vector_blend_weight;
375 self.enable_vector_integration = true;
376
377 println!("Operation complete");
378 println!("Note: Full weight loading requires safetensors integration");
379 } else {
380 return Err(format!("Model config file not found: {path}.config").into());
381 }
382
383 Ok(())
384 }
385
386 pub fn get_eval_stats(&mut self, positions: &[Board]) -> CandleResult<EvalStats> {
388 let mut stats = EvalStats::new();
389
390 for board in positions {
391 let eval = self.evaluate(board)?; stats.add_evaluation(eval);
393 }
394
395 Ok(stats)
396 }
397}
398
399impl FeatureTransformer {
400 fn new(vs: VarBuilder, input_size: usize, output_size: usize) -> CandleResult<Self> {
401 let weights = vs.get((input_size, output_size), "ft_weights")?;
402 let biases = vs.get(output_size, "ft_biases")?;
403
404 Ok(Self {
405 weights,
406 biases,
407 accumulated_features: None,
408 king_squares: [Square::E1, Square::E8], })
410 }
411}
412
413impl Module for FeatureTransformer {
414 fn forward(&self, x: &Tensor) -> CandleResult<Tensor> {
415 let output = x.matmul(&self.weights)?;
417 output.broadcast_add(&self.biases)
418 }
419}
420
421#[derive(Debug, Clone)]
423pub struct EvalStats {
424 pub count: usize,
425 pub mean: f32,
426 pub min: f32,
427 pub max: f32,
428 pub std_dev: f32,
429}
430
431impl EvalStats {
432 fn new() -> Self {
433 Self {
434 count: 0,
435 mean: 0.0,
436 min: f32::INFINITY,
437 max: f32::NEG_INFINITY,
438 std_dev: 0.0,
439 }
440 }
441
442 fn add_evaluation(&mut self, eval: f32) {
443 self.count += 1;
444 self.min = self.min.min(eval);
445 self.max = self.max.max(eval);
446
447 let delta = eval - self.mean;
449 self.mean += delta / self.count as f32;
450
451 if self.count > 1 {
453 let sum_sq =
454 (self.count - 1) as f32 * self.std_dev.powi(2) + delta * (eval - self.mean);
455 self.std_dev = (sum_sq / (self.count - 1) as f32).sqrt();
456 }
457 }
458}
459
460pub struct HybridEvaluator {
462 nnue: NNUE,
463 vector_evaluator: Option<Box<dyn Fn(&Board) -> Option<f32>>>,
464 blend_strategy: BlendStrategy,
465}
466
467#[derive(Debug, Clone)]
468pub enum BlendStrategy {
469 Weighted(f32), Adaptive, Confidence(f32), GamePhase, }
474
475impl HybridEvaluator {
476 pub fn new(nnue: NNUE, blend_strategy: BlendStrategy) -> Self {
477 Self {
478 nnue,
479 vector_evaluator: None,
480 blend_strategy,
481 }
482 }
483
484 pub fn set_vector_evaluator<F>(&mut self, evaluator: F)
485 where
486 F: Fn(&Board) -> Option<f32> + 'static,
487 {
488 self.vector_evaluator = Some(Box::new(evaluator));
489 }
490
491 pub fn evaluate(&mut self, board: &Board) -> CandleResult<f32> {
492 let nnue_eval = self.nnue.evaluate(board)?;
493
494 let vector_eval = if let Some(ref evaluator) = self.vector_evaluator {
495 evaluator(board)
496 } else {
497 None
498 };
499
500 match self.blend_strategy {
501 BlendStrategy::Weighted(weight) => {
502 if let Some(vector_eval) = vector_eval {
503 Ok((1.0 - weight) * nnue_eval + weight * vector_eval)
504 } else {
505 Ok(nnue_eval)
506 }
507 }
508 BlendStrategy::Adaptive => {
509 let is_tactical = self.is_tactical_position(board);
511 let weight = if is_tactical { 0.2 } else { 0.5 }; if let Some(vector_eval) = vector_eval {
514 Ok((1.0 - weight) * nnue_eval + weight * vector_eval)
515 } else {
516 Ok(nnue_eval)
517 }
518 }
519 _ => Ok(nnue_eval), }
521 }
522
523 fn is_tactical_position(&self, board: &Board) -> bool {
524 board.checkers().popcnt() > 0
526 || chess::MoveGen::new_legal(board).any(|m| board.piece_on(m.get_dest()).is_some())
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use chess::Board;
534
535 #[test]
536 fn test_nnue_creation() {
537 let config = NNUEConfig::default();
538 let nnue = NNUE::new(config);
539 assert!(nnue.is_ok());
540 }
541
542 #[test]
543 fn test_nnue_evaluation() {
544 let config = NNUEConfig::default();
545 let mut nnue = NNUE::new(config).unwrap();
546 let board = Board::default();
547
548 let eval = nnue.evaluate(&board);
549 if eval.is_err() {
550 println!("NNUE evaluation error: {:?}", eval.err());
551 panic!("NNUE evaluation failed");
552 }
553
554 let eval_value = eval.unwrap();
556 assert!(eval_value.abs() < 100.0); }
558
559 #[test]
560 fn test_hybrid_evaluation() {
561 let config = NNUEConfig::vector_integrated();
562 let mut nnue = NNUE::new(config).unwrap();
563 let board = Board::default();
564
565 let vector_eval = Some(25.0); let hybrid_eval = nnue.evaluate_hybrid(&board, vector_eval);
567 assert!(hybrid_eval.is_ok());
568 }
569
570 #[test]
571 fn test_feature_extraction() {
572 let config = NNUEConfig::default();
573 let nnue = NNUE::new(config).unwrap();
574 let board = Board::default();
575
576 let features = nnue.extract_features(&board);
577 assert!(features.is_ok());
578
579 let feature_tensor = features.unwrap();
580 assert_eq!(feature_tensor.shape().dims(), &[1, 768]);
581 }
582
583 #[test]
584 fn test_blend_strategies() {
585 let config = NNUEConfig::default();
586 let nnue = NNUE::new(config).unwrap();
587
588 let mut evaluator = HybridEvaluator::new(nnue, BlendStrategy::Weighted(0.3));
589 evaluator.set_vector_evaluator(|_| Some(50.0));
590
591 let board = Board::default();
592 let eval = evaluator.evaluate(&board);
593 assert!(eval.is_ok());
594 }
595}