#![allow(non_camel_case_types)]
use std::io::{self, Read, Seek};
use std::marker::PhantomData;
use std::sync::OnceLock;
use super::accumulator::{Aligned, AlignedBox, DirtyPiece, IndexList, MAX_PATH_LENGTH};
use super::activation::FtActivation;
use super::constants::{FV_SCALE_HALFKA, HALFKA_HM_DIMENSIONS, MAX_ARCH_LEN, NNUE_VERSION_HALFKA};
use super::features::{FeatureSet, HalfKA_hm_FeatureSet};
use super::network::{get_fv_scale_override, parse_fv_scale_from_arch};
use crate::position::Position;
use crate::types::{Color, Value};
#[inline]
fn nnue_debug_enabled() -> bool {
static NNUE_DEBUG: OnceLock<bool> = OnceLock::new();
*NNUE_DEBUG.get_or_init(|| std::env::var("NNUE_DEBUG").is_ok())
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
#[inline]
unsafe fn m256_add_dpbusd_epi32(
acc: &mut std::arch::x86_64::__m256i,
a: std::arch::x86_64::__m256i,
b: std::arch::x86_64::__m256i,
) {
unsafe {
use std::arch::x86_64::*;
let product = _mm256_maddubs_epi16(a, b);
let product32 = _mm256_madd_epi16(product, _mm256_set1_epi16(1));
*acc = _mm256_add_epi32(*acc, product32);
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
#[inline]
unsafe fn hsum_i32_avx2(v: std::arch::x86_64::__m256i) -> i32 {
unsafe {
use std::arch::x86_64::*;
let hi = _mm256_extracti128_si256(v, 1);
let lo = _mm256_castsi256_si128(v);
let sum128 = _mm_add_epi32(lo, hi);
let hi64 = _mm_unpackhi_epi64(sum128, sum128);
let sum64 = _mm_add_epi32(sum128, hi64);
let hi32 = _mm_shuffle_epi32(sum64, 1);
let sum32 = _mm_add_epi32(sum64, hi32);
_mm_cvtsi128_si32(sum32)
}
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "ssse3",
not(target_feature = "avx2")
))]
#[inline]
unsafe fn hsum_i32_sse2(v: std::arch::x86_64::__m128i) -> i32 {
unsafe {
use std::arch::x86_64::*;
let hi64 = _mm_unpackhi_epi64(v, v);
let sum64 = _mm_add_epi32(v, hi64);
let hi32 = _mm_shuffle_epi32(sum64, 1);
let sum32 = _mm_add_epi32(sum64, hi32);
_mm_cvtsi128_si32(sum32)
}
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "ssse3",
not(target_feature = "avx2")
))]
#[inline]
unsafe fn m128_add_dpbusd_epi32(
acc: &mut std::arch::x86_64::__m128i,
a: std::arch::x86_64::__m128i,
b: std::arch::x86_64::__m128i,
) {
unsafe {
use std::arch::x86_64::*;
let product = _mm_maddubs_epi16(a, b);
let product32 = _mm_madd_epi16(product, _mm_set1_epi16(1));
*acc = _mm_add_epi32(*acc, product32);
}
}
pub struct AccumulatorHalfKA_hm<const L1: usize> {
pub accumulation: [AlignedBox<i16>; 2],
pub computed_accumulation: bool,
}
impl<const L1: usize> AccumulatorHalfKA_hm<L1> {
pub fn new() -> Self {
Self {
accumulation: [AlignedBox::new_zeroed(L1), AlignedBox::new_zeroed(L1)],
computed_accumulation: false,
}
}
pub fn clear(&mut self) {
self.accumulation[0].fill(0);
self.accumulation[1].fill(0);
self.computed_accumulation = false;
}
}
impl<const L1: usize> Default for AccumulatorHalfKA_hm<L1> {
fn default() -> Self {
Self::new()
}
}
impl<const L1: usize> Clone for AccumulatorHalfKA_hm<L1> {
fn clone(&self) -> Self {
Self {
accumulation: [self.accumulation[0].clone(), self.accumulation[1].clone()],
computed_accumulation: self.computed_accumulation,
}
}
}
pub struct AccumulatorEntryHalfKA_hm<const L1: usize> {
pub accumulator: AccumulatorHalfKA_hm<L1>,
pub dirty_piece: DirtyPiece,
pub previous: Option<usize>,
}
pub struct AccumulatorStackHalfKA_hm<const L1: usize> {
entries: Vec<AccumulatorEntryHalfKA_hm<L1>>,
current_idx: usize,
}
impl<const L1: usize> AccumulatorStackHalfKA_hm<L1> {
pub fn new() -> Self {
let mut entries = Vec::with_capacity(128);
entries.push(AccumulatorEntryHalfKA_hm {
accumulator: AccumulatorHalfKA_hm::new(),
dirty_piece: DirtyPiece::default(),
previous: None,
});
Self {
entries,
current_idx: 0,
}
}
pub fn current(&self) -> &AccumulatorEntryHalfKA_hm<L1> {
&self.entries[self.current_idx]
}
pub fn current_mut(&mut self) -> &mut AccumulatorEntryHalfKA_hm<L1> {
&mut self.entries[self.current_idx]
}
#[inline]
pub fn top(&self) -> &AccumulatorHalfKA_hm<L1> {
&self.entries[self.current_idx].accumulator
}
#[inline]
pub fn top_mut(&mut self) -> &mut AccumulatorHalfKA_hm<L1> {
&mut self.entries[self.current_idx].accumulator
}
#[inline]
pub fn top_and_source(
&mut self,
source_idx: usize,
) -> (&mut AccumulatorHalfKA_hm<L1>, &AccumulatorHalfKA_hm<L1>) {
let current_idx = self.current_idx;
debug_assert!(
source_idx < current_idx,
"source_idx ({source_idx}) must be < current_idx ({current_idx})"
);
let (left, right) = self.entries.split_at_mut(current_idx);
(&mut right[0].accumulator, &left[source_idx].accumulator)
}
pub fn push(&mut self, dirty_piece: DirtyPiece) {
let prev_idx = self.current_idx;
self.current_idx = self.entries.len();
self.entries.push(AccumulatorEntryHalfKA_hm {
accumulator: AccumulatorHalfKA_hm::new(),
dirty_piece,
previous: Some(prev_idx),
});
}
pub fn pop(&mut self) {
if let Some(prev) = self.entries[self.current_idx].previous {
self.current_idx = prev;
}
self.entries.truncate(self.current_idx + 1);
}
pub fn reset(&mut self) {
self.current_idx = 0;
self.entries.truncate(1);
self.entries[0].accumulator.computed_accumulation = false;
self.entries[0].dirty_piece.clear();
self.entries[0].previous = None;
}
pub fn find_usable_accumulator(&self) -> Option<(usize, usize)> {
const MAX_DEPTH: usize = 1;
let current = &self.entries[self.current_idx];
if current.dirty_piece.king_moved[0] || current.dirty_piece.king_moved[1] {
return None;
}
let mut prev_idx = current.previous?;
let mut depth = 1;
loop {
let prev = &self.entries[prev_idx];
if prev.accumulator.computed_accumulation {
return Some((prev_idx, depth));
}
if depth >= MAX_DEPTH {
return None;
}
let next_prev_idx = prev.previous?;
if prev.dirty_piece.king_moved[0] || prev.dirty_piece.king_moved[1] {
return None;
}
prev_idx = next_prev_idx;
depth += 1;
}
}
pub fn entry_at(&self, idx: usize) -> &AccumulatorEntryHalfKA_hm<L1> {
&self.entries[idx]
}
pub fn entry_at_mut(&mut self, idx: usize) -> &mut AccumulatorEntryHalfKA_hm<L1> {
&mut self.entries[idx]
}
pub fn get_prev_and_current_accumulators(
&mut self,
prev_idx: usize,
) -> (&AccumulatorHalfKA_hm<L1>, &mut AccumulatorHalfKA_hm<L1>) {
let current_idx = self.current_idx;
if prev_idx < current_idx {
let (left, right) = self.entries.split_at_mut(current_idx);
(&left[prev_idx].accumulator, &mut right[0].accumulator)
} else {
let (left, right) = self.entries.split_at_mut(prev_idx);
(&right[0].accumulator, &mut left[current_idx].accumulator)
}
}
pub fn current_index(&self) -> usize {
self.current_idx
}
pub fn collect_path(&self, source_idx: usize) -> Option<IndexList<MAX_PATH_LENGTH>> {
let mut path = IndexList::new();
let mut idx = self.current_idx;
while idx != source_idx {
if !path.push(idx) {
return None;
}
match self.entries[idx].previous {
Some(prev) => idx = prev,
None => return None,
}
}
path.reverse();
Some(path)
}
}
impl<const L1: usize> Default for AccumulatorStackHalfKA_hm<L1> {
fn default() -> Self {
Self::new()
}
}
pub struct FeatureTransformerHalfKA_hm<const L1: usize> {
pub biases: Vec<i16>,
pub weights: AlignedBox<i16>,
}
impl<const L1: usize> FeatureTransformerHalfKA_hm<L1> {
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let input_dim = HALFKA_HM_DIMENSIONS;
let mut biases = vec![0i16; L1];
let mut buf = [0u8; 2];
for bias in biases.iter_mut() {
reader.read_exact(&mut buf)?;
*bias = i16::from_le_bytes(buf);
}
let weight_size = input_dim * L1;
let mut weights = AlignedBox::new_zeroed(weight_size);
for weight in weights.iter_mut() {
reader.read_exact(&mut buf)?;
*weight = i16::from_le_bytes(buf);
}
Ok(Self { biases, weights })
}
pub fn refresh_accumulator(&self, pos: &Position, acc: &mut AccumulatorHalfKA_hm<L1>) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let accumulation = &mut acc.accumulation[p];
accumulation.copy_from_slice(&self.biases);
let active_indices = HalfKA_hm_FeatureSet::collect_active_indices(pos, perspective);
for &index in active_indices.iter() {
self.add_weights(accumulation, index);
}
}
acc.computed_accumulation = true;
}
pub fn update_accumulator(
&self,
pos: &Position,
dirty_piece: &DirtyPiece,
acc: &mut AccumulatorHalfKA_hm<L1>,
prev_acc: &AccumulatorHalfKA_hm<L1>,
) {
for perspective in [Color::Black, Color::White] {
let p = perspective as usize;
let reset = HalfKA_hm_FeatureSet::needs_refresh(dirty_piece, perspective);
if reset {
acc.accumulation[p].copy_from_slice(&self.biases);
let active_indices = HalfKA_hm_FeatureSet::collect_active_indices(pos, perspective);
for &index in active_indices.iter() {
self.add_weights(&mut acc.accumulation[p], index);
}
} else {
let (removed, added) = HalfKA_hm_FeatureSet::collect_changed_indices(
dirty_piece,
perspective,
pos.king_square(perspective),
);
acc.accumulation[p].copy_from_slice(&prev_acc.accumulation[p]);
for &index in removed.iter() {
self.sub_weights(&mut acc.accumulation[p], index);
}
for &index in added.iter() {
self.add_weights(&mut acc.accumulation[p], index);
}
}
}
acc.computed_accumulation = true;
}
pub fn forward_update_incremental(
&self,
pos: &Position,
stack: &mut AccumulatorStackHalfKA_hm<L1>,
source_idx: usize,
) -> bool {
let Some(path) = stack.collect_path(source_idx) else {
return false;
};
{
let (source_acc, current_acc) = stack.get_prev_and_current_accumulators(source_idx);
for p in 0..2 {
current_acc.accumulation[p].copy_from_slice(&source_acc.accumulation[p]);
}
}
let current_idx = stack.current_index();
for &entry_idx in path.iter() {
let dirty_piece = stack.entry_at(entry_idx).dirty_piece;
for perspective in [Color::Black, Color::White] {
debug_assert!(
!dirty_piece.king_moved[perspective.index()],
"King moved between source and current"
);
let king_sq = pos.king_square(perspective);
let (removed, added) = HalfKA_hm_FeatureSet::collect_changed_indices(
&dirty_piece,
perspective,
king_sq,
);
let p = perspective as usize;
let accumulation = &mut stack.entry_at_mut(current_idx).accumulator.accumulation[p];
for &index in removed.iter() {
self.sub_weights(accumulation, index);
}
for &index in added.iter() {
self.add_weights(accumulation, index);
}
}
}
stack.entry_at_mut(current_idx).accumulator.computed_accumulation = true;
true
}
#[inline]
fn add_weights(&self, accumulation: &mut [i16], index: usize) {
let offset = index * L1;
let weights = &self.weights[offset..offset + L1];
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
let num_chunks = L1 / 16;
for i in 0..num_chunks {
let acc_vec = _mm256_load_si256(acc_ptr.add(i * 16) as *const __m256i);
let weight_vec = _mm256_load_si256(weight_ptr.add(i * 16) as *const __m256i);
let result = _mm256_add_epi16(acc_vec, weight_vec);
_mm256_store_si256(acc_ptr.add(i * 16) as *mut __m256i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "avx2")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
let num_chunks = L1 / 8;
for i in 0..num_chunks {
let acc_vec = _mm_load_si128(acc_ptr.add(i * 8) as *const __m128i);
let weight_vec = _mm_load_si128(weight_ptr.add(i * 8) as *const __m128i);
let result = _mm_add_epi16(acc_vec, weight_vec);
_mm_store_si128(acc_ptr.add(i * 8) as *mut __m128i, result);
}
}
return;
}
#[allow(unreachable_code)]
for (acc, &w) in accumulation.iter_mut().zip(weights) {
*acc = acc.wrapping_add(w);
}
}
#[inline]
fn sub_weights(&self, accumulation: &mut [i16], index: usize) {
let offset = index * L1;
let weights = &self.weights[offset..offset + L1];
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
let num_chunks = L1 / 16;
for i in 0..num_chunks {
let acc_vec = _mm256_load_si256(acc_ptr.add(i * 16) as *const __m256i);
let weight_vec = _mm256_load_si256(weight_ptr.add(i * 16) as *const __m256i);
let result = _mm256_sub_epi16(acc_vec, weight_vec);
_mm256_store_si256(acc_ptr.add(i * 16) as *mut __m256i, result);
}
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "sse2",
not(target_feature = "avx2")
))]
{
unsafe {
use std::arch::x86_64::*;
let acc_ptr = accumulation.as_mut_ptr();
let weight_ptr = weights.as_ptr();
let num_chunks = L1 / 8;
for i in 0..num_chunks {
let acc_vec = _mm_load_si128(acc_ptr.add(i * 8) as *const __m128i);
let weight_vec = _mm_load_si128(weight_ptr.add(i * 8) as *const __m128i);
let result = _mm_sub_epi16(acc_vec, weight_vec);
_mm_store_si128(acc_ptr.add(i * 8) as *mut __m128i, result);
}
}
return;
}
#[allow(unreachable_code)]
for (acc, &w) in accumulation.iter_mut().zip(weights) {
*acc = acc.wrapping_sub(w);
}
}
pub fn transform_raw(
&self,
acc: &AccumulatorHalfKA_hm<L1>,
side_to_move: Color,
output: &mut [i16],
) {
let perspectives = [side_to_move, !side_to_move];
for (p, &perspective) in perspectives.iter().enumerate() {
let out_offset = L1 * p;
let accumulation = &acc.accumulation[perspective as usize];
output[out_offset..out_offset + L1].copy_from_slice(accumulation);
}
}
}
pub struct AffineTransformHalfKA_hm<const INPUT: usize, const OUTPUT: usize> {
pub biases: [i32; OUTPUT],
pub weights: AlignedBox<i8>,
}
impl<const INPUT: usize, const OUTPUT: usize> AffineTransformHalfKA_hm<INPUT, OUTPUT> {
const PADDED_INPUT: usize = INPUT.div_ceil(32) * 32;
#[cfg(any(target_feature = "avx2", target_feature = "ssse3"))]
const CHUNK_SIZE: usize = 4;
#[cfg(any(target_feature = "avx2", target_feature = "ssse3"))]
const NUM_INPUT_CHUNKS: usize = Self::PADDED_INPUT / Self::CHUNK_SIZE;
#[cfg(any(target_feature = "avx2", target_feature = "ssse3"))]
#[inline]
const fn should_use_scrambled_weights() -> bool {
if cfg!(all(target_arch = "x86_64", target_feature = "avx2")) {
OUTPUT.is_multiple_of(8)
} else if cfg!(all(
target_arch = "x86_64",
target_feature = "ssse3",
not(target_feature = "avx2")
)) {
OUTPUT.is_multiple_of(4)
} else {
false
}
}
#[cfg(any(target_feature = "avx2", target_feature = "ssse3"))]
#[inline]
const fn get_weight_index_scrambled(i: usize) -> usize {
(i / Self::CHUNK_SIZE) % Self::NUM_INPUT_CHUNKS * OUTPUT * Self::CHUNK_SIZE
+ i / Self::PADDED_INPUT * Self::CHUNK_SIZE
+ i % Self::CHUNK_SIZE
}
pub fn read<R: Read>(reader: &mut R) -> io::Result<Self> {
let mut biases = [0i32; OUTPUT];
let mut buf4 = [0u8; 4];
for bias in biases.iter_mut() {
reader.read_exact(&mut buf4)?;
*bias = i32::from_le_bytes(buf4);
}
let weight_size = OUTPUT * Self::PADDED_INPUT;
let mut weights = AlignedBox::new_zeroed(weight_size);
#[cfg(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(
target_arch = "x86_64",
target_feature = "ssse3",
not(target_feature = "avx2")
)
))]
{
let mut buf1 = [0u8; 1];
for i in 0..weight_size {
reader.read_exact(&mut buf1)?;
let idx = if Self::should_use_scrambled_weights() {
Self::get_weight_index_scrambled(i)
} else {
i
};
weights[idx] = buf1[0] as i8;
}
}
#[cfg(not(any(
all(target_arch = "x86_64", target_feature = "avx2"),
all(
target_arch = "x86_64",
target_feature = "ssse3",
not(target_feature = "avx2")
)
)))]
{
let mut row_buf = vec![0u8; Self::PADDED_INPUT];
for o in 0..OUTPUT {
reader.read_exact(&mut row_buf)?;
for i in 0..Self::PADDED_INPUT {
weights[o * Self::PADDED_INPUT + i] = row_buf[i] as i8;
}
}
}
Ok(Self { biases, weights })
}
pub fn propagate(&self, input: &[u8], output: &mut [i32; OUTPUT]) {
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
unsafe {
self.propagate_avx2_loop_inverted(input, output);
}
return;
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "ssse3",
not(target_feature = "avx2")
))]
{
unsafe {
self.propagate_ssse3_loop_inverted(input, output);
}
return;
}
#[allow(unreachable_code)]
{
output.copy_from_slice(&self.biases);
for (j, out) in output.iter_mut().enumerate() {
let weight_offset = j * Self::PADDED_INPUT;
for (i, &in_val) in input.iter().enumerate().take(INPUT) {
*out += self.weights[weight_offset + i] as i32 * in_val as i32;
}
}
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
#[inline]
#[allow(clippy::needless_range_loop)]
unsafe fn propagate_avx2_loop_inverted(&self, input: &[u8], output: &mut [i32; OUTPUT]) {
unsafe {
use std::arch::x86_64::*;
if OUTPUT.is_multiple_of(8) {
const MAX_REGS: usize = 128; let num_regs = OUTPUT / 8;
debug_assert!(num_regs <= MAX_REGS);
let mut acc = [_mm256_setzero_si256(); MAX_REGS];
let bias_ptr = self.biases.as_ptr() as *const __m256i;
for k in 0..num_regs {
acc[k] = _mm256_loadu_si256(bias_ptr.add(k));
}
let input32 = input.as_ptr() as *const i32;
let weights_ptr = self.weights.as_ptr();
for i in 0..Self::NUM_INPUT_CHUNKS {
let in_val = _mm256_set1_epi32(*input32.add(i));
let col = weights_ptr.add(i * OUTPUT * Self::CHUNK_SIZE) as *const __m256i;
for k in 0..num_regs {
m256_add_dpbusd_epi32(&mut acc[k], in_val, _mm256_load_si256(col.add(k)));
}
}
let out_ptr = output.as_mut_ptr() as *mut __m256i;
for k in 0..num_regs {
_mm256_storeu_si256(out_ptr.add(k), acc[k]);
}
} else {
output.copy_from_slice(&self.biases);
let num_chunks = Self::PADDED_INPUT / 32;
let input_ptr = input.as_ptr();
let weight_ptr = self.weights.as_ptr();
for (j, out) in output.iter_mut().enumerate() {
let mut acc_simd = _mm256_setzero_si256();
let row_offset = j * Self::PADDED_INPUT;
for chunk in 0..num_chunks {
let in_vec =
_mm256_loadu_si256(input_ptr.add(chunk * 32) as *const __m256i);
let w_vec = _mm256_load_si256(
weight_ptr.add(row_offset + chunk * 32) as *const __m256i
);
m256_add_dpbusd_epi32(&mut acc_simd, in_vec, w_vec);
}
*out += hsum_i32_avx2(acc_simd);
}
}
}
}
#[cfg(all(
target_arch = "x86_64",
target_feature = "ssse3",
not(target_feature = "avx2")
))]
#[inline]
unsafe fn propagate_ssse3_loop_inverted(&self, input: &[u8], output: &mut [i32; OUTPUT]) {
unsafe {
use std::arch::x86_64::*;
if OUTPUT % 4 == 0 && OUTPUT > 0 {
const MAX_REGS: usize = 256; let num_regs = OUTPUT / 4;
debug_assert!(num_regs <= MAX_REGS);
let mut acc = [_mm_setzero_si128(); MAX_REGS];
let bias_ptr = self.biases.as_ptr() as *const __m128i;
for k in 0..num_regs {
acc[k] = _mm_loadu_si128(bias_ptr.add(k));
}
let input32 = input.as_ptr() as *const i32;
let weights_ptr = self.weights.as_ptr();
for i in 0..Self::NUM_INPUT_CHUNKS {
let in_val = _mm_set1_epi32(*input32.add(i));
let col = weights_ptr.add(i * OUTPUT * Self::CHUNK_SIZE) as *const __m128i;
for k in 0..num_regs {
m128_add_dpbusd_epi32(&mut acc[k], in_val, _mm_load_si128(col.add(k)));
}
}
let out_ptr = output.as_mut_ptr() as *mut __m128i;
for k in 0..num_regs {
_mm_storeu_si128(out_ptr.add(k), acc[k]);
}
} else {
output.copy_from_slice(&self.biases);
let num_chunks = Self::PADDED_INPUT / 16;
let input_ptr = input.as_ptr();
let weight_ptr = self.weights.as_ptr();
for (j, out) in output.iter_mut().enumerate() {
let mut acc_simd = _mm_setzero_si128();
let row_offset = j * Self::PADDED_INPUT;
for chunk in 0..num_chunks {
let in_vec = _mm_loadu_si128(input_ptr.add(chunk * 16) as *const __m128i);
let w_vec = _mm_load_si128(
weight_ptr.add(row_offset + chunk * 16) as *const __m128i
);
m128_add_dpbusd_epi32(&mut acc_simd, in_vec, w_vec);
}
*out += hsum_i32_sse2(acc_simd);
}
}
}
}
}
pub struct NetworkHalfKA_hm<
const L1: usize,
const FT_OUT: usize,
const L1_INPUT: usize,
const L2: usize,
const L3: usize,
A: FtActivation,
> {
pub feature_transformer: FeatureTransformerHalfKA_hm<L1>,
pub l1: AffineTransformHalfKA_hm<L1_INPUT, L2>,
pub l2: AffineTransformHalfKA_hm<L2, L3>,
pub output: AffineTransformHalfKA_hm<L3, 1>,
pub fv_scale: i32,
pub qa: i16,
_activation: PhantomData<A>,
}
impl<
const L1: usize,
const FT_OUT: usize,
const L1_INPUT: usize,
const L2: usize,
const L3: usize,
A: FtActivation,
> NetworkHalfKA_hm<L1, FT_OUT, L1_INPUT, L2, L3, A>
{
const _ASSERT_DIMS: () = {
assert!(FT_OUT == L1 * 2, "FT_OUT must equal L1 * 2");
assert!(
L1_INPUT == L1 * 2 || L1_INPUT == L1,
"L1_INPUT must equal L1 * 2 (CReLU/SCReLU) or L1 (Pairwise)"
);
};
pub fn read<R: Read + Seek>(reader: &mut R) -> io::Result<Self> {
let mut buf4 = [0u8; 4];
reader.read_exact(&mut buf4)?;
let version = u32::from_le_bytes(buf4);
if version != 0x7AF32F16 && version != NNUE_VERSION_HALFKA {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Unknown NNUE version: {version:#x}"),
));
}
reader.read_exact(&mut buf4)?;
reader.read_exact(&mut buf4)?;
let arch_len = u32::from_le_bytes(buf4) as usize;
if arch_len == 0 || arch_len > MAX_ARCH_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid arch string length: {arch_len}"),
));
}
let mut arch = vec![0u8; arch_len];
reader.read_exact(&mut arch)?;
let arch_str = String::from_utf8_lossy(&arch);
if arch_str.contains("Factorizer") {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unsupported model format: factorized (non-coalesced) HalfKA_hm^ model detected.\n\
This engine only supports coalesced models (73,305 dimensions).\n\
Factorized models (74,934 dimensions) are for training only.\n\n\
To fix: Re-export the model using nnue-pytorch serialize.py:\n\
python serialize.py model.ckpt output.nnue\n\n\
The serialize.py script automatically coalesces factor weights.\n\
Architecture string: {arch_str}"
),
));
}
let fv_scale = parse_fv_scale_from_arch(&arch_str).unwrap_or(FV_SCALE_HALFKA);
let qa = parse_qa_from_arch(&arch_str).unwrap_or(127);
reader.read_exact(&mut buf4)?;
let feature_transformer = FeatureTransformerHalfKA_hm::read(reader)?;
reader.read_exact(&mut buf4)?;
let l1 = AffineTransformHalfKA_hm::read(reader)?;
let l2 = AffineTransformHalfKA_hm::read(reader)?;
let output = AffineTransformHalfKA_hm::read(reader)?;
Ok(Self {
feature_transformer,
l1,
l2,
output,
fv_scale,
qa,
_activation: PhantomData,
})
}
pub fn refresh_accumulator(&self, pos: &Position, acc: &mut AccumulatorHalfKA_hm<L1>) {
self.feature_transformer.refresh_accumulator(pos, acc);
}
pub fn update_accumulator(
&self,
pos: &Position,
dirty_piece: &DirtyPiece,
acc: &mut AccumulatorHalfKA_hm<L1>,
prev_acc: &AccumulatorHalfKA_hm<L1>,
) {
self.feature_transformer.update_accumulator(pos, dirty_piece, acc, prev_acc);
}
pub fn forward_update_incremental(
&self,
pos: &Position,
stack: &mut AccumulatorStackHalfKA_hm<L1>,
source_idx: usize,
) -> bool {
self.feature_transformer.forward_update_incremental(pos, stack, source_idx)
}
pub fn evaluate(&self, pos: &Position, acc: &AccumulatorHalfKA_hm<L1>) -> Value {
let debug = nnue_debug_enabled();
let mut ft_out_i16: Aligned<[i16; FT_OUT]> = unsafe { Aligned::new_uninit() };
self.feature_transformer
.transform_raw(acc, pos.side_to_move(), &mut ft_out_i16.0);
if debug {
let ft_min = ft_out_i16.0.iter().min().copied().unwrap_or(0);
let ft_max = ft_out_i16.0.iter().max().copied().unwrap_or(0);
let ft_sum: i64 = ft_out_i16.0.iter().map(|&x| x as i64).sum();
eprintln!(
"[DEBUG] FT output: min={ft_min}, max={ft_max}, sum={ft_sum}, len={}",
ft_out_i16.0.len()
);
eprintln!("[DEBUG] FT[0..8]: {:?}", &ft_out_i16.0[0..8]);
}
let mut transformed: Aligned<[u8; L1_INPUT]> = unsafe { Aligned::new_uninit() };
A::activate_i16_to_u8(&ft_out_i16.0, &mut transformed.0, self.qa);
if debug {
let t_min = transformed.0.iter().min().copied().unwrap_or(0);
let t_max = transformed.0.iter().max().copied().unwrap_or(0);
let t_sum: u64 = transformed.0.iter().map(|&x| x as u64).sum();
eprintln!(
"[DEBUG] After activation ({} i16→u8): min={t_min}, max={t_max}, sum={t_sum}, len={}",
A::name(),
transformed.0.len()
);
eprintln!("[DEBUG] transformed[0..16]: {:?}", &transformed.0[0..16]);
}
let mut l1_out: Aligned<[i32; L2]> = unsafe { Aligned::new_uninit() };
self.l1.propagate(&transformed.0, &mut l1_out.0);
if debug {
eprintln!("[DEBUG] L1 output: {:?}", &l1_out.0);
eprintln!(
"[DEBUG] L1 biases[0..8]: {:?}",
&self.l1.biases[0..8.min(self.l1.biases.len())]
);
}
#[cfg(debug_assertions)]
for (i, &v) in l1_out.0.iter().enumerate() {
debug_assert!(
v.abs() < 1_000_000,
"L1 output[{i}] = {v} is out of expected range (NetworkHalfKA_hm<{}, {}, {}, {}>)",
L1,
L2,
L3,
A::name()
);
}
let mut l1_relu: Aligned<[u8; L2]> = unsafe { Aligned::new_uninit() };
A::activate_i32_to_u8(&l1_out.0, &mut l1_relu.0);
let mut l2_out: Aligned<[i32; L3]> = unsafe { Aligned::new_uninit() };
self.l2.propagate(&l1_relu.0, &mut l2_out.0);
#[cfg(debug_assertions)]
for (i, &v) in l2_out.0.iter().enumerate() {
debug_assert!(
v.abs() < 1_000_000,
"L2 output[{i}] = {v} is out of expected range (NetworkHalfKA_hm<{}, {}, {}, {}>)",
L1,
L2,
L3,
A::name()
);
}
let mut l2_relu: Aligned<[u8; L3]> = unsafe { Aligned::new_uninit() };
A::activate_i32_to_u8(&l2_out.0, &mut l2_relu.0);
let mut output = [0i32; 1];
self.output.propagate(&l2_relu.0, &mut output);
let fv_scale = get_fv_scale_override().unwrap_or(self.fv_scale);
let eval = output[0] / fv_scale;
#[cfg(debug_assertions)]
debug_assert!(
eval.abs() < 50_000,
"Final evaluation {eval} is out of expected range (NetworkHalfKA_hm<{}, {}, {}, {}>). Raw output: {}",
L1,
L2,
L3,
A::name(),
output[0]
);
Value::new(eval)
}
pub fn activation_name(&self) -> &'static str {
A::name()
}
pub fn new_accumulator(&self) -> AccumulatorHalfKA_hm<L1> {
AccumulatorHalfKA_hm::new()
}
pub fn new_accumulator_stack(&self) -> AccumulatorStackHalfKA_hm<L1> {
AccumulatorStackHalfKA_hm::new()
}
pub fn architecture_name(&self) -> String {
format!("HalfKA_hm^{}x2-{}-{}-{}", L1, L2, L3, A::name())
}
}
fn parse_qa_from_arch(arch_str: &str) -> Option<i16> {
if let Some(start) = arch_str.find("qa=") {
let rest = &arch_str[start + 3..];
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len());
rest[..end].parse().ok()
} else {
None
}
}
use super::activation::CReLU;
pub type HalfKA_hm256CReLU = NetworkHalfKA_hm<256, 512, 512, 32, 32, CReLU>;
pub type HalfKA_hm512_8_64CReLU = NetworkHalfKA_hm<512, 1024, 1024, 8, 64, CReLU>;
pub type HalfKA_hm512CReLU = NetworkHalfKA_hm<512, 1024, 1024, 8, 96, CReLU>;
pub type HalfKA_hm512_32_32CReLU = NetworkHalfKA_hm<512, 1024, 1024, 32, 32, CReLU>;
pub type HalfKA_hm1024_8_64CReLU = NetworkHalfKA_hm<1024, 2048, 2048, 8, 64, CReLU>;
pub type HalfKA_hm1024CReLU = NetworkHalfKA_hm<1024, 2048, 2048, 8, 96, CReLU>;
pub type HalfKA_hm1024_8_32CReLU = NetworkHalfKA_hm<1024, 2048, 2048, 8, 32, CReLU>;
pub type HalfKA_hm768CReLU = NetworkHalfKA_hm<768, 1536, 1536, 16, 64, CReLU>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_accumulator_halfka_256() {
let mut acc = AccumulatorHalfKA_hm::<256>::new();
assert_eq!(acc.accumulation[0].len(), 256);
assert!(!acc.computed_accumulation);
acc.accumulation[0][0] = 100;
acc.computed_accumulation = true;
let cloned = acc.clone();
assert_eq!(cloned.accumulation[0][0], 100);
assert!(cloned.computed_accumulation);
}
#[test]
fn test_accumulator_halfka_512() {
let acc = AccumulatorHalfKA_hm::<512>::new();
assert_eq!(acc.accumulation[0].len(), 512);
}
#[test]
fn test_accumulator_halfka_1024() {
let acc = AccumulatorHalfKA_hm::<1024>::new();
assert_eq!(acc.accumulation[0].len(), 1024);
}
#[test]
fn test_padded_input() {
assert_eq!(AffineTransformHalfKA_hm::<8, 96>::PADDED_INPUT, 32);
assert_eq!(AffineTransformHalfKA_hm::<32, 96>::PADDED_INPUT, 32);
assert_eq!(AffineTransformHalfKA_hm::<33, 96>::PADDED_INPUT, 64);
assert_eq!(AffineTransformHalfKA_hm::<96, 1>::PADDED_INPUT, 96);
assert_eq!(AffineTransformHalfKA_hm::<1024, 8>::PADDED_INPUT, 1024);
assert_eq!(AffineTransformHalfKA_hm::<2048, 8>::PADDED_INPUT, 2048);
}
#[test]
fn test_parse_qa_from_arch() {
assert_eq!(
parse_qa_from_arch(
"Features=HalfKA_hm[73305->512x2],fv_scale=5,l1_input=1024,l2=8,l3=96,qa=255,qb=64"
),
Some(255)
);
assert_eq!(
parse_qa_from_arch(
"Features=HalfKA_hm[73305->512x2],fv_scale=5,l1_input=1024,l2=8,l3=96,qa=127,qb=64"
),
Some(127)
);
assert_eq!(
parse_qa_from_arch("Features=HalfKA_hm[73305->512x2],Network=AffineTransform[1<-96]"),
None
);
}
#[test]
fn test_type_aliases() {
fn _check_halfka_256_crelu(_: HalfKA_hm256CReLU) {}
fn _check_halfka_512_crelu(_: HalfKA_hm512CReLU) {}
fn _check_halfka_1024_crelu(_: HalfKA_hm1024CReLU) {}
}
}