use crate::nnue::accumulator::DirtyPiece;
use crate::nnue::network_halfkp::AccumulatorStackHalfKP;
use crate::nnue::spec::{Activation, ArchitectureSpec, FeatureSet};
use crate::position::Position;
use crate::types::Value;
use crate::nnue::aliases::{
HalfKP512CReLU, HalfKP512Pairwise, HalfKP512SCReLU, HalfKP512_32_32CReLU,
HalfKP512_32_32Pairwise, HalfKP512_32_32SCReLU,
};
crate::define_l1_variants!(
enum HalfKPL512,
feature_set HalfKP,
l1 512,
acc crate::nnue::network_halfkp::AccumulatorHalfKP<512>,
stack AccumulatorStackHalfKP<512>,
variants {
(8, 96, CReLU, "CReLU") => CReLU8x96 : HalfKP512CReLU,
(8, 96, SCReLU, "SCReLU") => SCReLU8x96 : HalfKP512SCReLU,
(8, 96, PairwiseCReLU, "Pairwise") => Pairwise8x96 : HalfKP512Pairwise,
(32, 32, CReLU, "CReLU") => CReLU32x32 : HalfKP512_32_32CReLU,
(32, 32, SCReLU, "SCReLU") => SCReLU32x32 : HalfKP512_32_32SCReLU,
(32, 32, PairwiseCReLU, "Pairwise") => Pairwise32x32 : HalfKP512_32_32Pairwise,
}
);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_supported_specs() {
assert_eq!(HalfKPL512::SUPPORTED_SPECS.len(), 6);
let spec = &HalfKPL512::SUPPORTED_SPECS[0];
assert_eq!(spec.feature_set, FeatureSet::HalfKP);
assert_eq!(spec.l1, 512);
assert_eq!(spec.l2, 8);
assert_eq!(spec.l3, 96);
assert_eq!(spec.activation, Activation::CReLU);
let spec = &HalfKPL512::SUPPORTED_SPECS[3];
assert_eq!(spec.l2, 32);
assert_eq!(spec.l3, 32);
}
#[test]
fn test_l1_size() {
for spec in HalfKPL512::SUPPORTED_SPECS {
assert_eq!(spec.l1, 512);
}
}
#[test]
fn test_architecture_name_format() {
for spec in HalfKPL512::SUPPORTED_SPECS {
let name = spec.name();
assert!(
name.starts_with("HalfKP-512-"),
"Architecture name should start with 'HalfKP-512-', got: {name}"
);
}
}
#[test]
fn test_activation_output_dim_divisor() {
for spec in HalfKPL512::SUPPORTED_SPECS {
match spec.activation {
Activation::CReLU | Activation::SCReLU => {
assert_eq!(spec.activation.output_dim_divisor(), 1);
}
Activation::PairwiseCReLU => {
assert_eq!(spec.activation.output_dim_divisor(), 2);
}
}
}
}
#[test]
fn test_multiple_l2_l3_combinations() {
let combinations: Vec<_> =
HalfKPL512::SUPPORTED_SPECS.iter().map(|s| (s.l2, s.l3)).collect();
assert!(combinations.contains(&(8, 96)), "Should support L2=8, L3=96");
assert!(combinations.contains(&(32, 32)), "Should support L2=32, L3=32");
}
#[test]
fn test_all_activations_present() {
let activations: Vec<_> =
HalfKPL512::SUPPORTED_SPECS.iter().map(|s| s.activation).collect();
assert!(activations.contains(&Activation::CReLU));
assert!(activations.contains(&Activation::SCReLU));
assert!(activations.contains(&Activation::PairwiseCReLU));
}
}