use super::bona_piece::ExtBonaPiece;
use super::constants::{NUM_REFRESH_TRIGGERS, TRANSFORMED_FEATURE_DIMENSIONS};
use super::piece_list::PieceNumber;
use crate::types::{Color, MAX_PLY, Value};
use std::alloc::{Layout, alloc_zeroed, dealloc};
use std::mem::MaybeUninit;
use std::ops::{Deref, DerefMut};
pub const MAX_CHANGED_FEATURES: usize = 16;
pub const MAX_ACTIVE_FEATURES: usize = 54;
pub const MAX_PATH_LENGTH: usize = 8;
#[derive(Clone, Copy)]
pub struct IndexList<const N: usize> {
indices: [MaybeUninit<usize>; N],
len: u8,
}
impl<const N: usize> IndexList<N> {
const _ASSERT_N_FITS_U8: () = assert!(N <= u8::MAX as usize, "IndexList: N must be <= 255");
#[inline]
#[allow(path_statements)]
pub fn new() -> Self {
Self::_ASSERT_N_FITS_U8;
Self {
indices: [const { MaybeUninit::uninit() }; N],
len: 0,
}
}
#[inline]
#[must_use]
pub fn push(&mut self, index: usize) -> bool {
let pos = self.len as usize;
if pos >= N {
debug_assert!(false, "IndexList overflow: capacity={N}, len={pos}");
return false;
}
self.indices[pos].write(index);
self.len += 1;
true
}
#[inline]
pub fn iter(&self) -> impl ExactSizeIterator<Item = &usize> + '_ {
self.indices[..self.len as usize].iter().map(|v| unsafe { v.assume_init_ref() })
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn len(&self) -> usize {
self.len as usize
}
#[inline]
pub fn reverse(&mut self) {
let slice = &mut self.indices[..self.len as usize];
slice.reverse();
}
}
impl<const N: usize> Default for IndexList<N> {
fn default() -> Self {
Self::new()
}
}
#[repr(C, align(64))]
#[derive(Clone, Copy)]
pub struct Aligned<T: Copy>(pub T);
impl<T: Copy> Aligned<T> {
#[inline]
#[allow(clippy::uninit_assumed_init)]
pub unsafe fn new_uninit() -> Self {
unsafe { std::mem::MaybeUninit::uninit().assume_init() }
}
}
impl<T: Default + Copy> Default for Aligned<T> {
fn default() -> Self {
Self(T::default())
}
}
pub const CACHE_LINE_SIZE: usize = 64;
pub struct AlignedBox<T> {
ptr: *mut T,
len: usize,
layout: Layout,
}
impl<T: Copy + Default> AlignedBox<T> {
pub fn new_zeroed(len: usize) -> Self {
let size = std::mem::size_of::<T>()
.checked_mul(len)
.expect("AlignedBox::new_zeroed: size overflow");
let align = CACHE_LINE_SIZE.max(std::mem::align_of::<T>());
let layout = Layout::from_size_align(size, align).expect("Invalid layout").pad_to_align();
let ptr = unsafe { alloc_zeroed(layout) as *mut T };
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
Self { ptr, len, layout }
}
}
impl<T> Deref for AlignedBox<T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
}
impl<T> DerefMut for AlignedBox<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<T> Drop for AlignedBox<T> {
fn drop(&mut self) {
unsafe {
dealloc(self.ptr as *mut u8, self.layout);
}
}
}
impl<T: Copy + Default> Clone for AlignedBox<T> {
fn clone(&self) -> Self {
let mut new_box = Self::new_zeroed(self.len);
new_box.copy_from_slice(self);
new_box
}
}
unsafe impl<T: Send> Send for AlignedBox<T> {}
unsafe impl<T: Sync> Sync for AlignedBox<T> {}
#[repr(C, align(64))]
#[derive(Clone)]
pub struct Accumulator {
pub accumulation: [[Aligned<[i16; TRANSFORMED_FEATURE_DIMENSIONS]>; NUM_REFRESH_TRIGGERS]; 2],
pub score: Value,
pub computed_accumulation: bool,
pub computed_score: bool,
}
impl Default for Accumulator {
fn default() -> Self {
Self {
accumulation: [[Aligned([0i16; TRANSFORMED_FEATURE_DIMENSIONS]); NUM_REFRESH_TRIGGERS];
2],
score: Value::ZERO,
computed_accumulation: false,
computed_score: false,
}
}
}
impl Accumulator {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn reset(&mut self) {
self.computed_accumulation = false;
self.computed_score = false;
}
#[inline]
pub fn get(
&self,
perspective: usize,
trigger: usize,
) -> &[i16; TRANSFORMED_FEATURE_DIMENSIONS] {
&self.accumulation[perspective][trigger].0
}
#[inline]
pub fn get_mut(
&mut self,
perspective: usize,
trigger: usize,
) -> &mut [i16; TRANSFORMED_FEATURE_DIMENSIONS] {
&mut self.accumulation[perspective][trigger].0
}
}
#[derive(Clone, Copy)]
pub struct DirtyPiece {
pub piece_no: [PieceNumber; 2],
pub changed_piece: [ChangedBonaPiece; 2],
pub dirty_num: u8,
pub king_moved: [bool; Color::NUM],
}
impl DirtyPiece {
#[inline]
pub const fn new() -> Self {
Self {
piece_no: [PieceNumber::NONE; 2],
changed_piece: [ChangedBonaPiece::EMPTY; 2],
dirty_num: 0,
king_moved: [false; Color::NUM],
}
}
#[inline]
pub fn clear(&mut self) {
self.dirty_num = 0;
self.king_moved = [false; Color::NUM];
}
}
impl Default for DirtyPiece {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Copy)]
pub struct ChangedBonaPiece {
pub old_piece: ExtBonaPiece,
pub new_piece: ExtBonaPiece,
}
impl ChangedBonaPiece {
pub const EMPTY: Self = Self {
old_piece: ExtBonaPiece::ZERO,
new_piece: ExtBonaPiece::ZERO,
};
}
#[repr(C, align(64))]
#[derive(Default)]
pub struct StackEntry {
pub accumulator: Accumulator,
pub dirty_piece: DirtyPiece,
pub previous: Option<usize>,
}
pub struct AccumulatorStack {
entries: Box<[StackEntry]>,
current_idx: usize,
}
impl AccumulatorStack {
pub const SIZE: usize = (MAX_PLY + 1) as usize;
pub fn new() -> Self {
let entries: Vec<StackEntry> = (0..Self::SIZE).map(|_| StackEntry::default()).collect();
Self {
entries: entries.into_boxed_slice(),
current_idx: 0,
}
}
pub fn reset(&mut self) {
self.current_idx = 0;
self.entries[0].accumulator.reset();
self.entries[0].dirty_piece.clear();
self.entries[0].previous = None;
}
#[inline]
pub fn current(&self) -> &StackEntry {
&self.entries[self.current_idx]
}
#[inline]
pub fn current_mut(&mut self) -> &mut StackEntry {
&mut self.entries[self.current_idx]
}
#[inline]
pub fn entry_at(&self, idx: usize) -> &StackEntry {
&self.entries[idx]
}
#[inline]
pub fn entry_at_mut(&mut self, idx: usize) -> &mut StackEntry {
&mut self.entries[idx]
}
#[inline]
pub fn current_index(&self) -> usize {
self.current_idx
}
#[inline]
pub fn push(&mut self, dirty_piece: DirtyPiece) {
let prev_idx = self.current_idx;
self.current_idx += 1;
debug_assert!(self.current_idx < Self::SIZE, "AccumulatorStack overflow");
let entry = &mut self.entries[self.current_idx];
entry.previous = Some(prev_idx);
entry.accumulator.reset(); entry.dirty_piece = dirty_piece;
}
#[inline]
pub fn pop(&mut self) {
debug_assert!(self.current_idx > 0, "AccumulatorStack underflow");
self.current_idx -= 1;
}
#[inline]
pub fn get_prev_and_current_accumulators(
&mut self,
prev_idx: usize,
) -> (&Accumulator, &mut Accumulator) {
let cur_idx = self.current_idx;
debug_assert!(prev_idx < cur_idx, "prev_idx ({prev_idx}) must be < cur_idx ({cur_idx})");
let (left, right) = self.entries.split_at_mut(cur_idx);
(&left[prev_idx].accumulator, &mut right[0].accumulator)
}
pub fn find_usable_accumulator(&self) -> Option<(usize, usize)> {
const MAX_DEPTH: usize = 1;
let current = &self.entries[self.current_idx];
if current.dirty_piece.king_moved[Color::Black.index()]
|| current.dirty_piece.king_moved[Color::White.index()]
{
return None;
}
let mut prev_idx = current.previous?;
let mut depth = 1;
loop {
let prev = &self.entries[prev_idx];
if prev.accumulator.computed_accumulation {
return Some((prev_idx, depth));
}
if depth >= MAX_DEPTH {
return None;
}
let next_prev_idx = prev.previous?;
if prev.dirty_piece.king_moved[Color::Black.index()]
|| prev.dirty_piece.king_moved[Color::White.index()]
{
return None;
}
prev_idx = next_prev_idx;
depth += 1;
}
}
pub fn collect_path(&self, source_idx: usize) -> Option<IndexList<MAX_PATH_LENGTH>> {
if self.current_idx.saturating_sub(source_idx) > MAX_PATH_LENGTH {
return None;
}
let mut path = IndexList::new();
let mut idx = self.current_idx;
while idx != source_idx {
if !path.push(idx) {
debug_assert!(false, "collect_path overflow: MAX_PATH_LENGTH={MAX_PATH_LENGTH}");
return None;
}
let entry = &self.entries[idx];
match entry.previous {
Some(prev_idx) => idx = prev_idx,
None => {
debug_assert!(
false,
"Path broken: expected to reach source_idx={source_idx} but got None at idx={idx}"
);
return None;
}
}
}
path.reverse();
Some(path)
}
}
impl Default for AccumulatorStack {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_accumulator_new() {
let acc = Accumulator::new();
assert!(!acc.computed_accumulation);
assert!(!acc.computed_score);
assert_eq!(acc.score, Value::ZERO);
}
#[test]
fn test_accumulator_reset() {
let mut acc = Accumulator::new();
acc.computed_accumulation = true;
acc.computed_score = true;
acc.reset();
assert!(!acc.computed_accumulation);
assert!(!acc.computed_score);
}
#[test]
fn test_accumulator_get() {
let mut acc = Accumulator::new();
acc.accumulation[0][0].0[0] = 100;
acc.accumulation[1][0].0[0] = 200;
assert_eq!(acc.get(0, 0)[0], 100);
assert_eq!(acc.get(1, 0)[0], 200);
}
#[test]
fn test_accumulator_alignment() {
let acc = Accumulator::new();
let addr = &acc as *const _ as usize;
assert_eq!(addr % 64, 0);
}
#[test]
fn test_dirty_piece_new() {
let dp = DirtyPiece::new();
assert_eq!(dp.dirty_num, 0);
assert!(!dp.king_moved[0]);
assert!(!dp.king_moved[1]);
}
#[test]
fn test_accumulator_stack_push_pop() {
let mut stack = AccumulatorStack::new();
assert_eq!(stack.current_index(), 0);
stack.push(DirtyPiece::new());
assert_eq!(stack.current_index(), 1);
assert_eq!(stack.current().previous, Some(0));
stack.push(DirtyPiece::new());
assert_eq!(stack.current_index(), 2);
assert_eq!(stack.current().previous, Some(1));
stack.pop();
assert_eq!(stack.current_index(), 1);
stack.pop();
assert_eq!(stack.current_index(), 0);
}
#[test]
fn test_accumulator_stack_reset() {
let mut stack = AccumulatorStack::new();
stack.push(DirtyPiece::new());
stack.push(DirtyPiece::new());
stack.current_mut().accumulator.computed_accumulation = true;
stack.reset();
assert_eq!(stack.current_index(), 0);
assert!(!stack.current().accumulator.computed_accumulation);
}
#[test]
fn test_accumulator_stack_find_usable() {
let mut stack = AccumulatorStack::new();
stack.current_mut().accumulator.computed_accumulation = true;
stack.push(DirtyPiece::new());
let result = stack.find_usable_accumulator();
assert!(result.is_some());
let (idx, depth) = result.unwrap();
assert_eq!(idx, 0);
assert_eq!(depth, 1);
}
#[test]
fn test_accumulator_stack_find_usable_exceeds_max_depth() {
let mut stack = AccumulatorStack::new();
stack.current_mut().accumulator.computed_accumulation = true;
stack.push(DirtyPiece::new());
stack.push(DirtyPiece::new());
let result = stack.find_usable_accumulator();
assert!(result.is_none());
}
#[test]
fn test_accumulator_stack_find_usable_with_king_move() {
let mut stack = AccumulatorStack::new();
stack.current_mut().accumulator.computed_accumulation = true;
let mut dp = DirtyPiece::new();
dp.king_moved[Color::Black.index()] = true;
stack.push(dp);
let result = stack.find_usable_accumulator();
assert!(result.is_none());
}
}