use super::accumulator_layer_stacks::AccumulatorStackLayerStacks;
use super::accumulator_stack_variant::AccumulatorStackVariant;
use super::activation::detect_activation_from_arch;
use super::constants::{MAX_ARCH_LEN, NNUE_VERSION, NNUE_VERSION_HALFKA};
use super::halfka::{HalfKANetwork, HalfKAStack};
use super::halfka_hm::{HalfKA_hmNetwork, HalfKA_hmStack};
use super::halfkp::{HalfKPNetwork, HalfKPStack};
use super::network_layer_stacks::NetworkLayerStacks;
use super::spec::{Activation, FeatureSet};
use super::stats::{count_already_computed, count_refresh, count_update};
use crate::eval::material;
use crate::position::Position;
use crate::types::Value;
use std::fs::File;
use std::io::{self, BufReader, Cursor, Read, Seek, SeekFrom};
use std::path::Path;
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::OnceLock;
static NETWORK: OnceLock<NNUENetwork> = OnceLock::new();
static FV_SCALE_OVERRIDE: AtomicI32 = AtomicI32::new(0);
pub fn get_fv_scale_override() -> Option<i32> {
let value = FV_SCALE_OVERRIDE.load(Ordering::Relaxed);
if value > 0 {
Some(value)
} else {
None
}
}
pub fn set_fv_scale_override(value: i32) {
FV_SCALE_OVERRIDE.store(value.max(0), Ordering::Relaxed);
}
pub enum NNUENetwork {
HalfKA(HalfKANetwork),
#[allow(non_camel_case_types)]
HalfKA_hm(HalfKA_hmNetwork),
HalfKP(HalfKPNetwork),
LayerStacks(Box<NetworkLayerStacks>),
}
impl NNUENetwork {
pub fn supported_halfkp_specs() -> Vec<super::spec::ArchitectureSpec> {
HalfKPNetwork::supported_specs()
}
pub fn supported_halfka_hm_specs() -> Vec<super::spec::ArchitectureSpec> {
HalfKA_hmNetwork::supported_specs()
}
pub fn supported_halfka_specs() -> Vec<super::spec::ArchitectureSpec> {
HalfKANetwork::supported_specs()
}
pub fn load<P: AsRef<Path>>(path: P) -> io::Result<Self> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
Self::read(&mut reader)
}
pub fn read<R: Read + Seek>(reader: &mut R) -> io::Result<Self> {
let file_size = reader.seek(SeekFrom::End(0))?;
reader.seek(SeekFrom::Start(0))?;
let mut buf4 = [0u8; 4];
reader.read_exact(&mut buf4)?;
let version = u32::from_le_bytes(buf4);
match version {
NNUE_VERSION | NNUE_VERSION_HALFKA => {
reader.read_exact(&mut buf4)?; reader.read_exact(&mut buf4)?; let arch_len = u32::from_le_bytes(buf4) as usize;
if arch_len == 0 || arch_len > MAX_ARCH_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid arch string length: {arch_len}"),
));
}
let mut arch = vec![0u8; arch_len];
reader.read_exact(&mut arch)?;
let arch_str = String::from_utf8_lossy(&arch);
let activation_str = detect_activation_from_arch(&arch_str);
let activation = match activation_str {
"SCReLU" => Activation::SCReLU,
"PairwiseCReLU" => Activation::PairwiseCReLU,
_ => Activation::CReLU,
};
let parsed = super::spec::parse_architecture(&arch_str).map_err(|msg| {
io::Error::new(io::ErrorKind::InvalidData, msg)
})?;
if parsed.feature_set == FeatureSet::LayerStacks {
reader.seek(SeekFrom::Start(0))?;
let network = NetworkLayerStacks::read(reader)?;
return Ok(Self::LayerStacks(Box::new(network)));
}
let detection = super::spec::detect_architecture_from_size(
file_size,
arch_len,
Some(parsed.feature_set),
)
.ok_or_else(|| {
let candidates =
super::spec::list_candidate_architectures(file_size, arch_len);
let candidates_str: Vec<String> = candidates
.iter()
.take(5)
.map(|(spec, diff)| format!("{} (diff: {:+})", spec.name(), diff))
.collect();
io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unknown architecture: file_size={}, arch_len={}, feature_set={}. \
Closest candidates: [{}]",
file_size,
arch_len,
parsed.feature_set,
candidates_str.join(", ")
),
)
})?;
reader.seek(SeekFrom::Start(0))?;
let l1 = detection.spec.l1;
let l2 = detection.spec.l2;
let l3 = detection.spec.l3;
match detection.spec.feature_set {
FeatureSet::HalfKA_hm => {
let network = HalfKA_hmNetwork::read(reader, l1, l2, l3, activation)?;
Ok(Self::HalfKA_hm(network))
}
FeatureSet::HalfKA => {
let network = HalfKANetwork::read(reader, l1, l2, l3, activation)?;
Ok(Self::HalfKA(network))
}
FeatureSet::HalfKP => {
let network = HalfKPNetwork::read(reader, l1, l2, l3, activation)?;
Ok(Self::HalfKP(network))
}
FeatureSet::LayerStacks => {
unreachable!()
}
}
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unknown NNUE version: {version:#x}. Expected {NNUE_VERSION:#x} (HalfKP) or {NNUE_VERSION_HALFKA:#x} (HalfKA_hm^)"
),
)),
}
}
pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
let mut cursor = Cursor::new(bytes);
Self::read(&mut cursor)
}
pub fn is_layer_stacks(&self) -> bool {
matches!(self, Self::LayerStacks(_))
}
pub fn is_halfka(&self) -> bool {
matches!(self, Self::HalfKA(_))
}
pub fn is_halfka_hm(&self) -> bool {
matches!(self, Self::HalfKA_hm(_))
}
pub fn is_halfkp(&self) -> bool {
matches!(self, Self::HalfKP(_))
}
pub fn l1_size(&self) -> usize {
match self {
Self::HalfKA(net) => net.l1_size(),
Self::HalfKA_hm(net) => net.l1_size(),
Self::HalfKP(net) => net.l1_size(),
Self::LayerStacks(_) => 1536,
}
}
pub fn architecture_name(&self) -> &'static str {
match self {
Self::HalfKA(net) => net.architecture_name(),
Self::HalfKA_hm(net) => net.architecture_name(),
Self::HalfKP(net) => net.architecture_name(),
Self::LayerStacks(_) => "LayerStacks",
}
}
pub fn architecture_spec(&self) -> super::spec::ArchitectureSpec {
match self {
Self::HalfKA(net) => net.architecture_spec(),
Self::HalfKA_hm(net) => net.architecture_spec(),
Self::HalfKP(net) => net.architecture_spec(),
Self::LayerStacks(_) => super::spec::ArchitectureSpec::new(
super::spec::FeatureSet::LayerStacks,
1536,
0,
0,
Activation::CReLU,
),
}
}
pub fn refresh_accumulator_layer_stacks(
&self,
pos: &Position,
acc: &mut super::accumulator_layer_stacks::AccumulatorLayerStacks,
) {
match self {
Self::LayerStacks(net) => net.refresh_accumulator(pos, acc),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn update_accumulator_layer_stacks(
&self,
pos: &Position,
dirty_piece: &super::accumulator::DirtyPiece,
acc: &mut super::accumulator_layer_stacks::AccumulatorLayerStacks,
prev_acc: &super::accumulator_layer_stacks::AccumulatorLayerStacks,
) {
match self {
Self::LayerStacks(net) => net.update_accumulator(pos, dirty_piece, acc, prev_acc),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn forward_update_incremental_layer_stacks(
&self,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
source_idx: usize,
) -> bool {
match self {
Self::LayerStacks(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn evaluate_layer_stacks(
&self,
pos: &Position,
acc: &super::accumulator_layer_stacks::AccumulatorLayerStacks,
) -> Value {
match self {
Self::LayerStacks(net) => net.evaluate(pos, acc),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn refresh_accumulator_halfka_hm(&self, pos: &Position, stack: &mut HalfKA_hmStack) {
match self {
Self::HalfKA_hm(net) => net.refresh_accumulator(pos, stack),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn refresh_accumulator_halfka(&self, pos: &Position, stack: &mut HalfKAStack) {
match self {
Self::HalfKA(net) => net.refresh_accumulator(pos, stack),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn update_accumulator_halfka_hm(
&self,
pos: &Position,
dirty: &super::accumulator::DirtyPiece,
stack: &mut HalfKA_hmStack,
source_idx: usize,
) {
match self {
Self::HalfKA_hm(net) => net.update_accumulator(pos, dirty, stack, source_idx),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn update_accumulator_halfka(
&self,
pos: &Position,
dirty: &super::accumulator::DirtyPiece,
stack: &mut HalfKAStack,
source_idx: usize,
) {
match self {
Self::HalfKA(net) => net.update_accumulator(pos, dirty, stack, source_idx),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn forward_update_incremental_halfka_hm(
&self,
pos: &Position,
stack: &mut HalfKA_hmStack,
source_idx: usize,
) -> bool {
match self {
Self::HalfKA_hm(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn forward_update_incremental_halfka(
&self,
pos: &Position,
stack: &mut HalfKAStack,
source_idx: usize,
) -> bool {
match self {
Self::HalfKA(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn evaluate_halfka_hm(&self, pos: &Position, stack: &HalfKA_hmStack) -> Value {
match self {
Self::HalfKA_hm(net) => net.evaluate(pos, stack),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn evaluate_halfka(&self, pos: &Position, stack: &HalfKAStack) -> Value {
match self {
Self::HalfKA(net) => net.evaluate(pos, stack),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn refresh_accumulator_halfkp(&self, pos: &Position, stack: &mut HalfKPStack) {
match self {
Self::HalfKP(net) => net.refresh_accumulator(pos, stack),
_ => panic!("This method is only for HalfKP architecture."),
}
}
pub fn update_accumulator_halfkp(
&self,
pos: &Position,
dirty: &super::accumulator::DirtyPiece,
stack: &mut HalfKPStack,
source_idx: usize,
) {
match self {
Self::HalfKP(net) => net.update_accumulator(pos, dirty, stack, source_idx),
_ => panic!("This method is only for HalfKP architecture."),
}
}
pub fn forward_update_incremental_halfkp(
&self,
pos: &Position,
stack: &mut HalfKPStack,
source_idx: usize,
) -> bool {
match self {
Self::HalfKP(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for HalfKP architecture."),
}
}
pub fn evaluate_halfkp(&self, pos: &Position, stack: &HalfKPStack) -> Value {
match self {
Self::HalfKP(net) => net.evaluate(pos, stack),
_ => panic!("This method is only for HalfKP architecture."),
}
}
}
pub fn parse_fv_scale_from_arch(arch_str: &str) -> Option<i32> {
const FV_SCALE_MIN: i32 = 1;
const FV_SCALE_MAX: i32 = 128;
for part in arch_str.split(',') {
if let Some(value) = part.strip_prefix("fv_scale=") {
if let Ok(scale) = value.parse::<i32>() {
if (FV_SCALE_MIN..=FV_SCALE_MAX).contains(&scale) {
return Some(scale);
}
}
return None;
}
}
None
}
pub fn init_nnue<P: AsRef<Path>>(path: P) -> io::Result<()> {
let network = NNUENetwork::load(path)?;
NETWORK
.set(network)
.map_err(|_| io::Error::new(io::ErrorKind::AlreadyExists, "NNUE already initialized"))
}
pub fn init_nnue_from_bytes(bytes: &[u8]) -> io::Result<()> {
let network = NNUENetwork::from_bytes(bytes)?;
NETWORK
.set(network)
.map_err(|_| io::Error::new(io::ErrorKind::AlreadyExists, "NNUE already initialized"))
}
pub fn is_nnue_initialized() -> bool {
NETWORK.get().is_some()
}
#[derive(Debug, Clone)]
pub struct NnueFormatInfo {
pub architecture: String,
pub l1_dimension: u32,
pub l2_dimension: u32,
pub l3_dimension: u32,
pub activation: String,
pub version: u32,
pub arch_string: String,
}
pub fn detect_format(bytes: &[u8], file_size: u64) -> io::Result<NnueFormatInfo> {
const MIN_HEADER_SIZE: usize = 12;
if bytes.len() < MIN_HEADER_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"NNUE file too small: {} bytes (need at least {} for header)",
bytes.len(),
MIN_HEADER_SIZE
),
));
}
let version = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
match version {
NNUE_VERSION | NNUE_VERSION_HALFKA => {
let arch_len = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
if arch_len == 0 || arch_len > MAX_ARCH_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid arch string length: {} (max: {})", arch_len, MAX_ARCH_LEN),
));
}
let required_size = MIN_HEADER_SIZE + arch_len;
if bytes.len() < required_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"NNUE file too small: {} bytes (need {} for arch string)",
bytes.len(),
required_size
),
));
}
let arch_str = String::from_utf8_lossy(&bytes[12..12 + arch_len]).to_string();
let activation = detect_activation_from_arch(&arch_str).to_string();
let parsed = super::spec::parse_architecture(&arch_str)
.map_err(|msg| io::Error::new(io::ErrorKind::InvalidData, msg))?;
let (l1, l2, l3, feature_set, used_file_size_detection) = if let Some(detection) =
super::spec::detect_architecture_from_size(
file_size,
arch_len,
Some(parsed.feature_set),
) {
(
detection.spec.l1,
detection.spec.l2,
detection.spec.l3,
detection.spec.feature_set,
true,
)
} else {
(parsed.l1, parsed.l2, parsed.l3, parsed.feature_set, false)
};
#[cfg(debug_assertions)]
if !used_file_size_detection {
eprintln!(
"Warning: File size detection failed for size={}. \
Falling back to header parsing (may be inaccurate).",
file_size
);
}
let _ = used_file_size_detection;
let architecture = match feature_set {
FeatureSet::LayerStacks => "LayerStacks".to_string(),
FeatureSet::HalfKA_hm => format!("HalfKA_hm{}", l1),
FeatureSet::HalfKA => format!("HalfKA{}", l1),
FeatureSet::HalfKP => format!("HalfKP{}", l1),
};
Ok(NnueFormatInfo {
architecture,
l1_dimension: l1 as u32,
l2_dimension: l2 as u32,
l3_dimension: l3 as u32,
activation,
version,
arch_string: arch_str,
})
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Unknown NNUE version: 0x{version:08X}"),
)),
}
}
pub fn get_network() -> Option<&'static NNUENetwork> {
NETWORK.get()
}
#[inline]
fn update_and_evaluate_layer_stacks(
network: &NNUENetwork,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
) -> Value {
let current_entry = stack.current();
if !current_entry.accumulator.computed_accumulation {
let mut updated = false;
if let Some(prev_idx) = current_entry.previous {
let prev_computed = stack.entry_at(prev_idx).accumulator.computed_accumulation;
if prev_computed {
let dirty_piece = stack.current().dirty_piece;
let (prev_acc, current_acc) = stack.get_prev_and_current_accumulators(prev_idx);
network.update_accumulator_layer_stacks(pos, &dirty_piece, current_acc, prev_acc);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_layer_stacks(pos, stack, source_idx);
}
}
if !updated {
let acc = &mut stack.current_mut().accumulator;
network.refresh_accumulator_layer_stacks(pos, acc);
}
}
let acc_ref = &stack.current().accumulator;
network.evaluate_layer_stacks(pos, acc_ref)
}
#[inline]
fn update_and_evaluate_halfka_hm(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKA_hmStack,
) -> Value {
if !stack.is_current_computed() {
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka_hm(pos, &dirty, stack, prev_idx);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_halfka_hm(pos, stack, source_idx);
}
}
if !updated {
network.refresh_accumulator_halfka_hm(pos, stack);
}
}
network.evaluate_halfka_hm(pos, stack)
}
#[inline]
fn update_and_evaluate_halfka(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKAStack,
) -> Value {
if !stack.is_current_computed() {
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka(pos, &dirty, stack, prev_idx);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_halfka(pos, stack, source_idx);
}
}
if !updated {
network.refresh_accumulator_halfka(pos, stack);
}
}
network.evaluate_halfka(pos, stack)
}
#[inline]
fn update_and_evaluate_halfkp(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKPStack,
) -> Value {
if !stack.is_current_computed() {
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfkp(pos, &dirty, stack, prev_idx);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_halfkp(pos, stack, source_idx);
}
}
if !updated {
network.refresh_accumulator_halfkp(pos, stack);
}
}
network.evaluate_halfkp(pos, stack)
}
pub fn is_layer_stacks_loaded() -> bool {
NETWORK.get().map(|n| n.is_layer_stacks()).unwrap_or(false)
}
pub fn is_halfka_hm_256_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka_hm() && n.l1_size() == 256).unwrap_or(false)
}
pub fn is_halfka_256_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka() && n.l1_size() == 256).unwrap_or(false)
}
pub fn is_halfka_hm_512_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka_hm() && n.l1_size() == 512).unwrap_or(false)
}
pub fn is_halfka_512_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka() && n.l1_size() == 512).unwrap_or(false)
}
pub fn is_halfka_hm_1024_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka_hm() && n.l1_size() == 1024).unwrap_or(false)
}
pub fn is_halfka_1024_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka() && n.l1_size() == 1024).unwrap_or(false)
}
pub fn evaluate_layer_stacks(pos: &Position, stack: &mut AccumulatorStackLayerStacks) -> Value {
let Some(network) = NETWORK.get() else {
return material::evaluate_material(pos);
};
if !network.is_layer_stacks() {
panic!("Non-LayerStacks architecture detected. Use evaluate() with AccumulatorStack.");
}
update_and_evaluate_layer_stacks(network, pos, stack)
}
pub fn evaluate_dispatch(pos: &Position, stack: &mut AccumulatorStackVariant) -> Value {
let Some(network) = NETWORK.get() else {
return material::evaluate_material(pos);
};
match stack {
AccumulatorStackVariant::LayerStacks(s) => {
update_and_evaluate_layer_stacks(network, pos, s)
}
AccumulatorStackVariant::HalfKA(s) => update_and_evaluate_halfka(network, pos, s),
AccumulatorStackVariant::HalfKA_hm(s) => update_and_evaluate_halfka_hm(network, pos, s),
AccumulatorStackVariant::HalfKP(s) => update_and_evaluate_halfkp(network, pos, s),
}
}
pub fn ensure_accumulator_computed(pos: &Position, stack: &mut AccumulatorStackVariant) {
let Some(network) = NETWORK.get() else {
return;
};
match stack {
AccumulatorStackVariant::LayerStacks(s) => {
update_accumulator_only_layer_stacks(network, pos, s);
}
AccumulatorStackVariant::HalfKA(s) => {
update_accumulator_only_halfka(network, pos, s);
}
AccumulatorStackVariant::HalfKA_hm(s) => {
update_accumulator_only_halfka_hm(network, pos, s);
}
AccumulatorStackVariant::HalfKP(s) => {
update_accumulator_only_halfkp(network, pos, s);
}
}
}
#[inline]
fn update_accumulator_only_layer_stacks(
network: &NNUENetwork,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
) {
let current_entry = stack.current();
if current_entry.accumulator.computed_accumulation {
count_already_computed!();
return;
}
let mut updated = false;
if let Some(prev_idx) = current_entry.previous {
let prev_computed = stack.entry_at(prev_idx).accumulator.computed_accumulation;
if prev_computed {
let dirty_piece = stack.current().dirty_piece;
let (prev_acc, current_acc) = stack.get_prev_and_current_accumulators(prev_idx);
network.update_accumulator_layer_stacks(pos, &dirty_piece, current_acc, prev_acc);
count_update!();
updated = true;
}
}
if !updated {
let acc = &mut stack.current_mut().accumulator;
network.refresh_accumulator_layer_stacks(pos, acc);
count_refresh!();
}
}
#[inline]
fn update_accumulator_only_halfka_hm(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKA_hmStack,
) {
if stack.is_current_computed() {
count_already_computed!();
return;
}
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka_hm(pos, &dirty, stack, prev_idx);
count_update!();
updated = true;
}
}
if !updated {
network.refresh_accumulator_halfka_hm(pos, stack);
count_refresh!();
}
}
#[inline]
fn update_accumulator_only_halfka(network: &NNUENetwork, pos: &Position, stack: &mut HalfKAStack) {
if stack.is_current_computed() {
count_already_computed!();
return;
}
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka(pos, &dirty, stack, prev_idx);
count_update!();
updated = true;
}
}
if !updated {
network.refresh_accumulator_halfka(pos, stack);
count_refresh!();
}
}
#[inline]
fn update_accumulator_only_halfkp(network: &NNUENetwork, pos: &Position, stack: &mut HalfKPStack) {
if stack.is_current_computed() {
count_already_computed!();
return;
}
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfkp(pos, &dirty, stack, prev_idx);
count_update!();
updated = true;
}
}
if !updated {
network.refresh_accumulator_halfkp(pos, stack);
count_refresh!();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::position::SFEN_HIRATE;
#[test]
fn test_evaluate_fallback() {
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let mut stack = AccumulatorStackVariant::new_default();
let value = evaluate_dispatch(&pos, &mut stack);
assert!(value.raw().abs() < 1000);
}
#[test]
fn test_accumulator_stack_variant_fallback() {
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let mut stack = AccumulatorStackVariant::new_default();
let value1 = evaluate_dispatch(&pos, &mut stack);
let value2 = evaluate_dispatch(&pos, &mut stack);
let _ = (value1, value2);
}
#[test]
#[ignore]
fn test_nnue_network_auto_detect_layer_stacks() {
let path = std::env::var("NNUE_TEST_FILE")
.unwrap_or_else(|_| "/path/to/your/layer_stacks.nnue".to_string());
let network = match NNUENetwork::load(path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_layer_stacks(), "epoch82.nnue should be detected as LayerStacks");
assert_eq!(network.architecture_name(), "LayerStacks");
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let mut acc = crate::nnue::AccumulatorLayerStacks::new();
network.refresh_accumulator_layer_stacks(&pos, &mut acc);
let value = network.evaluate_layer_stacks(&pos, &acc);
eprintln!("LayerStacks evaluate: {}", value.raw());
assert!(value.raw().abs() < 1000);
}
#[test]
#[ignore]
fn test_detect_format_aoba() {
let path = std::env::var("NNUE_AOBA_FILE").unwrap_or_else(|_| "AobaNNUE.bin".to_string());
let bytes = match std::fs::read(&path) {
Ok(b) => b,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
let file_size = bytes.len() as u64;
let info = detect_format(&bytes, file_size).expect("Failed to detect format");
eprintln!("File: {path}");
eprintln!("Architecture: {}", info.architecture);
eprintln!(
"L1: {}, L2: {}, L3: {}",
info.l1_dimension, info.l2_dimension, info.l3_dimension
);
eprintln!("Activation: {}", info.activation);
eprintln!("Arch string (header): {}", info.arch_string);
assert_eq!(
info.architecture, "HalfKP768",
"Should detect HalfKP768 from file size, not HalfKP256 from header"
);
assert_eq!(info.l1_dimension, 768, "L1 should be 768, not 256 from header");
assert_eq!(info.l2_dimension, 16, "L2 should be 16");
assert_eq!(info.l3_dimension, 64, "L3 should be 64");
assert!(
info.arch_string.contains("256"),
"Header should claim 256, but file size detection should override it"
);
}
#[test]
fn test_detect_format_fallback_to_header() {
let unknown_file_size = 12345678u64;
let mut bytes = Vec::new();
bytes.extend_from_slice(&NNUE_VERSION_HALFKA.to_le_bytes()); bytes.extend_from_slice(&0u32.to_le_bytes());
let arch_str = "Features=HalfKA_hm[73305->512x2],l2=8,l3=96";
let arch_len = arch_str.len() as u32;
bytes.extend_from_slice(&arch_len.to_le_bytes());
bytes.extend_from_slice(arch_str.as_bytes());
let info =
detect_format(&bytes, unknown_file_size).expect("Should fallback to header parsing");
assert_eq!(info.architecture, "HalfKA_hm512");
assert_eq!(info.l1_dimension, 512);
assert_eq!(info.l2_dimension, 8);
assert_eq!(info.l3_dimension, 96);
}
#[test]
fn test_detect_format_error_cases() {
let bytes = vec![0u8; 5];
let result = detect_format(&bytes, 5);
assert!(result.is_err(), "Should fail for too small file");
assert!(
result.unwrap_err().to_string().contains("too small"),
"Error message should mention 'too small'"
);
let mut bytes = Vec::new();
bytes.extend_from_slice(&NNUE_VERSION.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes()); let result = detect_format(&bytes, 100);
assert!(result.is_err(), "Should fail for arch_len = 0");
assert!(
result.unwrap_err().to_string().contains("Invalid arch string length"),
"Error message should mention invalid arch string length"
);
let mut bytes = Vec::new();
bytes.extend_from_slice(&NNUE_VERSION.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes());
bytes.extend_from_slice(&(MAX_ARCH_LEN as u32 + 1).to_le_bytes());
let result = detect_format(&bytes, 100);
assert!(result.is_err(), "Should fail for arch_len > MAX_ARCH_LEN");
let mut bytes = Vec::new();
bytes.extend_from_slice(&NNUE_VERSION.to_le_bytes());
bytes.extend_from_slice(&0u32.to_le_bytes());
bytes.extend_from_slice(&100u32.to_le_bytes()); let result = detect_format(&bytes, 1000);
assert!(result.is_err(), "Should fail when buffer is too small for arch_str");
let mut bytes = Vec::new();
bytes.extend_from_slice(&0xDEADBEEFu32.to_le_bytes());
bytes.extend_from_slice(&[0u8; 100]);
let result = detect_format(&bytes, 112);
assert!(result.is_err(), "Should fail for unknown version");
assert!(
result.unwrap_err().to_string().contains("Unknown NNUE version"),
"Error message should mention unknown version"
);
}
#[test]
fn test_parse_fv_scale_from_arch() {
assert_eq!(
parse_fv_scale_from_arch(
"Features=HalfKA_hm^[73305->256x2]-SCReLU,fv_scale=13,qa=127,qb=64,scale=600"
),
Some(13)
);
assert_eq!(
parse_fv_scale_from_arch(
"Features=HalfKA_hm^[73305->512x2]-SCReLU,fv_scale=20,qa=127,qb=64,scale=400"
),
Some(20)
);
assert_eq!(
parse_fv_scale_from_arch(
"Features=HalfKA_hm^[73305->1024x2]-SCReLU,fv_scale=16,qa=127,qb=64,scale=508"
),
Some(16)
);
assert_eq!(parse_fv_scale_from_arch("Features=HalfKP[125388->256x2]"), None);
assert_eq!(parse_fv_scale_from_arch("Features=HalfKA_hm^[73305->512x2]"), None);
assert_eq!(parse_fv_scale_from_arch(""), None);
assert_eq!(
parse_fv_scale_from_arch("Features=HalfKA_hm^[73305->256x2],fv_scale=abc"),
None
);
}
#[test]
fn test_parse_fv_scale_edge_cases() {
assert_eq!(parse_fv_scale_from_arch("fv_scale=1"), Some(1));
assert_eq!(parse_fv_scale_from_arch("fv_scale=128"), Some(128));
assert_eq!(parse_fv_scale_from_arch("fv_scale=64"), Some(64));
assert_eq!(parse_fv_scale_from_arch("fv_scale=0"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=129"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=-1"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=-100"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=99999"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=2147483647"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale= 16"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=16 "), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=10,fv_scale=20"), Some(10));
assert_eq!(parse_fv_scale_from_arch("fv_scale="), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=16.5"), None);
assert_eq!(parse_fv_scale_from_arch("my_fv_scale=16"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale_v2=16"), None);
}
#[test]
#[ignore]
fn test_nnue_halfkp_768_auto_detect() {
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("Failed to find workspace root");
let default_path = workspace_root
.join("eval/halfkp_768x2-16-64_crelu/AobaNNUE_HalfKP_768x2_16_64_FV_SCALE_40.bin");
let path = std::env::var("NNUE_HALFKP_768_FILE")
.unwrap_or_else(|_| default_path.display().to_string());
let network = match NNUENetwork::load(&path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_halfkp(), "File should be detected as HalfKP");
assert_eq!(network.l1_size(), 768, "L1 should be 768");
let spec = network.architecture_spec();
assert_eq!(spec.l1, 768, "spec.l1 should be 768");
assert_eq!(spec.l2, 16, "spec.l2 should be 16");
assert_eq!(spec.l3, 64, "spec.l3 should be 64");
eprintln!("Successfully loaded HalfKP 768x2-16-64 network");
eprintln!("Architecture name: {}", network.architecture_name());
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
use crate::nnue::halfkp::HalfKPStack;
let mut stack = HalfKPStack::from_network(match &network {
NNUENetwork::HalfKP(net) => net,
_ => unreachable!(),
});
network.refresh_accumulator_halfkp(&pos, &mut stack);
let value = network.evaluate_halfkp(&pos, &stack);
eprintln!("HalfKP 768 evaluate: {}", value.raw());
assert!(value.raw().abs() < 10000, "Evaluation {} is out of expected range", value.raw());
}
#[test]
#[ignore]
fn test_nnue_halfka_hm_256_auto_detect() {
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("Failed to find workspace root");
let default_path = workspace_root.join("eval/halfka_hm_256x2-32-32_crelu/v28_epoch65.nnue");
let path = std::env::var("NNUE_HALFKA_HM_256_FILE")
.unwrap_or_else(|_| default_path.display().to_string());
let network = match NNUENetwork::load(&path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_halfka_hm(), "File should be detected as HalfKA_hm");
assert_eq!(network.l1_size(), 256, "L1 should be 256");
let spec = network.architecture_spec();
assert_eq!(spec.l1, 256, "spec.l1 should be 256");
assert_eq!(spec.l2, 32, "spec.l2 should be 32");
assert_eq!(spec.l3, 32, "spec.l3 should be 32");
eprintln!("Successfully loaded HalfKA_hm 256x2-32-32 network");
eprintln!("Architecture name: {}", network.architecture_name());
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
use crate::nnue::halfka_hm::HalfKA_hmStack;
let mut stack = HalfKA_hmStack::from_network(match &network {
NNUENetwork::HalfKA_hm(net) => net,
_ => unreachable!(),
});
network.refresh_accumulator_halfka_hm(&pos, &mut stack);
let value = network.evaluate_halfka_hm(&pos, &stack);
eprintln!("HalfKA_hm 256 evaluate: {}", value.raw());
assert!(value.raw().abs() < 10000, "Evaluation {} is out of expected range", value.raw());
}
#[test]
#[ignore]
fn test_nnue_halfka_hm_1024_auto_detect() {
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("Failed to find workspace root");
let default_path = workspace_root.join("eval/halfka_hm_1024x2-8-96_crelu/epoch20_v2.nnue");
let path = std::env::var("NNUE_HALFKA_HM_1024_FILE")
.unwrap_or_else(|_| default_path.display().to_string());
let network = match NNUENetwork::load(&path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_halfka_hm(), "File should be detected as HalfKA_hm");
assert_eq!(network.l1_size(), 1024, "L1 should be 1024");
let spec = network.architecture_spec();
assert_eq!(spec.l1, 1024, "spec.l1 should be 1024");
assert_eq!(spec.l2, 8, "spec.l2 should be 8");
assert_eq!(spec.l3, 96, "spec.l3 should be 96");
eprintln!("Successfully loaded HalfKA_hm 1024x2-8-96 network");
eprintln!("Architecture name: {}", network.architecture_name());
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
use crate::nnue::halfka_hm::HalfKA_hmStack;
let mut stack = HalfKA_hmStack::from_network(match &network {
NNUENetwork::HalfKA_hm(net) => net,
_ => unreachable!(),
});
network.refresh_accumulator_halfka_hm(&pos, &mut stack);
let value = network.evaluate_halfka_hm(&pos, &stack);
eprintln!("HalfKA_hm 1024 evaluate: {}", value.raw());
assert!(value.raw().abs() < 10000, "Evaluation {} is out of expected range", value.raw());
}
#[test]
#[ignore]
fn test_nnue_halfkp_256_suisho5() {
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.and_then(|p| p.parent())
.expect("Failed to find workspace root");
let default_path = workspace_root.join("eval/halfkp_256x2-32-32_crelu/suisho5.bin");
let path = std::env::var("NNUE_HALFKP_256_FILE")
.unwrap_or_else(|_| default_path.display().to_string());
let network = match NNUENetwork::load(&path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_halfkp(), "File should be detected as HalfKP");
assert_eq!(network.l1_size(), 256, "L1 should be 256");
let spec = network.architecture_spec();
assert_eq!(spec.l1, 256, "spec.l1 should be 256");
assert_eq!(spec.l2, 32, "spec.l2 should be 32");
assert_eq!(spec.l3, 32, "spec.l3 should be 32");
eprintln!("Successfully loaded HalfKP 256x2-32-32 network (suisho5)");
eprintln!("Architecture name: {}", network.architecture_name());
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
use crate::nnue::halfkp::HalfKPStack;
let mut stack = HalfKPStack::from_network(match &network {
NNUENetwork::HalfKP(net) => net,
_ => unreachable!(),
});
network.refresh_accumulator_halfkp(&pos, &mut stack);
let value = network.evaluate_halfkp(&pos, &stack);
eprintln!("HalfKP 256 evaluate: {}", value.raw());
assert!(value.raw().abs() < 10000, "Evaluation {} is out of expected range", value.raw());
}
}