1use {
2 crate::{
3 backoff::Backoff,
4 sync::{
5 atomic::{AtomicUsize, Ordering},
6 Arc,
7 },
8 },
9 crossbeam_utils::CachePadded,
10 std::{alloc, cmp::PartialEq, marker::PhantomData, ops::Deref},
11};
12
13pub struct Producer<T> {
15 shared: Arc<Shared<T>>,
16}
17
18pub struct Consumer<T> {
20 shared: Arc<Shared<T>>,
21}
22
23pub fn fifo<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
29 let shared = Arc::new(Shared::new(capacity));
30
31 (
32 Producer {
33 shared: Arc::clone(&shared),
34 },
35 Consumer { shared },
36 )
37}
38
39unsafe impl<T> Send for Producer<T> where T: Send {}
40unsafe impl<T> Send for Consumer<T> where T: Send {}
41
42struct Shared<T> {
43 capacity: usize,
44 buffer: Buffer<T>,
45 head: CachePadded<AtomicUsize>,
46 tail: CachePadded<AtomicUsize>,
47}
48
49struct Buffer<T> {
50 ptr: *mut T,
51 len: usize,
52}
53
54impl<T> Producer<T> {
55 pub fn push_blocking(&self, mut value: T) {
59 let backoff = Backoff::default();
60
61 while let Err(value_failed_to_push) = self.push(value) {
62 backoff.snooze();
63 value = value_failed_to_push;
64 }
65 }
66
67 pub fn push(&self, value: T) -> Result<(), T> {
71 let tail = self.shared.get(Ordering::Relaxed);
72 let head = self.shared.get(Ordering::Acquire);
73
74 let size = self.shared.size(head, tail);
75 assert!(
76 size <= self.shared.capacity,
77 "size ({}) should not be greater than capacity ({})",
78 size,
79 self.shared.capacity
80 );
81
82 if size == self.shared.capacity {
83 return Err(value);
84 }
85
86 let element = self.shared.element(tail);
87 if self.shared.has_wrapped_around(tail) {
88 unsafe { element.drop_in_place() };
89 }
90 unsafe { element.write(value) };
91
92 self.shared.advance(tail, Ordering::Release);
93 Ok(())
94 }
95}
96
97impl<T> Consumer<T> {
98 pub fn pop(&self) -> Option<T>
100 where
101 T: Copy,
102 {
103 self.pop_ref().map(|r| *r)
104 }
105
106 pub fn pop_ref(&self) -> Option<PopRef<'_, T>> {
110 let head = self.shared.get(Ordering::Relaxed);
111 let tail = self.shared.get(Ordering::Acquire);
112
113 let size = self.shared.size(head, tail);
114 assert!(
115 size <= self.shared.capacity,
116 "size ({}) should not be greater than capacity ({})",
117 size,
118 self.shared.capacity
119 );
120
121 if size == 0 {
122 return None;
123 }
124
125 let element = self.shared.element(head);
126 let value = unsafe { &*element };
127
128 Some(PopRef {
129 head,
130 value,
131 consumer: self,
132 })
133 }
134}
135
136impl<T> Shared<T> {
137 fn new(capacity: usize) -> Self {
138 Self {
139 capacity,
140 buffer: Buffer::new(capacity),
141 head: CachePadded::new(AtomicUsize::new(0)),
142 tail: CachePadded::new(AtomicUsize::new(0)),
143 }
144 }
145
146 fn size(&self, head: Head, tail: Tail) -> usize {
147 size(head, tail, self.buffer.len)
148 }
149
150 fn index<Role>(&self, cursor: Cursor<Role>) -> usize {
151 index(cursor, self.buffer.len)
152 }
153
154 fn element<Role>(&self, cursor: Cursor<Role>) -> *mut T {
155 self.buffer.at(self.index(cursor))
156 }
157
158 fn has_wrapped_around<Role>(&self, Cursor(pos, _): Cursor<Role>) -> bool {
159 pos >= self.buffer.len
160 }
161
162 fn advance<Role>(&self, cursor: Cursor<Role>, ordering: Ordering)
163 where
164 Self: SetCursor<Role>,
165 {
166 self.set(advance(cursor, self.buffer.len), ordering);
167 }
168}
169
170trait SetCursor<Role> {
171 fn set(&self, cursor: Cursor<Role>, ordering: Ordering);
172}
173
174impl<T> SetCursor<HeadRole> for Shared<T> {
175 fn set(&self, head: Head, ordering: Ordering) {
176 self.head.store(head.into(), ordering);
177 }
178}
179
180impl<T> SetCursor<TailRole> for Shared<T> {
181 fn set(&self, tail: Tail, ordering: Ordering) {
182 self.tail.store(tail.into(), ordering);
183 }
184}
185
186trait GetCursor<Role> {
187 fn get(&self, ordering: Ordering) -> Cursor<Role>;
188}
189
190impl<T> GetCursor<HeadRole> for Shared<T> {
191 fn get(&self, ordering: Ordering) -> Head {
192 self.head.load(ordering).into()
193 }
194}
195
196impl<T> GetCursor<TailRole> for Shared<T> {
197 fn get(&self, ordering: Ordering) -> Tail {
198 self.tail.load(ordering).into()
199 }
200}
201
202impl<T> Drop for Shared<T> {
203 fn drop(&mut self) {
204 let tail: Tail = self.get(Ordering::Relaxed);
205
206 let elements_to_drop = if self.has_wrapped_around(tail) {
207 self.buffer.len
208 } else {
209 tail.into()
210 };
211
212 for i in 0..elements_to_drop {
213 let element = self.buffer.at(i);
214 unsafe { element.drop_in_place() };
215 }
216 }
217}
218
219impl<T> Buffer<T> {
220 fn new(size: usize) -> Self {
221 let size = size.next_power_of_two();
222 let layout = layout_for::<T>(size);
223
224 let buffer = unsafe { alloc::alloc(layout) };
225 if buffer.is_null() {
226 panic!("failed to allocate buffer");
227 }
228
229 Self {
230 ptr: buffer.cast(),
231 len: size,
232 }
233 }
234
235 fn at(&self, index: usize) -> *mut T {
236 debug_assert!(index < self.len, "index out of bounds");
237 unsafe { self.ptr.add(index) }
238 }
239}
240
241impl<T> Drop for Buffer<T> {
242 fn drop(&mut self) {
243 let layout = layout_for::<T>(self.len);
244 unsafe { alloc::dealloc(self.ptr.cast(), layout) };
245 }
246}
247
248fn layout_for<T>(size: usize) -> alloc::Layout {
249 let bytes = size.checked_mul(size_of::<T>()).expect("capacity overflow");
250 alloc::Layout::from_size_align(bytes, align_of::<T>()).expect("failed to create layout")
251}
252
253pub struct PopRef<'a, T> {
255 head: Head,
256 value: &'a T,
257 consumer: &'a Consumer<T>,
258}
259
260impl<T> Deref for PopRef<'_, T> {
261 type Target = T;
262
263 fn deref(&self) -> &Self::Target {
264 self.value
265 }
266}
267
268impl<T> Drop for PopRef<'_, T> {
269 fn drop(&mut self) {
270 self.consumer.shared.advance(self.head, Ordering::Release);
271 }
272}
273
274#[derive(Debug, Copy, Clone)]
275struct Cursor<Role>(usize, PhantomData<Role>);
276
277fn size(head: Head, tail: Tail, size: usize) -> usize {
278 if head == tail {
279 return 0;
280 }
281
282 let head = index(head, size);
283 let tail = index(tail, size);
284
285 if head < tail {
286 tail - head
287 } else {
288 size - head + tail
289 }
290}
291
292fn advance<Role>(Cursor(cursor, _): Cursor<Role>, size: usize) -> Cursor<Role> {
293 match cursor.checked_add(1) {
294 Some(cursor) => cursor,
295 None => (cursor % size) + size + 1,
296 }
297 .into()
298}
299
300fn index<Role>(Cursor(cursor, _): Cursor<Role>, size: usize) -> usize {
301 debug_assert!(
302 size.is_power_of_two(),
303 "size must be a power of two, got {size:?}",
304 );
305
306 cursor & (size - 1)
307}
308
309#[derive(Debug, Copy, Clone)]
310struct HeadRole;
311
312#[derive(Debug, Copy, Clone)]
313struct TailRole;
314
315type Head = Cursor<HeadRole>;
316
317type Tail = Cursor<TailRole>;
318
319impl<RoleA, RoleB> PartialEq<Cursor<RoleA>> for Cursor<RoleB> {
320 fn eq(&self, other: &Cursor<RoleA>) -> bool {
321 self.0 == other.0
322 }
323}
324
325impl<RoleA, RoleB> PartialOrd<Cursor<RoleA>> for Cursor<RoleB> {
326 fn partial_cmp(&self, other: &Cursor<RoleA>) -> Option<std::cmp::Ordering> {
327 self.0.partial_cmp(&other.0)
328 }
329}
330
331impl<Role> From<usize> for Cursor<Role> {
332 fn from(value: usize) -> Self {
333 Cursor(value, PhantomData)
334 }
335}
336
337impl<Role> From<Cursor<Role>> for usize {
338 fn from(Cursor(cursor, _): Cursor<Role>) -> usize {
339 cursor
340 }
341}
342
343#[cfg(test)]
344mod test {
345 use {super::*, std::thread};
346
347 fn get_buffer_size<T>(producer: &Producer<T>) -> usize {
348 size(
349 producer.shared.get(Ordering::Relaxed),
350 producer.shared.get(Ordering::Relaxed),
351 producer.shared.buffer.len,
352 )
353 }
354
355 #[derive(Debug, Default, Clone)]
356 struct DropCounter(Arc<AtomicUsize>);
357
358 impl DropCounter {
359 fn count(&self) -> usize {
360 self.0.load(Ordering::Relaxed)
361 }
362 }
363
364 impl Drop for DropCounter {
365 fn drop(&mut self) {
366 self.0.fetch_add(1, Ordering::Relaxed);
367 }
368 }
369
370 fn head(pos: usize) -> Head {
371 Cursor(pos, PhantomData)
372 }
373
374 fn tail(pos: usize) -> Tail {
375 Cursor(pos, PhantomData)
376 }
377
378 #[test]
379 fn querying_size() {
380 assert_eq!(size(head(0), tail(0), 4), 0);
381 assert_eq!(size(head(0), tail(1), 4), 1);
382 assert_eq!(size(head(0), tail(2), 4), 2);
383 assert_eq!(size(head(0), tail(3), 4), 3);
384 assert_eq!(size(head(1), tail(3), 4), 2);
385 assert_eq!(size(head(2), tail(3), 4), 1);
386 assert_eq!(size(head(3), tail(3), 4), 0);
387 }
388
389 #[test]
390 fn advancing_cursors() {
391 let cursor = head(0);
392
393 let cursor = advance(cursor, 4);
394 assert_eq!(cursor, head(1));
395
396 let cursor = advance(cursor, 4);
397 assert_eq!(cursor, head(2));
398
399 let cursor = advance(cursor, 4);
400 assert_eq!(cursor, head(3));
401
402 let cursor = advance(cursor, 4);
403 assert_eq!(cursor, head(4));
404
405 let cursor = advance(cursor, 4);
406 assert_eq!(cursor, head(5));
407 }
408
409 #[test]
410 fn cursor_overflow() {
411 for capacity in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
412 let head = head(usize::MAX - capacity);
413 let tail = tail(usize::MAX);
414
415 assert_eq!(size(head, tail, capacity), capacity);
416
417 let head = advance(head, capacity);
418
419 assert_eq!(size(head, tail, capacity), capacity - 1);
420
421 assert!(head < tail);
422 let tail = advance(tail, capacity);
423 assert!(tail < head);
424
425 assert_eq!(size(head, tail, capacity), capacity);
426 }
427 }
428
429 #[test]
430 fn cursor_to_index() {
431 assert_eq!(index(head(0), 4), 0);
432 assert_eq!(index(head(1), 4), 1);
433 assert_eq!(index(head(2), 4), 2);
434 assert_eq!(index(head(3), 4), 3);
435 assert_eq!(index(head(4), 4), 0);
436 assert_eq!(index(head(5), 4), 1);
437 assert_eq!(index(head(6), 4), 2);
438 assert_eq!(index(head(7), 4), 3);
439 assert_eq!(index(head(8), 4), 0);
440 }
441
442 #[beady::scenario]
443 #[test]
444 fn using_a_fifo() {
445 'given_an_empty_fifo: {
446 let (tx, rx) = fifo::<i32>(3);
447
448 'then_the_fifo_starts_empty: {
449 assert_eq!(get_buffer_size(&tx), 0);
450 }
451
452 'when_we_try_to_pop_a_value: {
453 let value = rx.pop();
454
455 'then_nothing_is_returned: {
456 assert!(value.is_none());
457 }
458 }
459
460 'when_we_push_a_value: {
461 tx.push(5).unwrap();
462
463 'and_when_we_try_to_pop_a_value: {
464 let value = rx.pop();
465
466 'then_the_popped_value_is_the_one_that_was_pushed: {
467 assert_eq!(value, Some(5));
468 }
469 }
470 }
471
472 'when_the_fifo_is_at_capacity: {
473 tx.push(1).unwrap();
474 tx.push(2).unwrap();
475 tx.push(3).unwrap();
476
477 'and_when_we_try_to_push_another_value: {
478 let push_result = tx.push(4);
479
480 'then_the_push_fails: {
481 assert!(push_result.is_err());
482
483 'and_then_we_get_the_value_we_tried_to_push_back: {
484 assert_eq!(push_result.unwrap_err(), 4);
485 }
486 }
487 }
488
489 'and_when_we_pop_a_value: {
490 let value = rx.pop();
491
492 'then_the_popped_value_is_the_first_one_that_was_pushed: {
493 assert_eq!(value, Some(1));
494 }
495
496 'and_when_we_push_another_value: {
497 let push_result = tx.push(4);
498
499 'then_the_push_now_succeeds: {
500 assert!(push_result.is_ok());
501 }
502 }
503 }
504 }
505 }
506
507 'given_a_fifo_whose_elements_do_not_impl_copy: {
508 let (tx, rx) = fifo::<String>(2);
509
510 'when_we_push_a_value: {
511 tx.push("hello".to_string()).unwrap();
512
513 'and_when_we_pop_a_value_by_ref: {
514 let value_ref = rx.pop_ref();
515
516 'then_the_pop_succeeds: {
517 assert!(value_ref.is_some());
518
519 'and_then_the_value_is_correct: {
520 assert_eq!(value_ref.unwrap().as_str(), "hello");
521 }
522 }
523 }
524 }
525 }
526 }
527
528 #[test]
529 fn elements_are_dropped_when_overwritten() {
530 let drop_counter = DropCounter::default();
531 let (tx, rx) = fifo(3);
532
533 tx.push(drop_counter.clone()).unwrap();
534 tx.push(drop_counter.clone()).unwrap();
535 tx.push(drop_counter.clone()).unwrap();
536 assert_eq!(drop_counter.count(), 0);
537
538 rx.pop_ref();
539 assert_eq!(drop_counter.count(), 0);
540
541 tx.push(drop_counter.clone()).unwrap();
542 assert_eq!(drop_counter.count(), 0);
543
544 rx.pop_ref();
545 assert_eq!(drop_counter.count(), 0);
546
547 tx.push(drop_counter.clone()).unwrap();
548 assert_eq!(drop_counter.count(), 1);
549 }
550
551 #[test]
552 fn elements_are_dropped_when_buffer_is_dropped() {
553 let drop_counter = DropCounter::default();
554 let (tx, rx) = fifo(3);
555
556 tx.push(drop_counter.clone()).unwrap();
557 tx.push(drop_counter.clone()).unwrap();
558 tx.push(drop_counter.clone()).unwrap();
559
560 rx.pop_ref();
561 assert_eq!(drop_counter.count(), 0);
562
563 tx.push(drop_counter.clone()).unwrap();
564 assert_eq!(drop_counter.count(), 0);
565
566 drop((tx, rx));
567
568 assert_eq!(drop_counter.count(), 4);
569 }
570
571 #[test]
572 fn cursors_wrap_around_and_start_at_the_allocated_capacity() {
573 const BEFORE_OVERFLOW: Tail = Cursor(usize::MAX - 1, PhantomData);
574 const OVERFLOW: Tail = Cursor(usize::MAX, PhantomData);
575
576 for capacity in (0..=1024).map(|capacity: usize| capacity.next_power_of_two()) {
577 let prev_tail = advance(BEFORE_OVERFLOW, capacity);
578 let prev_tail_index = index(prev_tail, capacity);
579
580 let next_tail = advance(OVERFLOW, capacity);
581 let next_tail_index = index(next_tail, capacity);
582
583 assert!(
586 next_tail.0 >= capacity,
587 "the wrapped cursor should not reuse the values between [0, allocated_capacity)"
588 );
589 assert_eq!(next_tail_index, (prev_tail_index + 1) % capacity);
590 }
591 }
592
593 #[test]
594 fn reading_and_writing_on_different_threads() {
595 let (writer, reader) = fifo(12);
596
597 #[cfg(miri)]
598 const NUM_WRITES: usize = 128;
599
600 #[cfg(not(miri))]
601 const NUM_WRITES: usize = 1_000_000;
602
603 thread::spawn({
604 move || {
605 for value in 1..=NUM_WRITES {
606 writer.push_blocking(value);
607 }
608 }
609 });
610
611 let mut last = None;
612 while last != Some(NUM_WRITES) {
613 match reader.pop() {
614 Some(value) => {
615 if let Some(last) = last {
616 assert_eq!(last + 1, value, "values should be popped in order");
617 }
618 last = Some(value);
619 }
620 None => thread::yield_now(),
621 }
622 }
623 }
624}