#![allow(non_camel_case_types)]
use crate::nnue::accumulator::DirtyPiece;
use crate::nnue::network_halfka::AccumulatorStackHalfKA;
use crate::nnue::spec::{Activation, ArchitectureSpec, FeatureSet};
use crate::position::Position;
use crate::types::Value;
use crate::nnue::aliases::{HalfKA1024_8_32CReLU, HalfKA1024_8_64CReLU, HalfKA1024CReLU};
crate::define_l1_variants!(
enum HalfKA_L1024,
feature_set HalfKA,
l1 1024,
acc crate::nnue::network_halfka::AccumulatorHalfKA<1024>,
stack AccumulatorStackHalfKA<1024>,
variants {
(8, 64, CReLU, "CReLU") => CReLU8x64 : HalfKA1024_8_64CReLU,
(8, 96, CReLU, "CReLU") => CReLU8x96 : HalfKA1024CReLU,
(8, 32, CReLU, "CReLU") => CReLU8x32 : HalfKA1024_8_32CReLU,
}
);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_supported_specs() {
assert_eq!(HalfKA_L1024::SUPPORTED_SPECS.len(), 3);
let spec = &HalfKA_L1024::SUPPORTED_SPECS[0];
assert_eq!(spec.feature_set, FeatureSet::HalfKA);
assert_eq!(spec.l1, 1024);
assert_eq!(spec.l2, 8);
assert_eq!(spec.l3, 64);
assert_eq!(spec.activation, Activation::CReLU);
let spec = &HalfKA_L1024::SUPPORTED_SPECS[1];
assert_eq!(spec.l2, 8);
assert_eq!(spec.l3, 96);
let spec = &HalfKA_L1024::SUPPORTED_SPECS[2];
assert_eq!(spec.l2, 8);
assert_eq!(spec.l3, 32);
}
#[test]
fn test_l1_size() {
for spec in HalfKA_L1024::SUPPORTED_SPECS {
assert_eq!(spec.l1, 1024);
}
}
#[test]
fn test_architecture_name_format() {
for spec in HalfKA_L1024::SUPPORTED_SPECS {
let name = spec.name();
assert!(
name.starts_with("HalfKA-1024-"),
"Architecture name should start with 'HalfKA-1024-', got: {name}"
);
}
}
#[test]
fn test_activation_output_dim_divisor() {
for spec in HalfKA_L1024::SUPPORTED_SPECS {
assert_eq!(spec.activation, Activation::CReLU);
assert_eq!(spec.activation.output_dim_divisor(), 1);
}
}
#[test]
fn test_multiple_l2_l3_combinations() {
let combinations: Vec<_> =
HalfKA_L1024::SUPPORTED_SPECS.iter().map(|s| (s.l2, s.l3)).collect();
assert!(combinations.contains(&(8, 64)), "Should support L2=8, L3=64");
assert!(combinations.contains(&(8, 96)), "Should support L2=8, L3=96");
assert!(combinations.contains(&(8, 32)), "Should support L2=8, L3=32");
}
}