use crate::map::{Map, ObjectMap};
use crate::spin::SpinLock;
use std::ops::Deref;
use std::sync::atomic::Ordering::{AcqRel, Acquire, Release};
use std::sync::atomic::{fence, AtomicUsize};
use std::sync::Arc;
const NONE_KEY: usize = !0 >> 1;
pub type NodeRef<T> = Arc<Node<T>>;
pub struct Node<T> {
lock: SpinLock<()>,
prev: AtomicUsize,
next: AtomicUsize,
obj: T,
}
pub struct LinkedObjectMap<T> {
map: ObjectMap<NodeRef<T>>,
head: AtomicUsize,
tail: AtomicUsize,
}
impl<T> LinkedObjectMap<T> {
pub fn with_capacity(cap: usize) -> Self {
LinkedObjectMap {
map: ObjectMap::with_capacity(cap),
head: AtomicUsize::new(NONE_KEY),
tail: AtomicUsize::new(NONE_KEY),
}
}
pub fn insert_front(&self, key: &usize, value: T) {
debug_assert_ne!(*key, NONE_KEY);
let backoff = crossbeam_utils::Backoff::new();
let new_front = Node::new(value, NONE_KEY, NONE_KEY);
if let Some(_) = self.map.insert(key, new_front.clone()) {
return;
}
let _new_guard = new_front.lock.lock();
loop {
let front = self.head.load(Acquire);
let front_node = self.map.get(&front);
let _front_guard = front_node.as_ref().map(|n| n.lock.lock());
if let Some(ref front_node) = front_node {
if front_node.get_prev() != NONE_KEY {
backoff.spin();
continue;
}
} else if front != NONE_KEY {
backoff.spin();
continue;
}
new_front.set_next(front);
if self.head.compare_and_swap(front, *key, AcqRel) == front {
if let Some(ref front_node) = front_node {
front_node.prev.store(*key, Release);
} else {
debug_assert_eq!(front, NONE_KEY);
self.tail.compare_and_swap(NONE_KEY, *key, AcqRel);
}
break;
} else {
backoff.spin();
}
}
}
pub fn insert_back(&self, key: &usize, value: T) {
debug_assert_ne!(*key, NONE_KEY);
let backoff = crossbeam_utils::Backoff::new();
let new_back = Node::new(value, NONE_KEY, NONE_KEY);
let _new_guard = new_back.lock.lock();
if let Some(_) = self.map.insert(key, new_back.clone()) {
return;
}
loop {
let back = self.tail.load(Acquire);
let back_node = self.map.get(&back);
let _back_guard = back_node.as_ref().map(|n| n.lock.lock());
if let Some(ref back_node) = back_node {
if back_node.get_next() != NONE_KEY {
backoff.spin();
continue;
}
} else if back != NONE_KEY {
backoff.spin();
continue;
}
new_back.set_prev(back);
if self.tail.compare_and_swap(back, *key, AcqRel) == back {
if let Some(ref back_node) = back_node {
back_node.next.store(*key, Release);
} else {
debug_assert_eq!(back, NONE_KEY);
self.head.compare_and_swap(NONE_KEY, *key, AcqRel);
}
break;
} else {
backoff.spin();
}
}
}
pub fn get(&self, key: &usize) -> Option<NodeRef<T>> {
self.map.get(key)
}
pub fn remove(&self, key: &usize) -> Option<NodeRef<T>> {
let val = self.map.get(key);
if let Some(val_node) = val {
self.remove_node(*key, val_node);
return self.map.remove(key);
} else {
return val;
}
}
fn remove_node(&self, key: usize, val_node: NodeRef<T>) {
let backoff = crossbeam_utils::Backoff::new();
loop {
let prev = val_node.get_prev();
let next = val_node.get_next();
let prev_node = self.map.get(&prev);
let next_node = self.map.get(&next);
if (prev != NONE_KEY && prev_node.is_none())
|| (next != NONE_KEY && next_node.is_none())
{
backoff.spin();
continue;
}
let _prev_guard = prev_node.as_ref().map(|n| n.lock.lock());
let _self_guard = val_node.lock.lock();
let _next_guard = next_node.as_ref().map(|n| n.lock.lock());
if {
prev_node
.as_ref()
.map(|n| n.get_next() != key)
.unwrap_or(false)
| (val_node.get_prev() != prev)
| (val_node.get_next() != next)
| next_node
.as_ref()
.map(|n| n.get_prev() != key)
.unwrap_or(false)
} {
backoff.spin();
continue;
}
prev_node.as_ref().map(|n| n.set_next(next));
next_node.as_ref().map(|n| n.set_prev(prev));
if prev_node.is_none() {
debug_assert_eq!(self.head.load(Acquire), key);
self.head.store(next, Release);
}
if next_node.is_none() {
debug_assert_eq!(self.tail.load(Acquire), key);
self.tail.store(prev, Release);
}
return;
}
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn contains_key(&self, key: &usize) -> bool {
self.map.contains_key(key)
}
pub fn all_pairs(&self) -> Vec<(usize, NodeRef<T>)> {
let mut res = vec![];
let mut node_key = self.head.load(Acquire);
loop {
if let Some(node) = self.map.get(&node_key) {
let new_node_key = node.get_next();
res.push((node_key, node));
node_key = new_node_key;
} else if node_key == NONE_KEY {
break;
} else {
unreachable!();
}
}
res
}
pub fn all_keys(&self) -> Vec<usize> {
let mut res = vec![];
let mut node_key = self.head.load(Acquire);
loop {
if let Some(node) = self.map.get(&node_key) {
res.push(node_key);
node_key = node.get_next();
} else if node_key == NONE_KEY {
break;
} else {
unreachable!();
}
}
res
}
pub fn all_values(&self) -> Vec<NodeRef<T>> {
let mut res = vec![];
let mut node_key = self.head.load(Acquire);
loop {
if let Some(node) = self.map.get(&node_key) {
node_key = node.get_next();
res.push(node);
} else if node_key == NONE_KEY {
break;
} else {
unreachable!();
}
}
res
}
pub fn iter(&self) -> LinkedMapIter<T> {
loop {
}
}
}
pub struct LinkedMapIter<'a, T> {
node: Arc<Node<T>>,
map: &'a LinkedObjectMap<T>
}
impl<T> Node<T> {
pub fn new(obj: T, prev: usize, next: usize) -> NodeRef<T> {
Arc::new(Self {
obj,
lock: SpinLock::new(()),
prev: AtomicUsize::new(prev),
next: AtomicUsize::new(next),
})
}
fn get_next(&self) -> usize {
self.next.load(Acquire)
}
fn get_prev(&self) -> usize {
self.prev.load(Acquire)
}
fn set_next(&self, new: usize) {
self.next.store(new, Release)
}
fn set_prev(&self, new: usize) {
self.prev.store(new, Release)
}
}
impl<T> Deref for Node<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.obj
}
}
#[cfg(test)]
mod test {
use super::*;
use std::{collections::HashSet, thread};
#[test]
pub fn linked_map_serial() {
let map = LinkedObjectMap::with_capacity(16);
for i in 0..1024 {
map.insert_front(&i, i);
}
for i in 1024..2048 {
map.insert_back(&i, i);
}
}
#[test]
pub fn linked_map_insertions() {
let _ = env_logger::try_init();
let linked_map = Arc::new(LinkedObjectMap::with_capacity(16));
let num_threads = num_cpus::get();
let mut threads = vec![];
let num_data = 999;
for i in 0..num_threads {
let map = linked_map.clone();
threads.push(thread::spawn(move || {
for j in 0..num_data {
let num = i * 1000 + j;
debug!("Insert {}", num);
if j % 2 == 1 {
map.insert_back(&num, num);
} else {
map.insert_front(&num, num);
}
}
map.all_keys();
map.all_values();
map.all_pairs();
}));
}
info!("Waiting for threads to finish");
for t in threads {
t.join().unwrap();
}
let mut num_set = HashSet::new();
for (key, node) in linked_map.all_pairs() {
let value = **node;
assert_eq!(key, value);
num_set.insert(key);
}
assert_eq!(num_set.len(), num_threads * num_data);
}
}