use std::ptr;
use std::sync::atomic::{AtomicPtr, Ordering};
const OSC: Ordering = Ordering::SeqCst;
#[derive(Debug)]
pub(super) struct Node<K, V> {
key: Option<K>,
pub val: AtomicPtr<V>,
next: AtomicPtr<Node<K, V>>,
}
impl<K, V> Node<K, V> {
fn empty() -> Self {
Node {
key: None,
val: AtomicPtr::new(ptr::null_mut()),
next: AtomicPtr::new(ptr::null_mut()),
}
}
fn new(key: K, val: V) -> Self {
let v = Box::new(val);
Node {
key: Some(key),
val: AtomicPtr::new(Box::into_raw(v)),
next: AtomicPtr::new(ptr::null_mut()),
}
}
}
#[derive(Debug)]
pub(super) struct LinkedList<K, V> {
head: AtomicPtr<Node<K, V>>,
tail: AtomicPtr<Node<K, V>>,
}
impl<K, V> Default for LinkedList<K, V> {
fn default() -> Self {
let head = Box::new(Node::empty());
let tail = Box::into_raw(Box::new(Node::empty()));
head.next.store(tail, OSC);
LinkedList {
head: AtomicPtr::new(Box::into_raw(head)),
tail: AtomicPtr::new(tail),
}
}
}
impl<K, V> LinkedList<K, V>
where
K: Ord,
V: Copy,
{
pub(super) fn insert(
&self,
key: K,
val: V,
remove_nodes: &mut Vec<*mut Node<K, V>>,
) -> Option<*mut V> {
let mut new_node = Box::new(Node::new(key, val));
let mut left_node = ptr::null_mut();
loop {
let right_node =
self.search(new_node.key.as_ref().unwrap(), &mut left_node, remove_nodes);
if right_node != self.tail.load(OSC) && unsafe { &*right_node }
.key
.as_ref()
.map(|k| k == new_node.key.as_ref().unwrap())
.unwrap_or(false)
{
let rn = unsafe { &*right_node };
let v = Box::new(val);
let old = rn.val.swap(Box::into_raw(v), OSC);
return Some(old);
}
new_node.next.store(right_node, OSC);
let new_node_ptr = Box::into_raw(new_node);
if unsafe { &*left_node }
.next
.compare_and_swap(right_node, new_node_ptr, OSC)
== right_node
{
return None;
}
new_node = unsafe { Box::from_raw(new_node_ptr) };
}
}
pub(super) fn get(&self, search_key: &K, remove_nodes: &mut Vec<*mut Node<K, V>>) -> Option<V> {
let mut left_node = ptr::null_mut();
let right_node = self.search(&search_key, &mut left_node, remove_nodes);
if right_node == self.tail.load(OSC) || unsafe { &*right_node }
.key
.as_ref()
.map(|k| k != search_key)
.unwrap_or(true)
{
None
} else {
unsafe { Some(*(&*right_node).val.load(OSC)) }
}
}
pub(super) fn delete(
&self,
search_key: &K,
remove_nodes: &mut Vec<*mut Node<K, V>>,
) -> Option<V> {
let mut left_node = ptr::null_mut();
let mut right_node;
let mut right_node_next;
loop {
right_node = self.search(search_key, &mut left_node, remove_nodes);
if (right_node == self.tail.load(OSC)) || unsafe { &*right_node }
.key
.as_ref()
.map(|k| k != search_key)
.unwrap_or(true)
{
return None; }
right_node_next = unsafe { &*right_node }.next.load(OSC);
if !Self::is_marked_reference(right_node_next)
&& unsafe { &*right_node }.next.compare_and_swap(
right_node_next,
Self::get_marked_reference(right_node_next),
OSC,
) == right_node_next
{
break;
}
}
let rn = unsafe { &*right_node };
let old = unsafe { *rn.val.load(OSC) };
if unsafe { &*left_node }
.next
.compare_and_swap(right_node, right_node_next, OSC)
!= right_node
{
let _ = self.search(
unsafe { &*right_node }.key.as_ref().unwrap(),
&mut left_node,
remove_nodes,
);
}
Some(old) }
fn is_marked_reference(ptr: *mut Node<K, V>) -> bool {
(ptr as usize & 0x1) == 1
}
fn get_marked_reference(ptr: *mut Node<K, V>) -> *mut Node<K, V> {
(ptr as usize | 0x1) as *mut _
}
fn get_unmarked_reference(ptr: *mut Node<K, V>) -> *mut Node<K, V> {
(ptr as usize & !0x1) as *mut _
}
fn search(
&self,
search_key: &K,
left_node: &mut *mut Node<K, V>,
remove_nodes: &mut Vec<*mut Node<K, V>>,
) -> *mut Node<K, V> {
let mut left_node_next = ptr::null_mut();
let mut right_node;
'search_again: loop {
let mut t = self.head.load(OSC);
let mut t_next = unsafe { &*t }.next.load(OSC);
loop {
if !Self::is_marked_reference(t_next) {
*left_node = t;
left_node_next = t_next;
}
t = Self::get_unmarked_reference(t_next);
if t == self.tail.load(OSC) {
break;
}
t_next = unsafe { &*t }.next.load(OSC);
if !Self::is_marked_reference(t_next) && unsafe { &*t }
.key
.as_ref()
.map(|k| k >= search_key)
.unwrap_or(false)
{
break;
}
}
right_node = t;
if left_node_next == right_node {
if right_node != self.tail.load(OSC)
&& Self::is_marked_reference(unsafe { &*right_node }.next.load(OSC))
{
continue 'search_again;
} else {
return right_node;
}
}
if unsafe { &**left_node }
.next
.compare_and_swap(left_node_next, right_node, OSC)
== left_node_next
{
let mut curr_node = left_node_next;
loop {
assert_eq!(Self::is_marked_reference(curr_node), false);
remove_nodes.push(curr_node);
curr_node = unsafe { &*curr_node }.next.load(OSC);
assert_eq!(Self::is_marked_reference(curr_node), true);
curr_node = Self::get_unmarked_reference(curr_node); if curr_node == right_node {
break;
}
}
if right_node != self.tail.load(OSC)
&& Self::is_marked_reference(unsafe { &*right_node }.next.load(OSC))
{
continue 'search_again;
} else {
return right_node;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn linkedlist_basics() {
let mut remove_nodes = Vec::new();
let new_linked_list = LinkedList::default();
println!("{:?}", new_linked_list);
new_linked_list.insert(3, 2, &mut remove_nodes);
new_linked_list.insert(3, 4, &mut remove_nodes);
new_linked_list.insert(5, 8, &mut remove_nodes);
new_linked_list.insert(4, 6, &mut remove_nodes);
new_linked_list.insert(1, 8, &mut remove_nodes);
new_linked_list.insert(6, 6, &mut remove_nodes);
assert_eq!(new_linked_list.get(&3, &mut remove_nodes).unwrap(), 4);
assert_eq!(new_linked_list.get(&5, &mut remove_nodes).unwrap(), 8);
assert_eq!(new_linked_list.get(&2, &mut remove_nodes), None);
}
#[test]
fn more_linked_list_tests() {
let mut remove_nodes = Vec::new();
let new_linked_list = LinkedList::default();
println!(
"Insert: {:?}",
new_linked_list.insert(5, 3, &mut remove_nodes)
);
println!(
"Insert: {:?}",
new_linked_list.insert(5, 8, &mut remove_nodes)
);
println!(
"Insert: {:?}",
new_linked_list.insert(2, 3, &mut remove_nodes)
);
println!("Get: {:?}", new_linked_list.get(&5, &mut remove_nodes));
new_linked_list.delete(&5, &mut remove_nodes);
}
}