use noru::config::{Activation, NnueConfig};
pub const BOARD_SIZE: usize = 15;
pub const NUM_SQUARES: usize = BOARD_SIZE * BOARD_SIZE;
pub const PS_BASE: usize = 0;
pub const LP_BASE: usize = 450;
pub const COMPOUND_BASE: usize = 2754;
pub const DENSITY_BASE: usize = 2854;
pub const CROSS_LINE_BASE: usize = 2904;
pub const BROKEN_BASE: usize = 3416;
pub const RESERVED_BASE: usize = 3848;
pub const TOTAL_FEATURE_SIZE: usize = 4096;
pub const PS_PER_PERSP: usize = NUM_SQUARES; pub const HALF_FEATURE_SIZE: usize = PS_PER_PERSP;
pub const LP_NUM_LENGTH: usize = 8;
pub const LP_NUM_OPEN: usize = 4;
pub const LP_NUM_DIR: usize = 4;
pub const LP_NUM_ZONE: usize = 9;
pub const LP_PER_PERSP: usize = LP_NUM_LENGTH * LP_NUM_OPEN * LP_NUM_DIR * LP_NUM_ZONE;
pub const COMPOUND_PER_PERSP: usize = 50;
pub const DENSITY_NUM_CATEGORIES: usize = 5;
pub const DENSITY_NUM_BUCKETS: usize = 10;
pub const DENSITY_CAT_MY_COUNT: usize = 0;
pub const DENSITY_CAT_OPP_COUNT: usize = 1;
pub const DENSITY_CAT_MY_LOCAL: usize = 2;
pub const DENSITY_CAT_OPP_LOCAL: usize = 3;
pub const DENSITY_CAT_LEGAL: usize = 4;
pub const BROKEN_NUM_SHAPES: usize = 3;
pub const BROKEN_NUM_OPEN: usize = 2;
pub const BROKEN_PER_PERSP: usize =
BROKEN_NUM_SHAPES * BROKEN_NUM_OPEN * LP_NUM_DIR * LP_NUM_ZONE;
pub const BROKEN_SHAPE_THREE: usize = 0;
pub const BROKEN_SHAPE_JUMP_FOUR: usize = 1;
pub const BROKEN_SHAPE_DOUBLE_THREE: usize = 2;
pub const CROSS_LINE_BUCKETS: usize = 256;
pub const CROSS_LINE_PER_PERSP: usize = CROSS_LINE_BUCKETS;
pub const MAX_ACTIVE_FEATURES: usize = 1536;
pub const GOMOKU_NNUE_CONFIG: NnueConfig = NnueConfig {
feature_size: TOTAL_FEATURE_SIZE,
accumulator_size: 512,
hidden_sizes: std::borrow::Cow::Borrowed(&[64]),
activation: Activation::CReLU,
};
const _: () = assert!(LP_BASE == PS_BASE + PS_PER_PERSP * 2);
const _: () = assert!(COMPOUND_BASE == LP_BASE + LP_PER_PERSP * 2);
const _: () = assert!(DENSITY_BASE == COMPOUND_BASE + COMPOUND_PER_PERSP * 2);
const _: () = assert!(CROSS_LINE_BASE == DENSITY_BASE + DENSITY_NUM_CATEGORIES * DENSITY_NUM_BUCKETS);
const _: () = assert!(BROKEN_BASE == CROSS_LINE_BASE + CROSS_LINE_PER_PERSP * 2);
const _: () = assert!(RESERVED_BASE == BROKEN_BASE + BROKEN_PER_PERSP * 2);
const _: () = assert!(RESERVED_BASE <= TOTAL_FEATURE_SIZE);
#[inline]
pub fn ps_index(perspective: usize, square: usize) -> usize {
debug_assert!(perspective < 2);
debug_assert!(square < NUM_SQUARES);
PS_BASE + perspective * PS_PER_PERSP + square
}
#[inline]
pub fn lp_rich_index(
perspective: usize,
length: usize,
open: usize,
dir: usize,
zone: usize,
) -> usize {
debug_assert!(perspective < 2);
debug_assert!(length < LP_NUM_LENGTH);
debug_assert!(open < LP_NUM_OPEN);
debug_assert!(dir < LP_NUM_DIR);
debug_assert!(zone < LP_NUM_ZONE);
LP_BASE
+ perspective * LP_PER_PERSP
+ length * (LP_NUM_OPEN * LP_NUM_DIR * LP_NUM_ZONE)
+ open * (LP_NUM_DIR * LP_NUM_ZONE)
+ dir * LP_NUM_ZONE
+ zone
}
#[inline]
pub fn compound_index(perspective: usize, combo_id: usize) -> usize {
debug_assert!(perspective < 2);
debug_assert!(combo_id < COMPOUND_PER_PERSP);
COMPOUND_BASE + perspective * COMPOUND_PER_PERSP + combo_id
}
#[inline]
pub fn density_index(category: usize, bucket: usize) -> usize {
debug_assert!(category < DENSITY_NUM_CATEGORIES);
debug_assert!(bucket < DENSITY_NUM_BUCKETS);
DENSITY_BASE + category * DENSITY_NUM_BUCKETS + bucket
}
#[inline]
pub fn cross_line_index(perspective: usize, bucket: usize) -> usize {
debug_assert!(perspective < 2);
debug_assert!(bucket < CROSS_LINE_BUCKETS);
CROSS_LINE_BASE + perspective * CROSS_LINE_PER_PERSP + bucket
}
#[inline]
pub fn broken_index(
perspective: usize,
shape: usize,
open: usize,
dir: usize,
zone: usize,
) -> usize {
debug_assert!(perspective < 2);
debug_assert!(shape < BROKEN_NUM_SHAPES);
debug_assert!(open < BROKEN_NUM_OPEN);
debug_assert!(dir < LP_NUM_DIR);
debug_assert!(zone < LP_NUM_ZONE);
BROKEN_BASE
+ perspective * BROKEN_PER_PERSP
+ shape * (BROKEN_NUM_OPEN * LP_NUM_DIR * LP_NUM_ZONE)
+ open * (LP_NUM_DIR * LP_NUM_ZONE)
+ dir * LP_NUM_ZONE
+ zone
}
#[inline]
pub fn cross_line_hash(
my_cells: [u8; 9], ) -> usize {
let canonical = d4_canonical_3x3(my_cells);
let h = canonical.wrapping_mul(0x9E37_79B9_7F4A_7C15);
(h >> (64 - 8)) as usize }
#[inline]
fn pack_3x3(c: &[u8; 9]) -> u64 {
let mut v = 0u64;
for &cell in c {
v = (v << 2) | (cell as u64 & 0b11);
}
v
}
#[inline]
fn rotate_3x3(c: [u8; 9]) -> [u8; 9] {
[c[6], c[3], c[0], c[7], c[4], c[1], c[8], c[5], c[2]]
}
#[inline]
fn mirror_3x3(c: [u8; 9]) -> [u8; 9] {
[c[2], c[1], c[0], c[5], c[4], c[3], c[8], c[7], c[6]]
}
#[inline]
fn d4_canonical_3x3(cells: [u8; 9]) -> u64 {
let mut c = cells;
let mut best = pack_3x3(&c);
for _ in 0..3 {
c = rotate_3x3(c);
let v = pack_3x3(&c);
if v < best {
best = v;
}
}
let mut c = mirror_3x3(cells);
let v = pack_3x3(&c);
if v < best {
best = v;
}
for _ in 0..3 {
c = rotate_3x3(c);
let v = pack_3x3(&c);
if v < best {
best = v;
}
}
best
}
#[inline]
pub fn length_bucket(count: u32) -> usize {
match count {
1 => 0,
2 => 1,
3 => 2,
4 => 3,
5 => 4,
_ => 5,
}
}
#[inline]
pub fn open_bucket(open_front: bool, open_back: bool) -> usize {
(open_front as usize) | ((open_back as usize) << 1)
}
#[inline]
pub fn zone_for(row: i32, col: i32) -> usize {
let r = (row.clamp(0, BOARD_SIZE as i32 - 1) / 5) as usize;
let c = (col.clamp(0, BOARD_SIZE as i32 - 1) / 5) as usize;
r * 3 + c
}
#[inline]
pub fn count_bucket(n: u32) -> usize {
match n {
0 => 0,
1..=3 => 1,
4..=7 => 2,
8..=15 => 3,
16..=30 => 4,
31..=60 => 5,
61..=100 => 6,
101..=150 => 7,
151..=200 => 8,
_ => 9,
}
}
#[inline]
pub fn local_density_bucket(n: u32) -> usize {
(n as usize).min(DENSITY_NUM_BUCKETS - 1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn layout_sizes() {
assert_eq!(PS_PER_PERSP * 2, LP_BASE);
assert_eq!(LP_PER_PERSP, 1152);
assert_eq!(LP_PER_PERSP * 2 + LP_BASE, COMPOUND_BASE);
assert_eq!(CROSS_LINE_BASE, 2904);
assert_eq!(CROSS_LINE_PER_PERSP * 2 + CROSS_LINE_BASE, BROKEN_BASE);
assert_eq!(BROKEN_BASE, 3416);
assert_eq!(BROKEN_PER_PERSP, 216);
assert_eq!(BROKEN_PER_PERSP * 2 + BROKEN_BASE, RESERVED_BASE);
assert_eq!(RESERVED_BASE, 3848);
assert!(RESERVED_BASE < TOTAL_FEATURE_SIZE);
}
#[test]
fn ps_indexing() {
assert_eq!(ps_index(0, 0), 0);
assert_eq!(ps_index(0, 224), 224);
assert_eq!(ps_index(1, 0), 225);
assert_eq!(ps_index(1, 224), 449);
}
#[test]
fn lp_rich_in_range_and_unique() {
let mut seen = std::collections::HashSet::new();
for p in 0..2 {
for l in 0..LP_NUM_LENGTH {
for o in 0..LP_NUM_OPEN {
for d in 0..LP_NUM_DIR {
for z in 0..LP_NUM_ZONE {
let idx = lp_rich_index(p, l, o, d, z);
assert!(idx >= LP_BASE && idx < COMPOUND_BASE);
assert!(seen.insert(idx), "dup at {p},{l},{o},{d},{z}");
}
}
}
}
}
assert_eq!(seen.len(), LP_PER_PERSP * 2);
}
#[test]
fn density_index_in_range() {
for c in 0..DENSITY_NUM_CATEGORIES {
for b in 0..DENSITY_NUM_BUCKETS {
let idx = density_index(c, b);
assert!(idx >= DENSITY_BASE && idx < RESERVED_BASE);
}
}
}
#[test]
fn zone_grid() {
assert_eq!(zone_for(0, 0), 0);
assert_eq!(zone_for(7, 7), 4);
assert_eq!(zone_for(14, 14), 8);
assert_eq!(zone_for(0, 14), 2);
assert_eq!(zone_for(14, 0), 6);
}
#[test]
fn open_bucket_combinations() {
assert_eq!(open_bucket(false, false), 0);
assert_eq!(open_bucket(true, false), 1);
assert_eq!(open_bucket(false, true), 2);
assert_eq!(open_bucket(true, true), 3);
}
#[test]
fn length_bucket_mapping() {
assert_eq!(length_bucket(1), 0);
assert_eq!(length_bucket(5), 4);
assert_eq!(length_bucket(6), 5);
assert_eq!(length_bucket(99), 5);
}
}