1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
//! Module for the `TranspositionTable`, a type of hash-map where Zobrist Keys map to information about a position.
//!
//! A [`TranspositionTable`] is a structure to quickly lookup chess positions and determine information from them.
//! It maps from Board positions to information such as the evaluation of that position, the best move found so far,
//! the depth that move was found at, etc.
//!
//! Specifically, a [`TranspositionTable`] maps from a u64 to an [`Entry`].
//!
//! This is a lock-free table, able to be concurrently accessed by multiple threads quickly. However, there is still
//! a risk of collisions & over-writes when using this with multiple threads. Furthermore, Keys (generated by a
//! zobrist hash) are not guaranteed to uniquely map to a specific chess position, and unique keys are not
//! guaranteed to map to unique buckets inside the table.  The chances of collision are extremely
//! low, but it's still something to take into account when using the transposition table.
//!
//! # Examples
//!
//! Here, we create a new [`TranspositionTable`] with 4,000 entries, and search for a key. Because this table is empty,
//! found should return false. Now, we insert the data for an [`Entry`], and then search again for the key. Now, found
//! will be true, and we'll end up with the data we entered.
//!
//! ```ignore
//! let tt = TranspositionTable::new_num_entries(40000);
//! let prng = PRNG::init(932445561);
//!
//! let key: u64 = prng.rand();
//! let (found, entry): (bool, &mut Entry) = tt.probe(key);
//! assert!(!found);
//! entry.place(key, BitMove::new(0x555), 3, 4, 3, NodeBound::Exact);
//! let (found, entry) = tt.probe(key);
//! assert!(found);
//! ```
//!
//! [`TranspositionTable`]: ../../tools/tt/struct.TranspositionTable.html
//! [`Entry`]: ../../tools/tt/struct.Entry.html

use std::ptr::NonNull;
use std::mem;
use std::heap::{Alloc, Layout, Global, oom};
use std::cmp::min;
use std::cell::UnsafeCell;

use prefetch::prefetch::*;

use super::PreFetchable;
use core::piece_move::BitMove;

// TODO: investigate potention for SIMD in key lookup
// Currently, there is now way to do this right now in rust without it being extensive.

pub type Key = u64;

/// BitMask for the [NodeTypeTimeBound]'s time data.
pub const TIME_MASK: u8 = 0b1111_1100;

/// BitMask for the retrieving a [NodeTypeTimeBound]'s [NodeType].
pub const NODE_TYPE_MASK: u8 = 0b0000_0011;

/// Number of Entries per Cluster.
pub const CLUSTER_SIZE: usize = 3;

const BYTES_PER_KB: usize = 1000;
const BYTES_PER_MB: usize = BYTES_PER_KB * 1000;
const BYTES_PER_GB: usize = BYTES_PER_MB * 1000;

/// Designates the type of Node in the Chess Search tree.
/// See the [ChessWiki](https://chessprogramming.wikispaces.com/Node+Types) for more information
/// about PV Node types and their use.
#[derive(Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
pub enum NodeBound {
    NoBound = 0,
    LowerBound = 1,
    UpperBound = 2,
    Exact = 3,
}

/// Abstraction for combining the 'time' a node was found alongside the `NodeType`.
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct NodeTypeTimeBound {
    data: u8
}

impl NodeTypeTimeBound {
    /// Creates a NodeTypeTimeBound with the designated node_type and time.
    ///
    /// # Usage
    ///
    /// time_bound must be divisible by 8 or else Undefined behavior will follow.
    pub fn create(node_type: NodeBound, time_bound: u8) -> Self {
        NodeTypeTimeBound {
            data: time_bound + (node_type as u8)
        }
    }

    /// Updates the [NodeType] of an entry.
    pub fn update_bound(&mut self, node_type: NodeBound, gen: u8) {
        self.data = (self.data & TIME_MASK) | node_type as u8 | gen;
    }

    /// Updates the time field of an entry.
    pub fn update_time(&mut self, time_bound: u8) {
        self.data = (self.data & NODE_TYPE_MASK) | time_bound;
    }
}



// 2 bytes + 2 bytes + 2 Byte + 2 byte + 1 + 1 = 10 Bytes

