use std::{cmp::max, sync::atomic::AtomicUsize};
use crate::Context;
use atomic::Ordering;
use parking_lot::{Mutex, RwLock};
use super::notifier::Notifier;
use std::fmt::Debug;
pub struct MpmcCircularBuffer<T> {
buffer: Box<[Slot<T>]>,
head: AtomicUsize,
maintenance: Mutex<()>,
readers: AtomicUsize,
}
impl<T> Debug for MpmcCircularBuffer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MpmcCircularBuffer")
.field("buffer", &self.buffer)
.field("head", &self.head)
.field("readers", &self.readers)
.finish()
}
}
impl<T> MpmcCircularBuffer<T>
where
T: Clone,
{
pub fn new(capacity: usize) -> (Self, BufferReader) {
let capacity = max(2, capacity);
let mut vec = Vec::with_capacity(capacity);
for _ in 0..capacity {
vec.push(Slot::new(0));
}
let this = Self {
buffer: vec.into_boxed_slice(),
head: AtomicUsize::new(1),
readers: AtomicUsize::new(1),
maintenance: Mutex::new(()),
};
let reader = BufferReader { index: 1 };
(this, reader)
}
}
pub enum TryWrite<T> {
Pending(T),
Ready,
}
pub enum SlotTryWrite<T> {
Pending(T),
Ready,
Written(T),
}
impl<T> MpmcCircularBuffer<T> {
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn try_write(&self, mut value: T, cx: &Context<'_>) -> TryWrite<T> {
loop {
let head_id = self.head.load(Ordering::Acquire);
let head_slot = self.get_slot(head_id);
#[cfg(feature = "debug")]
log::debug!(
"[{}] Attempting write with required readers {:?}, slot index {:?} with {:?} readers of {:?} required",
head_id,
&self.readers,
head_slot.index,
head_slot.reads,
&self.readers
);
let try_write = head_slot.try_write(head_id, value, &self.readers, cx, || {
if let Err(_e) = self.head.compare_exchange(
head_id,
head_id + 1,
Ordering::SeqCst,
Ordering::Relaxed,
) {
#[cfg(feature = "debug")]
log::warn!(
"[{}] Expected {} head value, found {}",
head_id,
head_id + 1,
_e
);
}
});
match try_write {
SlotTryWrite::Pending(v) => {
return TryWrite::Pending(v);
}
SlotTryWrite::Ready => {
#[cfg(feature = "debug")]
let slot_index = head_id % self.len();
#[cfg(feature = "debug")]
log::info!(
"[{}] Write complete in slot {}, head incremented from {} to {}",
head_id,
slot_index,
head_id,
head_id + 1
);
return TryWrite::Ready;
}
SlotTryWrite::Written(v) => {
value = v;
continue;
}
}
}
}
pub fn new_reader(&self) -> BufferReader {
let _maint = self.maintenance.lock();
let index = self.head.load(Ordering::Acquire);
self.readers.fetch_add(1, Ordering::AcqRel);
self.mark_read_in_range(0, index);
#[cfg(feature = "debug")]
log::info!("[{}] New reader", index);
BufferReader { index }
}
fn mark_read_in_range(&self, min: usize, max: usize) {
for slot in self.buffer.iter() {
let readers = self.readers.load(Ordering::Acquire);
slot.mark_read_in_range(min, max, readers);
}
}
pub(in crate::sync::mpmc_circular_buffer) fn get_slot(&self, id: usize) -> &Slot<T> {
let index = id % self.len();
&self.buffer[index]
}
}
#[derive(Debug)]
pub struct BufferReader {
index: usize,
}
pub enum TryRead<T> {
Ready(T),
Pending,
}
impl BufferReader {
pub fn try_read<T>(&mut self, buffer: &MpmcCircularBuffer<T>, cx: &Context<'_>) -> TryRead<T>
where
T: Clone,
{
let index = self.index;
let slot = buffer.get_slot(index);
let try_read = slot.try_read(index, &buffer.readers, cx);
match &try_read {
TryRead::Ready(_) => {
self.index += 1;
#[cfg(feature = "debug")]
log::debug!(
"[{}] Read complete in slot {} with {:?} reads of {:?} required",
index,
index % buffer.len(),
slot.reads,
&buffer.readers,
);
}
TryRead::Pending => {
#[cfg(feature = "debug")]
log::debug!("[{}] Read pending, slot: {:?}", index, slot);
}
}
try_read
}
pub fn clone_with<T>(&self, buffer: &MpmcCircularBuffer<T>) -> Self {
let _maint = buffer.maintenance.lock();
buffer.readers.fetch_add(1, Ordering::AcqRel);
let index = self.index;
buffer.mark_read_in_range(0, index);
#[cfg(feature = "debug")]
log::error!("[{}] Cloned reader", index);
BufferReader { index }
}
pub fn drop_with<T>(&mut self, buffer: &MpmcCircularBuffer<T>) {
let _maint = buffer.maintenance.lock();
buffer
.buffer
.iter()
.for_each(|slot| slot.decrement_read_in_range(0, self.index));
buffer.readers.fetch_sub(1, Ordering::AcqRel);
for (_id, slot) in buffer.buffer.iter().enumerate() {
#[cfg(feature = "debug")]
log::debug!(
"[{}] Dropping reader, notifying slot {} with reads {:?} of new reader count {:?}",
self.index,
_id,
slot.reads,
buffer.readers,
);
slot.notify_readers_decreased(&buffer.readers);
}
#[cfg(feature = "debug")]
log::error!(
"[{}] Dropped reader, readers reduced to {:?}",
self.index,
buffer.readers
);
}
}
pub struct Slot<T> {
data: RwLock<Option<T>>,
reads: AtomicUsize,
index: AtomicUsize,
on_write: Notifier,
on_release: Notifier,
}
impl<T> Slot<T> {
pub fn new(index: usize) -> Self {
Self {
data: RwLock::new(None),
reads: AtomicUsize::new(0),
index: AtomicUsize::new(index),
on_write: Notifier::new(),
on_release: Notifier::new(),
}
}
pub fn try_write<OnWrite>(
&self,
index: usize,
value: T,
readers: &AtomicUsize,
cx: &Context<'_>,
on_write: OnWrite,
) -> SlotTryWrite<T>
where
OnWrite: FnOnce(),
{
loop {
let prev_index = self.index.load(Ordering::Acquire);
if prev_index >= index {
return SlotTryWrite::Written(value);
} else if prev_index != 0
&& self.reads.load(Ordering::Acquire) < readers.load(Ordering::Acquire)
{
self.on_release.subscribe(cx);
if prev_index < self.index.load(Ordering::Acquire) {
#[cfg(feature = "debug")]
log::warn!(
"[{}] Slot index advanced during write, invalidating subscription",
index
);
continue;
}
if self.reads.load(Ordering::Acquire) >= readers.load(Ordering::Acquire) {
#[cfg(feature = "debug")]
log::warn!(
"[{}] Reads incremented during write, invalidating subscription",
index
);
continue;
}
return SlotTryWrite::Pending(value);
}
let mut data = self.data.write();
if prev_index != 0
&& self.reads.load(Ordering::Acquire) < readers.load(Ordering::Acquire)
{
#[cfg(feature = "debug")]
log::warn!(
"[{}] Reads decreased during write (upgrading index {})",
index,
prev_index
);
continue;
}
if self
.index
.compare_exchange(prev_index, index, Ordering::AcqRel, Ordering::Relaxed)
.is_err()
{
continue;
}
on_write();
*data = Some(value);
self.reads.store(0, Ordering::Release);
self.on_write.notify();
return SlotTryWrite::Ready;
}
}
fn mark_read_in_range(&self, min: usize, max: usize, readers: usize) {
let _read = self.data.read();
let index = self.index.load(Ordering::Acquire);
if index >= min && index < max {
let reads = 1 + self.reads.fetch_add(1, Ordering::AcqRel);
#[cfg(feature = "debug")]
log::debug!(
"[{}] Mark read in range occurred. Increased reads to {} of required readers {}",
index,
reads,
readers
);
if reads >= readers {
self.on_release.notify();
}
}
}
fn decrement_read_in_range(&self, min: usize, max: usize) {
let _read = self.data.read();
let index = self.index.load(Ordering::Acquire);
if index >= min && index < max {
loop {
let reads = self.reads.load(Ordering::Acquire);
if reads == 0 {
return;
}
if self
.reads
.compare_exchange(reads, reads - 1, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
#[cfg(feature = "debug")]
log::debug!(
"[{}] Mark decrement in range occurred. Decreased reads to {}",
index,
reads - 1
);
return;
}
}
}
}
fn notify_readers_decreased(&self, readers: &AtomicUsize) {
if self.reads.load(Ordering::Acquire) >= readers.load(Ordering::Acquire) {
self.on_release.notify();
}
}
}
impl<T> Slot<T>
where
T: Clone,
{
#[allow(clippy::comparison_chain)]
pub fn try_read(&self, index: usize, readers: &AtomicUsize, cx: &Context<'_>) -> TryRead<T> {
loop {
let slot_index = self.index.load(Ordering::Acquire);
if slot_index < index {
self.on_write.subscribe(cx);
if self.index.load(Ordering::Acquire) >= index {
continue;
}
return TryRead::Pending;
} else if slot_index > index {
#[cfg(feature = "debug")]
log::error!(
"Slot index {} has advanced past reader position {}",
slot_index,
index
);
return TryRead::Pending;
}
let data_lock = self.data.read();
let reads = 1 + self.reads.fetch_add(1, Ordering::AcqRel);
#[cfg(feature = "debug")]
log::debug!(
"[{}] Read action occurred. Increased reads to {}",
index,
reads
);
let data_ref = data_lock.as_ref().unwrap();
let data_cloned = data_ref.clone();
if reads >= readers.load(Ordering::Acquire) {
self.on_release.notify();
}
break TryRead::Ready(data_cloned);
}
}
}
impl<T> Debug for Slot<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Slot")
.field("reads", &self.reads)
.field("index", &self.index)
.finish()
}
}