codas_flow/
lib.rs

1#![cfg_attr(all(not(test)), no_std)]
2// Use the README file as the root-level
3// docs for this library.
4#![doc = include_str!("../README.md")]
5
6extern crate alloc;
7
8use core::{
9    cell::UnsafeCell,
10    fmt::Debug,
11    future::Future,
12    ops::{Deref, DerefMut, Range},
13    pin::Pin,
14    sync::atomic::Ordering,
15    task::{Context, Poll},
16};
17
18use alloc::{boxed::Box, vec::Vec};
19use portable_atomic::AtomicU64;
20use portable_atomic_util::{Arc, Weak};
21use snafu::Snafu;
22
23pub mod async_support;
24pub mod stage;
25
26/// Bounded queue for publishing and receiving
27/// data from (a)synchronous tasks.
28///
29/// Refer to the [crate] docs for more info.
30#[derive(Debug, Clone)]
31pub struct Flow<T: Flows> {
32    state: Arc<FlowState<T>>,
33}
34
35impl<T: Flows> Flow<T> {
36    /// Returns a tuple of `(flow, [subscribers])`,
37    /// where `capacity` is the maximum capacity
38    /// of the flow.
39    ///
40    /// # Panics
41    ///
42    /// Iff `capacity` is _not_ a power of two
43    /// (like `2`, `32`, `256`, and so on).
44    pub fn new<const SUB: usize>(capacity: usize) -> (Self, [FlowSubscriber<T>; SUB])
45    where
46        T: Default,
47    {
48        assert!(capacity & (capacity - 1) == 0, "flow capacity _must_ be a power of two (like `2`, `4`, `256`, `2048`...), not {capacity}");
49
50        // Allocate the flow buffer.
51        let mut buffer = Vec::with_capacity(capacity);
52        for _ in 0..capacity {
53            buffer.push(UnsafeCell::new(T::default()));
54        }
55        let buffer = buffer.into_boxed_slice();
56
57        // Build the flow state.
58        let mut flow_state = FlowState {
59            buffer,
60            next_writable_seq: AtomicU64::new(0),
61            next_publishable_seq: AtomicU64::new(0),
62            next_receivable_seqs: Vec::with_capacity(SUB),
63        };
64
65        // Add subscribers to the state.
66        let mut subscriber_seqs = Vec::with_capacity(SUB);
67        for _ in 0..SUB {
68            subscriber_seqs.push(flow_state.add_subscriber_seq());
69        }
70
71        // Finalize flow state and wrap subscriber
72        // sequences in the subscriber API.
73        let flow_state = Arc::new(flow_state);
74        let subscribers: Vec<FlowSubscriber<T>> = subscriber_seqs
75            .into_iter()
76            .map(|seq| FlowSubscriber {
77                flow_state: flow_state.clone(),
78                next_receivable_seq: seq,
79            })
80            .collect();
81
82        (Self { state: flow_state }, subscribers.try_into().unwrap())
83    }
84
85    /// Tries to claim the next publishable
86    /// sequence in the flow, returning
87    /// a [`UnpublishedData`] iff successful.
88    pub fn try_next(&mut self) -> Result<UnpublishedData<T>, Error> {
89        self.try_next_internal()
90    }
91
92    /// Awaits and claims the next publishable sequence
93    /// in the flow, returning a [`UnpublishedData`]
94    /// iff successful.
95    #[allow(clippy::should_implement_trait)]
96    pub fn next(&mut self) -> impl Future<Output = Result<UnpublishedData<T>, Error>> {
97        PublishNextFuture { flow: self }
98    }
99
100    /// Implementation of [`Self::try_next`] that
101    /// takes `self` as an immutable reference with
102    /// interior mutability.
103    #[inline(always)]
104    fn try_next_internal(&self) -> Result<UnpublishedData<T>, Error> {
105        if let Some(next) = self.state.try_claim_publishable() {
106            let next_item = UnpublishedData {
107                flow: self,
108                sequence: next,
109                data: unsafe { self.state.get_mut(next) },
110            };
111            Ok(next_item)
112        } else {
113            Err(Error::Full)
114        }
115    }
116}
117
118/// Future returned by [`Flow::next`].
119struct PublishNextFuture<'a, T: Flows> {
120    flow: &'a Flow<T>,
121}
122
123impl<'a, T: Flows> Future for PublishNextFuture<'a, T> {
124    type Output = Result<UnpublishedData<'a, T>, Error>;
125
126    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
127        match self.flow.try_next_internal() {
128            Ok(next) => Poll::Ready(Ok(next)),
129            Err(Error::Full) => {
130                cx.waker().wake_by_ref();
131                Poll::Pending
132            }
133            Err(e) => Poll::Ready(Err(e)),
134        }
135    }
136}
137
138/// Internal state of a [`Flow`].
139///
140/// This state is placed in a separate data
141/// structure from the rest of a [`Flow`]
142/// to simplify sharing references to the state
143/// between a flow and it's subscribers.
144struct FlowState<T: Flows> {
145    /// Pre-allocated contiguous buffer of
146    /// data entries in the flow.
147    ///
148    /// This buffer is a ring buffer: When it
149    /// is full, writes "wrap" around to the
150    /// beginning of the buffer, overwriting
151    /// the oldest data.
152    ///
153    /// Each data entry in the buffer is
154    /// wrapped in an [`UnsafeCell`], enabling
155    /// concurrent tasks to immutably read
156    /// the same data at the same time.
157    buffer: Box<[UnsafeCell<T>]>,
158
159    /// The sequence number that will be assigned
160    /// to the _next_ data entry written into the flow.
161    next_writable_seq: AtomicU64,
162
163    /// The sequence number of the next data entry
164    /// that will become readable by the flow's
165    /// subscriber(s).
166    ///
167    /// All data entries with sequences less than
168    /// this number are assumed to be readable.
169    next_publishable_seq: AtomicU64,
170
171    /// The sequence numbers of the next data entry
172    /// that will be read by each of the flow's
173    /// subscriber(s).
174    ///
175    /// All data entries with sequences less than
176    /// the _lowest_ of these sequence numbers are
177    /// assumed to be overwritable.
178    next_receivable_seqs: Vec<Weak<AtomicU64>>,
179}
180
181impl<T> FlowState<T>
182where
183    T: Flows,
184{
185    /// Adds and returns a new subscriber sequence
186    /// number to the flow.
187    fn add_subscriber_seq(&mut self) -> Arc<AtomicU64> {
188        let next_receivable_seq = Arc::new(AtomicU64::new(0));
189        self.next_receivable_seqs
190            .push(Arc::downgrade(&next_receivable_seq));
191        next_receivable_seq
192    }
193
194    /// Tries to claim and return the next
195    /// publishable data sequence in the flow.
196    ///
197    /// Iff `Some(sequence)` is returned, the
198    /// sequence _must_ be published via
199    /// [`Self::try_publish`], or the flow
200    /// will stall from backpressure.
201    ///
202    /// Iff `None` is returned, the flow is full.
203    #[inline(always)]
204    fn try_claim_publishable(&self) -> Option<u64> {
205        let next_writable = self.next_writable_seq.load(Ordering::SeqCst);
206
207        // Calculate the minimum receivable sequence
208        // across all subscribers, defaulting to the
209        // current sequence that's publishable.
210        let mut min_receivable_seq = self.next_publishable_seq.load(Ordering::SeqCst);
211        for next_received_seq in self.next_receivable_seqs.iter() {
212            if let Some(seq) = next_received_seq.upgrade() {
213                min_receivable_seq = min_receivable_seq.min(seq.load(Ordering::SeqCst));
214            }
215        }
216
217        // Only claim if there's space.
218        if min_receivable_seq + self.buffer.len() as u64 > next_writable
219            && self
220                .next_writable_seq
221                .compare_exchange(
222                    next_writable,
223                    next_writable + 1,
224                    Ordering::SeqCst,
225                    Ordering::SeqCst,
226                )
227                .is_ok()
228        {
229            return Some(next_writable);
230        }
231
232        None
233    }
234
235    /// Tries to publish `sequence`, returning
236    /// true iff the sequence was published.
237    #[inline(always)]
238    fn try_publish(&self, sequence: u64) -> bool {
239        self.next_publishable_seq
240            .compare_exchange_weak(sequence, sequence + 1, Ordering::SeqCst, Ordering::SeqCst)
241            .is_ok()
242    }
243
244    /// Returns a reference to the data at `sequence`.
245    ///
246    /// Refer to [`Self::get_mut`] for information
247    /// on the safety properties of this function.
248    ///
249    /// # Panics
250    ///
251    /// Iff any other thread attempts to acquire a _mutable_
252    /// reference to `sequence` at the same time.
253    #[allow(clippy::mut_from_ref)]
254    #[inline(always)]
255    unsafe fn get(&self, sequence: u64) -> &T {
256        assert!(self.buffer.len() & (self.buffer.len() - 1) == 0);
257
258        // Convert sequence to an queue index.
259        let index = (self.buffer.len() - 1) & sequence as usize;
260
261        // Array access will always be within bounds.
262        &*self.buffer.get_unchecked(index).get()
263    }
264
265    /// Returns a mutable reference to the data at `sequence`.
266    ///
267    /// # Safety
268    ///
269    /// This function is unsafe because it _can_ return
270    /// multiple mutable references to the sae data.
271    ///
272    /// This function _is_ safe to call from any task which
273    /// has successfully claimed a sequence number via
274    /// [`Self::try_claim_publishable`] and
275    /// has not yet published that sequence number
276    /// via [`Self::try_publish`]. In this scenario,
277    /// the task is guaranteed to be the only one with
278    /// read/write access to the data.
279    ///
280    /// This function's behavior is undefined if the task
281    /// (having claimed a sequence number via
282    /// [`Self::try_claim_publishable`]) calls this
283    /// function _repeatedly_ with the same sequence number.
284    ///
285    /// # Panics
286    ///
287    /// Iff the same or different tasks attempt to acquire
288    /// more than one _mutable_ reference to `sequence`.
289    #[allow(clippy::mut_from_ref)]
290    #[inline(always)]
291    unsafe fn get_mut(&self, sequence: u64) -> &mut T {
292        assert!(self.buffer.len() & (self.buffer.len() - 1) == 0);
293
294        // Convert sequence to an queue index.
295        let index = (self.buffer.len() - 1) & sequence as usize;
296
297        // Array access will always be within bounds.
298        &mut *self.buffer.get_unchecked(index).get()
299    }
300}
301
302impl<T> Debug for FlowState<T>
303where
304    T: Flows,
305{
306    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
307        f.debug_struct("Flow")
308            .field("capacity", &self.buffer.len())
309            .field("next_writable_seq", &self.next_writable_seq)
310            .field("next_publishable_seq", &self.next_publishable_seq)
311            .field("next_receivable_seqs", &self.next_receivable_seqs)
312            .finish()
313    }
314}
315
316/// Subscriber which receives data from a [`Flow`].
317pub struct FlowSubscriber<T: Flows> {
318    flow_state: Arc<FlowState<T>>,
319
320    /// See [`FlowState::next_receivable_seqs`].
321    next_receivable_seq: Arc<AtomicU64>,
322}
323
324impl<T: Flows> FlowSubscriber<T> {
325    /// Returns a reference to the next data
326    /// in the flow, if the flow is active and
327    /// any data is available.
328    pub fn try_next(&mut self) -> Result<impl Deref<Target = T> + '_, Error> {
329        self.try_next_internal()
330    }
331
332    /// Awaits and returns a reference to the next
333    /// data  in the flow, if the flow is active.
334    #[allow(clippy::should_implement_trait)]
335    pub fn next(&mut self) -> impl Future<Output = Result<impl Deref<Target = T> + '_, Error>> {
336        ReceiveNextFuture { subscriber: self }
337    }
338
339    /// Implementation of [`Self::try_next`] that
340    /// takes `self` as an immutable reference with
341    /// interior mutability.
342    #[inline(always)]
343    fn try_next_internal(&self) -> Result<PublishedData<T>, Error> {
344        if let Some(next) = self.receivable_seqs().next() {
345            let data = PublishedData {
346                subscription: self,
347                sequence: next,
348                data: unsafe { self.flow_state.get(next) },
349            };
350
351            Ok(data)
352        } else {
353            Err(Error::Ahead)
354        }
355    }
356
357    /// Returns the range of data sequence numbers
358    /// that are receivable by this subscriber.
359    #[inline(always)]
360    fn receivable_seqs(&self) -> Range<u64> {
361        self.next_receivable_seq.load(Ordering::SeqCst)
362            ..self.flow_state.next_publishable_seq.load(Ordering::SeqCst)
363    }
364
365    /// Marks all sequences up to (and including)
366    /// `sequence` as received by this subscriber.
367    #[inline(always)]
368    fn receive_up_to(&self, sequence: u64) {
369        self.next_receivable_seq
370            .fetch_max(sequence + 1, Ordering::SeqCst);
371    }
372}
373
374/// Future returned by [`FlowSubscriber::next`].
375struct ReceiveNextFuture<'a, T: Flows> {
376    subscriber: &'a FlowSubscriber<T>,
377}
378
379impl<'a, T: Flows> Future for ReceiveNextFuture<'a, T> {
380    type Output = Result<PublishedData<'a, T>, Error>;
381
382    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
383        match self.subscriber.try_next_internal() {
384            Ok(next) => Poll::Ready(Ok(next)),
385            Err(Error::Ahead) => {
386                cx.waker().wake_by_ref();
387                Poll::Pending
388            }
389            Err(e) => Poll::Ready(Err(e)),
390        }
391    }
392}
393
394impl<T> Debug for FlowSubscriber<T>
395where
396    T: Flows,
397{
398    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
399        f.debug_struct("OutBarrier")
400            .field("flow_state", &self.flow_state)
401            .field("next_receivable_seq", &self.next_receivable_seq)
402            .finish()
403    }
404}
405
406// Flow states may be sent between threads and
407// safely accessed concurrently.
408unsafe impl<T> Send for FlowState<T> where T: Flows {}
409unsafe impl<T> Sync for FlowState<T> where T: Flows {}
410
411/// Blanket trait for data in a [`Flow`].
412pub trait Flows: Send + Sync + 'static {}
413impl<T> Flows for T where T: Send + Sync + 'static {}
414
415/// Reference to mutable, unpublished data in a [`Flow`].
416///
417/// When this reference is dropped, the data
418/// is marked as published into the [`Flow`].
419#[derive(Debug)]
420pub struct UnpublishedData<'a, T: Flows> {
421    flow: &'a Flow<T>,
422    sequence: u64,
423    data: &'a mut T,
424}
425
426impl<T: Flows> UnpublishedData<'_, T> {
427    /// Returns the data's sequence number.
428    pub fn sequence(&self) -> u64 {
429        self.sequence
430    }
431
432    /// Publishes `data` into this sequence.
433    pub fn publish(self, data: T) {
434        *self.data = data;
435        drop(self)
436    }
437}
438
439impl<T: Flows> Deref for UnpublishedData<'_, T> {
440    type Target = T;
441
442    fn deref(&self) -> &Self::Target {
443        self.data
444    }
445}
446
447impl<T: Flows> DerefMut for UnpublishedData<'_, T> {
448    fn deref_mut(&mut self) -> &mut Self::Target {
449        self.data
450    }
451}
452
453impl<T: Flows> Drop for UnpublishedData<'_, T> {
454    fn drop(&mut self) {
455        while !self.flow.state.try_publish(self.sequence) {}
456    }
457}
458
459/// Return value of [`FlowSubscriber::try_next`].
460///
461/// When this value is dropped, the data will
462/// be marked as received by its corresponding
463/// subscriber.
464#[derive(Debug)]
465struct PublishedData<'a, T: Flows> {
466    subscription: &'a FlowSubscriber<T>,
467    sequence: u64,
468    data: &'a T,
469}
470
471impl<T: Flows> Deref for PublishedData<'_, T> {
472    type Target = T;
473
474    fn deref(&self) -> &Self::Target {
475        self.data
476    }
477}
478
479impl<T: Flows> Drop for PublishedData<'_, T> {
480    fn drop(&mut self) {
481        self.subscription.receive_up_to(self.sequence);
482    }
483}
484
485/// Enumeration of non-retryable errors
486/// that may happen while using flows.
487#[derive(Debug, Snafu, PartialEq)]
488pub enum Error {
489    /// Publishing is temporarily impossible:
490    /// the flow is full of unreceived data.
491    Full,
492
493    /// The flow may or may not contain data, but the
494    /// subscriber has already read all data presently
495    /// in the flow.
496    Ahead,
497}
498
499#[cfg(test)]
500mod test {
501    use super::*;
502
503    /// Tests basic API functionality.
504    #[test]
505    fn pubs_and_subs() -> Result<(), crate::Error> {
506        // Prepare pubsub.
507        let (mut publisher, [mut subscriber]) = Flow::new(2);
508
509        // Publish some data.
510        let mut data = publisher.try_next().unwrap();
511        *data = 42u32;
512        assert_eq!(0, data.sequence());
513        drop(data);
514
515        // Check barrier sequences.
516        assert_eq!(0..1, subscriber.receivable_seqs());
517
518        // Receive some data.
519        let data = subscriber.try_next().unwrap();
520        assert!(42u32 == *data);
521        drop(data);
522
523        // Check barrier sequences.
524        assert_eq!(1..1, subscriber.receivable_seqs());
525
526        Ok(())
527    }
528}