/// Structure defining a singular Entry in a table, containing the `BestMove` found,
/// the score of that node, the type of Node, depth found, as well as a key uniquely defining
/// the node.
#[derive(Clone,PartialEq)]
#[repr(C)]
pub struct Entry {
    pub partial_key: u16,
    pub best_move: BitMove, // What was the best move found here?
    pub score: i16, // What was the Score of this node?
    pub eval: i16, // What is the evaluation of this node
    pub depth: i8, // How deep was this Score Found?
    pub time_node_bound: NodeTypeTimeBound,
}

impl Entry {

    pub fn is_empty(&self) -> bool {
        self.node_type() == NodeBound::NoBound || self.partial_key == 0
    }

    /// Rewrites over an Entry.
    pub fn place(&mut self, key: Key, best_move: BitMove, score: i16, eval: i16, depth: i16, node_type: NodeBound, gen: u8) {
        let partial_key = key.wrapping_shr(48) as u16;

        if partial_key != self.partial_key {
            self.best_move = best_move;
        }

        if partial_key != self.partial_key
            || node_type == NodeBound::Exact || depth > self.depth as i16 - 4 {
            self.partial_key = partial_key;
            self.score = score;
            self.eval = eval;
            self.depth = depth as i8;
            self.time_node_bound.update_bound(node_type, gen);
        }
    }

    /// Returns the current search time of the node.
    pub fn time(&self) -> u8 {
        self.time_node_bound.data & TIME_MASK
    }

    /// Returns the [NodeType] of an Entry.
    pub fn node_type(&self) -> NodeBound {
        match self.time_node_bound.data & NODE_TYPE_MASK {
            0 => NodeBound::NoBound,
            1 => NodeBound::LowerBound,
            2 => NodeBound::UpperBound,
            _ => NodeBound::Exact,
        }
    }

    /// Returns the value of the node in respect to the depth searched && when it was placed into the TranspositionTable.
    pub fn time_value(&self, curr_time: u8) -> i16 {
        let inner: i16 = ((259i16).wrapping_add(curr_time as i16)).wrapping_sub(self.time_node_bound.data as i16) & 0b1111_1100;
        (self.depth as i16).wrapping_sub(inner).wrapping_mul(2)
    }
}


// 30 bytes + 2 = 32 Bytes
/// Structure containing multiple Entries all mapped to by the same zobrist key.
#[repr(C)]
pub struct Cluster {
    pub entry: [Entry; CLUSTER_SIZE],
    pub padding: [u8; 2],
}

// clusters -> Pointer to the clusters
// cap -> n number of clusters (So n * CLUSTER_SIZE) number of entries
// time age -> documenting when an entry was placed

/// Structure for representing a `TranspositionTable`. A Transposition Table is a type
/// of HashTable that maps Zobrist Keys to information about that position, including the best move
/// found, score, depth the move was found at, and other information.
pub struct TranspositionTable {
    clusters: UnsafeCell<NonNull<Cluster>>, // pointer to the heap
    cap: UnsafeCell<usize>, // number of clusters, so (So n * CLUSTER_SIZE) number of entries
    time_age: UnsafeCell<u8>, // documenting at which root position an entry was placed
}

impl TranspositionTable {
    pub const MAX_SIZE_MB: usize = 100000;

    /// Creates new with a size of around 'mb_size'. Actual size is the nearest power
    /// of 2 times the size of a Cluster rounded down.
    ///
    /// # Panics
    ///
    /// mb_size should be > 0, or else a panic will occur
    pub fn new(mb_size: usize) -> Self {
        assert!(mb_size > 0);
        let mut num_clusters: usize = (mb_size * BYTES_PER_MB) / mem::size_of::<Cluster>();
        num_clusters = num_clusters.next_power_of_two() / 2;
        TranspositionTable::new_num_clusters(num_clusters)
    }

    /// Creates new TT rounded up to the nearest power of two number of entries.
    ///
    /// # Panics
    ///
    /// num_entries should be > 0, or else a panic will occur
    pub fn new_num_entries(num_entries: usize) -> Self {
        TranspositionTable::new_num_clusters(num_entries * CLUSTER_SIZE)
    }

    /// Creates new TT rounded up to the nearest power of two number of Clusters.
    ///
    /// # Panics
    ///
    /// Size should be > 0, or else a panic will occur
    pub fn new_num_clusters(num_clusters: usize) -> Self {
        TranspositionTable::create(num_clusters.next_power_of_two())
    }

