use super::accumulator_layer_stacks::AccumulatorStackLayerStacks;
use super::accumulator_stack_variant::AccumulatorStackVariant;
use super::activation::detect_activation_from_arch;
use super::constants::{MAX_ARCH_LEN, NNUE_VERSION, NNUE_VERSION_HALFKA};
use super::halfka::{HalfKANetwork, HalfKAStack};
use super::halfka_hm::{HalfKA_hmNetwork, HalfKA_hmStack};
use super::halfkp::{HalfKPNetwork, HalfKPStack};
use super::network_layer_stacks::NetworkLayerStacks;
use super::spec::{Activation, FeatureSet};
use crate::eval::material;
use crate::position::Position;
use crate::types::Value;
use std::fs::File;
use std::io::{self, BufReader, Cursor, Read, Seek, SeekFrom};
use std::path::Path;
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::OnceLock;
static NETWORK: OnceLock<NNUENetwork> = OnceLock::new();
static FV_SCALE_OVERRIDE: AtomicI32 = AtomicI32::new(0);
pub fn get_fv_scale_override() -> Option<i32> {
let value = FV_SCALE_OVERRIDE.load(Ordering::Relaxed);
if value > 0 {
Some(value)
} else {
None
}
}
pub fn set_fv_scale_override(value: i32) {
FV_SCALE_OVERRIDE.store(value.max(0), Ordering::Relaxed);
}
pub enum NNUENetwork {
HalfKA(HalfKANetwork),
#[allow(non_camel_case_types)]
HalfKA_hm(HalfKA_hmNetwork),
HalfKP(HalfKPNetwork),
LayerStacks(Box<NetworkLayerStacks>),
}
impl NNUENetwork {
pub fn supported_halfkp_specs() -> Vec<super::spec::ArchitectureSpec> {
HalfKPNetwork::supported_specs()
}
pub fn supported_halfka_hm_specs() -> Vec<super::spec::ArchitectureSpec> {
HalfKA_hmNetwork::supported_specs()
}
pub fn supported_halfka_specs() -> Vec<super::spec::ArchitectureSpec> {
HalfKANetwork::supported_specs()
}
pub fn load<P: AsRef<Path>>(path: P) -> io::Result<Self> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
Self::read(&mut reader)
}
pub fn read<R: Read + Seek>(reader: &mut R) -> io::Result<Self> {
let mut buf4 = [0u8; 4];
reader.read_exact(&mut buf4)?;
let version = u32::from_le_bytes(buf4);
match version {
NNUE_VERSION | NNUE_VERSION_HALFKA => {
reader.read_exact(&mut buf4)?;
reader.read_exact(&mut buf4)?;
let arch_len = u32::from_le_bytes(buf4) as usize;
if arch_len == 0 || arch_len > MAX_ARCH_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid arch string length: {arch_len}"),
));
}
let mut arch = vec![0u8; arch_len];
reader.read_exact(&mut arch)?;
let arch_str = String::from_utf8_lossy(&arch);
reader.seek(SeekFrom::Start(0))?;
let activation_str = detect_activation_from_arch(&arch_str);
let activation = match activation_str {
"SCReLU" => Activation::SCReLU,
"PairwiseCReLU" => Activation::PairwiseCReLU,
_ => Activation::CReLU,
};
let parsed = super::spec::parse_architecture(&arch_str).map_err(|msg| {
io::Error::new(io::ErrorKind::InvalidData, msg)
})?;
match parsed.feature_set {
FeatureSet::LayerStacks => {
let network = NetworkLayerStacks::read(reader)?;
Ok(Self::LayerStacks(Box::new(network)))
}
FeatureSet::HalfKA_hm => {
let network = HalfKA_hmNetwork::read(
reader,
parsed.l1,
parsed.l2,
parsed.l3,
activation,
)?;
Ok(Self::HalfKA_hm(network))
}
FeatureSet::HalfKA => {
let network =
HalfKANetwork::read(reader, parsed.l1, parsed.l2, parsed.l3, activation)?;
Ok(Self::HalfKA(network))
}
FeatureSet::HalfKP => {
let network =
HalfKPNetwork::read(reader, parsed.l1, parsed.l2, parsed.l3, activation)?;
Ok(Self::HalfKP(network))
}
}
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Unknown NNUE version: {version:#x}. Expected {NNUE_VERSION:#x} (HalfKP) or {NNUE_VERSION_HALFKA:#x} (HalfKA_hm^)"
),
)),
}
}
pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
let mut cursor = Cursor::new(bytes);
Self::read(&mut cursor)
}
pub fn is_layer_stacks(&self) -> bool {
matches!(self, Self::LayerStacks(_))
}
pub fn is_halfka(&self) -> bool {
matches!(self, Self::HalfKA(_))
}
pub fn is_halfka_hm(&self) -> bool {
matches!(self, Self::HalfKA_hm(_))
}
pub fn is_halfkp(&self) -> bool {
matches!(self, Self::HalfKP(_))
}
pub fn l1_size(&self) -> usize {
match self {
Self::HalfKA(net) => net.l1_size(),
Self::HalfKA_hm(net) => net.l1_size(),
Self::HalfKP(net) => net.l1_size(),
Self::LayerStacks(_) => 1536,
}
}
pub fn architecture_name(&self) -> &'static str {
match self {
Self::HalfKA(net) => net.architecture_name(),
Self::HalfKA_hm(net) => net.architecture_name(),
Self::HalfKP(net) => net.architecture_name(),
Self::LayerStacks(_) => "LayerStacks",
}
}
pub fn architecture_spec(&self) -> super::spec::ArchitectureSpec {
match self {
Self::HalfKA(net) => net.architecture_spec(),
Self::HalfKA_hm(net) => net.architecture_spec(),
Self::HalfKP(net) => net.architecture_spec(),
Self::LayerStacks(_) => super::spec::ArchitectureSpec::new(
super::spec::FeatureSet::LayerStacks,
1536,
0,
0,
Activation::CReLU,
),
}
}
pub fn refresh_accumulator_layer_stacks(
&self,
pos: &Position,
acc: &mut super::accumulator_layer_stacks::AccumulatorLayerStacks,
) {
match self {
Self::LayerStacks(net) => net.refresh_accumulator(pos, acc),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn update_accumulator_layer_stacks(
&self,
pos: &Position,
dirty_piece: &super::accumulator::DirtyPiece,
acc: &mut super::accumulator_layer_stacks::AccumulatorLayerStacks,
prev_acc: &super::accumulator_layer_stacks::AccumulatorLayerStacks,
) {
match self {
Self::LayerStacks(net) => net.update_accumulator(pos, dirty_piece, acc, prev_acc),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn forward_update_incremental_layer_stacks(
&self,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
source_idx: usize,
) -> bool {
match self {
Self::LayerStacks(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn evaluate_layer_stacks(
&self,
pos: &Position,
acc: &super::accumulator_layer_stacks::AccumulatorLayerStacks,
) -> Value {
match self {
Self::LayerStacks(net) => net.evaluate(pos, acc),
_ => panic!("This method is only for LayerStacks architecture."),
}
}
pub fn refresh_accumulator_halfka_hm(&self, pos: &Position, stack: &mut HalfKA_hmStack) {
match self {
Self::HalfKA_hm(net) => net.refresh_accumulator(pos, stack),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn refresh_accumulator_halfka(&self, pos: &Position, stack: &mut HalfKAStack) {
match self {
Self::HalfKA(net) => net.refresh_accumulator(pos, stack),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn update_accumulator_halfka_hm(
&self,
pos: &Position,
dirty: &super::accumulator::DirtyPiece,
stack: &mut HalfKA_hmStack,
source_idx: usize,
) {
match self {
Self::HalfKA_hm(net) => net.update_accumulator(pos, dirty, stack, source_idx),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn update_accumulator_halfka(
&self,
pos: &Position,
dirty: &super::accumulator::DirtyPiece,
stack: &mut HalfKAStack,
source_idx: usize,
) {
match self {
Self::HalfKA(net) => net.update_accumulator(pos, dirty, stack, source_idx),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn forward_update_incremental_halfka_hm(
&self,
pos: &Position,
stack: &mut HalfKA_hmStack,
source_idx: usize,
) -> bool {
match self {
Self::HalfKA_hm(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn forward_update_incremental_halfka(
&self,
pos: &Position,
stack: &mut HalfKAStack,
source_idx: usize,
) -> bool {
match self {
Self::HalfKA(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn evaluate_halfka_hm(&self, pos: &Position, stack: &HalfKA_hmStack) -> Value {
match self {
Self::HalfKA_hm(net) => net.evaluate(pos, stack),
_ => panic!("This method is only for HalfKA_hm architecture."),
}
}
pub fn evaluate_halfka(&self, pos: &Position, stack: &HalfKAStack) -> Value {
match self {
Self::HalfKA(net) => net.evaluate(pos, stack),
_ => panic!("This method is only for HalfKA architecture."),
}
}
pub fn refresh_accumulator_halfkp(&self, pos: &Position, stack: &mut HalfKPStack) {
match self {
Self::HalfKP(net) => net.refresh_accumulator(pos, stack),
_ => panic!("This method is only for HalfKP architecture."),
}
}
pub fn update_accumulator_halfkp(
&self,
pos: &Position,
dirty: &super::accumulator::DirtyPiece,
stack: &mut HalfKPStack,
source_idx: usize,
) {
match self {
Self::HalfKP(net) => net.update_accumulator(pos, dirty, stack, source_idx),
_ => panic!("This method is only for HalfKP architecture."),
}
}
pub fn forward_update_incremental_halfkp(
&self,
pos: &Position,
stack: &mut HalfKPStack,
source_idx: usize,
) -> bool {
match self {
Self::HalfKP(net) => net.forward_update_incremental(pos, stack, source_idx),
_ => panic!("This method is only for HalfKP architecture."),
}
}
pub fn evaluate_halfkp(&self, pos: &Position, stack: &HalfKPStack) -> Value {
match self {
Self::HalfKP(net) => net.evaluate(pos, stack),
_ => panic!("This method is only for HalfKP architecture."),
}
}
}
pub fn parse_fv_scale_from_arch(arch_str: &str) -> Option<i32> {
const FV_SCALE_MIN: i32 = 1;
const FV_SCALE_MAX: i32 = 128;
for part in arch_str.split(',') {
if let Some(value) = part.strip_prefix("fv_scale=") {
if let Ok(scale) = value.parse::<i32>() {
if (FV_SCALE_MIN..=FV_SCALE_MAX).contains(&scale) {
return Some(scale);
}
}
return None;
}
}
None
}
pub fn init_nnue<P: AsRef<Path>>(path: P) -> io::Result<()> {
let network = NNUENetwork::load(path)?;
NETWORK
.set(network)
.map_err(|_| io::Error::new(io::ErrorKind::AlreadyExists, "NNUE already initialized"))
}
pub fn init_nnue_from_bytes(bytes: &[u8]) -> io::Result<()> {
let network = NNUENetwork::from_bytes(bytes)?;
NETWORK
.set(network)
.map_err(|_| io::Error::new(io::ErrorKind::AlreadyExists, "NNUE already initialized"))
}
pub fn is_nnue_initialized() -> bool {
NETWORK.get().is_some()
}
#[derive(Debug, Clone)]
pub struct NnueFormatInfo {
pub architecture: String,
pub l1_dimension: u32,
pub l2_dimension: u32,
pub l3_dimension: u32,
pub activation: String,
pub version: u32,
pub arch_string: String,
}
pub fn detect_format(bytes: &[u8]) -> io::Result<NnueFormatInfo> {
if bytes.len() < 12 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"NNUE file too small (need at least 12 bytes for header)",
));
}
let version = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
match version {
NNUE_VERSION | NNUE_VERSION_HALFKA => {
let arch_len = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]) as usize;
if arch_len == 0 || arch_len > MAX_ARCH_LEN {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid arch string length: {arch_len}"),
));
}
if bytes.len() < 12 + arch_len {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("NNUE file too small (need {} bytes for arch string)", 12 + arch_len),
));
}
let arch_str = String::from_utf8_lossy(&bytes[12..12 + arch_len]).to_string();
let activation = detect_activation_from_arch(&arch_str).to_string();
let parsed = super::spec::parse_architecture(&arch_str)
.map_err(|msg| io::Error::new(io::ErrorKind::InvalidData, msg))?;
let architecture = match parsed.feature_set {
FeatureSet::LayerStacks => "LayerStacks".to_string(),
FeatureSet::HalfKA_hm => format!("HalfKA_hm{}", parsed.l1),
FeatureSet::HalfKA => format!("HalfKA{}", parsed.l1),
FeatureSet::HalfKP => format!("HalfKP{}", parsed.l1),
};
Ok(NnueFormatInfo {
architecture,
l1_dimension: parsed.l1 as u32,
l2_dimension: parsed.l2 as u32,
l3_dimension: parsed.l3 as u32,
activation,
version,
arch_string: arch_str,
})
}
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Unknown NNUE version: 0x{version:08X}"),
)),
}
}
pub fn get_network() -> Option<&'static NNUENetwork> {
NETWORK.get()
}
#[inline]
fn update_and_evaluate_layer_stacks(
network: &NNUENetwork,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
) -> Value {
let current_entry = stack.current();
if !current_entry.accumulator.computed_accumulation {
let mut updated = false;
if let Some(prev_idx) = current_entry.previous {
let prev_computed = stack.entry_at(prev_idx).accumulator.computed_accumulation;
if prev_computed {
let dirty_piece = stack.current().dirty_piece;
let (prev_acc, current_acc) = stack.get_prev_and_current_accumulators(prev_idx);
network.update_accumulator_layer_stacks(pos, &dirty_piece, current_acc, prev_acc);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_layer_stacks(pos, stack, source_idx);
}
}
if !updated {
let acc = &mut stack.current_mut().accumulator;
network.refresh_accumulator_layer_stacks(pos, acc);
}
}
let acc_ref = &stack.current().accumulator;
network.evaluate_layer_stacks(pos, acc_ref)
}
#[inline]
fn update_and_evaluate_halfka_hm(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKA_hmStack,
) -> Value {
if !stack.is_current_computed() {
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka_hm(pos, &dirty, stack, prev_idx);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_halfka_hm(pos, stack, source_idx);
}
}
if !updated {
network.refresh_accumulator_halfka_hm(pos, stack);
}
}
network.evaluate_halfka_hm(pos, stack)
}
#[inline]
fn update_and_evaluate_halfka(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKAStack,
) -> Value {
if !stack.is_current_computed() {
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfka(pos, &dirty, stack, prev_idx);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_halfka(pos, stack, source_idx);
}
}
if !updated {
network.refresh_accumulator_halfka(pos, stack);
}
}
network.evaluate_halfka(pos, stack)
}
#[inline]
fn update_and_evaluate_halfkp(
network: &NNUENetwork,
pos: &Position,
stack: &mut HalfKPStack,
) -> Value {
if !stack.is_current_computed() {
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
network.update_accumulator_halfkp(pos, &dirty, stack, prev_idx);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = network.forward_update_incremental_halfkp(pos, stack, source_idx);
}
}
if !updated {
network.refresh_accumulator_halfkp(pos, stack);
}
}
network.evaluate_halfkp(pos, stack)
}
pub fn is_layer_stacks_loaded() -> bool {
NETWORK.get().map(|n| n.is_layer_stacks()).unwrap_or(false)
}
pub fn is_halfka_hm_256_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka_hm() && n.l1_size() == 256).unwrap_or(false)
}
pub fn is_halfka_256_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka() && n.l1_size() == 256).unwrap_or(false)
}
pub fn is_halfka_hm_512_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka_hm() && n.l1_size() == 512).unwrap_or(false)
}
pub fn is_halfka_512_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka() && n.l1_size() == 512).unwrap_or(false)
}
pub fn is_halfka_hm_1024_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka_hm() && n.l1_size() == 1024).unwrap_or(false)
}
pub fn is_halfka_1024_loaded() -> bool {
NETWORK.get().map(|n| n.is_halfka() && n.l1_size() == 1024).unwrap_or(false)
}
pub fn evaluate_layer_stacks(pos: &Position, stack: &mut AccumulatorStackLayerStacks) -> Value {
let Some(network) = NETWORK.get() else {
return material::evaluate_material(pos);
};
if !network.is_layer_stacks() {
panic!("Non-LayerStacks architecture detected. Use evaluate() with AccumulatorStack.");
}
update_and_evaluate_layer_stacks(network, pos, stack)
}
pub fn evaluate_dispatch(pos: &Position, stack: &mut AccumulatorStackVariant) -> Value {
let Some(network) = NETWORK.get() else {
return material::evaluate_material(pos);
};
match stack {
AccumulatorStackVariant::LayerStacks(s) => {
update_and_evaluate_layer_stacks(network, pos, s)
}
AccumulatorStackVariant::HalfKA(s) => update_and_evaluate_halfka(network, pos, s),
AccumulatorStackVariant::HalfKA_hm(s) => update_and_evaluate_halfka_hm(network, pos, s),
AccumulatorStackVariant::HalfKP(s) => update_and_evaluate_halfkp(network, pos, s),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::position::SFEN_HIRATE;
#[test]
fn test_evaluate_fallback() {
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let mut stack = AccumulatorStackVariant::new_default();
let value = evaluate_dispatch(&pos, &mut stack);
assert!(value.raw().abs() < 1000);
}
#[test]
fn test_accumulator_stack_variant_fallback() {
let mut pos = Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let mut stack = AccumulatorStackVariant::new_default();
let value1 = evaluate_dispatch(&pos, &mut stack);
let value2 = evaluate_dispatch(&pos, &mut stack);
let _ = (value1, value2);
}
#[test]
#[ignore]
fn test_nnue_network_auto_detect_layer_stacks() {
let path = std::env::var("NNUE_TEST_FILE")
.unwrap_or_else(|_| "/path/to/your/layer_stacks.nnue".to_string());
let network = match NNUENetwork::load(path) {
Ok(n) => n,
Err(e) => {
eprintln!("Skipping test: {e}");
return;
}
};
assert!(network.is_layer_stacks(), "epoch82.nnue should be detected as LayerStacks");
assert_eq!(network.architecture_name(), "LayerStacks");
let mut pos = crate::position::Position::new();
pos.set_sfen(SFEN_HIRATE).unwrap();
let mut acc = crate::nnue::AccumulatorLayerStacks::new();
network.refresh_accumulator_layer_stacks(&pos, &mut acc);
let value = network.evaluate_layer_stacks(&pos, &acc);
eprintln!("LayerStacks evaluate: {}", value.raw());
assert!(value.raw().abs() < 1000);
}
#[test]
fn test_parse_fv_scale_from_arch() {
assert_eq!(
parse_fv_scale_from_arch(
"Features=HalfKA_hm^[73305->256x2]-SCReLU,fv_scale=13,qa=127,qb=64,scale=600"
),
Some(13)
);
assert_eq!(
parse_fv_scale_from_arch(
"Features=HalfKA_hm^[73305->512x2]-SCReLU,fv_scale=20,qa=127,qb=64,scale=400"
),
Some(20)
);
assert_eq!(
parse_fv_scale_from_arch(
"Features=HalfKA_hm^[73305->1024x2]-SCReLU,fv_scale=16,qa=127,qb=64,scale=508"
),
Some(16)
);
assert_eq!(parse_fv_scale_from_arch("Features=HalfKP[125388->256x2]"), None);
assert_eq!(parse_fv_scale_from_arch("Features=HalfKA_hm^[73305->512x2]"), None);
assert_eq!(parse_fv_scale_from_arch(""), None);
assert_eq!(
parse_fv_scale_from_arch("Features=HalfKA_hm^[73305->256x2],fv_scale=abc"),
None
);
}
#[test]
fn test_parse_fv_scale_edge_cases() {
assert_eq!(parse_fv_scale_from_arch("fv_scale=1"), Some(1));
assert_eq!(parse_fv_scale_from_arch("fv_scale=128"), Some(128));
assert_eq!(parse_fv_scale_from_arch("fv_scale=64"), Some(64));
assert_eq!(parse_fv_scale_from_arch("fv_scale=0"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=129"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=-1"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=-100"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=99999"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=2147483647"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale= 16"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=16 "), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=10,fv_scale=20"), Some(10));
assert_eq!(parse_fv_scale_from_arch("fv_scale="), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale=16.5"), None);
assert_eq!(parse_fv_scale_from_arch("my_fv_scale=16"), None);
assert_eq!(parse_fv_scale_from_arch("fv_scale_v2=16"), None);
}
}