use super::accumulator::AlignedBox;
use super::constants::{
LAYER_STACK_L1_OUT, LAYER_STACK_L2_IN, NNUE_PYTORCH_L1, NNUE_PYTORCH_L2, NNUE_PYTORCH_L3,
NNUE_PYTORCH_NNUE2SCORE, NUM_LAYER_STACK_BUCKETS,
};
use std::io::{self, Read};
const fn padded_input(input_dim: usize) -> usize {
input_dim.div_ceil(32) * 32
}
pub struct LayerStackBucket {
pub l1_biases: [i32; LAYER_STACK_L1_OUT],
pub l1_weights: AlignedBox<i8>,
pub l2_biases: [i32; NNUE_PYTORCH_L3],
pub l2_weights: AlignedBox<i8>,
pub output_bias: i32,
pub output_weights: AlignedBox<i8>,
}
impl LayerStackBucket {
const L1_PADDED_INPUT: usize = padded_input(NNUE_PYTORCH_L1);
const L2_PADDED_INPUT: usize = padded_input(LAYER_STACK_L2_IN);
const OUTPUT_PADDED_INPUT: usize = padded_input(NNUE_PYTORCH_L3);
pub fn new() -> Self {
Self {
l1_biases: [0; LAYER_STACK_L1_OUT],
l1_weights: AlignedBox::new_zeroed(LAYER_STACK_L1_OUT * Self::L1_PADDED_INPUT),
l2_biases: [0; NNUE_PYTORCH_L3],
l2_weights: AlignedBox::new_zeroed(NNUE_PYTORCH_L3 * Self::L2_PADDED_INPUT),
output_bias: 0,
output_weights: AlignedBox::new_zeroed(Self::OUTPUT_PADDED_INPUT),
}
}
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut bucket = Self::new();
let mut buf4 = [0u8; 4];
for bias in bucket.l1_biases.iter_mut() {
reader.read_exact(&mut buf4)?;
*bias = i32::from_le_bytes(buf4);
}
{
let total_bytes = LAYER_STACK_L1_OUT * Self::L1_PADDED_INPUT;
let mut temp_buf = vec![0u8; total_bytes];
reader.read_exact(&mut temp_buf)?;
for (i, &byte) in temp_buf.iter().enumerate() {
bucket.l1_weights[i] = byte as i8;
}
}
for bias in bucket.l2_biases.iter_mut() {
reader.read_exact(&mut buf4)?;
*bias = i32::from_le_bytes(buf4);
}
{
let total_bytes = NNUE_PYTORCH_L3 * Self::L2_PADDED_INPUT;
let mut temp_buf = vec![0u8; total_bytes];
reader.read_exact(&mut temp_buf)?;
for (i, &byte) in temp_buf.iter().enumerate() {
bucket.l2_weights[i] = byte as i8;
}
}
reader.read_exact(&mut buf4)?;
bucket.output_bias = i32::from_le_bytes(buf4);
{
let mut temp_buf = vec![0u8; Self::OUTPUT_PADDED_INPUT];
reader.read_exact(&mut temp_buf)?;
for (i, &byte) in temp_buf.iter().enumerate() {
bucket.output_weights[i] = byte as i8;
}
}
Ok(bucket)
}
pub fn propagate(&self, input: &[u8; NNUE_PYTORCH_L1]) -> i32 {
let mut l1_out = [0i32; LAYER_STACK_L1_OUT];
self.propagate_l1(input, &mut l1_out);
let l1_skip = l1_out[NNUE_PYTORCH_L2];
let mut l2_input = [0u8; LAYER_STACK_L2_IN];
for i in 0..NNUE_PYTORCH_L2 {
let val = l1_out[i] >> 6; let clamped = val.clamp(0, 127) as u8;
let sqr = ((clamped as u32 * clamped as u32) >> 7).min(127) as u8;
l2_input[i] = sqr; l2_input[NNUE_PYTORCH_L2 + i] = clamped; }
let mut l2_out = [0i32; NNUE_PYTORCH_L3];
self.propagate_l2(&l2_input, &mut l2_out);
let mut l2_relu = [0u8; NNUE_PYTORCH_L3];
for i in 0..NNUE_PYTORCH_L3 {
let val = l2_out[i] >> 6;
l2_relu[i] = val.clamp(0, 127) as u8;
}
let output = self.propagate_output(&l2_relu);
output + l1_skip
}
#[inline]
fn propagate_l1(&self, input: &[u8; NNUE_PYTORCH_L1], output: &mut [i32; LAYER_STACK_L1_OUT]) {
output.copy_from_slice(&self.l1_biases);
for (i, &in_val) in input.iter().enumerate() {
let in_i32 = in_val as i32;
for (j, out) in output.iter_mut().enumerate() {
let weight_idx = j * Self::L1_PADDED_INPUT + i;
*out += self.l1_weights[weight_idx] as i32 * in_i32;
}
}
}
#[inline]
fn propagate_l2(&self, input: &[u8; LAYER_STACK_L2_IN], output: &mut [i32; NNUE_PYTORCH_L3]) {
output.copy_from_slice(&self.l2_biases);
for (i, &in_val) in input.iter().enumerate() {
let in_i32 = in_val as i32;
for (j, out) in output.iter_mut().enumerate() {
let weight_idx = j * Self::L2_PADDED_INPUT + i;
*out += self.l2_weights[weight_idx] as i32 * in_i32;
}
}
}
#[inline]
fn propagate_output(&self, input: &[u8; NNUE_PYTORCH_L3]) -> i32 {
let mut sum = self.output_bias;
for (i, &in_val) in input.iter().enumerate() {
sum += self.output_weights[i] as i32 * in_val as i32;
}
sum
}
#[cfg(feature = "diagnostics")]
pub fn propagate_with_diagnostics(
&self,
input: &[u8; NNUE_PYTORCH_L1],
) -> (i32, [i32; LAYER_STACK_L1_OUT], i32) {
let mut l1_out = [0i32; LAYER_STACK_L1_OUT];
self.propagate_l1(input, &mut l1_out);
let l1_skip = l1_out[NNUE_PYTORCH_L2];
let mut l2_input = [0u8; LAYER_STACK_L2_IN];
for i in 0..NNUE_PYTORCH_L2 {
let val = l1_out[i] >> 6;
let clamped = val.clamp(0, 127) as u8;
let sqr = ((clamped as u32 * clamped as u32) >> 7).min(127) as u8;
l2_input[i] = sqr;
l2_input[NNUE_PYTORCH_L2 + i] = clamped;
}
let mut l2_out = [0i32; NNUE_PYTORCH_L3];
self.propagate_l2(&l2_input, &mut l2_out);
let mut l2_relu = [0u8; NNUE_PYTORCH_L3];
for i in 0..NNUE_PYTORCH_L3 {
let val = l2_out[i] >> 6;
l2_relu[i] = val.clamp(0, 127) as u8;
}
let output = self.propagate_output(&l2_relu);
let raw_score = output + l1_skip;
(raw_score, l1_out, l1_skip)
}
}
impl Default for LayerStackBucket {
fn default() -> Self {
Self::new()
}
}
pub struct LayerStacks {
pub buckets: [LayerStackBucket; NUM_LAYER_STACK_BUCKETS],
}
impl LayerStacks {
pub fn new() -> Self {
Self {
buckets: std::array::from_fn(|_| LayerStackBucket::new()),
}
}
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut stacks = Self::new();
let mut buf4 = [0u8; 4];
for bucket in stacks.buckets.iter_mut() {
reader.read_exact(&mut buf4)?;
let _fc_hash = u32::from_le_bytes(buf4);
*bucket = LayerStackBucket::read(reader)?;
}
Ok(stacks)
}
pub fn evaluate(&self, bucket_index: usize, input: &[u8; NNUE_PYTORCH_L1]) -> i32 {
debug_assert!(bucket_index < NUM_LAYER_STACK_BUCKETS);
let output = self.buckets[bucket_index].propagate(input);
output / NNUE_PYTORCH_NNUE2SCORE
}
pub fn evaluate_raw(&self, bucket_index: usize, input: &[u8; NNUE_PYTORCH_L1]) -> i32 {
debug_assert!(bucket_index < NUM_LAYER_STACK_BUCKETS);
self.buckets[bucket_index].propagate(input)
}
#[cfg(feature = "diagnostics")]
pub fn evaluate_raw_with_diagnostics(
&self,
bucket_index: usize,
input: &[u8; NNUE_PYTORCH_L1],
) -> (i32, [i32; LAYER_STACK_L1_OUT], i32) {
debug_assert!(bucket_index < NUM_LAYER_STACK_BUCKETS);
self.buckets[bucket_index].propagate_with_diagnostics(input)
}
}
impl Default for LayerStacks {
fn default() -> Self {
Self::new()
}
}
pub fn sqr_clipped_relu_transform(
us_acc: &[i16; NNUE_PYTORCH_L1],
them_acc: &[i16; NNUE_PYTORCH_L1],
output: &mut [u8; NNUE_PYTORCH_L1],
) {
let half = NNUE_PYTORCH_L1 / 2;
for i in 0..half {
let us_a = (us_acc[i] as i32).clamp(0, 127) as u32;
let us_b = (us_acc[half + i] as i32).clamp(0, 127) as u32;
let us_prod = ((us_a * us_b) >> 7).min(127);
output[i] = us_prod as u8;
let them_a = (them_acc[i] as i32).clamp(0, 127) as u32;
let them_b = (them_acc[half + i] as i32).clamp(0, 127) as u32;
let them_prod = ((them_a * them_b) >> 7).min(127);
output[half + i] = them_prod as u8;
}
}
pub fn compute_bucket_index(f_king_rank: usize, e_king_rank: usize) -> usize {
const F_TO_INDEX: [usize; 9] = [0, 0, 0, 3, 3, 3, 6, 6, 6];
const E_TO_INDEX: [usize; 9] = [0, 0, 0, 1, 1, 1, 2, 2, 2];
let f_idx = F_TO_INDEX[f_king_rank.min(8)];
let e_idx = E_TO_INDEX[e_king_rank.min(8)];
(f_idx + e_idx).min(NUM_LAYER_STACK_BUCKETS - 1)
}
pub fn compute_king_ranks(
side_to_move: crate::types::Color,
f_king_sq: crate::types::Square,
e_king_sq: crate::types::Square,
) -> (usize, usize) {
use crate::types::Color;
let f_rank = if side_to_move == Color::Black {
f_king_sq.rank() as usize } else {
8 - f_king_sq.rank() as usize };
let e_rank = if side_to_move == Color::Black {
8 - e_king_sq.rank() as usize } else {
e_king_sq.rank() as usize };
(f_rank, e_rank)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_stack_bucket_new() {
let bucket = LayerStackBucket::new();
assert_eq!(bucket.l1_biases.len(), LAYER_STACK_L1_OUT);
assert_eq!(bucket.l2_biases.len(), NNUE_PYTORCH_L3);
}
#[test]
fn test_layer_stacks_new() {
let stacks = LayerStacks::new();
assert_eq!(stacks.buckets.len(), NUM_LAYER_STACK_BUCKETS);
}
#[test]
fn test_bucket_index() {
assert_eq!(compute_bucket_index(0, 0), 0);
assert_eq!(compute_bucket_index(1, 1), 0);
assert_eq!(compute_bucket_index(2, 2), 0);
assert_eq!(compute_bucket_index(3, 3), 4);
assert_eq!(compute_bucket_index(6, 6), 8);
assert_eq!(compute_bucket_index(8, 8), 8);
assert_eq!(compute_bucket_index(0, 8), 2);
assert_eq!(compute_bucket_index(8, 0), 6);
assert_eq!(compute_bucket_index(10, 10), 8);
}
#[test]
fn test_compute_king_ranks_hirate() {
use crate::position::{Position, SFEN_HIRATE};
use crate::types::Color;
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
assert_eq!(pos.side_to_move(), Color::Black);
let f_king_sq = pos.king_square(Color::Black); let e_king_sq = pos.king_square(Color::White);
let (f_rank, e_rank) = compute_king_ranks(Color::Black, f_king_sq, e_king_sq);
assert_eq!(f_rank, 8, "f_rank for Black in hirate");
assert_eq!(e_rank, 8, "e_rank for Black in hirate");
assert_eq!(compute_bucket_index(f_rank, e_rank), 8);
}
#[test]
fn test_compute_king_ranks_positions() {
use crate::position::Position;
use crate::types::Color;
let mut pos = Position::new();
pos.set_sfen("4k4/9/9/9/4K4/9/9/9/9 b - 1").unwrap();
let f_king_sq = pos.king_square(Color::Black); let e_king_sq = pos.king_square(Color::White);
let (f_rank, e_rank) = compute_king_ranks(Color::Black, f_king_sq, e_king_sq);
assert_eq!(f_rank, 4, "f_rank for Black king at 5e");
assert_eq!(e_rank, 8, "e_rank for White king at 5a");
assert_eq!(compute_bucket_index(f_rank, e_rank), 5);
let mut pos2 = Position::new();
pos2.set_sfen("4k4/9/9/9/4K4/9/9/9/9 w - 1").unwrap();
let (f_rank2, e_rank2) = compute_king_ranks(
Color::White,
pos2.king_square(Color::White),
pos2.king_square(Color::Black),
);
assert_eq!(f_rank2, 8, "f_rank for White king at 5a");
assert_eq!(e_rank2, 4, "e_rank for Black king at 5e");
assert_eq!(compute_bucket_index(f_rank2, e_rank2), 7);
}
#[test]
fn test_sqr_clipped_relu_transform_basic() {
let mut us_acc = [0i16; NNUE_PYTORCH_L1];
let mut them_acc = [0i16; NNUE_PYTORCH_L1];
let mut output = [0u8; NNUE_PYTORCH_L1];
sqr_clipped_relu_transform(&us_acc, &them_acc, &mut output);
assert!(
output.iter().all(|&x| x == 0),
"all zeros input should produce all zeros output"
);
let half = NNUE_PYTORCH_L1 / 2;
for i in 0..half {
us_acc[i] = 127;
us_acc[half + i] = 127;
them_acc[i] = 127;
them_acc[half + i] = 127;
}
sqr_clipped_relu_transform(&us_acc, &them_acc, &mut output);
for (i, &val) in output.iter().enumerate().take(NNUE_PYTORCH_L1) {
assert_eq!(val, 126, "max input should produce 126 at index {i}");
}
for i in 0..NNUE_PYTORCH_L1 {
us_acc[i] = -100;
them_acc[i] = -100;
}
sqr_clipped_relu_transform(&us_acc, &them_acc, &mut output);
assert!(output.iter().all(|&x| x == 0), "negative input should be clamped to 0");
}
}