use crate::board::{Board, Stone, BOARD_SIZE, NUM_CELLS};
use crate::features::{
broken_index, compound_index, count_bucket, cross_line_hash, cross_line_index, density_index,
length_bucket, local_density_bucket, lp_rich_index, open_bucket, ps_index, zone_for,
BROKEN_SHAPE_DOUBLE_THREE, BROKEN_SHAPE_JUMP_FOUR, BROKEN_SHAPE_THREE, DENSITY_CAT_LEGAL,
DENSITY_CAT_MY_COUNT, DENSITY_CAT_MY_LOCAL, DENSITY_CAT_OPP_COUNT, DENSITY_CAT_OPP_LOCAL,
MAX_ACTIVE_FEATURES,
};
use crate::heuristic::{scan_line, DIR};
use noru::network::{forward, Accumulator, NnueWeights};
use std::sync::OnceLock;
static COMPOUND_ENABLED: OnceLock<bool> = OnceLock::new();
fn compound_enabled() -> bool {
*COMPOUND_ENABLED.get_or_init(|| std::env::var("NORU_NO_COMPOUND").is_err())
}
pub fn compute_active_features(board: &Board) -> (Vec<usize>, Vec<usize>) {
let (my_bb, opp_bb) = match board.side_to_move {
Stone::Black => (&board.black, &board.white),
Stone::White => (&board.white, &board.black),
};
let mut stm = Vec::with_capacity(MAX_ACTIVE_FEATURES);
let mut nstm = Vec::with_capacity(MAX_ACTIVE_FEATURES);
for sq in 0..NUM_CELLS {
if my_bb.get(sq) {
stm.push(ps_index(0, sq));
nstm.push(ps_index(1, sq));
} else if opp_bb.get(sq) {
stm.push(ps_index(1, sq));
nstm.push(ps_index(0, sq));
}
}
for idx in 0..NUM_CELLS {
let row = (idx / BOARD_SIZE) as i32;
let col = (idx % BOARD_SIZE) as i32;
for (dir_idx, &(dr, dc)) in DIR.iter().enumerate() {
if my_bb.get(idx) && is_line_start(my_bb, row, col, dr, dc) {
let info = scan_line(my_bb, opp_bb, row, col, dr, dc);
let z = zone_for(row, col);
let len = length_bucket(info.count);
let op = open_bucket(info.open_front, info.open_back);
stm.push(lp_rich_index(0, len, op, dir_idx, z));
nstm.push(lp_rich_index(1, len, op, dir_idx, z));
}
if opp_bb.get(idx) && is_line_start(opp_bb, row, col, dr, dc) {
let info = scan_line(opp_bb, my_bb, row, col, dr, dc);
let z = zone_for(row, col);
let len = length_bucket(info.count);
let op = open_bucket(info.open_front, info.open_back);
stm.push(lp_rich_index(1, len, op, dir_idx, z));
nstm.push(lp_rich_index(0, len, op, dir_idx, z));
}
}
}
if compound_enabled() {
compute_compound_threats(my_bb, opp_bb, &mut stm, &mut nstm);
}
let my_count = my_bb.count_ones();
let opp_count = opp_bb.count_ones();
push_density(&mut stm, &mut nstm, DENSITY_CAT_MY_COUNT, count_bucket(my_count));
push_density(&mut stm, &mut nstm, DENSITY_CAT_OPP_COUNT, count_bucket(opp_count));
let (my_local, opp_local) = local_density(board);
push_density(
&mut stm,
&mut nstm,
DENSITY_CAT_MY_LOCAL,
local_density_bucket(my_local),
);
push_density(
&mut stm,
&mut nstm,
DENSITY_CAT_OPP_LOCAL,
local_density_bucket(opp_local),
);
let legal = (NUM_CELLS as u32).saturating_sub(my_count + opp_count);
push_density(&mut stm, &mut nstm, DENSITY_CAT_LEGAL, count_bucket(legal));
for sq in 0..NUM_CELLS {
if !my_bb.get(sq) && !opp_bb.get(sq) {
continue;
}
let row = (sq / BOARD_SIZE) as i32;
let col = (sq % BOARD_SIZE) as i32;
let stm_cells = collect_3x3(my_bb, opp_bb, row, col);
let stm_bucket = cross_line_hash(stm_cells);
stm.push(cross_line_index(0, stm_bucket));
nstm.push(cross_line_index(1, stm_bucket));
let nstm_cells = swap_mine_opp(stm_cells);
let nstm_bucket = cross_line_hash(nstm_cells);
stm.push(cross_line_index(1, nstm_bucket));
nstm.push(cross_line_index(0, nstm_bucket));
}
for idx in 0..NUM_CELLS {
let row = (idx / BOARD_SIZE) as i32;
let col = (idx % BOARD_SIZE) as i32;
if my_bb.get(idx) {
for (dir_idx, &(dr, dc)) in DIR.iter().enumerate() {
detect_broken_and_push(
my_bb, opp_bb, row, col, dr, dc, dir_idx, 0, 1, &mut stm, &mut nstm,
);
}
}
if opp_bb.get(idx) {
for (dir_idx, &(dr, dc)) in DIR.iter().enumerate() {
detect_broken_and_push(
opp_bb, my_bb, row, col, dr, dc, dir_idx, 1, 0, &mut stm, &mut nstm,
);
}
}
}
(stm, nstm)
}
#[allow(clippy::too_many_arguments)]
fn detect_broken_and_push(
stones: &crate::board::BitBoard,
opp: &crate::board::BitBoard,
row: i32,
col: i32,
dr: i32,
dc: i32,
dir_idx: usize,
perspective_mine: usize,
perspective_opp: usize,
stm: &mut Vec<usize>,
nstm: &mut Vec<usize>,
) {
let pr = row - dr;
let pc = col - dc;
if pr >= 0 && pr < BOARD_SIZE as i32 && pc >= 0 && pc < BOARD_SIZE as i32 {
if stones.get((pr as usize) * BOARD_SIZE + pc as usize) {
return; }
}
let mut line = [2u8; 11]; for off in -5i32..=5 {
let nr = row + dr * off;
let nc = col + dc * off;
if nr < 0 || nr >= BOARD_SIZE as i32 || nc < 0 || nc >= BOARD_SIZE as i32 {
continue; }
let cell_idx = (nr as usize) * BOARD_SIZE + nc as usize;
let slot = (off + 5) as usize;
if stones.get(cell_idx) {
line[slot] = 1;
} else if opp.get(cell_idx) {
line[slot] = 2;
} else {
line[slot] = 0;
}
}
let zone = zone_for(row, col);
if let Some((shape, is_open)) = classify_broken_shape(&line) {
let open_bucket = if is_open { 1 } else { 0 };
stm.push(broken_index(
perspective_mine,
shape,
open_bucket,
dir_idx,
zone,
));
nstm.push(broken_index(
perspective_opp,
shape,
open_bucket,
dir_idx,
zone,
));
}
}
fn classify_broken_shape(line: &[u8; 11]) -> Option<(usize, bool)> {
debug_assert!(line[5] == 1);
let open_left = line[4] == 0;
let mut cells: [u8; 5] = [2; 5];
for i in 0..5 {
cells[i] = line[6 + i];
}
let mut mine_right = 0u32;
let mut gap_count = 0u32;
let mut right_open = false;
let mut prev_was_empty = false;
let mut scan_ended_early = false;
for &c in &cells {
if c == 2 {
scan_ended_early = true;
break;
}
if c == 0 {
if prev_was_empty {
right_open = true;
gap_count = gap_count.saturating_sub(1);
scan_ended_early = true;
break;
}
gap_count += 1;
prev_was_empty = true;
} else {
mine_right += 1;
prev_was_empty = false;
}
}
if !scan_ended_early && prev_was_empty {
right_open = true;
gap_count = gap_count.saturating_sub(1);
}
if gap_count == 0 {
return None;
}
let total_mine = 1 + mine_right; let is_open = open_left && right_open;
match (total_mine, gap_count) {
(3, 1) => Some((BROKEN_SHAPE_THREE, is_open)),
(4, 1) => Some((BROKEN_SHAPE_JUMP_FOUR, is_open)),
(3, 2) => Some((BROKEN_SHAPE_DOUBLE_THREE, is_open)),
_ => None,
}
}
#[inline]
fn collect_3x3(
my_bb: &crate::board::BitBoard,
opp_bb: &crate::board::BitBoard,
row: i32,
col: i32,
) -> [u8; 9] {
let mut cells = [0u8; 9];
let mut i = 0;
for dr in -1..=1 {
for dc in -1..=1 {
let nr = row + dr;
let nc = col + dc;
if nr < 0 || nr >= BOARD_SIZE as i32 || nc < 0 || nc >= BOARD_SIZE as i32 {
cells[i] = 3;
} else {
let idx = (nr as usize) * BOARD_SIZE + (nc as usize);
if my_bb.get(idx) {
cells[i] = 1;
} else if opp_bb.get(idx) {
cells[i] = 2;
} else {
cells[i] = 0;
}
}
i += 1;
}
}
cells
}
#[inline]
fn swap_mine_opp(c: [u8; 9]) -> [u8; 9] {
let mut out = [0u8; 9];
for i in 0..9 {
out[i] = match c[i] {
1 => 2,
2 => 1,
v => v,
};
}
out
}
#[inline]
fn is_line_start(bb: &crate::board::BitBoard, row: i32, col: i32, dr: i32, dc: i32) -> bool {
let pr = row - dr;
let pc = col - dc;
if pr < 0 || pr >= BOARD_SIZE as i32 || pc < 0 || pc >= BOARD_SIZE as i32 {
return true;
}
!bb.get(pr as usize * BOARD_SIZE + pc as usize)
}
#[inline]
fn push_density(stm: &mut Vec<usize>, nstm: &mut Vec<usize>, cat: usize, bucket: usize) {
let idx = density_index(cat, bucket);
stm.push(idx);
nstm.push(idx);
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum Threat {
None,
OpenTwo, ClosedThree, OpenThree, ClosedFour, OpenFour, Five, }
fn classify_threat(count: u32, open_ends: u32) -> Threat {
match (count, open_ends) {
(5.., _) => Threat::Five,
(4, 2) => Threat::OpenFour,
(4, 1) => Threat::ClosedFour,
(3, 2) => Threat::OpenThree,
(3, 1) => Threat::ClosedThree,
(2, 2) => Threat::OpenTwo,
_ => Threat::None,
}
}
fn compound_combo_id(threats: &[Threat; 4]) -> Option<usize> {
let mut sorted = *threats;
sorted.sort_unstable_by(|a, b| b.cmp(a));
let t1 = sorted[0];
let t2 = sorted[1];
let t3 = sorted[2];
if t1 == Threat::None {
return None; }
let t1_rank = threat_rank(t1);
let t2_rank = threat_rank(t2);
if t2 == Threat::None {
return None;
}
let dual_id = match t1_rank {
0 => 6 + t2_rank, 1 => 12 + (t2_rank - 1), 2 => 17 + (t2_rank - 2), 3 => 21 + (t2_rank - 3), 4 => 24 + (t2_rank - 4), 5 => 26, _ => return None,
};
if t3 != Threat::None && dual_id < 33 {
let triple_base = 27;
let triple_id = triple_base + threat_rank(t1).min(5);
return Some(triple_id);
}
Some(dual_id)
}
fn threat_rank(t: Threat) -> usize {
match t {
Threat::Five => 0,
Threat::OpenFour => 1,
Threat::ClosedFour => 2,
Threat::OpenThree => 3,
Threat::ClosedThree => 4,
Threat::OpenTwo => 5,
Threat::None => 6,
}
}
fn compute_compound_threats(
my_bb: &crate::board::BitBoard,
opp_bb: &crate::board::BitBoard,
stm: &mut Vec<usize>,
nstm: &mut Vec<usize>,
) {
for idx in 0..NUM_CELLS {
let row = (idx / BOARD_SIZE) as i32;
let col = (idx % BOARD_SIZE) as i32;
if my_bb.get(idx) {
let mut threats = [Threat::None; 4];
for (di, &(dr, dc)) in DIR.iter().enumerate() {
let info = scan_line(my_bb, opp_bb, row, col, dr, dc);
let open = info.open_front as u32 + info.open_back as u32;
threats[di] = classify_threat(info.count, open);
}
if let Some(combo) = compound_combo_id(&threats) {
stm.push(compound_index(0, combo));
nstm.push(compound_index(1, combo));
}
}
if opp_bb.get(idx) {
let mut threats = [Threat::None; 4];
for (di, &(dr, dc)) in DIR.iter().enumerate() {
let info = scan_line(opp_bb, my_bb, row, col, dr, dc);
let open = info.open_front as u32 + info.open_back as u32;
threats[di] = classify_threat(info.count, open);
}
if let Some(combo) = compound_combo_id(&threats) {
stm.push(compound_index(1, combo));
nstm.push(compound_index(0, combo));
}
}
}
}
fn local_density(board: &Board) -> (u32, u32) {
let Some(mv) = board.last_move else {
return (0, 0);
};
let (my_bb, opp_bb) = match board.side_to_move {
Stone::Black => (&board.black, &board.white),
Stone::White => (&board.white, &board.black),
};
let r = (mv / BOARD_SIZE) as i32;
let c = (mv % BOARD_SIZE) as i32;
let mut my = 0u32;
let mut op = 0u32;
for dr in -1..=1 {
for dc in -1..=1 {
if dr == 0 && dc == 0 {
continue;
}
let nr = r + dr;
let nc = c + dc;
if nr < 0 || nr >= BOARD_SIZE as i32 || nc < 0 || nc >= BOARD_SIZE as i32 {
continue;
}
let i = (nr as usize) * BOARD_SIZE + nc as usize;
if my_bb.get(i) {
my += 1;
}
if opp_bb.get(i) {
op += 1;
}
}
}
(my, op)
}
pub fn evaluate(board: &Board, weights: &NnueWeights) -> i32 {
let (stm_feats, nstm_feats) = compute_active_features(board);
let mut acc = Accumulator::new(&weights.feature_bias);
acc.refresh(weights, &stm_feats, &nstm_feats);
forward(&acc, weights)
}
pub struct IncrementalEval {
pub accumulator: Accumulator,
stack: Vec<Accumulator>,
}
impl IncrementalEval {
pub fn new(weights: &NnueWeights) -> Self {
Self {
accumulator: Accumulator::new(&weights.feature_bias),
stack: Vec::with_capacity(225),
}
}
pub fn refresh(&mut self, board: &Board, weights: &NnueWeights) {
let (stm_feats, nstm_feats) = compute_active_features(board);
self.accumulator.refresh(weights, &stm_feats, &nstm_feats);
}
pub fn push_move(&mut self, board: &Board, _mv: usize, weights: &NnueWeights) {
self.stack.push(self.accumulator.clone());
self.refresh(board, weights);
}
pub fn pop_move(&mut self) {
if let Some(prev) = self.stack.pop() {
self.accumulator = prev;
}
}
pub fn eval(&self, weights: &NnueWeights) -> i32 {
forward(&self.accumulator, weights)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::board::Board;
use crate::features::{
broken_index, compound_index, BROKEN_SHAPE_DOUBLE_THREE, BROKEN_SHAPE_JUMP_FOUR,
BROKEN_SHAPE_THREE, GOMOKU_NNUE_CONFIG, HALF_FEATURE_SIZE, LP_BASE, MAX_ACTIVE_FEATURES,
PS_BASE, TOTAL_FEATURE_SIZE,
};
#[test]
fn empty_board_has_only_density_features() {
let board = Board::new();
let (stm, nstm) = compute_active_features(&board);
assert_eq!(stm.len(), 5);
assert_eq!(nstm.len(), 5);
}
#[test]
fn evaluate_zero_weights() {
let board = Board::new();
let weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
assert_eq!(evaluate(&board, &weights), 0);
}
#[test]
fn features_include_lp_after_two_in_row() {
let mut board = Board::new();
board.make_move(7 * 15 + 7); board.make_move(0); board.make_move(7 * 15 + 8); let (stm, _) = compute_active_features(&board);
let has_lp = stm
.iter()
.any(|&f| f >= LP_BASE && f < LP_BASE + 2 * 1152);
assert!(has_lp, "should have LP-Rich features after 2-in-row");
}
#[test]
fn all_features_within_range() {
let mut board = Board::new();
for sq in [112, 0, 113, 1, 114, 15, 100, 50] {
board.make_move(sq);
}
let (stm, nstm) = compute_active_features(&board);
for &f in stm.iter().chain(nstm.iter()) {
assert!(f < TOTAL_FEATURE_SIZE, "feature {f} >= {TOTAL_FEATURE_SIZE}");
}
}
#[test]
fn active_features_under_cap() {
let mut board = Board::new();
for sq in 0..NUM_CELLS {
if board.is_empty(sq) {
board.make_move(sq);
}
}
let (stm, nstm) = compute_active_features(&board);
assert!(stm.len() <= MAX_ACTIVE_FEATURES, "stm len={}", stm.len());
assert!(nstm.len() <= MAX_ACTIVE_FEATURES, "nstm len={}", nstm.len());
}
#[test]
fn push_pop_consistency() {
let mut weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
let acc_size = GOMOKU_NNUE_CONFIG.accumulator_size;
for sq in 0..20 {
for i in 0..acc_size {
weights.feature_weights[sq][i] = ((sq * 7 + i) % 13) as i16 - 6;
weights.feature_weights[sq + HALF_FEATURE_SIZE][i] =
((sq * 3 + i) % 11) as i16 - 5;
}
}
let mut board = Board::new();
let mut inc = IncrementalEval::new(&weights);
inc.refresh(&board, &weights);
let before = inc.eval(&weights);
board.make_move(112);
inc.push_move(&board, 112, &weights);
board.undo_move();
inc.pop_move();
assert_eq!(before, inc.eval(&weights));
let _ = PS_BASE;
}
#[test]
fn compound_catches_double_three_at_non_line_start_stone() {
let mut board = Board::new();
board.make_move(7 * 15 + 5); board.make_move(0); board.make_move(7 * 15 + 6); board.make_move(1); board.make_move(7 * 15 + 7); board.make_move(2); board.make_move(6 * 15 + 7); board.make_move(3); board.make_move(8 * 15 + 7);
let (stm, _) = compute_active_features(&board);
let expected = compound_index(1, 21);
assert!(
stm.contains(&expected),
"stm should contain opponent's O3+O3 compound at the non-line-start \
crossing stone (7,7); expected feature index {expected} missing.\n\
stm={stm:?}"
);
}
#[test]
fn broken_three_detected_open() {
let mut board = Board::new();
board.make_move(7 * 15 + 5); board.make_move(0); board.make_move(7 * 15 + 6); board.make_move(1); board.make_move(7 * 15 + 8);
let (stm, _) = compute_active_features(&board);
let zone = zone_for(7, 5);
let expected = broken_index(1, BROKEN_SHAPE_THREE, 1, 0, zone);
assert!(
stm.contains(&expected),
"expected broken three (open) feature {expected} missing; stm={stm:?}"
);
}
#[test]
fn jump_four_detected() {
let mut board = Board::new();
board.make_move(7 * 15 + 5); board.make_move(0);
board.make_move(7 * 15 + 6); board.make_move(1);
board.make_move(7 * 15 + 7); board.make_move(2);
board.make_move(7 * 15 + 9);
let (stm, _) = compute_active_features(&board);
let zone = zone_for(7, 5);
let expected = broken_index(1, BROKEN_SHAPE_JUMP_FOUR, 1, 0, zone);
assert!(
stm.contains(&expected),
"expected jump four (open) feature {expected} missing; stm={stm:?}"
);
}
#[test]
fn double_broken_three_detected() {
let mut board = Board::new();
board.make_move(7 * 15 + 5); board.make_move(0);
board.make_move(7 * 15 + 7); board.make_move(1);
board.make_move(7 * 15 + 9);
let (stm, _) = compute_active_features(&board);
let zone = zone_for(7, 5);
let expected = broken_index(1, BROKEN_SHAPE_DOUBLE_THREE, 1, 0, zone);
assert!(
stm.contains(&expected),
"expected double broken three (open) feature {expected} missing; stm={stm:?}"
);
}
#[test]
fn compound_excludes_single_threat() {
let mut board = Board::new();
board.make_move(7 * 15 + 6); board.make_move(0); board.make_move(7 * 15 + 7); board.make_move(1); board.make_move(7 * 15 + 8);
let (stm, _) = compute_active_features(&board);
let single_o3 = compound_index(1, 3);
assert!(
!stm.contains(&single_o3),
"compound should skip single O3 (already handled by LP-Rich); \
unexpected single-threat compound feature {single_o3} found in stm={stm:?}"
);
}
}