use std::sync::Arc;
use super::accumulator::DirtyPiece;
use super::accumulator_layer_stacks::AccumulatorStackLayerStacks;
use super::accumulator_stack_variant::AccumulatorStackVariant;
use super::halfka::HalfKAStack;
use super::halfka_hm::HalfKA_hmStack;
use super::halfkp::HalfKPStack;
use super::network::NNUENetwork;
use super::spec::ArchitectureSpec;
use crate::position::Position;
use crate::types::Value;
pub struct NNUEEvaluator {
net: Arc<NNUENetwork>,
stack: AccumulatorStackVariant,
}
impl NNUEEvaluator {
pub fn new_with_position(net: Arc<NNUENetwork>, pos: &Position) -> Self {
let stack = AccumulatorStackVariant::from_network(&net);
let mut evaluator = Self { net, stack };
evaluator.reset(pos);
evaluator
}
pub fn clone_for_thread(&self, pos: &Position) -> Self {
let mut evaluator = Self {
net: Arc::clone(&self.net),
stack: AccumulatorStackVariant::from_network(&self.net),
};
evaluator.reset(pos);
evaluator
}
pub fn reset(&mut self, pos: &Position) {
self.stack.reset();
self.refresh_accumulator(pos);
}
#[inline]
pub fn push(&mut self, dirty_piece: DirtyPiece) {
self.stack.push(dirty_piece);
}
#[inline]
pub fn pop(&mut self) {
self.stack.pop();
}
#[inline(always)]
pub fn evaluate(&mut self, pos: &Position) -> Value {
self.ensure_accumulator_computed(pos);
self.evaluate_only(pos)
}
pub fn refresh(&mut self, pos: &Position) {
self.refresh_accumulator(pos);
}
#[inline(always)]
pub fn evaluate_only(&self, pos: &Position) -> Value {
match (&*self.net, &self.stack) {
(NNUENetwork::HalfKA(net), AccumulatorStackVariant::HalfKA(st)) => {
net.evaluate(pos, st)
}
(NNUENetwork::HalfKA_hm(net), AccumulatorStackVariant::HalfKA_hm(st)) => {
net.evaluate(pos, st)
}
(NNUENetwork::HalfKP(net), AccumulatorStackVariant::HalfKP(st)) => {
net.evaluate(pos, st)
}
(NNUENetwork::LayerStacks(net), AccumulatorStackVariant::LayerStacks(st)) => {
net.evaluate(pos, &st.current().accumulator)
}
_ => unreachable!("Network/Stack type mismatch"),
}
}
pub fn architecture_name(&self) -> &'static str {
self.net.architecture_name()
}
pub fn architecture_spec(&self) -> ArchitectureSpec {
self.net.architecture_spec()
}
pub fn network(&self) -> &Arc<NNUENetwork> {
&self.net
}
pub fn l1_size(&self) -> usize {
self.net.l1_size()
}
fn refresh_accumulator(&mut self, pos: &Position) {
match (&*self.net, &mut self.stack) {
(NNUENetwork::HalfKA(net), AccumulatorStackVariant::HalfKA(st)) => {
net.refresh_accumulator(pos, st);
}
(NNUENetwork::HalfKA_hm(net), AccumulatorStackVariant::HalfKA_hm(st)) => {
net.refresh_accumulator(pos, st);
}
(NNUENetwork::HalfKP(net), AccumulatorStackVariant::HalfKP(st)) => {
net.refresh_accumulator(pos, st);
}
(NNUENetwork::LayerStacks(net), AccumulatorStackVariant::LayerStacks(st)) => {
net.refresh_accumulator(pos, &mut st.current_mut().accumulator);
}
_ => unreachable!("Network/Stack type mismatch"),
}
}
fn ensure_accumulator_computed(&mut self, pos: &Position) {
match (&*self.net, &mut self.stack) {
(NNUENetwork::HalfKA(net), AccumulatorStackVariant::HalfKA(st)) => {
Self::update_halfka_accumulator(net, pos, st);
}
(NNUENetwork::HalfKA_hm(net), AccumulatorStackVariant::HalfKA_hm(st)) => {
Self::update_halfka_hm_accumulator(net, pos, st);
}
(NNUENetwork::HalfKP(net), AccumulatorStackVariant::HalfKP(st)) => {
Self::update_halfkp_accumulator(net, pos, st);
}
(NNUENetwork::LayerStacks(net), AccumulatorStackVariant::LayerStacks(st)) => {
Self::update_layer_stacks_accumulator(net, pos, st);
}
_ => unreachable!("Network/Stack type mismatch"),
}
}
#[inline]
fn update_halfka_accumulator(
net: &super::halfka::HalfKANetwork,
pos: &Position,
stack: &mut HalfKAStack,
) {
if stack.is_current_computed() {
return;
}
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
net.update_accumulator(pos, &dirty, stack, prev_idx);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = net.forward_update_incremental(pos, stack, source_idx);
}
}
if !updated {
net.refresh_accumulator(pos, stack);
}
}
#[inline]
fn update_halfka_hm_accumulator(
net: &super::halfka_hm::HalfKA_hmNetwork,
pos: &Position,
stack: &mut HalfKA_hmStack,
) {
if stack.is_current_computed() {
return;
}
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
net.update_accumulator(pos, &dirty, stack, prev_idx);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = net.forward_update_incremental(pos, stack, source_idx);
}
}
if !updated {
net.refresh_accumulator(pos, stack);
}
}
#[inline]
fn update_halfkp_accumulator(
net: &super::halfkp::HalfKPNetwork,
pos: &Position,
stack: &mut HalfKPStack,
) {
if stack.is_current_computed() {
return;
}
let mut updated = false;
if let Some(prev_idx) = stack.current_previous() {
if stack.is_entry_computed(prev_idx) {
let dirty = stack.current_dirty_piece();
net.update_accumulator(pos, &dirty, stack, prev_idx);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = net.forward_update_incremental(pos, stack, source_idx);
}
}
if !updated {
net.refresh_accumulator(pos, stack);
}
}
#[inline]
fn update_layer_stacks_accumulator(
net: &super::network_layer_stacks::NetworkLayerStacks,
pos: &Position,
stack: &mut AccumulatorStackLayerStacks,
) {
let current_entry = stack.current();
if current_entry.accumulator.computed_accumulation {
return;
}
let mut updated = false;
if let Some(prev_idx) = current_entry.previous {
let prev_computed = stack.entry_at(prev_idx).accumulator.computed_accumulation;
if prev_computed {
let dirty_piece = stack.current().dirty_piece;
let (prev_acc, current_acc) = stack.get_prev_and_current_accumulators(prev_idx);
net.update_accumulator(pos, &dirty_piece, current_acc, prev_acc);
updated = true;
}
}
if !updated {
if let Some((source_idx, _depth)) = stack.find_usable_accumulator() {
updated = net.forward_update_incremental(pos, stack, source_idx);
}
}
if !updated {
let acc = &mut stack.current_mut().accumulator;
net.refresh_accumulator(pos, acc);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_evaluator_construction() {
let stack = AccumulatorStackVariant::new_default();
assert!(stack.is_halfkp());
}
#[test]
fn test_stack_push_pop() {
let mut stack = AccumulatorStackVariant::new_default();
let dirty = DirtyPiece::default();
stack.reset();
stack.push(dirty);
stack.push(dirty);
stack.pop();
stack.pop();
}
#[test]
fn test_evaluator_size() {
use std::mem::size_of;
let evaluator_size = size_of::<NNUEEvaluator>();
let arc_size = size_of::<Arc<NNUENetwork>>();
let stack_size = size_of::<AccumulatorStackVariant>();
eprintln!("NNUEEvaluator size: {evaluator_size} bytes");
eprintln!("Arc<NNUENetwork> size: {arc_size} bytes");
eprintln!("AccumulatorStackVariant size: {stack_size} bytes");
assert!(evaluator_size > 0);
}
#[test]
fn test_stack_variant_type_checking() {
use crate::nnue::network_halfka::AccumulatorStackHalfKA;
use crate::nnue::network_halfka_hm::AccumulatorStackHalfKA_hm;
use crate::nnue::network_halfkp::AccumulatorStackHalfKP;
let halfka_nm_l256 = AccumulatorStackVariant::HalfKA(HalfKAStack::L256(
AccumulatorStackHalfKA::<256>::new(),
));
let halfka_nm_l512 = AccumulatorStackVariant::HalfKA(HalfKAStack::L512(
AccumulatorStackHalfKA::<512>::new(),
));
let halfka_nm_l1024 = AccumulatorStackVariant::HalfKA(HalfKAStack::L1024(
AccumulatorStackHalfKA::<1024>::new(),
));
let halfka_hm_l256 =
AccumulatorStackVariant::HalfKA_hm(HalfKA_hmStack::L256(AccumulatorStackHalfKA_hm::<
256,
>::new()));
let halfka_hm_l512 =
AccumulatorStackVariant::HalfKA_hm(HalfKA_hmStack::L512(AccumulatorStackHalfKA_hm::<
512,
>::new()));
let halfka_hm_l1024 =
AccumulatorStackVariant::HalfKA_hm(HalfKA_hmStack::L1024(AccumulatorStackHalfKA_hm::<
1024,
>::new()));
let halfkp_l256 = AccumulatorStackVariant::HalfKP(HalfKPStack::L256(
AccumulatorStackHalfKP::<256>::new(),
));
let halfkp_l512 = AccumulatorStackVariant::HalfKP(HalfKPStack::L512(
AccumulatorStackHalfKP::<512>::new(),
));
assert!(matches!(halfka_nm_l256, AccumulatorStackVariant::HalfKA(_)));
assert!(matches!(halfka_nm_l512, AccumulatorStackVariant::HalfKA(_)));
assert!(matches!(halfka_nm_l1024, AccumulatorStackVariant::HalfKA(_)));
assert!(matches!(halfka_hm_l256, AccumulatorStackVariant::HalfKA_hm(_)));
assert!(matches!(halfka_hm_l512, AccumulatorStackVariant::HalfKA_hm(_)));
assert!(matches!(halfka_hm_l1024, AccumulatorStackVariant::HalfKA_hm(_)));
assert!(matches!(halfkp_l256, AccumulatorStackVariant::HalfKP(_)));
assert!(matches!(halfkp_l512, AccumulatorStackVariant::HalfKP(_)));
assert!(!halfka_nm_l256.is_halfkp());
assert!(!halfka_nm_l512.is_halfkp());
assert!(!halfka_nm_l1024.is_halfkp());
assert!(!halfka_hm_l256.is_halfkp());
assert!(!halfka_hm_l512.is_halfkp());
assert!(!halfka_hm_l1024.is_halfkp());
assert!(halfkp_l256.is_halfkp());
assert!(halfkp_l512.is_halfkp());
}
#[test]
fn test_all_variants_push_pop_consistency() {
use crate::nnue::accumulator_layer_stacks::AccumulatorStackLayerStacks;
use crate::nnue::network_halfka::AccumulatorStackHalfKA;
use crate::nnue::network_halfka_hm::AccumulatorStackHalfKA_hm;
use crate::nnue::network_halfkp::AccumulatorStackHalfKP;
let dirty = DirtyPiece::default();
let mut stack = AccumulatorStackVariant::HalfKA(HalfKAStack::L512(
AccumulatorStackHalfKA::<512>::new(),
));
stack.reset();
stack.push(dirty);
stack.push(dirty);
stack.pop();
stack.pop();
let mut stack =
AccumulatorStackVariant::HalfKA_hm(HalfKA_hmStack::L512(AccumulatorStackHalfKA_hm::<
512,
>::new()));
stack.reset();
stack.push(dirty);
stack.push(dirty);
stack.pop();
stack.pop();
let mut stack = AccumulatorStackVariant::HalfKP(HalfKPStack::L512(
AccumulatorStackHalfKP::<512>::new(),
));
stack.reset();
stack.push(dirty);
stack.push(dirty);
stack.pop();
stack.pop();
let mut stack = AccumulatorStackVariant::LayerStacks(AccumulatorStackLayerStacks::new());
stack.reset();
stack.push(dirty);
stack.push(dirty);
stack.pop();
stack.pop();
}
#[test]
fn test_deep_search_simulation() {
let mut stack = AccumulatorStackVariant::new_default();
let dirty = DirtyPiece::default();
stack.reset();
const MAX_DEPTH: usize = 30;
for _ in 0..5 {
for _ in 0..MAX_DEPTH {
stack.push(dirty);
}
for _ in 0..MAX_DEPTH {
stack.pop();
}
}
}
#[test]
fn test_network_enum_coverage() {
use crate::nnue::halfka::HalfKANetwork;
use crate::nnue::halfka_hm::HalfKA_hmNetwork;
use crate::nnue::halfkp::HalfKPNetwork;
let halfka_specs = HalfKANetwork::supported_specs();
assert_eq!(halfka_specs.len(), 15);
let halfka_hm_specs = HalfKA_hmNetwork::supported_specs();
assert_eq!(halfka_hm_specs.len(), 15);
let halfkp_specs = HalfKPNetwork::supported_specs();
assert_eq!(halfkp_specs.len(), 18);
for spec in &halfka_specs {
assert_eq!(
spec.feature_set,
crate::nnue::spec::FeatureSet::HalfKA,
"HalfKA spec has wrong feature_set: {spec:?}"
);
}
for spec in &halfka_hm_specs {
assert_eq!(
spec.feature_set,
crate::nnue::spec::FeatureSet::HalfKA_hm,
"HalfKA_hm spec has wrong feature_set: {spec:?}"
);
}
for spec in &halfkp_specs {
assert_eq!(
spec.feature_set,
crate::nnue::spec::FeatureSet::HalfKP,
"HalfKP spec has wrong feature_set: {spec:?}"
);
}
}
}