#![cfg_attr(all(not(test)), no_std)]
#![doc = include_str!("../README.md")]
extern crate alloc;
use core::{
cell::UnsafeCell,
fmt::Debug,
future::Future,
ops::{Deref, DerefMut, Range},
pin::Pin,
sync::atomic::Ordering,
task::{Context, Poll},
};
use alloc::{boxed::Box, vec::Vec};
use portable_atomic::AtomicU64;
use portable_atomic_util::{Arc, Weak};
use snafu::Snafu;
pub mod async_support;
pub mod stage;
#[derive(Debug, Clone)]
pub struct Flow<T: Flows> {
state: Arc<FlowState<T>>,
}
impl<T: Flows> Flow<T> {
pub fn new<const SUB: usize>(capacity: usize) -> (Self, [FlowSubscriber<T>; SUB])
where
T: Default,
{
assert!(capacity & (capacity - 1) == 0, "flow capacity _must_ be a power of two (like `2`, `4`, `256`, `2048`...), not {capacity}");
let mut buffer = Vec::with_capacity(capacity);
for _ in 0..capacity {
buffer.push(UnsafeCell::new(T::default()));
}
let buffer = buffer.into_boxed_slice();
let mut flow_state = FlowState {
buffer,
next_writable_seq: AtomicU64::new(0),
next_publishable_seq: AtomicU64::new(0),
next_receivable_seqs: Vec::with_capacity(SUB),
};
let mut subscriber_seqs = Vec::with_capacity(SUB);
for _ in 0..SUB {
subscriber_seqs.push(flow_state.add_subscriber_seq());
}
let flow_state = Arc::new(flow_state);
let subscribers: Vec<FlowSubscriber<T>> = subscriber_seqs
.into_iter()
.map(|seq| FlowSubscriber {
flow_state: flow_state.clone(),
next_receivable_seq: seq,
})
.collect();
(Self { state: flow_state }, subscribers.try_into().unwrap())
}
pub fn try_next(&mut self) -> Result<UnpublishedData<'_, T>, Error> {
self.try_next_internal()
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> impl Future<Output = Result<UnpublishedData<'_, T>, Error>> {
PublishNextFuture { flow: self }
}
#[inline(always)]
fn try_next_internal(&self) -> Result<UnpublishedData<'_, T>, Error> {
if let Some(next) = self.state.try_claim_publishable() {
let next_item = UnpublishedData {
flow: self,
sequence: next,
data: unsafe { self.state.get_mut(next) },
};
Ok(next_item)
} else {
Err(Error::Full)
}
}
}
struct PublishNextFuture<'a, T: Flows> {
flow: &'a Flow<T>,
}
impl<'a, T: Flows> Future for PublishNextFuture<'a, T> {
type Output = Result<UnpublishedData<'a, T>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.flow.try_next_internal() {
Ok(next) => Poll::Ready(Ok(next)),
Err(Error::Full) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
struct FlowState<T: Flows> {
buffer: Box<[UnsafeCell<T>]>,
next_writable_seq: AtomicU64,
next_publishable_seq: AtomicU64,
next_receivable_seqs: Vec<Weak<AtomicU64>>,
}
impl<T> FlowState<T>
where
T: Flows,
{
fn add_subscriber_seq(&mut self) -> Arc<AtomicU64> {
let next_receivable_seq = Arc::new(AtomicU64::new(0));
self.next_receivable_seqs
.push(Arc::downgrade(&next_receivable_seq));
next_receivable_seq
}
#[inline(always)]
fn try_claim_publishable(&self) -> Option<u64> {
let next_writable = self.next_writable_seq.load(Ordering::SeqCst);
let mut min_receivable_seq = self.next_publishable_seq.load(Ordering::SeqCst);
for next_received_seq in self.next_receivable_seqs.iter() {
if let Some(seq) = next_received_seq.upgrade() {
min_receivable_seq = min_receivable_seq.min(seq.load(Ordering::SeqCst));
}
}
if min_receivable_seq + self.buffer.len() as u64 > next_writable
&& self
.next_writable_seq
.compare_exchange(
next_writable,
next_writable + 1,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
return Some(next_writable);
}
None
}
#[inline(always)]
fn try_publish(&self, sequence: u64) -> bool {
self.next_publishable_seq
.compare_exchange_weak(sequence, sequence + 1, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
}
#[allow(clippy::mut_from_ref)]
#[inline(always)]
unsafe fn get(&self, sequence: u64) -> &T {
assert!(self.buffer.len() & (self.buffer.len() - 1) == 0);
let index = (self.buffer.len() - 1) & sequence as usize;
&*self.buffer.get_unchecked(index).get()
}
#[allow(clippy::mut_from_ref)]
#[inline(always)]
unsafe fn get_mut(&self, sequence: u64) -> &mut T {
assert!(self.buffer.len() & (self.buffer.len() - 1) == 0);
let index = (self.buffer.len() - 1) & sequence as usize;
&mut *self.buffer.get_unchecked(index).get()
}
}
impl<T> Debug for FlowState<T>
where
T: Flows,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Flow")
.field("capacity", &self.buffer.len())
.field("next_writable_seq", &self.next_writable_seq)
.field("next_publishable_seq", &self.next_publishable_seq)
.field("next_receivable_seqs", &self.next_receivable_seqs)
.finish()
}
}
pub struct FlowSubscriber<T: Flows> {
flow_state: Arc<FlowState<T>>,
next_receivable_seq: Arc<AtomicU64>,
}
impl<T: Flows> FlowSubscriber<T> {
pub fn try_next(&mut self) -> Result<impl Deref<Target = T> + '_, Error> {
self.try_next_internal()
}
#[allow(clippy::should_implement_trait)]
pub fn next(&mut self) -> impl Future<Output = Result<impl Deref<Target = T> + '_, Error>> {
ReceiveNextFuture { subscriber: self }
}
#[inline(always)]
fn try_next_internal(&self) -> Result<PublishedData<'_, T>, Error> {
if let Some(next) = self.receivable_seqs().next() {
let data = PublishedData {
subscription: self,
sequence: next,
data: unsafe { self.flow_state.get(next) },
};
Ok(data)
} else {
Err(Error::Ahead)
}
}
#[inline(always)]
fn receivable_seqs(&self) -> Range<u64> {
self.next_receivable_seq.load(Ordering::SeqCst)
..self.flow_state.next_publishable_seq.load(Ordering::SeqCst)
}
#[inline(always)]
fn receive_up_to(&self, sequence: u64) {
self.next_receivable_seq
.fetch_max(sequence + 1, Ordering::SeqCst);
}
}
struct ReceiveNextFuture<'a, T: Flows> {
subscriber: &'a FlowSubscriber<T>,
}
impl<'a, T: Flows> Future for ReceiveNextFuture<'a, T> {
type Output = Result<PublishedData<'a, T>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.subscriber.try_next_internal() {
Ok(next) => Poll::Ready(Ok(next)),
Err(Error::Ahead) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl<T> Debug for FlowSubscriber<T>
where
T: Flows,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("OutBarrier")
.field("flow_state", &self.flow_state)
.field("next_receivable_seq", &self.next_receivable_seq)
.finish()
}
}
unsafe impl<T> Send for FlowState<T> where T: Flows {}
unsafe impl<T> Sync for FlowState<T> where T: Flows {}
pub trait Flows: Send + Sync + 'static {}
impl<T> Flows for T where T: Send + Sync + 'static {}
#[derive(Debug)]
pub struct UnpublishedData<'a, T: Flows> {
flow: &'a Flow<T>,
sequence: u64,
data: &'a mut T,
}
impl<T: Flows> UnpublishedData<'_, T> {
pub fn sequence(&self) -> u64 {
self.sequence
}
pub fn publish(self, data: T) {
*self.data = data;
drop(self)
}
}
impl<T: Flows> Deref for UnpublishedData<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.data
}
}
impl<T: Flows> DerefMut for UnpublishedData<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data
}
}
impl<T: Flows> Drop for UnpublishedData<'_, T> {
fn drop(&mut self) {
while !self.flow.state.try_publish(self.sequence) {}
}
}
#[derive(Debug)]
struct PublishedData<'a, T: Flows> {
subscription: &'a FlowSubscriber<T>,
sequence: u64,
data: &'a T,
}
impl<T: Flows> Deref for PublishedData<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.data
}
}
impl<T: Flows> Drop for PublishedData<'_, T> {
fn drop(&mut self) {
self.subscription.receive_up_to(self.sequence);
}
}
#[derive(Debug, Snafu, PartialEq)]
pub enum Error {
Full,
Ahead,
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn pubs_and_subs() -> Result<(), crate::Error> {
let (mut publisher, [mut subscriber]) = Flow::new(2);
let mut data = publisher.try_next().unwrap();
*data = 42u32;
assert_eq!(0, data.sequence());
drop(data);
assert_eq!(0..1, subscriber.receivable_seqs());
let data = subscriber.try_next().unwrap();
assert!(42u32 == *data);
drop(data);
assert_eq!(1..1, subscriber.receivable_seqs());
Ok(())
}
}