1extern crate crossbeam;
2extern crate rand;
3
4use std::sync::Arc;
5use std::sync::Mutex;
6use std::hash::{Hasher, BuildHasher, Hash};
7use std::cmp::min;
8use std::sync::atomic::Ordering::{Acquire, Release, Relaxed};
9use std::sync::atomic::AtomicUsize;
10use crossbeam::mem::epoch::Guard;
11use crossbeam::mem::epoch::{self, Atomic, Owned, Shared};
12use std::collections::hash_map::RandomState;
13
14const DEFAULT_SEGMENT_COUNT: u32 = 8;
15const DEFAULT_CAPACITY: u32 = 16;
16const DEFAULT_LOAD_FACTOR: f32 = 0.8;
17const MAX_CAPACITY: u32 = 1 << 30;
18const MAX_SEGMENT_COUNT: u32 = 1 << 12;
19const MIN_LOAD_FACTOR: f32 = 0.2;
20const MAX_LOAD_FACTOR: f32 = 1.0;
21
22pub struct ConcurrentHashMap<K: Eq + Hash + Sync + Clone, V: Sync + Clone, H: BuildHasher> {
43 inner: Arc<CHMInner<K, V>>,
44 hasher: H,
45}
46
47struct CHMInner<K: Eq + Hash + Sync + Clone, V: Sync + Clone> {
48 segments: Vec<CHMSegment<K, V>>,
49 bit_mask: u32,
50 mask_shift_count: u32,
51}
52
53struct CHMSegment<K: Eq + Hash + Sync + Clone, V: Sync + Clone> {
54 table: Atomic<Vec<Atomic<CHMEntry<K, V>>>>,
55 lock: Mutex<()>,
56 max_capacity: AtomicUsize,
59 len: AtomicUsize,
60}
61
62struct CHMEntry<K, V> {
63 hash: u32,
64 key: K,
65 value: V,
66 next: Atomic<CHMEntry<K, V>>
67}
68
69impl<K: Eq + Hash + Sync + Clone, V: Sync + Clone, H: BuildHasher + Clone> Clone for ConcurrentHashMap<K, V, H> {
72 fn clone(&self) -> ConcurrentHashMap<K, V, H> {
73 ConcurrentHashMap{ inner: self.inner.clone(), hasher: self.hasher.clone() }
74 }
75}
76
77
78impl<K: Eq + Hash + Sync + Clone, V: Sync + Clone> ConcurrentHashMap<K, V, RandomState> {
79 pub fn new() -> ConcurrentHashMap<K, V, RandomState> {
81 ConcurrentHashMap::new_with_options(DEFAULT_CAPACITY, DEFAULT_SEGMENT_COUNT, DEFAULT_LOAD_FACTOR, RandomState::new())
82 }
83}
84
85impl<K: Eq + Hash + Sync + Clone, V: Sync + Clone, H: BuildHasher> ConcurrentHashMap<K, V, H> {
86
87 pub fn new_with_options(capacity: u32, segments: u32, load_factor: f32, hasher: H) -> ConcurrentHashMap<K, V, H> {
99
100 let (capacity, segments, load_factor) = Self::check_params(capacity, segments, load_factor);
101
102 ConcurrentHashMap { inner: Arc::new(CHMInner::new(capacity, segments, load_factor)), hasher: hasher }
103 }
104
105 fn check_params(mut capacity: u32, mut segments: u32, mut load_factor: f32) -> (u32, u32, f32) {
107 assert!(!load_factor.is_nan());
108
109 segments = min(MAX_SEGMENT_COUNT, segments.checked_next_power_of_two().unwrap());
110 if load_factor > MAX_LOAD_FACTOR {
111 load_factor = MAX_LOAD_FACTOR;
112 }
113
114 if load_factor < MIN_LOAD_FACTOR {
115 load_factor = MIN_LOAD_FACTOR;
116 }
117
118 capacity = (capacity as f64/load_factor as f64) as u32;
119
120 capacity = min(MAX_CAPACITY, capacity);
121
122 capacity = capacity.checked_next_power_of_two().unwrap();
123
124 if capacity < segments {
125 capacity = segments;
126 }
127
128 assert!(capacity % segments == 0u32);
129 (capacity, segments, load_factor)
130 }
131
132 pub fn insert(&mut self, key: K, value: V) -> Option<V> {
135 let mut hasher = self.hasher.build_hasher();
136 self.inner.insert(key, value, &mut hasher)
137 }
138
139 pub fn get(&self, key: K) -> Option<V> {
141 let mut hasher = self.hasher.build_hasher();
142 self.inner.get(key, &mut hasher)
143 }
144
145 pub fn remove(&mut self, key: K) -> Option<V> {
148 let mut hasher = self.hasher.build_hasher();
149 self.inner.remove(key, &mut hasher)
150 }
151
152 pub fn len(&self) -> u32 {
155 self.inner.len()
156 }
157
158 pub fn entries(&self) -> Vec<(K, V)> {
159 self.inner.entries()
160 }
161
162}
163
164impl<K: Eq + Hash + Sync + Clone, V: Sync + Clone> CHMInner<K, V> {
165
166 fn new(capacity: u32, seg_count: u32, load_factor: f32) -> CHMInner<K, V> {
167 assert!(seg_count % 2 == 0 || seg_count == 1);
168 assert!(capacity % seg_count == 0);
169 assert!(capacity > 0);
170 assert!(load_factor <= MAX_LOAD_FACTOR);
171 assert!(load_factor >= MIN_LOAD_FACTOR);
172
173 let per_seg_capacity = capacity / seg_count;
174
175 let (bit_mask, shift_count) = Self::make_segment_bit_mask(seg_count);
176 let mut segments = Vec::with_capacity(seg_count as usize);
177
178 for _ in 0..seg_count {
179 segments.push(CHMSegment::new(per_seg_capacity, load_factor));
180 }
181
182 CHMInner { segments: segments, bit_mask: bit_mask, mask_shift_count: shift_count }
183
184 }
185
186 fn make_segment_bit_mask(seg_count: u32) -> (u32, u32) {
188 let mut bit_mask = seg_count - 1;
189 let mut shift_count = 0;
190 while bit_mask & 0b10000000000000000000000000000000 == 0 && bit_mask != 0 {
191 bit_mask <<= 1;
192 shift_count += 1;
193 }
194 (bit_mask, shift_count)
195 }
196
197 fn get_segment_from_hash(&self, mut hash: u32) -> u32 {
198 hash &= self.bit_mask;
199 hash >>= self.mask_shift_count;
200 hash
201 }
202
203 fn insert<H: Hasher>(&self, key: K, value: V, hasher: &mut H) -> Option<V> {
204 let (segment, hash) = self.get_hash_and_segment(&key, hasher);
205 self.segments[segment].insert(key, value, hash)
206 }
207
208 fn get<H: Hasher>(&self, key: K, hasher: &mut H) -> Option<V> {
209 let (segment, hash) = self.get_hash_and_segment(&key, hasher);
210 self.segments[segment].get(key, hash)
211 }
212
213 fn remove<H: Hasher>(&self, key: K, hasher: &mut H) -> Option<V> {
214 let (segment, hash) = self.get_hash_and_segment(&key, hasher);
215 self.segments[segment].remove(key, hash)
216 }
217
218 fn get_hash_and_segment<H: Hasher>(&self, key: &K, hasher: &mut H) -> (usize, u32) {
219 key.hash(hasher);
220 let hash = hasher.finish() as u32;
221 let segment = self.get_segment_from_hash(hash);
222 (segment as usize, hash)
223 }
224
225 fn len(&self) -> u32 {
226 self.segments.iter().fold(0, |acc, segment| acc + segment.len() as u32)
227 }
228
229 fn entries(&self) -> Vec<(K, V)> {
230 self.segments.iter().fold(Vec::new(), |mut acc, segment| { acc.extend_from_slice(&segment.entries()); acc })
231 }
232}
233
234impl<K: Eq + Hash + Sync + Clone, V: Sync + Clone> CHMSegment<K, V> {
235
236 fn new(capacity: u32, load_factor: f32) -> CHMSegment<K, V> {
237 debug_assert!(capacity % 2 == 0 || capacity == 1);
238 debug_assert!(capacity > 0);
239
240 let max_cap = (capacity as f32 * load_factor) as usize;
241
242 let segment = CHMSegment { table: Atomic::null(), lock: Mutex::new(()), len: AtomicUsize::new(0), max_capacity: AtomicUsize::new(max_cap)};
243 segment.table.store(Some(Owned::new(Self::new_table(capacity))), Relaxed);
247
248 segment
249 }
250
251 fn len(&self) -> usize {
252 self.len.load(Relaxed)
253 }
254
255 fn insert(&self, key: K, value: V, hash: u32) -> Option<V> {
256 let lock_guard = self.lock.lock().expect("Couldn't lock segment mutex");
261 let ret = self.insert_inner(key, value, hash, &self.table);
262 drop(lock_guard);
263 ret
264 }
265
266 fn insert_inner(&self, key: K, value: V, hash: u32, s_table: &Atomic<Vec<Atomic<CHMEntry<K, V>>>>) -> Option<V> {
272 let guard = epoch::pin();
273
274 let table = s_table.load(Relaxed, &guard).expect("Table should have been initialised on creation");
275 let hash_bucket = self.get_bucket(hash, table.len() as u32);
276 let mut ret = None;
277
278 let mut bucket = &table[hash_bucket as usize];
279 let new_node = self.create_new_entry(hash, key, value);
280 loop {
281 let bucket_data = bucket.load(Relaxed, &guard);
282 let entry = match bucket_data {
283 None => {
284 self.len.store(self.len() + 1, Relaxed);
287 break;
288 },
289 Some(data) => data
290 };
291
292 if entry.hash == new_node.hash && entry.key == new_node.key {
293 ret = Some(entry.value.clone());
295 new_node.next.store_shared(entry.next.load(Relaxed, &guard), Release);
296 break;
297 } else {
298 bucket = &entry.next;
299 }
300 }
301 let old_node = bucket.swap(Some(new_node), Release, &guard);
302 if let Some(old_node_content) = old_node {
303 unsafe {
304 guard.unlinked(old_node_content);
305 }
306 } else {
307 if self.len() >= self.max_capacity.load(Relaxed) {
309 self.grow(&guard);
310 }
311 }
312 ret
313 }
314
315 fn grow(&self, guard: &Guard) {
317 self.max_capacity.fetch_add(self.max_capacity.load(Relaxed), Relaxed);
318
319 let old_table = self.table.load(Relaxed, &guard).expect("Table should have been initialised on creation");
320
321 let new_table = Owned::new(Self::new_table(old_table.len() as u32 * 2));
322
323 for mut old_bucket in old_table.iter() {
324 while let Some(entry) = old_bucket.load(Relaxed, guard) {
325 let hash_bucket = self.get_bucket(entry.hash, new_table.len() as u32);
326 let mut new_bucket = &new_table[hash_bucket as usize];
327 while let Some(new_entry) = new_bucket.load(Relaxed, guard) {
328 new_bucket = &new_entry.next;
329 };
330 let new_entry = self.create_new_entry(entry.hash, entry.key.clone(), entry.value.clone());
331 new_bucket.store(Some(new_entry), Release);
332 old_bucket = &entry.next;
333 }
334 }
335
336 self.table.store(Some(new_table), Release);
337
338 unsafe {Self::destroy_table(old_table, guard)};
339
340 }
341
342 unsafe fn destroy_table(table: Shared<Vec<Atomic<CHMEntry<K, V>>>>, guard: &Guard) {
343 for mut bucket in table.iter() {
345 while let Some(entry) = bucket.load(Relaxed, guard) {
346 guard.unlinked(entry);
347 bucket = &entry.next;
348 }
349 }
350 guard.unlinked(table);
351 }
352
353 fn entries(&self) -> Vec<(K, V)> {
354 let mut xs = Vec::with_capacity(self.len());
355 let guard = epoch::pin();
356 let table = self.table.load(Acquire, &guard).unwrap();
357 for mut bucket in table.iter() {
358 while let Some(entry) = bucket.load(Acquire, &guard) {
359 let e = (entry.key.clone(), entry.value.clone());
360 xs.push(e);
361 bucket = &entry.next;
362 }
363 }
364 xs
365 }
366
367 fn get(&self, key: K, hash: u32) -> Option<V> {
368 let guard = epoch::pin();
369 let table = self.table.load(Acquire, &guard).expect("Table should have been initialised on creation");
370 let hash_bucket = self.get_bucket(hash, table.len() as u32);
371
372 let mut bucket = &table[hash_bucket as usize];
373
374 loop {
375 let bucket_data = bucket.load(Acquire, &guard);
376 let entry = match bucket_data {
377 None => {
378 return None;
379 },
380 Some(data) => data
381 };
382
383 if entry.hash == hash && entry.key == key {
384 return Some(entry.value.clone());
385 } else {
386 bucket = &entry.next;
387 }
388 }
389 }
390
391 fn remove(&self, key: K, hash: u32) -> Option<V> {
392 let lock_guard = self.lock.lock().unwrap();
393 let ret = self.remove_inner(key, hash);
394 drop(lock_guard);
395 ret
396 }
397
398 fn remove_inner(&self, key: K, hash: u32) -> Option<V> {
404 let guard = epoch::pin();
405
406 let table = self.table.load(Relaxed, &guard).expect("Table should have been initialised on creation");
407 let hash_bucket = self.get_bucket(hash, table.len() as u32);
408
409 let mut bucket = &table[hash_bucket as usize];
410 loop {
411 let bucket_data = bucket.load(Relaxed, &guard);
412 let entry = match bucket_data {
413 None => {
414 return None;
415 },
416 Some(data) => data
417 };
418
419 if entry.hash == hash && entry.key == key {
420 bucket.store_shared(entry.next.load(Relaxed, &guard), Release);
421 let ret = entry.value.clone();
422 self.len.fetch_sub(1, Relaxed);
423 unsafe {
424 guard.unlinked(entry);
425 }
426 return Some(ret);
427 } else {
428 bucket = &entry.next;
429 }
430 }
431 }
432
433
434 fn create_new_entry(&self, hash: u32, key: K, value: V) -> Owned<CHMEntry<K, V>> {
435 Owned::new(CHMEntry {
436 hash: hash,
437 key: key,
438 value: value,
439 next: Atomic::null()
440 })
441 }
442
443 fn get_bucket(&self, hash: u32, cap: u32) -> u32 {
444 hash & (cap - 1)
445 }
446
447 #[allow(dead_code)]
449 fn table_cap(&self) -> usize {
450 let guard = epoch::pin();
451 self.table.load(Acquire, &guard).expect("Table should have been initialised on creation").len()
452 }
453
454 #[allow(dead_code)]
456 fn max_cap(&self) -> usize {
457 self.max_capacity.load(Relaxed)
458 }
459
460 fn new_table(capacity: u32) -> Vec<Atomic<CHMEntry<K, V>>>{
461 let mut v = Vec::with_capacity(capacity as usize);
462 for _ in 0..capacity {
463 v.push(Atomic::null());
464 }
465 v
466 }
467
468 #[allow(dead_code)]
470 fn lock_then_do_work<F: Fn()>(&self, work: F) {
471 let lock_guard = self.lock.lock();
472 work();
473 drop (lock_guard);
474 }
475
476}
477
478impl<K: Eq + Hash + Sync + Clone, V: Sync + Clone> Drop for CHMSegment<K, V> {
479 fn drop(&mut self) {
480 let lock_guard = self.lock.lock().expect("Couldn't lock segment mutex");
481 let guard = epoch::pin();
482 unsafe {Self::destroy_table(self.table.load(Relaxed, &guard).unwrap(), &guard) };
483 drop(lock_guard);
484 }
485}
486
487#[cfg(test)]
488mod test {
489 use super::*;
490 use super::CHMSegment;
491 use super::CHMInner;
492 use std::sync::mpsc::sync_channel;
493 use std::thread;
494 use std::sync::Arc;
495 use std::hash::SipHasher;
496 use std::hash::{Hasher, BuildHasher, BuildHasherDefault, Hash};
497
498 #[test]
499 fn seg_bit_mask() {
500 assert_eq!(CHMInner::<u32,u32>::make_segment_bit_mask(16u32), (0b11110000000000000000000000000000u32, 28));
501 assert_eq!(CHMInner::<u32,u32>::make_segment_bit_mask(1u32), (0u32, 0u32));
502 assert_eq!(CHMInner::<u32,u32>::make_segment_bit_mask(2u32), (0b10000000000000000000000000000000u32, 31));
503 assert_eq!(CHMInner::<u32,u32>::make_segment_bit_mask(1024u32), (0b11111111110000000000000000000000u32, 22));
504 }
505
506 #[test]
507 fn settings_weird_load_factors() {
508 validate_chm_settings(16, 16, 32, 32, 1.0, 1.0);
509 validate_chm_settings(16, 16, 32, 256, 0.1, 0.2);
510 validate_chm_settings(16, 16, 32, 32, 1.1, 1.0);
511 }
512
513 #[test]
514 fn settings_weird_capacities() {
515 validate_chm_settings(12, 16, 30, 32, 1.0, 1.0);
516 validate_chm_settings(17, 32, 30, 32, 1.0, 1.0);
517 validate_chm_settings(17, 32, 10, 32, 1.0, 1.0);
518 }
519
520 #[test]
521 fn settings_weird_segments() {
522 }
523
524 #[test]
525 fn simple_insert_and_get() {
526 let mut chm = ConcurrentHashMap::<u32, u32, BuildHasherDefault<SipHasher>>::new_with_options(100, 1, 0.8, BuildHasherDefault::<SipHasher>::default());
527 chm.insert(1,100);
528 assert_eq!(chm.get(1), Some(100));
529 }
530
531 #[test]
532 fn simple_insert_and_get_other() {
533 let mut chm = ConcurrentHashMap::<u32, u32, BuildHasherDefault<SipHasher>>::new_with_options(100, 1, 0.8, BuildHasherDefault::<SipHasher>::default());
534 chm.insert(1,100);
535 assert_eq!(chm.get(2), None);
536 }
537
538 #[test]
539 fn simple_insert_and_remove() {
540 let mut chm = ConcurrentHashMap::<u32, u32, BuildHasherDefault<SipHasher>>::new_with_options(100, 1, 0.8, BuildHasherDefault::<SipHasher>::default());
541 chm.insert(1,100);
542 assert_eq!(chm.remove(1), Some(100));
543 assert_eq!(chm.get(1), None);
544 }
545
546 #[test]
547 fn many_insert_and_remove() {
548 let mut chm = ConcurrentHashMap::<u32, u32, BuildHasherDefault<SipHasher>>::new_with_options(16, 1, 1.0, BuildHasherDefault::<SipHasher>::default());
549 for i in 0..100 {
550 assert_eq!(chm.insert(i,i), None);
551 }
552
553 assert_eq!(chm.remove(101), None);
554
555 for i in 0..100 {
556 assert_eq!(chm.remove(i), Some(i));
557 }
558
559 for i in 0..100 {
560 assert_eq!(chm.get(i), None);
561 }
562
563 }
564
565 #[test]
566 fn many_insert_and_get_back() {
567 let mut chm = ConcurrentHashMap::<u32, u32, BuildHasherDefault<SipHasher>>::new_with_options(16, 1, 1.0, BuildHasherDefault::<SipHasher>::default());
568 assert_eq!(chm.entries(), Vec::new());
569 let v: Vec<(u32,u32)> = (0..100).map(|i| (i, i + 1)).collect();
570 for &(i,j) in v.iter() {
571 chm.insert(i, j);
572 }
573 let mut entries = chm.entries();
574 entries.sort();
575 assert_eq!(entries, v);
576 }
577 #[test]
580 fn many_insert_and_get() {
581 let mut chm = ConcurrentHashMap::<u32, u32, BuildHasherDefault<SipHasher>>::new_with_options(16, 1, 1.0, BuildHasherDefault::<SipHasher>::default());
582 for i in 0..100 {
583 chm.insert(i,i);
584 }
585 for i in 0..100 {
586 assert_eq!(chm.get(i), Some(i));
587 }
588 }
589
590 #[test]
591 fn many_insert_and_get_none() {
592 let mut chm = ConcurrentHashMap::<u32, u32, BuildHasherDefault<SipHasher>>::new_with_options(16, 1, 1.0, BuildHasherDefault::<SipHasher>::default());
593 for i in 0..100 {
594 chm.insert(i,i);
595 }
596 for i in 100..200 {
597 assert_eq!(chm.get(i), None);
598 }
599 }
600
601 #[test]
602 fn check_hash_collisions() {
603 let chm = CHMSegment::<u32, u32>::new(16, 1.0);
604 for i in 0..100 {
605 assert_eq!(chm.insert(i,i,0), None);
606 }
607 for i in 0..100 {
608 assert_eq!(chm.insert(i,i+1,0), Some(i));
609 }
610 }
611
612 #[test]
613 fn check_len() {
614 let chm = CHMSegment::<u32, u32>::new(16, 1.0);
615 assert_eq!(chm.max_cap(), 16);
616 assert_eq!(chm.table_cap(), 16);
617 assert_eq!(chm.len(), 0);
618 for i in 0..100 {
619 assert_eq!(chm.insert(i,i,i), None);
620 }
621 assert_eq!(chm.len(), 100);
622 assert_eq!(chm.max_cap(), 128);
623 assert_eq!(chm.table_cap(), 128);
624
625 for i in 0..100 {
626 assert_eq!(chm.insert(i,i+1,i), Some(i));
627 }
628 assert_eq!(chm.len(), 100);
629 assert_eq!(chm.max_cap(), 128);
630 assert_eq!(chm.table_cap(), 128);
631
632 for i in 0..100 {
633 assert_eq!(chm.remove(i,i), Some(i+1));
634 }
635 assert_eq!(chm.len(), 0);
636 assert_eq!(chm.max_cap(), 128);
637 assert_eq!(chm.table_cap(), 128);
638
639 }
640
641 #[test]
642 fn read_segment_while_locked() {
643 let chm = Arc::new(CHMSegment::<u32, u32>::new(16, 1.0));
644 for i in 0..100 {
645 chm.insert(i,i, i);
646 }
647 let chm_clone = chm.clone();
648 let (tx, rx) = sync_channel::<()>(0);
649 thread::spawn(move || {
650 chm_clone.lock_then_do_work(|| {
651 rx.recv().unwrap();
652 for i in 0..100 {
653 assert_eq!(chm_clone.insert_inner(i,i+1, i, &chm_clone.table), Some(i));
654 }
655 rx.recv().unwrap();
656 rx.recv().unwrap();
657 for i in 0..100 {
658 assert_eq!(chm_clone.remove_inner(i, i), Some(i+1));
659 }
660 rx.recv().unwrap();
661 })
662 });
663 for i in 0..100 {
664 assert_eq!(chm.get(i,i), Some(i));
665 }
666 tx.send(()).unwrap();
667 tx.send(()).unwrap();
668 for i in 0..100 {
669 assert_eq!(chm.get(i,i), Some(i+1));
670 }
671 tx.send(()).unwrap();
672 tx.send(()).unwrap();
673 for i in 0..100 {
674 assert_eq!(chm.get(i,i), None);
675 }
676 }
677
678 fn validate_chm_settings(seg_count: u32, expected_seg_count: u32,
679 capacity: u32, expected_capacity: u32,
680 load_factor: f32, expected_load_factor: f32) {
681
682 let (capacity_chk, seg_count_chk, load_factor_chk) = ConcurrentHashMap::<u32, u32, BuildHasherDefault<SipHasher>>::check_params(capacity, seg_count, load_factor);
683 assert_eq!(seg_count_chk, expected_seg_count);
684 assert_eq!(capacity_chk, expected_capacity);
685 assert_eq!(load_factor_chk, expected_load_factor);
686 }
687}