use core::ptr;
use std::cell::Cell;
use std::error::Error;
use std::mem::MaybeUninit;
use std::sync::atomic::Ordering::{Acquire, Release};
use std::sync::atomic::{AtomicPtr, AtomicU8, AtomicUsize, Ordering};
use std::sync::Arc;
use crate::State::{Empty, Handled, Set};
const BUFFER_SIZE: usize = 1620;
#[repr(u8)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum State {
Empty,
Set,
Handled,
}
impl From<State> for u8 {
fn from(state: State) -> Self {
state as u8
}
}
impl From<u8> for State {
fn from(state: u8) -> Self {
match state {
0 => State::Empty,
1 => State::Set,
2 => State::Handled,
_ => unreachable!(),
}
}
}
impl PartialEq<State> for u8 {
fn eq(&self, other: &State) -> bool {
*self == *other as u8
}
}
struct Node<T> {
data: MaybeUninit<T>,
is_set: AtomicU8,
}
impl<T> Node<T> {
unsafe fn set_data(&mut self, data: T) {
self.data.as_mut_ptr().write(data);
self.is_set.store(State::Set.into(), Release);
}
fn state(&self) -> State {
self.is_set.load(Acquire).into()
}
fn set_state(&mut self, state: State) {
self.is_set.store(state.into(), Release)
}
}
impl<T> Default for Node<T> {
fn default() -> Self {
Node {
data: MaybeUninit::uninit(),
is_set: AtomicU8::new(State::Empty.into()),
}
}
}
struct BufferList<T> {
nodes: Vec<Node<T>>,
prev: *mut BufferList<T>,
next: AtomicPtr<BufferList<T>>,
head: usize,
pos: usize,
}
impl<T> BufferList<T> {
fn new(size: usize, position_in_queue: usize) -> Self {
BufferList::with_prev(size, position_in_queue, ptr::null_mut())
}
fn with_prev(size: usize, pos: usize, prev: *mut BufferList<T>) -> Self {
let mut curr_buffer = Vec::with_capacity(size);
curr_buffer.resize_with(size, Node::default);
BufferList {
nodes: curr_buffer,
prev,
next: AtomicPtr::new(ptr::null_mut()),
head: 0,
pos,
}
}
}
unsafe impl<T> Send for BufferList<T> {}
unsafe impl<T> Sync for BufferList<T> {}
#[derive(Debug)]
pub struct MpscQueue<T> {
head_of_queue: Cell<*mut BufferList<T>>,
tail_of_queue: AtomicPtr<BufferList<T>>,
buffer_size: usize,
tail: AtomicUsize,
}
unsafe impl<T> Send for MpscQueue<T> {}
unsafe impl<T> Sync for MpscQueue<T> {}
impl<T> MpscQueue<T> {
pub fn new() -> Self {
let head_of_queue = BufferList::new(BUFFER_SIZE, 1);
let head = Box::new(head_of_queue);
let head = Box::into_raw(head);
let tail = AtomicPtr::new(head);
MpscQueue {
head_of_queue: Cell::new(head),
tail_of_queue: tail,
buffer_size: BUFFER_SIZE,
tail: AtomicUsize::new(0),
}
}
pub fn enqueue(&self, data: T) -> Result<(), T> {
let location = self.tail.fetch_add(1, Ordering::SeqCst);
let mut temp_tail;
let mut is_last_buffer = true;
loop {
temp_tail = unsafe { &mut *self.tail_of_queue.load(Ordering::Acquire) };
let mut prev_size = self.size_without_buffer(temp_tail);
while location < prev_size {
is_last_buffer = false;
temp_tail = unsafe { &mut *temp_tail.prev };
prev_size -= self.buffer_size;
}
let global_size = self.buffer_size + prev_size;
if prev_size <= location && location < global_size {
return self.insert(data, location - prev_size, temp_tail, is_last_buffer);
}
if location >= global_size {
let next = temp_tail.next.load(Ordering::Acquire);
if next.is_null() {
let new_buffer_ptr = self.allocate_buffer(temp_tail);
if temp_tail
.next
.compare_exchange(
ptr::null_mut(),
new_buffer_ptr,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
temp_tail.next.store(new_buffer_ptr, Ordering::Release)
} else {
MpscQueue::drop_buffer(new_buffer_ptr)
}
} else {
self.tail_of_queue.compare_and_swap(
temp_tail as *mut _,
next,
Ordering::SeqCst,
);
}
}
}
}
pub fn dequeue(&self) -> Option<T> {
let mut temp_tail;
loop {
temp_tail = unsafe { &mut *self.tail_of_queue.load(Ordering::SeqCst) };
let prev_size = self.size_without_buffer(temp_tail);
let head = &mut unsafe { &mut *self.head_of_queue.get() }.head;
let head_is_tail = self.head_of_queue.get() == temp_tail;
let head_is_empty = *head == self.tail.load(Ordering::Acquire) - prev_size;
if head_is_tail && head_is_empty {
break None;
}
if *head < self.buffer_size {
let node = &mut unsafe { &mut *self.head_of_queue.get() }.nodes[*head];
if node.state() == Handled {
*head += 1;
continue;
}
if node.state() == Empty {
if let Some(data) = self.search(head, node) {
return Some(data);
}
}
if node.state() == Set {
unsafe { (*self.head_of_queue.get()).head += 1 };
let data = MpscQueue::read_data(node);
return Some(data);
}
}
if *head >= self.buffer_size {
if self.head_of_queue.get() == self.tail_of_queue.load(Acquire) {
return None;
}
let next = unsafe { &*self.head_of_queue.get() }.next.load(Acquire);
if next.is_null() {
return None;
}
MpscQueue::drop_buffer(self.head_of_queue.get());
self.head_of_queue.set(next);
}
}
}
fn search(&self, head: &mut usize, node: &mut Node<T>) -> Option<T> {
let mut temp_buffer = self.head_of_queue.get();
let mut temp_head = *head;
let mut search_next_buffer = false;
let mut all_handled = true;
while node.state() == Empty {
if temp_head < self.buffer_size {
let mut temp_node = &mut unsafe { &mut *self.head_of_queue.get() }.nodes[temp_head];
temp_head += 1;
if temp_node.state() == Set && node.state() == Empty {
self.scan(node, &mut temp_buffer, &mut temp_head, &mut temp_node);
if node.state() == Set {
break;
}
let data = MpscQueue::read_data(&mut temp_node);
if search_next_buffer && (temp_head - 1) == unsafe { (*temp_buffer).head } {
unsafe { (*temp_buffer).head += 1 };
}
return Some(data);
}
if temp_node.state() == Empty {
all_handled = false;
}
}
if temp_head >= self.buffer_size {
if all_handled && search_next_buffer {
if self.fold_buffer(&mut temp_buffer, &mut temp_head) {
all_handled = true;
search_next_buffer = true;
} else {
return None;
}
} else {
let next = unsafe { &*temp_buffer }.next.load(Acquire);
if next.is_null() {
return None;
}
temp_buffer = next;
temp_head = unsafe { &*temp_buffer }.head;
all_handled = true;
search_next_buffer = true;
}
}
}
None
}
fn scan(
&self,
node: &Node<T>,
temp_head_of_queue: &mut *mut BufferList<T>,
temp_head: &mut usize,
temp_node: &mut &mut Node<T>,
) {
let mut scan_head_of_queue = self.head_of_queue.get();
let mut scan_head = unsafe { &*scan_head_of_queue }.head;
while node.state() == Empty && scan_head_of_queue != *temp_head_of_queue
|| scan_head < (*temp_head - 1)
{
if scan_head > self.buffer_size {
scan_head_of_queue = unsafe { (*scan_head_of_queue).next.load(Acquire) };
scan_head = unsafe { (*scan_head_of_queue).head };
continue;
}
let scan_node = &mut unsafe { &mut *scan_head_of_queue }.nodes[scan_head];
scan_head += 1;
if scan_node.state() == Set {
*temp_head = scan_head;
*temp_head_of_queue = scan_head_of_queue;
*temp_node = scan_node;
scan_head_of_queue = self.head_of_queue.get();
scan_head = unsafe { &*scan_head_of_queue }.head;
}
}
}
fn fold_buffer(&self, buffer_ptr: &mut *mut BufferList<T>, buffer_head: &mut usize) -> bool {
let buffer = unsafe { &**buffer_ptr };
let next = buffer.next.load(Acquire);
let prev = buffer.prev;
if next.is_null() {
return false;
}
unsafe { &mut *next }.prev = prev;
unsafe { &mut *prev }.next.store(next, Ordering::Release);
MpscQueue::drop_buffer(*buffer_ptr);
*buffer_ptr = next;
*buffer_head = unsafe { &mut **buffer_ptr }.head;
true
}
fn size_without_buffer(&self, buffer: &BufferList<T>) -> usize {
self.buffer_size * (buffer.pos - 1)
}
fn insert(
&self,
data: T,
index: usize,
buffer: &mut BufferList<T>,
is_last_buffer: bool,
) -> Result<(), T> {
unsafe {
buffer.nodes[index].set_data(data);
}
if index == 1 && is_last_buffer {
let new_buffer_ptr = self.allocate_buffer(buffer);
if buffer
.next
.compare_exchange(
ptr::null_mut(),
new_buffer_ptr,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_err()
{
MpscQueue::drop_buffer(new_buffer_ptr);
}
}
Ok(())
}
fn allocate_buffer(&self, buffer: &mut BufferList<T>) -> *mut BufferList<T> {
let new_buffer = BufferList::with_prev(self.buffer_size, buffer.pos + 1, buffer as *mut _);
Box::into_raw(Box::new(new_buffer))
}
fn drop_buffer(ptr: *mut BufferList<T>) {
drop(unsafe { Box::from_raw(ptr) })
}
fn read_data(node: &mut Node<T>) -> T {
let data = unsafe { node.data.as_ptr().read() };
node.data = MaybeUninit::uninit();
node.set_state(Handled);
data
}
}
impl<T> Default for MpscQueue<T> {
fn default() -> Self {
Self::new()
}
}
pub struct Sender<T> {
queue: Arc<MpscQueue<T>>,
}
impl<T> Sender<T> {
fn new(queue: Arc<MpscQueue<T>>) -> Self {
Sender { queue }
}
pub fn send(&self, t: T) -> Result<(), T> {
self.queue.enqueue(t)
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Sender {
queue: self.queue.clone(),
}
}
}
pub struct Receiver<T> {
queue: Arc<MpscQueue<T>>,
}
impl<T> Receiver<T> {
fn new(queue: Arc<MpscQueue<T>>) -> Self {
Receiver { queue }
}
pub fn recv(&self) -> Result<Option<T>, Box<dyn Error>> {
let head = self.queue.dequeue();
Ok(head)
}
}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let queue = Arc::new(MpscQueue::new());
(Sender::new(queue.clone()), Receiver::new(queue))
}
#[cfg(test)]
mod tests {
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::thread;
use super::*;
#[test]
fn enqueue() {
let q = MpscQueue::new();
unsafe {
assert_eq!(State::Empty, (*q.head_of_queue.get()).nodes[0].state());
}
unsafe {
assert_eq!(State::Empty, (*q.head_of_queue.get()).nodes[1].state());
}
assert_eq!(Ok(()), q.enqueue(42));
assert_eq!(Ok(()), q.enqueue(43));
unsafe {
assert_eq!(42, (*q.head_of_queue.get()).nodes[0].data.assume_init());
}
unsafe {
assert_eq!(State::Set, (*q.head_of_queue.get()).nodes[0].state());
}
unsafe {
assert_eq!(43, (*q.head_of_queue.get()).nodes[1].data.assume_init());
}
unsafe {
assert_eq!(State::Set, (*q.head_of_queue.get()).nodes[1].state());
}
}
#[test]
fn enqueue_exceeds_buffer() {
let q = MpscQueue::new();
for i in 0..BUFFER_SIZE * 2 {
let _ = q.enqueue(i);
}
for i in 0..BUFFER_SIZE * 2 {
let buffer = i / BUFFER_SIZE;
let index = i % BUFFER_SIZE;
if buffer == 0 {
unsafe {
assert_eq!(i, (*q.head_of_queue.get()).nodes[index].data.assume_init());
}
} else {
unsafe {
assert_eq!(
i,
(*q.tail_of_queue.load(Ordering::SeqCst)).nodes[index]
.data
.assume_init()
);
}
}
}
}
#[test]
fn dequeue() {
let q = MpscQueue::new();
assert_eq!(None, q.dequeue());
assert_eq!(Ok(()), q.enqueue(42));
assert_eq!(Some(42), q.dequeue());
assert_eq!(None, q.dequeue());
}
#[test]
fn dequeue_exceeds_buffer() {
let q = MpscQueue::new();
let size = BUFFER_SIZE * 2.5 as usize;
for i in 0..size {
assert_eq!(q.enqueue(i), Ok(()));
}
for i in 0..size {
assert_eq!(q.dequeue(), Some(i));
}
}
#[test]
fn multi_threaded_direct() {
let nthreads = 8;
let nmsgs = 1000;
let q = MpscQueue::new();
let q = Arc::new(q);
let handles = (0..nthreads)
.map(|_| {
let q = q.clone();
thread::spawn(move || {
for i in 0..nmsgs {
let _ = q.enqueue(i);
}
})
})
.collect::<Vec<_>>();
for handle in handles {
let _ = handle.join();
}
let q = Arc::try_unwrap(q).unwrap();
let mut i = 0;
while let Some(data) = q.dequeue() {
i += data;
}
let expected = (0..1000).sum::<i32>() * nthreads;
assert_eq!(i, expected)
}
#[test]
fn multi_threaded_channel() {
let nthreads = 8;
let nmsgs = 1000;
let (tx, rx) = channel::<i32>();
let handles = (0..nthreads)
.map(|_| {
let tx = tx.clone();
thread::spawn(move || {
for i in 0..nmsgs {
let _ = tx.send(i).unwrap();
}
})
})
.collect::<Vec<_>>();
for handle in handles {
let _ = handle.join();
}
let mut i = 0;
while let Some(data) = rx.recv().unwrap() {
i += data;
}
let expected = (0..1000).sum::<i32>() * nthreads;
assert_eq!(i, expected)
}
}