1#![cfg_attr(all(not(test)), no_std)]
2#![doc = include_str!("../README.md")]
5
6extern crate alloc;
7
8use core::{
9 cell::UnsafeCell,
10 fmt::Debug,
11 future::Future,
12 ops::{Deref, DerefMut, Range},
13 pin::Pin,
14 sync::atomic::Ordering,
15 task::{Context, Poll},
16};
17
18use alloc::{boxed::Box, vec::Vec};
19use portable_atomic::AtomicU64;
20use portable_atomic_util::{Arc, Weak};
21use snafu::Snafu;
22
23pub mod async_support;
24pub mod stage;
25
26#[derive(Debug, Clone)]
31pub struct Flow<T: Flows> {
32 state: Arc<FlowState<T>>,
33}
34
35impl<T: Flows> Flow<T> {
36 pub fn new<const SUB: usize>(capacity: usize) -> (Self, [FlowSubscriber<T>; SUB]) {
45 assert!(capacity & (capacity - 1) == 0, "flow capacity _must_ be a power of two (like `2`, `4`, `256`, `2048`...), not {capacity}");
46
47 let mut buffer = Vec::with_capacity(capacity);
49 for _ in 0..capacity {
50 buffer.push(UnsafeCell::new(T::default()));
51 }
52 let buffer = buffer.into_boxed_slice();
53
54 let mut flow_state = FlowState {
56 buffer,
57 next_writable_seq: AtomicU64::new(0),
58 next_publishable_seq: AtomicU64::new(0),
59 next_receivable_seqs: Vec::with_capacity(SUB),
60 };
61
62 let mut subscriber_seqs = Vec::with_capacity(SUB);
64 for _ in 0..SUB {
65 subscriber_seqs.push(flow_state.add_subscriber_seq());
66 }
67
68 let flow_state = Arc::new(flow_state);
71 let subscribers: Vec<FlowSubscriber<T>> = subscriber_seqs
72 .into_iter()
73 .map(|seq| FlowSubscriber {
74 flow_state: flow_state.clone(),
75 next_receivable_seq: seq,
76 })
77 .collect();
78
79 (Self { state: flow_state }, subscribers.try_into().unwrap())
80 }
81
82 pub fn try_next(&mut self) -> Result<UnpublishedData<T>, Error> {
86 self.try_next_internal()
87 }
88
89 #[allow(clippy::should_implement_trait)]
93 pub fn next(&mut self) -> impl Future<Output = Result<UnpublishedData<T>, Error>> {
94 PublishNextFuture { flow: self }
95 }
96
97 #[inline(always)]
101 fn try_next_internal(&self) -> Result<UnpublishedData<T>, Error> {
102 if let Some(next) = self.state.try_claim_publishable() {
103 let next_item = UnpublishedData {
104 flow: self,
105 sequence: next,
106 data: unsafe { self.state.get_mut(next) },
107 };
108 Ok(next_item)
109 } else {
110 Err(Error::Full)
111 }
112 }
113}
114
115struct PublishNextFuture<'a, T: Flows> {
117 flow: &'a Flow<T>,
118}
119
120impl<'a, T: Flows> Future for PublishNextFuture<'a, T> {
121 type Output = Result<UnpublishedData<'a, T>, Error>;
122
123 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
124 match self.flow.try_next_internal() {
125 Ok(next) => Poll::Ready(Ok(next)),
126 Err(Error::Full) => {
127 cx.waker().wake_by_ref();
128 Poll::Pending
129 }
130 Err(e) => Poll::Ready(Err(e)),
131 }
132 }
133}
134
135struct FlowState<T: Flows> {
142 buffer: Box<[UnsafeCell<T>]>,
155
156 next_writable_seq: AtomicU64,
159
160 next_publishable_seq: AtomicU64,
167
168 next_receivable_seqs: Vec<Weak<AtomicU64>>,
176}
177
178impl<T> FlowState<T>
179where
180 T: Flows,
181{
182 fn add_subscriber_seq(&mut self) -> Arc<AtomicU64> {
185 let next_receivable_seq = Arc::new(AtomicU64::new(0));
186 self.next_receivable_seqs
187 .push(Arc::downgrade(&next_receivable_seq));
188 next_receivable_seq
189 }
190
191 #[inline(always)]
201 fn try_claim_publishable(&self) -> Option<u64> {
202 let next_writable = self.next_writable_seq.load(Ordering::SeqCst);
203
204 let mut min_receivable_seq = self.next_publishable_seq.load(Ordering::SeqCst);
208 for next_received_seq in self.next_receivable_seqs.iter() {
209 if let Some(seq) = next_received_seq.upgrade() {
210 min_receivable_seq = min_receivable_seq.min(seq.load(Ordering::SeqCst));
211 }
212 }
213
214 if min_receivable_seq + self.buffer.len() as u64 > next_writable
216 && self
217 .next_writable_seq
218 .compare_exchange(
219 next_writable,
220 next_writable + 1,
221 Ordering::SeqCst,
222 Ordering::SeqCst,
223 )
224 .is_ok()
225 {
226 return Some(next_writable);
227 }
228
229 None
230 }
231
232 #[inline(always)]
235 fn try_publish(&self, sequence: u64) -> bool {
236 self.next_publishable_seq
237 .compare_exchange_weak(sequence, sequence + 1, Ordering::SeqCst, Ordering::SeqCst)
238 .is_ok()
239 }
240
241 #[allow(clippy::mut_from_ref)]
251 #[inline(always)]
252 unsafe fn get(&self, sequence: u64) -> &T {
253 assert!(self.buffer.len() & (self.buffer.len() - 1) == 0);
254
255 let index = (self.buffer.len() - 1) & sequence as usize;
257
258 &*self.buffer.get_unchecked(index).get()
260 }
261
262 #[allow(clippy::mut_from_ref)]
287 #[inline(always)]
288 unsafe fn get_mut(&self, sequence: u64) -> &mut T {
289 assert!(self.buffer.len() & (self.buffer.len() - 1) == 0);
290
291 let index = (self.buffer.len() - 1) & sequence as usize;
293
294 &mut *self.buffer.get_unchecked(index).get()
296 }
297}
298
299impl<T> Debug for FlowState<T>
300where
301 T: Flows,
302{
303 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
304 f.debug_struct("Flow")
305 .field("capacity", &self.buffer.len())
306 .field("next_writable_seq", &self.next_writable_seq)
307 .field("next_publishable_seq", &self.next_publishable_seq)
308 .field("next_receivable_seqs", &self.next_receivable_seqs)
309 .finish()
310 }
311}
312
313pub struct FlowSubscriber<T: Flows> {
315 flow_state: Arc<FlowState<T>>,
316
317 next_receivable_seq: Arc<AtomicU64>,
319}
320
321impl<T: Flows> FlowSubscriber<T> {
322 pub fn try_next(&mut self) -> Result<impl Deref<Target = T> + '_, Error> {
326 self.try_next_internal()
327 }
328
329 #[allow(clippy::should_implement_trait)]
332 pub fn next(&mut self) -> impl Future<Output = Result<impl Deref<Target = T> + '_, Error>> {
333 ReceiveNextFuture { subscriber: self }
334 }
335
336 #[inline(always)]
340 fn try_next_internal(&self) -> Result<PublishedData<T>, Error> {
341 if let Some(next) = self.receivable_seqs().next() {
342 let data = PublishedData {
343 subscription: self,
344 sequence: next,
345 data: unsafe { self.flow_state.get(next) },
346 };
347
348 Ok(data)
349 } else {
350 Err(Error::Ahead)
351 }
352 }
353
354 #[inline(always)]
357 fn receivable_seqs(&self) -> Range<u64> {
358 self.next_receivable_seq.load(Ordering::SeqCst)
359 ..self.flow_state.next_publishable_seq.load(Ordering::SeqCst)
360 }
361
362 #[inline(always)]
365 fn receive_up_to(&self, sequence: u64) {
366 self.next_receivable_seq
367 .fetch_max(sequence + 1, Ordering::SeqCst);
368 }
369}
370
371struct ReceiveNextFuture<'a, T: Flows> {
373 subscriber: &'a FlowSubscriber<T>,
374}
375
376impl<'a, T: Flows> Future for ReceiveNextFuture<'a, T> {
377 type Output = Result<PublishedData<'a, T>, Error>;
378
379 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
380 match self.subscriber.try_next_internal() {
381 Ok(next) => Poll::Ready(Ok(next)),
382 Err(Error::Ahead) => {
383 cx.waker().wake_by_ref();
384 Poll::Pending
385 }
386 Err(e) => Poll::Ready(Err(e)),
387 }
388 }
389}
390
391impl<T> Debug for FlowSubscriber<T>
392where
393 T: Flows,
394{
395 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
396 f.debug_struct("OutBarrier")
397 .field("flow_state", &self.flow_state)
398 .field("next_receivable_seq", &self.next_receivable_seq)
399 .finish()
400 }
401}
402
403unsafe impl<T> Send for FlowState<T> where T: Flows {}
406unsafe impl<T> Sync for FlowState<T> where T: Flows {}
407
408pub trait Flows: Default + Send + Sync + 'static {}
410impl<T> Flows for T where T: Default + Send + Sync + 'static {}
411
412#[derive(Debug)]
417pub struct UnpublishedData<'a, T: Flows> {
418 flow: &'a Flow<T>,
419 sequence: u64,
420 data: &'a mut T,
421}
422
423impl<T: Flows> UnpublishedData<'_, T> {
424 pub fn sequence(&self) -> u64 {
426 self.sequence
427 }
428
429 pub fn publish(self, data: T) {
431 *self.data = data;
432 drop(self)
433 }
434}
435
436impl<T: Flows> Deref for UnpublishedData<'_, T> {
437 type Target = T;
438
439 fn deref(&self) -> &Self::Target {
440 self.data
441 }
442}
443
444impl<T: Flows> DerefMut for UnpublishedData<'_, T> {
445 fn deref_mut(&mut self) -> &mut Self::Target {
446 self.data
447 }
448}
449
450impl<T: Flows> Drop for UnpublishedData<'_, T> {
451 fn drop(&mut self) {
452 while !self.flow.state.try_publish(self.sequence) {}
453 }
454}
455
456#[derive(Debug)]
462struct PublishedData<'a, T: Flows> {
463 subscription: &'a FlowSubscriber<T>,
464 sequence: u64,
465 data: &'a T,
466}
467
468impl<T: Flows> Deref for PublishedData<'_, T> {
469 type Target = T;
470
471 fn deref(&self) -> &Self::Target {
472 self.data
473 }
474}
475
476impl<T: Flows> Drop for PublishedData<'_, T> {
477 fn drop(&mut self) {
478 self.subscription.receive_up_to(self.sequence);
479 }
480}
481
482#[derive(Debug, Snafu, PartialEq)]
485pub enum Error {
486 Full,
489
490 Ahead,
494}
495
496#[cfg(test)]
497mod test {
498 use super::*;
499
500 #[test]
502 fn pubs_and_subs() -> Result<(), crate::Error> {
503 let (mut publisher, [mut subscriber]) = Flow::new(2);
505
506 let mut data = publisher.try_next().unwrap();
508 *data = 42u32;
509 assert_eq!(0, data.sequence());
510 drop(data);
511
512 assert_eq!(0..1, subscriber.receivable_seqs());
514
515 let data = subscriber.try_next().unwrap();
517 assert!(42u32 == *data);
518 drop(data);
519
520 assert_eq!(1..1, subscriber.receivable_seqs());
522
523 Ok(())
524 }
525}