use super::accumulator::{Aligned, AlignedBox};
use super::accumulator::{DirtyPiece, IndexList, MAX_ACTIVE_FEATURES, MAX_CHANGED_FEATURES};
use super::accumulator_layer_stacks::{
AccumulatorCacheLayerStacks, AccumulatorLayerStacks, AccumulatorStackLayerStacks,
};
use super::constants::{HALFKA_HM_DIMENSIONS, NNUE_PYTORCH_L1};
use super::features::{Feature, FeatureSet, HalfKA_hm, HalfKA_hm_FeatureSet};
use super::leb128::read_compressed_tensor_i16_all;
use crate::position::Position;
use crate::types::Color;
use std::io::{self, Read};
use std::mem::MaybeUninit;
#[cold]
#[inline(never)]
fn feature_index_oob(index: usize, max: usize) -> ! {
panic!("Feature index out of range: {index} (max: {max})")
}
#[inline]
fn append_changed_indices(
dirty_piece: &DirtyPiece,
perspective: Color,
king_sq: crate::types::Square,
removed: &mut IndexList<MAX_CHANGED_FEATURES>,
added: &mut IndexList<MAX_CHANGED_FEATURES>,
) {
<HalfKA_hm as Feature>::append_changed_indices(
dirty_piece,
perspective,
king_sq,
removed,
added,
);
}
#[inline]
fn append_active_indices(
pos: &Position,
perspective: Color,
active: &mut IndexList<MAX_ACTIVE_FEATURES>,
) {
<HalfKA_hm as Feature>::append_active_indices(pos, perspective, active);
}
#[repr(C, align(64))]
pub struct FeatureTransformerLayerStacks {
pub biases: Aligned<[i16; NNUE_PYTORCH_L1]>,
pub weights: AlignedBox<i16>,
}
impl FeatureTransformerLayerStacks {
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut biases = [0i16; NNUE_PYTORCH_L1];
let mut buf = [0u8; 2];
for bias in biases.iter_mut() {
reader.read_exact(&mut buf)?;
*bias = i16::from_le_bytes(buf);
}
let weight_size = HALFKA_HM_DIMENSIONS * NNUE_PYTORCH_L1;
let mut weights = AlignedBox::new_zeroed(weight_size);
for weight in weights.iter_mut() {
reader.read_exact(&mut buf)?;
*weight = i16::from_le_bytes(buf);
}
Ok(Self {
biases: Aligned(biases),
weights,
})
}
pub fn read_leb128<R: Read>(reader: &mut R) -> io::Result<Self> {
let weight_size = HALFKA_HM_DIMENSIONS * NNUE_PYTORCH_L1;
let total_size = NNUE_PYTORCH_L1 + weight_size;
let first_block = read_compressed_tensor_i16_all(reader)?;
if first_block.len() == total_size {
let mut biases = [0i16; NNUE_PYTORCH_L1];
biases.copy_from_slice(&first_block[..NNUE_PYTORCH_L1]);
let mut weights = AlignedBox::new_zeroed(weight_size);
weights.copy_from_slice(&first_block[NNUE_PYTORCH_L1..]);
return Ok(Self {
biases: Aligned(biases),
weights,
});
}
if first_block.len() == NNUE_PYTORCH_L1 {
let weights_block = read_compressed_tensor_i16_all(reader)?;
if weights_block.len() != weight_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"FT weights block size mismatch: got {}, expected {}",
weights_block.len(),
weight_size
),
));
}
let mut biases = [0i16; NNUE_PYTORCH_L1];
biases.copy_from_slice(&first_block);
let mut weights = AlignedBox::new_zeroed(weight_size);
weights.copy_from_slice(&weights_block);
return Ok(Self {
biases: Aligned(biases),
weights,
});
}
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unexpected LEB128 tensor size: got {}, expected {} or {}",
first_block.len(),
NNUE_PYTORCH_L1,
total_size
),
))
}
pub fn refresh_accumulator(&self, pos: &Position, acc: &mut AccumulatorLayerStacks) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let accumulation = acc.get_mut(p);
accumulation.copy_from_slice(&self.biases.0);
let mut active_indices = IndexList::new();
append_active_indices(pos, perspective, &mut active_indices);
for &index in active_indices.iter() {
self.add_weights(accumulation, index);
}
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
pub fn update_accumulator(
&self,
pos: &Position,
dirty_piece: &DirtyPiece,
acc: &mut AccumulatorLayerStacks,
prev_acc: &AccumulatorLayerStacks,
) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let reset = HalfKA_hm_FeatureSet::needs_refresh(dirty_piece, perspective);
if reset {
let accumulation = acc.get_mut(p);
accumulation.copy_from_slice(&self.biases.0);
let mut active_indices = IndexList::new();
append_active_indices(pos, perspective, &mut active_indices);
for &index in active_indices.iter() {
self.add_weights(accumulation, index);
}
} else {
let mut removed = IndexList::new();
let mut added = IndexList::new();
append_changed_indices(
dirty_piece,
perspective,
pos.king_square(perspective),
&mut removed,
&mut added,
);
let prev = prev_acc.get(p);
let curr = acc.get_mut(p);
curr.copy_from_slice(prev);
for &index in removed.iter() {
self.sub_weights(curr, index);
}
for &index in added.iter() {
self.add_weights(curr, index);
}
}
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
pub fn update_accumulator_with_cache(
&self,
pos: &Position,
dirty_piece: &DirtyPiece,
acc: &mut AccumulatorLayerStacks,
prev_acc: &AccumulatorLayerStacks,
cache: &mut AccumulatorCacheLayerStacks,
) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let reset = HalfKA_hm_FeatureSet::needs_refresh(dirty_piece, perspective);
if reset {
self.refresh_perspective_with_cache(pos, perspective, acc.get_mut(p), cache);
} else {
let mut removed = IndexList::new();
let mut added = IndexList::new();
append_changed_indices(
dirty_piece,
perspective,
pos.king_square(perspective),
&mut removed,
&mut added,
);
let prev = prev_acc.get(p);
let curr = acc.get_mut(p);
curr.copy_from_slice(prev);
for &index in removed.iter() {
self.sub_weights(curr, index);
}
for &index in added.iter() {
self.add_weights(curr, index);
}
}
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
pub fn refresh_accumulator_with_cache(
&self,
pos: &Position,
acc: &mut AccumulatorLayerStacks,
cache: &mut AccumulatorCacheLayerStacks,
) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
self.refresh_perspective_with_cache(pos, perspective, acc.get_mut(p), cache);
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
fn refresh_perspective_with_cache(
&self,
pos: &Position,
perspective: Color,
accumulation: &mut [i16; NNUE_PYTORCH_L1],
cache: &mut AccumulatorCacheLayerStacks,
) {
let king_sq = pos.king_square(perspective);
let mut active_indices = IndexList::new();
append_active_indices(pos, perspective, &mut active_indices);
let mut sorted_buf = [const { MaybeUninit::<u32>::uninit() }; MAX_ACTIVE_FEATURES];
let len = active_indices.len();
for (slot, &idx) in sorted_buf[..len].iter_mut().zip(active_indices.iter()) {
slot.write(idx as u32);
}
let sorted =
unsafe { std::slice::from_raw_parts_mut(sorted_buf.as_mut_ptr() as *mut u32, len) };
sorted.sort_unstable();
cache.refresh_or_cache(
king_sq,
perspective,
sorted,
&self.biases.0,
accumulation,
|acc, idx| self.add_weights(acc, idx),
|acc, idx| self.sub_weights(acc, idx),
);
}
pub fn forward_update_incremental(
&self,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
source_idx: usize,
) -> bool {
let Some(path) = stack.collect_path(source_idx) else {
return false;
};
let source_acc = stack.entry_at(source_idx).accumulator.clone();
{
let current_acc = &mut stack.current_mut().accumulator;
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
current_acc.get_mut(p).copy_from_slice(source_acc.get(p));
}
}
for &entry_idx in path.iter() {
let dirty_piece = stack.entry_at(entry_idx).dirty_piece;
for perspective in [Color::Black, Color::White] {
debug_assert!(
!dirty_piece.king_moved[perspective.index()],
"King moved between source and current"
);
let king_sq = pos.king_square(perspective);
let mut removed = IndexList::new();
let mut added = IndexList::new();
append_changed_indices(
&dirty_piece,
perspective,
king_sq,
&mut removed,
&mut added,
);
let p = perspective as usize;
let accumulation = stack.current_mut().accumulator.get_mut(p);
for &index in removed.iter() {
self.sub_weights(accumulation, index);
}
for &index in added.iter() {
self.add_weights(accumulation, index);
}
}
}
stack.current_mut().accumulator.computed_accumulation = true;
stack.current_mut().accumulator.computed_score = false;
true
}
#[inline]
fn add_weights(&self, accumulation: &mut [i16; NNUE_PYTORCH_L1], index: usize) {
let Some(offset) = index.checked_mul(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
let Some(end) = offset.checked_add(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
if end > self.weights.len() {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
}
let weights = &self.weights[offset..offset + NNUE_PYTORCH_L1];
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx512f",
target_feature = "avx512bw"
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..48 {
let acc_vec = _mm512_load_si512(acc_ptr.add(i * 32) as *const __m512i);
let weight_vec = _mm512_load_si512(weight_ptr.add(i * 32) as *const __m512i);
let result = _mm512_add_epi16(acc_vec, weight_vec);
_mm512_store_si512(acc_ptr.add(i * 32) as *mut __m512i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx2",
not(target_feature = "avx512bw")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..96 {
let acc_vec = _mm256_load_si256(acc_ptr.add(i * 16) as *const __m256i);
let weight_vec = _mm256_load_si256(weight_ptr.add(i * 16) as *const __m256i);
let result = _mm256_add_epi16(acc_vec, weight_vec);
_mm256_store_si256(acc_ptr.add(i * 16) as *mut __m256i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "avx2")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..192 {
let acc_vec = _mm_load_si128(acc_ptr.add(i * 8) as *const __m128i);
let weight_vec = _mm_load_si128(weight_ptr.add(i * 8) as *const __m128i);
let result = _mm_add_epi16(acc_vec, weight_vec);
_mm_store_si128(acc_ptr.add(i * 8) as *mut __m128i, result);
}
}
return;
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
unsafe {
use std::arch::wasm32::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..192 {
let acc_vec = v128_load(acc_ptr.add(i * 8) as *const v128);
let weight_vec = v128_load(weight_ptr.add(i * 8) as *const v128);
let result = i16x8_add(acc_vec, weight_vec);
v128_store(acc_ptr.add(i * 8) as *mut v128, result);
}
}
return;
}
#[allow(unreachable_code)]
for (acc, &weight) in accumulation.iter_mut().zip(weights) {
*acc = acc.wrapping_add(weight);
}
}
#[inline]
fn sub_weights(&self, accumulation: &mut [i16; NNUE_PYTORCH_L1], index: usize) {
let Some(offset) = index.checked_mul(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
let Some(end) = offset.checked_add(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
if end > self.weights.len() {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
}
let weights = &self.weights[offset..offset + NNUE_PYTORCH_L1];
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx512f",
target_feature = "avx512bw"
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..48 {
let acc_vec = _mm512_load_si512(acc_ptr.add(i * 32) as *const __m512i);
let weight_vec = _mm512_load_si512(weight_ptr.add(i * 32) as *const __m512i);
let result = _mm512_sub_epi16(acc_vec, weight_vec);
_mm512_store_si512(acc_ptr.add(i * 32) as *mut __m512i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx2",
not(target_feature = "avx512bw")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..96 {
let acc_vec = _mm256_load_si256(acc_ptr.add(i * 16) as *const __m256i);
let weight_vec = _mm256_load_si256(weight_ptr.add(i * 16) as *const __m256i);
let result = _mm256_sub_epi16(acc_vec, weight_vec);
_mm256_store_si256(acc_ptr.add(i * 16) as *mut __m256i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "avx2")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..192 {
let acc_vec = _mm_load_si128(acc_ptr.add(i * 8) as *const __m128i);
let weight_vec = _mm_load_si128(weight_ptr.add(i * 8) as *const __m128i);
let result = _mm_sub_epi16(acc_vec, weight_vec);
_mm_store_si128(acc_ptr.add(i * 8) as *mut __m128i, result);
}
}
return;
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
unsafe {
use std::arch::wasm32::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..192 {
let acc_vec = v128_load(acc_ptr.add(i * 8) as *const v128);
let weight_vec = v128_load(weight_ptr.add(i * 8) as *const v128);
let result = i16x8_sub(acc_vec, weight_vec);
v128_store(acc_ptr.add(i * 8) as *mut v128, result);
}
}
return;
}
#[allow(unreachable_code)]
for (acc, &weight) in accumulation.iter_mut().zip(weights) {
*acc = acc.wrapping_sub(weight);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feature_transformer_dimensions() {
assert_eq!(NNUE_PYTORCH_L1, 1536);
assert_eq!(HALFKA_HM_DIMENSIONS, 73305);
}
}