    // Creates new TT with the number of Clusters being size. size must be a power of two.
    fn create(size: usize) -> Self {
        assert_eq!(size.count_ones(), 1);
        assert!(size > 0);
        TranspositionTable {
            clusters: UnsafeCell::new(alloc_room(size)),
            cap: UnsafeCell::new(size),
            time_age: UnsafeCell::new(0),
        }
    }

    pub unsafe fn uninitialized_init(&self, mb_size: usize) {
        let mut num_clusters: usize = (mb_size * BYTES_PER_MB) / mem::size_of::<Cluster>();
        num_clusters = num_clusters.next_power_of_two() / 2;
        self.re_alloc(num_clusters);
    }

    /// Returns the size of the heap allocated portion of the TT in KiloBytes.
    #[inline(always)]
    pub fn size_kilobytes(&self) -> usize {
        (mem::size_of::<Cluster>() * self.num_clusters()) / BYTES_PER_KB
    }

    /// Returns the size of the heap allocated portion of the TT in MegaBytes.
    #[inline(always)]
    pub fn size_megabytes(&self) -> usize {
        (mem::size_of::<Cluster>() * self.num_clusters()) / BYTES_PER_MB
    }

    /// Returns the size of the heap allocated portion of the TT in GigaBytes.
    #[inline(always)]
    pub fn size_gigabytes(&self) -> usize {
        (mem::size_of::<Cluster>() * self.num_clusters()) / BYTES_PER_GB

    }

    /// Returns the number of clusters the Transposition Table holds.
    #[inline(always)]
    pub fn num_clusters(&self) -> usize {
        unsafe {
            *self.cap.get()
        }
    }

    /// Returns the number of Entries the Transposition Table holds.
    #[inline(always)]
    pub fn num_entries(&self) -> usize {
        self.num_clusters() * CLUSTER_SIZE
    }

    /// Re-sizes to 'size' number of Clusters and deletes all data
    ///
    /// # Panic
    ///
    /// size must be greater then 0
    ///
    /// # Safety
    ///
    /// This is function is unsafe to use if the TT is currently being accessed, Or any thread of
    /// structure contains a current reference to a `TTEntry`. Otherwise, using this function will
    /// absolutely lead to Segmentation Fault.
    pub unsafe fn resize_round_up(&self, size: usize) {
        self.resize(size.next_power_of_two());
    }

    /// Re-sizes to the the mb_size number of megabytes, rounded down for power of 2
    /// number of clusters. Returns the actual size.
    ///
    /// # Panic
    ///
    /// mb_size must be greater then 0
    ///
    /// # Safety
    ///
    /// This is function is unsafe to use if the TT is currently being accessed, Or any thread of
    /// structure contains a current reference to a `TTEntry`. Otherwise, using this function will
    /// absolutely lead to Segmentation Fault.
    pub unsafe fn resize_to_megabytes(&self, mb_size: usize) -> usize {
        assert!(mb_size > 0);
        let mut num_clusters: usize = (mb_size * BYTES_PER_MB) / mem::size_of::<Cluster>();
        num_clusters = num_clusters.next_power_of_two() / 2;
        self.resize(num_clusters);
        self.size_megabytes()
    }

    // resizes the tt to a certain type
    // TODO: Modify self.cap
    unsafe fn resize(&self, size: usize) {
        assert_eq!(size.count_ones(), 1);
        assert!(size > 0);
        self.de_alloc();
        self.re_alloc(size);
    }

    /// Clears the entire TranspositionTable
    ///
    /// # Safety
    ///
    /// This is function is unsafe to use if the TT is currently being accessed, Or any thread of
    /// structure contains a current reference to a `TTEntry`. Otherwise, using this function will
    /// absolutely lead to Segmentation Fault.
    pub unsafe fn clear(&self) {
        let size = self.cap.get();
        self.resize(*size);
    }

    // Called each time a new position is searched.
    #[inline]
    pub fn new_search(&self) {
        unsafe {
            let c = self.time_age.get();
            *c = (*c).wrapping_add(4);
        }
    }

    /// Returns the current time age of a TT.
    #[inline]
    pub fn time_age(&self) -> u8 {
        unsafe {
            *self.time_age.get()
        }
    }

