use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
use std::thread;
mod linked_list;
use self::linked_list::{LinkedList, Node};
const OSC: Ordering = Ordering::SeqCst;
const REFRESH_RATE: usize = 100;
struct Table<K, V> {
nbuckets: usize,
map: Vec<LinkedList<K, V>>,
nitems: AtomicUsize,
}
impl<K, V> Table<K, V> {
fn new(num_of_buckets: usize) -> Self {
let mut t = Table {
nbuckets: num_of_buckets,
map: Vec::with_capacity(num_of_buckets),
nitems: AtomicUsize::new(0),
};
for _ in 0..num_of_buckets {
t.map.push(LinkedList::default());
}
t
}
}
impl<K, V> Table<K, V>
where
K: Hash + Ord,
V: Copy,
{
fn insert(&self, key: K, value: V, remove_nodes: &mut Vec<*mut Node<K, V>>) -> Option<*mut V> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash: usize = hasher.finish() as usize;
let index = hash % self.nbuckets;
let ret = self.map[index].insert(key, value, remove_nodes);
if ret.is_none() {
self.nitems.fetch_add(1, OSC);
}
ret
}
fn get(&self, key: &K, remove_nodes: &mut Vec<*mut Node<K, V>>) -> Option<V> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash: usize = hasher.finish() as usize;
let index = hash % self.nbuckets;
self.map[index].get(key, remove_nodes)
}
fn delete(&self, key: &K, remove_nodes: &mut Vec<*mut Node<K, V>>) -> Option<V> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash: usize = hasher.finish() as usize;
let index = hash % self.nbuckets;
let ret = self.map[index].delete(key, remove_nodes);
if ret.is_some() {
self.nitems.fetch_sub(1, OSC);
}
ret
}
}
pub struct MapHandle<K, V> {
map: Arc<Map<K, V>>,
epoch_counter: Arc<AtomicUsize>,
remove_nodes: Vec<*mut Node<K, V>>,
remove_val: Vec<*mut V>,
refresh: usize,
}
unsafe impl<K, V> Send for MapHandle<K, V>
where
K: Send + Sync,
V: Send,
{
}
impl<K, V> MapHandle<K, V> {
fn cleanup(&mut self) {
let mut started = Vec::new();
let handles_map = self.map.handles.read().unwrap();
for h in handles_map.iter() {
started.push(h.load(OSC));
}
for (i, h) in handles_map.iter().enumerate() {
if started[i] % 2 == 0 {
continue;
}
let mut check = h.load(OSC);
let mut iter = 0;
while (check <= started[i]) && (check % 2 == 1) {
if iter % 4 == 0 {
thread::yield_now();
}
check = h.load(OSC);
iter += 1;
}
}
for to_drop in &self.remove_nodes {
let n = unsafe { (&**to_drop).val.load(OSC) };
self.remove_val.push(n);
drop(unsafe { Box::from_raw(*to_drop) });
}
for to_drop in &self.remove_val {
drop(unsafe { Box::from_raw(*to_drop) });
}
self.remove_nodes = Vec::new();
self.remove_val = Vec::new();
}
}
impl<K, V> MapHandle<K, V>
where
K: Hash + Ord,
V: Copy,
{
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
self.refresh += 1;
self.epoch_counter.fetch_add(1, OSC);
let val = self.map.table.insert(key, value, &mut self.remove_nodes);
self.epoch_counter.fetch_add(1, OSC);
let mut ret = None;
if let Some(v) = val {
ret = Some(unsafe { *v });
self.remove_val.push(v);
}
if self.refresh == REFRESH_RATE {
self.refresh = 0;
self.cleanup();
}
ret
}
pub fn get(&mut self, key: &K) -> Option<V> {
self.refresh = (self.refresh + 1) % REFRESH_RATE;
self.epoch_counter.fetch_add(1, OSC);
let ret = self.map.table.get(key, &mut self.remove_nodes);
self.epoch_counter.fetch_add(1, OSC);
if self.refresh == REFRESH_RATE {
self.refresh = 0;
self.cleanup();
}
ret
}
pub fn remove(&mut self, key: &K) -> Option<V> {
self.refresh = (self.refresh + 1) % REFRESH_RATE;
self.epoch_counter.fetch_add(1, OSC);
let ret = self.map.table.delete(key, &mut self.remove_nodes);
self.epoch_counter.fetch_add(1, OSC);
if self.refresh == REFRESH_RATE {
self.refresh = 0;
self.cleanup();
}
ret
}
pub fn len(&self) -> usize {
self.map.table.nitems.load(OSC)
}
pub fn is_empty(&self) -> bool {
self.map.table.nitems.load(OSC) == 0
}
}
impl<K, V> Clone for MapHandle<K, V> {
fn clone(&self) -> Self {
let ret = Self {
map: Arc::clone(&self.map),
epoch_counter: Arc::new(AtomicUsize::new(0)),
remove_nodes: Vec::new(),
remove_val: Vec::new(),
refresh: 0,
};
let mut handles_vec = self.map.handles.write().unwrap(); handles_vec.push(Arc::clone(&ret.epoch_counter));
ret
}
}
pub struct Map<K, V> {
table: Table<K, V>,
handles: RwLock<Vec<Arc<AtomicUsize>>>, }
impl<K, V> Map<K, V> {
pub fn with_capacity(nbuckets: usize) -> MapHandle<K, V> {
let new_hashmap = Map {
table: Table::new(nbuckets),
handles: RwLock::new(Vec::new()),
};
let ret = MapHandle {
map: Arc::new(new_hashmap),
epoch_counter: Arc::new(AtomicUsize::new(0)),
remove_nodes: Vec::new(),
remove_val: Vec::new(),
refresh: 0,
};
let hashmap = Arc::clone(&ret.map);
let mut handles_vec = hashmap.handles.write().unwrap();
handles_vec.push(Arc::clone(&ret.epoch_counter));
ret
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{thread_rng, Rng};
use std::thread;
#[test]
fn hashmap_concurr() {
let handle = Map::with_capacity(8); let mut threads = vec![];
let nthreads = 5;
for _ in 0..nthreads {
let mut new_handle = handle.clone();
threads.push(thread::spawn(move || {
let num_iterations = 1000000;
for _ in 0..num_iterations {
let mut rng = thread_rng();
let val = rng.gen_range(0, 128);
let two = rng.gen_range(0, 3);
if two % 3 == 0 {
new_handle.insert(val, val);
} else if two % 3 == 1 {
let v = new_handle.get(&val);
if v.is_some() {
assert_eq!(v.unwrap(), val);
}
} else {
new_handle.remove(&val);
}
}
assert_eq!(new_handle.epoch_counter.load(OSC), num_iterations * 2);
}));
}
for t in threads {
t.join().unwrap();
}
}
#[test]
fn hashmap_delete() {
let mut handle = Map::with_capacity(8);
handle.insert(1, 3);
handle.insert(2, 5);
handle.insert(3, 8);
handle.insert(4, 3);
handle.insert(5, 4);
handle.insert(6, 5);
handle.insert(7, 3);
handle.insert(8, 3);
handle.insert(9, 3);
handle.insert(10, 3);
handle.insert(11, 3);
handle.insert(12, 3);
handle.insert(13, 3);
handle.insert(14, 3);
handle.insert(15, 3);
handle.insert(16, 3);
assert_eq!(handle.get(&1).unwrap(), 3);
assert_eq!(handle.remove(&1).unwrap(), 3);
assert_eq!(handle.get(&1), None);
assert_eq!(handle.remove(&2).unwrap(), 5);
assert_eq!(handle.remove(&16).unwrap(), 3);
assert_eq!(handle.get(&16), None);
}
#[test]
fn hashmap_basics() {
let mut new_hashmap = Map::with_capacity(8); new_hashmap.insert(1, 1);
new_hashmap.insert(2, 5);
new_hashmap.insert(12, 5);
new_hashmap.insert(13, 7);
new_hashmap.insert(0, 0);
new_hashmap.insert(20, 3);
new_hashmap.insert(3, 2);
new_hashmap.insert(4, 1);
assert_eq!(new_hashmap.insert(20, 5).unwrap(), 3); assert_eq!(new_hashmap.insert(3, 8).unwrap(), 2); assert_eq!(new_hashmap.insert(5, 5), None);
let cln = Arc::clone(&new_hashmap.map);
assert_eq!(cln.table.nitems.load(OSC), 9);
new_hashmap.insert(3, 8);
assert_eq!(new_hashmap.get(&20).unwrap(), 5);
assert_eq!(new_hashmap.get(&12).unwrap(), 5);
assert_eq!(new_hashmap.get(&1).unwrap(), 1);
assert_eq!(new_hashmap.get(&0).unwrap(), 0);
assert!(new_hashmap.get(&3).unwrap() != 2);
assert_eq!(new_hashmap.get(&20).unwrap(), 5);
assert_eq!(new_hashmap.get(&12).unwrap(), 5);
assert_eq!(new_hashmap.get(&1).unwrap(), 1);
assert_eq!(new_hashmap.get(&0).unwrap(), 0);
assert!(new_hashmap.get(&3).unwrap() != 2); }
}