use super::accumulator::{Aligned, AlignedBox};
use super::accumulator::{DirtyPiece, IndexList, MAX_ACTIVE_FEATURES, MAX_CHANGED_FEATURES};
use super::accumulator_layer_stacks::{
AccumulatorCacheLayerStacks, AccumulatorLayerStacks, AccumulatorStackLayerStacks,
};
use super::bona_piece::BonaPiece;
use super::bona_piece_halfka_hm::{halfka_index, is_hm_mirror, king_bucket, pack_bonapiece};
use super::constants::{HALFKA_HM_DIMENSIONS, NNUE_PYTORCH_L1};
use super::features::{Feature, FeatureSet, HalfKA_hm, HalfKA_hm_FeatureSet};
use super::leb128::read_compressed_tensor_i16_all;
use crate::position::Position;
use crate::types::Color;
use std::io::{self, Read};
use std::mem::MaybeUninit;
#[cold]
#[inline(never)]
fn feature_index_oob(index: usize, max: usize) -> ! {
panic!("Feature index out of range: {index} (max: {max})")
}
#[inline]
fn append_changed_indices(
dirty_piece: &DirtyPiece,
perspective: Color,
king_sq: crate::types::Square,
removed: &mut IndexList<MAX_CHANGED_FEATURES>,
added: &mut IndexList<MAX_CHANGED_FEATURES>,
) {
<HalfKA_hm as Feature>::append_changed_indices(
dirty_piece,
perspective,
king_sq,
removed,
added,
);
}
#[inline]
fn append_active_indices(
pos: &Position,
perspective: Color,
active: &mut IndexList<MAX_ACTIVE_FEATURES>,
) {
<HalfKA_hm as Feature>::append_active_indices(pos, perspective, active);
}
#[inline]
fn feature_index_from_bona_piece(
bp: BonaPiece,
perspective: Color,
king_sq: crate::types::Square,
) -> usize {
let kb = king_bucket(king_sq, perspective);
let hm_mirror = is_hm_mirror(king_sq, perspective);
let packed = pack_bonapiece(bp, hm_mirror);
halfka_index(kb, packed)
}
#[repr(C, align(64))]
pub struct FeatureTransformerLayerStacks {
pub biases: Aligned<[i16; NNUE_PYTORCH_L1]>,
pub weights: AlignedBox<i16>,
}
impl FeatureTransformerLayerStacks {
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut biases = [0i16; NNUE_PYTORCH_L1];
let mut buf = [0u8; 2];
for bias in biases.iter_mut() {
reader.read_exact(&mut buf)?;
*bias = i16::from_le_bytes(buf);
}
let weight_size = HALFKA_HM_DIMENSIONS * NNUE_PYTORCH_L1;
let mut weights = AlignedBox::new_zeroed(weight_size);
for weight in weights.iter_mut() {
reader.read_exact(&mut buf)?;
*weight = i16::from_le_bytes(buf);
}
Ok(Self {
biases: Aligned(biases),
weights,
})
}
pub fn read_leb128<R: Read>(reader: &mut R) -> io::Result<Self> {
let weight_size = HALFKA_HM_DIMENSIONS * NNUE_PYTORCH_L1;
let total_size = NNUE_PYTORCH_L1 + weight_size;
let first_block = read_compressed_tensor_i16_all(reader)?;
if first_block.len() == total_size {
let mut biases = [0i16; NNUE_PYTORCH_L1];
biases.copy_from_slice(&first_block[..NNUE_PYTORCH_L1]);
let mut weights = AlignedBox::new_zeroed(weight_size);
weights.copy_from_slice(&first_block[NNUE_PYTORCH_L1..]);
return Ok(Self {
biases: Aligned(biases),
weights,
});
}
if first_block.len() == NNUE_PYTORCH_L1 {
let weights_block = read_compressed_tensor_i16_all(reader)?;
if weights_block.len() != weight_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"FT weights block size mismatch: got {}, expected {}",
weights_block.len(),
weight_size
),
));
}
let mut biases = [0i16; NNUE_PYTORCH_L1];
biases.copy_from_slice(&first_block);
let mut weights = AlignedBox::new_zeroed(weight_size);
weights.copy_from_slice(&weights_block);
return Ok(Self {
biases: Aligned(biases),
weights,
});
}
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unexpected LEB128 tensor size: got {}, expected {} or {}",
first_block.len(),
NNUE_PYTORCH_L1,
total_size
),
))
}
pub fn refresh_accumulator(&self, pos: &Position, acc: &mut AccumulatorLayerStacks) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let accumulation = acc.get_mut(p);
accumulation.copy_from_slice(&self.biases.0);
let mut active_indices = IndexList::new();
append_active_indices(pos, perspective, &mut active_indices);
for index in active_indices.iter() {
self.add_weights(accumulation, index);
}
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
pub fn update_accumulator(
&self,
pos: &Position,
dirty_piece: &DirtyPiece,
acc: &mut AccumulatorLayerStacks,
prev_acc: &AccumulatorLayerStacks,
) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let reset = HalfKA_hm_FeatureSet::needs_refresh(dirty_piece, perspective);
if reset {
let accumulation = acc.get_mut(p);
accumulation.copy_from_slice(&self.biases.0);
let mut active_indices = IndexList::new();
append_active_indices(pos, perspective, &mut active_indices);
for index in active_indices.iter() {
self.add_weights(accumulation, index);
}
} else {
let mut removed = IndexList::new();
let mut added = IndexList::new();
append_changed_indices(
dirty_piece,
perspective,
pos.king_square(perspective),
&mut removed,
&mut added,
);
let prev = prev_acc.get(p);
let curr = acc.get_mut(p);
curr.copy_from_slice(prev);
if !self.try_apply_dirty_piece_fast(
curr,
dirty_piece,
perspective,
pos.king_square(perspective),
) {
for index in removed.iter() {
self.sub_weights(curr, index);
}
for index in added.iter() {
self.add_weights(curr, index);
}
}
}
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
pub fn update_accumulator_with_cache(
&self,
pos: &Position,
dirty_piece: &DirtyPiece,
acc: &mut AccumulatorLayerStacks,
prev_acc: &AccumulatorLayerStacks,
cache: &mut AccumulatorCacheLayerStacks,
) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let reset = HalfKA_hm_FeatureSet::needs_refresh(dirty_piece, perspective);
if reset {
self.refresh_perspective_with_cache(pos, perspective, acc.get_mut(p), cache);
} else {
let mut removed = IndexList::new();
let mut added = IndexList::new();
append_changed_indices(
dirty_piece,
perspective,
pos.king_square(perspective),
&mut removed,
&mut added,
);
let prev = prev_acc.get(p);
let curr = acc.get_mut(p);
curr.copy_from_slice(prev);
if !self.try_apply_dirty_piece_fast(
curr,
dirty_piece,
perspective,
pos.king_square(perspective),
) {
for index in removed.iter() {
self.sub_weights(curr, index);
}
for index in added.iter() {
self.add_weights(curr, index);
}
}
}
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
pub fn refresh_accumulator_with_cache(
&self,
pos: &Position,
acc: &mut AccumulatorLayerStacks,
cache: &mut AccumulatorCacheLayerStacks,
) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
self.refresh_perspective_with_cache(pos, perspective, acc.get_mut(p), cache);
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
fn refresh_perspective_with_cache(
&self,
pos: &Position,
perspective: Color,
accumulation: &mut [i16; NNUE_PYTORCH_L1],
cache: &mut AccumulatorCacheLayerStacks,
) {
let king_sq = pos.king_square(perspective);
let mut active_indices = IndexList::new();
append_active_indices(pos, perspective, &mut active_indices);
let mut sorted_buf = [const { MaybeUninit::<u32>::uninit() }; MAX_ACTIVE_FEATURES];
let len = active_indices.len();
for (slot, idx) in sorted_buf[..len].iter_mut().zip(active_indices.iter()) {
slot.write(idx as u32);
}
let sorted =
unsafe { std::slice::from_raw_parts_mut(sorted_buf.as_mut_ptr() as *mut u32, len) };
sorted.sort_unstable();
cache.refresh_or_cache(
king_sq,
perspective,
sorted,
&self.biases.0,
accumulation,
|acc, idx| self.add_weights(acc, idx),
|acc, idx| self.sub_weights(acc, idx),
);
}
pub fn forward_update_incremental(
&self,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
source_idx: usize,
) -> bool {
let Some(path) = stack.collect_path(source_idx) else {
return false;
};
let source_acc = stack.entry_at(source_idx).accumulator.clone();
{
let current_acc = &mut stack.current_mut().accumulator;
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
current_acc.get_mut(p).copy_from_slice(source_acc.get(p));
}
}
for entry_idx in path.iter() {
let dirty_piece = stack.entry_at(entry_idx).dirty_piece;
for perspective in [Color::Black, Color::White] {
debug_assert!(
!dirty_piece.king_moved[perspective.index()],
"King moved between source and current"
);
let king_sq = pos.king_square(perspective);
let mut removed = IndexList::new();
let mut added = IndexList::new();
append_changed_indices(
&dirty_piece,
perspective,
king_sq,
&mut removed,
&mut added,
);
let p = perspective as usize;
let accumulation = stack.current_mut().accumulator.get_mut(p);
if !self.try_apply_dirty_piece_fast(
accumulation,
&dirty_piece,
perspective,
king_sq,
) {
for index in removed.iter() {
self.sub_weights(accumulation, index);
}
for index in added.iter() {
self.add_weights(accumulation, index);
}
}
}
}
stack.current_mut().accumulator.computed_accumulation = true;
stack.current_mut().accumulator.computed_score = false;
true
}
#[inline]
fn add_weights(&self, accumulation: &mut [i16; NNUE_PYTORCH_L1], index: usize) {
let weights = self.weight_row(index);
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx512f",
target_feature = "avx512bw"
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..48 {
let acc_vec = _mm512_load_si512(acc_ptr.add(i * 32) as *const __m512i);
let weight_vec = _mm512_load_si512(weight_ptr.add(i * 32) as *const __m512i);
let result = _mm512_add_epi16(acc_vec, weight_vec);
_mm512_store_si512(acc_ptr.add(i * 32) as *mut __m512i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx2",
not(target_feature = "avx512bw")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..96 {
let acc_vec = _mm256_load_si256(acc_ptr.add(i * 16) as *const __m256i);
let weight_vec = _mm256_load_si256(weight_ptr.add(i * 16) as *const __m256i);
let result = _mm256_add_epi16(acc_vec, weight_vec);
_mm256_store_si256(acc_ptr.add(i * 16) as *mut __m256i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "avx2")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..192 {
let acc_vec = _mm_load_si128(acc_ptr.add(i * 8) as *const __m128i);
let weight_vec = _mm_load_si128(weight_ptr.add(i * 8) as *const __m128i);
let result = _mm_add_epi16(acc_vec, weight_vec);
_mm_store_si128(acc_ptr.add(i * 8) as *mut __m128i, result);
}
}
return;
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
unsafe {
use std::arch::wasm32::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..192 {
let acc_vec = v128_load(acc_ptr.add(i * 8) as *const v128);
let weight_vec = v128_load(weight_ptr.add(i * 8) as *const v128);
let result = i16x8_add(acc_vec, weight_vec);
v128_store(acc_ptr.add(i * 8) as *mut v128, result);
}
}
return;
}
#[allow(unreachable_code)]
for (acc, &weight) in accumulation.iter_mut().zip(weights) {
*acc = acc.wrapping_add(weight);
}
}
#[inline]
fn weight_row(&self, index: usize) -> &[i16] {
let Some(offset) = index.checked_mul(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
let Some(end) = offset.checked_add(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
if end > self.weights.len() {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
}
&self.weights[offset..end]
}
#[inline]
fn try_apply_dirty_piece_fast(
&self,
accumulation: &mut [i16; NNUE_PYTORCH_L1],
dirty_piece: &DirtyPiece,
perspective: Color,
king_sq: crate::types::Square,
) -> bool {
let changed = &dirty_piece.changed_piece;
let old_new = |idx: usize| {
let entry = &changed[idx];
let old_bp = if perspective == Color::Black {
entry.old_piece.fb
} else {
entry.old_piece.fw
};
let new_bp = if perspective == Color::Black {
entry.new_piece.fb
} else {
entry.new_piece.fw
};
(old_bp, new_bp)
};
match dirty_piece.dirty_num as usize {
1 => {
let (old_bp, new_bp) = old_new(0);
if old_bp != BonaPiece::ZERO && new_bp != BonaPiece::ZERO {
self.apply_sub_add_fused(
accumulation,
feature_index_from_bona_piece(old_bp, perspective, king_sq),
feature_index_from_bona_piece(new_bp, perspective, king_sq),
);
true
} else {
false
}
}
2 => {
let (old_bp0, new_bp0) = old_new(0);
let (old_bp1, new_bp1) = old_new(1);
if old_bp0 != BonaPiece::ZERO
&& new_bp0 != BonaPiece::ZERO
&& old_bp1 != BonaPiece::ZERO
&& new_bp1 != BonaPiece::ZERO
{
self.apply_double_sub_add_fused(
accumulation,
feature_index_from_bona_piece(old_bp0, perspective, king_sq),
feature_index_from_bona_piece(new_bp0, perspective, king_sq),
feature_index_from_bona_piece(old_bp1, perspective, king_sq),
feature_index_from_bona_piece(new_bp1, perspective, king_sq),
);
true
} else {
false
}
}
_ => false,
}
}
#[inline]
fn apply_sub_add_fused(
&self,
accumulation: &mut [i16; NNUE_PYTORCH_L1],
sub_index: usize,
add_index: usize,
) {
let sub_weights = self.weight_row(sub_index);
let add_weights = self.weight_row(add_index);
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx512f",
target_feature = "avx512bw"
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let sub_ptr = sub_weights.as_ptr();
let add_ptr = add_weights.as_ptr();
for i in 0..48 {
let acc_vec = _mm512_load_si512(acc_ptr.add(i * 32) as *const __m512i);
let sub_vec = _mm512_load_si512(sub_ptr.add(i * 32) as *const __m512i);
let add_vec = _mm512_load_si512(add_ptr.add(i * 32) as *const __m512i);
let result = _mm512_add_epi16(_mm512_sub_epi16(acc_vec, sub_vec), add_vec);
_mm512_store_si512(acc_ptr.add(i * 32) as *mut __m512i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx2",
not(target_feature = "avx512bw")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let sub_ptr = sub_weights.as_ptr();
let add_ptr = add_weights.as_ptr();
for i in 0..96 {
let acc_vec = _mm256_load_si256(acc_ptr.add(i * 16) as *const __m256i);
let sub_vec = _mm256_load_si256(sub_ptr.add(i * 16) as *const __m256i);
let add_vec = _mm256_load_si256(add_ptr.add(i * 16) as *const __m256i);
let result = _mm256_add_epi16(_mm256_sub_epi16(acc_vec, sub_vec), add_vec);
_mm256_store_si256(acc_ptr.add(i * 16) as *mut __m256i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "avx2")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let sub_ptr = sub_weights.as_ptr();
let add_ptr = add_weights.as_ptr();
for i in 0..192 {
let acc_vec = _mm_load_si128(acc_ptr.add(i * 8) as *const __m128i);
let sub_vec = _mm_load_si128(sub_ptr.add(i * 8) as *const __m128i);
let add_vec = _mm_load_si128(add_ptr.add(i * 8) as *const __m128i);
let result = _mm_add_epi16(_mm_sub_epi16(acc_vec, sub_vec), add_vec);
_mm_store_si128(acc_ptr.add(i * 8) as *mut __m128i, result);
}
}
return;
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
unsafe {
use std::arch::wasm32::*;
let acc_ptr = accumulation.as_mut_ptr();
let sub_ptr = sub_weights.as_ptr();
let add_ptr = add_weights.as_ptr();
for i in 0..192 {
let acc_vec = v128_load(acc_ptr.add(i * 8) as *const v128);
let sub_vec = v128_load(sub_ptr.add(i * 8) as *const v128);
let add_vec = v128_load(add_ptr.add(i * 8) as *const v128);
let result = i16x8_add(i16x8_sub(acc_vec, sub_vec), add_vec);
v128_store(acc_ptr.add(i * 8) as *mut v128, result);
}
}
return;
}
#[allow(unreachable_code)]
for ((acc, &sub_weight), &add_weight) in
accumulation.iter_mut().zip(sub_weights.iter()).zip(add_weights.iter())
{
*acc = acc.wrapping_sub(sub_weight).wrapping_add(add_weight);
}
}
#[inline]
fn apply_double_sub_add_fused(
&self,
accumulation: &mut [i16; NNUE_PYTORCH_L1],
sub_index0: usize,
add_index0: usize,
sub_index1: usize,
add_index1: usize,
) {
let sub_weights0 = self.weight_row(sub_index0);
let add_weights0 = self.weight_row(add_index0);
let sub_weights1 = self.weight_row(sub_index1);
let add_weights1 = self.weight_row(add_index1);
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx512f",
target_feature = "avx512bw"
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let sub_ptr0 = sub_weights0.as_ptr();
let add_ptr0 = add_weights0.as_ptr();
let sub_ptr1 = sub_weights1.as_ptr();
let add_ptr1 = add_weights1.as_ptr();
for i in 0..48 {
let acc_vec = _mm512_load_si512(acc_ptr.add(i * 32) as *const __m512i);
let sub_vec0 = _mm512_load_si512(sub_ptr0.add(i * 32) as *const __m512i);
let add_vec0 = _mm512_load_si512(add_ptr0.add(i * 32) as *const __m512i);
let sub_vec1 = _mm512_load_si512(sub_ptr1.add(i * 32) as *const __m512i);
let add_vec1 = _mm512_load_si512(add_ptr1.add(i * 32) as *const __m512i);
let result = _mm512_add_epi16(
_mm512_add_epi16(_mm512_sub_epi16(acc_vec, sub_vec0), add_vec0),
_mm512_sub_epi16(add_vec1, sub_vec1),
);
_mm512_store_si512(acc_ptr.add(i * 32) as *mut __m512i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx2",
not(target_feature = "avx512bw")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let sub_ptr0 = sub_weights0.as_ptr();
let add_ptr0 = add_weights0.as_ptr();
let sub_ptr1 = sub_weights1.as_ptr();
let add_ptr1 = add_weights1.as_ptr();
for i in 0..96 {
let acc_vec = _mm256_load_si256(acc_ptr.add(i * 16) as *const __m256i);
let sub_vec0 = _mm256_load_si256(sub_ptr0.add(i * 16) as *const __m256i);
let add_vec0 = _mm256_load_si256(add_ptr0.add(i * 16) as *const __m256i);
let sub_vec1 = _mm256_load_si256(sub_ptr1.add(i * 16) as *const __m256i);
let add_vec1 = _mm256_load_si256(add_ptr1.add(i * 16) as *const __m256i);
let result = _mm256_add_epi16(
_mm256_add_epi16(_mm256_sub_epi16(acc_vec, sub_vec0), add_vec0),
_mm256_sub_epi16(add_vec1, sub_vec1),
);
_mm256_store_si256(acc_ptr.add(i * 16) as *mut __m256i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "avx2")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let sub_ptr0 = sub_weights0.as_ptr();
let add_ptr0 = add_weights0.as_ptr();
let sub_ptr1 = sub_weights1.as_ptr();
let add_ptr1 = add_weights1.as_ptr();
for i in 0..192 {
let acc_vec = _mm_load_si128(acc_ptr.add(i * 8) as *const __m128i);
let sub_vec0 = _mm_load_si128(sub_ptr0.add(i * 8) as *const __m128i);
let add_vec0 = _mm_load_si128(add_ptr0.add(i * 8) as *const __m128i);
let sub_vec1 = _mm_load_si128(sub_ptr1.add(i * 8) as *const __m128i);
let add_vec1 = _mm_load_si128(add_ptr1.add(i * 8) as *const __m128i);
let result = _mm_add_epi16(
_mm_add_epi16(_mm_sub_epi16(acc_vec, sub_vec0), add_vec0),
_mm_sub_epi16(add_vec1, sub_vec1),
);
_mm_store_si128(acc_ptr.add(i * 8) as *mut __m128i, result);
}
}
return;
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
unsafe {
use std::arch::wasm32::*;
let acc_ptr = accumulation.as_mut_ptr();
let sub_ptr0 = sub_weights0.as_ptr();
let add_ptr0 = add_weights0.as_ptr();
let sub_ptr1 = sub_weights1.as_ptr();
let add_ptr1 = add_weights1.as_ptr();
for i in 0..192 {
let acc_vec = v128_load(acc_ptr.add(i * 8) as *const v128);
let sub_vec0 = v128_load(sub_ptr0.add(i * 8) as *const v128);
let add_vec0 = v128_load(add_ptr0.add(i * 8) as *const v128);
let sub_vec1 = v128_load(sub_ptr1.add(i * 8) as *const v128);
let add_vec1 = v128_load(add_ptr1.add(i * 8) as *const v128);
let result = i16x8_add(
i16x8_add(i16x8_sub(acc_vec, sub_vec0), add_vec0),
i16x8_sub(add_vec1, sub_vec1),
);
v128_store(acc_ptr.add(i * 8) as *mut v128, result);
}
}
return;
}
#[allow(unreachable_code)]
for ((((acc, &sub_weight0), &add_weight0), &sub_weight1), &add_weight1) in accumulation
.iter_mut()
.zip(sub_weights0.iter())
.zip(add_weights0.iter())
.zip(sub_weights1.iter())
.zip(add_weights1.iter())
{
*acc = acc
.wrapping_sub(sub_weight0)
.wrapping_add(add_weight0)
.wrapping_sub(sub_weight1)
.wrapping_add(add_weight1);
}
}
#[inline]
fn sub_weights(&self, accumulation: &mut [i16; NNUE_PYTORCH_L1], index: usize) {
let weights = self.weight_row(index);
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx512f",
target_feature = "avx512bw"
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..48 {
let acc_vec = _mm512_load_si512(acc_ptr.add(i * 32) as *const __m512i);
let weight_vec = _mm512_load_si512(weight_ptr.add(i * 32) as *const __m512i);
let result = _mm512_sub_epi16(acc_vec, weight_vec);
_mm512_store_si512(acc_ptr.add(i * 32) as *mut __m512i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "avx2",
not(target_feature = "avx512bw")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..96 {
let acc_vec = _mm256_load_si256(acc_ptr.add(i * 16) as *const __m256i);
let weight_vec = _mm256_load_si256(weight_ptr.add(i * 16) as *const __m256i);
let result = _mm256_sub_epi16(acc_vec, weight_vec);
_mm256_store_si256(acc_ptr.add(i * 16) as *mut __m256i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "avx2")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..192 {
let acc_vec = _mm_load_si128(acc_ptr.add(i * 8) as *const __m128i);
let weight_vec = _mm_load_si128(weight_ptr.add(i * 8) as *const __m128i);
let result = _mm_sub_epi16(acc_vec, weight_vec);
_mm_store_si128(acc_ptr.add(i * 8) as *mut __m128i, result);
}
}
return;
}
#[cfg(all(target_arch = "wasm32", target_feature = "simd128"))]
{
unsafe {
use std::arch::wasm32::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
for i in 0..192 {
let acc_vec = v128_load(acc_ptr.add(i * 8) as *const v128);
let weight_vec = v128_load(weight_ptr.add(i * 8) as *const v128);
let result = i16x8_sub(acc_vec, weight_vec);
v128_store(acc_ptr.add(i * 8) as *mut v128, result);
}
}
return;
}
#[allow(unreachable_code)]
for (acc, &weight) in accumulation.iter_mut().zip(weights) {
*acc = acc.wrapping_sub(weight);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nnue::accumulator::ChangedBonaPiece;
use crate::nnue::bona_piece::ExtBonaPiece;
use crate::nnue::piece_list::PieceNumber;
use crate::types::{File, Piece, PieceType, Rank, Square};
fn make_test_transformer() -> FeatureTransformerLayerStacks {
FeatureTransformerLayerStacks {
biases: Aligned([0; NNUE_PYTORCH_L1]),
weights: AlignedBox::new_zeroed(HALFKA_HM_DIMENSIONS * NNUE_PYTORCH_L1),
}
}
fn fill_weight_row(ft: &mut FeatureTransformerLayerStacks, index: usize, seed: i16) {
let start = index * NNUE_PYTORCH_L1;
for (i, slot) in ft.weights[start..start + NNUE_PYTORCH_L1].iter_mut().enumerate() {
*slot = seed.wrapping_add((i % 29) as i16);
}
}
fn apply_generic(
ft: &FeatureTransformerLayerStacks,
accumulation: &mut [i16; NNUE_PYTORCH_L1],
dirty_piece: &DirtyPiece,
perspective: Color,
king_sq: Square,
) {
let mut removed = IndexList::new();
let mut added = IndexList::new();
append_changed_indices(dirty_piece, perspective, king_sq, &mut removed, &mut added);
for index in removed.iter() {
ft.sub_weights(accumulation, index);
}
for index in added.iter() {
ft.add_weights(accumulation, index);
}
}
#[test]
fn test_feature_transformer_dimensions() {
assert_eq!(NNUE_PYTORCH_L1, 1536);
assert_eq!(HALFKA_HM_DIMENSIONS, 73305);
}
#[test]
fn test_try_apply_dirty_piece_fast_matches_generic_single_move() {
let king_sq = Square::new(File::File5, Rank::Rank9);
let mut ft = make_test_transformer();
let mut dirty_piece = DirtyPiece::new();
dirty_piece.dirty_num = 1;
dirty_piece.piece_no[0] = PieceNumber(0);
dirty_piece.changed_piece[0] = ChangedBonaPiece {
old_piece: ExtBonaPiece::from_board(
Piece::B_PAWN,
Square::new(File::File7, Rank::Rank7),
),
new_piece: ExtBonaPiece::from_board(
Piece::B_PAWN,
Square::new(File::File7, Rank::Rank6),
),
};
let old_index = feature_index_from_bona_piece(
dirty_piece.changed_piece[0].old_piece.fb,
Color::Black,
king_sq,
);
let new_index = feature_index_from_bona_piece(
dirty_piece.changed_piece[0].new_piece.fb,
Color::Black,
king_sq,
);
fill_weight_row(&mut ft, old_index, 11);
fill_weight_row(&mut ft, new_index, 37);
let mut generic = Aligned([5i16; NNUE_PYTORCH_L1]);
let mut fast = Aligned([5i16; NNUE_PYTORCH_L1]);
apply_generic(&ft, &mut generic.0, &dirty_piece, Color::Black, king_sq);
assert!(ft.try_apply_dirty_piece_fast(&mut fast.0, &dirty_piece, Color::Black, king_sq));
assert_eq!(generic.0, fast.0);
}
#[test]
fn test_try_apply_dirty_piece_fast_matches_generic_capture() {
let king_sq = Square::new(File::File5, Rank::Rank9);
let mut ft = make_test_transformer();
let mut dirty_piece = DirtyPiece::new();
dirty_piece.dirty_num = 2;
dirty_piece.piece_no[0] = PieceNumber(0);
dirty_piece.changed_piece[0] = ChangedBonaPiece {
old_piece: ExtBonaPiece::from_board(
Piece::B_PAWN,
Square::new(File::File2, Rank::Rank4),
),
new_piece: ExtBonaPiece::from_board(
Piece::B_PAWN,
Square::new(File::File2, Rank::Rank3),
),
};
dirty_piece.piece_no[1] = PieceNumber(1);
dirty_piece.changed_piece[1] = ChangedBonaPiece {
old_piece: ExtBonaPiece::from_board(
Piece::W_PAWN,
Square::new(File::File2, Rank::Rank3),
),
new_piece: ExtBonaPiece::from_hand(Color::Black, PieceType::Pawn, 1),
};
let indices = [
feature_index_from_bona_piece(
dirty_piece.changed_piece[0].old_piece.fb,
Color::Black,
king_sq,
),
feature_index_from_bona_piece(
dirty_piece.changed_piece[0].new_piece.fb,
Color::Black,
king_sq,
),
feature_index_from_bona_piece(
dirty_piece.changed_piece[1].old_piece.fb,
Color::Black,
king_sq,
),
feature_index_from_bona_piece(
dirty_piece.changed_piece[1].new_piece.fb,
Color::Black,
king_sq,
),
];
for (seed, &index) in [13i16, 29, 43, 71].iter().zip(indices.iter()) {
fill_weight_row(&mut ft, index, *seed);
}
let mut generic = Aligned([7i16; NNUE_PYTORCH_L1]);
let mut fast = Aligned([7i16; NNUE_PYTORCH_L1]);
apply_generic(&ft, &mut generic.0, &dirty_piece, Color::Black, king_sq);
assert!(ft.try_apply_dirty_piece_fast(&mut fast.0, &dirty_piece, Color::Black, king_sq));
assert_eq!(generic.0, fast.0);
}
#[test]
fn test_try_apply_dirty_piece_fast_matches_generic_single_move_white() {
let king_sq = Square::new(File::File5, Rank::Rank1);
let mut ft = make_test_transformer();
let mut dirty_piece = DirtyPiece::new();
dirty_piece.dirty_num = 1;
dirty_piece.piece_no[0] = PieceNumber(0);
dirty_piece.changed_piece[0] = ChangedBonaPiece {
old_piece: ExtBonaPiece::from_board(
Piece::W_PAWN,
Square::new(File::File3, Rank::Rank3),
),
new_piece: ExtBonaPiece::from_board(
Piece::W_PAWN,
Square::new(File::File3, Rank::Rank4),
),
};
let old_index = feature_index_from_bona_piece(
dirty_piece.changed_piece[0].old_piece.fw,
Color::White,
king_sq,
);
let new_index = feature_index_from_bona_piece(
dirty_piece.changed_piece[0].new_piece.fw,
Color::White,
king_sq,
);
fill_weight_row(&mut ft, old_index, 19);
fill_weight_row(&mut ft, new_index, 53);
let mut generic = Aligned([5i16; NNUE_PYTORCH_L1]);
let mut fast = Aligned([5i16; NNUE_PYTORCH_L1]);
apply_generic(&ft, &mut generic.0, &dirty_piece, Color::White, king_sq);
assert!(ft.try_apply_dirty_piece_fast(&mut fast.0, &dirty_piece, Color::White, king_sq));
assert_eq!(generic.0, fast.0);
}
#[test]
fn test_try_apply_dirty_piece_fast_matches_generic_capture_white() {
let king_sq = Square::new(File::File5, Rank::Rank1);
let mut ft = make_test_transformer();
let mut dirty_piece = DirtyPiece::new();
dirty_piece.dirty_num = 2;
dirty_piece.piece_no[0] = PieceNumber(0);
dirty_piece.changed_piece[0] = ChangedBonaPiece {
old_piece: ExtBonaPiece::from_board(
Piece::W_BISHOP,
Square::new(File::File8, Rank::Rank2),
),
new_piece: ExtBonaPiece::from_board(
Piece::W_BISHOP,
Square::new(File::File3, Rank::Rank7),
),
};
dirty_piece.piece_no[1] = PieceNumber(1);
dirty_piece.changed_piece[1] = ChangedBonaPiece {
old_piece: ExtBonaPiece::from_board(
Piece::B_PAWN,
Square::new(File::File3, Rank::Rank7),
),
new_piece: ExtBonaPiece::from_hand(Color::White, PieceType::Pawn, 1),
};
let indices = [
feature_index_from_bona_piece(
dirty_piece.changed_piece[0].old_piece.fw,
Color::White,
king_sq,
),
feature_index_from_bona_piece(
dirty_piece.changed_piece[0].new_piece.fw,
Color::White,
king_sq,
),
feature_index_from_bona_piece(
dirty_piece.changed_piece[1].old_piece.fw,
Color::White,
king_sq,
),
feature_index_from_bona_piece(
dirty_piece.changed_piece[1].new_piece.fw,
Color::White,
king_sq,
),
];
for (seed, &index) in [17i16, 31, 47, 67].iter().zip(indices.iter()) {
fill_weight_row(&mut ft, index, *seed);
}
let mut generic = Aligned([7i16; NNUE_PYTORCH_L1]);
let mut fast = Aligned([7i16; NNUE_PYTORCH_L1]);
apply_generic(&ft, &mut generic.0, &dirty_piece, Color::White, king_sq);
assert!(ft.try_apply_dirty_piece_fast(&mut fast.0, &dirty_piece, Color::White, king_sq));
assert_eq!(generic.0, fast.0);
}
#[test]
fn test_try_apply_dirty_piece_fast_returns_false_for_hand_only_change() {
let king_sq = Square::new(File::File5, Rank::Rank9);
let ft = make_test_transformer();
let mut dirty_piece = DirtyPiece::new();
dirty_piece.dirty_num = 1;
dirty_piece.piece_no[0] = PieceNumber(0);
dirty_piece.changed_piece[0] = ChangedBonaPiece {
old_piece: ExtBonaPiece::ZERO,
new_piece: ExtBonaPiece::from_hand(Color::Black, PieceType::Pawn, 1),
};
let mut accumulation = Aligned([0i16; NNUE_PYTORCH_L1]);
assert!(!ft.try_apply_dirty_piece_fast(
&mut accumulation.0,
&dirty_piece,
Color::Black,
king_sq,
));
}
}