use super::accumulator::{Aligned, AlignedBox};
use super::accumulator::{DirtyPiece, IndexList, MAX_ACTIVE_FEATURES};
use super::accumulator_layer_stacks::{AccumulatorLayerStacks, AccumulatorStackLayerStacks};
use super::constants::{HALFKA_HM_DIMENSIONS, NNUE_PYTORCH_L1};
use super::features::{FeatureSet, HalfKA_hm_FeatureSet};
use super::leb128::read_compressed_tensor_i16;
use crate::position::Position;
use crate::types::Color;
use std::io::{self, Read};
#[cold]
#[inline(never)]
fn feature_index_oob(index: usize, max: usize) -> ! {
panic!("Feature index out of range: {index} (max: {max})")
}
#[repr(C, align(64))]
pub struct FeatureTransformerLayerStacks {
pub biases: Aligned<[i16; NNUE_PYTORCH_L1]>,
pub weights: AlignedBox<i16>,
}
impl FeatureTransformerLayerStacks {
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut biases = [0i16; NNUE_PYTORCH_L1];
let mut buf = [0u8; 2];
for bias in biases.iter_mut() {
reader.read_exact(&mut buf)?;
*bias = i16::from_le_bytes(buf);
}
let weight_size = HALFKA_HM_DIMENSIONS * NNUE_PYTORCH_L1;
let mut weights = AlignedBox::new_zeroed(weight_size);
for weight in weights.iter_mut() {
reader.read_exact(&mut buf)?;
*weight = i16::from_le_bytes(buf);
}
Ok(Self {
biases: Aligned(biases),
weights,
})
}
pub fn read_leb128<R: Read>(reader: &mut R) -> io::Result<Self> {
let bias_vec = read_compressed_tensor_i16(reader, NNUE_PYTORCH_L1)?;
let mut biases = [0i16; NNUE_PYTORCH_L1];
biases.copy_from_slice(&bias_vec);
let weight_size = HALFKA_HM_DIMENSIONS * NNUE_PYTORCH_L1;
let weight_vec = read_compressed_tensor_i16(reader, weight_size)?;
let mut weights = AlignedBox::new_zeroed(weight_size);
weights.copy_from_slice(&weight_vec);
Ok(Self {
biases: Aligned(biases),
weights,
})
}
pub fn refresh_accumulator(&self, pos: &Position, acc: &mut AccumulatorLayerStacks) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let accumulation = acc.get_mut(p);
accumulation.copy_from_slice(&self.biases.0);
let active_indices = self.get_active_features(pos, perspective);
for &index in active_indices.iter() {
self.add_weights(accumulation, index);
}
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
pub fn update_accumulator(
&self,
pos: &Position,
dirty_piece: &DirtyPiece,
acc: &mut AccumulatorLayerStacks,
prev_acc: &AccumulatorLayerStacks,
) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let reset = HalfKA_hm_FeatureSet::needs_refresh(dirty_piece, perspective);
if reset {
let accumulation = acc.get_mut(p);
accumulation.copy_from_slice(&self.biases.0);
let active_indices = self.get_active_features(pos, perspective);
for &index in active_indices.iter() {
self.add_weights(accumulation, index);
}
} else {
let (removed, added) = HalfKA_hm_FeatureSet::collect_changed_indices(
dirty_piece,
perspective,
pos.king_square(perspective),
);
let prev = prev_acc.get(p);
let curr = acc.get_mut(p);
curr.copy_from_slice(prev);
for &index in removed.iter() {
self.sub_weights(curr, index);
}
for &index in added.iter() {
self.add_weights(curr, index);
}
}
}
acc.computed_accumulation = true;
acc.computed_score = false;
}
pub fn forward_update_incremental(
&self,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
source_idx: usize,
) -> bool {
let Some(path) = stack.collect_path(source_idx) else {
return false;
};
let source_acc = stack.entry_at(source_idx).accumulator.clone();
{
let current_acc = &mut stack.current_mut().accumulator;
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
current_acc.get_mut(p).copy_from_slice(source_acc.get(p));
}
}
for &entry_idx in path.iter() {
let dirty_piece = stack.entry_at(entry_idx).dirty_piece;
for perspective in [Color::Black, Color::White] {
debug_assert!(
!dirty_piece.king_moved[perspective.index()],
"King moved between source and current"
);
let king_sq = pos.king_square(perspective);
let (removed, added) = HalfKA_hm_FeatureSet::collect_changed_indices(
&dirty_piece,
perspective,
king_sq,
);
let p = perspective as usize;
let accumulation = stack.current_mut().accumulator.get_mut(p);
for &index in removed.iter() {
self.sub_weights(accumulation, index);
}
for &index in added.iter() {
self.add_weights(accumulation, index);
}
}
}
stack.current_mut().accumulator.computed_accumulation = true;
stack.current_mut().accumulator.computed_score = false;
true
}
#[inline]
fn get_active_features(
&self,
pos: &Position,
perspective: Color,
) -> IndexList<MAX_ACTIVE_FEATURES> {
HalfKA_hm_FeatureSet::collect_active_indices(pos, perspective)
}
#[inline]
fn add_weights(&self, accumulation: &mut [i16; NNUE_PYTORCH_L1], index: usize) {
let Some(offset) = index.checked_mul(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
let Some(end) = offset.checked_add(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
if end > self.weights.len() {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
}
let weights = &self.weights[offset..offset + NNUE_PYTORCH_L1];
for (acc, &weight) in accumulation.iter_mut().zip(weights) {
*acc = acc.wrapping_add(weight);
}
}
#[inline]
fn sub_weights(&self, accumulation: &mut [i16; NNUE_PYTORCH_L1], index: usize) {
let Some(offset) = index.checked_mul(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
let Some(end) = offset.checked_add(NNUE_PYTORCH_L1) else {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
};
if end > self.weights.len() {
feature_index_oob(index, self.weights.len() / NNUE_PYTORCH_L1);
}
let weights = &self.weights[offset..offset + NNUE_PYTORCH_L1];
for (acc, &weight) in accumulation.iter_mut().zip(weights) {
*acc = acc.wrapping_sub(weight);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feature_transformer_dimensions() {
assert_eq!(NNUE_PYTORCH_L1, 1536);
assert_eq!(HALFKA_HM_DIMENSIONS, 73305);
}
}