use crate::board::{Board, Stone, BOARD_SIZE, NUM_CELLS};
use crate::features::{
broken_index, compound_index, count_bucket, cross_line_hash, cross_line_index, density_index,
length_bucket, local_density_bucket, lp_rich_index, open_bucket, ps_index, zone_for,
BROKEN_SHAPE_DOUBLE_THREE, BROKEN_SHAPE_JUMP_FOUR, BROKEN_SHAPE_THREE, DENSITY_CAT_LEGAL,
DENSITY_CAT_MY_COUNT, DENSITY_CAT_MY_LOCAL, DENSITY_CAT_OPP_COUNT, DENSITY_CAT_OPP_LOCAL,
MAX_ACTIVE_FEATURES,
};
use crate::heuristic::{scan_line, DIR};
use noru::network::{forward, Accumulator, FeatureDelta, NnueWeights};
use std::sync::OnceLock;
pub(crate) fn affected_cells(mv: usize) -> Vec<usize> {
let row = (mv / BOARD_SIZE) as i32;
let col = (mv % BOARD_SIZE) as i32;
let mut cells = Vec::with_capacity(121);
for dr in -5..=5i32 {
for dc in -5..=5i32 {
let r = row + dr;
let c = col + dc;
if r >= 0 && r < BOARD_SIZE as i32 && c >= 0 && c < BOARD_SIZE as i32 {
cells.push((r as usize) * BOARD_SIZE + c as usize);
}
}
}
cells
}
static COMPOUND_ENABLED: OnceLock<bool> = OnceLock::new();
fn compound_enabled() -> bool {
*COMPOUND_ENABLED.get_or_init(|| std::env::var("NORU_NO_COMPOUND").is_err())
}
pub fn compute_active_features(board: &Board) -> (Vec<usize>, Vec<usize>) {
let (my_bb, opp_bb) = match board.side_to_move {
Stone::Black => (&board.black, &board.white),
Stone::White => (&board.white, &board.black),
};
let mut stm = Vec::with_capacity(MAX_ACTIVE_FEATURES);
let mut nstm = Vec::with_capacity(MAX_ACTIVE_FEATURES);
let compound_on = compound_enabled();
for sq in my_bb.iter_ones().chain(opp_bb.iter_ones()) {
features_from_cell(my_bb, opp_bb, sq, compound_on, &mut stm, &mut nstm);
}
push_density_features(board, my_bb, opp_bb, &mut stm, &mut nstm);
(stm, nstm)
}
#[inline]
pub(crate) fn features_from_cell(
my_bb: &crate::board::BitBoard,
opp_bb: &crate::board::BitBoard,
sq: usize,
compound_on: bool,
stm: &mut Vec<usize>,
nstm: &mut Vec<usize>,
) {
let row = (sq / BOARD_SIZE) as i32;
let col = (sq % BOARD_SIZE) as i32;
let (stones, opponent, persp_mine, persp_opp) = if my_bb.get(sq) {
(my_bb, opp_bb, 0, 1)
} else if opp_bb.get(sq) {
(opp_bb, my_bb, 1, 0)
} else {
return; };
stm.push(ps_index(persp_mine, sq));
nstm.push(ps_index(persp_opp, sq));
for (dir_idx, &(dr, dc)) in DIR.iter().enumerate() {
if is_line_start(stones, row, col, dr, dc) {
let info = scan_line(stones, opponent, row, col, dr, dc);
let z = zone_for(row, col);
let len = length_bucket(info.count);
let op = open_bucket(info.open_front, info.open_back);
stm.push(lp_rich_index(persp_mine, len, op, dir_idx, z));
nstm.push(lp_rich_index(persp_opp, len, op, dir_idx, z));
}
}
if compound_on {
let mut threats = [Threat::None; 4];
for (di, &(dr, dc)) in DIR.iter().enumerate() {
let info = scan_line(stones, opponent, row, col, dr, dc);
let open = info.open_front as u32 + info.open_back as u32;
threats[di] = classify_threat(info.count, open);
}
if let Some(combo) = compound_combo_id(&threats) {
stm.push(compound_index(persp_mine, combo));
nstm.push(compound_index(persp_opp, combo));
}
}
let stm_cells = collect_3x3(my_bb, opp_bb, row, col);
let stm_bucket = cross_line_hash(stm_cells);
stm.push(cross_line_index(0, stm_bucket));
nstm.push(cross_line_index(1, stm_bucket));
let nstm_cells = swap_mine_opp(stm_cells);
let nstm_bucket = cross_line_hash(nstm_cells);
stm.push(cross_line_index(1, nstm_bucket));
nstm.push(cross_line_index(0, nstm_bucket));
for (dir_idx, &(dr, dc)) in DIR.iter().enumerate() {
detect_broken_and_push(
stones, opponent, row, col, dr, dc, dir_idx, persp_mine, persp_opp, stm, nstm,
);
}
}
fn push_density_features(
board: &Board,
my_bb: &crate::board::BitBoard,
opp_bb: &crate::board::BitBoard,
stm: &mut Vec<usize>,
nstm: &mut Vec<usize>,
) {
let my_count = my_bb.count_ones();
let opp_count = opp_bb.count_ones();
push_density(stm, nstm, DENSITY_CAT_MY_COUNT, count_bucket(my_count));
push_density(stm, nstm, DENSITY_CAT_OPP_COUNT, count_bucket(opp_count));
let (my_local, opp_local) = local_density(board);
push_density(stm, nstm, DENSITY_CAT_MY_LOCAL, local_density_bucket(my_local));
push_density(stm, nstm, DENSITY_CAT_OPP_LOCAL, local_density_bucket(opp_local));
let legal = (NUM_CELLS as u32).saturating_sub(my_count + opp_count);
push_density(stm, nstm, DENSITY_CAT_LEGAL, count_bucket(legal));
}
#[allow(clippy::too_many_arguments)]
fn detect_broken_and_push(
stones: &crate::board::BitBoard,
opp: &crate::board::BitBoard,
row: i32,
col: i32,
dr: i32,
dc: i32,
dir_idx: usize,
perspective_mine: usize,
perspective_opp: usize,
stm: &mut Vec<usize>,
nstm: &mut Vec<usize>,
) {
let pr = row - dr;
let pc = col - dc;
if pr >= 0 && pr < BOARD_SIZE as i32 && pc >= 0 && pc < BOARD_SIZE as i32 {
if stones.get((pr as usize) * BOARD_SIZE + pc as usize) {
return; }
}
let mut line = [2u8; 11]; for off in -5i32..=5 {
let nr = row + dr * off;
let nc = col + dc * off;
if nr < 0 || nr >= BOARD_SIZE as i32 || nc < 0 || nc >= BOARD_SIZE as i32 {
continue; }
let cell_idx = (nr as usize) * BOARD_SIZE + nc as usize;
let slot = (off + 5) as usize;
if stones.get(cell_idx) {
line[slot] = 1;
} else if opp.get(cell_idx) {
line[slot] = 2;
} else {
line[slot] = 0;
}
}
let zone = zone_for(row, col);
if let Some((shape, is_open)) = classify_broken_shape(&line) {
let open_bucket = if is_open { 1 } else { 0 };
stm.push(broken_index(
perspective_mine,
shape,
open_bucket,
dir_idx,
zone,
));
nstm.push(broken_index(
perspective_opp,
shape,
open_bucket,
dir_idx,
zone,
));
}
}
fn classify_broken_shape(line: &[u8; 11]) -> Option<(usize, bool)> {
debug_assert!(line[5] == 1);
let open_left = line[4] == 0;
let mut cells: [u8; 5] = [2; 5];
for i in 0..5 {
cells[i] = line[6 + i];
}
let mut mine_right = 0u32;
let mut gap_count = 0u32;
let mut right_open = false;
let mut prev_was_empty = false;
let mut scan_ended_early = false;
for &c in &cells {
if c == 2 {
scan_ended_early = true;
break;
}
if c == 0 {
if prev_was_empty {
right_open = true;
gap_count = gap_count.saturating_sub(1);
scan_ended_early = true;
break;
}
gap_count += 1;
prev_was_empty = true;
} else {
mine_right += 1;
prev_was_empty = false;
}
}
if !scan_ended_early && prev_was_empty {
right_open = true;
gap_count = gap_count.saturating_sub(1);
}
if gap_count == 0 {
return None;
}
let total_mine = 1 + mine_right; let is_open = open_left && right_open;
match (total_mine, gap_count) {
(3, 1) => Some((BROKEN_SHAPE_THREE, is_open)),
(4, 1) => Some((BROKEN_SHAPE_JUMP_FOUR, is_open)),
(3, 2) => Some((BROKEN_SHAPE_DOUBLE_THREE, is_open)),
_ => None,
}
}
#[inline]
fn collect_3x3(
my_bb: &crate::board::BitBoard,
opp_bb: &crate::board::BitBoard,
row: i32,
col: i32,
) -> [u8; 9] {
let mut cells = [0u8; 9];
let mut i = 0;
for dr in -1..=1 {
for dc in -1..=1 {
let nr = row + dr;
let nc = col + dc;
if nr < 0 || nr >= BOARD_SIZE as i32 || nc < 0 || nc >= BOARD_SIZE as i32 {
cells[i] = 3;
} else {
let idx = (nr as usize) * BOARD_SIZE + (nc as usize);
if my_bb.get(idx) {
cells[i] = 1;
} else if opp_bb.get(idx) {
cells[i] = 2;
} else {
cells[i] = 0;
}
}
i += 1;
}
}
cells
}
#[inline]
fn swap_mine_opp(c: [u8; 9]) -> [u8; 9] {
let mut out = [0u8; 9];
for i in 0..9 {
out[i] = match c[i] {
1 => 2,
2 => 1,
v => v,
};
}
out
}
#[inline]
fn is_line_start(bb: &crate::board::BitBoard, row: i32, col: i32, dr: i32, dc: i32) -> bool {
let pr = row - dr;
let pc = col - dc;
if pr < 0 || pr >= BOARD_SIZE as i32 || pc < 0 || pc >= BOARD_SIZE as i32 {
return true;
}
!bb.get(pr as usize * BOARD_SIZE + pc as usize)
}
#[inline]
fn push_density(stm: &mut Vec<usize>, nstm: &mut Vec<usize>, cat: usize, bucket: usize) {
let idx = density_index(cat, bucket);
stm.push(idx);
nstm.push(idx);
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum Threat {
None,
OpenTwo, ClosedThree, OpenThree, ClosedFour, OpenFour, Five, }
fn classify_threat(count: u32, open_ends: u32) -> Threat {
match (count, open_ends) {
(5.., _) => Threat::Five,
(4, 2) => Threat::OpenFour,
(4, 1) => Threat::ClosedFour,
(3, 2) => Threat::OpenThree,
(3, 1) => Threat::ClosedThree,
(2, 2) => Threat::OpenTwo,
_ => Threat::None,
}
}
fn compound_combo_id(threats: &[Threat; 4]) -> Option<usize> {
let mut sorted = *threats;
sorted.sort_unstable_by(|a, b| b.cmp(a));
let t1 = sorted[0];
let t2 = sorted[1];
let t3 = sorted[2];
if t1 == Threat::None {
return None; }
let t1_rank = threat_rank(t1);
let t2_rank = threat_rank(t2);
if t2 == Threat::None {
return None;
}
let dual_id = match t1_rank {
0 => 6 + t2_rank, 1 => 12 + (t2_rank - 1), 2 => 17 + (t2_rank - 2), 3 => 21 + (t2_rank - 3), 4 => 24 + (t2_rank - 4), 5 => 26, _ => return None,
};
if t3 != Threat::None && dual_id < 33 {
let triple_base = 27;
let triple_id = triple_base + threat_rank(t1).min(5);
return Some(triple_id);
}
Some(dual_id)
}
fn threat_rank(t: Threat) -> usize {
match t {
Threat::Five => 0,
Threat::OpenFour => 1,
Threat::ClosedFour => 2,
Threat::OpenThree => 3,
Threat::ClosedThree => 4,
Threat::OpenTwo => 5,
Threat::None => 6,
}
}
fn local_density(board: &Board) -> (u32, u32) {
let Some(mv) = board.last_move else {
return (0, 0);
};
let (my_bb, opp_bb) = match board.side_to_move {
Stone::Black => (&board.black, &board.white),
Stone::White => (&board.white, &board.black),
};
let r = (mv / BOARD_SIZE) as i32;
let c = (mv % BOARD_SIZE) as i32;
let mut my = 0u32;
let mut op = 0u32;
for dr in -1..=1 {
for dc in -1..=1 {
if dr == 0 && dc == 0 {
continue;
}
let nr = r + dr;
let nc = c + dc;
if nr < 0 || nr >= BOARD_SIZE as i32 || nc < 0 || nc >= BOARD_SIZE as i32 {
continue;
}
let i = (nr as usize) * BOARD_SIZE + nc as usize;
if my_bb.get(i) {
my += 1;
}
if opp_bb.get(i) {
op += 1;
}
}
}
(my, op)
}
pub fn evaluate(board: &Board, weights: &NnueWeights) -> i32 {
let (stm_feats, nstm_feats) = compute_active_features(board);
let mut acc = Accumulator::new(&weights.feature_bias);
acc.refresh(weights, &stm_feats, &nstm_feats);
forward(&acc, weights)
}
pub struct IncrementalEval {
pub accumulator: Accumulator,
cell_features: Vec<(Vec<usize>, Vec<usize>)>,
density_features: (Vec<usize>, Vec<usize>),
stack: Vec<UndoRecord>,
}
struct UndoRecord {
accumulator: Accumulator,
cell_changes: Vec<(usize, Vec<usize>, Vec<usize>)>,
density: (Vec<usize>, Vec<usize>),
}
impl IncrementalEval {
pub fn new(weights: &NnueWeights) -> Self {
Self {
accumulator: Accumulator::new(&weights.feature_bias),
cell_features: vec![(Vec::new(), Vec::new()); NUM_CELLS],
density_features: (Vec::new(), Vec::new()),
stack: Vec::with_capacity(225),
}
}
pub fn refresh(&mut self, board: &Board, weights: &NnueWeights) {
let (my_bb, opp_bb) = match board.side_to_move {
Stone::Black => (&board.black, &board.white),
Stone::White => (&board.white, &board.black),
};
let compound_on = compound_enabled();
for i in 0..NUM_CELLS {
self.cell_features[i].0.clear();
self.cell_features[i].1.clear();
}
for sq in my_bb.iter_ones().chain(opp_bb.iter_ones()) {
let entry = &mut self.cell_features[sq];
features_from_cell(my_bb, opp_bb, sq, compound_on, &mut entry.0, &mut entry.1);
}
self.density_features.0.clear();
self.density_features.1.clear();
push_density_features(
board,
my_bb,
opp_bb,
&mut self.density_features.0,
&mut self.density_features.1,
);
let (all_stm, all_nstm) = self.collect_all_features();
self.accumulator.refresh(weights, &all_stm, &all_nstm);
self.stack.clear();
}
fn collect_all_features(&self) -> (Vec<usize>, Vec<usize>) {
let mut stm = Vec::with_capacity(MAX_ACTIVE_FEATURES);
let mut nstm = Vec::with_capacity(MAX_ACTIVE_FEATURES);
for (s, n) in &self.cell_features {
stm.extend(s.iter().copied());
nstm.extend(n.iter().copied());
}
stm.extend(self.density_features.0.iter().copied());
nstm.extend(self.density_features.1.iter().copied());
(stm, nstm)
}
pub fn push_move(&mut self, board: &Board, mv: usize, weights: &NnueWeights) {
let mut undo = UndoRecord {
accumulator: self.accumulator.clone(),
cell_changes: Vec::new(),
density: self.density_features.clone(),
};
self.accumulator.swap();
for feats in self.cell_features.iter_mut() {
std::mem::swap(&mut feats.0, &mut feats.1);
}
std::mem::swap(&mut self.density_features.0, &mut self.density_features.1);
let (my_bb, opp_bb) = match board.side_to_move {
Stone::Black => (&board.black, &board.white),
Stone::White => (&board.white, &board.black),
};
let compound_on = compound_enabled();
let cells = affected_cells(mv);
let mut new_stm_buf: Vec<usize> = Vec::with_capacity(16);
let mut new_nstm_buf: Vec<usize> = Vec::with_capacity(16);
for &c in &cells {
new_stm_buf.clear();
new_nstm_buf.clear();
features_from_cell(my_bb, opp_bb, c, compound_on, &mut new_stm_buf, &mut new_nstm_buf);
let (old_stm, old_nstm) = &self.cell_features[c];
if old_stm.as_slice() == new_stm_buf.as_slice()
&& old_nstm.as_slice() == new_nstm_buf.as_slice()
{
continue; }
apply_delta_by_chunks(
&mut self.accumulator,
weights,
&new_stm_buf,
old_stm,
&new_nstm_buf,
old_nstm,
);
undo.cell_changes
.push((c, std::mem::take(&mut self.cell_features[c].0), std::mem::take(&mut self.cell_features[c].1)));
self.cell_features[c].0 = new_stm_buf.clone();
self.cell_features[c].1 = new_nstm_buf.clone();
}
let mut new_dens_stm: Vec<usize> = Vec::with_capacity(8);
let mut new_dens_nstm: Vec<usize> = Vec::with_capacity(8);
push_density_features(board, my_bb, opp_bb, &mut new_dens_stm, &mut new_dens_nstm);
if new_dens_stm != self.density_features.0 || new_dens_nstm != self.density_features.1 {
apply_delta_by_chunks(
&mut self.accumulator,
weights,
&new_dens_stm,
&self.density_features.0,
&new_dens_nstm,
&self.density_features.1,
);
self.density_features.0 = new_dens_stm;
self.density_features.1 = new_dens_nstm;
}
self.stack.push(undo);
}
pub fn pop_move(&mut self) {
if let Some(undo) = self.stack.pop() {
self.accumulator = undo.accumulator;
self.density_features = undo.density;
for (c, old_stm, old_nstm) in undo.cell_changes {
self.cell_features[c].0 = old_stm;
self.cell_features[c].1 = old_nstm;
}
for feats in self.cell_features.iter_mut() {
std::mem::swap(&mut feats.0, &mut feats.1);
}
}
}
pub fn eval(&self, weights: &NnueWeights) -> i32 {
forward(&self.accumulator, weights)
}
}
fn apply_delta_by_chunks(
acc: &mut Accumulator,
weights: &NnueWeights,
new_stm: &[usize],
old_stm: &[usize],
new_nstm: &[usize],
old_nstm: &[usize],
) {
let (stm_add, stm_rem) = multiset_diff(new_stm, old_stm);
let (nstm_add, nstm_rem) = multiset_diff(new_nstm, old_nstm);
const MAX_FD: usize = noru::network::MAX_FEATURE_DELTA;
let max_chunk = MAX_FD;
let stm_chunks = chunk_pairs(&stm_add, &stm_rem, max_chunk);
let nstm_chunks = chunk_pairs(&nstm_add, &nstm_rem, max_chunk);
let n = stm_chunks.len().max(nstm_chunks.len());
for i in 0..n {
let (sa, sr) = stm_chunks.get(i).cloned().unwrap_or((&[][..], &[][..]));
let (na, nr) = nstm_chunks.get(i).cloned().unwrap_or((&[][..], &[][..]));
let stm_delta = FeatureDelta::from_slices(sa, sr).expect("stm chunk overflow");
let nstm_delta = FeatureDelta::from_slices(na, nr).expect("nstm chunk overflow");
acc.update_incremental(weights, &stm_delta, &nstm_delta);
}
}
fn chunk_pairs<'a>(
add: &'a [usize],
rem: &'a [usize],
max_chunk: usize,
) -> Vec<(&'a [usize], &'a [usize])> {
let n_add = add.len();
let n_rem = rem.len();
let chunks = n_add.div_ceil(max_chunk).max(n_rem.div_ceil(max_chunk)).max(1);
let mut out = Vec::with_capacity(chunks);
for i in 0..chunks {
let a_start = (i * max_chunk).min(n_add);
let a_end = ((i + 1) * max_chunk).min(n_add);
let r_start = (i * max_chunk).min(n_rem);
let r_end = ((i + 1) * max_chunk).min(n_rem);
out.push((&add[a_start..a_end], &rem[r_start..r_end]));
}
out
}
fn multiset_diff(new: &[usize], old: &[usize]) -> (Vec<usize>, Vec<usize>) {
let mut new_count: std::collections::HashMap<usize, i32> = std::collections::HashMap::new();
for &x in new {
*new_count.entry(x).or_insert(0) += 1;
}
for &x in old {
*new_count.entry(x).or_insert(0) -= 1;
}
let mut add = Vec::new();
let mut rem = Vec::new();
for (&idx, &count) in new_count.iter() {
if count > 0 {
for _ in 0..count {
add.push(idx);
}
} else if count < 0 {
for _ in 0..(-count) {
rem.push(idx);
}
}
}
(add, rem)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::board::Board;
use crate::features::{
broken_index, compound_index, BROKEN_SHAPE_DOUBLE_THREE, BROKEN_SHAPE_JUMP_FOUR,
BROKEN_SHAPE_THREE, GOMOKU_NNUE_CONFIG, HALF_FEATURE_SIZE, LP_BASE, MAX_ACTIVE_FEATURES,
PS_BASE, TOTAL_FEATURE_SIZE,
};
#[test]
fn empty_board_has_only_density_features() {
let board = Board::new();
let (stm, nstm) = compute_active_features(&board);
assert_eq!(stm.len(), 5);
assert_eq!(nstm.len(), 5);
}
#[test]
fn evaluate_zero_weights() {
let board = Board::new();
let weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
assert_eq!(evaluate(&board, &weights), 0);
}
#[test]
fn features_include_lp_after_two_in_row() {
let mut board = Board::new();
board.make_move(7 * 15 + 7); board.make_move(0); board.make_move(7 * 15 + 8); let (stm, _) = compute_active_features(&board);
let has_lp = stm
.iter()
.any(|&f| f >= LP_BASE && f < LP_BASE + 2 * 1152);
assert!(has_lp, "should have LP-Rich features after 2-in-row");
}
#[test]
fn all_features_within_range() {
let mut board = Board::new();
for sq in [112, 0, 113, 1, 114, 15, 100, 50] {
board.make_move(sq);
}
let (stm, nstm) = compute_active_features(&board);
for &f in stm.iter().chain(nstm.iter()) {
assert!(f < TOTAL_FEATURE_SIZE, "feature {f} >= {TOTAL_FEATURE_SIZE}");
}
}
#[test]
fn active_features_under_cap() {
let mut board = Board::new();
for sq in 0..NUM_CELLS {
if board.is_empty(sq) {
board.make_move(sq);
}
}
let (stm, nstm) = compute_active_features(&board);
assert!(stm.len() <= MAX_ACTIVE_FEATURES, "stm len={}", stm.len());
assert!(nstm.len() <= MAX_ACTIVE_FEATURES, "nstm len={}", nstm.len());
}
#[test]
fn push_pop_consistency() {
let mut weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
let acc_size = GOMOKU_NNUE_CONFIG.accumulator_size;
for sq in 0..20 {
for i in 0..acc_size {
weights.feature_weights[sq][i] = ((sq * 7 + i) % 13) as i16 - 6;
weights.feature_weights[sq + HALF_FEATURE_SIZE][i] =
((sq * 3 + i) % 11) as i16 - 5;
}
}
let mut board = Board::new();
let mut inc = IncrementalEval::new(&weights);
inc.refresh(&board, &weights);
let before = inc.eval(&weights);
board.make_move(112);
inc.push_move(&board, 112, &weights);
board.undo_move();
inc.pop_move();
assert_eq!(before, inc.eval(&weights));
let _ = PS_BASE;
}
#[test]
fn incremental_matches_full_refresh() {
let mut weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
let acc_size = GOMOKU_NNUE_CONFIG.accumulator_size;
for f in 0..TOTAL_FEATURE_SIZE {
for i in 0..acc_size {
weights.feature_weights[f][i] =
((f.wrapping_mul(13).wrapping_add(i) % 31) as i16) - 15;
}
weights.feature_bias[i_mod(f, acc_size)] =
((f.wrapping_mul(7) % 19) as i16) - 9;
}
let moves = [
112, 113, 97, 98, 127, 128, 111, 114, 96, 99, 126, 129, 82, 83, 84, 85, 100, 101, 115, 116,
];
let mut board = Board::new();
let mut inc = IncrementalEval::new(&weights);
inc.refresh(&board, &weights);
let initial = inc.eval(&weights);
assert_eq!(initial, evaluate(&board, &weights), "refresh mismatch at empty");
for (i, &mv) in moves.iter().enumerate() {
if !board.is_empty(mv) {
continue;
}
board.make_move(mv);
inc.push_move(&board, mv, &weights);
let inc_val = inc.eval(&weights);
let full_val = evaluate(&board, &weights);
assert_eq!(
inc_val, full_val,
"mismatch after move {} (ply {}): inc={} full={}",
mv, i + 1, inc_val, full_val
);
}
for _ in 0..moves.len() {
board.undo_move();
inc.pop_move();
let inc_val = inc.eval(&weights);
let full_val = evaluate(&board, &weights);
assert_eq!(inc_val, full_val, "mismatch during undo");
}
}
fn i_mod(f: usize, acc: usize) -> usize {
f % acc
}
#[test]
fn incremental_matches_full_refresh_far_apart() {
let mut weights = NnueWeights::zeros(GOMOKU_NNUE_CONFIG);
let acc_size = GOMOKU_NNUE_CONFIG.accumulator_size;
for f in 0..TOTAL_FEATURE_SIZE {
for i in 0..acc_size {
weights.feature_weights[f][i] =
((f.wrapping_mul(17).wrapping_add(i) % 37) as i16) - 18;
}
}
for i in 0..acc_size {
weights.feature_bias[i] = ((i % 23) as i16) - 11;
}
let moves = [0, 224, 14, 210, 112, 30, 200, 58, 101, 150, 7, 217];
let mut board = Board::new();
let mut inc = IncrementalEval::new(&weights);
inc.refresh(&board, &weights);
assert_eq!(inc.eval(&weights), evaluate(&board, &weights));
for (i, &mv) in moves.iter().enumerate() {
if !board.is_empty(mv) {
continue;
}
board.make_move(mv);
inc.push_move(&board, mv, &weights);
let inc_val = inc.eval(&weights);
let full_val = evaluate(&board, &weights);
assert_eq!(
inc_val, full_val,
"far-apart mismatch after move {} (ply {}): inc={} full={}",
mv, i + 1, inc_val, full_val
);
}
for _ in 0..moves.len() {
board.undo_move();
inc.pop_move();
assert_eq!(inc.eval(&weights), evaluate(&board, &weights), "undo mismatch");
}
}
#[test]
#[ignore = "requires a real NNUE weights file (env NORU_TEST_WEIGHTS or default models/gomoku_v13_broken_rapfi.bin)"]
fn incremental_matches_full_refresh_real_weights() {
use crate::board::GameResult;
use noru::trainer::SimpleRng;
let path = std::env::var("NORU_TEST_WEIGHTS").unwrap_or_else(|_| {
let manifest = env!("CARGO_MANIFEST_DIR");
format!("{}/models/gomoku_v13_broken_rapfi.bin", manifest)
});
let data = std::fs::read(&path)
.unwrap_or_else(|e| panic!("failed to read weights from {path}: {e}"));
let weights = NnueWeights::load_from_bytes(&data, Some(GOMOKU_NNUE_CONFIG.clone()))
.unwrap_or_else(|e| panic!("load_from_bytes failed for {path}: {e}"));
let mut rng = SimpleRng::new(2026);
for trial in 0..100 {
let mut board = Board::new();
let mut inc = IncrementalEval::new(&weights);
inc.refresh(&board, &weights);
for ply in 0..160 {
if board.game_result() != GameResult::Ongoing {
break;
}
let moves = board.candidate_moves();
if moves.is_empty() {
break;
}
let mv = moves[rng.next_usize(moves.len())];
board.make_move(mv);
inc.push_move(&board, mv, &weights);
let inc_val = inc.eval(&weights);
let full_val = evaluate(&board, &weights);
assert_eq!(
inc_val, full_val,
"trial {trial} ply {ply} (move {mv}): inc={inc_val} full={full_val}"
);
}
}
}
#[test]
fn compound_catches_double_three_at_non_line_start_stone() {
let mut board = Board::new();
board.make_move(7 * 15 + 5); board.make_move(0); board.make_move(7 * 15 + 6); board.make_move(1); board.make_move(7 * 15 + 7); board.make_move(2); board.make_move(6 * 15 + 7); board.make_move(3); board.make_move(8 * 15 + 7);
let (stm, _) = compute_active_features(&board);
let expected = compound_index(1, 21);
assert!(
stm.contains(&expected),
"stm should contain opponent's O3+O3 compound at the non-line-start \
crossing stone (7,7); expected feature index {expected} missing.\n\
stm={stm:?}"
);
}
#[test]
fn broken_three_detected_open() {
let mut board = Board::new();
board.make_move(7 * 15 + 5); board.make_move(0); board.make_move(7 * 15 + 6); board.make_move(1); board.make_move(7 * 15 + 8);
let (stm, _) = compute_active_features(&board);
let zone = zone_for(7, 5);
let expected = broken_index(1, BROKEN_SHAPE_THREE, 1, 0, zone);
assert!(
stm.contains(&expected),
"expected broken three (open) feature {expected} missing; stm={stm:?}"
);
}
#[test]
fn jump_four_detected() {
let mut board = Board::new();
board.make_move(7 * 15 + 5); board.make_move(0);
board.make_move(7 * 15 + 6); board.make_move(1);
board.make_move(7 * 15 + 7); board.make_move(2);
board.make_move(7 * 15 + 9);
let (stm, _) = compute_active_features(&board);
let zone = zone_for(7, 5);
let expected = broken_index(1, BROKEN_SHAPE_JUMP_FOUR, 1, 0, zone);
assert!(
stm.contains(&expected),
"expected jump four (open) feature {expected} missing; stm={stm:?}"
);
}
#[test]
fn double_broken_three_detected() {
let mut board = Board::new();
board.make_move(7 * 15 + 5); board.make_move(0);
board.make_move(7 * 15 + 7); board.make_move(1);
board.make_move(7 * 15 + 9);
let (stm, _) = compute_active_features(&board);
let zone = zone_for(7, 5);
let expected = broken_index(1, BROKEN_SHAPE_DOUBLE_THREE, 1, 0, zone);
assert!(
stm.contains(&expected),
"expected double broken three (open) feature {expected} missing; stm={stm:?}"
);
}
#[test]
fn compound_excludes_single_threat() {
let mut board = Board::new();
board.make_move(7 * 15 + 6); board.make_move(0); board.make_move(7 * 15 + 7); board.make_move(1); board.make_move(7 * 15 + 8);
let (stm, _) = compute_active_features(&board);
let single_o3 = compound_index(1, 3);
assert!(
!stm.contains(&single_o3),
"compound should skip single O3 (already handled by LP-Rich); \
unexpected single-threat compound feature {single_o3} found in stm={stm:?}"
);
}
}