use std::ptr::NonNull;
use std::mem;
use std::alloc::{Layout, handle_alloc_error, self};
use std::cmp::min;
use std::cell::UnsafeCell;
use super::PreFetchable;
use core::piece_move::BitMove;
use super::prefetch_write;
pub type Key = u64;
pub const TIME_MASK: u8 = 0b1111_1100;
pub const NODE_TYPE_MASK: u8 = 0b0000_0011;
pub const CLUSTER_SIZE: usize = 3;
const BYTES_PER_KB: usize = 1000;
const BYTES_PER_MB: usize = BYTES_PER_KB * 1000;
const BYTES_PER_GB: usize = BYTES_PER_MB * 1000;
#[derive(Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
pub enum NodeBound {
NoBound = 0,
LowerBound = 1,
UpperBound = 2,
Exact = 3,
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct NodeTypeTimeBound {
data: u8
}
impl NodeTypeTimeBound {
pub fn create(node_type: NodeBound, time_bound: u8) -> Self {
NodeTypeTimeBound {
data: time_bound + (node_type as u8)
}
}
pub fn update_bound(&mut self, node_type: NodeBound, gen: u8) {
self.data = (self.data & TIME_MASK) | node_type as u8 | gen;
}
pub fn update_time(&mut self, time_bound: u8) {
self.data = (self.data & NODE_TYPE_MASK) | time_bound;
}
}
#[derive(Clone,PartialEq)]
#[repr(C)]
pub struct Entry {
pub partial_key: u16,
pub best_move: BitMove, pub score: i16, pub eval: i16, pub depth: i8, pub time_node_bound: NodeTypeTimeBound,
}
impl Entry {
pub fn is_empty(&self) -> bool {
self.node_type() == NodeBound::NoBound || self.partial_key == 0
}
pub fn place(&mut self, key: Key, best_move: BitMove, score: i16, eval: i16, depth: i16, node_type: NodeBound, gen: u8) {
let partial_key = key.wrapping_shr(48) as u16;
if partial_key != self.partial_key {
self.best_move = best_move;
}
if partial_key != self.partial_key
|| node_type == NodeBound::Exact || depth > self.depth as i16 - 4 {
self.partial_key = partial_key;
self.score = score;
self.eval = eval;
self.depth = depth as i8;
self.time_node_bound.update_bound(node_type, gen);
}
}
pub fn time(&self) -> u8 {
self.time_node_bound.data & TIME_MASK
}
pub fn node_type(&self) -> NodeBound {
match self.time_node_bound.data & NODE_TYPE_MASK {
0 => NodeBound::NoBound,
1 => NodeBound::LowerBound,
2 => NodeBound::UpperBound,
_ => NodeBound::Exact,
}
}
pub fn time_value(&self, curr_time: u8) -> i16 {
let inner: i16 = ((259i16).wrapping_add(curr_time as i16)).wrapping_sub(self.time_node_bound.data as i16) & 0b1111_1100;
(self.depth as i16).wrapping_sub(inner).wrapping_mul(2)
}
}
#[repr(C)]
pub struct Cluster {
pub entry: [Entry; CLUSTER_SIZE],
pub padding: [u8; 2],
}
pub struct TranspositionTable {
clusters: UnsafeCell<NonNull<Cluster>>, cap: UnsafeCell<usize>, time_age: UnsafeCell<u8>, }
impl TranspositionTable {
pub const MAX_SIZE_MB: usize = 100000;
pub fn new(mb_size: usize) -> Self {
assert!(mb_size > 0);
let mut num_clusters: usize = (mb_size * BYTES_PER_MB) / mem::size_of::<Cluster>();
num_clusters = num_clusters.next_power_of_two() / 2;
TranspositionTable::new_num_clusters(num_clusters)
}
pub fn new_num_entries(num_entries: usize) -> Self {
TranspositionTable::new_num_clusters(num_entries * CLUSTER_SIZE)
}
pub fn new_num_clusters(num_clusters: usize) -> Self {
TranspositionTable::create(num_clusters.next_power_of_two())
}
fn create(size: usize) -> Self {
assert_eq!(size.count_ones(), 1);
assert!(size > 0);
TranspositionTable {
clusters: UnsafeCell::new(alloc_room(size)),
cap: UnsafeCell::new(size),
time_age: UnsafeCell::new(0),
}
}
pub unsafe fn uninitialized_init(&self, mb_size: usize) {
let mut num_clusters: usize = (mb_size * BYTES_PER_MB) / mem::size_of::<Cluster>();
num_clusters = num_clusters.next_power_of_two() / 2;
self.re_alloc(num_clusters);
}
#[inline(always)]
pub fn size_kilobytes(&self) -> usize {
(mem::size_of::<Cluster>() * self.num_clusters()) / BYTES_PER_KB
}
#[inline(always)]
pub fn size_megabytes(&self) -> usize {
(mem::size_of::<Cluster>() * self.num_clusters()) / BYTES_PER_MB
}
#[inline(always)]
pub fn size_gigabytes(&self) -> usize {
(mem::size_of::<Cluster>() * self.num_clusters()) / BYTES_PER_GB
}
#[inline(always)]
pub fn num_clusters(&self) -> usize {
unsafe {
*self.cap.get()
}
}
#[inline(always)]
pub fn num_entries(&self) -> usize {
self.num_clusters() * CLUSTER_SIZE
}
pub unsafe fn resize_round_up(&self, size: usize) {
self.resize(size.next_power_of_two());
}
pub unsafe fn resize_to_megabytes(&self, mb_size: usize) -> usize {
assert!(mb_size > 0);
let mut num_clusters: usize = (mb_size * BYTES_PER_MB) / mem::size_of::<Cluster>();
num_clusters = num_clusters.next_power_of_two() / 2;
self.resize(num_clusters);
self.size_megabytes()
}
unsafe fn resize(&self, size: usize) {
assert_eq!(size.count_ones(), 1);
assert!(size > 0);
self.de_alloc();
self.re_alloc(size);
}
pub unsafe fn clear(&self) {
let size = self.cap.get();
self.resize(*size);
}
#[inline]
pub fn new_search(&self) {
unsafe {
let c = self.time_age.get();
*c = (*c).wrapping_add(4);
}
}
#[inline]
pub fn time_age(&self) -> u8 {
unsafe {
*self.time_age.get()
}
}
#[inline]
pub fn time_age_cylces(&self) -> u8 {
unsafe {
(*self.time_age.get()).wrapping_shr(2)
}
}
pub fn probe(&self, key: Key) -> (bool, &mut Entry) {
let partial_key: u16 = (key).wrapping_shr(48) as u16;
unsafe {
let cluster: *mut Cluster = self.cluster(key);
let init_entry: *mut Entry = cluster_first_entry(cluster);
for i in 0..CLUSTER_SIZE {
let entry_ptr: *mut Entry = init_entry.offset(i as isize);
let entry: &mut Entry = &mut (*entry_ptr);
if entry.partial_key == 0 || entry.partial_key == partial_key {
if entry.time() != self.time_age() && entry.partial_key != 0 {
entry.time_node_bound.update_time(self.time_age());
}
return (entry.partial_key != 0, entry);
}
}
let mut replacement: *mut Entry = init_entry;
let mut replacement_score: i16 = (&*replacement).time_value(self.time_age());
for i in 1..CLUSTER_SIZE {
let entry_ptr: *mut Entry = init_entry.offset(i as isize);
let entry_score: i16 = (&*entry_ptr).time_value(self.time_age());
if entry_score < replacement_score {
replacement = entry_ptr;
replacement_score = entry_score;
}
}
(false, &mut (*replacement))
}
}
#[inline]
fn cluster(&self, key: Key) -> *mut Cluster {
let index: usize = ((self.num_clusters() - 1) as u64 & key) as usize;
unsafe {
(*self.clusters.get()).as_ptr().offset(index as isize)
}
}
unsafe fn re_alloc(&self, size: usize) {
let c = self.clusters.get();
*c = alloc_room(size);
}
unsafe fn de_alloc(&self) {
let layout = Layout::from_size_align(*self.cap.get(), 2).unwrap();
let ptr: *mut u8 = mem::transmute(*self.clusters.get());
alloc::dealloc(ptr, layout);
}
pub fn hash_percent(&self) -> f64 {
unsafe {
let clusters_scanned: u64 = min((*self.cap.get() - 1) as u64, 333);
let mut hits: f64 = 0.0;
for i in 0..clusters_scanned {
let cluster = self.cluster(i + 1);
let init_entry: *mut Entry = cluster_first_entry(cluster);
for e in 0..CLUSTER_SIZE {
let entry_ptr: *mut Entry = init_entry.offset(e as isize);
let entry: &Entry = & (*entry_ptr);
if entry.time() == self.time_age() {
hits += 1.0;
}
}
}
(hits * 100.0) / (clusters_scanned * CLUSTER_SIZE as u64) as f64
}
}
}
unsafe impl Sync for TranspositionTable {}
impl PreFetchable for TranspositionTable {
#[inline(always)]
fn prefetch(&self, key: u64) {
let index: usize = ((self.num_clusters() - 1) as u64 & key) as usize;
unsafe {
let ptr = (*self.clusters.get()).as_ptr().offset(index as isize);
prefetch_write(ptr);
};
}
}
impl Drop for TranspositionTable {
fn drop(&mut self) {
unsafe {self.de_alloc();}
}
}
#[inline]
unsafe fn cluster_first_entry(cluster: *mut Cluster) -> *mut Entry {
(*cluster).entry.get_unchecked_mut(0) as *mut Entry
}
#[inline]
fn alloc_room(size: usize) -> NonNull<Cluster> {
unsafe {
let size = size * mem::size_of::<Cluster>();
let layout = Layout::from_size_align(size, 2).unwrap();
let ptr: *mut u8 = alloc::alloc_zeroed(layout);
let new_ptr: NonNull<Cluster> = match NonNull::new(ptr) {
Some(ptr) => ptr.cast(),
_ => handle_alloc_error(layout),
};
new_ptr
}
}
#[cfg(test)]
mod tests {
extern crate rand;
use super::*;
use std::thread::sleep;
use std::time::Duration;
use std::sync::atomic::Ordering;
use std::sync::atomic::compiler_fence;
const HALF_GIG: usize = 2 << 24;
const THIRTY_MB: usize = 2 << 20;
#[test]
fn tt_alloc_realloc() {
let size: usize = 8;
let tt = TranspositionTable::create(size);
assert_eq!(tt.num_clusters(), size);
let key = create_key(32, 44);
let (_found,_entry) = tt.probe(key);
sleep(Duration::from_millis(1));
}
#[test]
fn tt_test_sizes() {
let tt = TranspositionTable::new_num_clusters(100);
assert_eq!(tt.num_clusters(), (100 as usize).next_power_of_two());
assert_eq!(tt.num_entries(), (100 as usize).next_power_of_two() * CLUSTER_SIZE);
compiler_fence(Ordering::Release);
sleep(Duration::from_millis(1));
}
#[test]
fn tt_null_ptr() {
let size: usize = 2 << 20;
let tt = TranspositionTable::new_num_clusters(size);
for x in 0..1_000_000 as u64 {
let key: u64 = rand::random::<u64>();
{
let (_found, entry) = tt.probe(key);
entry.depth = (x % 0b1111_1111) as i8;
entry.partial_key = key.wrapping_shr(48) as u16;
}
tt.new_search();
}
compiler_fence(Ordering::Release);
sleep(Duration::from_millis(1));
}
#[test]
fn tt_basic_insert() {
let tt = TranspositionTable::new_num_clusters(THIRTY_MB);
let partial_key_1: u16 = 17773;
let key_index: u64 = 0x5556;
let key_1 = create_key(partial_key_1, 0x5556);
let (found, entry) = tt.probe(key_1);
assert!(!found);
entry.partial_key = partial_key_1;
entry.depth = 2;
let (found, entry) = tt.probe(key_1);
assert!(found);
assert!(entry.is_empty());
assert_eq!(entry.partial_key,partial_key_1);
assert_eq!(entry.depth,2);
let partial_key_2: u16 = 8091;
let partial_key_3: u16 = 12;
let key_2: u64 = create_key(partial_key_2, key_index);
let key_3: u64 = create_key(partial_key_3, key_index);
let (found, entry) = tt.probe(key_2);
assert!(!found);
assert!(entry.is_empty());
entry.partial_key = partial_key_2;
entry.depth = 3;
let (found, entry) = tt.probe(key_3);
assert!(!found);
assert!(entry.is_empty());
entry.partial_key = partial_key_3;
entry.depth = 6;
let partial_key_4: u16 = 18;
let key_4: u64 = create_key(partial_key_4, key_index);
let (found, entry) = tt.probe(key_4);
assert!(!found);
assert_eq!(entry.partial_key, partial_key_1);
assert_eq!(entry.depth, 2);
compiler_fence(Ordering::Release);
sleep(Duration::from_millis(1));
}
fn create_key(partial_key: u16, full_key: u64) -> u64 {
(partial_key as u64).wrapping_shl(48) | (full_key & 0x0000_FFFF_FFFF_FFFF)
}
}