use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
#[repr(align(64))]
struct Padded<T>(T);
struct Slot<T> {
sequence: AtomicUsize,
value: UnsafeCell<MaybeUninit<T>>,
}
impl<T> Slot<T> {
fn new(seq: usize) -> Self {
Slot {
sequence: AtomicUsize::new(seq),
value: UnsafeCell::new(MaybeUninit::uninit()),
}
}
}
unsafe impl<T: Send> Send for Slot<T> {}
unsafe impl<T: Send> Sync for Slot<T> {}
pub struct LockFreeQueue<T> {
buffer: Vec<Slot<T>>,
capacity: usize,
mask: usize,
head: Padded<AtomicUsize>,
tail: Padded<AtomicUsize>,
}
unsafe impl<T: Send> Send for LockFreeQueue<T> {}
unsafe impl<T: Send> Sync for LockFreeQueue<T> {}
impl<T> LockFreeQueue<T> {
pub fn new(capacity: usize) -> Self {
let cap = capacity.max(2).next_power_of_two();
let buffer: Vec<Slot<T>> = (0..cap).map(|i| Slot::new(i)).collect();
LockFreeQueue {
buffer,
capacity: cap,
mask: cap - 1,
head: Padded(AtomicUsize::new(0)),
tail: Padded(AtomicUsize::new(0)),
}
}
pub fn push(&self, val: T) -> bool {
let mut pos = self.tail.0.load(Ordering::Relaxed);
loop {
let slot = &self.buffer[pos & self.mask];
let seq = slot.sequence.load(Ordering::Acquire);
let diff = seq as isize - pos as isize;
match diff.cmp(&0) {
std::cmp::Ordering::Equal => {
match self.tail.0.compare_exchange_weak(
pos,
pos.wrapping_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
unsafe {
(*slot.value.get()).write(val);
}
slot.sequence.store(pos.wrapping_add(1), Ordering::Release);
return true;
}
Err(updated) => {
pos = updated;
}
}
}
std::cmp::Ordering::Less => {
return false;
}
std::cmp::Ordering::Greater => {
pos = self.tail.0.load(Ordering::Relaxed);
}
}
}
}
pub fn pop(&self) -> Option<T> {
let mut pos = self.head.0.load(Ordering::Relaxed);
loop {
let slot = &self.buffer[pos & self.mask];
let seq = slot.sequence.load(Ordering::Acquire);
let diff = seq as isize - pos.wrapping_add(1) as isize;
match diff.cmp(&0) {
std::cmp::Ordering::Equal => {
match self.head.0.compare_exchange_weak(
pos,
pos.wrapping_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
let val = unsafe { (*slot.value.get()).assume_init_read() };
slot.sequence
.store(pos.wrapping_add(self.capacity), Ordering::Release);
return Some(val);
}
Err(updated) => {
pos = updated;
}
}
}
std::cmp::Ordering::Less => {
return None;
}
std::cmp::Ordering::Greater => {
pos = self.head.0.load(Ordering::Relaxed);
}
}
}
}
pub fn len(&self) -> usize {
let tail = self.tail.0.load(Ordering::Relaxed);
let head = self.head.0.load(Ordering::Relaxed);
tail.saturating_sub(head)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
self.capacity
}
}
impl<T> Drop for LockFreeQueue<T> {
fn drop(&mut self) {
while self.pop().is_some() {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_basic_push_pop() {
let q: LockFreeQueue<u32> = LockFreeQueue::new(4);
assert!(q.is_empty());
assert!(q.push(10));
assert!(q.push(20));
assert_eq!(q.len(), 2);
assert_eq!(q.pop(), Some(10));
assert_eq!(q.pop(), Some(20));
assert_eq!(q.pop(), None);
}
#[test]
fn test_capacity_limit() {
let q: LockFreeQueue<i32> = LockFreeQueue::new(4);
assert_eq!(q.capacity(), 4);
assert!(q.push(1));
assert!(q.push(2));
assert!(q.push(3));
assert!(q.push(4));
assert!(!q.push(5));
assert_eq!(q.pop(), Some(1));
assert!(q.push(5));
}
#[test]
fn test_fifo_order() {
let q: LockFreeQueue<usize> = LockFreeQueue::new(16);
for i in 0..10 {
assert!(q.push(i));
}
for i in 0..10 {
assert_eq!(q.pop(), Some(i));
}
assert_eq!(q.pop(), None);
}
#[test]
fn test_concurrent_mpmc() {
const PRODUCERS: usize = 4;
const ITEMS_PER_PRODUCER: usize = 1_000;
const CAPACITY: usize = 256;
let q = Arc::new(LockFreeQueue::<usize>::new(CAPACITY));
let total = PRODUCERS * ITEMS_PER_PRODUCER;
let handles: Vec<_> = (0..PRODUCERS)
.map(|_p| {
let q2 = Arc::clone(&q);
thread::spawn(move || {
let mut sent = 0usize;
while sent < ITEMS_PER_PRODUCER {
if q2.push(1) {
sent += 1;
} else {
thread::yield_now();
}
}
})
})
.collect();
let mut received = 0usize;
while received < total {
if let Some(_) = q.pop() {
received += 1;
} else {
thread::yield_now();
}
}
for h in handles {
h.join().expect("producer thread panicked");
}
assert_eq!(received, total);
assert!(q.is_empty());
}
#[test]
fn test_drop_runs_destructors() {
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
let counter = Arc::new(AtomicUsize::new(0));
struct Tracker(Arc<AtomicUsize>);
impl Drop for Tracker {
fn drop(&mut self) {
self.0.fetch_add(1, Ordering::Relaxed);
}
}
{
let q: LockFreeQueue<Tracker> = LockFreeQueue::new(8);
q.push(Tracker(Arc::clone(&counter)));
q.push(Tracker(Arc::clone(&counter)));
}
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
#[test]
fn test_zero_capacity_becomes_one() {
let q: LockFreeQueue<u8> = LockFreeQueue::new(0);
assert_eq!(q.capacity(), 2);
assert!(q.push(42));
assert!(q.push(43));
assert!(!q.push(44));
assert_eq!(q.pop(), Some(42));
assert_eq!(q.pop(), Some(43));
assert_eq!(q.pop(), None);
}
}