    /// Returns the current number of cycles a TT has gone through. Cycles is simply the
    /// number of times refresh has been called.
    #[inline]
    pub fn time_age_cylces(&self) -> u8 {
        unsafe {
            (*self.time_age.get()).wrapping_shr(2)
        }
    }

    /// Probes the Transposition Table for a specified Key. Returns (true, entry) if either (1) an
    /// Entry corresponding to the current key is found, or an Open Entry slot is found for the key.
    /// In the case of an open Entry, the entry can be tested for its contents by using `Entry::is_empty()`.
    /// If no entry is found && there are no open entries, returns the entry that is is most irrelevent to
    /// the current search, e.g. has the shallowest depth or was found in a previous search.
    ///
    /// If 'true' is returned, the Entry is guaranteed to have the correct time.
    pub fn probe(&self, key: Key) -> (bool, &mut Entry) {
        let partial_key: u16 = (key).wrapping_shr(48) as u16;

        unsafe {
            let cluster: *mut Cluster = self.cluster(key);
            let init_entry: *mut Entry = cluster_first_entry(cluster);

            // for each entry
            for i in 0..CLUSTER_SIZE {
                // get a pointer to the specified entry
                let entry_ptr: *mut Entry = init_entry.offset(i as isize);
                // convert to &mut
                let entry: &mut Entry = &mut (*entry_ptr);

                // found a spot
                if entry.partial_key == 0 || entry.partial_key == partial_key {

                    // if age is incorrect, make it correct
                    if entry.time() != self.time_age() && entry.partial_key != 0 {
                        entry.time_node_bound.update_time(self.time_age());
                    }

                    // Return the spot
                    return (entry.partial_key != 0, entry);
                }
            }

            let mut replacement: *mut Entry = init_entry;
            let mut replacement_score: i16 = (&*replacement).time_value(self.time_age());

            // Table is full, find the best replacement based on depth and time placed there
            for i in 1..CLUSTER_SIZE {
                let entry_ptr: *mut Entry = init_entry.offset(i as isize);
                let entry_score: i16 = (&*entry_ptr).time_value(self.time_age());
                if entry_score < replacement_score {
                    replacement = entry_ptr;
                    replacement_score = entry_score;
                }
            }
            // return the best place to replace
            (false, &mut (*replacement))
        }
    }

    /// Returns the cluster of a given key.
    #[inline]
    fn cluster(&self, key: Key) -> *mut Cluster {
        let index: usize = ((self.num_clusters() - 1) as u64 & key) as usize;
        unsafe {
            (*self.clusters.get()).as_ptr().offset(index as isize)
        }
    }

    // Re-Allocates the current TT to a specified size.
    unsafe fn re_alloc(&self, size: usize) {
        let c = self.clusters.get();
        *c = alloc_room(size);
    }

    /// De-allocates the current heap.
    unsafe fn de_alloc(&self) {
        Global.dealloc((*self.clusters.get()).as_opaque(), Layout::array::<Cluster>(*self.cap.get()).unwrap());
    }

    /// Returns the % of the hash table that is full.
    pub fn hash_percent(&self) -> f64 {
        unsafe {
            let clusters_scanned: u64 = min((*self.cap.get() - 1) as u64, 333);
            let mut hits: f64 = 0.0;

            for i in 0..clusters_scanned {
                let cluster = self.cluster(i + 1);
                let init_entry: *mut Entry = cluster_first_entry(cluster);
                for e in 0..CLUSTER_SIZE {
                    // get a pointer to the specified entry
                    let entry_ptr: *mut Entry = init_entry.offset(e as isize);
                    let entry: &Entry = & (*entry_ptr);
                    if entry.time() == self.time_age() {
                        hits += 1.0;
                    }
                }
            }
            (hits * 100.0) / (clusters_scanned * CLUSTER_SIZE as u64) as f64
        }
    }
}

unsafe impl Sync for TranspositionTable {}

impl PreFetchable for TranspositionTable {
    /// Pre-fetches a particular key. This means bringing it into the cache for faster eventual
    /// access.
    #[inline(always)]
    fn prefetch(&self, key: u64) {
        let index: usize = ((self.num_clusters() - 1) as u64 & key) as usize;
        unsafe {
            let ptr = (*self.clusters.get()).as_ptr().offset(index as isize);
            prefetch::<Write, High, Data, Cluster>(ptr);
        };
    }
}

