use super::*;
use atomics::*;
use search_tree::*;
pub unsafe trait TranspositionTable<Spec: MCTS>: Sync + Sized {
fn insert<'a>(
&'a self,
key: &Spec::State,
value: &'a SearchNode<Spec>,
handle: SearchHandle<Spec>,
) -> Option<&'a SearchNode<Spec>>;
fn lookup<'a>(
&'a self,
key: &Spec::State,
handle: SearchHandle<Spec>,
) -> Option<&'a SearchNode<Spec>>;
fn clear(&mut self) {}
}
unsafe impl<Spec: MCTS<TranspositionTable = Self>> TranspositionTable<Spec> for () {
fn insert<'a>(
&'a self,
_: &Spec::State,
_: &'a SearchNode<Spec>,
_: SearchHandle<Spec>,
) -> Option<&'a SearchNode<Spec>> {
None
}
fn lookup<'a>(
&'a self,
_: &Spec::State,
_: SearchHandle<Spec>,
) -> Option<&'a SearchNode<Spec>> {
None
}
}
pub trait TranspositionHash {
fn hash(&self) -> u64;
}
pub struct ApproxQuadraticProbingHashTable<K: TranspositionHash, V> {
arr: Box<[Entry16<K, V>]>,
capacity: usize,
mask: usize,
size: AtomicUsize,
}
struct Entry16<K: TranspositionHash, V> {
k: AtomicU64,
v: AtomicPtr<V>,
_marker: std::marker::PhantomData<K>,
}
impl<K: TranspositionHash, V> Default for Entry16<K, V> {
fn default() -> Self {
Self {
k: Default::default(),
v: Default::default(),
_marker: Default::default(),
}
}
}
impl<K: TranspositionHash, V> Clone for Entry16<K, V> {
fn clone(&self) -> Self {
Self {
k: AtomicU64::new(self.k.load(Ordering::Relaxed)),
v: AtomicPtr::new(self.v.load(Ordering::Relaxed)),
_marker: Default::default(),
}
}
}
impl<K: TranspositionHash, V> ApproxQuadraticProbingHashTable<K, V> {
pub fn new(capacity: usize) -> Self {
assert!(std::mem::size_of::<Entry16<K, V>>() <= 16);
assert!(
capacity.count_ones() == 1,
"the capacity must be a power of 2"
);
let arr = vec![Entry16::default(); capacity].into_boxed_slice();
let mask = capacity - 1;
Self {
arr,
mask,
capacity,
size: AtomicUsize::default(),
}
}
pub fn enough_to_hold(num: usize) -> Self {
let mut capacity = 1;
while capacity * 2 < num * 3 {
capacity <<= 1;
}
Self::new(capacity)
}
}
unsafe impl<K: TranspositionHash, V> Sync for ApproxQuadraticProbingHashTable<K, V> {}
unsafe impl<K: TranspositionHash, V> Send for ApproxQuadraticProbingHashTable<K, V> {}
pub type ApproxTable<Spec> =
ApproxQuadraticProbingHashTable<<Spec as MCTS>::State, SearchNode<Spec>>;
fn get_or_write<'a, V>(ptr: &AtomicPtr<V>, v: &'a V) -> Option<&'a V> {
let result = ptr.compare_exchange(
std::ptr::null_mut(),
v as *const _ as *mut _,
Ordering::Release,
Ordering::Acquire,
);
convert(result.unwrap_or_else(|x| x))
}
fn convert<'a, V>(ptr: *const V) -> Option<&'a V> {
if ptr.is_null() {
None
} else {
unsafe { Some(&*ptr) }
}
}
const PROBE_LIMIT: usize = 16;
unsafe impl<Spec> TranspositionTable<Spec> for ApproxTable<Spec>
where
Spec::State: TranspositionHash,
Spec: MCTS,
{
fn insert<'a>(
&'a self,
key: &Spec::State,
value: &'a SearchNode<Spec>,
handle: SearchHandle<Spec>,
) -> Option<&'a SearchNode<Spec>> {
if self.size.load(Ordering::Relaxed) * 3 > self.capacity * 2 {
return self.lookup(key, handle);
}
let my_hash = key.hash();
if my_hash == 0 {
return None;
}
let mut posn = my_hash as usize & self.mask;
for inc in 1..(PROBE_LIMIT + 1) {
let entry = unsafe { self.arr.get_unchecked(posn) };
let key_here = entry.k.load(Ordering::Relaxed);
if key_here == my_hash {
let value_here = entry.v.load(Ordering::Acquire);
if !value_here.is_null() {
return unsafe { Some(&*value_here) };
}
return get_or_write(&entry.v, value);
}
if key_here == 0 {
let key_here = entry
.k
.compare_exchange(0, my_hash, Ordering::Relaxed, Ordering::Relaxed)
.unwrap_or_else(|x| x);
if key_here == 0 || key_here == my_hash {
self.size.fetch_add(1, Ordering::Relaxed);
return get_or_write(&entry.v, value);
}
}
posn += inc;
posn &= self.mask;
}
None
}
fn clear(&mut self) {
for entry in self.arr.iter_mut() {
*entry.k.get_mut() = 0;
*entry.v.get_mut() = std::ptr::null_mut();
}
*self.size.get_mut() = 0;
}
fn lookup<'a>(
&'a self,
key: &Spec::State,
_: SearchHandle<Spec>,
) -> Option<&'a SearchNode<Spec>> {
let my_hash = key.hash();
let mut posn = my_hash as usize & self.mask;
for inc in 1..(PROBE_LIMIT + 1) {
let entry = unsafe { self.arr.get_unchecked(posn) };
let key_here = entry.k.load(Ordering::Relaxed);
if key_here == my_hash {
return convert(entry.v.load(Ordering::Acquire));
}
if key_here == 0 {
return None;
}
posn += inc;
posn &= self.mask;
}
None
}
}