#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FeatureSet {
HalfKP,
#[allow(non_camel_case_types)]
HalfKA_hm,
HalfKA,
LayerStacks,
}
impl FeatureSet {
pub fn as_str(&self) -> &'static str {
match self {
Self::HalfKP => "HalfKP",
Self::HalfKA_hm => "HalfKA_hm",
Self::HalfKA => "HalfKA",
Self::LayerStacks => "LayerStacks",
}
}
}
impl std::fmt::Display for FeatureSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Activation {
CReLU,
SCReLU,
PairwiseCReLU,
}
impl Activation {
pub fn as_str(&self) -> &'static str {
match self {
Self::CReLU => "CReLU",
Self::SCReLU => "SCReLU",
Self::PairwiseCReLU => "PairwiseCReLU",
}
}
pub fn output_dim_divisor(&self) -> usize {
match self {
Self::CReLU | Self::SCReLU => 1,
Self::PairwiseCReLU => 2,
}
}
pub fn from_header_suffix(suffix: &str) -> Self {
if suffix.contains("-PairwiseCReLU") || suffix.contains("-Pairwise") {
Self::PairwiseCReLU
} else if suffix.contains("-SCReLU") {
Self::SCReLU
} else {
Self::CReLU
}
}
}
impl std::fmt::Display for Activation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ArchitectureSpec {
pub feature_set: FeatureSet,
pub l1: usize,
pub l2: usize,
pub l3: usize,
pub activation: Activation,
}
impl ArchitectureSpec {
pub const fn new(
feature_set: FeatureSet,
l1: usize,
l2: usize,
l3: usize,
activation: Activation,
) -> Self {
Self {
feature_set,
l1,
l2,
l3,
activation,
}
}
pub fn name(&self) -> String {
format!("{}-{}-{}-{}-{}", self.feature_set, self.l1, self.l2, self.l3, self.activation)
}
}
impl std::fmt::Display for ArchitectureSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ParsedArchitecture {
pub feature_set: FeatureSet,
pub l1: usize,
pub l2: usize,
pub l3: usize,
}
pub fn parse_feature_input_dimensions(arch_str: &str) -> Option<usize> {
let features_key = "Features=";
let start = arch_str.find(features_key)?;
let after_key = &arch_str[start + features_key.len()..];
let bracket_start = after_key.find('[')?;
let after_bracket = &after_key[bracket_start + 1..];
let arrow_idx = after_bracket.find("->")?;
let num_str = &after_bracket[..arrow_idx];
num_str.parse::<usize>().ok()
}
pub fn parse_feature_set_from_arch(arch_str: &str) -> Result<FeatureSet, String> {
use super::constants::{HALFKA_DIMENSIONS, HALFKA_HM_DIMENSIONS};
if arch_str.contains("LayerStacks") || arch_str.contains("->1536x2]") {
return Ok(FeatureSet::LayerStacks);
}
if arch_str.contains("HalfKP") {
return Ok(FeatureSet::HalfKP);
}
if arch_str.contains("HalfKA_hm") {
return Ok(FeatureSet::HalfKA_hm);
}
if arch_str.contains("HalfKA") {
let input_dim = parse_feature_input_dimensions(arch_str).ok_or_else(|| {
"HalfKA architecture is missing input dimensions in arch string.".to_string()
})?;
return match input_dim {
HALFKA_HM_DIMENSIONS => Ok(FeatureSet::HalfKA_hm),
HALFKA_DIMENSIONS => Ok(FeatureSet::HalfKA),
_ => Err(format!("Unknown HalfKA input dimensions: {input_dim}")),
};
}
Err("Unknown feature set in arch string.".to_string())
}
pub fn parse_arch_dimensions(arch_str: &str) -> (usize, usize, usize) {
let l1 = if let Some(idx) = arch_str.find("x2]") {
let before = &arch_str[..idx];
if let Some(arrow_idx) = before.rfind("->") {
let after_arrow = &before[arrow_idx + 2..];
let num_str = if let Some(slash_idx) = after_arrow.find('/') {
&after_arrow[..slash_idx]
} else {
after_arrow
};
num_str.parse::<usize>().unwrap_or(0)
} else {
0
}
} else {
0
};
let mut layers: Vec<(usize, usize)> = Vec::new();
let pattern = "AffineTransform[";
let mut search_start = 0;
while let Some(start) = arch_str[search_start..].find(pattern) {
let abs_start = search_start + start + pattern.len();
if let Some(end) = arch_str[abs_start..].find(']') {
let content = &arch_str[abs_start..abs_start + end];
if let Some(arrow_idx) = content.find("<-") {
let out_str = &content[..arrow_idx];
let in_str = &content[arrow_idx + 2..];
if let (Ok(out), Ok(inp)) = (out_str.parse::<usize>(), in_str.parse::<usize>()) {
layers.push((out, inp));
}
}
search_start = abs_start + end;
} else {
break;
}
}
let mut l2 = 0usize;
let mut l3 = 0usize;
for part in arch_str.split(',') {
if let Some(val_str) = part.strip_prefix("l2=") {
if let Ok(val) = val_str.parse::<usize>() {
l2 = val;
}
} else if let Some(val_str) = part.strip_prefix("l3=") {
if let Ok(val) = val_str.parse::<usize>() {
l3 = val;
}
}
}
if l2 == 0 || l3 == 0 {
layers.reverse();
if layers.len() >= 3 {
if l2 == 0 {
l2 = layers[0].0;
}
if l3 == 0 {
l3 = layers[1].0;
}
}
}
(l1, l2, l3)
}
pub fn parse_halfkp_l1(arch_str: &str) -> usize {
if let Some(idx) = arch_str.find("->") {
let after = &arch_str[idx + 2..];
let end = after.find(|c: char| !c.is_ascii_digit()).unwrap_or(after.len());
let num_str = &after[..end];
return num_str.parse().unwrap_or(0);
}
if let Some(idx) = arch_str.find("x2]") {
let before = &arch_str[..idx];
if let Some(slash_idx) = before.rfind('/') {
let num_part = &before[..slash_idx];
if let Some(start) = num_part.rfind(|c: char| !c.is_ascii_digit()) {
let num_str = &num_part[start + 1..];
return num_str.parse().unwrap_or(0);
}
} else if let Some(start) = before.rfind(|c: char| !c.is_ascii_digit()) {
let num_str = &before[start + 1..];
return num_str.parse().unwrap_or(0);
}
}
0
}
pub fn parse_architecture(arch_str: &str) -> Result<ParsedArchitecture, String> {
let feature_set = parse_feature_set_from_arch(arch_str)?;
let (mut l1, l2, l3) = parse_arch_dimensions(arch_str);
if feature_set == FeatureSet::HalfKP {
let halfkp_l1 = parse_halfkp_l1(arch_str);
if halfkp_l1 != 0 {
l1 = halfkp_l1;
}
}
Ok(ParsedArchitecture {
feature_set,
l1,
l2,
l3,
})
}
const fn pad32(n: usize) -> usize {
n.div_ceil(32) * 32
}
pub const fn network_payload_halfkp(l1: usize, l2: usize, l3: usize) -> u64 {
const HALFKP_DIMENSIONS: usize = 125388;
let ft_bias = l1 * 2;
let ft_weight = HALFKP_DIMENSIONS * l1 * 2;
let l1_bias = l2 * 4;
let l1_weight = pad32(l1 * 2) * l2;
let l2_bias = l3 * 4;
let l2_weight = pad32(l2) * l3;
let output_bias = 4;
let output_weight = l3;
(ft_bias + ft_weight + l1_bias + l1_weight + l2_bias + l2_weight + output_bias + output_weight)
as u64
}
pub const fn network_payload_halfka_hm(l1: usize, l2: usize, l3: usize) -> u64 {
const HALFKA_HM_DIMENSIONS: usize = 73305;
let ft_bias = l1 * 2;
let ft_weight = HALFKA_HM_DIMENSIONS * l1 * 2;
let l1_bias = l2 * 4;
let l1_weight = pad32(l1 * 2) * l2;
let l2_bias = l3 * 4;
let l2_weight = pad32(l2) * l3;
let output_bias = 4;
let output_weight = l3;
(ft_bias + ft_weight + l1_bias + l1_weight + l2_bias + l2_weight + output_bias + output_weight)
as u64
}
pub const fn network_payload_halfka(l1: usize, l2: usize, l3: usize) -> u64 {
const HALFKA_DIMENSIONS: usize = 138510;
let ft_bias = l1 * 2;
let ft_weight = HALFKA_DIMENSIONS * l1 * 2;
let l1_bias = l2 * 4;
let l1_weight = pad32(l1 * 2) * l2;
let l2_bias = l3 * 4;
let l2_weight = pad32(l2) * l3;
let output_bias = 4;
let output_weight = l3;
(ft_bias + ft_weight + l1_bias + l1_weight + l2_bias + l2_weight + output_bias + output_weight)
as u64
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ArchDetectionResult {
pub spec: ArchitectureSpec,
pub has_hash: bool,
pub is_bullet_format: bool,
}
const KNOWN_PAYLOADS: &[(FeatureSet, usize, usize, usize, u64)] = &[
(FeatureSet::HalfKP, 256, 32, 32, network_payload_halfkp(256, 32, 32)),
(FeatureSet::HalfKP, 512, 8, 64, network_payload_halfkp(512, 8, 64)),
(FeatureSet::HalfKP, 512, 8, 96, network_payload_halfkp(512, 8, 96)),
(FeatureSet::HalfKP, 512, 32, 32, network_payload_halfkp(512, 32, 32)),
(FeatureSet::HalfKP, 768, 16, 64, network_payload_halfkp(768, 16, 64)),
(FeatureSet::HalfKP, 1024, 8, 32, network_payload_halfkp(1024, 8, 32)),
(FeatureSet::HalfKP, 1024, 8, 64, network_payload_halfkp(1024, 8, 64)),
(FeatureSet::HalfKA_hm, 256, 32, 32, network_payload_halfka_hm(256, 32, 32)),
(FeatureSet::HalfKA_hm, 512, 8, 64, network_payload_halfka_hm(512, 8, 64)),
(FeatureSet::HalfKA_hm, 512, 8, 96, network_payload_halfka_hm(512, 8, 96)),
(FeatureSet::HalfKA_hm, 512, 32, 32, network_payload_halfka_hm(512, 32, 32)),
(FeatureSet::HalfKA_hm, 768, 16, 64, network_payload_halfka_hm(768, 16, 64)),
(FeatureSet::HalfKA_hm, 1024, 8, 32, network_payload_halfka_hm(1024, 8, 32)),
(FeatureSet::HalfKA_hm, 1024, 8, 64, network_payload_halfka_hm(1024, 8, 64)),
(FeatureSet::HalfKA_hm, 1024, 8, 96, network_payload_halfka_hm(1024, 8, 96)),
(FeatureSet::HalfKA, 256, 32, 32, network_payload_halfka(256, 32, 32)),
(FeatureSet::HalfKA, 512, 8, 64, network_payload_halfka(512, 8, 64)),
(FeatureSet::HalfKA, 512, 8, 96, network_payload_halfka(512, 8, 96)),
(FeatureSet::HalfKA, 512, 32, 32, network_payload_halfka(512, 32, 32)),
(FeatureSet::HalfKA, 768, 16, 64, network_payload_halfka(768, 16, 64)),
(FeatureSet::HalfKA, 1024, 8, 32, network_payload_halfka(1024, 8, 32)),
(FeatureSet::HalfKA, 1024, 8, 64, network_payload_halfka(1024, 8, 64)),
(FeatureSet::HalfKA, 1024, 8, 96, network_payload_halfka(1024, 8, 96)),
];
pub fn detect_architecture_from_size(
file_size: u64,
arch_len: usize,
feature_set_hint: Option<FeatureSet>,
) -> Option<ArchDetectionResult> {
let header_size = 12 + arch_len as u64;
if file_size < header_size {
return None;
}
let base = file_size - header_size;
for &(feature_set, l1, l2, l3, expected_payload) in KNOWN_PAYLOADS {
if let Some(hint) = feature_set_hint {
if feature_set != hint {
continue;
}
}
if base == expected_payload {
return Some(ArchDetectionResult {
spec: ArchitectureSpec::new(feature_set, l1, l2, l3, Activation::CReLU),
has_hash: false,
is_bullet_format: false,
});
}
if base == expected_payload + 8 {
return Some(ArchDetectionResult {
spec: ArchitectureSpec::new(feature_set, l1, l2, l3, Activation::CReLU),
has_hash: true,
is_bullet_format: false,
});
}
let bullet_base = expected_payload + 8; if base > bullet_base && base <= bullet_base + 63 {
return Some(ArchDetectionResult {
spec: ArchitectureSpec::new(feature_set, l1, l2, l3, Activation::CReLU),
has_hash: true,
is_bullet_format: true,
});
}
}
None
}
pub fn list_candidate_architectures(
file_size: u64,
arch_len: usize,
) -> Vec<(ArchitectureSpec, i64)> {
let header_size = 12 + arch_len as u64;
let base = if file_size >= header_size {
file_size - header_size
} else {
return vec![];
};
let mut candidates: Vec<(ArchitectureSpec, i64)> = KNOWN_PAYLOADS
.iter()
.flat_map(|&(feature_set, l1, l2, l3, expected_payload)| {
let spec = ArchitectureSpec::new(feature_set, l1, l2, l3, Activation::CReLU);
vec![
(spec, base as i64 - expected_payload as i64), (spec, base as i64 - (expected_payload + 8) as i64), ]
})
.collect();
candidates.sort_unstable_by_key(|(_, diff)| diff.abs());
candidates.truncate(10);
candidates
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feature_set_display() {
assert_eq!(FeatureSet::HalfKP.as_str(), "HalfKP");
assert_eq!(FeatureSet::HalfKA_hm.as_str(), "HalfKA_hm");
assert_eq!(FeatureSet::HalfKA.as_str(), "HalfKA");
assert_eq!(FeatureSet::LayerStacks.as_str(), "LayerStacks");
}
#[test]
fn test_activation_display() {
assert_eq!(Activation::CReLU.as_str(), "CReLU");
assert_eq!(Activation::SCReLU.as_str(), "SCReLU");
assert_eq!(Activation::PairwiseCReLU.as_str(), "PairwiseCReLU");
}
#[test]
fn test_activation_output_dim_divisor() {
assert_eq!(Activation::CReLU.output_dim_divisor(), 1);
assert_eq!(Activation::SCReLU.output_dim_divisor(), 1);
assert_eq!(Activation::PairwiseCReLU.output_dim_divisor(), 2);
}
#[test]
fn test_activation_from_header_suffix() {
assert_eq!(
Activation::from_header_suffix("Features=HalfKA_hm[73305->512x2]"),
Activation::CReLU
);
assert_eq!(
Activation::from_header_suffix("Features=HalfKA_hm[73305->512x2]-SCReLU"),
Activation::SCReLU
);
assert_eq!(
Activation::from_header_suffix("Features=HalfKA_hm[73305->512/2x2]-Pairwise"),
Activation::PairwiseCReLU
);
assert_eq!(
Activation::from_header_suffix("Features=HalfKA_hm[73305->512/2x2]-PairwiseCReLU"),
Activation::PairwiseCReLU
);
}
#[test]
fn test_architecture_spec_name() {
let spec = ArchitectureSpec::new(FeatureSet::HalfKA_hm, 512, 8, 96, Activation::CReLU);
assert_eq!(spec.name(), "HalfKA_hm-512-8-96-CReLU");
let spec2 = ArchitectureSpec::new(FeatureSet::HalfKP, 256, 32, 32, Activation::SCReLU);
assert_eq!(spec2.name(), "HalfKP-256-32-32-SCReLU");
}
#[test]
fn test_parse_feature_set_from_arch() {
assert_eq!(
parse_feature_set_from_arch(
"Features=HalfKA_hm[73305->512x2],Network=AffineTransform[1<-96]"
)
.unwrap(),
FeatureSet::HalfKA_hm
);
assert_eq!(
parse_feature_set_from_arch(
"Features=HalfKA[138510->512x2],Network=AffineTransform[1<-96]"
)
.unwrap(),
FeatureSet::HalfKA
);
assert_eq!(
parse_feature_set_from_arch(
"Features=HalfKA[73305->512x2],Network=AffineTransform[1<-96]"
)
.unwrap(),
FeatureSet::HalfKA_hm
);
assert_eq!(
parse_feature_set_from_arch("Features=HalfKP[125388->256x2]").unwrap(),
FeatureSet::HalfKP
);
}
#[test]
fn test_parse_feature_set_from_arch_missing_dimensions() {
let err = parse_feature_set_from_arch("Features=HalfKA,Network=AffineTransform[1<-96]")
.unwrap_err();
assert!(err.contains("missing input dimensions"));
}
#[test]
fn test_parse_arch_dimensions() {
let arch = "Features=HalfKA_hm[73305->512x2],Network=AffineTransform[1<-96](ClippedReLU[96](AffineTransform[96<-8](ClippedReLU[8](AffineTransform[8<-1024](InputSlice[1024(0:1024)])))))";
assert_eq!(parse_arch_dimensions(arch), (512, 8, 96));
let arch = "Features=HalfKA_hm[73305->1024x2],Network=AffineTransform[1<-96](ClippedReLU[96](AffineTransform[96<-8](ClippedReLU[8](AffineTransform[8<-2048](InputSlice[2048(0:2048)])))))";
assert_eq!(parse_arch_dimensions(arch), (1024, 8, 96));
let arch = "Features=HalfKA_hm^[73305->512x2]-SCReLU,fv_scale=13,l2=8,l3=96,qa=127,qb=64";
assert_eq!(parse_arch_dimensions(arch), (512, 8, 96));
let arch = "Features=HalfKA_hm^[73305->1024x2]-SCReLU,fv_scale=16,l2=8,l3=96,qa=127,qb=64";
assert_eq!(parse_arch_dimensions(arch), (1024, 8, 96));
let arch = "Features=HalfKA_hm^[73305->256x2]-SCReLU,fv_scale=13,l2=32,l3=32,qa=127,qb=64";
assert_eq!(parse_arch_dimensions(arch), (256, 32, 32));
let arch = "Features=HalfKP[125388->256x2]";
assert_eq!(parse_arch_dimensions(arch), (256, 0, 0));
let arch = "Features=HalfKA_hm[73305->512/2x2]-Pairwise,fv_scale=10,l1_input=512,l2=8,l3=96,qa=255,qb=64,scale=1600,pairwise=true";
assert_eq!(parse_arch_dimensions(arch), (512, 8, 96));
let arch = "Features=HalfKA_hm[73305->256/2x2]-Pairwise,fv_scale=10,l1_input=256,l2=32,l3=32,qa=255,qb=64";
assert_eq!(parse_arch_dimensions(arch), (256, 32, 32));
assert_eq!(parse_arch_dimensions("unknown"), (0, 0, 0));
assert_eq!(parse_arch_dimensions(""), (0, 0, 0));
}
#[test]
fn test_network_payload_halfkp() {
let payload = network_payload_halfkp(768, 16, 64);
assert_eq!(payload, 192_624_516);
let payload = network_payload_halfkp(256, 32, 32);
assert_eq!(payload, 64_216_868);
}
#[test]
fn test_network_payload_halfka_hm() {
let payload = network_payload_halfka_hm(256, 32, 32);
assert_eq!(payload, 37_550_372);
let payload = network_payload_halfka_hm(512, 8, 96);
assert_eq!(payload, 75_077_124);
}
#[test]
fn test_detect_architecture_from_size_nn_bin() {
let result = detect_architecture_from_size(192_624_720, 184, Some(FeatureSet::HalfKP));
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.spec.feature_set, FeatureSet::HalfKP);
assert_eq!(result.spec.l1, 768);
assert_eq!(result.spec.l2, 16);
assert_eq!(result.spec.l3, 64);
assert!(result.has_hash);
assert!(!result.is_bullet_format);
}
#[test]
fn test_detect_architecture_from_size_suisho5() {
let result = detect_architecture_from_size(64_217_066, 178, Some(FeatureSet::HalfKP));
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.spec.feature_set, FeatureSet::HalfKP);
assert_eq!(result.spec.l1, 256);
assert_eq!(result.spec.l2, 32);
assert_eq!(result.spec.l3, 32);
assert!(result.has_hash);
assert!(!result.is_bullet_format);
}
#[test]
fn test_detect_architecture_from_size_no_hint() {
let result = detect_architecture_from_size(192_624_720, 184, None);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.spec.l1, 768);
assert_eq!(result.spec.l2, 16);
assert_eq!(result.spec.l3, 64);
assert!(!result.is_bullet_format);
}
#[test]
fn test_detect_architecture_from_size_unknown() {
let result = detect_architecture_from_size(12345, 100, None);
assert!(result.is_none());
}
#[test]
fn test_detect_architecture_hash_without() {
let result = detect_architecture_from_size(192_624_712, 184, Some(FeatureSet::HalfKP));
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.spec.l1, 768);
assert_eq!(result.spec.l2, 16);
assert_eq!(result.spec.l3, 64);
assert!(!result.has_hash); assert!(!result.is_bullet_format);
}
#[test]
fn test_detect_bullet_shogi_halfka_512_8_96() {
let result = detect_architecture_from_size(141_847_232, 105, Some(FeatureSet::HalfKA));
assert!(result.is_some(), "bullet-shogi HalfKA 512-8-96 should be detected");
let result = result.unwrap();
assert_eq!(result.spec.feature_set, FeatureSet::HalfKA);
assert_eq!(result.spec.l1, 512);
assert_eq!(result.spec.l2, 8);
assert_eq!(result.spec.l3, 96);
assert!(result.has_hash);
assert!(result.is_bullet_format, "Should be detected as bullet-shogi format");
}
#[test]
fn test_detect_bullet_shogi_various_paddings() {
let payload = network_payload_halfka(512, 8, 96);
let arch_len = 100usize;
let header_size = 12 + arch_len as u64;
let file_size_no_padding = header_size + payload + 8;
let result =
detect_architecture_from_size(file_size_no_padding, arch_len, Some(FeatureSet::HalfKA));
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.spec.l1, 512);
assert!(!result.is_bullet_format, "padding=0 should be nnue-pytorch format");
for padding in 1..64u64 {
let file_size_with_padding = header_size + payload + 8 + padding;
let result = detect_architecture_from_size(
file_size_with_padding,
arch_len,
Some(FeatureSet::HalfKA),
);
assert!(result.is_some(), "Should detect with padding={padding}");
let result = result.unwrap();
assert_eq!(result.spec.l1, 512, "L1 should be 512 with padding={padding}");
assert!(result.is_bullet_format, "padding={padding} should be bullet-shogi format");
}
let file_size_too_much_padding = header_size + payload + 8 + 64;
let result = detect_architecture_from_size(
file_size_too_much_padding,
arch_len,
Some(FeatureSet::HalfKA),
);
assert!(result.is_none(), "padding=64 should not be detected");
}
#[test]
fn test_network_payload_halfka() {
let payload = network_payload_halfka(512, 8, 96);
assert_eq!(payload, 141_847_044);
}
}