1use std::cell::Cell;
46use std::fmt;
47use std::mem::ManuallyDrop;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicUsize, Ordering};
50
51use crossbeam_utils::CachePadded;
52
53use crate::Full;
54
55pub fn ring_buffer<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
63 assert!(capacity > 0, "capacity must be non-zero");
64
65 let capacity = capacity
66 .checked_next_power_of_two()
67 .expect("capacity too large (must be <= usize::MAX / 2)");
68 let mask = capacity - 1;
69
70 let mut slots = ManuallyDrop::new(Vec::<T>::with_capacity(capacity));
71 let buffer = slots.as_mut_ptr();
72
73 let shared = Arc::new(Shared {
74 tail: CachePadded::new(AtomicUsize::new(0)),
75 head: CachePadded::new(AtomicUsize::new(0)),
76 buffer,
77 mask,
78 });
79
80 (
81 Producer {
82 local_tail: Cell::new(0),
83 cached_head: Cell::new(0),
84 buffer,
85 mask,
86 shared: Arc::clone(&shared),
87 },
88 Consumer {
89 local_head: Cell::new(0),
90 cached_tail: Cell::new(0),
91 buffer,
92 mask,
93 shared,
94 },
95 )
96}
97
98#[repr(C)]
101struct Shared<T> {
102 tail: CachePadded<AtomicUsize>,
103 head: CachePadded<AtomicUsize>,
104 buffer: *mut T,
105 mask: usize,
106}
107
108unsafe impl<T: Send> Send for Shared<T> {}
112unsafe impl<T: Send> Sync for Shared<T> {}
113
114impl<T> Drop for Shared<T> {
115 fn drop(&mut self) {
116 let head = self.head.load(Ordering::Relaxed);
117 let tail = self.tail.load(Ordering::Relaxed);
118
119 let mut i = head;
120 while i != tail {
121 unsafe { self.buffer.add(i & self.mask).drop_in_place() };
124 i = i.wrapping_add(1);
125 }
126
127 unsafe {
130 let capacity = self.mask + 1;
131 let _ = Vec::from_raw_parts(self.buffer, 0, capacity);
132 }
133 }
134}
135
136#[repr(C)]
142pub struct Producer<T> {
143 local_tail: Cell<usize>,
144 cached_head: Cell<usize>,
145 buffer: *mut T,
146 mask: usize,
147 shared: Arc<Shared<T>>,
148}
149
150unsafe impl<T: Send> Send for Producer<T> {}
154
155impl<T> Producer<T> {
156 #[inline]
161 #[must_use = "push returns Err if full, which should be handled"]
162 pub fn push(&self, value: T) -> Result<(), Full<T>> {
163 let tail = self.local_tail.get();
164
165 if tail.wrapping_sub(self.cached_head.get()) > self.mask {
166 self.cached_head
167 .set(self.shared.head.load(Ordering::Relaxed));
168
169 std::sync::atomic::fence(Ordering::Acquire);
170 if tail.wrapping_sub(self.cached_head.get()) > self.mask {
171 return Err(Full(value));
172 }
173 }
174
175 unsafe { self.buffer.add(tail & self.mask).write(value) };
178 let new_tail = tail.wrapping_add(1);
179 std::sync::atomic::fence(Ordering::Release);
180
181 self.shared.tail.store(new_tail, Ordering::Relaxed);
182 self.local_tail.set(new_tail);
183
184 Ok(())
185 }
186
187 #[inline]
189 pub fn capacity(&self) -> usize {
190 self.mask + 1
191 }
192
193 #[inline]
195 pub fn is_disconnected(&self) -> bool {
196 Arc::strong_count(&self.shared) == 1
197 }
198}
199
200impl<T> fmt::Debug for Producer<T> {
201 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202 f.debug_struct("Producer")
203 .field("capacity", &self.capacity())
204 .finish_non_exhaustive()
205 }
206}
207
208#[repr(C)]
214pub struct Consumer<T> {
215 local_head: Cell<usize>,
216 cached_tail: Cell<usize>,
217 buffer: *mut T,
218 mask: usize,
219 shared: Arc<Shared<T>>,
220}
221
222unsafe impl<T: Send> Send for Consumer<T> {}
226
227impl<T> Consumer<T> {
228 #[inline]
232 pub fn pop(&self) -> Option<T> {
233 let head = self.local_head.get();
234
235 if head == self.cached_tail.get() {
236 self.cached_tail
237 .set(self.shared.tail.load(Ordering::Relaxed));
238 std::sync::atomic::fence(Ordering::Acquire);
239
240 if head == self.cached_tail.get() {
241 return None;
242 }
243 }
244
245 let value = unsafe { self.buffer.add(head & self.mask).read() };
248 let new_head = head.wrapping_add(1);
249 std::sync::atomic::fence(Ordering::Release);
250
251 self.shared.head.store(new_head, Ordering::Relaxed);
252 self.local_head.set(new_head);
253
254 Some(value)
255 }
256
257 #[inline]
259 pub fn capacity(&self) -> usize {
260 self.mask + 1
261 }
262
263 #[inline]
265 pub fn is_disconnected(&self) -> bool {
266 Arc::strong_count(&self.shared) == 1
267 }
268}
269
270impl<T> fmt::Debug for Consumer<T> {
271 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272 f.debug_struct("Consumer")
273 .field("capacity", &self.capacity())
274 .finish_non_exhaustive()
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
287 fn basic_push_pop() {
288 let (prod, cons) = ring_buffer::<u64>(4);
289
290 assert!(prod.push(1).is_ok());
291 assert!(prod.push(2).is_ok());
292 assert!(prod.push(3).is_ok());
293
294 assert_eq!(cons.pop(), Some(1));
295 assert_eq!(cons.pop(), Some(2));
296 assert_eq!(cons.pop(), Some(3));
297 assert_eq!(cons.pop(), None);
298 }
299
300 #[test]
301 fn empty_pop_returns_none() {
302 let (_, cons) = ring_buffer::<u64>(4);
303 assert_eq!(cons.pop(), None);
304 assert_eq!(cons.pop(), None);
305 }
306
307 #[test]
308 fn fill_then_drain() {
309 let (prod, cons) = ring_buffer::<u64>(4);
310
311 for i in 0..4 {
312 assert!(prod.push(i).is_ok());
313 }
314
315 for i in 0..4 {
316 assert_eq!(cons.pop(), Some(i));
317 }
318
319 assert_eq!(cons.pop(), None);
320 }
321
322 #[test]
323 fn push_returns_error_when_full() {
324 let (prod, _cons) = ring_buffer::<u64>(4);
325
326 assert!(prod.push(1).is_ok());
327 assert!(prod.push(2).is_ok());
328 assert!(prod.push(3).is_ok());
329 assert!(prod.push(4).is_ok());
330
331 let err = prod.push(5).unwrap_err();
332 assert_eq!(err.into_inner(), 5);
333 }
334
335 #[test]
340 fn interleaved_no_overwrite() {
341 let (prod, cons) = ring_buffer::<u64>(8);
342
343 for i in 0..1000 {
344 assert!(prod.push(i).is_ok());
345 assert_eq!(cons.pop(), Some(i));
346 }
347 }
348
349 #[test]
350 fn partial_fill_drain_cycles() {
351 let (prod, cons) = ring_buffer::<u64>(8);
352
353 for round in 0..100 {
354 for i in 0..4 {
355 assert!(prod.push(round * 4 + i).is_ok());
356 }
357
358 for i in 0..4 {
359 assert_eq!(cons.pop(), Some(round * 4 + i));
360 }
361 }
362 }
363
364 #[test]
369 fn single_slot_bounded() {
370 let (prod, cons) = ring_buffer::<u64>(1);
371
372 assert!(prod.push(1).is_ok());
373 assert!(prod.push(2).is_err());
374
375 assert_eq!(cons.pop(), Some(1));
376 assert!(prod.push(2).is_ok());
377 }
378
379 #[test]
384 fn producer_disconnected() {
385 let (prod, cons) = ring_buffer::<u64>(4);
386
387 assert!(!cons.is_disconnected());
388 drop(prod);
389 assert!(cons.is_disconnected());
390 }
391
392 #[test]
393 fn consumer_disconnected() {
394 let (prod, cons) = ring_buffer::<u64>(4);
395
396 assert!(!prod.is_disconnected());
397 drop(cons);
398 assert!(prod.is_disconnected());
399 }
400
401 #[test]
406 fn drop_cleans_up_remaining() {
407 use std::sync::atomic::AtomicUsize;
408
409 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
410
411 struct DropCounter;
412 impl Drop for DropCounter {
413 fn drop(&mut self) {
414 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
415 }
416 }
417
418 DROP_COUNT.store(0, Ordering::SeqCst);
419
420 let (prod, cons) = ring_buffer::<DropCounter>(4);
421
422 let _ = prod.push(DropCounter);
423 let _ = prod.push(DropCounter);
424 let _ = prod.push(DropCounter);
425
426 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
427
428 drop(prod);
429 drop(cons);
430
431 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
432 }
433
434 #[test]
439 fn cross_thread_bounded() {
440 use std::thread;
441
442 let (prod, cons) = ring_buffer::<u64>(64);
443
444 let producer = thread::spawn(move || {
445 for i in 0..10_000 {
446 while prod.push(i).is_err() {
447 std::hint::spin_loop();
448 }
449 }
450 });
451
452 let consumer = thread::spawn(move || {
453 let mut received = 0u64;
454 while received < 10_000 {
455 if cons.pop().is_some() {
456 received += 1;
457 } else {
458 std::hint::spin_loop();
459 }
460 }
461 received
462 });
463
464 producer.join().unwrap();
465 let received = consumer.join().unwrap();
466 assert_eq!(received, 10_000);
467 }
468
469 #[test]
474 fn zero_sized_type() {
475 let (prod, cons) = ring_buffer::<()>(8);
476
477 let _ = prod.push(());
478 let _ = prod.push(());
479
480 assert_eq!(cons.pop(), Some(()));
481 assert_eq!(cons.pop(), Some(()));
482 assert_eq!(cons.pop(), None);
483 }
484
485 #[test]
486 fn string_type() {
487 let (prod, cons) = ring_buffer::<String>(4);
488
489 let _ = prod.push("hello".to_string());
490 let _ = prod.push("world".to_string());
491
492 assert_eq!(cons.pop(), Some("hello".to_string()));
493 assert_eq!(cons.pop(), Some("world".to_string()));
494 }
495
496 #[test]
497 #[should_panic(expected = "capacity must be non-zero")]
498 fn zero_capacity_panics() {
499 let _ = ring_buffer::<u64>(0);
500 }
501
502 #[test]
503 fn large_message_type() {
504 #[repr(C, align(64))]
505 struct LargeMessage {
506 data: [u8; 256],
507 }
508
509 let (prod, cons) = ring_buffer::<LargeMessage>(8);
510
511 let msg = LargeMessage { data: [42u8; 256] };
512 assert!(prod.push(msg).is_ok());
513
514 let received = cons.pop().unwrap();
515 assert_eq!(received.data[0], 42);
516 assert_eq!(received.data[255], 42);
517 }
518
519 #[test]
520 fn multiple_laps() {
521 let (prod, cons) = ring_buffer::<u64>(4);
522
523 for i in 0..40 {
525 assert!(prod.push(i).is_ok());
526 assert_eq!(cons.pop(), Some(i));
527 }
528 }
529
530 #[test]
531 fn fifo_order_cross_thread() {
532 use std::thread;
533
534 let (prod, cons) = ring_buffer::<u64>(64);
535
536 let producer = thread::spawn(move || {
537 for i in 0..10_000u64 {
538 while prod.push(i).is_err() {
539 std::hint::spin_loop();
540 }
541 }
542 });
543
544 let consumer = thread::spawn(move || {
545 let mut expected = 0u64;
546 while expected < 10_000 {
547 if let Some(val) = cons.pop() {
548 assert_eq!(val, expected, "FIFO order violated");
549 expected += 1;
550 } else {
551 std::hint::spin_loop();
552 }
553 }
554 });
555
556 producer.join().unwrap();
557 consumer.join().unwrap();
558 }
559
560 #[test]
561 fn stress_high_volume() {
562 use std::thread;
563
564 const COUNT: u64 = 1_000_000;
565
566 let (prod, cons) = ring_buffer::<u64>(1024);
567
568 let producer = thread::spawn(move || {
569 for i in 0..COUNT {
570 while prod.push(i).is_err() {
571 std::hint::spin_loop();
572 }
573 }
574 });
575
576 let consumer = thread::spawn(move || {
577 let mut sum = 0u64;
578 let mut received = 0u64;
579 while received < COUNT {
580 if let Some(val) = cons.pop() {
581 sum = sum.wrapping_add(val);
582 received += 1;
583 } else {
584 std::hint::spin_loop();
585 }
586 }
587 sum
588 });
589
590 producer.join().unwrap();
591 let sum = consumer.join().unwrap();
592 assert_eq!(sum, COUNT * (COUNT - 1) / 2);
593 }
594
595 #[test]
596 fn capacity_rounds_to_power_of_two() {
597 let (prod, _) = ring_buffer::<u64>(100);
598 assert_eq!(prod.capacity(), 128);
599
600 let (prod, _) = ring_buffer::<u64>(1000);
601 assert_eq!(prod.capacity(), 1024);
602 }
603}