codas_flow/
lib.rs

1#![cfg_attr(all(not(test)), no_std)]
2// Use the README file as the root-level
3// docs for this library.
4#![doc = include_str!("../README.md")]
5
6extern crate alloc;
7
8use core::{
9    cell::UnsafeCell,
10    fmt::Debug,
11    future::Future,
12    ops::{Deref, DerefMut, Range},
13    pin::Pin,
14    sync::atomic::Ordering,
15    task::{Context, Poll},
16};
17
18use alloc::{boxed::Box, vec::Vec};
19use portable_atomic::AtomicU64;
20use portable_atomic_util::{Arc, Weak};
21use snafu::Snafu;
22
23pub mod async_support;
24pub mod stage;
25
26/// Bounded queue for publishing and receiving
27/// data from (a)synchronous tasks.
28///
29/// Refer to the [crate] docs for more info.
30#[derive(Debug, Clone)]
31pub struct Flow<T: Flows> {
32    state: Arc<FlowState<T>>,
33}
34
35impl<T: Flows> Flow<T> {
36    /// Returns a tuple of `(flow, [subscribers])`,
37    /// where `capacity` is the maximum capacity
38    /// of the flow.
39    ///
40    /// # Panics
41    ///
42    /// Iff `capacity` is _not_ a power of two
43    /// (like `2`, `32`, `256`, and so on).
44    pub fn new<const SUB: usize>(capacity: usize) -> (Self, [FlowSubscriber<T>; SUB]) {
45        assert!(capacity & (capacity - 1) == 0, "flow capacity _must_ be a power of two (like `2`, `4`, `256`, `2048`...), not {capacity}");
46
47        // Allocate the flow buffer.
48        let mut buffer = Vec::with_capacity(capacity);
49        for _ in 0..capacity {
50            buffer.push(UnsafeCell::new(T::default()));
51        }
52        let buffer = buffer.into_boxed_slice();
53
54        // Build the flow state.
55        let mut flow_state = FlowState {
56            buffer,
57            next_writable_seq: AtomicU64::new(0),
58            next_publishable_seq: AtomicU64::new(0),
59            next_receivable_seqs: Vec::with_capacity(SUB),
60        };
61
62        // Add subscribers to the state.
63        let mut subscriber_seqs = Vec::with_capacity(SUB);
64        for _ in 0..SUB {
65            subscriber_seqs.push(flow_state.add_subscriber_seq());
66        }
67
68        // Finalize flow state and wrap subscriber
69        // sequences in the subscriber API.
70        let flow_state = Arc::new(flow_state);
71        let subscribers: Vec<FlowSubscriber<T>> = subscriber_seqs
72            .into_iter()
73            .map(|seq| FlowSubscriber {
74                flow_state: flow_state.clone(),
75                next_receivable_seq: seq,
76            })
77            .collect();
78
79        (Self { state: flow_state }, subscribers.try_into().unwrap())
80    }
81
82    /// Tries to claim the next publishable
83    /// sequence in the flow, returning
84    /// a [`UnpublishedData`] iff successful.
85    pub fn try_next(&mut self) -> Result<UnpublishedData<T>, Error> {
86        self.try_next_internal()
87    }
88
89    /// Awaits and claims the next publishable sequence
90    /// in the flow, returning a [`UnpublishedData`]
91    /// iff successful.
92    #[allow(clippy::should_implement_trait)]
93    pub fn next(&mut self) -> impl Future<Output = Result<UnpublishedData<T>, Error>> {
94        PublishNextFuture { flow: self }
95    }
96
97    /// Implementation of [`Self::try_next`] that
98    /// takes `self` as an immutable reference with
99    /// interior mutability.
100    #[inline(always)]
101    fn try_next_internal(&self) -> Result<UnpublishedData<T>, Error> {
102        if let Some(next) = self.state.try_claim_publishable() {
103            let next_item = UnpublishedData {
104                flow: self,
105                sequence: next,
106                data: unsafe { self.state.get_mut(next) },
107            };
108            Ok(next_item)
109        } else {
110            Err(Error::Full)
111        }
112    }
113}
114
115/// Future returned by [`Flow::next`].
116struct PublishNextFuture<'a, T: Flows> {
117    flow: &'a Flow<T>,
118}
119
120impl<'a, T: Flows> Future for PublishNextFuture<'a, T> {
121    type Output = Result<UnpublishedData<'a, T>, Error>;
122
123    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
124        match self.flow.try_next_internal() {
125            Ok(next) => Poll::Ready(Ok(next)),
126            Err(Error::Full) => {
127                cx.waker().wake_by_ref();
128                Poll::Pending
129            }
130            Err(e) => Poll::Ready(Err(e)),
131        }
132    }
133}
134
135/// Internal state of a [`Flow`].
136///
137/// This state is placed in a separate data
138/// structure from the rest of a [`Flow`]
139/// to simplify sharing references to the state
140/// between a flow and it's subscribers.
141struct FlowState<T: Flows> {
142    /// Pre-allocated contiguous buffer of
143    /// data entries in the flow.
144    ///
145    /// This buffer is a ring buffer: When it
146    /// is full, writes "wrap" around to the
147    /// beginning of the buffer, overwriting
148    /// the oldest data.
149    ///
150    /// Each data entry in the buffer is
151    /// wrapped in an [`UnsafeCell`], enabling
152    /// concurrent tasks to immutably read
153    /// the same data at the same time.
154    buffer: Box<[UnsafeCell<T>]>,
155
156    /// The sequence number that will be assigned
157    /// to the _next_ data entry written into the flow.
158    next_writable_seq: AtomicU64,
159
160    /// The sequence number of the next data entry
161    /// that will become readable by the flow's
162    /// subscriber(s).
163    ///
164    /// All data entries with sequences less than
165    /// this number are assumed to be readable.
166    next_publishable_seq: AtomicU64,
167
168    /// The sequence numbers of the next data entry
169    /// that will be read by each of the flow's
170    /// subscriber(s).
171    ///
172    /// All data entries with sequences less than
173    /// the _lowest_ of these sequence numbers are
174    /// assumed to be overwritable.
175    next_receivable_seqs: Vec<Weak<AtomicU64>>,
176}
177
178impl<T> FlowState<T>
179where
180    T: Flows,
181{
182    /// Adds and returns a new subscriber sequence
183    /// number to the flow.
184    fn add_subscriber_seq(&mut self) -> Arc<AtomicU64> {
185        let next_receivable_seq = Arc::new(AtomicU64::new(0));
186        self.next_receivable_seqs
187            .push(Arc::downgrade(&next_receivable_seq));
188        next_receivable_seq
189    }
190
191    /// Tries to claim and return the next
192    /// publishable data sequence in the flow.
193    ///
194    /// Iff `Some(sequence)` is returned, the
195    /// sequence _must_ be published via
196    /// [`Self::try_publish`], or the flow
197    /// will stall from backpressure.
198    ///
199    /// Iff `None` is returned, the flow is full.
200    #[inline(always)]
201    fn try_claim_publishable(&self) -> Option<u64> {
202        let next_writable = self.next_writable_seq.load(Ordering::SeqCst);
203
204        // Calculate the minimum receivable sequence
205        // across all subscribers, defaulting to the
206        // current sequence that's publishable.
207        let mut min_receivable_seq = self.next_publishable_seq.load(Ordering::SeqCst);
208        for next_received_seq in self.next_receivable_seqs.iter() {
209            if let Some(seq) = next_received_seq.upgrade() {
210                min_receivable_seq = min_receivable_seq.min(seq.load(Ordering::SeqCst));
211            }
212        }
213
214        // Only claim if there's space.
215        if min_receivable_seq + self.buffer.len() as u64 > next_writable
216            && self
217                .next_writable_seq
218                .compare_exchange(
219                    next_writable,
220                    next_writable + 1,
221                    Ordering::SeqCst,
222                    Ordering::SeqCst,
223                )
224                .is_ok()
225        {
226            return Some(next_writable);
227        }
228
229        None
230    }
231
232    /// Tries to publish `sequence`, returning
233    /// true iff the sequence was published.
234    #[inline(always)]
235    fn try_publish(&self, sequence: u64) -> bool {
236        self.next_publishable_seq
237            .compare_exchange_weak(sequence, sequence + 1, Ordering::SeqCst, Ordering::SeqCst)
238            .is_ok()
239    }
240
241    /// Returns a reference to the data at `sequence`.
242    ///
243    /// Refer to [`Self::get_mut`] for information
244    /// on the safety properties of this function.
245    ///
246    /// # Panics
247    ///
248    /// Iff any other thread attempts to acquire a _mutable_
249    /// reference to `sequence` at the same time.
250    #[allow(clippy::mut_from_ref)]
251    #[inline(always)]
252    unsafe fn get(&self, sequence: u64) -> &T {
253        assert!(self.buffer.len() & (self.buffer.len() - 1) == 0);
254
255        // Convert sequence to an queue index.
256        let index = (self.buffer.len() - 1) & sequence as usize;
257
258        // Array access will always be within bounds.
259        &*self.buffer.get_unchecked(index).get()
260    }
261
262    /// Returns a mutable reference to the data at `sequence`.
263    ///
264    /// # Safety
265    ///
266    /// This function is unsafe because it _can_ return
267    /// multiple mutable references to the sae data.
268    ///
269    /// This function _is_ safe to call from any task which
270    /// has successfully claimed a sequence number via
271    /// [`Self::try_claim_publishable`] and
272    /// has not yet published that sequence number
273    /// via [`Self::try_publish`]. In this scenario,
274    /// the task is guaranteed to be the only one with
275    /// read/write access to the data.
276    ///
277    /// This function's behavior is undefined if the task
278    /// (having claimed a sequence number via
279    /// [`Self::try_claim_publishable`]) calls this
280    /// function _repeatedly_ with the same sequence number.
281    ///
282    /// # Panics
283    ///
284    /// Iff the same or different tasks attempt to acquire
285    /// more than one _mutable_ reference to `sequence`.
286    #[allow(clippy::mut_from_ref)]
287    #[inline(always)]
288    unsafe fn get_mut(&self, sequence: u64) -> &mut T {
289        assert!(self.buffer.len() & (self.buffer.len() - 1) == 0);
290
291        // Convert sequence to an queue index.
292        let index = (self.buffer.len() - 1) & sequence as usize;
293
294        // Array access will always be within bounds.
295        &mut *self.buffer.get_unchecked(index).get()
296    }
297}
298
299impl<T> Debug for FlowState<T>
300where
301    T: Flows,
302{
303    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
304        f.debug_struct("Flow")
305            .field("capacity", &self.buffer.len())
306            .field("next_writable_seq", &self.next_writable_seq)
307            .field("next_publishable_seq", &self.next_publishable_seq)
308            .field("next_receivable_seqs", &self.next_receivable_seqs)
309            .finish()
310    }
311}
312
313/// Subscriber which receives data from a [`Flow`].
314pub struct FlowSubscriber<T: Flows> {
315    flow_state: Arc<FlowState<T>>,
316
317    /// See [`FlowState::next_receivable_seqs`].
318    next_receivable_seq: Arc<AtomicU64>,
319}
320
321impl<T: Flows> FlowSubscriber<T> {
322    /// Returns a reference to the next data
323    /// in the flow, if the flow is active and
324    /// any data is available.
325    pub fn try_next(&mut self) -> Result<impl Deref<Target = T> + '_, Error> {
326        self.try_next_internal()
327    }
328
329    /// Awaits and returns a reference to the next
330    /// data  in the flow, if the flow is active.
331    #[allow(clippy::should_implement_trait)]
332    pub fn next(&mut self) -> impl Future<Output = Result<impl Deref<Target = T> + '_, Error>> {
333        ReceiveNextFuture { subscriber: self }
334    }
335
336    /// Implementation of [`Self::try_next`] that
337    /// takes `self` as an immutable reference with
338    /// interior mutability.
339    #[inline(always)]
340    fn try_next_internal(&self) -> Result<PublishedData<T>, Error> {
341        if let Some(next) = self.receivable_seqs().next() {
342            let data = PublishedData {
343                subscription: self,
344                sequence: next,
345                data: unsafe { self.flow_state.get(next) },
346            };
347
348            Ok(data)
349        } else {
350            Err(Error::Ahead)
351        }
352    }
353
354    /// Returns the range of data sequence numbers
355    /// that are receivable by this subscriber.
356    #[inline(always)]
357    fn receivable_seqs(&self) -> Range<u64> {
358        self.next_receivable_seq.load(Ordering::SeqCst)
359            ..self.flow_state.next_publishable_seq.load(Ordering::SeqCst)
360    }
361
362    /// Marks all sequences up to (and including)
363    /// `sequence` as received by this subscriber.
364    #[inline(always)]
365    fn receive_up_to(&self, sequence: u64) {
366        self.next_receivable_seq
367            .fetch_max(sequence + 1, Ordering::SeqCst);
368    }
369}
370
371/// Future returned by [`FlowSubscriber::next`].
372struct ReceiveNextFuture<'a, T: Flows> {
373    subscriber: &'a FlowSubscriber<T>,
374}
375
376impl<'a, T: Flows> Future for ReceiveNextFuture<'a, T> {
377    type Output = Result<PublishedData<'a, T>, Error>;
378
379    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
380        match self.subscriber.try_next_internal() {
381            Ok(next) => Poll::Ready(Ok(next)),
382            Err(Error::Ahead) => {
383                cx.waker().wake_by_ref();
384                Poll::Pending
385            }
386            Err(e) => Poll::Ready(Err(e)),
387        }
388    }
389}
390
391impl<T> Debug for FlowSubscriber<T>
392where
393    T: Flows,
394{
395    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
396        f.debug_struct("OutBarrier")
397            .field("flow_state", &self.flow_state)
398            .field("next_receivable_seq", &self.next_receivable_seq)
399            .finish()
400    }
401}
402
403// Flow states may be sent between threads and
404// safely accessed concurrently.
405unsafe impl<T> Send for FlowState<T> where T: Flows {}
406unsafe impl<T> Sync for FlowState<T> where T: Flows {}
407
408/// Blanket trait for data in a [`Flow`].
409pub trait Flows: Default + Send + Sync + 'static {}
410impl<T> Flows for T where T: Default + Send + Sync + 'static {}
411
412/// Reference to mutable, unpublished data in a [`Flow`].
413///
414/// When this reference is dropped, the data
415/// is marked as published into the [`Flow`].
416#[derive(Debug)]
417pub struct UnpublishedData<'a, T: Flows> {
418    flow: &'a Flow<T>,
419    sequence: u64,
420    data: &'a mut T,
421}
422
423impl<T: Flows> UnpublishedData<'_, T> {
424    /// Returns the data's sequence number.
425    pub fn sequence(&self) -> u64 {
426        self.sequence
427    }
428
429    /// Publishes `data` into this sequence.
430    pub fn publish(self, data: T) {
431        *self.data = data;
432        drop(self)
433    }
434}
435
436impl<T: Flows> Deref for UnpublishedData<'_, T> {
437    type Target = T;
438
439    fn deref(&self) -> &Self::Target {
440        self.data
441    }
442}
443
444impl<T: Flows> DerefMut for UnpublishedData<'_, T> {
445    fn deref_mut(&mut self) -> &mut Self::Target {
446        self.data
447    }
448}
449
450impl<T: Flows> Drop for UnpublishedData<'_, T> {
451    fn drop(&mut self) {
452        while !self.flow.state.try_publish(self.sequence) {}
453    }
454}
455
456/// Return value of [`FlowSubscriber::try_next`].
457///
458/// When this value is dropped, the data will
459/// be marked as received by its corresponding
460/// subscriber.
461#[derive(Debug)]
462struct PublishedData<'a, T: Flows> {
463    subscription: &'a FlowSubscriber<T>,
464    sequence: u64,
465    data: &'a T,
466}
467
468impl<T: Flows> Deref for PublishedData<'_, T> {
469    type Target = T;
470
471    fn deref(&self) -> &Self::Target {
472        self.data
473    }
474}
475
476impl<T: Flows> Drop for PublishedData<'_, T> {
477    fn drop(&mut self) {
478        self.subscription.receive_up_to(self.sequence);
479    }
480}
481
482/// Enumeration of non-retryable errors
483/// that may happen while using flows.
484#[derive(Debug, Snafu, PartialEq)]
485pub enum Error {
486    /// Publishing is temporarily impossible:
487    /// the flow is full of unreceived data.
488    Full,
489
490    /// The flow may or may not contain data, but the
491    /// subscriber has already read all data presently
492    /// in the flow.
493    Ahead,
494}
495
496#[cfg(test)]
497mod test {
498    use super::*;
499
500    /// Tests basic API functionality.
501    #[test]
502    fn pubs_and_subs() -> Result<(), crate::Error> {
503        // Prepare pubsub.
504        let (mut publisher, [mut subscriber]) = Flow::new(2);
505
506        // Publish some data.
507        let mut data = publisher.try_next().unwrap();
508        *data = 42u32;
509        assert_eq!(0, data.sequence());
510        drop(data);
511
512        // Check barrier sequences.
513        assert_eq!(0..1, subscriber.receivable_seqs());
514
515        // Receive some data.
516        let data = subscriber.try_next().unwrap();
517        assert!(42u32 == *data);
518        drop(data);
519
520        // Check barrier sequences.
521        assert_eq!(1..1, subscriber.receivable_seqs());
522
523        Ok(())
524    }
525}