use super::accumulator_layer_stacks::{AccumulatorCacheLayerStacks, AccumulatorStackLayerStacks};
use super::accumulator_stack_variant::AccumulatorStackVariant;
use super::activation::detect_activation_from_arch;
use super::bona_piece::BonaPiece;
use super::bona_piece_halfka_hm::FE_OLD_END;
use super::constants::{MAX_ARCH_LEN, NNUE_VERSION, NNUE_VERSION_HALFKA};
use super::halfka::{HalfKANetwork, HalfKAStack};
use super::halfka_hm::{HalfKA_hmNetwork, HalfKA_hmStack};
use super::halfkp::{HalfKPNetwork, HalfKPStack};
use super::network_layer_stacks::NetworkLayerStacks;
use super::spec::{Activation, FeatureSet};
use super::stats::{count_already_computed, count_refresh, count_update};
use crate::eval::material;
use crate::position::Position;
use crate::types::{Color, PieceType, Value};
use std::cell::Cell;
use std::fs::File;
use std::io::{self, BufReader, Cursor, Read, Seek, SeekFrom};
use std::path::Path;
use std::sync::atomic::{AtomicI32, AtomicPtr, Ordering};
use std::sync::{Arc, LazyLock, RwLock};
static NETWORK: LazyLock<RwLock<Option<Arc<NNUENetwork>>>> = LazyLock::new(|| RwLock::new(None));
static FV_SCALE_OVERRIDE: AtomicI32 = AtomicI32::new(0);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerStackBucketMode {
KingRank9 = 0,
Ply9 = 1,
Progress8 = 2,
Progress8Gikou = 3,
Progress8KPAbs = 4,
}
impl LayerStackBucketMode {
pub fn as_str(self) -> &'static str {
match self {
Self::KingRank9 => "kingrank9",
Self::Ply9 => "ply9",
Self::Progress8 => "progress8",
Self::Progress8Gikou => "progress8gikou",
Self::Progress8KPAbs => "progress8kpabs",
}
}
}
pub const LAYER_STACK_PLY9_DEFAULT_BOUNDS: [u16; 8] = [30, 44, 58, 72, 86, 100, 116, 138];
pub const SHOGI_PROGRESS8_NUM_FEATURES: usize = 6;
pub const SHOGI_PROGRESS8_NUM_BUCKETS: usize = 8;
pub const SHOGI_PROGRESS8_FEATURE_ORDER: [&str; SHOGI_PROGRESS8_NUM_FEATURES] = [
"x_board_non_king",
"x_hand_total",
"x_major_board",
"x_promoted_board",
"x_stm_king_rank_rel",
"x_ntm_king_rank_rel",
];
pub const SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES: usize = 34;
pub const SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS: usize = 81 * FE_OLD_END;
const PROGRESS_BUCKET_THRESHOLDS: [f32; 7] = [
-1.945_910_1, -1.098_612_3, -0.510_825_6, 0.0, 0.510_825_6, 1.098_612_3, 1.945_910_1, ];
thread_local! {
static CACHED_PROGRESS_BUCKET: Cell<Option<usize>> = const { Cell::new(None) };
}
pub const SHOGI_PROGRESS_GIKOU_LITE_FEATURE_ORDER: [&str; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES] = [
"x_board_non_king",
"x_hand_total",
"x_major_board",
"x_promoted_board",
"x_stm_king_rank_rel",
"x_ntm_king_rank_rel",
"x_stm_all_to_own_king_d1",
"x_stm_all_to_own_king_d2",
"x_stm_all_to_own_king_d3p",
"x_stm_all_to_opp_king_d1",
"x_stm_all_to_opp_king_d2",
"x_stm_all_to_opp_king_d3p",
"x_ntm_all_to_own_king_d1",
"x_ntm_all_to_own_king_d2",
"x_ntm_all_to_own_king_d3p",
"x_ntm_all_to_opp_king_d1",
"x_ntm_all_to_opp_king_d2",
"x_ntm_all_to_opp_king_d3p",
"x_stm_major_to_own_king_d1",
"x_stm_major_to_own_king_d2",
"x_stm_major_to_own_king_d3p",
"x_stm_major_to_opp_king_d1",
"x_stm_major_to_opp_king_d2",
"x_stm_major_to_opp_king_d3p",
"x_ntm_major_to_own_king_d1",
"x_ntm_major_to_own_king_d2",
"x_ntm_major_to_own_king_d3p",
"x_ntm_major_to_opp_king_d1",
"x_ntm_major_to_opp_king_d2",
"x_ntm_major_to_opp_king_d3p",
"x_stm_hand_total",
"x_ntm_hand_total",
"x_stm_hand_major",
"x_ntm_hand_major",
];
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerStackProgressCoeff {
pub mean: [f32; SHOGI_PROGRESS8_NUM_FEATURES],
pub std: [f32; SHOGI_PROGRESS8_NUM_FEATURES],
pub weights: [f32; SHOGI_PROGRESS8_NUM_FEATURES],
pub bias: f32,
pub z_clip: [f32; 2],
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerStackProgressCoeffGikouLite {
pub mean: [f32; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES],
pub std: [f32; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES],
pub weights: [f32; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES],
pub bias: f32,
pub z_clip: [f32; 2],
}
impl LayerStackProgressCoeffGikouLite {
pub const fn new(
mean: [f32; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES],
std: [f32; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES],
weights: [f32; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES],
bias: f32,
z_clip: [f32; 2],
) -> Self {
Self {
mean,
std,
weights,
bias,
z_clip,
}
}
}
impl LayerStackProgressCoeff {
pub const fn new(
mean: [f32; SHOGI_PROGRESS8_NUM_FEATURES],
std: [f32; SHOGI_PROGRESS8_NUM_FEATURES],
weights: [f32; SHOGI_PROGRESS8_NUM_FEATURES],
bias: f32,
z_clip: [f32; 2],
) -> Self {
Self {
mean,
std,
weights,
bias,
z_clip,
}
}
}
impl Default for LayerStackProgressCoeff {
fn default() -> Self {
Self {
mean: [30.12, 8.45, 2.18, 1.63, 6.71, 6.24],
std: [3.77, 4.02, 0.66, 1.40, 1.31, 1.27],
weights: [-0.81, 0.56, -0.32, 0.48, 0.11, -0.09],
bias: -0.15,
z_clip: [-8.0, 8.0],
}
}
}
impl Default for LayerStackProgressCoeffGikouLite {
fn default() -> Self {
Self {
mean: [0.0; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES],
std: [1.0; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES],
weights: [0.0; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES],
bias: 0.0,
z_clip: [-8.0, 8.0],
}
}
}
static LAYER_STACK_BUCKET_MODE: AtomicI32 = AtomicI32::new(LayerStackBucketMode::KingRank9 as i32);
static LAYER_STACK_PLY_BOUNDS: [AtomicI32; 8] = [
AtomicI32::new(LAYER_STACK_PLY9_DEFAULT_BOUNDS[0] as i32),
AtomicI32::new(LAYER_STACK_PLY9_DEFAULT_BOUNDS[1] as i32),
AtomicI32::new(LAYER_STACK_PLY9_DEFAULT_BOUNDS[2] as i32),
AtomicI32::new(LAYER_STACK_PLY9_DEFAULT_BOUNDS[3] as i32),
AtomicI32::new(LAYER_STACK_PLY9_DEFAULT_BOUNDS[4] as i32),
AtomicI32::new(LAYER_STACK_PLY9_DEFAULT_BOUNDS[5] as i32),
AtomicI32::new(LAYER_STACK_PLY9_DEFAULT_BOUNDS[6] as i32),
AtomicI32::new(LAYER_STACK_PLY9_DEFAULT_BOUNDS[7] as i32),
];
static LAYER_STACK_PROGRESS_COEFF: LazyLock<RwLock<LayerStackProgressCoeff>> =
LazyLock::new(|| RwLock::new(LayerStackProgressCoeff::default()));
static LAYER_STACK_PROGRESS_COEFF_GIKOU_LITE: LazyLock<RwLock<LayerStackProgressCoeffGikouLite>> =
LazyLock::new(|| RwLock::new(LayerStackProgressCoeffGikouLite::default()));
static LAYER_STACK_PROGRESS_KP_ABS_ZERO_WEIGHTS: [f32; SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS] =
[0.0; SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS];
static LAYER_STACK_PROGRESS_KP_ABS_PTR: AtomicPtr<f32> = AtomicPtr::new(std::ptr::null_mut());
pub fn get_fv_scale_override() -> Option<i32> {
let value = FV_SCALE_OVERRIDE.load(Ordering::Relaxed);
if value > 0 { Some(value) } else { None }
}
pub fn set_fv_scale_override(value: i32) {
FV_SCALE_OVERRIDE.store(value.max(0), Ordering::Relaxed);
}
pub fn get_layer_stack_bucket_mode() -> LayerStackBucketMode {
match LAYER_STACK_BUCKET_MODE.load(Ordering::Relaxed) {
1 => LayerStackBucketMode::Ply9,
2 => LayerStackBucketMode::Progress8,
3 => LayerStackBucketMode::Progress8Gikou,
4 => LayerStackBucketMode::Progress8KPAbs,
_ => LayerStackBucketMode::KingRank9,
}
}
pub fn set_layer_stack_bucket_mode(mode: LayerStackBucketMode) {
LAYER_STACK_BUCKET_MODE.store(mode as i32, Ordering::Relaxed);
}
pub fn get_layer_stack_ply_bounds() -> [u16; 8] {
std::array::from_fn(|i| {
let value = LAYER_STACK_PLY_BOUNDS[i].load(Ordering::Relaxed);
if value < 0 { 0 } else { value as u16 }
})
}
pub fn set_layer_stack_ply_bounds(bounds: [u16; 8]) {
for (slot, &value) in LAYER_STACK_PLY_BOUNDS.iter().zip(bounds.iter()) {
slot.store(i32::from(value), Ordering::Relaxed);
}
}
pub fn get_layer_stack_progress_coeff() -> LayerStackProgressCoeff {
match LAYER_STACK_PROGRESS_COEFF.read() {
Ok(guard) => *guard,
Err(poisoned) => *poisoned.into_inner(),
}
}
pub fn set_layer_stack_progress_coeff(coeff: LayerStackProgressCoeff) {
match LAYER_STACK_PROGRESS_COEFF.write() {
Ok(mut guard) => *guard = coeff,
Err(poisoned) => *poisoned.into_inner() = coeff,
}
}
pub fn get_layer_stack_progress_coeff_gikou_lite() -> LayerStackProgressCoeffGikouLite {
match LAYER_STACK_PROGRESS_COEFF_GIKOU_LITE.read() {
Ok(guard) => *guard,
Err(poisoned) => *poisoned.into_inner(),
}
}
pub fn set_layer_stack_progress_coeff_gikou_lite(coeff: LayerStackProgressCoeffGikouLite) {
match LAYER_STACK_PROGRESS_COEFF_GIKOU_LITE.write() {
Ok(mut guard) => *guard = coeff,
Err(poisoned) => *poisoned.into_inner() = coeff,
}
}
pub fn get_layer_stack_progress_kpabs_weights() -> &'static [f32] {
let ptr = LAYER_STACK_PROGRESS_KP_ABS_PTR.load(Ordering::Relaxed);
if ptr.is_null() {
&LAYER_STACK_PROGRESS_KP_ABS_ZERO_WEIGHTS
} else {
unsafe { std::slice::from_raw_parts(ptr.cast_const(), SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS) }
}
}
pub fn set_layer_stack_progress_kpabs_weights(weights: Box<[f32]>) -> Result<(), String> {
if weights.len() != SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS {
return Err(format!(
"progress8kpabs weights length mismatch: got {}, expected {}",
weights.len(),
SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS
));
}
let leaked = Box::leak(weights);
let old_ptr = LAYER_STACK_PROGRESS_KP_ABS_PTR.swap(leaked.as_mut_ptr(), Ordering::Relaxed);
if !old_ptr.is_null() {
unsafe {
drop(Box::from_raw(std::ptr::slice_from_raw_parts_mut(
old_ptr,
SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS,
)));
}
}
Ok(())
}
pub fn reset_layer_stack_progress_kpabs_weights() {
let old_ptr = LAYER_STACK_PROGRESS_KP_ABS_PTR.swap(std::ptr::null_mut(), Ordering::Relaxed);
if !old_ptr.is_null() {
unsafe {
drop(Box::from_raw(std::ptr::slice_from_raw_parts_mut(
old_ptr,
SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS,
)));
}
}
}
pub enum NNUENetwork {
HalfKA(HalfKANetwork),
#[allow(non_camel_case_types)]
HalfKA_hm(HalfKA_hmNetwork),
HalfKP(HalfKPNetwork),
LayerStacks(Box<NetworkLayerStacks>),
}
impl NNUENetwork {
pub fn supported_halfkp_specs() -> Vec<super::spec::ArchitectureSpec> {
HalfKPNetwork::supported_specs()
}
pub fn supported_halfka_hm_specs() -> Vec<super::spec::ArchitectureSpec> {
HalfKA_hmNetwork::supported_specs()
}
pub fn supported_halfka_specs() -> Vec<super::spec::ArchitectureSpec> {
HalfKANetwork::supported_specs()
}
pub fn load<P: AsRef<Path>>(path: P) -> io::Result<Self> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
Self::read(&mut reader)
}
pub fn read<R: Read + Seek>(reader: &mut R) -> io::Result<Self> {
let file_size = reader.seek(SeekFrom::End(0))?;
reader.seek(SeekFrom::Start(0))?;
let mut buf4 = [0u8; 4];
reader.read_exact(&mut buf4)?;
let version = u32::from_le_bytes(buf4);
match version {
NNUE_VERSION | NNUE_VERSION_HALFKA => {
reader.read_exact(&mut buf4)?; reader.read_exact(&mut buf4)?; let arch_len = u32::from_le_bytes(buf4) as usize;
if arch_len == 0 || arch_len > MAX_ARCH_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid arch string length: {arch_len}"),
));
}
let mut arch = vec![0u8; arch_len];
reader.read_exact(&mut arch)?;
let arch_str = String::from_utf8_lossy(&arch);
let activation_str = detect_activation_from_arch(&arch_str);
let activation = match activation_str {
"SCReLU" => Activation::SCReLU,
"PairwiseCReLU" => Activation::PairwiseCReLU,
_ => Activation::CReLU,
};
let parsed = super::spec::parse_architecture(&arch_str)
.map_err(|msg| io::Error::new(io::ErrorKind::InvalidData, msg))?;
if parsed.feature_set == FeatureSet::LayerStacks {
reader.seek(SeekFrom::Start(0))?;
let network = NetworkLayerStacks::read(reader)?;
return Ok(Self::LayerStacks(Box::new(network)));
}
let detection = super::spec::detect_architecture_from_size(
file_size,
arch_len,
Some(parsed.feature_set),
)
.ok_or_else(|| {
let candidates = super::spec::list_candidate_architectures(file_size, arch_len);
let candidates_str: Vec<String> = candidates
.iter()
.take(5)
.map(|(spec, diff)| format!("{} (diff: {:+})", spec.name(), diff))
.collect();
io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unknown architecture: file_size={}, arch_len={}, feature_set={}. \
Closest candidates: [{}]",
file_size,
arch_len,
parsed.feature_set,
candidates_str.join(", ")
),
)
})?;
reader.seek(SeekFrom::Start(0))?;
let l1 = detection.spec.l1;
let l2 = detection.spec.l2;
let l3 = detection.spec.l3;
match detection.spec.feature_set {
FeatureSet::HalfKA_hm => {
let network = HalfKA_hmNetwork::read(reader, l1, l2, l3, activation)?;
Ok(Self::HalfKA_hm(network))
}
FeatureSet::HalfKA => {
let network = HalfKANetwork::read(reader, l1, l2, l3, activation)?;
Ok(Self::HalfKA(network))
}
FeatureSet::HalfKP => {
let network = HalfKPNetwork::read(reader, l1, l2, l3, activation)?;
Ok(Self::HalfKP(network))
}
FeatureSet::LayerStacks => {
unreachable!()
}
}
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unknown NNUE version: {version:#x}. Expected {NNUE_VERSION:#x} (HalfKP) or {NNUE_VERSION_HALFKA:#x} (HalfKA_hm^)"
),
)),
}
}
pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
let mut cursor = Cursor::new(bytes);
Self::read(&mut cursor)
}
pub fn is_layer_stacks(&self) -> bool {
matches!(self, Self::LayerStacks(_))
}
pub fn is_halfka(&self) -> bool {
matches!(self, Self::HalfKA(_))
}
pub fn is_halfka_hm(&self) -> bool {
matches!(self, Self::HalfKA_hm(_))
}
pub fn is_halfkp(&self) -> bool {
matches!(self, Self::HalfKP(_))
}
pub fn l1_size(&self) -> usize {
match self {
Self::HalfKA(net) => net.l1_size(),
Self::HalfKA_hm(net) => net.l1_size(),
Self::HalfKP(net) => net.l1_size(),
Self::LayerStacks(_) => 1536,
}
}
pub fn architecture_name(&self) -> &'static str {
match self {
Self::HalfKA(net) => net.architecture_name(),
Self::HalfKA_hm(net) => net.architecture_name(),
Self::HalfKP(net) => net.architecture_name(),
Self::LayerStacks(_) => "LayerStacks",
}
}
pub fn architecture_spec(&self) -> super::spec::ArchitectureSpec {
match self {
Self::HalfKA(net) => net.architecture_spec(),
Self::HalfKA_hm(net) => net.architecture_spec(),
Self::HalfKP(net) => net.architecture_spec(),
Self::LayerStacks(_) => super::spec::ArchitectureSpec::new(
super::spec::FeatureSet::LayerStacks,
1536,
0,
0,
Activation::CReLU,
),
}
}
pub fn refresh_accumulator_layer_stacks(
&self,
pos: &Position,
acc: &mut super::accumulator_layer_stacks::AccumulatorLayerStacks,
) {
match self {
Self::LayerStacks(net) => net.refresh_accumulator(pos, acc),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn update_accumulator_layer_stacks(
&self,
pos: &Position,
dirty_piece: &super::accumulator::DirtyPiece,
acc: &mut super::accumulator_layer_stacks::AccumulatorLayerStacks,
prev_acc: &super::accumulator_layer_stacks::AccumulatorLayerStacks,
) {
match self {
Self::LayerStacks(net) => net.update_accumulator(pos, dirty_piece, acc, prev_acc),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn refresh_accumulator_layer_stacks_with_cache(
&self,
pos: &Position,
acc: &mut super::accumulator_layer_stacks::AccumulatorLayerStacks,
cache: &mut AccumulatorCacheLayerStacks,
) {
match self {
Self::LayerStacks(net) => net.refresh_accumulator_with_cache(pos, acc, cache),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn update_accumulator_layer_stacks_with_cache(
&self,
pos: &Position,
dirty_piece: &super::accumulator::DirtyPiece,
acc: &mut super::accumulator_layer_stacks::AccumulatorLayerStacks,
prev_acc: &super::accumulator_layer_stacks::AccumulatorLayerStacks,
cache: &mut AccumulatorCacheLayerStacks,
) {
match self {
Self::LayerStacks(net) => {
net.update_accumulator_with_cache(pos, dirty_piece, acc, prev_acc, cache)
}
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn forward_update_incremental_layer_stacks(
&self,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
source_idx: usize,
) -> bool {
match self {
Self::LayerStacks(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn evaluate_layer_stacks(
&self,
pos: &Position,
acc: &super::accumulator_layer_stacks::AccumulatorLayerStacks,
) -> Value {
match self {
Self::LayerStacks(net) => net.evaluate(pos, acc),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn evaluate_layer_stacks_with_bucket(
&self,
pos: &Position,
acc: &super::accumulator_layer_stacks::AccumulatorLayerStacks,
bucket_index: usize,
) -> Value {
match self {
Self::LayerStacks(net) => net.evaluate_with_bucket(pos, acc, bucket_index),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn refresh_accumulator_halfka_hm(&self, pos: &Position, stack: &mut HalfKA_hmStack) {
match self {
Self::HalfKA_hm(net) => net.refresh_accumulator(pos, stack),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn refresh_accumulator_halfka(&self, pos: &Position, stack: &mut HalfKAStack) {
match self {
Self::HalfKA(net) => net.refresh_accumulator(pos, stack),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn update_accumulator_halfka_hm(
&self,
pos: &Position,
dirty: &super::accumulator::DirtyPiece,
stack: &mut HalfKA_hmStack,
source_idx: usize,
) {
match self {
Self::HalfKA_hm(net) => net.update_accumulator(pos, dirty, stack, source_idx),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn update_accumulator_halfka(
&self,
pos: &Position,
dirty: &super::accumulator::DirtyPiece,
stack: &mut HalfKAStack,
source_idx: usize,
) {
match self {
Self::HalfKA(net) => net.update_accumulator(pos, dirty, stack, source_idx),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn forward_update_incremental_halfka_hm(
&self,
pos: &Position,
stack: &mut HalfKA_hmStack,
source_idx: usize,
) -> bool {
match self {
Self::HalfKA_hm(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn forward_update_incremental_halfka(
&self,
pos: &Position,
stack: &mut HalfKAStack,
source_idx: usize,
) -> bool {
match self {
Self::HalfKA(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn evaluate_halfka_hm(&self, pos: &Position, stack: &HalfKA_hmStack) -> Value {
match self {
Self::HalfKA_hm(net) => net.evaluate(pos, stack),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn evaluate_halfka(&self, pos: &Position, stack: &HalfKAStack) -> Value {
match self {
Self::HalfKA(net) => net.evaluate(pos, stack),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn refresh_accumulator_halfkp(&self, pos: &Position, stack: &mut HalfKPStack) {
match self {
Self::HalfKP(net) => net.refresh_accumulator(pos, stack),
_ => panic!("This method is only for HalfKP architecture."),
}
}
pub fn update_accumulator_halfkp(
&self,
pos: &Position,
dirty: &super::accumulator::DirtyPiece,
stack: &mut HalfKPStack,
source_idx: usize,
) {
match self {
Self::HalfKP(net) => net.update_accumulator(pos, dirty, stack, source_idx),
_ => panic!("This method is only for HalfKP architecture."),
}
}
pub fn forward_update_incremental_halfkp(
&self,
pos: &Position,
stack: &mut HalfKPStack,
source_idx: usize,
) -> bool {
match self {
Self::HalfKP(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for HalfKP architecture."),
}
}
pub fn evaluate_halfkp(&self, pos: &Position, stack: &HalfKPStack) -> Value {
match self {
Self::HalfKP(net) => net.evaluate(pos, stack),
_ => panic!("This method is only for HalfKP architecture."),
}
}
}
pub fn parse_fv_scale_from_arch(arch_str: &str) -> Option<i32> {
const FV_SCALE_MIN: i32 = 1;
const FV_SCALE_MAX: i32 = 128;
for part in arch_str.split(',') {
if let Some(value) = part.strip_prefix("fv_scale=") {
if let Ok(scale) = value.parse::<i32>() {
if (FV_SCALE_MIN..=FV_SCALE_MAX).contains(&scale) {
return Some(scale);
}
}
return None;
}
}
None
}
pub fn parse_layer_stack_bucket_mode(value: &str) -> Option<LayerStackBucketMode> {
match value.trim().to_ascii_lowercase().as_str() {
"kingrank9" => Some(LayerStackBucketMode::KingRank9),
"ply9" => Some(LayerStackBucketMode::Ply9),
"progress8" => Some(LayerStackBucketMode::Progress8),
"progress8gikou" => Some(LayerStackBucketMode::Progress8Gikou),
"progress8kpabs" => Some(LayerStackBucketMode::Progress8KPAbs),
_ => None,
}
}
pub fn parse_layer_stack_ply_bounds_csv(text: &str) -> Result<[u16; 8], String> {
let mut values = Vec::new();
for token in text.split(',') {
let t = token.trim();
if t.is_empty() {
continue;
}
let value: u16 =
t.parse().map_err(|e| format!("invalid LS_PLY_BOUNDS value '{t}': {e}"))?;
values.push(value);
}
if values.len() != 8 {
return Err(format!(
"LS_PLY_BOUNDS requires exactly 8 comma-separated values (got {})",
values.len()
));
}
Ok([
values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7],
])
}
pub fn format_layer_stack_ply_bounds(bounds: [u16; 8]) -> String {
bounds.iter().map(|v| v.to_string()).collect::<Vec<_>>().join(",")
}
pub fn compute_layer_stack_ply9_bucket_index(game_ply: i32, bounds: [u16; 8]) -> usize {
let ply = if game_ply < 0 {
0
} else {
u16::try_from(game_ply).unwrap_or(u16::MAX)
};
for (i, &bound) in bounds.iter().enumerate() {
if ply <= bound {
return i;
}
}
8
}
pub fn compute_layer_stack_progress8_bucket_index(
pos: &Position,
side_to_move: Color,
coeff: LayerStackProgressCoeff,
) -> usize {
let board_non_king = (pos.occupied().count() - pos.pieces_pt(PieceType::King).count()) as f32;
let hand_black = pos.hand(Color::Black);
let hand_white = pos.hand(Color::White);
let hand_total = PieceType::HAND_PIECES
.iter()
.map(|&pt| hand_black.count(pt) + hand_white.count(pt))
.sum::<u32>() as f32;
let major_board = (pos.pieces_pt(PieceType::Bishop).count()
+ pos.pieces_pt(PieceType::Rook).count()
+ pos.pieces_pt(PieceType::Horse).count()
+ pos.pieces_pt(PieceType::Dragon).count()) as f32;
let promoted_board = (pos.pieces_pt(PieceType::ProPawn).count()
+ pos.pieces_pt(PieceType::ProLance).count()
+ pos.pieces_pt(PieceType::ProKnight).count()
+ pos.pieces_pt(PieceType::ProSilver).count()
+ pos.pieces_pt(PieceType::Horse).count()
+ pos.pieces_pt(PieceType::Dragon).count()) as f32;
let f_king_rank = pos.king_square(side_to_move).rank().index() as f32;
let e_king_rank = pos.king_square(!side_to_move).rank().index() as f32;
let (stm_king_rank_rel, ntm_king_rank_rel) = match side_to_move {
Color::Black => (f_king_rank, 8.0 - e_king_rank),
Color::White => (8.0 - f_king_rank, e_king_rank),
};
let x = [
board_non_king,
hand_total,
major_board,
promoted_board,
stm_king_rank_rel,
ntm_king_rank_rel,
];
let mut z = coeff.bias;
for (i, &feature) in x.iter().enumerate() {
let std = if coeff.std[i] > 0.0 {
coeff.std[i]
} else {
1.0
};
let x_norm = (feature - coeff.mean[i]) / std;
z += coeff.weights[i] * x_norm;
}
let z_min = coeff.z_clip[0].min(coeff.z_clip[1]);
let z_max = coeff.z_clip[0].max(coeff.z_clip[1]);
let z_clamped = z.clamp(z_min, z_max);
let p = (1.0 / (1.0 + (-z_clamped).exp())).clamp(0.0, 1.0);
let raw = (p * SHOGI_PROGRESS8_NUM_BUCKETS as f32).floor() as i32;
raw.clamp(0, (SHOGI_PROGRESS8_NUM_BUCKETS - 1) as i32) as usize
}
#[inline]
fn chebyshev_distance(a: crate::types::Square, b: crate::types::Square) -> u8 {
let df = a.file().index().abs_diff(b.file().index());
let dr = a.rank().index().abs_diff(b.rank().index());
df.max(dr) as u8
}
#[inline]
fn distance_bin(d: u8) -> usize {
if d <= 1 {
0
} else if d == 2 {
1
} else {
2
}
}
#[inline]
fn is_major_piece(pt: PieceType) -> bool {
matches!(pt, PieceType::Bishop | PieceType::Rook | PieceType::Horse | PieceType::Dragon)
}
pub fn compute_layer_stack_progress8gikou_bucket_index(
pos: &Position,
side_to_move: Color,
coeff: LayerStackProgressCoeffGikouLite,
) -> usize {
let mut x = [0.0f32; SHOGI_PROGRESS_GIKOU_LITE_NUM_FEATURES];
let board_non_king = (pos.occupied().count() - pos.pieces_pt(PieceType::King).count()) as f32;
let hand_black = pos.hand(Color::Black);
let hand_white = pos.hand(Color::White);
let hand_total = PieceType::HAND_PIECES
.iter()
.map(|&pt| hand_black.count(pt) + hand_white.count(pt))
.sum::<u32>() as f32;
let major_board = (pos.pieces_pt(PieceType::Bishop).count()
+ pos.pieces_pt(PieceType::Rook).count()
+ pos.pieces_pt(PieceType::Horse).count()
+ pos.pieces_pt(PieceType::Dragon).count()) as f32;
let promoted_board = (pos.pieces_pt(PieceType::ProPawn).count()
+ pos.pieces_pt(PieceType::ProLance).count()
+ pos.pieces_pt(PieceType::ProKnight).count()
+ pos.pieces_pt(PieceType::ProSilver).count()
+ pos.pieces_pt(PieceType::Horse).count()
+ pos.pieces_pt(PieceType::Dragon).count()) as f32;
let f_king_rank = pos.king_square(side_to_move).rank().index() as f32;
let e_king_rank = pos.king_square(!side_to_move).rank().index() as f32;
let (stm_king_rank_rel, ntm_king_rank_rel) = match side_to_move {
Color::Black => (f_king_rank, 8.0 - e_king_rank),
Color::White => (8.0 - f_king_rank, e_king_rank),
};
x[0] = board_non_king;
x[1] = hand_total;
x[2] = major_board;
x[3] = promoted_board;
x[4] = stm_king_rank_rel;
x[5] = ntm_king_rank_rel;
let stm_king = pos.king_square(side_to_move);
let ntm_king = pos.king_square(!side_to_move);
for sq in pos.occupied().iter() {
let pc = pos.piece_on(sq);
if pc.is_none() {
continue;
}
let pt = pc.piece_type();
if pt == PieceType::King {
continue;
}
let is_stm_piece = pc.color() == side_to_move;
let side_offset = if is_stm_piece { 6usize } else { 12usize };
let major_offset = if is_stm_piece { 18usize } else { 24usize };
let own_king = if is_stm_piece { stm_king } else { ntm_king };
let opp_king = if is_stm_piece { ntm_king } else { stm_king };
let own_bin = distance_bin(chebyshev_distance(sq, own_king));
let opp_bin = distance_bin(chebyshev_distance(sq, opp_king));
x[side_offset + own_bin] += 1.0;
x[side_offset + 3 + opp_bin] += 1.0;
if is_major_piece(pt) {
x[major_offset + own_bin] += 1.0;
x[major_offset + 3 + opp_bin] += 1.0;
}
}
let stm_hand = pos.hand(side_to_move);
let ntm_hand = pos.hand(!side_to_move);
x[30] = PieceType::HAND_PIECES.iter().map(|&pt| stm_hand.count(pt)).sum::<u32>() as f32;
x[31] = PieceType::HAND_PIECES.iter().map(|&pt| ntm_hand.count(pt)).sum::<u32>() as f32;
x[32] = (stm_hand.count(PieceType::Bishop) + stm_hand.count(PieceType::Rook)) as f32;
x[33] = (ntm_hand.count(PieceType::Bishop) + ntm_hand.count(PieceType::Rook)) as f32;
let mut z = coeff.bias;
for (i, &feature) in x.iter().enumerate() {
let std = if coeff.std[i] > 0.0 {
coeff.std[i]
} else {
1.0
};
let x_norm = (feature - coeff.mean[i]) / std;
z += coeff.weights[i] * x_norm;
}
let z_min = coeff.z_clip[0].min(coeff.z_clip[1]);
let z_max = coeff.z_clip[0].max(coeff.z_clip[1]);
let z_clamped = z.clamp(z_min, z_max);
let p = (1.0 / (1.0 + (-z_clamped).exp())).clamp(0.0, 1.0);
let raw = (p * SHOGI_PROGRESS8_NUM_BUCKETS as f32).floor() as i32;
raw.clamp(0, (SHOGI_PROGRESS8_NUM_BUCKETS - 1) as i32) as usize
}
pub fn compute_layer_stack_progress8kpabs_bucket_index(
pos: &Position,
_side_to_move: Color,
weights: &[f32],
) -> usize {
let cached = CACHED_PROGRESS_BUCKET.with(|c| c.replace(None));
if let Some(bucket) = cached {
return bucket;
}
let sum = compute_progress8kpabs_sum(pos, weights);
progress_sum_to_bucket(sum)
}
pub fn compute_progress8kpabs_sum(pos: &Position, weights: &[f32]) -> f32 {
debug_assert_eq!(
weights.len(),
SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS,
"progress8kpabs weights length mismatch"
);
let sq_bk = pos.king_square(Color::Black).index();
let sq_wk = pos.king_square(Color::White).inverse().index();
let weights_b = unsafe { weights.get_unchecked(sq_bk * FE_OLD_END..(sq_bk + 1) * FE_OLD_END) };
let weights_w = unsafe { weights.get_unchecked(sq_wk * FE_OLD_END..(sq_wk + 1) * FE_OLD_END) };
let mut sum = 0.0f32;
for sq in pos.occupied().iter() {
let pc = pos.piece_on(sq);
if pc.is_none() || pc.piece_type() == PieceType::King {
continue;
}
let bp_b = BonaPiece::from_piece_square(pc, sq, Color::Black);
if bp_b != BonaPiece::ZERO {
sum += weights_b[bp_b.value() as usize];
}
let bp_w = BonaPiece::from_piece_square(pc, sq, Color::White);
if bp_w != BonaPiece::ZERO {
sum += weights_w[bp_w.value() as usize];
}
}
for owner in [Color::Black, Color::White] {
let hand = pos.hand(owner);
for &pt in &PieceType::HAND_PIECES {
let count = hand.count(pt);
for c in 1..=count {
let c_u8 = u8::try_from(c).expect("hand count fits in u8");
let bp_b = BonaPiece::from_hand_piece(Color::Black, owner, pt, c_u8);
if bp_b != BonaPiece::ZERO {
sum += weights_b[bp_b.value() as usize];
}
let bp_w = BonaPiece::from_hand_piece(Color::White, owner, pt, c_u8);
if bp_w != BonaPiece::ZERO {
sum += weights_w[bp_w.value() as usize];
}
}
}
}
sum
}
#[inline]
pub fn update_progress8kpabs_sum_diff(
prev_sum: f32,
dirty_piece: &super::accumulator::DirtyPiece,
sq_bk: usize,
sq_wk: usize,
weights: &[f32],
) -> f32 {
debug_assert!(sq_bk < 81, "sq_bk out of range: {sq_bk}");
debug_assert!(sq_wk < 81, "sq_wk out of range: {sq_wk}");
debug_assert_eq!(
weights.len(),
SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS,
"progress8kpabs weights length mismatch"
);
let weights_b = unsafe { weights.get_unchecked(sq_bk * FE_OLD_END..(sq_bk + 1) * FE_OLD_END) };
let weights_w = unsafe { weights.get_unchecked(sq_wk * FE_OLD_END..(sq_wk + 1) * FE_OLD_END) };
let mut sum = prev_sum;
for i in 0..dirty_piece.dirty_num as usize {
debug_assert!(i < dirty_piece.changed_piece.len());
let changed = unsafe { dirty_piece.changed_piece.get_unchecked(i) };
let old_fb = changed.old_piece.fb;
if old_fb != BonaPiece::ZERO {
let idx = old_fb.value() as usize;
debug_assert!(idx < weights_b.len());
sum -= unsafe { *weights_b.get_unchecked(idx) };
}
let old_fw = changed.old_piece.fw;
if old_fw != BonaPiece::ZERO {
let idx = old_fw.value() as usize;
debug_assert!(idx < weights_w.len());
sum -= unsafe { *weights_w.get_unchecked(idx) };
}
let new_fb = changed.new_piece.fb;
if new_fb != BonaPiece::ZERO {
let idx = new_fb.value() as usize;
debug_assert!(idx < weights_b.len());
sum += unsafe { *weights_b.get_unchecked(idx) };
}
let new_fw = changed.new_piece.fw;
if new_fw != BonaPiece::ZERO {
let idx = new_fw.value() as usize;
debug_assert!(idx < weights_w.len());
sum += unsafe { *weights_w.get_unchecked(idx) };
}
}
sum
}
#[inline]
pub fn progress_sum_to_bucket(sum: f32) -> usize {
PROGRESS_BUCKET_THRESHOLDS.partition_point(|&threshold| sum >= threshold)
}
pub fn init_nnue<P: AsRef<Path>>(path: P) -> io::Result<()> {
let network = Arc::new(NNUENetwork::load(path)?);
*NETWORK.write().expect("NNUE lock poisoned") = Some(network);
Ok(())
}
pub fn init_nnue_from_bytes(bytes: &[u8]) -> io::Result<()> {
let network = Arc::new(NNUENetwork::from_bytes(bytes)?);
*NETWORK.write().expect("NNUE lock poisoned") = Some(network);
Ok(())
}
pub fn clear_nnue() {
*NETWORK.write().expect("NNUE lock poisoned") = None;
}
pub fn is_nnue_initialized() -> bool {
NETWORK.read().expect("NNUE lock poisoned").is_some()
}
#[derive(Debug, Clone)]
pub struct NnueFormatInfo {
pub architecture: String,
pub l1_dimension: u32,
pub l2_dimension: u32,
pub l3_dimension: u32,
pub activation: String,
pub version: u32,
pub arch_string: String,
}
pub fn detect_format(bytes: &[u8], file_size: u64) -> io::Result<NnueFormatInfo> {
const MIN_HEADER_SIZE: usize = 12;
if bytes.len() < MIN_HEADER_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"NNUE file too small: {} bytes (need at least {} for header)",
bytes.len(),
MIN_HEADER_SIZE
),
));
}
let version = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
match version {
NNUE_VERSION | NNUE_VERSION_HALFKA => {
let arch_len = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
if arch_len == 0 || arch_len > MAX_ARCH_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid arch string length: {} (max: {})", arch_len, MAX_ARCH_LEN),
));
}
let required_size = MIN_HEADER_SIZE + arch_len;
if bytes.len() < required_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"NNUE file too small: {} bytes (need {} for arch string)",
bytes.len(),
required_size
),
));
}
let arch_str = String::from_utf8_lossy(&bytes[12..12 + arch_len]).to_string();
let activation = detect_activation_from_arch(&arch_str).to_string();
let parsed = super::spec::parse_architecture(&arch_str)
.map_err(|msg| io::Error::new(io::ErrorKind::InvalidData, msg))?;
let (l1, l2, l3, feature_set, used_file_size_detection) = if let Some(detection) =
super::spec::detect_architecture_from_size(
file_size,
arch_len,
Some(parsed.feature_set),
) {
(
detection.spec.l1,
detection.spec.l2,
detection.spec.l3,
detection.spec.feature_set,
true,
)
} else {
(parsed.l1, parsed.l2, parsed.l3, parsed.feature_set, false)
};
#[cfg(debug_assertions)]
if !used_file_size_detection {
eprintln!(
"Warning: File size detection failed for size={}. \
Falling back to header parsing (may be inaccurate).",
file_size
);
}
let _ = used_file_size_detection;
let architecture = match feature_set {
FeatureSet::LayerStacks => "LayerStacks".to_string(),
FeatureSet::HalfKA_hm => format!("HalfKA_hm{}", l1),
FeatureSet::HalfKA => format!("HalfKA{}", l1),
FeatureSet::HalfKP => format!("HalfKP{}", l1),
};
Ok(NnueFormatInfo {
architecture,
l1_dimension: l1 as u32,
l2_dimension: l2 as u32,
l3_dimension: l3 as u32,
activation,
version,
arch_string: arch_str,
})
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Unknown NNUE version: 0x{version:08X}"),
)),
}
}
pub fn get_network() -> Option<Arc<NNUENetwork>> {
NETWORK.read().expect("NNUE lock poisoned").clone()
}
#[inline]
fn update_and_evaluate_layer_stacks(
network: &NNUENetwork,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
) -> Value {
let current_entry = stack.current();
if !current_entry.accumulator.computed_accumulation {
let mut updated = false;
if let Some(prev_idx) = current_entry.previous {
let prev_computed = stack.entry_at(prev_idx).accumulator.computed_accumulation;
if prev_computed {
let dirty_piece = stack.current().dirty_piece;
let (prev_acc, current_acc) = stack.get_prev_and_current_accumulators(prev_idx);
network.update_accumulator_layer_stacks(pos, &dirty_piece, current_acc, prev_acc);
updated = true;
}
}
if !updated && let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_layer_stacks(pos, stack, source_idx);
}
if !updated {
let acc = &mut stack.current_mut().accumulator;
network.refresh_accumulator_layer_stacks(pos, acc);
}
}
if get_layer_stack_bucket_mode() == LayerStackBucketMode::Progress8KPAbs {
let bucket = ensure_progress_bucket(pos, stack);
CACHED_PROGRESS_BUCKET.with(|c| c.set(Some(bucket)));
}
let acc_ref = &stack.current().accumulator;
network.evaluate_layer_stacks(pos, acc_ref)
}
#[inline]
fn update_and_evaluate_layer_stacks_cached(
network: &NNUENetwork,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
acc_cache: &mut Option<AccumulatorCacheLayerStacks>,
) -> Value {
let current_entry = stack.current();
if !current_entry.accumulator.computed_accumulation {
let mut updated = false;
if let Some(prev_idx) = current_entry.previous {
let prev_computed = stack.entry_at(prev_idx).accumulator.computed_accumulation;
if prev_computed {
let dirty_piece = stack.current().dirty_piece;
let (prev_acc, current_acc) = stack.get_prev_and_current_accumulators(prev_idx);
if let Some(cache) = acc_cache {
network.update_accumulator_layer_stacks_with_cache(
pos,
&dirty_piece,
current_acc,
prev_acc,
cache,
);
} else {
network.update_accumulator_layer_stacks(
pos,
&dirty_piece,
current_acc,
prev_acc,
);
}
updated = true;
}
}
if !updated && let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_layer_stacks(pos, stack, source_idx);
}
if !updated {
let acc = &mut stack.current_mut().accumulator;
if let Some(cache) = acc_cache {
network.refresh_accumulator_layer_stacks_with_cache(pos, acc, cache);
} else {
network.refresh_accumulator_layer_stacks(pos, acc);
}
}
}
if get_layer_stack_bucket_mode() == LayerStackBucketMode::Progress8KPAbs {
let bucket = ensure_progress_bucket(pos, stack);
CACHED_PROGRESS_BUCKET.with(|c| c.set(Some(bucket)));
}
let acc_ref = &stack.current().accumulator;
network.evaluate_layer_stacks(pos, acc_ref)
}
#[inline]
fn ensure_progress_bucket(pos: &Position, stack: &mut AccumulatorStackLayerStacks) -> usize {
if !stack.current().computed_progress {
let weights = get_layer_stack_progress_kpabs_weights();
let current_entry = stack.current();
let dirty = ¤t_entry.dirty_piece;
let king_moved = dirty.king_moved[0] || dirty.king_moved[1];
if !king_moved
&& let Some(prev_idx) = current_entry.previous
&& stack.entry_at(prev_idx).computed_progress
{
let prev_sum = stack.entry_at(prev_idx).progress_sum;
let sq_bk = pos.king_square(Color::Black).index();
let sq_wk = pos.king_square(Color::White).inverse().index();
let new_sum = update_progress8kpabs_sum_diff(prev_sum, dirty, sq_bk, sq_wk, weights);
let entry = stack.current_mut();
entry.progress_sum = new_sum;
entry.computed_progress = true;
}
if !stack.current().computed_progress {
let sum = compute_progress8kpabs_sum(pos, weights);
let entry = stack.current_mut();
entry.progress_sum = sum;
entry.computed_progress = true;
}
}
progress_sum_to_bucket(stack.current().progress_sum)
}
#[cfg(not(feature = "layerstack-only"))]
#[inline]
fn update_and_evaluate_halfka_hm(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKA_hmStack,
) -> Value {
if !stack.is_current_computed() {
let mut updated = false;
if let Some(prev_idx) = stack.current_previous()
&& stack.is_entry_computed(prev_idx)
{
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka_hm(pos, &dirty, stack, prev_idx);
updated = true;
}
if !updated && let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_halfka_hm(pos, stack, source_idx);
}
if !updated {
network.refresh_accumulator_halfka_hm(pos, stack);
}
}
network.evaluate_halfka_hm(pos, stack)
}
#[cfg(not(feature = "layerstack-only"))]
#[inline]
fn update_and_evaluate_halfka(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKAStack,
) -> Value {
if !stack.is_current_computed() {
let mut updated = false;
if let Some(prev_idx) = stack.current_previous()
&& stack.is_entry_computed(prev_idx)
{
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka(pos, &dirty, stack, prev_idx);
updated = true;
}
if !updated && let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_halfka(pos, stack, source_idx);
}
if !updated {
network.refresh_accumulator_halfka(pos, stack);
}
}
network.evaluate_halfka(pos, stack)
}
#[cfg(not(feature = "layerstack-only"))]
#[inline]
fn update_and_evaluate_halfkp(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKPStack,
) -> Value {
if !stack.is_current_computed() {
let mut updated = false;
if let Some(prev_idx) = stack.current_previous()
&& stack.is_entry_computed(prev_idx)
{
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfkp(pos, &dirty, stack, prev_idx);
updated = true;
}
if !updated && let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_halfkp(pos, stack, source_idx);
}
if !updated {
network.refresh_accumulator_halfkp(pos, stack);
}
}
network.evaluate_halfkp(pos, stack)
}
pub fn is_layer_stacks_loaded() -> bool {
get_network().is_some_and(|n| n.is_layer_stacks())
}
pub fn is_halfka_hm_256_loaded() -> bool {
get_network().is_some_and(|n| n.is_halfka_hm() && n.l1_size() == 256)
}
pub fn is_halfka_256_loaded() -> bool {
get_network().is_some_and(|n| n.is_halfka() && n.l1_size() == 256)
}
pub fn is_halfka_hm_512_loaded() -> bool {
get_network().is_some_and(|n| n.is_halfka_hm() && n.l1_size() == 512)
}
pub fn is_halfka_512_loaded() -> bool {
get_network().is_some_and(|n| n.is_halfka() && n.l1_size() == 512)
}
pub fn is_halfka_hm_1024_loaded() -> bool {
get_network().is_some_and(|n| n.is_halfka_hm() && n.l1_size() == 1024)
}
pub fn is_halfka_1024_loaded() -> bool {
get_network().is_some_and(|n| n.is_halfka() && n.l1_size() == 1024)
}
pub fn evaluate_layer_stacks(pos: &Position, stack: &mut AccumulatorStackLayerStacks) -> Value {
if material::is_material_enabled() {
return material::evaluate_material(pos);
}
let Some(network) = get_network() else {
panic!(
"NNUE network not loaded and MaterialLevel not set. \
Use 'setoption name EvalFile' or 'setoption name MaterialLevel'."
);
};
if !network.is_layer_stacks() {
panic!("Non-LayerStacks architecture detected. Use evaluate() with AccumulatorStack.");
}
update_and_evaluate_layer_stacks(&network, pos, stack)
}
pub fn evaluate_dispatch(
pos: &Position,
stack: &mut AccumulatorStackVariant,
acc_cache: &mut Option<AccumulatorCacheLayerStacks>,
) -> Value {
if material::is_material_enabled() {
return material::evaluate_material(pos);
}
let Some(network) = get_network() else {
panic!(
"NNUE network not loaded and MaterialLevel not set. \
Use 'setoption name EvalFile' or 'setoption name MaterialLevel'."
);
};
match stack {
AccumulatorStackVariant::LayerStacks(s) => {
update_and_evaluate_layer_stacks_cached(&network, pos, s, acc_cache)
}
#[cfg(not(feature = "layerstack-only"))]
AccumulatorStackVariant::HalfKA(s) => update_and_evaluate_halfka(&network, pos, s),
#[cfg(not(feature = "layerstack-only"))]
AccumulatorStackVariant::HalfKA_hm(s) => update_and_evaluate_halfka_hm(&network, pos, s),
#[cfg(not(feature = "layerstack-only"))]
AccumulatorStackVariant::HalfKP(s) => update_and_evaluate_halfkp(&network, pos, s),
#[cfg(feature = "layerstack-only")]
AccumulatorStackVariant::HalfKA(_)
| AccumulatorStackVariant::HalfKA_hm(_)
| AccumulatorStackVariant::HalfKP(_) => {
unreachable!("layerstack-only build: only LayerStacks variant is supported")
}
}
}
pub fn ensure_accumulator_computed(
pos: &Position,
stack: &mut AccumulatorStackVariant,
acc_cache: &mut Option<AccumulatorCacheLayerStacks>,
) {
let Some(network) = get_network() else {
return;
};
match stack {
AccumulatorStackVariant::LayerStacks(s) => {
update_accumulator_only_layer_stacks_cached(&network, pos, s, acc_cache);
}
#[cfg(not(feature = "layerstack-only"))]
AccumulatorStackVariant::HalfKA(s) => {
update_accumulator_only_halfka(&network, pos, s);
}
#[cfg(not(feature = "layerstack-only"))]
AccumulatorStackVariant::HalfKA_hm(s) => {
update_accumulator_only_halfka_hm(&network, pos, s);
}
#[cfg(not(feature = "layerstack-only"))]
AccumulatorStackVariant::HalfKP(s) => {
update_accumulator_only_halfkp(&network, pos, s);
}
#[cfg(feature = "layerstack-only")]
AccumulatorStackVariant::HalfKA(_)
| AccumulatorStackVariant::HalfKA_hm(_)
| AccumulatorStackVariant::HalfKP(_) => {
unreachable!("layerstack-only build: only LayerStacks variant is supported")
}
}
}
#[inline]
fn update_accumulator_only_layer_stacks_cached(
network: &NNUENetwork,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
acc_cache: &mut Option<AccumulatorCacheLayerStacks>,
) {
let current_entry = stack.current();
if current_entry.accumulator.computed_accumulation {
count_already_computed!();
return;
}
let mut updated = false;
if let Some(prev_idx) = current_entry.previous {
let prev_computed = stack.entry_at(prev_idx).accumulator.computed_accumulation;
if prev_computed {
let dirty_piece = stack.current().dirty_piece;
let (prev_acc, current_acc) = stack.get_prev_and_current_accumulators(prev_idx);
if let Some(cache) = acc_cache {
network.update_accumulator_layer_stacks_with_cache(
pos,
&dirty_piece,
current_acc,
prev_acc,
cache,
);
} else {
network.update_accumulator_layer_stacks(pos, &dirty_piece, current_acc, prev_acc);
}
count_update!();
updated = true;
}
}
if !updated {
let acc = &mut stack.current_mut().accumulator;
if let Some(cache) = acc_cache {
network.refresh_accumulator_layer_stacks_with_cache(pos, acc, cache);
} else {
network.refresh_accumulator_layer_stacks(pos, acc);
}
count_refresh!();
}
}
#[cfg(not(feature = "layerstack-only"))]
#[inline]
fn update_accumulator_only_halfka_hm(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKA_hmStack,
) {
if stack.is_current_computed() {
count_already_computed!();
return;
}
let mut updated = false;
if let Some(prev_idx) = stack.current_previous()
&& stack.is_entry_computed(prev_idx)
{
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka_hm(pos, &dirty, stack, prev_idx);
count_update!();
updated = true;
}
if !updated {
network.refresh_accumulator_halfka_hm(pos, stack);
count_refresh!();
}
}
#[cfg(not(feature = "layerstack-only"))]
#[inline]
fn update_accumulator_only_halfka(network: &NNUENetwork, pos: &Position, stack: &mut HalfKAStack) {
if stack.is_current_computed() {
count_already_computed!();
return;
}
let mut updated = false;
if let Some(prev_idx) = stack.current_previous()
&& stack.is_entry_computed(prev_idx)
{
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka(pos, &dirty, stack, prev_idx);
count_update!();
updated = true;
}
if !updated {
network.refresh_accumulator_halfka(pos, stack);
count_refresh!();
}
}
#[cfg(not(feature = "layerstack-only"))]
#[inline]
fn update_accumulator_only_halfkp(network: &NNUENetwork, pos: &Position, stack: &mut HalfKPStack) {
if stack.is_current_computed() {
count_already_computed!();
return;
}
let mut updated = false;
if let Some(prev_idx) = stack.current_previous()
&& stack.is_entry_computed(prev_idx)
{
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfkp(pos, &dirty, stack, prev_idx);
count_update!();
updated = true;
}
if !updated {
network.refresh_accumulator_halfkp(pos, stack);
count_refresh!();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::position::SFEN_HIRATE;
#[test]
fn test_evaluate_fallback() {
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let mut stack = AccumulatorStackVariant::new_default();
let value = evaluate_dispatch(&pos, &mut stack, &mut None);
assert!(value.raw().abs() < 1000);
}
#[test]
fn test_accumulator_stack_variant_fallback() {
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let mut stack = AccumulatorStackVariant::new_default();
let value1 = evaluate_dispatch(&pos, &mut stack, &mut None);
let value2 = evaluate_dispatch(&pos, &mut stack, &mut None);
let _ = (value1, value2);
}
#[test]
#[ignore]
fn test_nnue_network_auto_detect_layer_stacks() {
let path = std::env::var("NNUE_TEST_FILE")
.unwrap_or_else(|_| "/path/to/your/layer_stacks.nnue".to_string());
let network = match NNUENetwork::load(path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_layer_stacks(), "epoch82.nnue should be detected as LayerStacks");
assert_eq!(network.architecture_name(), "LayerStacks");
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let mut acc = crate::nnue::AccumulatorLayerStacks::new();
network.refresh_accumulator_layer_stacks(&pos, &mut acc);
let value = network.evaluate_layer_stacks(&pos, &acc);
eprintln!("LayerStacks evaluate: {}", value.raw());
assert!(value.raw().abs() < 1000);
}
#[test]
#[ignore]
fn test_detect_format_aoba() {
let path = std::env::var("NNUE_AOBA_FILE").unwrap_or_else(|_| "AobaNNUE.bin".to_string());
let bytes = match std::fs::read(&path) {
Ok(b) => b,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
let file_size = bytes.len() as u64;
let info = detect_format(&bytes, file_size).expect("Failed to detect format");
eprintln!("File: {path}");
eprintln!("Architecture: {}", info.architecture);
eprintln!(
"L1: {}, L2: {}, L3: {}",
info.l1_dimension, info.l2_dimension, info.l3_dimension
);
eprintln!("Activation: {}", info.activation);
eprintln!("Arch string (header): {}", info.arch_string);
assert_eq!(
info.architecture, "HalfKP768",
"Should detect HalfKP768 from file size, not HalfKP256 from header"
);
assert_eq!(info.l1_dimension, 768, "L1 should be 768, not 256 from header");
assert_eq!(info.l2_dimension, 16, "L2 should be 16");
assert_eq!(info.l3_dimension, 64, "L3 should be 64");
assert!(
info.arch_string.contains("256"),
"Header should claim 256, but file size detection should override it"
);
}
#[test]
fn test_detect_format_fallback_to_header() {
let unknown_file_size = 12345678u64;
let mut bytes = Vec::new();
bytes.extend_from_slice(&NNUE_VERSION_HALFKA.to_le_bytes()); bytes.extend_from_slice(&0u32.to_le_bytes());
let arch_str = "Features=HalfKA_hm[73305->512x2],l2=8,l3=96";
let arch_len = arch_str.len() as u32;
bytes.extend_from_slice(&arch_len.to_le_bytes());
bytes.extend_from_slice(arch_str.as_bytes());
let info =
detect_format(&bytes, unknown_file_size).expect("Should fallback to header parsing");
assert_eq!(info.architecture, "HalfKA_hm512");
assert_eq!(info.l1_dimension, 512);
assert_eq!(info.l2_dimension, 8);
assert_eq!(info.l3_dimension, 96);
}
#[test]
fn test_detect_format_error_cases() {
let bytes = vec![0u8; 5];
let result = detect_format(&bytes, 5);
assert!(result.is_err(), "Should fail for too small file");
assert!(
result.unwrap_err().to_string().contains("too small"),
"Error message should mention 'too small'"
);
let mut bytes = Vec::new();
bytes.extend_from_slice(&NNUE_VERSION.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes()); let result = detect_format(&bytes, 100);
assert!(result.is_err(), "Should fail for arch_len = 0");
assert!(
result.unwrap_err().to_string().contains("Invalid arch string length"),
"Error message should mention invalid arch string length"
);
let mut bytes = Vec::new();
bytes.extend_from_slice(&NNUE_VERSION.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes());
bytes.extend_from_slice(&(MAX_ARCH_LEN as u32 + 1).to_le_bytes());
let result = detect_format(&bytes, 100);
assert!(result.is_err(), "Should fail for arch_len > MAX_ARCH_LEN");
let mut bytes = Vec::new();
bytes.extend_from_slice(&NNUE_VERSION.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes());
bytes.extend_from_slice(&100u32.to_le_bytes()); let result = detect_format(&bytes, 1000);
assert!(result.is_err(), "Should fail when buffer is too small for arch_str");
let mut bytes = Vec::new();
bytes.extend_from_slice(&0xDEADBEEFu32.to_le_bytes());
bytes.extend_from_slice(&[0u8; 100]);
let result = detect_format(&bytes, 112);
assert!(result.is_err(), "Should fail for unknown version");
assert!(
result.unwrap_err().to_string().contains("Unknown NNUE version"),
"Error message should mention unknown version"
);
}
#[test]
fn test_parse_fv_scale_from_arch() {
assert_eq!(
parse_fv_scale_from_arch(
"Features=HalfKA_hm^[73305->256x2]-SCReLU,fv_scale=13,qa=127,qb=64,scale=600"
),
Some(13)
);
assert_eq!(
parse_fv_scale_from_arch(
"Features=HalfKA_hm^[73305->512x2]-SCReLU,fv_scale=20,qa=127,qb=64,scale=400"
),
Some(20)
);
assert_eq!(
parse_fv_scale_from_arch(
"Features=HalfKA_hm^[73305->1024x2]-SCReLU,fv_scale=16,qa=127,qb=64,scale=508"
),
Some(16)
);
assert_eq!(parse_fv_scale_from_arch("Features=HalfKP[125388->256x2]"), None);
assert_eq!(parse_fv_scale_from_arch("Features=HalfKA_hm^[73305->512x2]"), None);
assert_eq!(parse_fv_scale_from_arch(""), None);
assert_eq!(
parse_fv_scale_from_arch("Features=HalfKA_hm^[73305->256x2],fv_scale=abc"),
None
);
}
#[test]
fn test_parse_fv_scale_edge_cases() {
assert_eq!(parse_fv_scale_from_arch("fv_scale=1"), Some(1));
assert_eq!(parse_fv_scale_from_arch("fv_scale=128"), Some(128));
assert_eq!(parse_fv_scale_from_arch("fv_scale=64"), Some(64));
assert_eq!(parse_fv_scale_from_arch("fv_scale=0"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=129"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=-1"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=-100"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=99999"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=2147483647"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale= 16"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=16 "), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=10,fv_scale=20"), Some(10));
assert_eq!(parse_fv_scale_from_arch("fv_scale="), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=16.5"), None);
assert_eq!(parse_fv_scale_from_arch("my_fv_scale=16"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale_v2=16"), None);
}
#[test]
fn test_parse_layer_stack_bucket_mode() {
assert_eq!(
parse_layer_stack_bucket_mode("kingrank9"),
Some(LayerStackBucketMode::KingRank9)
);
assert_eq!(parse_layer_stack_bucket_mode("ply9"), Some(LayerStackBucketMode::Ply9));
assert_eq!(parse_layer_stack_bucket_mode("PLY9"), Some(LayerStackBucketMode::Ply9));
assert_eq!(
parse_layer_stack_bucket_mode("progress8"),
Some(LayerStackBucketMode::Progress8)
);
assert_eq!(
parse_layer_stack_bucket_mode("progress8gikou"),
Some(LayerStackBucketMode::Progress8Gikou)
);
assert_eq!(
parse_layer_stack_bucket_mode("progress8kpabs"),
Some(LayerStackBucketMode::Progress8KPAbs)
);
assert_eq!(
parse_layer_stack_bucket_mode(" kingrank9 "),
Some(LayerStackBucketMode::KingRank9)
);
assert_eq!(parse_layer_stack_bucket_mode("unknown"), None);
}
#[test]
fn test_parse_layer_stack_ply_bounds_csv() {
assert_eq!(
parse_layer_stack_ply_bounds_csv("30,44,58,72,86,100,116,138").unwrap(),
[30, 44, 58, 72, 86, 100, 116, 138]
);
assert_eq!(
parse_layer_stack_ply_bounds_csv(" 30, 44, 58, 72, 86, 100, 116, 138 ").unwrap(),
[30, 44, 58, 72, 86, 100, 116, 138]
);
assert!(parse_layer_stack_ply_bounds_csv("30,44,58").is_err());
assert!(parse_layer_stack_ply_bounds_csv("30,44,58,72,86,100,116,abc").is_err());
}
#[test]
fn test_compute_layer_stack_ply9_bucket_index() {
let bounds = LAYER_STACK_PLY9_DEFAULT_BOUNDS;
assert_eq!(compute_layer_stack_ply9_bucket_index(0, bounds), 0);
assert_eq!(compute_layer_stack_ply9_bucket_index(30, bounds), 0);
assert_eq!(compute_layer_stack_ply9_bucket_index(31, bounds), 1);
assert_eq!(compute_layer_stack_ply9_bucket_index(138, bounds), 7);
assert_eq!(compute_layer_stack_ply9_bucket_index(139, bounds), 8);
assert_eq!(compute_layer_stack_ply9_bucket_index(400, bounds), 8);
assert_eq!(compute_layer_stack_ply9_bucket_index(-5, bounds), 0);
}
#[test]
fn test_compute_layer_stack_progress8_bucket_index_range() {
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let coeff = LayerStackProgressCoeff::default();
let b = compute_layer_stack_progress8_bucket_index(&pos, pos.side_to_move(), coeff);
assert!(b <= 7, "progress8 bucket must be in 0..=7, got {b}");
}
#[test]
fn test_compute_layer_stack_progress8gikou_bucket_index_range() {
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let coeff = LayerStackProgressCoeffGikouLite::default();
let b = compute_layer_stack_progress8gikou_bucket_index(&pos, pos.side_to_move(), coeff);
assert!(b <= 7, "progress8gikou bucket must be in 0..=7, got {b}");
}
#[test]
fn test_compute_layer_stack_progress8kpabs_bucket_index_range() {
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let weights = vec![0.0f32; SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS];
let b = compute_layer_stack_progress8kpabs_bucket_index(&pos, pos.side_to_move(), &weights);
assert_eq!(b, 4, "zero-weight progress8kpabs should map to the middle bucket");
}
#[test]
fn test_progress_bucket_thresholds_match_sigmoid() {
let sigmoid_bucket = |sum: f32| -> usize {
let p = (1.0 / (1.0 + (-sum).exp())).clamp(0.0, 1.0);
let raw = (p * SHOGI_PROGRESS8_NUM_BUCKETS as f32).floor() as i32;
raw.clamp(0, (SHOGI_PROGRESS8_NUM_BUCKETS - 1) as i32) as usize
};
let threshold_bucket = |sum: f32| -> usize {
PROGRESS_BUCKET_THRESHOLDS
.iter()
.filter(|&&t| sum >= t)
.count()
.min(SHOGI_PROGRESS8_NUM_BUCKETS - 1)
};
for &sum in &[
-10.0, -5.0, -3.0, -2.5, -1.5, -0.8, -0.3, 0.0, 0.3, 0.8, 1.5, 2.5, 3.0, 5.0, 10.0,
] {
assert_eq!(sigmoid_bucket(sum), threshold_bucket(sum), "mismatch at sum={sum}");
}
}
#[test]
fn test_progress8kpabs_diff_update() {
use crate::types::Move;
let mut weights = vec![0.0f32; SHOGI_PROGRESS_KP_ABS_NUM_WEIGHTS];
let mut rng: u64 = 12345;
for w in weights.iter_mut() {
rng ^= rng << 13;
rng ^= rng >> 7;
rng ^= rng << 17;
*w = ((rng as i64 % 1000) as f32) / 1000.0;
}
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let sum0 = compute_progress8kpabs_sum(&pos, &weights);
let moves_usi = [
"7g7f", "3c3d", "2g2f", "8c8d", "2f2e", "8d8e", "6i7h", "4a3b",
];
let mut prev_sum = sum0;
for &mv_str in &moves_usi {
let mv = Move::from_usi(mv_str).expect("valid move");
let gives_check = pos.gives_check(mv);
let dirty = pos.do_move(mv, gives_check);
let expected_sum = compute_progress8kpabs_sum(&pos, &weights);
let expected_bucket = progress_sum_to_bucket(expected_sum);
if dirty.king_moved[0] || dirty.king_moved[1] {
prev_sum = expected_sum;
} else {
let sq_bk = pos.king_square(Color::Black).index();
let sq_wk = pos.king_square(Color::White).inverse().index();
let diff_sum =
update_progress8kpabs_sum_diff(prev_sum, &dirty, sq_bk, sq_wk, &weights);
let diff_bucket = progress_sum_to_bucket(diff_sum);
assert!(
(diff_sum - expected_sum).abs() < 1e-5,
"sum mismatch after {mv_str}: diff={diff_sum}, expected={expected_sum}"
);
assert_eq!(diff_bucket, expected_bucket, "bucket mismatch after {mv_str}");
prev_sum = diff_sum;
}
}
}
#[test]
#[ignore]
fn test_nnue_halfkp_768_auto_detect() {
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("Failed to find workspace root");
let default_path = workspace_root
.join("eval/halfkp_768x2-16-64_crelu/AobaNNUE_HalfKP_768x2_16_64_FV_SCALE_40.bin");
let path = std::env::var("NNUE_HALFKP_768_FILE")
.unwrap_or_else(|_| default_path.display().to_string());
let network = match NNUENetwork::load(&path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_halfkp(), "File should be detected as HalfKP");
assert_eq!(network.l1_size(), 768, "L1 should be 768");
let spec = network.architecture_spec();
assert_eq!(spec.l1, 768, "spec.l1 should be 768");
assert_eq!(spec.l2, 16, "spec.l2 should be 16");
assert_eq!(spec.l3, 64, "spec.l3 should be 64");
eprintln!("Successfully loaded HalfKP 768x2-16-64 network");
eprintln!("Architecture name: {}", network.architecture_name());
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
use crate::nnue::halfkp::HalfKPStack;
let mut stack = HalfKPStack::from_network(match &network {
NNUENetwork::HalfKP(net) => net,
_ => unreachable!(),
});
network.refresh_accumulator_halfkp(&pos, &mut stack);
let value = network.evaluate_halfkp(&pos, &stack);
eprintln!("HalfKP 768 evaluate: {}", value.raw());
assert!(value.raw().abs() < 10000, "Evaluation {} is out of expected range", value.raw());
}
#[test]
#[ignore]
fn test_nnue_halfka_hm_256_auto_detect() {
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("Failed to find workspace root");
let default_path = workspace_root.join("eval/halfka_hm_256x2-32-32_crelu/v28_epoch65.nnue");
let path = std::env::var("NNUE_HALFKA_HM_256_FILE")
.unwrap_or_else(|_| default_path.display().to_string());
let network = match NNUENetwork::load(&path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_halfka_hm(), "File should be detected as HalfKA_hm");
assert_eq!(network.l1_size(), 256, "L1 should be 256");
let spec = network.architecture_spec();
assert_eq!(spec.l1, 256, "spec.l1 should be 256");
assert_eq!(spec.l2, 32, "spec.l2 should be 32");
assert_eq!(spec.l3, 32, "spec.l3 should be 32");
eprintln!("Successfully loaded HalfKA_hm 256x2-32-32 network");
eprintln!("Architecture name: {}", network.architecture_name());
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
use crate::nnue::halfka_hm::HalfKA_hmStack;
let mut stack = HalfKA_hmStack::from_network(match &network {
NNUENetwork::HalfKA_hm(net) => net,
_ => unreachable!(),
});
network.refresh_accumulator_halfka_hm(&pos, &mut stack);
let value = network.evaluate_halfka_hm(&pos, &stack);
eprintln!("HalfKA_hm 256 evaluate: {}", value.raw());
assert!(value.raw().abs() < 10000, "Evaluation {} is out of expected range", value.raw());
}
#[test]
#[ignore]
fn test_nnue_halfka_hm_1024_auto_detect() {
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("Failed to find workspace root");
let default_path = workspace_root.join("eval/halfka_hm_1024x2-8-96_crelu/epoch20_v2.nnue");
let path = std::env::var("NNUE_HALFKA_HM_1024_FILE")
.unwrap_or_else(|_| default_path.display().to_string());
let network = match NNUENetwork::load(&path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_halfka_hm(), "File should be detected as HalfKA_hm");
assert_eq!(network.l1_size(), 1024, "L1 should be 1024");
let spec = network.architecture_spec();
assert_eq!(spec.l1, 1024, "spec.l1 should be 1024");
assert_eq!(spec.l2, 8, "spec.l2 should be 8");
assert_eq!(spec.l3, 96, "spec.l3 should be 96");
eprintln!("Successfully loaded HalfKA_hm 1024x2-8-96 network");
eprintln!("Architecture name: {}", network.architecture_name());
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
use crate::nnue::halfka_hm::HalfKA_hmStack;
let mut stack = HalfKA_hmStack::from_network(match &network {
NNUENetwork::HalfKA_hm(net) => net,
_ => unreachable!(),
});
network.refresh_accumulator_halfka_hm(&pos, &mut stack);
let value = network.evaluate_halfka_hm(&pos, &stack);
eprintln!("HalfKA_hm 1024 evaluate: {}", value.raw());
assert!(value.raw().abs() < 10000, "Evaluation {} is out of expected range", value.raw());
}
#[test]
#[ignore]
fn test_nnue_halfkp_256_suisho5() {
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("Failed to find workspace root");
let default_path = workspace_root.join("eval/halfkp_256x2-32-32_crelu/suisho5.bin");
let path = std::env::var("NNUE_HALFKP_256_FILE")
.unwrap_or_else(|_| default_path.display().to_string());
let network = match NNUENetwork::load(&path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_halfkp(), "File should be detected as HalfKP");
assert_eq!(network.l1_size(), 256, "L1 should be 256");
let spec = network.architecture_spec();
assert_eq!(spec.l1, 256, "spec.l1 should be 256");
assert_eq!(spec.l2, 32, "spec.l2 should be 32");
assert_eq!(spec.l3, 32, "spec.l3 should be 32");
eprintln!("Successfully loaded HalfKP 256x2-32-32 network (suisho5)");
eprintln!("Architecture name: {}", network.architecture_name());
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
use crate::nnue::halfkp::HalfKPStack;
let mut stack = HalfKPStack::from_network(match &network {
NNUENetwork::HalfKP(net) => net,
_ => unreachable!(),
});
network.refresh_accumulator_halfkp(&pos, &mut stack);
let value = network.evaluate_halfkp(&pos, &stack);
eprintln!("HalfKP 256 evaluate: {}", value.raw());
assert!(value.raw().abs() < 10000, "Evaluation {} is out of expected range", value.raw());
}
}