use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicUsize, Ordering};
pub struct LinkedList<T> {
list_ptr: *mut u8,
head_tail: HeadTail<T>,
len: usize,
}
unsafe impl<T> Send for LinkedList<T> where T: Send {}
unsafe impl<T> Sync for LinkedList<T> where T: Sync {}
impl<T> Debug for LinkedList<T>
where
T: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_list().entries(self.iter()).finish()
}
}
impl<T> Default for LinkedList<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> LinkedList<T> {
pub fn new() -> Self {
Self {
list_ptr: Box::into_raw(Box::new(0)),
head_tail: HeadTail {
head: std::ptr::null_mut(),
tail: std::ptr::null_mut(),
},
len: 0,
}
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn push_head(&mut self, value: T) -> NodeRef<T> {
self.len += 1;
let mut node = Node::new(value);
assert_eq!(
self.head_tail.head.is_null(),
self.head_tail.tail.is_null(),
"head and tail should both be null or non-null"
);
let head = unsafe { self.head_tail.head.as_mut() };
match head {
None => {
let node_ptr = Box::into_raw(Box::new(node));
self.head_tail.head = node_ptr;
self.head_tail.tail = node_ptr;
NodeRef::new(self.list_ptr, node_ptr)
}
Some(head) => {
node.next = self.head_tail.head;
let node_ptr = Box::into_raw(Box::new(node));
head.prev = node_ptr;
self.head_tail.head = node_ptr;
NodeRef::new(self.list_ptr, node_ptr)
}
}
}
pub fn push_tail(&mut self, value: T) -> NodeRef<T> {
self.len += 1;
let mut node = Node::new(value);
assert_eq!(
self.head_tail.head.is_null(),
self.head_tail.tail.is_null(),
"head and tail should both be null or non-null"
);
let tail = unsafe { self.head_tail.tail.as_mut() };
match tail {
None => {
let node_ptr = Box::into_raw(Box::new(node));
self.head_tail.head = node_ptr;
self.head_tail.tail = node_ptr;
NodeRef::new(self.list_ptr, node_ptr)
}
Some(tail) => {
node.prev = self.head_tail.tail;
let node_ptr = Box::into_raw(Box::new(node));
tail.next = node_ptr;
self.head_tail.tail = node_ptr;
NodeRef::new(self.list_ptr, node_ptr)
}
}
}
pub fn push_before(&mut self, node_ref: &NodeRef<T>, value: T) -> Result<NodeRef<T>, T> {
let Some(other_node) = self.get_node_mut(node_ref) else {
return Err(value);
};
let mut node = Node::new(value);
node.prev = other_node.prev;
node.next = node_ref.node_ptr;
let prev_ptr = node.prev;
let node_ptr = Box::into_raw(Box::new(node));
other_node.prev = node_ptr;
let prev = unsafe { prev_ptr.as_mut() };
if let Some(prev) = prev {
prev.next = node_ptr;
} else {
self.head_tail.head = node_ptr;
}
self.len += 1;
Ok(NodeRef::new(self.list_ptr, node_ptr))
}
pub fn push_after(&mut self, node_ref: &NodeRef<T>, value: T) -> Result<NodeRef<T>, T> {
let Some(other_node) = self.get_node_mut(node_ref) else {
return Err(value);
};
let mut node = Node::new(value);
node.prev = node_ref.node_ptr;
node.next = other_node.next;
let next_ptr = node.next;
let node_ptr = Box::into_raw(Box::new(node));
other_node.next = node_ptr;
let next = unsafe { next_ptr.as_mut() };
if let Some(next) = next {
next.prev = node_ptr;
} else {
self.head_tail.tail = node_ptr;
}
self.len += 1;
Ok(NodeRef::new(self.list_ptr, node_ptr))
}
pub fn pop_head(&mut self) -> Option<T> {
assert_eq!(
self.head_tail.head.is_null(),
self.head_tail.tail.is_null(),
"head and tail should both be null or non-null"
);
if self.head_tail.head.is_null() {
return None;
}
self.remove_node(self.head_tail.head)
}
pub fn pop_tail(&mut self) -> Option<T> {
assert_eq!(
self.head_tail.head.is_null(),
self.head_tail.tail.is_null(),
"head and tail should both be null or non-null"
);
self.remove_node(self.head_tail.tail)
}
pub fn get(&self, node_ref: &NodeRef<T>) -> Option<&T> {
self.get_node(node_ref)
.map(|node| node.value.as_ref().expect("value"))
}
pub fn get_mut(&mut self, node_ref: &NodeRef<T>) -> Option<&mut T> {
self.get_node_mut(node_ref)
.map(|node| node.value.as_mut().expect("value"))
}
fn get_node(&self, node_ref: &NodeRef<T>) -> Option<&Node<T>> {
if node_ref.list_ptr != self.list_ptr {
return None;
}
let node = unsafe { node_ref.node_ptr.as_ref()? };
if node.value.is_none() {
None
} else {
Some(node)
}
}
fn get_node_mut(&mut self, node_ref: &NodeRef<T>) -> Option<&mut Node<T>> {
if node_ref.list_ptr != self.list_ptr {
return None;
}
let node = unsafe { node_ref.node_ptr.as_mut()? };
if node.value.is_none() {
None
} else {
Some(node)
}
}
pub fn head(&self) -> Option<&T> {
unsafe {
self.head_tail
.head
.as_ref()
.and_then(|node| node.value.as_ref())
}
}
pub fn tail(&self) -> Option<&T> {
unsafe {
self.head_tail
.tail
.as_ref()
.and_then(|node| node.value.as_ref())
}
}
pub fn iter(&self) -> Iter<'_, T> {
Iter::new(HeadTail {
head: self.head_tail.head,
tail: self.head_tail.tail,
})
}
pub fn iter_mut(&mut self) -> IterMut<'_, T> {
IterMut::new(HeadTail {
head: self.head_tail.head,
tail: self.head_tail.tail,
})
}
pub fn into_iter(self) -> IntoIter<T> {
IntoIter::new(self)
}
pub fn remove(&mut self, mut node_ref: NodeRef<T>) -> Option<T> {
if node_ref.list_ptr != self.list_ptr {
return None;
}
let value = self.remove_node(node_ref.node_ptr)?;
node_ref.list_ptr = std::ptr::null_mut();
drop(node_ref);
self.len -= 1;
Some(value)
}
fn remove_node(&mut self, node_ptr: *mut Node<T>) -> Option<T> {
let node = unsafe { node_ptr.as_mut()? };
assert_eq!(
self.head_tail.head.is_null(),
self.head_tail.tail.is_null(),
"head and tail should both be null or non-null"
);
let value = node.value.take()?;
if !node.prev.is_null() {
unsafe { (*node.prev).next = node.next }
} else {
self.head_tail.head = node.next
}
if !node.next.is_null() {
unsafe { (*node.next).prev = node.prev }
} else {
self.head_tail.tail = node.prev
}
if self.head_tail.head.is_null() && !self.head_tail.tail.is_null() {
self.head_tail.head = self.head_tail.tail
} else if !self.head_tail.head.is_null() && self.head_tail.tail.is_null() {
self.head_tail.tail = self.head_tail.head
}
maybe_drop_node(node_ptr);
Some(value)
}
}
impl<T> FromIterator<T> for LinkedList<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let mut list = Self::new();
for item in iter {
list.push_tail(item);
}
list
}
}
impl<'a, T> IntoIterator for &'a LinkedList<T> {
type Item = &'a T;
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<T> IntoIterator for LinkedList<T> {
type Item = T;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
self.into_iter()
}
}
impl<T> Drop for LinkedList<T> {
fn drop(&mut self) {
unsafe { drop(Box::from_raw(self.list_ptr)) }
let mut node_ptr = self.head_tail.head;
while let Some(node) = unsafe { node_ptr.as_ref() } {
let next_ptr = node.next;
maybe_drop_node(node_ptr);
node_ptr = next_ptr;
}
}
}
pub struct Iter<'a, T> {
head_tail: HeadTail<T>,
phantom: PhantomData<&'a T>,
}
impl<'a, T> Iter<'a, T> {
fn new(head_tail: HeadTail<T>) -> Self {
Self {
head_tail,
phantom: Default::default(),
}
}
fn check_done(&mut self) {
if self.head_tail.head == self.head_tail.tail {
self.head_tail.head = std::ptr::null_mut();
self.head_tail.tail = std::ptr::null_mut();
}
}
}
impl<'a, T> Iterator for Iter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<&'a T> {
let node = unsafe { self.head_tail.head.as_ref()? };
self.check_done();
self.head_tail.head = node.next;
Some(node.value.as_ref().expect("value"))
}
}
impl<'a, T> DoubleEndedIterator for Iter<'a, T> {
fn next_back(&mut self) -> Option<Self::Item> {
let node = unsafe { self.head_tail.tail.as_ref()? };
self.check_done();
self.head_tail.tail = node.prev;
Some(node.value.as_ref().expect("value"))
}
}
pub struct IterMut<'a, T> {
head_tail: HeadTail<T>,
phantom: PhantomData<&'a T>,
}
impl<'a, T> IterMut<'a, T> {
fn new(head_tail: HeadTail<T>) -> Self {
Self {
head_tail,
phantom: Default::default(),
}
}
fn check_done(&mut self) {
if self.head_tail.head == self.head_tail.tail {
self.head_tail.head = std::ptr::null_mut();
self.head_tail.tail = std::ptr::null_mut();
}
}
}
impl<'a, T> Iterator for IterMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<&'a mut T> {
let node = unsafe { self.head_tail.head.as_mut()? };
self.check_done();
self.head_tail.head = node.next;
Some(node.value.as_mut().expect("value"))
}
}
impl<'a, T> DoubleEndedIterator for IterMut<'a, T> {
fn next_back(&mut self) -> Option<Self::Item> {
let node = unsafe { self.head_tail.tail.as_mut()? };
self.check_done();
self.head_tail.tail = node.prev;
Some(node.value.as_mut().expect("value"))
}
}
pub struct IntoIter<T> {
list: LinkedList<T>,
}
impl<T> IntoIter<T> {
fn new(list: LinkedList<T>) -> Self {
Self { list }
}
}
impl<T> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.list.pop_head()
}
}
impl<T> DoubleEndedIterator for IntoIter<T> {
fn next_back(&mut self) -> Option<Self::Item> {
self.list.pop_tail()
}
}
#[derive(Clone)]
struct HeadTail<T> {
head: *mut Node<T>,
tail: *mut Node<T>,
}
struct Node<T> {
value: Option<T>,
prev: *mut Node<T>,
next: *mut Node<T>,
ref_count: AtomicUsize,
}
fn maybe_drop_node<T>(node_ptr: *mut Node<T>) {
let node = unsafe { node_ptr.as_ref().expect("node") };
let ref_count = node.ref_count.fetch_sub(1, Ordering::Relaxed);
if ref_count == 1 {
let node = unsafe { Box::from_raw(node_ptr) };
drop(node);
}
}
pub struct NodeRef<T> {
list_ptr: *mut u8,
node_ptr: *mut Node<T>,
}
unsafe impl<T> Send for NodeRef<T> where T: Send {}
unsafe impl<T> Sync for NodeRef<T> where T: Sync {}
impl<T> Clone for NodeRef<T> {
fn clone(&self) -> Self {
unsafe {
self.node_ptr
.as_mut()
.expect("deref")
.ref_count
.fetch_add(1, Ordering::Relaxed);
}
Self {
list_ptr: self.list_ptr,
node_ptr: self.node_ptr,
}
}
}
impl<T> NodeRef<T> {
fn new(list_ptr: *mut u8, node_ptr: *mut Node<T>) -> Self {
Self { list_ptr, node_ptr }
}
}
impl<T> Drop for NodeRef<T> {
fn drop(&mut self) {
maybe_drop_node(self.node_ptr);
}
}
impl<T> AsRef<T> for Node<T> {
fn as_ref(&self) -> &T {
self.value.as_ref().expect("value")
}
}
impl<T> Node<T> {
fn new(value: T) -> Node<T> {
Node {
value: Some(value),
prev: std::ptr::null_mut(),
next: std::ptr::null_mut(),
ref_count: AtomicUsize::from(2),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn push_tail_remove_drop() {
let mut list = LinkedList::<i32>::new();
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
let node_ref = list.push_tail(123);
assert!(matches!(list.head(), Some(&123)));
assert!(matches!(list.tail(), Some(&123)));
list.remove(node_ref);
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
drop(list);
}
#[test]
fn push_head_remove_drop() {
let mut list = LinkedList::<i32>::new();
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
let node_ref = list.push_head(123);
assert!(matches!(list.head(), Some(&123)));
assert!(matches!(list.tail(), Some(&123)));
list.remove(node_ref);
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
drop(list);
}
#[test]
fn collect_iter() {
let list: LinkedList<_> = (1..=4i32).collect();
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&1, &2, &3, &4]);
assert_eq!(list.iter().rev().collect::<Vec<_>>(), vec![&4, &3, &2, &1]);
}
#[test]
fn into_iter_ref() {
let mut range = 1..=4i32;
let list: LinkedList<_> = range.clone().collect();
for got in &list {
let want = range.next().unwrap();
assert_eq!(&want, got);
}
}
#[test]
fn pop_head_tail() {
let mut list: LinkedList<_> = (1..=4i32).collect();
assert_eq!(list.pop_head(), Some(1));
assert_eq!(list.pop_tail(), Some(4));
assert_eq!(list.pop_head(), Some(2));
assert_eq!(list.pop_tail(), Some(3));
assert_eq!(list.pop_head(), None);
assert_eq!(list.pop_tail(), None);
}
#[test]
fn clone_node_ref() {
let mut list = LinkedList::new();
let node_ref1 = list.push_tail(1);
let node_ref2 = node_ref1.clone();
let node_ref3 = node_ref2.clone();
assert_eq!(list.get(&node_ref1), Some(&1));
assert_eq!(list.get(&node_ref2), Some(&1));
assert_eq!(list.get(&node_ref3), Some(&1));
drop(node_ref1);
assert_eq!(list.get(&node_ref2), Some(&1));
assert_eq!(list.get(&node_ref3), Some(&1));
drop(node_ref2);
assert_eq!(list.get(&node_ref3), Some(&1));
drop(node_ref3);
assert_eq!(list.pop_head(), Some(1));
assert_eq!(list.pop_head(), None);
}
#[test]
fn add_remove_multiple_drop() {
let mut list = LinkedList::<i32>::new();
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
let node_ref1 = list.push_tail(1);
let node_ref2 = list.push_tail(2);
let node_ref3 = list.push_tail(3);
let node_ref4 = list.push_tail(4);
assert!(matches!(list.head(), Some(&1)));
assert!(matches!(list.tail(), Some(&4)));
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&1, &2, &3, &4]);
list.remove(node_ref1);
assert!(matches!(list.head(), Some(&2)));
assert!(matches!(list.tail(), Some(&4)));
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&2, &3, &4]);
list.remove(node_ref3);
assert!(matches!(list.head(), Some(&2)));
assert!(matches!(list.tail(), Some(&4)));
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&2, &4]);
list.remove(node_ref4);
assert!(matches!(list.head(), Some(&2)));
assert!(matches!(list.tail(), Some(&2)));
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&2]);
list.remove(node_ref2);
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
assert_eq!(list.iter().collect::<Vec<_>>(), Vec::<&i32>::new());
}
#[test]
fn push_head() {
let mut list = LinkedList::<i32>::new();
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
list.push_head(1);
list.push_head(2);
list.push_head(3);
list.push_head(4);
assert!(matches!(list.head(), Some(&4)));
assert!(matches!(list.tail(), Some(&1)));
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&4, &3, &2, &1]);
assert_eq!(list.iter().rev().collect::<Vec<_>>(), vec![&1, &2, &3, &4]);
}
#[test]
fn push_before() {
let mut list = LinkedList::<i32>::new();
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
let node_ref = list.push_head(1);
list.push_before(&node_ref, 2).expect("ok");
list.push_before(&node_ref, 3).expect("ok");
list.push_before(&node_ref, 4).expect("ok");
assert!(matches!(list.head(), Some(&2)));
assert!(matches!(list.tail(), Some(&1)));
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&2, &3, &4, &1]);
assert_eq!(list.iter().rev().collect::<Vec<_>>(), vec![&1, &4, &3, &2]);
}
#[test]
fn push_before_ref() {
let mut list = LinkedList::<i32>::new();
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
let node_ref = list.push_head(1);
let node_ref = list.push_before(&node_ref, 2).expect("ok");
let node_ref = list.push_before(&node_ref, 3).expect("ok");
list.push_before(&node_ref, 4).expect("ok");
assert!(matches!(list.head(), Some(&4)));
assert!(matches!(list.tail(), Some(&1)));
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&4, &3, &2, &1]);
assert_eq!(list.iter().rev().collect::<Vec<_>>(), vec![&1, &2, &3, &4]);
}
#[test]
fn push_after() {
let mut list = LinkedList::<i32>::new();
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
let node_ref = list.push_tail(1);
list.push_after(&node_ref, 2).expect("ok");
list.push_after(&node_ref, 3).expect("ok");
list.push_after(&node_ref, 4).expect("ok");
assert!(matches!(list.head(), Some(&1)));
assert!(matches!(list.tail(), Some(&2)));
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&1, &4, &3, &2]);
assert_eq!(list.iter().rev().collect::<Vec<_>>(), vec![&2, &3, &4, &1]);
}
#[test]
fn push_after_ref() {
let mut list = LinkedList::<i32>::new();
assert!(matches!(list.head(), None));
assert!(matches!(list.tail(), None));
let node_ref = list.push_head(1);
let node_ref = list.push_after(&node_ref, 2).expect("ok");
let node_ref = list.push_after(&node_ref, 3).expect("ok");
list.push_after(&node_ref, 4).expect("ok");
assert!(matches!(list.head(), Some(&1)));
assert!(matches!(list.tail(), Some(&4)));
assert_eq!(list.iter().collect::<Vec<_>>(), vec![&1, &2, &3, &4]);
assert_eq!(list.iter().rev().collect::<Vec<_>>(), vec![&4, &3, &2, &1]);
}
#[test]
fn iter_mut() {
let mut list: LinkedList<_> = (0..10).collect();
for item in list.iter_mut() {
*item += 1;
}
assert_eq!(
list.into_iter().collect::<Vec<_>>(),
(1..11).collect::<Vec<_>>()
);
}
}