use super::accumulator::{DirtyPiece, IndexList, MAX_ACTIVE_FEATURES, MAX_PATH_LENGTH};
use super::constants::NNUE_PYTORCH_L1;
use crate::types::{Color, MAX_PLY, Square};
#[repr(C, align(64))]
#[derive(Clone)]
pub struct AccumulatorLayerStacks {
pub accumulation: [[i16; NNUE_PYTORCH_L1]; 2],
pub computed_accumulation: bool,
pub computed_score: bool,
}
impl AccumulatorLayerStacks {
pub fn new() -> Self {
Self {
accumulation: [[0; NNUE_PYTORCH_L1]; 2],
computed_accumulation: false,
computed_score: false,
}
}
#[inline]
pub fn get(&self, perspective: usize) -> &[i16; NNUE_PYTORCH_L1] {
&self.accumulation[perspective]
}
#[inline]
pub fn get_mut(&mut self, perspective: usize) -> &mut [i16; NNUE_PYTORCH_L1] {
&mut self.accumulation[perspective]
}
}
impl Default for AccumulatorLayerStacks {
fn default() -> Self {
Self::new()
}
}
#[repr(C, align(64))]
struct AccCacheEntry {
accumulation: [i16; NNUE_PYTORCH_L1],
active_indices: [u32; MAX_ACTIVE_FEATURES],
num_active: u16,
valid: bool,
}
impl AccCacheEntry {
fn new_invalid() -> Self {
Self {
accumulation: [0; NNUE_PYTORCH_L1],
active_indices: [0; MAX_ACTIVE_FEATURES],
num_active: 0,
valid: false,
}
}
}
pub struct AccumulatorCacheLayerStacks {
entries: Box<[[AccCacheEntry; 2]; Square::NUM]>,
}
impl AccumulatorCacheLayerStacks {
pub fn new() -> Self {
let entries: Vec<[AccCacheEntry; 2]> = (0..Square::NUM)
.map(|_| [AccCacheEntry::new_invalid(), AccCacheEntry::new_invalid()])
.collect();
let boxed: Box<[[AccCacheEntry; 2]]> = entries.into_boxed_slice();
let ptr = Box::into_raw(boxed) as *mut [[AccCacheEntry; 2]; Square::NUM];
let entries = unsafe { Box::from_raw(ptr) };
Self { entries }
}
pub fn invalidate(&mut self) {
for sq_entries in self.entries.iter_mut() {
for entry in sq_entries.iter_mut() {
entry.valid = false;
}
}
}
pub(crate) fn refresh_or_cache<FA, FS>(
&mut self,
king_sq: Square,
perspective: Color,
active: &[u32],
biases: &[i16; NNUE_PYTORCH_L1],
accumulation: &mut [i16; NNUE_PYTORCH_L1],
add_fn: FA,
sub_fn: FS,
) where
FA: Fn(&mut [i16; NNUE_PYTORCH_L1], usize),
FS: Fn(&mut [i16; NNUE_PYTORCH_L1], usize),
{
let entry = &mut self.entries[king_sq.raw() as usize][perspective as usize];
if entry.valid {
accumulation.copy_from_slice(&entry.accumulation);
let cached = &entry.active_indices[..entry.num_active as usize];
Self::apply_diff(cached, active, accumulation, &add_fn, &sub_fn);
} else {
accumulation.copy_from_slice(biases);
for &idx in active {
add_fn(accumulation, idx as usize);
}
}
entry.accumulation.copy_from_slice(accumulation);
let n = active.len().min(MAX_ACTIVE_FEATURES);
entry.active_indices[..n].copy_from_slice(&active[..n]);
entry.num_active = n as u16;
entry.valid = true;
}
#[inline]
fn apply_diff<FA, FS>(
cached: &[u32],
current: &[u32],
accumulation: &mut [i16; NNUE_PYTORCH_L1],
add_fn: &FA,
sub_fn: &FS,
) where
FA: Fn(&mut [i16; NNUE_PYTORCH_L1], usize),
FS: Fn(&mut [i16; NNUE_PYTORCH_L1], usize),
{
let mut ci = 0;
let mut ni = 0;
while ci < cached.len() && ni < current.len() {
let c = cached[ci];
let n = current[ni];
if c < n {
sub_fn(accumulation, c as usize);
ci += 1;
} else if c > n {
add_fn(accumulation, n as usize);
ni += 1;
} else {
ci += 1;
ni += 1;
}
}
while ci < cached.len() {
sub_fn(accumulation, cached[ci] as usize);
ci += 1;
}
while ni < current.len() {
add_fn(accumulation, current[ni] as usize);
ni += 1;
}
}
}
impl Default for AccumulatorCacheLayerStacks {
fn default() -> Self {
Self::new()
}
}
pub struct StackEntryLayerStacks {
pub accumulator: AccumulatorLayerStacks,
pub dirty_piece: DirtyPiece,
pub previous: Option<usize>,
pub progress_sum: f32,
pub computed_progress: bool,
}
impl StackEntryLayerStacks {
pub fn new() -> Self {
Self {
accumulator: AccumulatorLayerStacks::new(),
dirty_piece: DirtyPiece::default(),
previous: None,
progress_sum: 0.0,
computed_progress: false,
}
}
}
impl Default for StackEntryLayerStacks {
fn default() -> Self {
Self::new()
}
}
pub struct AccumulatorStackLayerStacks {
entries: Box<[StackEntryLayerStacks]>,
current: usize,
}
impl AccumulatorStackLayerStacks {
const STACK_SIZE: usize = (MAX_PLY as usize) + 16;
pub fn new() -> Self {
let entries: Vec<StackEntryLayerStacks> =
(0..Self::STACK_SIZE).map(|_| StackEntryLayerStacks::new()).collect();
Self {
entries: entries.into_boxed_slice(),
current: 0,
}
}
#[inline]
pub fn current(&self) -> &StackEntryLayerStacks {
&self.entries[self.current]
}
#[inline]
pub fn current_mut(&mut self) -> &mut StackEntryLayerStacks {
&mut self.entries[self.current]
}
#[inline]
pub fn current_index(&self) -> usize {
self.current
}
#[inline]
pub fn entry_at(&self, index: usize) -> &StackEntryLayerStacks {
&self.entries[index]
}
#[inline]
pub fn entry_at_mut(&mut self, index: usize) -> &mut StackEntryLayerStacks {
&mut self.entries[index]
}
#[inline]
pub fn push(&mut self) {
let prev = self.current;
self.current += 1;
debug_assert!(self.current < Self::STACK_SIZE);
self.entries[self.current].previous = Some(prev);
self.entries[self.current].accumulator.computed_accumulation = false;
self.entries[self.current].accumulator.computed_score = false;
self.entries[self.current].dirty_piece = DirtyPiece::default();
self.entries[self.current].computed_progress = false;
}
#[inline]
pub fn pop(&mut self) {
debug_assert!(self.current > 0);
self.current -= 1;
}
#[inline]
pub fn get_prev_and_current_accumulators(
&mut self,
prev_idx: usize,
) -> (&AccumulatorLayerStacks, &mut AccumulatorLayerStacks) {
let cur_idx = self.current;
debug_assert!(prev_idx < cur_idx, "prev_idx ({prev_idx}) must be < cur_idx ({cur_idx})");
let (left, right) = self.entries.split_at_mut(cur_idx);
(&left[prev_idx].accumulator, &mut right[0].accumulator)
}
#[inline]
pub fn reset(&mut self) {
self.current = 0;
self.entries[0].accumulator.computed_accumulation = false;
self.entries[0].accumulator.computed_score = false;
self.entries[0].previous = None;
self.entries[0].computed_progress = false;
}
pub fn find_usable_accumulator(&self) -> Option<(usize, usize)> {
const MAX_DEPTH: usize = 4;
let current = &self.entries[self.current];
if current.dirty_piece.king_moved[0] || current.dirty_piece.king_moved[1] {
return None;
}
let mut prev_idx = current.previous?;
let mut depth = 1;
loop {
let prev = &self.entries[prev_idx];
if prev.accumulator.computed_accumulation {
return Some((prev_idx, depth));
}
if depth >= MAX_DEPTH {
return None;
}
let next_prev_idx = prev.previous?;
if prev.dirty_piece.king_moved[0] || prev.dirty_piece.king_moved[1] {
return None;
}
prev_idx = next_prev_idx;
depth += 1;
}
}
pub fn collect_path(&self, source_idx: usize) -> Option<IndexList<MAX_PATH_LENGTH>> {
self.collect_path_internal(source_idx)
}
fn collect_path_internal(&self, source_idx: usize) -> Option<IndexList<MAX_PATH_LENGTH>> {
let mut path = IndexList::new();
let mut idx = self.current;
while idx != source_idx {
if !path.push(idx) {
return None;
}
match self.entries[idx].previous {
Some(prev) => idx = prev,
None => return None,
}
}
path.reverse();
Some(path)
}
}
impl Default for AccumulatorStackLayerStacks {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_accumulator_new() {
let acc = AccumulatorLayerStacks::new();
assert!(!acc.computed_accumulation);
assert_eq!(acc.accumulation[0].len(), NNUE_PYTORCH_L1);
}
#[test]
fn test_stack_push_pop() {
let mut stack = AccumulatorStackLayerStacks::new();
assert_eq!(stack.current_index(), 0);
stack.push();
assert_eq!(stack.current_index(), 1);
assert_eq!(stack.current().previous, Some(0));
stack.pop();
assert_eq!(stack.current_index(), 0);
}
#[test]
fn test_cache_new_is_invalid() {
let cache = AccumulatorCacheLayerStacks::new();
for sq in 0..Square::NUM {
let king_sq = unsafe { Square::from_u8_unchecked(sq as u8) };
for perspective in [Color::Black, Color::White] {
let entry = &cache.entries[king_sq.raw() as usize][perspective as usize];
assert!(!entry.valid);
}
}
}
#[test]
fn test_cache_invalidate() {
let mut cache = AccumulatorCacheLayerStacks::new();
cache.entries[0][0].valid = true;
cache.entries[40][1].valid = true;
cache.invalidate();
assert!(!cache.entries[0][0].valid);
assert!(!cache.entries[40][1].valid);
}
#[test]
fn test_apply_diff_basic() {
let mut acc = [0i16; NNUE_PYTORCH_L1];
acc[0] = 9;
let cached = [1u32, 3, 5];
let current = [2u32, 3, 4];
AccumulatorCacheLayerStacks::apply_diff(
&cached,
¤t,
&mut acc,
&|a, idx| a[0] = a[0].wrapping_add(idx as i16),
&|a, idx| a[0] = a[0].wrapping_sub(idx as i16),
);
assert_eq!(acc[0], 9);
}
#[test]
fn test_apply_diff_all_added() {
let mut acc = [0i16; NNUE_PYTORCH_L1];
let cached: [u32; 0] = [];
let current = [10u32, 20, 30];
AccumulatorCacheLayerStacks::apply_diff(
&cached,
¤t,
&mut acc,
&|a, idx| a[0] = a[0].wrapping_add(idx as i16),
&|a, idx| a[0] = a[0].wrapping_sub(idx as i16),
);
assert_eq!(acc[0], 60);
}
#[test]
fn test_apply_diff_all_removed() {
let mut acc = [0i16; NNUE_PYTORCH_L1];
acc[0] = 60;
let cached = [10u32, 20, 30];
let current: [u32; 0] = [];
AccumulatorCacheLayerStacks::apply_diff(
&cached,
¤t,
&mut acc,
&|a, idx| a[0] = a[0].wrapping_add(idx as i16),
&|a, idx| a[0] = a[0].wrapping_sub(idx as i16),
);
assert_eq!(acc[0], 0);
}
#[test]
fn test_apply_diff_identical() {
let mut acc = [0i16; NNUE_PYTORCH_L1];
acc[0] = 42;
let cached = [1u32, 2, 3, 4, 5];
let current = [1u32, 2, 3, 4, 5];
AccumulatorCacheLayerStacks::apply_diff(
&cached,
¤t,
&mut acc,
&|a, idx| a[0] = a[0].wrapping_add(idx as i16),
&|a, idx| a[0] = a[0].wrapping_sub(idx as i16),
);
assert_eq!(acc[0], 42);
}
#[test]
fn test_refresh_or_cache_cold_start() {
let mut cache = AccumulatorCacheLayerStacks::new();
let king_sq = Square::SQ_55; let perspective = Color::Black;
let mut biases = [0i16; NNUE_PYTORCH_L1];
biases[0] = 100;
biases[1] = 200;
let active = [5u32, 10, 15]; let mut accumulation = [0i16; NNUE_PYTORCH_L1];
cache.refresh_or_cache(
king_sq,
perspective,
&active,
&biases,
&mut accumulation,
|acc, idx| acc[0] = acc[0].wrapping_add(idx as i16),
|acc, idx| acc[0] = acc[0].wrapping_sub(idx as i16),
);
assert_eq!(accumulation[0], 130);
assert_eq!(accumulation[1], 200);
let entry = &cache.entries[king_sq.raw() as usize][perspective as usize];
assert!(entry.valid);
assert_eq!(entry.num_active, 3);
}
#[test]
fn test_refresh_or_cache_hit() {
let mut cache = AccumulatorCacheLayerStacks::new();
let king_sq = Square::SQ_55;
let perspective = Color::Black;
let biases = [0i16; NNUE_PYTORCH_L1];
let active1 = [5u32, 10, 15];
let mut acc1 = [0i16; NNUE_PYTORCH_L1];
cache.refresh_or_cache(
king_sq,
perspective,
&active1,
&biases,
&mut acc1,
|acc, idx| acc[0] = acc[0].wrapping_add(idx as i16),
|acc, idx| acc[0] = acc[0].wrapping_sub(idx as i16),
);
assert_eq!(acc1[0], 30);
let active2 = [5u32, 10, 20];
let mut acc2 = [0i16; NNUE_PYTORCH_L1];
cache.refresh_or_cache(
king_sq,
perspective,
&active2,
&biases,
&mut acc2,
|acc, idx| acc[0] = acc[0].wrapping_add(idx as i16),
|acc, idx| acc[0] = acc[0].wrapping_sub(idx as i16),
);
assert_eq!(acc2[0], 35);
}
}