#![allow(non_camel_case_types)]
mod l1024;
mod l256;
mod l512;
mod l768;
pub use l256::HalfKA_hm_L256;
pub use l512::HalfKA_hm_L512;
pub use l768::HalfKA_hm_L768;
pub use l1024::HalfKA_hm_L1024;
use crate::nnue::accumulator::DirtyPiece;
use crate::nnue::network_halfka_hm::AccumulatorStackHalfKA_hm;
use crate::nnue::spec::{Activation, ArchitectureSpec};
use crate::position::Position;
use crate::types::Value;
pub enum HalfKA_hmNetwork {
L256(HalfKA_hm_L256),
L512(HalfKA_hm_L512),
L768(HalfKA_hm_L768),
L1024(HalfKA_hm_L1024),
}
impl HalfKA_hmNetwork {
#[inline(always)]
pub fn evaluate(&self, pos: &Position, stack: &HalfKA_hmStack) -> Value {
match (self, stack) {
(Self::L256(net), HalfKA_hmStack::L256(st)) => net.evaluate(pos, st),
(Self::L512(net), HalfKA_hmStack::L512(st)) => net.evaluate(pos, st),
(Self::L768(net), HalfKA_hmStack::L768(st)) => net.evaluate(pos, st),
(Self::L1024(net), HalfKA_hmStack::L1024(st)) => net.evaluate(pos, st),
_ => unreachable!("L1 mismatch: network={}, stack={}", self.l1_size(), stack.l1_size()),
}
}
#[inline(always)]
pub fn refresh_accumulator(&self, pos: &Position, stack: &mut HalfKA_hmStack) {
match (self, stack) {
(Self::L256(net), HalfKA_hmStack::L256(st)) => net.refresh_accumulator(pos, st),
(Self::L512(net), HalfKA_hmStack::L512(st)) => net.refresh_accumulator(pos, st),
(Self::L768(net), HalfKA_hmStack::L768(st)) => net.refresh_accumulator(pos, st),
(Self::L1024(net), HalfKA_hmStack::L1024(st)) => net.refresh_accumulator(pos, st),
_ => unreachable!("L1 mismatch"),
}
}
#[inline(always)]
pub fn update_accumulator(
&self,
pos: &Position,
dirty: &DirtyPiece,
stack: &mut HalfKA_hmStack,
source_idx: usize,
) {
match (self, stack) {
(Self::L256(net), HalfKA_hmStack::L256(st)) => {
net.update_accumulator(pos, dirty, st, source_idx)
}
(Self::L512(net), HalfKA_hmStack::L512(st)) => {
net.update_accumulator(pos, dirty, st, source_idx)
}
(Self::L768(net), HalfKA_hmStack::L768(st)) => {
net.update_accumulator(pos, dirty, st, source_idx)
}
(Self::L1024(net), HalfKA_hmStack::L1024(st)) => {
net.update_accumulator(pos, dirty, st, source_idx)
}
_ => unreachable!("L1 mismatch"),
}
}
#[inline(always)]
pub fn forward_update_incremental(
&self,
pos: &Position,
stack: &mut HalfKA_hmStack,
source_idx: usize,
) -> bool {
match (self, stack) {
(Self::L256(net), HalfKA_hmStack::L256(st)) => {
net.forward_update_incremental(pos, st, source_idx)
}
(Self::L512(net), HalfKA_hmStack::L512(st)) => {
net.forward_update_incremental(pos, st, source_idx)
}
(Self::L768(net), HalfKA_hmStack::L768(st)) => {
net.forward_update_incremental(pos, st, source_idx)
}
(Self::L1024(net), HalfKA_hmStack::L1024(st)) => {
net.forward_update_incremental(pos, st, source_idx)
}
_ => unreachable!("L1 mismatch"),
}
}
pub fn read<R: std::io::Read + std::io::Seek>(
reader: &mut R,
l1: usize,
l2: usize,
l3: usize,
activation: Activation,
) -> std::io::Result<Self> {
if l2 == 0 || l3 == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"HalfKA_hm L1={l1} network missing L2/L3 dimensions in header. \
This is an old bullet-shogi format that is no longer supported. \
Please re-export the model with a newer version of bullet-shogi."
),
));
}
match l1 {
256 => {
let net = HalfKA_hm_L256::read(reader, l2, l3, activation)?;
Ok(Self::L256(net))
}
512 => {
let net = HalfKA_hm_L512::read(reader, l2, l3, activation)?;
Ok(Self::L512(net))
}
768 => {
let net = HalfKA_hm_L768::read(reader, l2, l3, activation)?;
Ok(Self::L768(net))
}
1024 => {
let net = HalfKA_hm_L1024::read(reader, l2, l3, activation)?;
Ok(Self::L1024(net))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Unsupported HalfKA_hm L1: {l1}"),
)),
}
}
pub fn l1_size(&self) -> usize {
match self {
Self::L256(_) => 256,
Self::L512(_) => 512,
Self::L768(_) => 768,
Self::L1024(_) => 1024,
}
}
pub fn architecture_name(&self) -> &'static str {
match self {
Self::L256(net) => net.architecture_name(),
Self::L512(net) => net.architecture_name(),
Self::L768(net) => net.architecture_name(),
Self::L1024(net) => net.architecture_name(),
}
}
pub fn architecture_spec(&self) -> ArchitectureSpec {
match self {
Self::L256(net) => net.architecture_spec(),
Self::L512(net) => net.architecture_spec(),
Self::L768(net) => net.architecture_spec(),
Self::L1024(net) => net.architecture_spec(),
}
}
pub fn supported_specs() -> Vec<ArchitectureSpec> {
let mut specs = Vec::new();
specs.extend_from_slice(HalfKA_hm_L256::SUPPORTED_SPECS);
specs.extend_from_slice(HalfKA_hm_L512::SUPPORTED_SPECS);
specs.extend_from_slice(HalfKA_hm_L768::SUPPORTED_SPECS);
specs.extend_from_slice(HalfKA_hm_L1024::SUPPORTED_SPECS);
specs
}
}
pub enum HalfKA_hmStack {
L256(AccumulatorStackHalfKA_hm<256>),
L512(AccumulatorStackHalfKA_hm<512>),
L768(AccumulatorStackHalfKA_hm<768>),
L1024(AccumulatorStackHalfKA_hm<1024>),
}
impl HalfKA_hmStack {
pub fn from_network(net: &HalfKA_hmNetwork) -> Self {
match net {
HalfKA_hmNetwork::L256(_) => Self::L256(AccumulatorStackHalfKA_hm::<256>::new()),
HalfKA_hmNetwork::L512(_) => Self::L512(AccumulatorStackHalfKA_hm::<512>::new()),
HalfKA_hmNetwork::L768(_) => Self::L768(AccumulatorStackHalfKA_hm::<768>::new()),
HalfKA_hmNetwork::L1024(_) => Self::L1024(AccumulatorStackHalfKA_hm::<1024>::new()),
}
}
pub fn l1_size(&self) -> usize {
match self {
Self::L256(_) => 256,
Self::L512(_) => 512,
Self::L768(_) => 768,
Self::L1024(_) => 1024,
}
}
pub fn reset(&mut self) {
match self {
Self::L256(s) => s.reset(),
Self::L512(s) => s.reset(),
Self::L768(s) => s.reset(),
Self::L1024(s) => s.reset(),
}
}
pub fn push(&mut self, dirty: DirtyPiece) {
match self {
Self::L256(s) => s.push(dirty),
Self::L512(s) => s.push(dirty),
Self::L768(s) => s.push(dirty),
Self::L1024(s) => s.push(dirty),
}
}
pub fn pop(&mut self) {
match self {
Self::L256(s) => s.pop(),
Self::L512(s) => s.pop(),
Self::L768(s) => s.pop(),
Self::L1024(s) => s.pop(),
}
}
pub fn current_index(&self) -> usize {
match self {
Self::L256(s) => s.current_index(),
Self::L512(s) => s.current_index(),
Self::L768(s) => s.current_index(),
Self::L1024(s) => s.current_index(),
}
}
pub fn find_usable_accumulator(&self) -> Option<(usize, usize)> {
match self {
Self::L256(s) => s.find_usable_accumulator(),
Self::L512(s) => s.find_usable_accumulator(),
Self::L768(s) => s.find_usable_accumulator(),
Self::L1024(s) => s.find_usable_accumulator(),
}
}
#[inline]
pub fn is_current_computed(&self) -> bool {
match self {
Self::L256(s) => s.current().accumulator.computed_accumulation,
Self::L512(s) => s.current().accumulator.computed_accumulation,
Self::L768(s) => s.current().accumulator.computed_accumulation,
Self::L1024(s) => s.current().accumulator.computed_accumulation,
}
}
#[inline]
pub fn current_previous(&self) -> Option<usize> {
match self {
Self::L256(s) => s.current().previous,
Self::L512(s) => s.current().previous,
Self::L768(s) => s.current().previous,
Self::L1024(s) => s.current().previous,
}
}
#[inline]
pub fn is_entry_computed(&self, idx: usize) -> bool {
match self {
Self::L256(s) => s.entry_at(idx).accumulator.computed_accumulation,
Self::L512(s) => s.entry_at(idx).accumulator.computed_accumulation,
Self::L768(s) => s.entry_at(idx).accumulator.computed_accumulation,
Self::L1024(s) => s.entry_at(idx).accumulator.computed_accumulation,
}
}
#[inline]
pub fn current_dirty_piece(&self) -> DirtyPiece {
match self {
Self::L256(s) => s.current().dirty_piece,
Self::L512(s) => s.current().dirty_piece,
Self::L768(s) => s.current().dirty_piece,
Self::L1024(s) => s.current().dirty_piece,
}
}
}
impl Default for HalfKA_hmStack {
fn default() -> Self {
Self::L512(AccumulatorStackHalfKA_hm::<512>::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::nnue::spec::FeatureSet;
#[test]
fn test_halfka_stack_from_network_l1_size() {
let stack = HalfKA_hmStack::L256(AccumulatorStackHalfKA_hm::<256>::new());
assert_eq!(stack.l1_size(), 256);
let stack = HalfKA_hmStack::L512(AccumulatorStackHalfKA_hm::<512>::new());
assert_eq!(stack.l1_size(), 512);
let stack = HalfKA_hmStack::L1024(AccumulatorStackHalfKA_hm::<1024>::new());
assert_eq!(stack.l1_size(), 1024);
}
#[test]
fn test_supported_specs_combined() {
let specs = HalfKA_hmNetwork::supported_specs();
assert_eq!(specs.len(), 8);
for spec in &specs {
assert_eq!(spec.feature_set, FeatureSet::HalfKA_hm);
}
}
#[test]
fn test_push_pop_index_consistency_l256() {
let mut stack = HalfKA_hmStack::L256(AccumulatorStackHalfKA_hm::<256>::new());
let dirty = DirtyPiece::default();
stack.reset();
let initial_index = stack.current_index();
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + 1);
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + 2);
stack.pop();
assert_eq!(stack.current_index(), initial_index + 1);
stack.pop();
assert_eq!(stack.current_index(), initial_index);
}
#[test]
fn test_push_pop_index_consistency_l512() {
let mut stack = HalfKA_hmStack::L512(AccumulatorStackHalfKA_hm::<512>::new());
let dirty = DirtyPiece::default();
stack.reset();
let initial_index = stack.current_index();
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + 1);
stack.pop();
assert_eq!(stack.current_index(), initial_index);
}
#[test]
fn test_push_pop_index_consistency_l1024() {
let mut stack = HalfKA_hmStack::L1024(AccumulatorStackHalfKA_hm::<1024>::new());
let dirty = DirtyPiece::default();
stack.reset();
let initial_index = stack.current_index();
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + 1);
stack.pop();
assert_eq!(stack.current_index(), initial_index);
}
#[test]
fn test_deep_push_pop() {
let mut stack = HalfKA_hmStack::default();
let dirty = DirtyPiece::default();
stack.reset();
let initial_index = stack.current_index();
const DEPTH: usize = 30;
for i in 0..DEPTH {
stack.push(dirty);
assert_eq!(stack.current_index(), initial_index + i + 1);
}
for i in (0..DEPTH).rev() {
stack.pop();
assert_eq!(stack.current_index(), initial_index + i);
}
}
#[test]
fn test_architecture_spec_consistency() {
for spec in HalfKA_hmNetwork::supported_specs() {
assert_eq!(spec.feature_set, FeatureSet::HalfKA_hm);
assert!(spec.l1 == 256 || spec.l1 == 512 || spec.l1 == 768 || spec.l1 == 1024);
assert!(spec.l2 > 0 && spec.l2 <= 128);
assert!(spec.l3 > 0 && spec.l3 <= 128);
}
}
}