use core::{
cell::UnsafeCell,
mem::{replace, swap},
sync::atomic::{AtomicU64, Ordering},
};
use lock_api::{RawRwLock, RwLock};
use crate::ring_buffer::{self, RingBuffer, ring_index};
pub struct FlipBuffer<T> {
buffer: RingBuffer<UnsafeCell<T>>,
popped: AtomicU64,
pushed: AtomicU64,
}
unsafe impl<T> Sync for FlipBuffer<T> where T: Send {}
unsafe impl<T> Send for FlipBuffer<T> where T: Send {}
impl<T> FlipBuffer<T> {
#[must_use]
pub fn new() -> Self {
FlipBuffer {
buffer: RingBuffer::new(),
popped: AtomicU64::new(0),
pushed: AtomicU64::new(0),
}
}
#[must_use]
pub fn with_capacity(cap: usize) -> Self {
FlipBuffer {
buffer: RingBuffer::with_capacity(cap),
popped: AtomicU64::new(0),
pushed: AtomicU64::new(0),
}
}
fn flush(&mut self) {
let len = self.buffer.len();
let vacant = self.buffer.capacity() - len;
let popped = replace(self.popped.get_mut(), 0);
let pushed = replace(self.pushed.get_mut(), 0);
let popped = match usize::try_from(popped) {
Ok(popped) if popped <= len => popped,
_ => len,
};
let pushed = match usize::try_from(pushed) {
Ok(pushed) if pushed <= vacant => pushed,
_ => vacant,
};
if popped == 0 && pushed == 0 {
return;
}
let new_len = len - popped + pushed;
let new_head = ring_index(self.buffer.head(), popped, self.buffer.capacity());
unsafe {
self.buffer.set_head(new_head);
self.buffer.set_len(new_len);
}
}
pub fn clear(&mut self) {
self.flush();
self.buffer.clear();
}
pub fn push_sync(&self, value: T) -> Result<(), T> {
let pushed = self.pushed.fetch_add(1, Ordering::Acquire);
match usize::try_from(pushed) {
Ok(pushed) if pushed < self.buffer.capacity() - self.buffer.len() => {
let idx = ring_index(
self.buffer.head(),
self.buffer.len() + pushed,
self.buffer.capacity(),
);
unsafe {
let slot = self.buffer.as_ptr().add(idx);
UnsafeCell::raw_get(slot).write(value);
}
Ok(())
}
_ => {
Err(value)
}
}
}
pub fn push(&mut self, value: T) {
self.flush();
self.buffer.push(UnsafeCell::new(value));
}
pub fn pop_sync(&self) -> Option<T> {
let popped: u64 = self.popped.fetch_add(1, Ordering::Acquire);
match usize::try_from(popped) {
Ok(popped) if popped < self.buffer.len() => {
let idx = ring_index(self.buffer.head(), popped, self.buffer.capacity());
let value = unsafe {
let slot = self.buffer.as_ptr().add(idx);
UnsafeCell::raw_get(slot).read()
};
Some(value)
}
_ => {
None
}
}
}
pub fn pop(&mut self) -> Option<T> {
self.flush();
self.buffer.pop().map(UnsafeCell::into_inner)
}
pub fn drain(&mut self) -> Drain<'_, T> {
self.flush();
Drain {
inner: self.buffer.drain(),
}
}
pub fn swap_buffer(&mut self, ring: &mut RingBuffer<T>) {
self.flush();
swap(ring.as_unsafe_cell_mut(), &mut self.buffer);
}
}
#[must_use = "iterator does nothing unless consumed"]
pub struct Drain<'a, T> {
inner: ring_buffer::Drain<'a, UnsafeCell<T>>,
}
impl<T> Iterator for Drain<'_, T> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<T> {
self.inner.next().map(UnsafeCell::into_inner)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
#[inline]
fn count(self) -> usize {
self.inner.count()
}
#[inline]
fn nth(&mut self, n: usize) -> Option<T> {
self.inner.nth(n).map(UnsafeCell::into_inner)
}
}
impl<T> ExactSizeIterator for Drain<'_, T> {
#[inline]
fn len(&self) -> usize {
self.inner.len()
}
}
pub struct FlipQueue<T, L = crate::DefaultRawRwLock> {
buffer: lock_api::RwLock<L, FlipBuffer<T>>,
}
unsafe impl<T, L> Send for FlipQueue<T, L>
where
T: Send,
L: Send,
{
}
unsafe impl<T, L> Sync for FlipQueue<T, L>
where
T: Send,
L: Sync,
{
}
impl<T, L> Default for FlipQueue<T, L>
where
L: RawRwLock,
{
fn default() -> Self {
FlipQueue::new()
}
}
impl<T, L> FlipQueue<T, L>
where
L: RawRwLock,
{
pub fn clear(&mut self) {
self.buffer.get_mut().clear();
}
}
impl<T, L> FlipQueue<T, L>
where
L: RawRwLock,
{
#[must_use]
pub fn new() -> Self {
FlipQueue {
buffer: RwLock::new(FlipBuffer::new()),
}
}
#[must_use]
pub fn with_capacity(cap: usize) -> Self {
FlipQueue {
buffer: RwLock::new(FlipBuffer::with_capacity(cap)),
}
}
pub fn try_push_sync(&self, value: T) -> Result<(), T> {
let read = self.buffer.read();
read.push_sync(value)
}
pub fn push_sync(&self, value: T) {
if let Err(value) = self.try_push_sync(value) {
self.push_slow(value);
}
}
#[inline(never)]
#[cold]
fn push_slow(&self, value: T) {
let mut write = self.buffer.write();
write.push(value);
}
pub fn push(&mut self, value: T) {
self.buffer.get_mut().push(value);
}
pub fn try_pop_sync(&self) -> Option<T> {
let read = self.buffer.read();
read.pop_sync()
}
pub fn pop_sync(&self) -> Option<T> {
let read = self.buffer.read();
if let Some(value) = read.pop_sync() {
return Some(value);
}
if read.pushed.load(Ordering::Relaxed) == 0 {
return None;
}
drop(read);
self.pop_slow()
}
fn pop_slow(&self) -> Option<T> {
let mut write = self.buffer.write();
write.pop()
}
pub fn pop(&mut self) -> Option<T> {
self.buffer.get_mut().pop()
}
pub fn drain(&mut self) -> Drain<'_, T> {
self.buffer.get_mut().drain()
}
pub fn drain_locking<R>(&self, f: impl FnOnce(Drain<T>) -> R) -> R {
let mut write = self.buffer.write();
f(write.drain())
}
pub fn swap_buffer(&self, ring: &mut RingBuffer<T>) {
let mut write = self.buffer.write();
write.swap_buffer(ring);
}
}
#[test]
#[cfg(feature = "std")]
fn test_flib_buffer() {
let mut flip_buffer = FlipBuffer::with_capacity(256);
std::thread::scope(|scope| {
let flip_buffer = &flip_buffer;
for i in 0..10 {
scope.spawn(move || {
for j in 0..10 {
flip_buffer.push_sync(i * 10 + j).unwrap();
}
});
}
});
let mut idx = flip_buffer.drain().collect::<Vec<_>>();
idx.sort();
assert_eq!(idx, (0..100).collect::<Vec<_>>());
}
#[test]
#[cfg(feature = "std")]
fn test_flib_queue() {
let mut flip_queue = FlipQueue::<_>::with_capacity(1);
std::thread::scope(|scope| {
let flip_queue = &flip_queue;
for i in 0..10 {
scope.spawn(move || {
for j in 0..10 {
flip_queue.push_sync(i * 10 + j);
}
});
}
});
let mut idx = flip_queue.drain().collect::<Vec<_>>();
idx.sort();
assert_eq!(idx, (0..100).collect::<Vec<_>>());
}
#[test]
#[cfg(feature = "std")]
fn test_flib_queue_push_pop() {
let mut flip_queue = FlipQueue::<_>::with_capacity(1);
std::thread::scope(|scope| {
let flip_queue = &flip_queue;
for _ in 0..10 {
scope.spawn(move || {
for _ in 0..10 {
while let None = flip_queue.pop_sync() {
std::thread::yield_now();
}
}
});
}
for i in 0..10 {
scope.spawn(move || {
for j in 0..10 {
flip_queue.push_sync(i * 10 + j);
}
});
}
});
assert_eq!(flip_queue.pop(), None);
}