use super::accumulator::{DirtyPiece, IndexList, MAX_PATH_LENGTH};
use super::constants::NNUE_PYTORCH_L1;
use crate::types::MAX_PLY;
#[repr(C, align(64))]
#[derive(Clone)]
pub struct AccumulatorLayerStacks {
pub accumulation: [[i16; NNUE_PYTORCH_L1]; 2],
pub computed_accumulation: bool,
pub computed_score: bool,
}
impl AccumulatorLayerStacks {
pub fn new() -> Self {
Self {
accumulation: [[0; NNUE_PYTORCH_L1]; 2],
computed_accumulation: false,
computed_score: false,
}
}
#[inline]
pub fn get(&self, perspective: usize) -> &[i16; NNUE_PYTORCH_L1] {
&self.accumulation[perspective]
}
#[inline]
pub fn get_mut(&mut self, perspective: usize) -> &mut [i16; NNUE_PYTORCH_L1] {
&mut self.accumulation[perspective]
}
}
impl Default for AccumulatorLayerStacks {
fn default() -> Self {
Self::new()
}
}
pub struct StackEntryLayerStacks {
pub accumulator: AccumulatorLayerStacks,
pub dirty_piece: DirtyPiece,
pub previous: Option<usize>,
}
impl StackEntryLayerStacks {
pub fn new() -> Self {
Self {
accumulator: AccumulatorLayerStacks::new(),
dirty_piece: DirtyPiece::default(),
previous: None,
}
}
}
impl Default for StackEntryLayerStacks {
fn default() -> Self {
Self::new()
}
}
pub struct AccumulatorStackLayerStacks {
entries: Box<[StackEntryLayerStacks]>,
current: usize,
}
impl AccumulatorStackLayerStacks {
const STACK_SIZE: usize = (MAX_PLY as usize) + 16;
pub fn new() -> Self {
let entries: Vec<StackEntryLayerStacks> =
(0..Self::STACK_SIZE).map(|_| StackEntryLayerStacks::new()).collect();
Self {
entries: entries.into_boxed_slice(),
current: 0,
}
}
#[inline]
pub fn current(&self) -> &StackEntryLayerStacks {
&self.entries[self.current]
}
#[inline]
pub fn current_mut(&mut self) -> &mut StackEntryLayerStacks {
&mut self.entries[self.current]
}
#[inline]
pub fn current_index(&self) -> usize {
self.current
}
#[inline]
pub fn entry_at(&self, index: usize) -> &StackEntryLayerStacks {
&self.entries[index]
}
#[inline]
pub fn entry_at_mut(&mut self, index: usize) -> &mut StackEntryLayerStacks {
&mut self.entries[index]
}
#[inline]
pub fn push(&mut self) {
let prev = self.current;
self.current += 1;
debug_assert!(self.current < Self::STACK_SIZE);
self.entries[self.current].previous = Some(prev);
self.entries[self.current].accumulator.computed_accumulation = false;
self.entries[self.current].accumulator.computed_score = false;
self.entries[self.current].dirty_piece = DirtyPiece::default();
}
#[inline]
pub fn pop(&mut self) {
debug_assert!(self.current > 0);
self.current -= 1;
}
#[inline]
pub fn get_prev_and_current_accumulators(
&mut self,
prev_idx: usize,
) -> (&AccumulatorLayerStacks, &mut AccumulatorLayerStacks) {
let cur_idx = self.current;
debug_assert!(prev_idx < cur_idx, "prev_idx ({prev_idx}) must be < cur_idx ({cur_idx})");
let (left, right) = self.entries.split_at_mut(cur_idx);
(&left[prev_idx].accumulator, &mut right[0].accumulator)
}
#[inline]
pub fn reset(&mut self) {
self.current = 0;
self.entries[0].accumulator.computed_accumulation = false;
self.entries[0].accumulator.computed_score = false;
self.entries[0].previous = None;
}
pub fn find_usable_accumulator(&self) -> Option<(usize, usize)> {
const MAX_DEPTH: usize = 1;
let current = &self.entries[self.current];
if current.dirty_piece.king_moved[0] || current.dirty_piece.king_moved[1] {
return None;
}
let mut prev_idx = current.previous?;
let mut depth = 1;
loop {
let prev = &self.entries[prev_idx];
if prev.accumulator.computed_accumulation {
return Some((prev_idx, depth));
}
if depth >= MAX_DEPTH {
return None;
}
let next_prev_idx = prev.previous?;
if prev.dirty_piece.king_moved[0] || prev.dirty_piece.king_moved[1] {
return None;
}
prev_idx = next_prev_idx;
depth += 1;
}
}
pub fn collect_path(&self, source_idx: usize) -> Option<IndexList<MAX_PATH_LENGTH>> {
self.collect_path_internal(source_idx)
}
fn collect_path_internal(&self, source_idx: usize) -> Option<IndexList<MAX_PATH_LENGTH>> {
let mut path = IndexList::new();
let mut idx = self.current;
while idx != source_idx {
if !path.push(idx) {
return None;
}
match self.entries[idx].previous {
Some(prev) => idx = prev,
None => return None,
}
}
path.reverse();
Some(path)
}
}
impl Default for AccumulatorStackLayerStacks {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_accumulator_new() {
let acc = AccumulatorLayerStacks::new();
assert!(!acc.computed_accumulation);
assert_eq!(acc.accumulation[0].len(), NNUE_PYTORCH_L1);
}
#[test]
fn test_stack_push_pop() {
let mut stack = AccumulatorStackLayerStacks::new();
assert_eq!(stack.current_index(), 0);
stack.push();
assert_eq!(stack.current_index(), 1);
assert_eq!(stack.current().previous, Some(0));
stack.pop();
assert_eq!(stack.current_index(), 0);
}
}