thincollections/
cla_map.rs

1// Copyright 2018 Mohammad Rezaei.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// Portions copyright The Rust Project Developers. Licensed under
10// the MIT License.
11
12use util::*;
13use thin_hasher::*;
14
15use std::{
16    alloc::{self, Layout},
17    mem, ptr, marker,
18};
19use std::hash::BuildHasher;
20use std::hash::Hash;
21use std::hash::Hasher;
22use std::fmt::{self, Debug};
23
24pub struct ClaMap<K: Eq + Hash + Debug, V: Debug, H: BuildHasher> {
25    hasher: H,
26    table_blocks: usize,
27    occupied: usize,
28    flagged_blocks: usize, // a block is flagged if it ever gets full
29    max_occupied: usize,
30    block_kv_count: i8,
31    block_v_offset: i8,
32    table: *mut u64,
33    _marker: marker::PhantomData<(K, V)>,
34}
35
36#[derive(PartialEq)]
37enum BucketState {
38    Full,
39    NoSpace,
40    Empty,
41    Removed,
42}
43
44impl BucketState {
45    #[inline(always)]
46    pub fn is_full(&self) -> bool {
47        *self == BucketState::Full
48    }
49}
50
51impl<K: Eq + Hash + Debug, V: Debug> ClaMap<K, V, OneFieldHasherBuilder> {
52    pub fn new() -> Self {
53        let (count, v_start) = calculate_sizes::<K, V>();
54        ClaMap {
55            table_blocks: 0,
56            occupied: 0,
57            flagged_blocks: 0,
58            max_occupied: 0,
59            block_kv_count: count,
60            block_v_offset: v_start,
61            table: ptr::null_mut(),
62            hasher: OneFieldHasherBuilder::new(),
63            _marker: marker::PhantomData,
64        }
65    }
66}
67
68impl<K: Eq + Hash + Debug, V: Debug, H: BuildHasher> ClaMap<K, V, H> {
69    pub fn len(&self) -> usize {
70        self.occupied
71    }
72
73    fn key_at(ptr: *mut u8, index: i8) -> *mut K {
74        unsafe {
75            ptr.offset(mem::size_of::<K>() as isize * index as isize) as *mut K
76        }
77    }
78
79    fn value_at(&self, ptr: *mut u8, index: i8) -> *mut V {
80        unsafe {
81            ptr.offset(self.block_v_offset as isize + mem::size_of::<V>() as isize * index as isize) as *mut V
82        }
83    }
84
85    fn increment_control(block_ptr: *mut u8, max: i8) -> bool
86    {
87        unsafe {
88            let control_ptr = block_ptr.add(63);
89            let already_flagged = (*control_ptr & 0x80) != 0;
90            let mut v = (*control_ptr & 0x7F) + 1;
91            if v == max as u8 {
92                v |= 0x80;
93            }
94            v |= *control_ptr & 0x80;
95            *control_ptr = v;
96            return v & 0x80 != 0 && !already_flagged;
97        }
98    }
99
100    pub fn insert(&mut self, key: K, value: V) -> Option<V> {
101        if self.table_blocks == 0 {
102            self.allocate_table();
103        }
104        let (block_ptr, index, bucket_state) = self.probe(&key);
105        let flagged: bool;
106        unsafe {
107            if bucket_state.is_full() {
108                return Some(mem::replace(&mut (*self.value_at(block_ptr, index)), value));
109            }
110            ptr::write(<ClaMap<K, V, H>>::key_at(block_ptr, index), key);
111            ptr::write(self.value_at(block_ptr, index), value);
112            flagged = <ClaMap<K, V, H>>::increment_control(block_ptr, self.block_kv_count);
113        }
114        self.occupied += 1;
115        if flagged {
116            self.flagged_blocks += 1;
117        }
118        if self.occupied >= self.max_occupied || self.flagged_blocks >= (self.table_blocks >> 1) {
119            self.rehash();
120        }
121        return None;
122    }
123
124    #[inline]
125    fn hash(&self, key: &K) -> u64 {
126        let mut state = self.hasher.build_hasher();
127        (*key).hash(&mut state);
128        state.finish()
129    }
130
131    #[inline]
132    fn hash_and_mask(&self, key: &K) -> (u64, isize) {
133        let hash = self.hash(key);
134        (hash, self.mask(hash) as isize)
135    }
136
137    #[inline(always)]
138    fn mask(&self, hash: u64) -> u64 {
139        hash & ((self.table_blocks - 1) as u64)
140    }
141
142    // 4 things can happen:
143    // 1: we find the key, we'll return the index and BucketState::FULL
144    // 2: we don't find the key, the block is not flagged, the block is not full,
145    //      we'll return the index, BucketState::EMPTY
146    // 3: we don't find the key, the block is flagged, the block is not full,
147    //      we'll return the negative index, BucketState::REMOVED
148    // 4: we don't find the key, the block is full (and therefore flagged),
149    //      we'll return -1, BucketState::NO_SPACE
150    fn search_block(&self, block_ptr: *mut u8, key: &K) -> (i8, BucketState) {
151        unsafe {
152            let control_ptr = block_ptr.offset(63);
153            let count = (*control_ptr & 0x7F) as i8;
154            let mut ptr: *mut K = block_ptr as *mut K;
155            let mut index = 0;
156            while index < count {
157                if *key == *ptr {
158                    return (index, BucketState::Full); // case 1
159                }
160                index += 1;
161                ptr = ptr.add(1);
162            }
163            if count == self.block_kv_count {
164                return (-1, BucketState::NoSpace); // case 4
165            }
166            let flagged = *control_ptr & 0x80 != 0;
167            if flagged {
168                return (-count, BucketState::Removed); // case 3
169            }
170            return (count, BucketState::Empty); // case 2
171        }
172    }
173
174    fn probe(&self, key: &K) -> (*mut u8, i8, BucketState) {
175        let (hash, block_index) = self.hash_and_mask(key);
176        unsafe {
177            let block_ptr: *mut u8 = self.table.offset(block_index << 3) as *mut u8;
178            let (index, state): (i8, BucketState) = self.search_block(block_ptr, key);
179            if index >= 0 {
180                return (block_ptr, index, state);
181            }
182            self.probe2(key, hash, block_ptr, index, state)
183        }
184    }
185
186    fn spread_two_and_mask(&self, hash: u64) -> isize {
187        self.mask(spread_two(hash)) as isize
188    }
189
190    fn probe2(&self, key: &K, hash: u64, original_block_ptr: *mut u8, original_index: i8,
191              original_state: BucketState) -> (*mut u8, i8, BucketState) {
192        let block_index = self.spread_two_and_mask(hash);
193        unsafe {
194            let block_ptr: *mut u8 = self.table.offset(block_index << 3) as *mut u8;
195            let (index, state): (i8, BucketState) = self.search_block(block_ptr, key);
196            if index >= 0 {
197                if state == BucketState::Full {
198                    return (block_ptr, index, state);
199                }
200                if state == BucketState::Empty {
201                    if original_state == BucketState::Removed {
202                        return (original_block_ptr, original_index, original_state);
203                    }
204                    return (block_ptr, index, state);
205                }
206            }
207            if original_state == BucketState::Removed {
208                return self.probe3(key, hash, original_block_ptr, original_index, original_state);
209            }
210            return self.probe3(key, hash, block_ptr, index, state);
211        }
212    }
213
214    fn probe3(&self, key: &K, hash: u64, mut original_block_ptr: *mut u8, mut original_index: i8,
215              mut original_state: BucketState) -> (*mut u8, i8, BucketState) {
216        let mut next_index = spread_one(hash) as isize;
217        let spread_two = spread_two(hash).rotate_right(32) | 1;
218
219        loop {
220            unsafe {
221                next_index = self.mask((next_index as u64).wrapping_add(spread_two)) as isize;
222                let block_ptr: *mut u8 = self.table.offset(next_index << 3) as *mut u8;
223                let (index, state): (i8, BucketState) = self.search_block(block_ptr, key);
224                if index >= 0 {
225                    if state == BucketState::Full {
226                        return (block_ptr, index, state);
227                    }
228                    if state == BucketState::Empty {
229                        if original_state == BucketState::Removed {
230                            return (original_block_ptr, original_index, original_state);
231                        }
232                        return (block_ptr, index, state);
233                    }
234                }
235                if state == BucketState::Removed && original_state != BucketState::Removed {
236                    original_block_ptr = block_ptr;
237                    original_state = state;
238                    original_index = index;
239                }
240            }
241        }
242    }
243
244    fn allocate_table(&mut self) {
245        let (num_blocks, max_occupied) = self.block_for_capcity(8);
246        self.max_occupied = max_occupied;
247        self.table = <ClaMap<K, V, H>>::allocate_table_for_blocks(num_blocks);
248        self.table_blocks = num_blocks;
249    }
250
251    fn allocate_table_for_blocks(blocks: usize) -> *mut u64 {
252        unsafe {
253            let layout = Layout::from_size_align(64 * blocks, 64).unwrap();
254            alloc::alloc_zeroed(layout) as *mut u64
255        }
256    }
257
258    fn block_for_capcity(&self, capacity: usize) -> (usize, usize) {
259        let count_minus_one = self.block_kv_count - 1;
260        let mut num_blocks = ceil_pow2(capacity as u64 / (count_minus_one as u64)) as usize;
261        if count_minus_one as usize * num_blocks < capacity { num_blocks <<= 1; }
262        (num_blocks, num_blocks * (count_minus_one as usize))
263    }
264
265    fn rehash(&mut self) {
266        let x = self.table_blocks << 1;
267        self.rehash_for_blocks(x);
268    }
269
270    fn rehash_for_blocks(&mut self, new_block_count: usize) {
271        let old_table = self.table;
272        let old_block_count = self.table_blocks;
273        self.table = <ClaMap<K, V, H>>::allocate_table_for_blocks(new_block_count);
274        self.table_blocks = new_block_count;
275        self.flagged_blocks = 0;
276        self.max_occupied = new_block_count * ((self.block_kv_count - 1) as usize);
277        unsafe {
278            let mut ptr: *mut u64 = old_table as *mut u64;
279            let table_end = old_table.add(old_block_count << 3);
280            while ptr < table_end {
281                let block_ptr = ptr as *mut u8;
282                let control_ptr = block_ptr.offset(63);
283                let count = (*control_ptr & 0x7F) as i8;
284                let mut kptr: *mut K = block_ptr as *mut K;
285                let mut vptr: *mut V = block_ptr.offset(self.block_v_offset as isize) as *mut V;
286                let mut i = 0;
287                while i < count {
288                    let (insert_block_ptr, index, _bucket_state) = self.probe(&(*kptr));
289                    ptr::copy_nonoverlapping(kptr, <ClaMap<K, V, H>>::key_at(insert_block_ptr, index), 1);
290                    ptr::copy_nonoverlapping(vptr, self.value_at(insert_block_ptr, index), 1);
291                    let flagged = <ClaMap<K, V, H>>::increment_control(insert_block_ptr, self.block_kv_count);
292                    if flagged {
293                        self.flagged_blocks += 1;
294                    }
295                    i += 1;
296                    kptr = kptr.add(1);
297                    vptr = vptr.add(1);
298                }
299                ptr = ptr.add(8);
300            }
301            let layout = Layout::from_size_align(64 * old_block_count, 64).unwrap();
302            alloc::dealloc(old_table as *mut u8, layout);
303        }
304    }
305}
306
307impl<K: Eq + Hash + Debug, V: Debug, H: BuildHasher> Drop for ClaMap<K, V, H> {
308    fn drop(&mut self) {
309        if self.table_blocks > 0 {
310            unsafe {
311                if self.occupied > 0 {
312                    let mut ptr: *mut u64 = self.table;
313                    let table_end = self.table.add(self.table_blocks << 3);
314                    while ptr < table_end {
315                        let block_ptr = ptr as *mut u8;
316                        let control_ptr = block_ptr.offset(63);
317                        let count = (*control_ptr & 0x7F) as i8;
318                        let mut kptr: *mut K = block_ptr as *mut K;
319                        let mut i = 0;
320                        while i < count {
321                            ptr::drop_in_place(kptr);
322                            i += 1;
323                            kptr = kptr.add(1);
324                        }
325                        if mem::size_of::<V>() > 0 {
326                            i = 0;
327                            let mut vptr = block_ptr.offset(self.block_v_offset as isize) as *mut V;
328                            while i < count {
329                                ptr::drop_in_place(vptr);
330                                i += 1;
331                                vptr = vptr.add(1);
332                            }
333                        }
334                        ptr = ptr.add(8);
335                    }
336                }
337                let layout = Layout::from_size_align(self.table_blocks * 64, 64).unwrap();
338                alloc::dealloc(self.table as *mut u8, layout);
339            }
340        }
341    }
342}
343
344impl<K, V, S> Debug for ClaMap<K, V, S>
345    where K: Eq + Hash + Debug,
346          V: Debug,
347          S: BuildHasher
348{
349    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
350        let mut debug_map = f.debug_map();
351        if self.table_blocks > 0 && self.occupied > 0 {
352            unsafe {
353                let mut ptr: *mut u64 = self.table;
354                let table_end = self.table.add(self.table_blocks << 3);
355                while ptr < table_end {
356                    let block_ptr = ptr as *mut u8;
357                    let control_ptr = block_ptr.offset(63);
358                    let count = (*control_ptr & 0x7F) as i8;
359                    let mut kptr: *mut K = block_ptr as *mut K;
360                    let mut vptr = block_ptr.offset(self.block_v_offset as isize) as *mut V;
361                    let mut i = 0;
362                    while i < count {
363                        debug_map.entry(&(*kptr), &(*vptr));
364                        i += 1;
365                        kptr = kptr.add(1);
366                        vptr = vptr.add(1);
367                    }
368                    ptr = ptr.add(8);
369                }
370            }
371        }
372        debug_map.finish()
373    }
374}
375
376impl<K, V, S> ClaMap<K, V, S>
377    where K: Eq + Hash + Debug,
378          V: Debug,
379          S: BuildHasher
380{
381    pub fn debug(&self) {
382        println!("occupied {}, table_blocks {}, flagged_blocks {}", self.occupied, self.table_blocks, self.flagged_blocks);
383        if self.table_blocks > 0 && self.occupied > 0 {
384            unsafe {
385                let mut ptr: *mut u64 = self.table;
386                let table_end = self.table.add(self.table_blocks << 3);
387                while ptr < table_end {
388                    let block_ptr = ptr as *mut u8;
389                    let control_ptr = block_ptr.offset(63);
390                    let count = (*control_ptr & 0x7F) as i8;
391                    let mut kptr: *mut K = block_ptr as *mut K;
392                    let mut vptr = block_ptr.offset(self.block_v_offset as isize) as *mut V;
393                    let mut i = 0;
394                    while i < count {
395                        println!("[{:?},{:?}]", &(*kptr), &(*vptr));
396                        i += 1;
397                        kptr = kptr.add(1);
398                        vptr = vptr.add(1);
399                    }
400                    ptr = ptr.add(8);
401                }
402            }
403        }
404    }
405}
406
407pub fn calculate_sizes<K, V>() -> (i8, i8) {
408    let k_size = mem::size_of::<K>();
409    let v_size = mem::size_of::<V>();
410    let mut nominal_count = 63 as usize / (k_size + v_size);
411    if nominal_count == 0 { panic!("Key-value size is too large!"); }
412    if v_size > 0 {
413        let v_align = mem::align_of::<V>();
414        let mut v_align_offset_count;
415        loop {
416            v_align_offset_count = k_size * nominal_count / v_align;
417            if v_align_offset_count * v_align < k_size * nominal_count {
418                v_align_offset_count += 1;
419            }
420            if v_size * nominal_count + v_align_offset_count * v_align > 63 {
421                nominal_count -= 1;
422            } else { break; };
423        }
424        return (nominal_count as i8, (v_align_offset_count * v_align) as i8);
425    }
426    (nominal_count as i8, 0)
427}
428
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn allocator_tests() {
436        let (count, v_start) = calculate_sizes::<u32, i32>();
437        assert_eq!((7, 28), (count, v_start));
438
439        let x = calculate_sizes::<i8, u64>();
440        assert_eq!((6, 8), x);
441
442        let x = calculate_sizes::<u8, ()>();
443        assert_eq!((63, 0), x);
444
445        let x = calculate_sizes::<i32, u64>();
446        assert_eq!((4, 16), x);
447
448        let x = calculate_sizes::<u64, i32>();
449        assert_eq!((5, 40), x);
450
451        let x = calculate_sizes::<u64, u8>();
452        assert_eq!((7, 56), x);
453    }
454
455    #[test]
456    #[should_panic]
457    fn allocator_panic_test() {
458        calculate_sizes::<(u64, u64, u64, u64), (i64, i64, i64, i64)>();
459    }
460}
461