impl Drop for TranspositionTable {
    fn drop(&mut self) {
        unsafe {self.de_alloc();}
    }
}

/// Returns the first entry of a cluster
#[inline]
unsafe fn cluster_first_entry(cluster: *mut Cluster) -> *mut Entry {
    (*cluster).entry.get_unchecked_mut(0) as *mut Entry
}

// Return a Heap Allocation of Size number of Clusters.
#[inline]
fn alloc_room(size: usize) -> NonNull<Cluster> {
    unsafe {
        let ptr = Global.alloc_zeroed(Layout::array::<Cluster>(size).unwrap());

        let new_ptr = match ptr {
            Ok(ptr) => ptr.cast(),
            Err(_err) => oom(),
        };
        new_ptr
    }

}


#[cfg(test)]
mod tests {

    extern crate rand;
    use super::*;

    use std::thread::sleep;
    use std::time::Duration;

    use std::sync::atomic::Ordering;
    use std::sync::atomic::compiler_fence;

    // around 0.5 GB
    const HALF_GIG: usize = 2 << 24;
    // around 30 MB
    const THIRTY_MB: usize = 2 << 20;

    #[test]
    fn tt_alloc_realloc() {
        let size: usize = 8;
        let tt = TranspositionTable::create(size);
        assert_eq!(tt.num_clusters(), size);

        let key = create_key(32, 44);
        let (_found,_entry) = tt.probe(key);

        sleep(Duration::from_millis(1));
    }

    #[test]
    fn tt_test_sizes() {
        let tt = TranspositionTable::new_num_clusters(100);
        assert_eq!(tt.num_clusters(), (100 as usize).next_power_of_two());
        assert_eq!(tt.num_entries(), (100 as usize).next_power_of_two() * CLUSTER_SIZE);
        compiler_fence(Ordering::Release);
        sleep(Duration::from_millis(1));
    }

    #[test]
    fn tt_null_ptr() {
        let size: usize = 2 << 20;
        let tt = TranspositionTable::new_num_clusters(size);

        for x  in 0..1_000_000 as u64 {
            let key: u64 = rand::random::<u64>();
            {
                let (_found, entry) = tt.probe(key);
                entry.depth = (x % 0b1111_1111) as i8;
                entry.partial_key = key.wrapping_shr(48) as u16;
            }
            tt.new_search();
        }
        compiler_fence(Ordering::Release);
        sleep(Duration::from_millis(1));
    }

    #[test]
    fn tt_basic_insert() {
        let tt = TranspositionTable::new_num_clusters(THIRTY_MB);
        let partial_key_1: u16 = 17773;
        let key_index: u64 = 0x5556;

        let key_1 = create_key(partial_key_1, 0x5556);
        let (found, entry) = tt.probe(key_1);
        assert!(!found);
        entry.partial_key = partial_key_1;
        entry.depth = 2;

        let (found, entry) = tt.probe(key_1);
        assert!(found);
        assert!(entry.is_empty());
        assert_eq!(entry.partial_key,partial_key_1);
        assert_eq!(entry.depth,2);

        let partial_key_2: u16 = 8091;
        let partial_key_3: u16 = 12;
        let key_2: u64 = create_key(partial_key_2, key_index);
        let key_3: u64 = create_key(partial_key_3, key_index);

        let (found, entry) = tt.probe(key_2);
        assert!(!found);
        assert!(entry.is_empty());
        entry.partial_key = partial_key_2;
        entry.depth = 3;

        let (found, entry) = tt.probe(key_3);
        assert!(!found);
        assert!(entry.is_empty());
        entry.partial_key = partial_key_3;
        entry.depth = 6;

        // key that should find a good replacement
        let partial_key_4: u16 = 18;
        let key_4: u64 = create_key(partial_key_4, key_index);

        let (found, entry) = tt.probe(key_4);
        assert!(!found);

        // most vulnerable should be key_1
        assert_eq!(entry.partial_key, partial_key_1);
        assert_eq!(entry.depth, 2);

        compiler_fence(Ordering::Release);
        sleep(Duration::from_millis(1));
    }

    /// Helper function to create a key of specified index / partial_key
    fn create_key(partial_key: u16, full_key: u64) -> u64 {
        (partial_key as u64).wrapping_shl(48) | (full_key & 0x0000_FFFF_FFFF_FFFF)
    }
}