use super::accumulator::Aligned;
use super::constants::{
LAYER_STACK_L1_OUT, LAYER_STACK_L2_IN, NNUE_PYTORCH_L1, NNUE_PYTORCH_L2, NNUE_PYTORCH_L3,
NUM_LAYER_STACK_BUCKETS,
};
use super::layers::AffineTransform;
use std::io::{self, Read};
const L2_PADDED_INPUT: usize = super::layers::padded_input(LAYER_STACK_L2_IN);
const OUTPUT_PADDED_INPUT: usize = super::layers::padded_input(NNUE_PYTORCH_L3);
#[cfg(test)]
fn sqr_clipped_relu_explicit<const DIM: usize>(input: &[i32; DIM], output: &mut [u8; DIM]) {
for i in 0..DIM {
output[i] = ((input[i] as i64 * input[i] as i64) >> 19).clamp(0, 127) as u8;
}
}
pub struct LayerStackBucket {
pub l1: AffineTransform<NNUE_PYTORCH_L1, LAYER_STACK_L1_OUT>,
pub l2: AffineTransform<LAYER_STACK_L2_IN, NNUE_PYTORCH_L3>,
pub output: AffineTransform<NNUE_PYTORCH_L3, 1>,
}
impl LayerStackBucket {
pub fn new() -> Self {
Self {
l1: AffineTransform::new(),
l2: AffineTransform::new(),
output: AffineTransform::new(),
}
}
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let l1 = AffineTransform::read(reader)?;
let l2 = AffineTransform::read(reader)?;
let output = AffineTransform::read(reader)?;
Ok(Self { l1, l2, output })
}
pub fn propagate(&self, input: &[u8; NNUE_PYTORCH_L1]) -> i32 {
let mut l1_out = [0i32; LAYER_STACK_L1_OUT];
let mut l2_input = Aligned([0u8; L2_PADDED_INPUT]);
let mut l2_out = [0i32; NNUE_PYTORCH_L3];
let mut l2_relu = Aligned([0u8; OUTPUT_PADDED_INPUT]);
let mut output_arr = [0i32; 1];
self.l1.propagate(input, &mut l1_out);
let l1_skip = l1_out[NNUE_PYTORCH_L2];
l1_sqr_clipped_relu_activation(&l1_out, &mut l2_input.0);
self.l2.propagate(&l2_input.0, &mut l2_out);
clipped_relu_i32_to_u8(&l2_out, &mut l2_relu.0);
self.output.propagate(&l2_relu.0, &mut output_arr);
output_arr[0] + l1_skip
}
#[cfg(feature = "diagnostics")]
pub fn propagate_with_diagnostics(
&self,
input: &[u8; NNUE_PYTORCH_L1],
) -> (i32, [i32; LAYER_STACK_L1_OUT], i32) {
let mut l1_out = [0i32; LAYER_STACK_L1_OUT];
let mut l2_input = Aligned([0u8; L2_PADDED_INPUT]);
let mut l2_out = [0i32; NNUE_PYTORCH_L3];
let mut l2_relu = Aligned([0u8; OUTPUT_PADDED_INPUT]);
let mut output_arr = [0i32; 1];
self.l1.propagate(input, &mut l1_out);
let l1_skip = l1_out[NNUE_PYTORCH_L2]; l1_sqr_clipped_relu_activation(&l1_out, &mut l2_input.0);
self.l2.propagate(&l2_input.0, &mut l2_out);
clipped_relu_i32_to_u8(&l2_out, &mut l2_relu.0);
self.output.propagate(&l2_relu.0, &mut output_arr);
let raw_score = output_arr[0] + l1_skip;
(raw_score, l1_out, l1_skip)
}
}
impl Default for LayerStackBucket {
fn default() -> Self {
Self::new()
}
}
pub struct LayerStacks {
pub buckets: [LayerStackBucket; NUM_LAYER_STACK_BUCKETS],
}
impl LayerStacks {
pub fn new() -> Self {
Self {
buckets: std::array::from_fn(|_| LayerStackBucket::new()),
}
}
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut stacks = Self::new();
let mut buf4 = [0u8; 4];
for bucket in stacks.buckets.iter_mut() {
reader.read_exact(&mut buf4)?;
let _fc_hash = u32::from_le_bytes(buf4);
*bucket = LayerStackBucket::read(reader)?;
}
Ok(stacks)
}
pub fn evaluate_raw(&self, bucket_index: usize, input: &[u8; NNUE_PYTORCH_L1]) -> i32 {
debug_assert!(bucket_index < NUM_LAYER_STACK_BUCKETS);
unsafe { self.buckets.get_unchecked(bucket_index) }.propagate(input)
}
#[cfg(feature = "diagnostics")]
pub fn evaluate_raw_with_diagnostics(
&self,
bucket_index: usize,
input: &[u8; NNUE_PYTORCH_L1],
) -> (i32, [i32; LAYER_STACK_L1_OUT], i32) {
debug_assert!(bucket_index < NUM_LAYER_STACK_BUCKETS);
self.buckets[bucket_index].propagate_with_diagnostics(input)
}
}
impl Default for LayerStacks {
fn default() -> Self {
Self::new()
}
}
#[inline]
fn l1_sqr_clipped_relu_activation(l1_out: &[i32; LAYER_STACK_L1_OUT], l2_input: &mut [u8]) {
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
let max127 = _mm256_set1_epi32(127);
let in_ptr = l1_out.as_ptr();
let out_ptr = l2_input.as_mut_ptr();
for chunk in 0..2 {
let offset = chunk * 8;
let v = _mm256_loadu_si256(in_ptr.add(offset) as *const __m256i);
let sqr = _mm256_mullo_epi32(v, v);
let sqr_shifted = _mm256_srai_epi32(sqr, 19);
let sqr_result = _mm256_min_epi32(_mm256_max_epi32(sqr_shifted, zero), max127);
let relu_shifted = _mm256_srai_epi32(v, 6);
let relu_result = _mm256_min_epi32(_mm256_max_epi32(relu_shifted, zero), max127);
let sqr_16 = _mm256_packs_epi32(sqr_result, sqr_result); let sqr_8 = _mm256_packus_epi16(sqr_16, sqr_16); let sqr_lo = _mm256_castsi256_si128(sqr_8);
let sqr_hi = _mm256_extracti128_si256(sqr_8, 1);
let sqr_combined = _mm_unpacklo_epi32(sqr_lo, sqr_hi);
_mm_storel_epi64(out_ptr.add(offset) as *mut __m128i, sqr_combined);
let relu_16 = _mm256_packs_epi32(relu_result, relu_result);
let relu_8 = _mm256_packus_epi16(relu_16, relu_16);
let relu_lo = _mm256_castsi256_si128(relu_8);
let relu_hi = _mm256_extracti128_si256(relu_8, 1);
let relu_combined = _mm_unpacklo_epi32(relu_lo, relu_hi);
_mm_storel_epi64(
out_ptr.add(NNUE_PYTORCH_L2 + offset) as *mut __m128i,
relu_combined,
);
}
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
{
for (i, &val) in l1_out.iter().enumerate().take(NNUE_PYTORCH_L2) {
let input_val = val as i64;
let sqr = ((input_val * input_val) >> 19).clamp(0, 127) as u8;
let clamped = (val >> 6).clamp(0, 127) as u8;
l2_input[i] = sqr;
l2_input[NNUE_PYTORCH_L2 + i] = clamped;
}
}
}
#[inline]
fn clipped_relu_i32_to_u8(input: &[i32; NNUE_PYTORCH_L3], output: &mut [u8]) {
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
let max127 = _mm256_set1_epi32(127);
let in_ptr = input.as_ptr();
let out_ptr = output.as_mut_ptr();
for chunk in 0..4 {
let offset = chunk * 8;
let v = _mm256_loadu_si256(in_ptr.add(offset) as *const __m256i);
let shifted = _mm256_srai_epi32(v, 6);
let clamped = _mm256_min_epi32(_mm256_max_epi32(shifted, zero), max127);
let packed16 = _mm256_packs_epi32(clamped, clamped);
let packed8 = _mm256_packus_epi16(packed16, packed16);
let lo = _mm256_castsi256_si128(packed8);
let hi = _mm256_extracti128_si256(packed8, 1);
let combined = _mm_unpacklo_epi32(lo, hi);
_mm_storel_epi64(out_ptr.add(offset) as *mut __m128i, combined);
}
}
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
{
for (out, &val) in output.iter_mut().zip(input.iter()) {
*out = (val >> 6).clamp(0, 127) as u8;
}
}
}
pub fn sqr_clipped_relu_transform(
us_acc: &[i16; NNUE_PYTORCH_L1],
them_acc: &[i16; NNUE_PYTORCH_L1],
output: &mut [u8; NNUE_PYTORCH_L1],
) {
let half = NNUE_PYTORCH_L1 / 2;
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx512f",
target_feature = "avx512bw"
))]
{
unsafe {
use std::arch::x86_64::*;
let zero = _mm512_setzero_si512();
let max127 = _mm512_set1_epi16(127);
for (acc, out_offset) in [(us_acc.as_ptr(), 0usize), (them_acc.as_ptr(), half)] {
let acc_a = acc;
let acc_b = acc.add(half);
let out_ptr = output.as_mut_ptr().add(out_offset);
for i in 0..(half / 32) {
let offset = i * 32;
let va = _mm512_load_si512(acc_a.add(offset) as *const __m512i);
let vb = _mm512_load_si512(acc_b.add(offset) as *const __m512i);
let a = _mm512_min_epi16(_mm512_max_epi16(va, zero), max127);
let b = _mm512_min_epi16(_mm512_max_epi16(vb, zero), max127);
let prod = _mm512_mullo_epi16(a, b);
let shifted = _mm512_srli_epi16(prod, 7);
let packed = _mm512_packus_epi16(shifted, zero);
let perm = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
let fixed = _mm512_permutexvar_epi64(perm, packed);
_mm256_storeu_si256(
out_ptr.add(offset) as *mut __m256i,
_mm512_castsi512_si256(fixed),
);
}
}
}
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx2",
not(all(target_feature = "avx512f", target_feature = "avx512bw"))
))]
{
unsafe {
use std::arch::x86_64::*;
let zero = _mm256_setzero_si256();
let max127 = _mm256_set1_epi16(127);
for (acc, out_offset) in [(us_acc.as_ptr(), 0usize), (them_acc.as_ptr(), half)] {
let acc_a = acc;
let acc_b = acc.add(half);
let out_ptr = output.as_mut_ptr().add(out_offset);
for i in 0..(half / 32) {
let offset = i * 32;
let va0 = _mm256_load_si256(acc_a.add(offset) as *const __m256i);
let vb0 = _mm256_load_si256(acc_b.add(offset) as *const __m256i);
let a0 = _mm256_min_epi16(_mm256_max_epi16(va0, zero), max127);
let b0 = _mm256_min_epi16(_mm256_max_epi16(vb0, zero), max127);
let shifted0 = _mm256_srli_epi16(_mm256_mullo_epi16(a0, b0), 7);
let va1 = _mm256_load_si256(acc_a.add(offset + 16) as *const __m256i);
let vb1 = _mm256_load_si256(acc_b.add(offset + 16) as *const __m256i);
let a1 = _mm256_min_epi16(_mm256_max_epi16(va1, zero), max127);
let b1 = _mm256_min_epi16(_mm256_max_epi16(vb1, zero), max127);
let shifted1 = _mm256_srli_epi16(_mm256_mullo_epi16(a1, b1), 7);
let packed = _mm256_packus_epi16(shifted0, shifted1);
let fixed = _mm256_permute4x64_epi64(packed, 0xD8);
_mm256_storeu_si256(out_ptr.add(offset) as *mut __m256i, fixed);
}
}
}
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "avx2")
))]
{
unsafe {
use std::arch::x86_64::*;
let zero = _mm_setzero_si128();
let max127 = _mm_set1_epi16(127);
for (acc, out_offset) in [(us_acc.as_ptr(), 0usize), (them_acc.as_ptr(), half)] {
let acc_a = acc;
let acc_b = acc.add(half);
let out_ptr = output.as_mut_ptr().add(out_offset);
for i in 0..(half / 16) {
let offset = i * 16;
let va0 = _mm_load_si128(acc_a.add(offset) as *const __m128i);
let vb0 = _mm_load_si128(acc_b.add(offset) as *const __m128i);
let a0 = _mm_min_epi16(_mm_max_epi16(va0, zero), max127);
let b0 = _mm_min_epi16(_mm_max_epi16(vb0, zero), max127);
let shifted0 = _mm_srli_epi16(_mm_mullo_epi16(a0, b0), 7);
let va1 = _mm_load_si128(acc_a.add(offset + 8) as *const __m128i);
let vb1 = _mm_load_si128(acc_b.add(offset + 8) as *const __m128i);
let a1 = _mm_min_epi16(_mm_max_epi16(va1, zero), max127);
let b1 = _mm_min_epi16(_mm_max_epi16(vb1, zero), max127);
let shifted1 = _mm_srli_epi16(_mm_mullo_epi16(a1, b1), 7);
let packed = _mm_packus_epi16(shifted0, shifted1);
_mm_storeu_si128(out_ptr.add(offset) as *mut __m128i, packed);
}
}
}
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
unsafe {
use std::arch::wasm32::*;
let zero = i16x8_splat(0);
let max127 = i16x8_splat(127);
for (acc, out_offset) in [(us_acc.as_ptr(), 0usize), (them_acc.as_ptr(), half)] {
let acc_a = acc;
let acc_b = acc.add(half);
let out_ptr = output.as_mut_ptr().add(out_offset);
for i in 0..(half / 16) {
let offset = i * 16;
let va0 = v128_load(acc_a.add(offset) as *const v128);
let vb0 = v128_load(acc_b.add(offset) as *const v128);
let a0 = i16x8_min(i16x8_max(va0, zero), max127);
let b0 = i16x8_min(i16x8_max(vb0, zero), max127);
let shifted0 = u16x8_shr(i16x8_mul(a0, b0), 7);
let va1 = v128_load(acc_a.add(offset + 8) as *const v128);
let vb1 = v128_load(acc_b.add(offset + 8) as *const v128);
let a1 = i16x8_min(i16x8_max(va1, zero), max127);
let b1 = i16x8_min(i16x8_max(vb1, zero), max127);
let shifted1 = u16x8_shr(i16x8_mul(a1, b1), 7);
let packed = u8x16_narrow_i16x8(shifted0, shifted1);
v128_store(out_ptr.add(offset) as *mut v128, packed);
}
}
}
}
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "sse2"),
all(target_arch = "wasm32", target_feature = "simd128")
)))]
{
for i in 0..half {
let us_a = (us_acc[i] as i32).clamp(0, 127) as u32;
let us_b = (us_acc[half + i] as i32).clamp(0, 127) as u32;
let us_prod = ((us_a * us_b) >> 7).min(127);
output[i] = us_prod as u8;
let them_a = (them_acc[i] as i32).clamp(0, 127) as u32;
let them_b = (them_acc[half + i] as i32).clamp(0, 127) as u32;
let them_prod = ((them_a * them_b) >> 7).min(127);
output[half + i] = them_prod as u8;
}
}
}
pub fn compute_bucket_index(f_king_rank: usize, e_king_rank: usize) -> usize {
const F_TO_INDEX: [usize; 9] = [0, 0, 0, 3, 3, 3, 6, 6, 6];
const E_TO_INDEX: [usize; 9] = [0, 0, 0, 1, 1, 1, 2, 2, 2];
let f_idx = F_TO_INDEX[f_king_rank.min(8)];
let e_idx = E_TO_INDEX[e_king_rank.min(8)];
(f_idx + e_idx).min(NUM_LAYER_STACK_BUCKETS - 1)
}
pub fn compute_king_ranks(
side_to_move: crate::types::Color,
f_king_sq: crate::types::Square,
e_king_sq: crate::types::Square,
) -> (usize, usize) {
use crate::types::Color;
let f_rank = if side_to_move == Color::Black {
f_king_sq.rank() as usize } else {
8 - f_king_sq.rank() as usize };
let e_rank = if side_to_move == Color::Black {
8 - e_king_sq.rank() as usize } else {
e_king_sq.rank() as usize };
(f_rank, e_rank)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nnue::accumulator::Aligned;
use crate::nnue::layers::ClippedReLU;
#[test]
fn test_layer_stack_bucket_new() {
let bucket = LayerStackBucket::new();
assert_eq!(bucket.l1.biases.len(), LAYER_STACK_L1_OUT);
assert_eq!(bucket.l2.biases.len(), NNUE_PYTORCH_L3);
}
#[test]
fn test_layer_stacks_new() {
let stacks = LayerStacks::new();
assert_eq!(stacks.buckets.len(), NUM_LAYER_STACK_BUCKETS);
}
#[test]
fn test_bucket_index() {
assert_eq!(compute_bucket_index(0, 0), 0);
assert_eq!(compute_bucket_index(1, 1), 0);
assert_eq!(compute_bucket_index(2, 2), 0);
assert_eq!(compute_bucket_index(3, 3), 4);
assert_eq!(compute_bucket_index(6, 6), 8);
assert_eq!(compute_bucket_index(8, 8), 8);
assert_eq!(compute_bucket_index(0, 8), 2);
assert_eq!(compute_bucket_index(8, 0), 6);
assert_eq!(compute_bucket_index(10, 10), 8);
}
#[test]
fn test_compute_king_ranks_hirate() {
use crate::position::{Position, SFEN_HIRATE};
use crate::types::Color;
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
assert_eq!(pos.side_to_move(), Color::Black);
let f_king_sq = pos.king_square(Color::Black); let e_king_sq = pos.king_square(Color::White);
let (f_rank, e_rank) = compute_king_ranks(Color::Black, f_king_sq, e_king_sq);
assert_eq!(f_rank, 8, "f_rank for Black in hirate");
assert_eq!(e_rank, 8, "e_rank for Black in hirate");
assert_eq!(compute_bucket_index(f_rank, e_rank), 8);
}
#[test]
fn test_compute_king_ranks_positions() {
use crate::position::Position;
use crate::types::Color;
let mut pos = Position::new();
pos.set_sfen("4k4/9/9/9/4K4/9/9/9/9 b - 1").unwrap();
let f_king_sq = pos.king_square(Color::Black); let e_king_sq = pos.king_square(Color::White);
let (f_rank, e_rank) = compute_king_ranks(Color::Black, f_king_sq, e_king_sq);
assert_eq!(f_rank, 4, "f_rank for Black king at 5e");
assert_eq!(e_rank, 8, "e_rank for White king at 5a");
assert_eq!(compute_bucket_index(f_rank, e_rank), 5);
let mut pos2 = Position::new();
pos2.set_sfen("4k4/9/9/9/4K4/9/9/9/9 w - 1").unwrap();
let (f_rank2, e_rank2) = compute_king_ranks(
Color::White,
pos2.king_square(Color::White),
pos2.king_square(Color::Black),
);
assert_eq!(f_rank2, 8, "f_rank for White king at 5a");
assert_eq!(e_rank2, 4, "e_rank for Black king at 5e");
assert_eq!(compute_bucket_index(f_rank2, e_rank2), 7);
}
#[test]
fn test_l1_sqr_clipped_relu_boundary() {
fn sqr_clipped_relu(input: i32) -> u8 {
((input as i64 * input as i64) >> 19).clamp(0, 127) as u8
}
assert_eq!(sqr_clipped_relu(0), 0);
assert_eq!(sqr_clipped_relu(64), 0); assert_eq!(sqr_clipped_relu(724), 0); assert_eq!(sqr_clipped_relu(8128), 126);
assert_eq!(sqr_clipped_relu(8192), 127);
assert_eq!(sqr_clipped_relu(8256), 127);
assert_eq!(sqr_clipped_relu(20000), 127);
assert_eq!(sqr_clipped_relu(-8192), 127);
let bucket = LayerStackBucket::new();
let mut bucket_with_biases = LayerStackBucket::new();
bucket_with_biases.l1.biases[0] = 8192;
bucket_with_biases.l1.biases[1] = 8128;
let input = Aligned([0u8; NNUE_PYTORCH_L1]);
let result = bucket_with_biases.propagate(&input.0);
let _ = result;
let _ = bucket; }
#[test]
fn test_sqr_clipped_relu_transform_basic() {
use super::super::accumulator::Aligned;
let mut us_acc = Aligned([0i16; NNUE_PYTORCH_L1]);
let mut them_acc = Aligned([0i16; NNUE_PYTORCH_L1]);
let mut output = Aligned([0u8; NNUE_PYTORCH_L1]);
sqr_clipped_relu_transform(&us_acc.0, &them_acc.0, &mut output.0);
assert!(
output.0.iter().all(|&x| x == 0),
"all zeros input should produce all zeros output"
);
let half = NNUE_PYTORCH_L1 / 2;
for i in 0..half {
us_acc.0[i] = 127;
us_acc.0[half + i] = 127;
them_acc.0[i] = 127;
them_acc.0[half + i] = 127;
}
sqr_clipped_relu_transform(&us_acc.0, &them_acc.0, &mut output.0);
for (i, &val) in output.0.iter().enumerate().take(NNUE_PYTORCH_L1) {
assert_eq!(val, 126, "max input should produce 126 at index {i}");
}
for i in 0..NNUE_PYTORCH_L1 {
us_acc.0[i] = -100;
them_acc.0[i] = -100;
}
sqr_clipped_relu_transform(&us_acc.0, &them_acc.0, &mut output.0);
assert!(output.0.iter().all(|&x| x == 0), "negative input should be clamped to 0");
}
#[test]
fn test_layer_stack_l2_input_matches_scalar_reference() {
let cases = [
[
-50000, -40000, -33000, -32768, -32000, -1000, 0, 64, 724, 8128, 8192, 8256, 20000,
32767, 40000, 50000,
],
[
-1, 1, 63, 127, 128, 255, 256, 4096, 8191, 8192, 16384, 24576, 32768, 40000, 65535,
70000,
],
];
for l1_out in cases {
let mut l1_relu = [0u8; LAYER_STACK_L1_OUT];
let mut l2_input_opt = Aligned([0u8; L2_PADDED_INPUT]);
let mut l2_sqr = [0u8; LAYER_STACK_L1_OUT];
ClippedReLU::<LAYER_STACK_L1_OUT>::propagate(&l1_out, &mut l1_relu);
sqr_clipped_relu_explicit::<LAYER_STACK_L1_OUT>(&l1_out, &mut l2_sqr);
l2_input_opt.0[..LAYER_STACK_L1_OUT].copy_from_slice(&l2_sqr);
l2_input_opt.0[NNUE_PYTORCH_L2..NNUE_PYTORCH_L2 + NNUE_PYTORCH_L2]
.copy_from_slice(&l1_relu[..NNUE_PYTORCH_L2]);
let mut l2_input_ref = Aligned([0u8; L2_PADDED_INPUT]);
for (i, &val) in l1_out.iter().enumerate().take(NNUE_PYTORCH_L2) {
let input_val = i64::from(val);
l2_input_ref.0[i] = ((input_val * input_val) >> 19).clamp(0, 127) as u8;
l2_input_ref.0[NNUE_PYTORCH_L2 + i] = (val >> 6).clamp(0, 127) as u8;
}
assert_eq!(
l2_input_opt.0, l2_input_ref.0,
"optimized l2_input must match scalar reference for l1_out={l1_out:?}"
);
}
}
#[test]
fn test_layer_stack_l2_relu_matches_scalar_reference() {
let input = [
-50000, -40000, -33000, -32768, -32000, -1000, -1, 0, 1, 63, 64, 127, 128, 255, 256,
4096, 8191, 8192, 16384, 24576, 32767, 32768, 40000, 50000, 65535, 70000, 80000, 90000,
100000, 110000, 120000, 130000,
];
let mut opt = [0u8; NNUE_PYTORCH_L3];
let mut reference = [0u8; NNUE_PYTORCH_L3];
ClippedReLU::<NNUE_PYTORCH_L3>::propagate(&input, &mut opt);
for (dst, &value) in reference.iter_mut().zip(input.iter()) {
*dst = (value >> 6).clamp(0, 127) as u8;
}
assert_eq!(opt, reference);
}
#[test]
fn test_layer_stack_bucket_propagate_matches_scalar_reference() {
fn affine_from_bytes<const INPUT_DIM: usize, const OUTPUT_DIM: usize>(
biases: [i32; OUTPUT_DIM],
weights: &[i8],
) -> AffineTransform<INPUT_DIM, OUTPUT_DIM> {
let mut bytes = Vec::with_capacity(OUTPUT_DIM * 4 + weights.len());
for bias in biases {
bytes.extend_from_slice(&bias.to_le_bytes());
}
for &weight in weights {
bytes.push(weight as u8);
}
AffineTransform::<INPUT_DIM, OUTPUT_DIM>::read(&mut &bytes[..]).unwrap()
}
fn scalar_reference(bucket: &LayerStackBucket, input: &[u8; NNUE_PYTORCH_L1]) -> i32 {
let mut l1_out = [0i32; LAYER_STACK_L1_OUT];
bucket.l1.propagate(input, &mut l1_out);
let l1_skip = l1_out[NNUE_PYTORCH_L2];
let mut l2_input = Aligned([0u8; L2_PADDED_INPUT]);
for (i, &val) in l1_out.iter().enumerate().take(NNUE_PYTORCH_L2) {
let input_val = i64::from(val);
l2_input.0[i] = ((input_val * input_val) >> 19).clamp(0, 127) as u8;
l2_input.0[NNUE_PYTORCH_L2 + i] = (val >> 6).clamp(0, 127) as u8;
}
let mut l2_out = [0i32; NNUE_PYTORCH_L3];
bucket.l2.propagate(&l2_input.0, &mut l2_out);
let mut l2_relu = Aligned([0u8; OUTPUT_PADDED_INPUT]);
for (dst, &val) in l2_relu.0.iter_mut().zip(l2_out.iter()) {
*dst = (val >> 6).clamp(0, 127) as u8;
}
let mut output_arr = [0i32; 1];
bucket.output.propagate(&l2_relu.0, &mut output_arr);
output_arr[0] + l1_skip
}
let l1_biases = [
-50000, -40000, -33000, -32768, -32000, -1000, 0, 64, 724, 8128, 8192, 8256, 20000,
32767, 40000, 50000,
];
let l1_weights = vec![0i8; LAYER_STACK_L1_OUT * NNUE_PYTORCH_L1];
let mut l2_biases = [0i32; NNUE_PYTORCH_L3];
for (i, bias) in l2_biases.iter_mut().enumerate() {
*bias = (i as i32 - 16) * 37;
}
let mut l2_weights = vec![0i8; NNUE_PYTORCH_L3 * L2_PADDED_INPUT];
for (i, weight) in l2_weights.iter_mut().enumerate() {
*weight = ((i as i32 % 7) - 3) as i8;
}
let output_biases = [123i32; 1];
let mut output_weights = vec![0i8; OUTPUT_PADDED_INPUT];
for (i, weight) in output_weights.iter_mut().enumerate() {
*weight = ((i as i32 % 5) - 2) as i8;
}
let bucket = LayerStackBucket {
l1: affine_from_bytes::<NNUE_PYTORCH_L1, LAYER_STACK_L1_OUT>(l1_biases, &l1_weights),
l2: affine_from_bytes::<LAYER_STACK_L2_IN, NNUE_PYTORCH_L3>(l2_biases, &l2_weights),
output: affine_from_bytes::<NNUE_PYTORCH_L3, 1>(output_biases, &output_weights),
};
let input = Aligned([0u8; NNUE_PYTORCH_L1]);
let mut l1_out = [0i32; LAYER_STACK_L1_OUT];
let mut l1_relu = [0u8; LAYER_STACK_L1_OUT];
let mut l2_input_opt = Aligned([0u8; L2_PADDED_INPUT]);
let mut l2_input_ref = Aligned([0u8; L2_PADDED_INPUT]);
let mut l2_sqr = [0u8; LAYER_STACK_L1_OUT];
let mut l2_out = [0i32; NNUE_PYTORCH_L3];
let mut l2_relu_opt = Aligned([0u8; OUTPUT_PADDED_INPUT]);
let mut l2_relu_ref = Aligned([0u8; OUTPUT_PADDED_INPUT]);
bucket.l1.propagate(&input.0, &mut l1_out);
ClippedReLU::<LAYER_STACK_L1_OUT>::propagate(&l1_out, &mut l1_relu);
sqr_clipped_relu_explicit::<LAYER_STACK_L1_OUT>(&l1_out, &mut l2_sqr);
l2_input_opt.0[..LAYER_STACK_L1_OUT].copy_from_slice(&l2_sqr);
l2_input_opt.0[NNUE_PYTORCH_L2..NNUE_PYTORCH_L2 + NNUE_PYTORCH_L2]
.copy_from_slice(&l1_relu[..NNUE_PYTORCH_L2]);
for (i, &val) in l1_out.iter().enumerate().take(NNUE_PYTORCH_L2) {
let input_val = i64::from(val);
l2_input_ref.0[i] = ((input_val * input_val) >> 19).clamp(0, 127) as u8;
l2_input_ref.0[NNUE_PYTORCH_L2 + i] = (val >> 6).clamp(0, 127) as u8;
}
assert_eq!(l2_input_opt.0, l2_input_ref.0);
bucket.l2.propagate(&l2_input_opt.0, &mut l2_out);
ClippedReLU::<NNUE_PYTORCH_L3>::propagate(&l2_out, &mut l2_relu_opt.0);
for (dst, &val) in l2_relu_ref.0.iter_mut().zip(l2_out.iter()) {
*dst = (val >> 6).clamp(0, 127) as u8;
}
assert_eq!(l2_relu_opt.0, l2_relu_ref.0);
let mut output_arr = [0i32; 1];
bucket.output.propagate(&l2_relu_opt.0, &mut output_arr);
let optimized_inline = output_arr[0] + l1_out[NNUE_PYTORCH_L2];
let optimized = bucket.propagate(&input.0);
let reference = scalar_reference(&bucket, &input.0);
assert_eq!(optimized_inline, reference);
assert_eq!(optimized, reference);
}
}