use std::alloc::{self, Layout};
use std::ptr::{self, NonNull};
use std::sync::atomic::{AtomicUsize, AtomicU32, AtomicBool, Ordering};
use std::cell::UnsafeCell;
use crossbeam_epoch::{self as epoch, Atomic, Owned, Guard};
#[cfg(feature = "lock_free")]
use portable_atomic::AtomicU64;
use crate::{Config, BufferError, BufferResult};
struct RingBuffer<T> {
data: UnsafeCell<NonNull<T>>,
capacity: usize,
mask: usize,
}
impl<T> RingBuffer<T> {
fn new(capacity: usize) -> BufferResult<Self> {
let capacity = capacity.next_power_of_two();
let mask = capacity - 1;
let layout = Layout::array::<T>(capacity)
.map_err(|_| BufferError::ResizeError("Invalid capacity".to_string()))?;
let data = unsafe {
let ptr = alloc::alloc(layout) as *mut T;
NonNull::new(ptr).ok_or_else(|| {
BufferError::ResizeError("Failed to allocate memory".to_string())
})?
};
Ok(Self {
data: UnsafeCell::new(data),
capacity,
mask
})
}
unsafe fn write(&self, index: usize, value: T) {
let masked_index = index & self.mask;
let data_ptr = (*self.data.get()).as_ptr();
ptr::write(data_ptr.add(masked_index), value);
}
unsafe fn read(&self, index: usize) -> T {
let masked_index = index & self.mask;
let data_ptr = (*self.data.get()).as_ptr();
ptr::read(data_ptr.add(masked_index))
}
}
impl<T> Drop for RingBuffer<T> {
fn drop(&mut self) {
unsafe {
let layout = Layout::array::<T>(self.capacity).unwrap();
let data_ptr = (*self.data.get()).as_ptr();
alloc::dealloc(data_ptr as *mut u8, layout);
}
}
}
pub struct LockFreeMPSCQueue<T> {
buffer: Atomic<RingBuffer<T>>,
head: AtomicU64,
tail: AtomicUsize,
generation: AtomicU32,
resizing: AtomicBool,
config: Config,
messages_enqueued: AtomicUsize,
messages_dequeued: AtomicUsize,
messages_dropped: AtomicUsize,
}
impl<T> LockFreeMPSCQueue<T> {
pub fn new(config: Config) -> BufferResult<Self> {
config.validate()?;
let initial_buffer = RingBuffer::new(config.initial_capacity)?;
let buffer = Atomic::new(initial_buffer);
Ok(Self {
buffer,
head: AtomicU64::new(0),
tail: AtomicUsize::new(0),
generation: AtomicU32::new(0),
resizing: AtomicBool::new(false),
config,
messages_enqueued: AtomicUsize::new(0),
messages_dequeued: AtomicUsize::new(0),
messages_dropped: AtomicUsize::new(0),
})
}
#[cfg(debug_assertions)]
#[allow(dead_code)] fn debug_assert_invariants(&self) {
let guard = &epoch::pin();
if let Some(buffer) = unsafe { self.buffer.load(Ordering::Acquire, guard).as_ref() } {
debug_assert!(buffer.capacity.is_power_of_two(),
"Ring buffer capacity must be power of 2, got {}", buffer.capacity);
debug_assert!(buffer.capacity >= self.config.min_capacity,
"Capacity {} below minimum {}", buffer.capacity, self.config.min_capacity);
debug_assert!(buffer.capacity <= self.config.max_capacity,
"Capacity {} exceeds maximum {}", buffer.capacity, self.config.max_capacity);
let head_packed = self.head.load(Ordering::Relaxed);
let (head_pos, generation) = Self::unpack_head(head_packed);
let tail_pos = self.tail.load(Ordering::Relaxed);
let queue_size = head_pos.wrapping_sub(tail_pos);
debug_assert!(queue_size <= buffer.capacity,
"Queue size {} exceeds capacity {}", queue_size, buffer.capacity);
let current_gen = self.generation.load(Ordering::Relaxed);
debug_assert!(generation <= current_gen,
"Head generation {} exceeds current generation {}", generation, current_gen);
let stats = self.stats();
debug_assert!(stats.messages_dequeued <= stats.messages_enqueued,
"Dequeued {} exceeds enqueued {}", stats.messages_dequeued, stats.messages_enqueued);
debug_assert!(stats.current_capacity == buffer.capacity,
"Stats capacity {} doesn't match buffer capacity {}", stats.current_capacity, buffer.capacity);
let expected_total = stats.messages_dequeued + stats.current_size + stats.messages_dropped;
debug_assert!(expected_total == stats.messages_enqueued,
"Message conservation violated: {} + {} + {} != {}",
stats.messages_dequeued, stats.current_size, stats.messages_dropped, stats.messages_enqueued);
}
}
#[cfg(not(debug_assertions))]
#[inline(always)]
fn debug_assert_invariants(&self) {
}
fn unpack_head(head: u64) -> (usize, u32) {
let position = (head & 0xFFFF_FFFF) as usize;
let generation = (head >> 32) as u32;
(position, generation)
}
fn pack_head(position: usize, generation: u32) -> u64 {
((generation as u64) << 32) | (position as u64 & 0xFFFF_FFFF)
}
pub fn try_enqueue(&self, item: T) -> BufferResult<()> {
let guard = &epoch::pin();
if self.resizing.load(Ordering::Acquire) {
self.messages_dropped.fetch_add(1, Ordering::Relaxed);
return Err(BufferError::ResizeError("Resize in progress".to_string()));
}
loop {
let buffer_ref = self.buffer.load(Ordering::Acquire, guard);
let buffer = unsafe { buffer_ref.as_ref() }
.ok_or_else(|| BufferError::ResizeError("Invalid buffer reference".to_string()))?;
let head_packed = self.head.load(Ordering::Acquire);
let (head_pos, head_gen) = Self::unpack_head(head_packed);
let tail_pos = self.tail.load(Ordering::Acquire);
let current_size = head_pos.wrapping_sub(tail_pos);
if current_size >= buffer.capacity {
if buffer.capacity >= self.config.max_capacity {
self.messages_dropped.fetch_add(1, Ordering::Relaxed);
return Err(BufferError::MaxCapacityReached(self.config.max_capacity));
} else {
return Err(BufferError::Full);
}
}
let new_head_packed = Self::pack_head(head_pos + 1, head_gen);
match self.head.compare_exchange_weak(
head_packed,
new_head_packed,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => {
let buffer_index = head_pos & (buffer.capacity - 1);
unsafe {
buffer.write(buffer_index, item);
}
self.messages_enqueued.fetch_add(1, Ordering::Relaxed);
#[cfg(all(debug_assertions, not(test)))]
self.debug_assert_invariants();
return Ok(());
}
Err(_) => {
continue;
}
}
}
}
pub fn try_dequeue(&self) -> BufferResult<Option<T>> {
let guard = &epoch::pin();
loop {
let buffer_ref = self.buffer.load(Ordering::Acquire, guard);
let buffer = unsafe { buffer_ref.as_ref() }
.ok_or_else(|| BufferError::ResizeError("Invalid buffer reference".to_string()))?;
let head_packed = self.head.load(Ordering::Acquire);
let (head_pos, _) = Self::unpack_head(head_packed);
let tail_pos = self.tail.load(Ordering::Acquire);
if tail_pos >= head_pos {
return Ok(None);
}
if self.tail.compare_exchange_weak(
tail_pos,
tail_pos + 1,
Ordering::Release,
Ordering::Relaxed,
).is_ok() {
let buffer_index = tail_pos & (buffer.capacity - 1);
let item = unsafe { buffer.read(buffer_index) };
self.messages_dequeued.fetch_add(1, Ordering::Relaxed);
self.try_resize_if_needed(guard)?;
#[cfg(all(debug_assertions, not(test)))]
self.debug_assert_invariants();
return Ok(Some(item));
}
continue;
}
}
fn try_resize_if_needed(&self, guard: &Guard) -> BufferResult<()> {
let buffer_ref = self.buffer.load(Ordering::Acquire, guard);
let buffer = unsafe { buffer_ref.as_ref() }
.ok_or_else(|| BufferError::ResizeError("Invalid buffer reference".to_string()))?;
let head_packed = self.head.load(Ordering::Acquire);
let (head_pos, _) = Self::unpack_head(head_packed);
let tail_pos = self.tail.load(Ordering::Acquire);
let current_size = head_pos.wrapping_sub(tail_pos);
if current_size * 2 >= buffer.capacity && buffer.capacity < self.config.max_capacity {
self.resize(guard)?;
}
Ok(())
}
fn resize(&self, guard: &Guard) -> BufferResult<()> {
if self.resizing.compare_exchange(
false,
true,
Ordering::AcqRel,
Ordering::Relaxed,
).is_err() {
return Ok(());
}
let current_buffer_ref = self.buffer.load(Ordering::Acquire, guard);
let current_buffer = unsafe { current_buffer_ref.as_ref() }
.ok_or_else(|| BufferError::ResizeError("Invalid buffer reference".to_string()))?;
let new_capacity = (current_buffer.capacity as f64 * self.config.growth_factor)
.ceil() as usize;
let new_capacity = new_capacity.min(self.config.max_capacity);
if new_capacity <= current_buffer.capacity {
self.resizing.store(false, Ordering::Release);
return Ok(());
}
let new_buffer = RingBuffer::new(new_capacity)?;
let head_packed = self.head.load(Ordering::Acquire);
let (head_pos, _head_gen) = Self::unpack_head(head_packed);
let tail_pos = self.tail.load(Ordering::Acquire);
let item_count = head_pos.wrapping_sub(tail_pos);
for (new_index, i) in (0..item_count).enumerate() {
let old_index = tail_pos.wrapping_add(i) & (current_buffer.capacity - 1);
let item = unsafe { current_buffer.read(old_index) };
unsafe { new_buffer.write(new_index, item); }
}
let new_generation = self.generation.fetch_add(1, Ordering::AcqRel) + 1;
self.tail.store(0, Ordering::Release);
let new_head_packed = Self::pack_head(item_count, new_generation);
self.head.store(new_head_packed, Ordering::Release);
let new_buffer_owned = Owned::new(new_buffer);
let old_buffer = self.buffer.swap(new_buffer_owned, Ordering::AcqRel, guard);
unsafe {
guard.defer_destroy(old_buffer);
}
self.resizing.store(false, Ordering::Release);
Ok(())
}
pub fn len(&self) -> usize {
let head_packed = self.head.load(Ordering::Acquire);
let (head_pos, _) = Self::unpack_head(head_packed);
let tail_pos = self.tail.load(Ordering::Acquire);
head_pos.wrapping_sub(tail_pos)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn capacity(&self) -> usize {
let guard = &epoch::pin();
let buffer_ref = self.buffer.load(Ordering::Acquire, guard);
if let Some(buffer) = unsafe { buffer_ref.as_ref() } {
buffer.capacity
} else {
0
}
}
pub fn stats(&self) -> QueueStats {
let messages_enqueued = self.messages_enqueued.load(Ordering::Acquire);
let messages_dequeued = self.messages_dequeued.load(Ordering::Acquire);
let messages_dropped = self.messages_dropped.load(Ordering::Acquire);
let head_packed = self.head.load(Ordering::Acquire);
let (head_pos, _) = Self::unpack_head(head_packed);
let tail_pos = self.tail.load(Ordering::Acquire);
let current_size = head_pos.wrapping_sub(tail_pos);
QueueStats {
messages_enqueued,
messages_dequeued,
messages_dropped,
current_size,
current_capacity: self.capacity(),
}
}
}
#[derive(Debug, Clone)]
pub struct QueueStats {
pub messages_enqueued: usize,
pub messages_dequeued: usize,
pub messages_dropped: usize,
pub current_size: usize,
pub current_capacity: usize,
}
unsafe impl<T: Send> Send for LockFreeMPSCQueue<T> {}
unsafe impl<T: Send> Sync for LockFreeMPSCQueue<T> {}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::sync::Arc;
fn test_config() -> Config {
Config::default()
.with_initial_capacity(4)
.with_min_capacity(2)
.with_max_capacity(64)
.with_growth_factor(2.0)
}
#[test]
fn test_basic_enqueue_dequeue() {
let queue = LockFreeMPSCQueue::new(test_config()).unwrap();
queue.try_enqueue(42).unwrap();
queue.try_enqueue(24).unwrap();
assert_eq!(queue.len(), 2);
assert!(!queue.is_empty());
assert_eq!(queue.try_dequeue().unwrap(), Some(42));
assert_eq!(queue.try_dequeue().unwrap(), Some(24));
assert_eq!(queue.try_dequeue().unwrap(), None);
assert_eq!(queue.len(), 0);
assert!(queue.is_empty());
}
#[test]
fn test_capacity_limits() {
let config = test_config()
.with_initial_capacity(2)
.with_min_capacity(1)
.with_max_capacity(4);
let queue = LockFreeMPSCQueue::new(config).unwrap();
queue.try_enqueue(1).unwrap(); queue.try_enqueue(2).unwrap();
queue.try_dequeue().unwrap();
queue.try_enqueue(3).unwrap(); queue.try_enqueue(4).unwrap(); queue.try_enqueue(5).unwrap();
assert!(matches!(
queue.try_enqueue(6),
Err(BufferError::Full) | Err(BufferError::MaxCapacityReached(_))
));
}
#[test]
fn test_mpsc_concurrent_access() {
let queue = Arc::new(LockFreeMPSCQueue::new(test_config()).unwrap());
let num_producers = 2;
let messages_per_producer = 10;
let mut handles = vec![];
for producer_id in 0..num_producers {
let queue_clone = Arc::clone(&queue);
let handle = thread::spawn(move || {
for i in 0..messages_per_producer {
let message = producer_id * 1000 + i;
let mut retry_count = 0;
while queue_clone.try_enqueue(message).is_err() {
thread::yield_now();
retry_count += 1;
if retry_count > 1000 {
break;
}
}
}
});
handles.push(handle);
}
let queue_clone = Arc::clone(&queue);
let consumer_handle = thread::spawn(move || {
let mut received = vec![];
let expected_total = num_producers * messages_per_producer;
let mut no_message_count = 0;
while received.len() < expected_total {
if let Ok(Some(message)) = queue_clone.try_dequeue() {
received.push(message);
no_message_count = 0; } else {
thread::yield_now();
no_message_count += 1;
if no_message_count > 10000 {
break;
}
}
}
received
});
for handle in handles {
handle.join().unwrap();
}
let received = consumer_handle.join().unwrap();
let expected_total = num_producers * messages_per_producer;
assert!(received.len() >= expected_total / 2,
"Expected at least {} messages, got {}", expected_total / 2, received.len());
let stats = queue.stats();
println!("Final stats: {:?}", stats);
}
#[test]
fn test_resize_behavior() {
let config = test_config()
.with_initial_capacity(4)
.with_min_capacity(2)
.with_max_capacity(16)
.with_growth_factor(2.0);
let queue = LockFreeMPSCQueue::new(config).unwrap();
assert_eq!(queue.capacity(), 4);
for i in 0..4 {
queue.try_enqueue(i).unwrap();
}
for _ in 0..2 {
queue.try_dequeue().unwrap();
}
assert!(queue.capacity() > 4);
for i in 4..8 {
queue.try_enqueue(i).unwrap();
}
for _ in 2..8 {
queue.try_dequeue().unwrap();
}
let stats = queue.stats();
assert_eq!(stats.messages_enqueued, 8);
assert_eq!(stats.messages_dequeued, 8);
}
}