chess/
transposition.rs

1// https://github.com/mvanthoor/rustic/blob/4.0-beta/src/engine/transposition.rs
2// Based on this author's work, mainly to understand generic types. Only used for type TableEntry currently.
3
4use std::vec;
5
6use crate::zobrist::PositionHash;
7use crate::{util, ShortMove, NULL_SHORT_MOVE};
8
9const DEFAULT_TABLE_SIZE_MB: usize = 200; // in MiB
10const NUM_BUCKETS: usize = 3;
11const UNINIT_ENTRY: TableEntry = TableEntry {
12    bound_type: BoundType::Invalid,
13    depth: 0,
14    eval: 0,
15    mv: NULL_SHORT_MOVE,
16};
17
18// TT with generic type T as TableEntry
19pub type TranspositionTable<T = TableEntry> = TT<T>;
20
21// TTData trait must be implemented for any type used in the TT
22pub trait TTData {
23    fn new() -> Self;
24    fn get_depth(&self) -> u8;
25    fn is_empty(&self) -> bool;
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum BoundType {
30    Exact,
31    Lower,
32    Upper,
33    Invalid,
34}
35
36// TODO detect checkmate distance
37#[derive(Debug, Clone, Copy)]
38pub struct TableEntry {
39    pub bound_type: BoundType,
40    pub depth: u8,
41    pub eval: i32,
42    pub mv: ShortMove,
43}
44impl TTData for TableEntry {
45    fn new() -> Self {
46        UNINIT_ENTRY
47    }
48
49    fn get_depth(&self) -> u8 {
50        self.depth
51    }
52
53    fn is_empty(&self) -> bool {
54        self.bound_type == BoundType::Invalid
55    }
56}
57
58#[derive(Debug)]
59pub struct TT<T> {
60    table: Vec<Entry<T>>,
61    entry_count: usize,
62    size_mb: usize,
63}
64impl<T: TTData + Copy + Clone> Default for TT<T> {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69impl<T: TTData + Copy + Clone> TT<T> {
70    pub fn new() -> Self {
71        Self::with_size(DEFAULT_TABLE_SIZE_MB)
72    }
73
74    pub fn with_size(size_mb: usize) -> Self {
75        let table = vec![Entry::<T>::new(); Self::mb_to_len(size_mb)];
76        Self {
77            table,
78            entry_count: 0,
79            size_mb,
80        }
81    }
82
83    pub fn get(&self, hash: PositionHash) -> Option<&T> {
84        if self.size_mb != 0 {
85            self.table[self.get_idx(hash)].get(self.get_bucket_hash(hash))
86        } else {
87            None
88        }
89    }
90
91    pub fn insert(&mut self, hash: PositionHash, data: T) {
92        if self.size_mb != 0 {
93            let idx = self.get_idx(hash);
94            let bucket_hash = self.get_bucket_hash(hash);
95            // returns true if the bucket was empty, so we can increment entry_count
96            if self.table[idx].insert(bucket_hash, data) {
97                self.entry_count += 1;
98            }
99        }
100    }
101
102    pub fn size(&self) -> usize {
103        self.table.len() * NUM_BUCKETS
104    }
105
106    pub fn heap_alloc_size(&self) -> usize {
107        self.table.len() * std::mem::size_of::<Entry<T>>()
108    }
109
110    pub fn len(&self) -> usize {
111        self.entry_count
112    }
113
114    pub fn is_empty(&self) -> bool {
115        self.entry_count == 0
116    }
117
118    pub fn clear(&mut self) {
119        self.entry_count = 0;
120        self.table.iter_mut().for_each(|entry| {
121            *entry = Entry::new();
122        });
123    }
124
125    const fn mb_to_len(mb_size: usize) -> usize {
126        (mb_size * 1024 * 1024) / std::mem::size_of::<Entry<T>>()
127    }
128
129    fn get_idx(&self, hash: PositionHash) -> usize {
130        let idx_hash = util::high_bits(hash); // use high bits for index, and low bits for bucket collision handling
131        (idx_hash as usize) % self.table.len()
132    }
133
134    const fn get_bucket_hash(&self, hash: PositionHash) -> u32 {
135        util::low_bits(hash)
136    }
137}
138
139#[derive(Debug, Clone, Copy)]
140struct Bucket<T> {
141    hash: u32,
142    data: T,
143}
144impl<T: TTData> Bucket<T> {
145    fn new() -> Self {
146        Self {
147            hash: 0,
148            data: T::new(),
149        }
150    }
151}
152
153#[derive(Debug, Clone, Copy)]
154struct Entry<T> {
155    buckets: [Bucket<T>; NUM_BUCKETS],
156}
157impl<T: TTData + Copy + Clone> Entry<T> {
158    fn new() -> Self {
159        Self {
160            buckets: [Bucket::new(); NUM_BUCKETS],
161        }
162    }
163
164    // returns true if the bucket was empty before data was inserted
165    fn insert(&mut self, hash: u32, data: T) -> bool {
166        let mut idx = 0;
167        for i in 1..self.buckets.len() {
168            // skip first bucket as we will start by comparing idx 0
169            // replacement strategy is removing lowest depth entry, uninitialised entries are depth 0
170            if self.buckets[i].data.get_depth() < self.buckets[idx].data.get_depth() {
171                idx = i;
172            }
173        }
174        let was_empty = self.buckets[idx].hash == 0;
175        self.buckets[idx].hash = hash;
176        self.buckets[idx].data = data;
177        if !was_empty {
178            log::trace!("TT bucket collision");
179        }
180        was_empty
181    }
182
183    fn get(&self, hash: u32) -> Option<&T> {
184        for bucket in &self.buckets {
185            if bucket.hash == hash {
186                return Some(&bucket.data);
187            }
188        }
189        None
190    }
191}