real_time/
fifo.rs

1use {
2    crate::{
3        backoff::Backoff,
4        sync::{
5            atomic::{AtomicUsize, Ordering},
6            Arc,
7        },
8    },
9    crossbeam_utils::CachePadded,
10    std::{alloc, cell::UnsafeCell, cmp::PartialEq, marker::PhantomData, ops::Deref},
11};
12
13/// A handle for writing values to the FIFO.
14pub struct Producer<T, const N: usize> {
15    shared: Arc<Shared<T, N>>,
16}
17
18/// A handle for reading values from the FIFO.
19pub struct Consumer<T, const N: usize> {
20    shared: Arc<Shared<T, N>>,
21}
22
23/// Create a new FIFO with the given capacity.
24///
25/// This FIFO is optimised for a consumer running on a real-time thread.
26/// Elements are not dropped when they are popped, instead they will be dropped
27/// when they are overwritten by a later push, or when the buffer is dropped.
28pub fn fifo<T, const N: usize>() -> (Producer<T, N>, Consumer<T, N>) {
29    let shared = Arc::new(Shared::new());
30
31    (
32        Producer {
33            shared: Arc::clone(&shared),
34        },
35        Consumer { shared },
36    )
37}
38
39unsafe impl<T, const N: usize> Send for Producer<T, N> where T: Send {}
40unsafe impl<T, const N: usize> Send for Consumer<T, N> where T: Send {}
41
42struct Shared<T, const N: usize> {
43    buffer: Buffer<T, N>,
44    atomic_head: CachePadded<AtomicHead>,
45    cached_tail: CachePadded<CachedTail>,
46    cached_head: CachePadded<CachedHead>,
47    atomic_tail: CachePadded<AtomicTail>,
48}
49
50struct Buffer<T, const N: usize> {
51    ptr: *mut T,
52}
53
54impl<T, const N: usize> Producer<T, N> {
55    /// Push a value into the FIFO.
56    ///
57    /// If the FIFO is full, this method will block until there is space
58    /// available.
59    pub fn push_blocking(&self, mut value: T) {
60        let backoff = Backoff::default();
61
62        while let Err(value_failed_to_push) = self.push(value) {
63            backoff.snooze();
64            value = value_failed_to_push;
65        }
66    }
67
68    /// Push a value into the FIFO.
69    ///
70    /// If the FIFO is full, this method will not block, and instead returns the
71    /// value to the caller.
72    pub fn push(&self, value: T) -> Result<(), T> {
73        let tail = self.shared.get(Ordering::Relaxed);
74        let head = self.shared.get_cached();
75
76        let size = match size(head, tail) {
77            size if size < N => size,
78            _ => {
79                let head = self.shared.get(Ordering::Acquire);
80                self.shared.set_cached(head);
81                size(head, tail)
82            }
83        };
84
85        debug_assert!(
86            size <= Buffer::<T, N>::SIZE,
87            "size ({}) should not be greater than capacity ({})",
88            size,
89            Buffer::<T, N>::SIZE
90        );
91
92        if size == N {
93            return Err(value);
94        }
95
96        let element = self.shared.buffer.get(tail);
97        if self.shared.buffer.has_wrapped(tail) {
98            unsafe { element.drop_in_place() };
99        }
100        unsafe { element.write(value) };
101
102        self.shared.set(advance(tail), Ordering::Release);
103        Ok(())
104    }
105}
106
107impl<T, const N: usize> Consumer<T, N> {
108    /// Pop a value from the FIFO.
109    pub fn pop(&self) -> Option<T>
110    where
111        T: Copy,
112    {
113        self.pop_head_impl().map(|r| *r)
114    }
115
116    /// Pop a value from the FIFO by reference.
117    ///
118    /// This method is useful when the elements in the FIFO do not implement
119    /// `Copy`.
120    pub fn pop_ref(&mut self) -> Option<PopRef<'_, T, N>> {
121        self.pop_head_impl()
122    }
123
124    fn pop_head_impl(&self) -> Option<PopRef<'_, T, N>> {
125        let head = self.shared.get(Ordering::Relaxed);
126        let tail = self.shared.get_cached();
127
128        let size = match size(head, tail) {
129            0 => {
130                let tail = self.shared.get(Ordering::Acquire);
131                self.shared.set_cached(tail);
132                size(head, tail)
133            }
134            size => size,
135        };
136
137        debug_assert!(
138            size <= Buffer::<T, N>::SIZE,
139            "size ({}) should not be greater than capacity ({})",
140            size,
141            Buffer::<T, N>::SIZE
142        );
143
144        if size == 0 {
145            return None;
146        }
147
148        Some(PopRef {
149            head,
150            consumer: self,
151        })
152    }
153}
154
155impl<T, const N: usize> Shared<T, N> {
156    fn new() -> Self {
157        Self {
158            buffer: Buffer::new(),
159            atomic_head: CachePadded::default(),
160            cached_tail: CachePadded::default(),
161            cached_head: CachePadded::default(),
162            atomic_tail: CachePadded::default(),
163        }
164    }
165}
166
167trait SetCursor<Role> {
168    fn set(&self, cursor: Cursor<Role>, ordering: Ordering);
169    fn set_cached(&self, cursor: Cursor<Role>);
170}
171
172impl<T, const N: usize> SetCursor<HeadRole> for Shared<T, N> {
173    #[inline]
174    fn set(&self, cursor: Head, ordering: Ordering) {
175        self.atomic_head.store(cursor, ordering);
176    }
177
178    #[inline]
179    fn set_cached(&self, cursor: Head) {
180        self.cached_head.set(cursor);
181    }
182}
183
184impl<T, const N: usize> SetCursor<TailRole> for Shared<T, N> {
185    #[inline]
186    fn set(&self, cursor: Tail, ordering: Ordering) {
187        self.atomic_tail.store(cursor, ordering);
188    }
189
190    #[inline]
191    fn set_cached(&self, cursor: Tail) {
192        self.cached_tail.set(cursor);
193    }
194}
195
196trait GetCursor<Role> {
197    fn get(&self, ordering: Ordering) -> Cursor<Role>;
198    fn get_cached(&self) -> Cursor<Role>;
199}
200
201impl<T, const N: usize> GetCursor<HeadRole> for Shared<T, N> {
202    #[inline]
203    fn get(&self, ordering: Ordering) -> Head {
204        self.atomic_head.load(ordering)
205    }
206
207    #[inline]
208    fn get_cached(&self) -> Head {
209        self.cached_head.get()
210    }
211}
212
213impl<T, const N: usize> GetCursor<TailRole> for Shared<T, N> {
214    #[inline]
215    fn get(&self, ordering: Ordering) -> Tail {
216        self.atomic_tail.load(ordering)
217    }
218
219    #[inline]
220    fn get_cached(&self) -> Tail {
221        self.cached_tail.get()
222    }
223}
224
225impl<T, const N: usize> Drop for Shared<T, N> {
226    fn drop(&mut self) {
227        let tail: Tail = self.get(Ordering::Relaxed);
228
229        let elements_to_drop = if self.buffer.has_wrapped(tail) {
230            Buffer::<T, N>::SIZE
231        } else {
232            tail.into()
233        };
234
235        for i in 0..elements_to_drop {
236            let element = self.buffer.at(i);
237            unsafe { element.drop_in_place() };
238        }
239    }
240}
241
242impl<T, const N: usize> Buffer<T, N> {
243    const SIZE: usize = usize::next_power_of_two(N);
244
245    fn new() -> Self {
246        let layout = layout_for::<T>(Self::SIZE);
247
248        let buffer = unsafe { alloc::alloc(layout) };
249        if buffer.is_null() {
250            panic!("failed to allocate buffer");
251        }
252
253        Self { ptr: buffer.cast() }
254    }
255
256    #[inline]
257    fn at(&self, index: usize) -> *mut T {
258        debug_assert!(index < Self::SIZE, "index out of bounds");
259        unsafe { self.ptr.add(index) }
260    }
261
262    #[inline]
263    fn index<Role>(&self, cursor: Cursor<Role>) -> usize {
264        index(cursor, Self::SIZE)
265    }
266
267    #[inline]
268    fn get<Role>(&self, cursor: Cursor<Role>) -> *mut T {
269        self.at(self.index(cursor))
270    }
271
272    #[inline]
273    fn has_wrapped<Role>(&self, Cursor(pos, _): Cursor<Role>) -> bool {
274        pos >= Buffer::<T, N>::SIZE
275    }
276}
277
278impl<T, const N: usize> Drop for Buffer<T, N> {
279    fn drop(&mut self) {
280        let layout = layout_for::<T>(Self::SIZE);
281        unsafe { alloc::dealloc(self.ptr.cast(), layout) };
282    }
283}
284
285fn layout_for<T>(size: usize) -> alloc::Layout {
286    let bytes = size.checked_mul(size_of::<T>()).expect("capacity overflow");
287    alloc::Layout::from_size_align(bytes, align_of::<T>()).expect("failed to create layout")
288}
289
290/// A reference to a value that has been popped from the FIFO.
291pub struct PopRef<'a, T, const N: usize> {
292    head: Head,
293    consumer: &'a Consumer<T, N>,
294}
295
296impl<T, const N: usize> Deref for PopRef<'_, T, N> {
297    type Target = T;
298
299    fn deref(&self) -> &Self::Target {
300        let element = self.consumer.shared.buffer.get(self.head);
301
302        // SAFETY: We have unique access to the value at head for the lifetime of this
303        // guard object. Only once it is dropped will the head cursor be advanced.
304        unsafe { &*element }
305    }
306}
307
308impl<T, const N: usize> Drop for PopRef<'_, T, N> {
309    fn drop(&mut self) {
310        self.consumer
311            .shared
312            .set(advance(self.head), Ordering::Release);
313    }
314}
315
316#[repr(transparent)]
317#[derive(Debug, Copy, Clone)]
318struct Cursor<Role>(usize, PhantomData<Role>);
319
320#[repr(transparent)]
321struct AtomicCursor<Role>(AtomicUsize, PhantomData<Role>);
322
323impl<Role> Default for AtomicCursor<Role> {
324    fn default() -> Self {
325        Self(AtomicUsize::new(0), PhantomData)
326    }
327}
328
329impl<Role> AtomicCursor<Role> {
330    #[inline]
331    fn load(&self, ordering: Ordering) -> Cursor<Role> {
332        Cursor(self.0.load(ordering), PhantomData)
333    }
334
335    #[inline]
336    fn store(&self, Cursor(cursor, _): Cursor<Role>, ordering: Ordering) {
337        self.0.store(cursor, ordering);
338    }
339}
340
341#[repr(transparent)]
342struct CachedCursor<Role>(UnsafeCell<Cursor<Role>>);
343
344impl<Role> Default for CachedCursor<Role> {
345    fn default() -> Self {
346        Self(UnsafeCell::new(Cursor(0, PhantomData)))
347    }
348}
349
350impl<Role> CachedCursor<Role> {
351    #[inline]
352    fn get(&self) -> Cursor<Role>
353    where
354        Cursor<Role>: Copy,
355    {
356        unsafe { *self.0.get() }
357    }
358
359    #[inline]
360    fn set(&self, cursor: Cursor<Role>) {
361        unsafe { *self.0.get() = cursor }
362    }
363}
364
365#[inline]
366fn size(Cursor(head, _): Head, Cursor(tail, _): Tail) -> usize {
367    tail - head
368}
369
370#[inline]
371fn advance<Role>(Cursor(cursor, _): Cursor<Role>) -> Cursor<Role> {
372    Cursor(cursor + 1, PhantomData)
373}
374
375#[inline]
376fn index<Role>(Cursor(cursor, _): Cursor<Role>, size: usize) -> usize {
377    debug_assert!(
378        size.is_power_of_two(),
379        "size must be a power of two, got {size:?}",
380    );
381    cursor & (size - 1)
382}
383
384#[derive(Debug, Copy, Clone)]
385struct HeadRole;
386
387#[derive(Debug, Copy, Clone)]
388struct TailRole;
389
390type Head = Cursor<HeadRole>;
391
392type Tail = Cursor<TailRole>;
393
394type AtomicHead = AtomicCursor<HeadRole>;
395
396type AtomicTail = AtomicCursor<TailRole>;
397
398type CachedHead = CachedCursor<HeadRole>;
399
400type CachedTail = CachedCursor<TailRole>;
401
402impl<RoleA, RoleB> PartialEq<Cursor<RoleA>> for Cursor<RoleB> {
403    fn eq(&self, other: &Cursor<RoleA>) -> bool {
404        self.0 == other.0
405    }
406}
407
408impl<RoleA, RoleB> PartialOrd<Cursor<RoleA>> for Cursor<RoleB> {
409    fn partial_cmp(&self, other: &Cursor<RoleA>) -> Option<std::cmp::Ordering> {
410        self.0.partial_cmp(&other.0)
411    }
412}
413
414impl<Role> From<usize> for Cursor<Role> {
415    fn from(value: usize) -> Self {
416        Cursor(value, PhantomData)
417    }
418}
419
420impl<Role> From<Cursor<Role>> for usize {
421    fn from(Cursor(cursor, _): Cursor<Role>) -> usize {
422        cursor
423    }
424}
425
426#[cfg(test)]
427mod test {
428    use {
429        super::*,
430        static_assertions::{assert_impl_all, assert_not_impl_any},
431        std::thread,
432    };
433
434    assert_impl_all!(Producer<i32, 8>: Send);
435    assert_not_impl_any!(Producer<i32, 8>: Sync, Copy, Clone);
436
437    assert_impl_all!(Consumer<i32, 8>: Send);
438    assert_not_impl_any!(Consumer<i32, 8>: Sync, Copy, Clone);
439
440    fn get_buffer_size<T, const N: usize>(producer: &Producer<T, N>) -> usize {
441        size(
442            producer.shared.get(Ordering::Relaxed),
443            producer.shared.get(Ordering::Relaxed),
444        )
445    }
446
447    #[derive(Debug, Default, Clone)]
448    struct DropCounter(Arc<AtomicUsize>);
449
450    impl DropCounter {
451        fn count(&self) -> usize {
452            self.0.load(Ordering::Relaxed)
453        }
454    }
455
456    impl Drop for DropCounter {
457        fn drop(&mut self) {
458            self.0.fetch_add(1, Ordering::Relaxed);
459        }
460    }
461
462    fn head(pos: usize) -> Head {
463        Cursor(pos, PhantomData)
464    }
465
466    fn tail(pos: usize) -> Tail {
467        Cursor(pos, PhantomData)
468    }
469
470    #[test]
471    fn querying_size() {
472        assert_eq!(size(head(0), tail(0)), 0);
473        assert_eq!(size(head(0), tail(1)), 1);
474        assert_eq!(size(head(0), tail(2)), 2);
475        assert_eq!(size(head(0), tail(3)), 3);
476        assert_eq!(size(head(1), tail(3)), 2);
477        assert_eq!(size(head(2), tail(3)), 1);
478        assert_eq!(size(head(3), tail(3)), 0);
479    }
480
481    #[test]
482    fn advancing_cursors() {
483        let cursor = head(0);
484
485        let cursor = advance(cursor);
486        assert_eq!(cursor, head(1));
487
488        let cursor = advance(cursor);
489        assert_eq!(cursor, head(2));
490
491        let cursor = advance(cursor);
492        assert_eq!(cursor, head(3));
493
494        let cursor = advance(cursor);
495        assert_eq!(cursor, head(4));
496
497        let cursor = advance(cursor);
498        assert_eq!(cursor, head(5));
499    }
500
501    #[test]
502    fn cursor_to_index() {
503        assert_eq!(index(head(0), 4), 0);
504        assert_eq!(index(head(1), 4), 1);
505        assert_eq!(index(head(2), 4), 2);
506        assert_eq!(index(head(3), 4), 3);
507        assert_eq!(index(head(4), 4), 0);
508        assert_eq!(index(head(5), 4), 1);
509        assert_eq!(index(head(6), 4), 2);
510        assert_eq!(index(head(7), 4), 3);
511        assert_eq!(index(head(8), 4), 0);
512    }
513
514    #[test]
515    fn using_a_fifo() {
516        let (tx, rx) = fifo::<i32, 3>();
517        assert_eq!(get_buffer_size(&tx), 0);
518
519        assert!(rx.pop().is_none());
520
521        tx.push(5).unwrap();
522
523        assert_eq!(rx.pop(), Some(5));
524
525        tx.push(1).unwrap();
526        tx.push(2).unwrap();
527        tx.push(3).unwrap();
528
529        let push_result = tx.push(4);
530        assert_eq!(push_result, Err(4));
531
532        assert_eq!(rx.pop(), Some(1));
533
534        let push_result = tx.push(4);
535        assert!(push_result.is_ok());
536
537        let (tx, mut rx) = fifo::<String, 2>();
538        tx.push("hello".to_string()).unwrap();
539
540        let value_ref = rx.pop_ref();
541        assert!(value_ref.is_some());
542        assert_eq!(value_ref.unwrap().as_str(), "hello");
543    }
544
545    #[test]
546    fn elements_are_dropped_when_overwritten() {
547        let drop_counter = DropCounter::default();
548        let (tx, mut rx) = fifo::<_, 3>();
549
550        tx.push(drop_counter.clone()).unwrap();
551        tx.push(drop_counter.clone()).unwrap();
552        tx.push(drop_counter.clone()).unwrap();
553        assert_eq!(drop_counter.count(), 0);
554
555        rx.pop_ref();
556        assert_eq!(drop_counter.count(), 0);
557
558        tx.push(drop_counter.clone()).unwrap();
559        assert_eq!(drop_counter.count(), 0);
560
561        rx.pop_ref();
562        assert_eq!(drop_counter.count(), 0);
563
564        tx.push(drop_counter.clone()).unwrap();
565        assert_eq!(drop_counter.count(), 1);
566    }
567
568    #[test]
569    fn elements_are_dropped_when_buffer_is_dropped() {
570        let drop_counter = DropCounter::default();
571        let (tx, mut rx) = fifo::<_, 3>();
572
573        tx.push(drop_counter.clone()).unwrap();
574        tx.push(drop_counter.clone()).unwrap();
575        tx.push(drop_counter.clone()).unwrap();
576
577        rx.pop_ref();
578        assert_eq!(drop_counter.count(), 0);
579
580        tx.push(drop_counter.clone()).unwrap();
581        assert_eq!(drop_counter.count(), 0);
582
583        drop((tx, rx));
584
585        assert_eq!(drop_counter.count(), 4);
586    }
587
588    #[test]
589    fn reading_and_writing_on_different_threads() {
590        let (writer, reader) = fifo::<_, 12>();
591
592        #[cfg(miri)]
593        const NUM_WRITES: usize = 128;
594
595        #[cfg(not(miri))]
596        const NUM_WRITES: usize = 1_000_000;
597
598        thread::spawn({
599            move || {
600                for value in 1..=NUM_WRITES {
601                    writer.push_blocking(value);
602                }
603            }
604        });
605
606        let mut last = None;
607        while last != Some(NUM_WRITES) {
608            match reader.pop() {
609                Some(value) => {
610                    if let Some(last) = last {
611                        assert_eq!(last + 1, value, "values should be popped in order");
612                    }
613                    last = Some(value);
614                }
615                None => thread::yield_now(),
616            }
617        }
618    }
619}