use super::accumulator::DirtyPiece;
use super::accumulator_layer_stacks::AccumulatorStackLayerStacks;
use super::halfka::HalfKAStack;
use super::halfka_hm::HalfKA_hmStack;
use super::halfkp::HalfKPStack;
use super::network::NNUENetwork;
#[allow(non_camel_case_types)]
pub enum AccumulatorStackVariant {
HalfKA(HalfKAStack),
HalfKA_hm(HalfKA_hmStack),
HalfKP(HalfKPStack),
LayerStacks(AccumulatorStackLayerStacks),
}
impl AccumulatorStackVariant {
pub fn from_network(network: &NNUENetwork) -> Self {
match network {
NNUENetwork::HalfKA(net) => Self::HalfKA(HalfKAStack::from_network(net)),
NNUENetwork::HalfKA_hm(net) => Self::HalfKA_hm(HalfKA_hmStack::from_network(net)),
NNUENetwork::HalfKP(net) => Self::HalfKP(HalfKPStack::from_network(net)),
NNUENetwork::LayerStacks(_) => Self::LayerStacks(AccumulatorStackLayerStacks::new()),
}
}
pub fn new_default() -> Self {
Self::HalfKP(HalfKPStack::default())
}
pub fn matches_network(&self, network: &NNUENetwork) -> bool {
match (self, network) {
(Self::HalfKA(stack), NNUENetwork::HalfKA(net)) => stack.l1_size() == net.l1_size(),
(Self::HalfKA_hm(stack), NNUENetwork::HalfKA_hm(net)) => {
stack.l1_size() == net.l1_size()
}
(Self::HalfKP(stack), NNUENetwork::HalfKP(net)) => stack.l1_size() == net.l1_size(),
(Self::LayerStacks(_), NNUENetwork::LayerStacks(_)) => true,
_ => false,
}
}
#[inline]
pub fn reset(&mut self) {
match self {
Self::HalfKA(stack) => stack.reset(),
Self::HalfKA_hm(stack) => stack.reset(),
Self::HalfKP(stack) => stack.reset(),
Self::LayerStacks(stack) => stack.reset(),
}
}
#[inline]
pub fn push(&mut self, dirty_piece: DirtyPiece) {
match self {
Self::HalfKA(stack) => stack.push(dirty_piece),
Self::HalfKA_hm(stack) => stack.push(dirty_piece),
Self::HalfKP(stack) => stack.push(dirty_piece),
Self::LayerStacks(stack) => {
stack.push();
stack.current_mut().dirty_piece = dirty_piece;
}
}
}
#[inline]
pub fn pop(&mut self) {
match self {
Self::HalfKA(stack) => stack.pop(),
Self::HalfKA_hm(stack) => stack.pop(),
Self::HalfKP(stack) => stack.pop(),
Self::LayerStacks(stack) => stack.pop(),
}
}
#[inline]
pub fn is_halfkp(&self) -> bool {
matches!(self, Self::HalfKP(_))
}
}
impl Default for AccumulatorStackVariant {
fn default() -> Self {
Self::new_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_is_halfkp() {
let stack = AccumulatorStackVariant::default();
assert!(stack.is_halfkp());
assert!(matches!(stack, AccumulatorStackVariant::HalfKP(_)));
assert!(!matches!(stack, AccumulatorStackVariant::LayerStacks(_)));
assert!(!matches!(stack, AccumulatorStackVariant::HalfKA(_)));
assert!(!matches!(stack, AccumulatorStackVariant::HalfKA_hm(_)));
}
#[test]
fn test_new_default_is_halfkp() {
let stack = AccumulatorStackVariant::new_default();
assert!(stack.is_halfkp());
assert!(matches!(stack, AccumulatorStackVariant::HalfKP(_)));
}
#[test]
fn test_reset_does_not_change_variant() {
let mut stack = AccumulatorStackVariant::new_default();
assert!(stack.is_halfkp());
stack.reset();
assert!(stack.is_halfkp());
}
#[test]
fn test_push_pop_symmetry() {
let mut stack = AccumulatorStackVariant::new_default();
let dirty = DirtyPiece::default();
stack.reset();
stack.push(dirty);
stack.push(dirty);
stack.pop();
stack.pop();
}
#[test]
fn test_push_pop_index_consistency_halfkp() {
let mut stack = HalfKPStack::default();
let dirty = DirtyPiece::default();
stack.reset();
let initial_index = stack.current_index();
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + 1);
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + 2);
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + 3);
stack.pop();
assert_eq!(stack.current_index(), initial_index + 2);
stack.pop();
assert_eq!(stack.current_index(), initial_index + 1);
stack.pop();
assert_eq!(stack.current_index(), initial_index);
}
#[test]
fn test_push_pop_index_consistency_halfka_hm() {
let mut stack = HalfKA_hmStack::default();
let dirty = DirtyPiece::default();
stack.reset();
let initial_index = stack.current_index();
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + 1);
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + 2);
stack.pop();
assert_eq!(stack.current_index(), initial_index + 1);
stack.pop();
assert_eq!(stack.current_index(), initial_index);
}
#[test]
fn test_halfka_hm_stack_l1_sizes() {
use crate::nnue::network_halfka_hm::AccumulatorStackHalfKA_hm;
let l256_stack = HalfKA_hmStack::L256(AccumulatorStackHalfKA_hm::<256>::new());
let l512_stack = HalfKA_hmStack::L512(AccumulatorStackHalfKA_hm::<512>::new());
let l1024_stack = HalfKA_hmStack::L1024(AccumulatorStackHalfKA_hm::<1024>::new());
assert_eq!(l256_stack.l1_size(), 256);
assert_eq!(l512_stack.l1_size(), 512);
assert_eq!(l1024_stack.l1_size(), 1024);
}
#[test]
fn test_halfkp_stack_l1_sizes() {
use crate::nnue::network_halfkp::AccumulatorStackHalfKP;
let l256_stack = HalfKPStack::L256(AccumulatorStackHalfKP::<256>::new());
let l512_stack = HalfKPStack::L512(AccumulatorStackHalfKP::<512>::new());
let l1024_stack = HalfKPStack::L1024(AccumulatorStackHalfKP::<1024>::new());
assert_eq!(l256_stack.l1_size(), 256);
assert_eq!(l512_stack.l1_size(), 512);
assert_eq!(l1024_stack.l1_size(), 1024);
}
#[test]
fn test_deep_push_pop() {
let mut stack = AccumulatorStackVariant::new_default();
let dirty = DirtyPiece::default();
stack.reset();
const DEPTH: usize = 30;
for _ in 0..DEPTH {
stack.push(dirty);
}
for _ in 0..DEPTH {
stack.pop();
}
}
#[test]
fn test_variant_size() {
use std::mem::size_of;
let variant_size = size_of::<AccumulatorStackVariant>();
let layer_stacks_size = size_of::<AccumulatorStackLayerStacks>();
let halfka_stack_size = size_of::<HalfKA_hmStack>();
let halfkp_stack_size = size_of::<HalfKPStack>();
eprintln!("AccumulatorStackVariant size: {variant_size} bytes");
eprintln!("HalfKA_hmStack size: {halfka_stack_size} bytes");
eprintln!("HalfKPStack size: {halfkp_stack_size} bytes");
eprintln!("LayerStacks size: {layer_stacks_size} bytes");
assert!(variant_size > 0);
}
}