1use {
2 crate::{
3 backoff::Backoff,
4 sync::{
5 atomic::{AtomicUsize, Ordering},
6 Arc,
7 },
8 },
9 crossbeam_utils::CachePadded,
10 std::{alloc, cell::UnsafeCell, cmp::PartialEq, marker::PhantomData, ops::Deref},
11};
12
13pub struct Producer<T, const N: usize> {
15 shared: Arc<Shared<T, N>>,
16}
17
18pub struct Consumer<T, const N: usize> {
20 shared: Arc<Shared<T, N>>,
21}
22
23pub fn fifo<T, const N: usize>() -> (Producer<T, N>, Consumer<T, N>) {
29 let shared = Arc::new(Shared::new());
30
31 (
32 Producer {
33 shared: Arc::clone(&shared),
34 },
35 Consumer { shared },
36 )
37}
38
39unsafe impl<T, const N: usize> Send for Producer<T, N> where T: Send {}
40unsafe impl<T, const N: usize> Send for Consumer<T, N> where T: Send {}
41
42struct Shared<T, const N: usize> {
43 buffer: Buffer<T, N>,
44 atomic_head: CachePadded<AtomicHead>,
45 cached_tail: CachePadded<CachedTail>,
46 cached_head: CachePadded<CachedHead>,
47 atomic_tail: CachePadded<AtomicTail>,
48}
49
50struct Buffer<T, const N: usize> {
51 ptr: *mut T,
52}
53
54impl<T, const N: usize> Producer<T, N> {
55 pub fn push_blocking(&self, mut value: T) {
60 let backoff = Backoff::default();
61
62 while let Err(value_failed_to_push) = self.push(value) {
63 backoff.snooze();
64 value = value_failed_to_push;
65 }
66 }
67
68 pub fn push(&self, value: T) -> Result<(), T> {
73 let tail = self.shared.get(Ordering::Relaxed);
74 let head = self.shared.get_cached();
75
76 let size = match size(head, tail) {
77 size if size < N => size,
78 _ => {
79 let head = self.shared.get(Ordering::Acquire);
80 self.shared.set_cached(head);
81 size(head, tail)
82 }
83 };
84
85 debug_assert!(
86 size <= Buffer::<T, N>::SIZE,
87 "size ({}) should not be greater than capacity ({})",
88 size,
89 Buffer::<T, N>::SIZE
90 );
91
92 if size == N {
93 return Err(value);
94 }
95
96 let element = self.shared.buffer.get(tail);
97 if self.shared.buffer.has_wrapped(tail) {
98 unsafe { element.drop_in_place() };
99 }
100 unsafe { element.write(value) };
101
102 self.shared.set(advance(tail), Ordering::Release);
103 Ok(())
104 }
105}
106
107impl<T, const N: usize> Consumer<T, N> {
108 pub fn pop(&self) -> Option<T>
110 where
111 T: Copy,
112 {
113 self.pop_head_impl().map(|r| *r)
114 }
115
116 pub fn pop_ref(&mut self) -> Option<PopRef<'_, T, N>> {
121 self.pop_head_impl()
122 }
123
124 fn pop_head_impl(&self) -> Option<PopRef<'_, T, N>> {
125 let head = self.shared.get(Ordering::Relaxed);
126 let tail = self.shared.get_cached();
127
128 let size = match size(head, tail) {
129 0 => {
130 let tail = self.shared.get(Ordering::Acquire);
131 self.shared.set_cached(tail);
132 size(head, tail)
133 }
134 size => size,
135 };
136
137 debug_assert!(
138 size <= Buffer::<T, N>::SIZE,
139 "size ({}) should not be greater than capacity ({})",
140 size,
141 Buffer::<T, N>::SIZE
142 );
143
144 if size == 0 {
145 return None;
146 }
147
148 Some(PopRef {
149 head,
150 consumer: self,
151 })
152 }
153}
154
155impl<T, const N: usize> Shared<T, N> {
156 fn new() -> Self {
157 Self {
158 buffer: Buffer::new(),
159 atomic_head: CachePadded::default(),
160 cached_tail: CachePadded::default(),
161 cached_head: CachePadded::default(),
162 atomic_tail: CachePadded::default(),
163 }
164 }
165}
166
167trait SetCursor<Role> {
168 fn set(&self, cursor: Cursor<Role>, ordering: Ordering);
169 fn set_cached(&self, cursor: Cursor<Role>);
170}
171
172impl<T, const N: usize> SetCursor<HeadRole> for Shared<T, N> {
173 #[inline]
174 fn set(&self, cursor: Head, ordering: Ordering) {
175 self.atomic_head.store(cursor, ordering);
176 }
177
178 #[inline]
179 fn set_cached(&self, cursor: Head) {
180 self.cached_head.set(cursor);
181 }
182}
183
184impl<T, const N: usize> SetCursor<TailRole> for Shared<T, N> {
185 #[inline]
186 fn set(&self, cursor: Tail, ordering: Ordering) {
187 self.atomic_tail.store(cursor, ordering);
188 }
189
190 #[inline]
191 fn set_cached(&self, cursor: Tail) {
192 self.cached_tail.set(cursor);
193 }
194}
195
196trait GetCursor<Role> {
197 fn get(&self, ordering: Ordering) -> Cursor<Role>;
198 fn get_cached(&self) -> Cursor<Role>;
199}
200
201impl<T, const N: usize> GetCursor<HeadRole> for Shared<T, N> {
202 #[inline]
203 fn get(&self, ordering: Ordering) -> Head {
204 self.atomic_head.load(ordering)
205 }
206
207 #[inline]
208 fn get_cached(&self) -> Head {
209 self.cached_head.get()
210 }
211}
212
213impl<T, const N: usize> GetCursor<TailRole> for Shared<T, N> {
214 #[inline]
215 fn get(&self, ordering: Ordering) -> Tail {
216 self.atomic_tail.load(ordering)
217 }
218
219 #[inline]
220 fn get_cached(&self) -> Tail {
221 self.cached_tail.get()
222 }
223}
224
225impl<T, const N: usize> Drop for Shared<T, N> {
226 fn drop(&mut self) {
227 let tail: Tail = self.get(Ordering::Relaxed);
228
229 let elements_to_drop = if self.buffer.has_wrapped(tail) {
230 Buffer::<T, N>::SIZE
231 } else {
232 tail.into()
233 };
234
235 for i in 0..elements_to_drop {
236 let element = self.buffer.at(i);
237 unsafe { element.drop_in_place() };
238 }
239 }
240}
241
242impl<T, const N: usize> Buffer<T, N> {
243 const SIZE: usize = usize::next_power_of_two(N);
244
245 fn new() -> Self {
246 let layout = layout_for::<T>(Self::SIZE);
247
248 let buffer = unsafe { alloc::alloc(layout) };
249 if buffer.is_null() {
250 panic!("failed to allocate buffer");
251 }
252
253 Self { ptr: buffer.cast() }
254 }
255
256 #[inline]
257 fn at(&self, index: usize) -> *mut T {
258 debug_assert!(index < Self::SIZE, "index out of bounds");
259 unsafe { self.ptr.add(index) }
260 }
261
262 #[inline]
263 fn index<Role>(&self, cursor: Cursor<Role>) -> usize {
264 index(cursor, Self::SIZE)
265 }
266
267 #[inline]
268 fn get<Role>(&self, cursor: Cursor<Role>) -> *mut T {
269 self.at(self.index(cursor))
270 }
271
272 #[inline]
273 fn has_wrapped<Role>(&self, Cursor(pos, _): Cursor<Role>) -> bool {
274 pos >= Buffer::<T, N>::SIZE
275 }
276}
277
278impl<T, const N: usize> Drop for Buffer<T, N> {
279 fn drop(&mut self) {
280 let layout = layout_for::<T>(Self::SIZE);
281 unsafe { alloc::dealloc(self.ptr.cast(), layout) };
282 }
283}
284
285fn layout_for<T>(size: usize) -> alloc::Layout {
286 let bytes = size.checked_mul(size_of::<T>()).expect("capacity overflow");
287 alloc::Layout::from_size_align(bytes, align_of::<T>()).expect("failed to create layout")
288}
289
290pub struct PopRef<'a, T, const N: usize> {
292 head: Head,
293 consumer: &'a Consumer<T, N>,
294}
295
296impl<T, const N: usize> Deref for PopRef<'_, T, N> {
297 type Target = T;
298
299 fn deref(&self) -> &Self::Target {
300 let element = self.consumer.shared.buffer.get(self.head);
301
302 unsafe { &*element }
305 }
306}
307
308impl<T, const N: usize> Drop for PopRef<'_, T, N> {
309 fn drop(&mut self) {
310 self.consumer
311 .shared
312 .set(advance(self.head), Ordering::Release);
313 }
314}
315
316#[repr(transparent)]
317#[derive(Debug, Copy, Clone)]
318struct Cursor<Role>(usize, PhantomData<Role>);
319
320#[repr(transparent)]
321struct AtomicCursor<Role>(AtomicUsize, PhantomData<Role>);
322
323impl<Role> Default for AtomicCursor<Role> {
324 fn default() -> Self {
325 Self(AtomicUsize::new(0), PhantomData)
326 }
327}
328
329impl<Role> AtomicCursor<Role> {
330 #[inline]
331 fn load(&self, ordering: Ordering) -> Cursor<Role> {
332 Cursor(self.0.load(ordering), PhantomData)
333 }
334
335 #[inline]
336 fn store(&self, Cursor(cursor, _): Cursor<Role>, ordering: Ordering) {
337 self.0.store(cursor, ordering);
338 }
339}
340
341#[repr(transparent)]
342struct CachedCursor<Role>(UnsafeCell<Cursor<Role>>);
343
344impl<Role> Default for CachedCursor<Role> {
345 fn default() -> Self {
346 Self(UnsafeCell::new(Cursor(0, PhantomData)))
347 }
348}
349
350impl<Role> CachedCursor<Role> {
351 #[inline]
352 fn get(&self) -> Cursor<Role>
353 where
354 Cursor<Role>: Copy,
355 {
356 unsafe { *self.0.get() }
357 }
358
359 #[inline]
360 fn set(&self, cursor: Cursor<Role>) {
361 unsafe { *self.0.get() = cursor }
362 }
363}
364
365#[inline]
366fn size(Cursor(head, _): Head, Cursor(tail, _): Tail) -> usize {
367 tail - head
368}
369
370#[inline]
371fn advance<Role>(Cursor(cursor, _): Cursor<Role>) -> Cursor<Role> {
372 Cursor(cursor + 1, PhantomData)
373}
374
375#[inline]
376fn index<Role>(Cursor(cursor, _): Cursor<Role>, size: usize) -> usize {
377 debug_assert!(
378 size.is_power_of_two(),
379 "size must be a power of two, got {size:?}",
380 );
381 cursor & (size - 1)
382}
383
384#[derive(Debug, Copy, Clone)]
385struct HeadRole;
386
387#[derive(Debug, Copy, Clone)]
388struct TailRole;
389
390type Head = Cursor<HeadRole>;
391
392type Tail = Cursor<TailRole>;
393
394type AtomicHead = AtomicCursor<HeadRole>;
395
396type AtomicTail = AtomicCursor<TailRole>;
397
398type CachedHead = CachedCursor<HeadRole>;
399
400type CachedTail = CachedCursor<TailRole>;
401
402impl<RoleA, RoleB> PartialEq<Cursor<RoleA>> for Cursor<RoleB> {
403 fn eq(&self, other: &Cursor<RoleA>) -> bool {
404 self.0 == other.0
405 }
406}
407
408impl<RoleA, RoleB> PartialOrd<Cursor<RoleA>> for Cursor<RoleB> {
409 fn partial_cmp(&self, other: &Cursor<RoleA>) -> Option<std::cmp::Ordering> {
410 self.0.partial_cmp(&other.0)
411 }
412}
413
414impl<Role> From<usize> for Cursor<Role> {
415 fn from(value: usize) -> Self {
416 Cursor(value, PhantomData)
417 }
418}
419
420impl<Role> From<Cursor<Role>> for usize {
421 fn from(Cursor(cursor, _): Cursor<Role>) -> usize {
422 cursor
423 }
424}
425
426#[cfg(test)]
427mod test {
428 use {
429 super::*,
430 static_assertions::{assert_impl_all, assert_not_impl_any},
431 std::thread,
432 };
433
434 assert_impl_all!(Producer<i32, 8>: Send);
435 assert_not_impl_any!(Producer<i32, 8>: Sync, Copy, Clone);
436
437 assert_impl_all!(Consumer<i32, 8>: Send);
438 assert_not_impl_any!(Consumer<i32, 8>: Sync, Copy, Clone);
439
440 fn get_buffer_size<T, const N: usize>(producer: &Producer<T, N>) -> usize {
441 size(
442 producer.shared.get(Ordering::Relaxed),
443 producer.shared.get(Ordering::Relaxed),
444 )
445 }
446
447 #[derive(Debug, Default, Clone)]
448 struct DropCounter(Arc<AtomicUsize>);
449
450 impl DropCounter {
451 fn count(&self) -> usize {
452 self.0.load(Ordering::Relaxed)
453 }
454 }
455
456 impl Drop for DropCounter {
457 fn drop(&mut self) {
458 self.0.fetch_add(1, Ordering::Relaxed);
459 }
460 }
461
462 fn head(pos: usize) -> Head {
463 Cursor(pos, PhantomData)
464 }
465
466 fn tail(pos: usize) -> Tail {
467 Cursor(pos, PhantomData)
468 }
469
470 #[test]
471 fn querying_size() {
472 assert_eq!(size(head(0), tail(0)), 0);
473 assert_eq!(size(head(0), tail(1)), 1);
474 assert_eq!(size(head(0), tail(2)), 2);
475 assert_eq!(size(head(0), tail(3)), 3);
476 assert_eq!(size(head(1), tail(3)), 2);
477 assert_eq!(size(head(2), tail(3)), 1);
478 assert_eq!(size(head(3), tail(3)), 0);
479 }
480
481 #[test]
482 fn advancing_cursors() {
483 let cursor = head(0);
484
485 let cursor = advance(cursor);
486 assert_eq!(cursor, head(1));
487
488 let cursor = advance(cursor);
489 assert_eq!(cursor, head(2));
490
491 let cursor = advance(cursor);
492 assert_eq!(cursor, head(3));
493
494 let cursor = advance(cursor);
495 assert_eq!(cursor, head(4));
496
497 let cursor = advance(cursor);
498 assert_eq!(cursor, head(5));
499 }
500
501 #[test]
502 fn cursor_to_index() {
503 assert_eq!(index(head(0), 4), 0);
504 assert_eq!(index(head(1), 4), 1);
505 assert_eq!(index(head(2), 4), 2);
506 assert_eq!(index(head(3), 4), 3);
507 assert_eq!(index(head(4), 4), 0);
508 assert_eq!(index(head(5), 4), 1);
509 assert_eq!(index(head(6), 4), 2);
510 assert_eq!(index(head(7), 4), 3);
511 assert_eq!(index(head(8), 4), 0);
512 }
513
514 #[test]
515 fn using_a_fifo() {
516 let (tx, rx) = fifo::<i32, 3>();
517 assert_eq!(get_buffer_size(&tx), 0);
518
519 assert!(rx.pop().is_none());
520
521 tx.push(5).unwrap();
522
523 assert_eq!(rx.pop(), Some(5));
524
525 tx.push(1).unwrap();
526 tx.push(2).unwrap();
527 tx.push(3).unwrap();
528
529 let push_result = tx.push(4);
530 assert_eq!(push_result, Err(4));
531
532 assert_eq!(rx.pop(), Some(1));
533
534 let push_result = tx.push(4);
535 assert!(push_result.is_ok());
536
537 let (tx, mut rx) = fifo::<String, 2>();
538 tx.push("hello".to_string()).unwrap();
539
540 let value_ref = rx.pop_ref();
541 assert!(value_ref.is_some());
542 assert_eq!(value_ref.unwrap().as_str(), "hello");
543 }
544
545 #[test]
546 fn elements_are_dropped_when_overwritten() {
547 let drop_counter = DropCounter::default();
548 let (tx, mut rx) = fifo::<_, 3>();
549
550 tx.push(drop_counter.clone()).unwrap();
551 tx.push(drop_counter.clone()).unwrap();
552 tx.push(drop_counter.clone()).unwrap();
553 assert_eq!(drop_counter.count(), 0);
554
555 rx.pop_ref();
556 assert_eq!(drop_counter.count(), 0);
557
558 tx.push(drop_counter.clone()).unwrap();
559 assert_eq!(drop_counter.count(), 0);
560
561 rx.pop_ref();
562 assert_eq!(drop_counter.count(), 0);
563
564 tx.push(drop_counter.clone()).unwrap();
565 assert_eq!(drop_counter.count(), 1);
566 }
567
568 #[test]
569 fn elements_are_dropped_when_buffer_is_dropped() {
570 let drop_counter = DropCounter::default();
571 let (tx, mut rx) = fifo::<_, 3>();
572
573 tx.push(drop_counter.clone()).unwrap();
574 tx.push(drop_counter.clone()).unwrap();
575 tx.push(drop_counter.clone()).unwrap();
576
577 rx.pop_ref();
578 assert_eq!(drop_counter.count(), 0);
579
580 tx.push(drop_counter.clone()).unwrap();
581 assert_eq!(drop_counter.count(), 0);
582
583 drop((tx, rx));
584
585 assert_eq!(drop_counter.count(), 4);
586 }
587
588 #[test]
589 fn reading_and_writing_on_different_threads() {
590 let (writer, reader) = fifo::<_, 12>();
591
592 #[cfg(miri)]
593 const NUM_WRITES: usize = 128;
594
595 #[cfg(not(miri))]
596 const NUM_WRITES: usize = 1_000_000;
597
598 thread::spawn({
599 move || {
600 for value in 1..=NUM_WRITES {
601 writer.push_blocking(value);
602 }
603 }
604 });
605
606 let mut last = None;
607 while last != Some(NUM_WRITES) {
608 match reader.pop() {
609 Some(value) => {
610 if let Some(last) = last {
611 assert_eq!(last + 1, value, "values should be popped in order");
612 }
613 last = Some(value);
614 }
615 None => thread::yield_now(),
616 }
617 }
618 }
619}