use std::fmt;
use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
struct StackNode<T> {
value: T,
next: *mut StackNode<T>,
}
pub struct LockFreeStack<T> {
head: AtomicPtr<StackNode<T>>,
len: AtomicUsize,
}
unsafe impl<T: Send> Send for LockFreeStack<T> {}
unsafe impl<T: Send> Sync for LockFreeStack<T> {}
impl<T> LockFreeStack<T> {
pub fn new() -> Self {
Self {
head: AtomicPtr::new(std::ptr::null_mut()),
len: AtomicUsize::new(0),
}
}
pub fn push(&self, value: T) {
let new_node = Box::into_raw(Box::new(StackNode {
value,
next: std::ptr::null_mut(),
}));
loop {
let current_head = self.head.load(Ordering::Acquire);
unsafe {
(*new_node).next = current_head;
}
if self
.head
.compare_exchange_weak(current_head, new_node, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
self.len.fetch_add(1, Ordering::Relaxed);
return;
}
}
}
pub fn pop(&self) -> Option<T> {
loop {
let current_head = self.head.load(Ordering::Acquire);
if current_head.is_null() {
return None;
}
let next = unsafe { (*current_head).next };
if self
.head
.compare_exchange_weak(current_head, next, Ordering::Release, Ordering::Relaxed)
.is_ok()
{
let node = unsafe { Box::from_raw(current_head) };
self.len.fetch_sub(1, Ordering::Relaxed);
return Some(node.value);
}
}
}
pub fn is_empty(&self) -> bool {
self.head.load(Ordering::Acquire).is_null()
}
pub fn len(&self) -> usize {
self.len.load(Ordering::Relaxed)
}
}
impl<T> Default for LockFreeStack<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: fmt::Debug> fmt::Debug for LockFreeStack<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LockFreeStack")
.field("len", &self.len())
.finish()
}
}
impl<T> Drop for LockFreeStack<T> {
fn drop(&mut self) {
while self.pop().is_some() {}
}
}
struct QueueNode<T> {
value: std::mem::ManuallyDrop<Option<T>>,
next: AtomicPtr<QueueNode<T>>,
}
impl<T> QueueNode<T> {
fn new(value: Option<T>) -> *mut Self {
Box::into_raw(Box::new(Self {
value: std::mem::ManuallyDrop::new(value),
next: AtomicPtr::new(std::ptr::null_mut()),
}))
}
}
pub struct LockFreeQueue<T> {
head: AtomicPtr<QueueNode<T>>,
tail: AtomicPtr<QueueNode<T>>,
len: AtomicUsize,
retired: std::sync::Mutex<Vec<*mut QueueNode<T>>>,
}
unsafe impl<T: Send> Send for LockFreeQueue<T> {}
unsafe impl<T: Send> Sync for LockFreeQueue<T> {}
impl<T> LockFreeQueue<T> {
pub fn new() -> Self {
let sentinel = QueueNode::new(None);
Self {
head: AtomicPtr::new(sentinel),
tail: AtomicPtr::new(sentinel),
len: AtomicUsize::new(0),
retired: std::sync::Mutex::new(Vec::new()),
}
}
pub fn enqueue(&self, value: T) {
let new_node = QueueNode::new(Some(value));
loop {
let tail = self.tail.load(Ordering::Acquire);
let tail_next = unsafe { (*tail).next.load(Ordering::Acquire) };
if tail_next.is_null() {
if unsafe {
(*tail)
.next
.compare_exchange_weak(
std::ptr::null_mut(),
new_node,
Ordering::Release,
Ordering::Relaxed,
)
.is_ok()
} {
let _ = self.tail.compare_exchange(
tail,
new_node,
Ordering::Release,
Ordering::Relaxed,
);
self.len.fetch_add(1, Ordering::Relaxed);
return;
}
} else {
let _ = self.tail.compare_exchange(
tail,
tail_next,
Ordering::Release,
Ordering::Relaxed,
);
}
}
}
pub fn dequeue(&self) -> Option<T> {
loop {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
let head_next = unsafe { (*head).next.load(Ordering::Acquire) };
if head != self.head.load(Ordering::Acquire) {
continue;
}
if head == tail {
if head_next.is_null() {
return None;
}
let _ = self.tail.compare_exchange(
tail,
head_next,
Ordering::Release,
Ordering::Relaxed,
);
} else if !head_next.is_null() {
if self
.head
.compare_exchange_weak(head, head_next, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
let value = unsafe {
std::ptr::read(
&(*head_next).value as *const std::mem::ManuallyDrop<Option<T>>,
)
};
let value = std::mem::ManuallyDrop::into_inner(value);
unsafe {
std::ptr::write(
&mut (*head_next).value as *mut std::mem::ManuallyDrop<Option<T>>,
std::mem::ManuallyDrop::new(None),
);
}
if let Ok(mut retired) = self.retired.lock() {
retired.push(head);
}
self.len.fetch_sub(1, Ordering::Relaxed);
return value;
}
}
}
}
pub fn is_empty(&self) -> bool {
let head = self.head.load(Ordering::Acquire);
let tail = self.tail.load(Ordering::Acquire);
if head != tail {
return false;
}
let head_next = unsafe { (*head).next.load(Ordering::Acquire) };
head_next.is_null()
}
pub fn len(&self) -> usize {
self.len.load(Ordering::Relaxed)
}
}
impl<T> Default for LockFreeQueue<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: fmt::Debug> fmt::Debug for LockFreeQueue<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LockFreeQueue")
.field("len", &self.len())
.finish()
}
}
impl<T> Drop for LockFreeQueue<T> {
fn drop(&mut self) {
let mut current = *self.head.get_mut();
while !current.is_null() {
unsafe {
let next = (*current).next.load(Ordering::Relaxed);
std::mem::ManuallyDrop::drop(&mut (*current).value);
let _ = Box::from_raw(current);
current = next;
}
}
if let Ok(retired) = self.retired.get_mut() {
for &node in retired.iter() {
if !node.is_null() {
unsafe {
std::mem::ManuallyDrop::drop(&mut (*node).value);
let _ = Box::from_raw(node);
}
}
}
}
}
}
#[derive(Debug)]
pub struct LockFreeCounter {
value: AtomicUsize,
}
impl LockFreeCounter {
pub fn new(initial: usize) -> Self {
Self {
value: AtomicUsize::new(initial),
}
}
pub fn increment(&self) -> usize {
self.value.fetch_add(1, Ordering::AcqRel)
}
pub fn decrement(&self) -> usize {
loop {
let current = self.value.load(Ordering::Acquire);
if current == 0 {
return 0;
}
if self
.value
.compare_exchange_weak(current, current - 1, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
return current;
}
}
}
pub fn get(&self) -> usize {
self.value.load(Ordering::Acquire)
}
pub fn add(&self, n: usize) -> usize {
self.value.fetch_add(n, Ordering::AcqRel)
}
pub fn compare_and_swap(&self, expected: usize, new_val: usize) -> Result<usize, usize> {
self.value
.compare_exchange(expected, new_val, Ordering::AcqRel, Ordering::Acquire)
}
pub fn reset(&self) -> usize {
self.value.swap(0, Ordering::AcqRel)
}
}
impl Default for LockFreeCounter {
fn default() -> Self {
Self::new(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_stack_push_pop_basic() {
let stack = LockFreeStack::new();
stack.push(1);
stack.push(2);
stack.push(3);
assert_eq!(stack.pop(), Some(3));
assert_eq!(stack.pop(), Some(2));
assert_eq!(stack.pop(), Some(1));
assert_eq!(stack.pop(), None);
}
#[test]
fn test_stack_empty() {
let stack = LockFreeStack::<i32>::new();
assert!(stack.is_empty());
assert_eq!(stack.len(), 0);
assert_eq!(stack.pop(), None);
}
#[test]
fn test_stack_len() {
let stack = LockFreeStack::new();
assert_eq!(stack.len(), 0);
stack.push(10);
assert_eq!(stack.len(), 1);
stack.push(20);
assert_eq!(stack.len(), 2);
stack.pop();
assert_eq!(stack.len(), 1);
}
#[test]
fn test_stack_concurrent_push() {
let stack = Arc::new(LockFreeStack::new());
let n_threads = 8;
let n_items = 1000;
let handles: Vec<_> = (0..n_threads)
.map(|t| {
let stack = Arc::clone(&stack);
thread::spawn(move || {
for i in 0..n_items {
stack.push(t * n_items + i);
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
assert_eq!(stack.len(), n_threads * n_items);
let mut count = 0;
while stack.pop().is_some() {
count += 1;
}
assert_eq!(count, n_threads * n_items);
}
#[test]
fn test_stack_concurrent_push_pop() {
let stack = Arc::new(LockFreeStack::new());
let n_threads = 4;
let n_items = 500;
let producers: Vec<_> = (0..n_threads)
.map(|_| {
let stack = Arc::clone(&stack);
thread::spawn(move || {
for i in 0..n_items {
stack.push(i);
}
})
})
.collect();
let consumers: Vec<_> = (0..n_threads)
.map(|_| {
let stack = Arc::clone(&stack);
thread::spawn(move || {
let mut count = 0usize;
for _ in 0..n_items {
loop {
if stack.pop().is_some() {
count += 1;
break;
}
thread::yield_now();
}
}
count
})
})
.collect();
for h in producers {
h.join().expect("producer panicked");
}
let total_consumed: usize = consumers
.into_iter()
.map(|h| h.join().expect("consumer panicked"))
.sum();
assert_eq!(total_consumed, n_threads * n_items);
}
#[test]
fn test_stack_drop_frees_memory() {
let stack = LockFreeStack::new();
for i in 0..100 {
stack.push(format!("item_{i}"));
}
drop(stack);
}
#[test]
fn test_stack_default() {
let stack: LockFreeStack<i32> = Default::default();
assert!(stack.is_empty());
}
#[test]
fn test_queue_enqueue_dequeue_basic() {
let queue = LockFreeQueue::new();
queue.enqueue(1);
queue.enqueue(2);
queue.enqueue(3);
assert_eq!(queue.dequeue(), Some(1));
assert_eq!(queue.dequeue(), Some(2));
assert_eq!(queue.dequeue(), Some(3));
assert_eq!(queue.dequeue(), None);
}
#[test]
fn test_queue_empty() {
let queue = LockFreeQueue::<i32>::new();
assert!(queue.is_empty());
assert_eq!(queue.len(), 0);
assert_eq!(queue.dequeue(), None);
}
#[test]
fn test_queue_len() {
let queue = LockFreeQueue::new();
assert_eq!(queue.len(), 0);
queue.enqueue(10);
assert_eq!(queue.len(), 1);
queue.enqueue(20);
assert_eq!(queue.len(), 2);
queue.dequeue();
assert_eq!(queue.len(), 1);
}
#[test]
fn test_queue_fifo_order() {
let queue = LockFreeQueue::new();
for i in 0..20 {
queue.enqueue(i);
}
for i in 0..20 {
assert_eq!(queue.dequeue(), Some(i));
}
}
#[test]
fn test_queue_concurrent_enqueue() {
let queue = Arc::new(LockFreeQueue::new());
let n_threads = 8;
let n_items = 1000;
let handles: Vec<_> = (0..n_threads)
.map(|t| {
let queue = Arc::clone(&queue);
thread::spawn(move || {
for i in 0..n_items {
queue.enqueue(t * n_items + i);
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
let mut items = Vec::new();
while let Some(v) = queue.dequeue() {
items.push(v);
}
assert_eq!(items.len(), n_threads * n_items);
items.sort_unstable();
let mut expected: Vec<usize> = Vec::new();
for t in 0..n_threads {
for i in 0..n_items {
expected.push(t * n_items + i);
}
}
expected.sort_unstable();
assert_eq!(items, expected);
}
#[test]
fn test_queue_concurrent_enqueue_dequeue() {
use std::sync::atomic::{AtomicUsize, Ordering};
let queue = Arc::new(LockFreeQueue::new());
let n_threads = 4;
let n_items = 500;
let total = n_threads * n_items;
let remaining = Arc::new(AtomicUsize::new(total));
let producers: Vec<_> = (0..n_threads)
.map(|_| {
let queue = Arc::clone(&queue);
thread::spawn(move || {
for i in 0..n_items {
queue.enqueue(i);
}
})
})
.collect();
let consumers: Vec<_> = (0..n_threads)
.map(|_| {
let queue = Arc::clone(&queue);
let remaining = Arc::clone(&remaining);
thread::spawn(move || {
let mut count = 0usize;
loop {
let rem = remaining.load(Ordering::Acquire);
if rem == 0 {
break;
}
if let Some(_) = queue.dequeue() {
remaining.fetch_sub(1, Ordering::AcqRel);
count += 1;
} else {
thread::yield_now();
}
}
count
})
})
.collect();
for h in producers {
h.join().expect("producer panicked");
}
let total_consumed: usize = consumers
.into_iter()
.map(|h| h.join().expect("consumer panicked"))
.sum();
assert_eq!(total_consumed, total);
}
#[test]
fn test_queue_drop_frees_memory() {
let queue = LockFreeQueue::new();
for i in 0..100 {
queue.enqueue(format!("item_{i}"));
}
drop(queue);
}
#[test]
fn test_queue_default() {
let queue: LockFreeQueue<i32> = Default::default();
assert!(queue.is_empty());
}
#[test]
fn test_counter_basic() {
let counter = LockFreeCounter::new(0);
assert_eq!(counter.get(), 0);
assert_eq!(counter.increment(), 0);
assert_eq!(counter.get(), 1);
assert_eq!(counter.increment(), 1);
assert_eq!(counter.get(), 2);
assert_eq!(counter.decrement(), 2);
assert_eq!(counter.get(), 1);
}
#[test]
fn test_counter_concurrent() {
let counter = Arc::new(LockFreeCounter::new(0));
let n_threads = 8;
let n_increments = 10_000;
let handles: Vec<_> = (0..n_threads)
.map(|_| {
let counter = Arc::clone(&counter);
thread::spawn(move || {
for _ in 0..n_increments {
counter.increment();
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
assert_eq!(counter.get(), n_threads * n_increments);
}
#[test]
fn test_counter_decrement_saturates() {
let counter = LockFreeCounter::new(0);
assert_eq!(counter.decrement(), 0);
assert_eq!(counter.get(), 0);
}
#[test]
fn test_counter_compare_and_swap() {
let counter = LockFreeCounter::new(10);
assert_eq!(counter.compare_and_swap(10, 20), Ok(10));
assert_eq!(counter.get(), 20);
assert_eq!(counter.compare_and_swap(10, 30), Err(20));
assert_eq!(counter.get(), 20);
}
#[test]
fn test_counter_reset() {
let counter = LockFreeCounter::new(0);
counter.add(100);
assert_eq!(counter.reset(), 100);
assert_eq!(counter.get(), 0);
}
#[test]
fn test_counter_add() {
let counter = LockFreeCounter::new(5);
assert_eq!(counter.add(10), 5);
assert_eq!(counter.get(), 15);
}
}