use super::accumulator::Aligned;
use super::accumulator_layer_stacks::{AccumulatorLayerStacks, AccumulatorStackLayerStacks};
use super::constants::{MAX_ARCH_LEN, NNUE_PYTORCH_L1, NNUE_VERSION_HALFKA};
use super::feature_transformer_layer_stacks::FeatureTransformerLayerStacks;
use super::layer_stacks::{LayerStacks, compute_bucket_index, sqr_clipped_relu_transform};
use crate::position::Position;
use crate::types::{Color, Value};
#[cfg(feature = "diagnostics")]
use log::info;
use std::fs::File;
use std::io::{self, BufReader, Cursor, Read, Seek};
use std::path::Path;
pub struct NetworkLayerStacks {
pub feature_transformer: FeatureTransformerLayerStacks,
pub layer_stacks: LayerStacks,
}
impl NetworkLayerStacks {
pub fn load<P: AsRef<Path>>(path: P) -> io::Result<Self> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
Self::read(&mut reader)
}
pub fn read<R: Read + Seek>(reader: &mut R) -> io::Result<Self> {
let mut buf4 = [0u8; 4];
reader.read_exact(&mut buf4)?;
let version = u32::from_le_bytes(buf4);
if version != NNUE_VERSION_HALFKA {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Invalid NNUE version for nnue-pytorch: {version:#x}, expected {NNUE_VERSION_HALFKA:#x}"
),
));
}
reader.read_exact(&mut buf4)?;
let _hash = u32::from_le_bytes(buf4);
reader.read_exact(&mut buf4)?;
let arch_len = u32::from_le_bytes(buf4) as usize;
if arch_len == 0 || arch_len > MAX_ARCH_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid arch string length: {arch_len} (max: {MAX_ARCH_LEN})"),
));
}
let mut arch = vec![0u8; arch_len];
reader.read_exact(&mut arch)?;
let arch_str = String::from_utf8_lossy(&arch);
if arch_str.contains("Factorizer") {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unsupported model format: factorized (non-coalesced) model detected.\n\
This engine only supports coalesced models.\n\n\
To fix: Re-export the model using nnue-pytorch serialize.py:\n\
python serialize.py model.ckpt output.nnue\n\n\
Architecture string: {arch_str}"
),
));
}
reader.read_exact(&mut buf4)?;
let _ft_hash = u32::from_le_bytes(buf4);
let feature_transformer = FeatureTransformerLayerStacks::read_leb128(reader)?;
let layer_stacks = LayerStacks::read(reader)?;
let mut probe = [0u8; 1];
match reader.read(&mut probe) {
Ok(0) => {
}
Ok(_) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"NNUE file has unexpected trailing data.\n\
This likely indicates a factorized (non-coalesced) model.\n\
This engine only supports coalesced models.\n\n\
To fix: Re-export the model using nnue-pytorch serialize.py:\n\
python serialize.py model.ckpt output.nnue\n\n\
The serialize.py script automatically coalesces factor weights.",
));
}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
}
Err(e) => {
return Err(e);
}
}
#[cfg(feature = "diagnostics")]
{
Self::log_load_diagnostics(&feature_transformer, &layer_stacks);
}
Ok(Self {
feature_transformer,
layer_stacks,
})
}
#[cfg(feature = "diagnostics")]
fn log_load_diagnostics(ft: &FeatureTransformerLayerStacks, ls: &LayerStacks) {
let bias_sum: i64 = ft.biases.0.iter().map(|&x| x as i64).sum();
let weight_min = ft.weights.iter().copied().min().unwrap_or(0);
let weight_max = ft.weights.iter().copied().max().unwrap_or(0);
let weight_nonzero: usize = ft.weights.iter().filter(|&&x| x != 0).count();
let weight_total = ft.weights.len();
info!("[NNUE Load] FT bias sum: {bias_sum}");
info!("[NNUE Load] FT weight: min={weight_min}, max={weight_max}");
info!(
"[NNUE Load] FT weight nonzero: {weight_nonzero}/{weight_total} ({:.2}%)",
weight_nonzero as f64 / weight_total as f64 * 100.0
);
let l1_biases = &ls.buckets[0].l1_biases;
info!("[NNUE Load] LayerStacks bucket0 l1_biases: {l1_biases:?}");
}
pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
let mut cursor = Cursor::new(bytes);
Self::read(&mut cursor)
}
pub fn evaluate(&self, pos: &Position, acc: &AccumulatorLayerStacks) -> Value {
let side_to_move = pos.side_to_move();
let (us_acc, them_acc) = if side_to_move == Color::Black {
(acc.get(Color::Black as usize), acc.get(Color::White as usize))
} else {
(acc.get(Color::White as usize), acc.get(Color::Black as usize))
};
let mut transformed: Aligned<[u8; NNUE_PYTORCH_L1]> = unsafe { Aligned::new_uninit() };
sqr_clipped_relu_transform(us_acc, them_acc, &mut transformed.0);
let f_king = pos.king_square(side_to_move);
let e_king = pos.king_square(!side_to_move);
let (f_rank, e_rank) =
crate::nnue::layer_stacks::compute_king_ranks(side_to_move, f_king, e_king);
let bucket_index = compute_bucket_index(f_rank, e_rank);
let score = self.layer_stacks.evaluate(bucket_index, &transformed.0);
Value::new(score)
}
#[cfg(feature = "diagnostics")]
pub fn evaluate_with_diagnostics(&self, pos: &Position, acc: &AccumulatorLayerStacks) -> Value {
use log::info;
let side_to_move = pos.side_to_move();
let (us_acc, them_acc) = if side_to_move == Color::Black {
(acc.get(Color::Black as usize), acc.get(Color::White as usize))
} else {
(acc.get(Color::White as usize), acc.get(Color::Black as usize))
};
let us_min = us_acc.iter().copied().min().unwrap_or(0);
let us_max = us_acc.iter().copied().max().unwrap_or(0);
let us_first_half_positive: usize = us_acc[0..768].iter().filter(|&&x| x > 0).count();
let us_second_half_positive: usize = us_acc[768..1536].iter().filter(|&&x| x > 0).count();
info!("[NNUE Eval] us_acc: min={us_min}, max={us_max}");
info!(
"[NNUE Eval] us_acc positive: first_half={us_first_half_positive}/768, second_half={us_second_half_positive}/768"
);
info!("[NNUE Eval] us_acc first 16: {:?}", &us_acc[0..16]);
let mut transformed = [0u8; NNUE_PYTORCH_L1];
sqr_clipped_relu_transform(us_acc, them_acc, &mut transformed);
let transformed_nonzero: usize = transformed.iter().filter(|&&x| x > 0).count();
let transformed_sum: u64 = transformed.iter().map(|&x| x as u64).sum();
info!("[NNUE Eval] transformed: nonzero={transformed_nonzero}/1536, sum={transformed_sum}");
info!("[NNUE Eval] transformed first 32: {:?}", &transformed[0..32]);
let f_king = pos.king_square(side_to_move);
let e_king = pos.king_square(!side_to_move);
let (f_rank, e_rank) =
crate::nnue::layer_stacks::compute_king_ranks(side_to_move, f_king, e_king);
let bucket_index = compute_bucket_index(f_rank, e_rank);
info!(
"[NNUE Eval] f_king_rank={f_rank}, e_king_rank={e_rank}, bucket_index={bucket_index}"
);
let (raw_score, l1_out, l1_skip) =
self.layer_stacks.evaluate_raw_with_diagnostics(bucket_index, &transformed);
info!("[NNUE Eval] l1_out (16 elements): {l1_out:?}");
info!("[NNUE Eval] l1_skip: {l1_skip}");
info!("[NNUE Eval] raw_score (with skip): {raw_score}");
let score = raw_score / super::constants::NNUE_PYTORCH_NNUE2SCORE;
let score_float = raw_score as f64 / super::constants::NNUE_PYTORCH_NNUE2SCORE as f64;
info!("[NNUE Eval] score: {score} (float: {score_float:.4})");
Value::new(score)
}
pub fn refresh_accumulator(&self, pos: &Position, acc: &mut AccumulatorLayerStacks) {
self.feature_transformer.refresh_accumulator(pos, acc);
}
pub fn update_accumulator(
&self,
pos: &Position,
dirty_piece: &super::accumulator::DirtyPiece,
acc: &mut AccumulatorLayerStacks,
prev_acc: &AccumulatorLayerStacks,
) {
self.feature_transformer.update_accumulator(pos, dirty_piece, acc, prev_acc);
}
pub fn forward_update_incremental(
&self,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
source_idx: usize,
) -> bool {
self.feature_transformer.forward_update_incremental(pos, stack, source_idx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nnue::constants::NNUE_PYTORCH_NNUE2SCORE;
use crate::position::{Position, SFEN_HIRATE};
#[test]
fn test_network_dimensions() {
assert_eq!(NNUE_PYTORCH_L1, 1536);
assert_eq!(NNUE_PYTORCH_NNUE2SCORE, 600);
}
#[test]
#[ignore]
fn test_load_layer_stacks_file() {
use crate::nnue::layer_stacks::{compute_bucket_index, sqr_clipped_relu_transform};
let path = std::env::var("NNUE_TEST_FILE")
.unwrap_or_else(|_| "/path/to/your/layer_stacks.nnue".to_string());
let network = match NetworkLayerStacks::load(path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
let bias_sum: i64 = network.feature_transformer.biases.0.iter().map(|&x| x as i64).sum();
eprintln!("FT bias sum: {bias_sum}");
let weight_sample: Vec<i16> = network.feature_transformer.weights[0..10].to_vec();
eprintln!("FT weight sample (first 10): {weight_sample:?}");
let weight_total = network.feature_transformer.weights.len();
let weight_nonzero: usize =
network.feature_transformer.weights.iter().filter(|&&x| x != 0).count();
eprintln!("FT weight total: {weight_total}, nonzero: {weight_nonzero}");
let mid_offset = weight_total / 2;
let weight_mid_sample: Vec<i16> =
network.feature_transformer.weights[mid_offset..mid_offset + 10].to_vec();
eprintln!("FT weight sample (mid): {weight_mid_sample:?}");
let first_nonzero_pos = network.feature_transformer.weights.iter().position(|&x| x != 0);
if let Some(pos) = first_nonzero_pos {
let sample_end = (pos + 10).min(weight_total);
let first_nonzero_sample: Vec<i16> =
network.feature_transformer.weights[pos..sample_end].to_vec();
eprintln!("First nonzero at position {pos}, sample: {first_nonzero_sample:?}");
let feature_idx = pos / NNUE_PYTORCH_L1;
eprintln!(" -> Feature index: {feature_idx}");
}
let l1_bias_sample: Vec<i32> = network.layer_stacks.buckets[0].l1_biases.to_vec();
eprintln!("L1 bias (bucket 0): {l1_bias_sample:?}");
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
use crate::nnue::features::{FeatureSet, HalfKA_hm_FeatureSet};
use crate::types::Color;
let active_black = HalfKA_hm_FeatureSet::collect_active_indices(&pos, Color::Black);
eprintln!("Active features for Black: {} features", active_black.len());
let first_5: Vec<usize> = active_black.iter().take(5).copied().collect();
eprintln!(" First 5 indices: {first_5:?}");
if let Some(&first_idx) = active_black.iter().next() {
let offset = first_idx * NNUE_PYTORCH_L1;
eprintln!(" Weight offset for feature {first_idx}: {offset}");
if offset + 10 <= weight_total {
let active_weight_sample: Vec<i16> =
network.feature_transformer.weights[offset..offset + 10].to_vec();
eprintln!(" Weight sample for first active feature: {active_weight_sample:?}");
}
}
let mut acc = AccumulatorLayerStacks::new();
network.refresh_accumulator(&pos, &mut acc);
let black_acc = acc.get(0);
let white_acc = acc.get(1);
let black_acc_sum: i64 = black_acc.iter().map(|&x| x as i64).sum();
let white_acc_sum: i64 = white_acc.iter().map(|&x| x as i64).sum();
eprintln!("Black acc sum: {black_acc_sum}, White acc sum: {white_acc_sum}");
eprintln!("Black acc sample (first 10): {:?}", &black_acc[0..10]);
let black_min = black_acc.iter().copied().min().unwrap_or(0);
let black_max = black_acc.iter().copied().max().unwrap_or(0);
let black_positive: usize = black_acc.iter().filter(|&&x| x > 0).count();
eprintln!("Black acc: min={black_min}, max={black_max}, positive={black_positive}/1536");
let first_half = &black_acc[0..768];
let second_half = &black_acc[768..1536];
let first_positive: usize = first_half.iter().filter(|&&x| x > 0).count();
let second_positive: usize = second_half.iter().filter(|&&x| x > 0).count();
eprintln!(
"First half positive: {first_positive}/768, Second half positive: {second_positive}/768"
);
let mut pairs_both_positive = 0usize;
for i in 0..768 {
if first_half[i] > 0 && second_half[i] > 0 {
pairs_both_positive += 1;
}
}
eprintln!("Pairs where both halves > 0: {pairs_both_positive}/768");
let mut transformed = [0u8; NNUE_PYTORCH_L1];
sqr_clipped_relu_transform(black_acc, white_acc, &mut transformed);
let transformed_sum: u64 = transformed.iter().map(|&x| x as u64).sum();
let transformed_nonzero: usize = transformed.iter().filter(|&&x| x > 0).count();
eprintln!("Transformed sum: {transformed_sum}, nonzero count: {transformed_nonzero}");
eprintln!("Transformed sample (first 20): {:?}", &transformed[0..20]);
let side_to_move = pos.side_to_move();
let f_king = pos.king_square(side_to_move);
let e_king = pos.king_square(!side_to_move);
let (f_rank, e_rank) =
crate::nnue::layer_stacks::compute_king_ranks(side_to_move, f_king, e_king);
let bucket_index = compute_bucket_index(f_rank, e_rank);
eprintln!("King ranks: f={f_rank}, e={e_rank}, bucket index: {bucket_index}");
let raw_score = network.layer_stacks.evaluate_raw(bucket_index, &transformed);
eprintln!("Raw score (before /600): {raw_score}");
let value = network.evaluate(&pos, &acc);
eprintln!("Initial position score: {}", value.raw());
assert!(value.raw().abs() < 1000, "Score {} is out of expected range", value.raw());
eprintln!("\n=== Various positions ===");
let test_positions = [
("初期局面", "lnsgkgsnl/1r5b1/ppppppppp/9/9/9/PPPPPPPPP/1B5R1/LNSGKGSNL b - 1"),
("後手1歩得", "lnsgkgsnl/1r5b1/ppppppppp/9/9/9/PPPPPPPP1/1B5R1/LNSGKGSNL b p 1"),
("先手1歩得", "lnsgkgsnl/1r5b1/pppppppp1/9/9/9/PPPPPPPPP/1B5R1/LNSGKGSNL b P 1"),
("後手飛車落ち", "lnsgkgsnl/7b1/ppppppppp/9/9/9/PPPPPPPPP/1B5R1/LNSGKGSNL b - 1"),
("先手角得", "lnsgkgsnl/1r7/ppppppppp/9/9/9/PPPPPPPPP/1B5R1/LNSGKGSNL b B 1"),
];
for (name, sfen) in test_positions {
pos.set_sfen(sfen).unwrap();
network.refresh_accumulator(&pos, &mut acc);
let (us_acc, them_acc) = (acc.get(0), acc.get(1));
let mut transformed = [0u8; NNUE_PYTORCH_L1];
sqr_clipped_relu_transform(us_acc, them_acc, &mut transformed);
let stm = pos.side_to_move();
let f_k = pos.king_square(stm);
let e_k = pos.king_square(!stm);
let (f_r, e_r) = crate::nnue::layer_stacks::compute_king_ranks(stm, f_k, e_k);
let bucket_idx = compute_bucket_index(f_r, e_r);
let raw = network.layer_stacks.evaluate_raw(bucket_idx, &transformed);
let val = network.evaluate(&pos, &acc);
eprintln!("{:15}: {:6} (raw: {:6})", name, val.raw(), raw);
}
}
}