use core::{
iter::{FromIterator, FusedIterator},
marker::PhantomData,
ops::{Deref, DerefMut},
ptr,
sync::atomic::{AtomicPtr, Ordering},
};
pub struct LinkedList<T> {
head: AtomicPtr<Node<T>>,
}
#[repr(C)]
pub struct Node<T> {
next: *mut Node<T>,
pub value: T,
}
pub struct IntoIter<T> {
list: LinkedList<T>,
}
pub struct IterMut<'a, T> {
curr: *mut Node<T>,
marker: PhantomData<&'a mut Node<T>>,
}
pub struct DrainFilter<'a, T, F>
where
F: FnMut(*mut Node<T>) -> bool,
{
raw: DrainFilterRaw<'a, T, F>,
}
pub struct DrainFilterRaw<'a, T, F>
where
F: FnMut(*mut Node<T>) -> bool,
{
head: &'a AtomicPtr<Node<T>>,
prev: *mut Node<T>,
curr: *mut Node<T>,
filter: F,
}
unsafe impl<T> Sync for LinkedList<T> {}
impl<T> LinkedList<T> {
#[inline]
pub const fn new() -> Self {
Self { head: AtomicPtr::new(ptr::null_mut()) }
}
#[inline]
pub fn is_empty(&self) -> bool {
self.head.load(Ordering::Relaxed).is_null()
}
#[inline]
pub fn push(&self, data: T) {
unsafe { self.push_raw(Box::into_raw(Box::new(Node::from(data)))) };
}
#[inline]
pub unsafe fn push_raw(&self, node: *mut Node<T>) {
loop {
let curr = self.head.load(Ordering::Relaxed);
unsafe { (*node).next = curr };
if self
.head
.compare_exchange_weak(curr, node, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
break;
}
}
}
#[inline]
pub fn pop(&self) -> Option<T> {
unsafe { self.pop_raw().map(|node| Box::from_raw(node).value) }
}
#[inline]
pub unsafe fn pop_raw(&self) -> Option<*mut Node<T>> {
loop {
let curr = self.head.load(Ordering::Acquire);
if curr.is_null() {
break None;
}
let next = unsafe { (*curr).next };
if self
.head
.compare_exchange_weak(curr, next, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
break Some(curr);
}
}
}
#[inline]
pub fn iter_mut(&mut self) -> IterMut<'_, T> {
unsafe { self.iter_mut_unchecked() }
}
#[inline]
pub unsafe fn iter_mut_unchecked(&self) -> IterMut<'_, T> {
IterMut { curr: self.head.load(Ordering::Acquire), marker: PhantomData }
}
#[inline]
pub fn drain_filter<'a, F: 'a>(
&'a mut self,
mut filter: F,
) -> DrainFilter<'_, T, impl FnMut(*mut Node<T>) -> bool>
where
F: FnMut(&mut T) -> bool,
{
unsafe { DrainFilter { raw: self.drain_filter_raw(move |node| filter(&mut *node)) } }
}
#[inline]
pub unsafe fn drain_filter_raw<F>(
&self,
filter: F,
) -> DrainFilterRaw<'_, T, impl FnMut(*mut Node<T>) -> bool>
where
F: FnMut(*mut Node<T>) -> bool,
{
DrainFilterRaw {
head: &self.head,
prev: ptr::null_mut(),
curr: self.head.load(Ordering::Acquire),
filter,
}
}
}
impl<T> Drop for LinkedList<T> {
#[inline]
fn drop(&mut self) {
let mut curr = self.head.load(Ordering::Acquire);
while !curr.is_null() {
let next = unsafe { (*curr).next };
drop(unsafe { Box::from_raw(curr) });
curr = next;
}
}
}
impl<T> Extend<T> for LinkedList<T> {
#[inline]
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for elem in iter {
self.push(elem);
}
}
}
impl<'a, T: 'a + Copy> Extend<&'a T> for LinkedList<T> {
#[inline]
fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
self.extend(iter.into_iter().cloned());
}
}
impl<T> FromIterator<T> for LinkedList<T> {
#[inline]
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let mut list = Self::new();
list.extend(iter);
list
}
}
impl<T> IntoIterator for LinkedList<T> {
type IntoIter = IntoIter<T>;
type Item = T;
#[inline]
fn into_iter(self) -> Self::IntoIter {
IntoIter { list: self }
}
}
impl<T> Iterator for IntoIter<T> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<T> {
let curr = self.list.head.load(Ordering::Acquire);
if curr.is_null() {
None
} else {
let next = unsafe { (*curr).next };
self.list.head.store(next, Ordering::Relaxed);
Some(unsafe { Box::from_raw(curr) }.value)
}
}
}
impl<T> FusedIterator for IntoIter<T> {}
impl<'a, T> Iterator for IterMut<'a, T> {
type Item = &'a mut T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.curr.is_null() {
None
} else {
let curr = self.curr;
self.curr = unsafe { (*self.curr).next };
Some(unsafe { &mut (*curr).value })
}
}
}
impl<T> FusedIterator for IterMut<'_, T> {}
impl<T, F> DrainFilter<'_, T, F>
where
F: FnMut(*mut Node<T>) -> bool,
{
#[inline]
pub fn is_end(&self) -> bool {
self.raw.is_end()
}
}
impl<T, F> Iterator for DrainFilter<'_, T, F>
where
F: FnMut(*mut Node<T>) -> bool,
{
type Item = T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.raw.next().map(|node| unsafe { Box::from_raw(node).value })
}
}
impl<T, F> FusedIterator for DrainFilter<'_, T, F> where F: FnMut(*mut Node<T>) -> bool {}
impl<T, F> DrainFilterRaw<'_, T, F>
where
F: FnMut(*mut Node<T>) -> bool,
{
#[inline]
pub fn is_end(&self) -> bool {
self.curr.is_null()
}
fn cut_out(&mut self, next: *mut Node<T>) {
if self.prev.is_null() {
match self.head.compare_exchange(self.curr, next, Ordering::Relaxed, Ordering::Relaxed)
{
Ok(_) => return,
Err(prev) => {
self.prev = prev;
while unsafe { (*self.prev).next } != self.curr {
self.prev = unsafe { (*self.prev).next };
}
}
}
}
unsafe { (*self.prev).next = next };
}
}
impl<T, F> Iterator for DrainFilterRaw<'_, T, F>
where
F: FnMut(*mut Node<T>) -> bool,
{
type Item = *mut Node<T>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
while !self.is_end() {
let next = unsafe { (*self.curr).next };
if (self.filter)(self.curr) {
self.cut_out(next);
let node = self.curr;
self.curr = next;
return Some(node);
}
self.prev = self.curr;
self.curr = next;
}
None
}
}
impl<T, F> Drop for DrainFilterRaw<'_, T, F>
where
F: FnMut(*mut Node<T>) -> bool,
{
#[inline]
fn drop(&mut self) {
self.for_each(drop);
}
}
impl<T, F> FusedIterator for DrainFilterRaw<'_, T, F> where F: FnMut(*mut Node<T>) -> bool {}
impl<T> From<T> for Node<T> {
#[inline]
fn from(value: T) -> Self {
Self { value, next: ptr::null_mut() }
}
}
impl<T> Deref for Node<T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
&self.value
}
}
impl<T> DerefMut for Node<T> {
#[inline]
fn deref_mut(&mut self) -> &mut T {
&mut self.value
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn drain_filter_test() {
let mut m: LinkedList<u32> = LinkedList::new();
m.extend(&[1, 2, 3, 4, 5, 6]);
let deleted = m.drain_filter(|v| *v < 4).collect::<Vec<_>>();
assert_eq!(deleted, &[3, 2, 1]);
assert_eq!(m.into_iter().collect::<Vec<_>>(), &[6, 5, 4]);
}
#[test]
fn drain_to_empty_test() {
let mut m: LinkedList<u32> = LinkedList::new();
m.extend(&[1, 2, 3, 4, 5, 6]);
let deleted = m.drain_filter(|_| true).collect::<Vec<_>>();
assert_eq!(deleted, &[6, 5, 4, 3, 2, 1]);
assert_eq!(m.into_iter().collect::<Vec<_>>(), &[]);
}
}