1use std::cell::Cell;
46use std::fmt;
47use std::mem::ManuallyDrop;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicUsize, Ordering};
50
51use crossbeam_utils::CachePadded;
52
53use crate::Full;
54
55pub fn ring_buffer<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
63 assert!(capacity > 0, "capacity must be non-zero");
64
65 let capacity = capacity.next_power_of_two();
66 let mask = capacity - 1;
67
68 let mut slots = ManuallyDrop::new(Vec::<T>::with_capacity(capacity));
69 let buffer = slots.as_mut_ptr();
70
71 let shared = Arc::new(Shared {
72 tail: CachePadded::new(AtomicUsize::new(0)),
73 head: CachePadded::new(AtomicUsize::new(0)),
74 buffer,
75 mask,
76 });
77
78 (
79 Producer {
80 local_tail: Cell::new(0),
81 cached_head: Cell::new(0),
82 buffer,
83 mask,
84 shared: Arc::clone(&shared),
85 },
86 Consumer {
87 local_head: Cell::new(0),
88 cached_tail: Cell::new(0),
89 buffer,
90 mask,
91 shared,
92 },
93 )
94}
95
96#[repr(C)]
99struct Shared<T> {
100 tail: CachePadded<AtomicUsize>,
101 head: CachePadded<AtomicUsize>,
102 buffer: *mut T,
103 mask: usize,
104}
105
106unsafe impl<T: Send> Send for Shared<T> {}
110unsafe impl<T: Send> Sync for Shared<T> {}
111
112impl<T> Drop for Shared<T> {
113 fn drop(&mut self) {
114 let head = self.head.load(Ordering::Relaxed);
115 let tail = self.tail.load(Ordering::Relaxed);
116
117 let mut i = head;
118 while i != tail {
119 unsafe { self.buffer.add(i & self.mask).drop_in_place() };
122 i = i.wrapping_add(1);
123 }
124
125 unsafe {
128 let capacity = self.mask + 1;
129 let _ = Vec::from_raw_parts(self.buffer, 0, capacity);
130 }
131 }
132}
133
134#[repr(C)]
140pub struct Producer<T> {
141 local_tail: Cell<usize>,
142 cached_head: Cell<usize>,
143 buffer: *mut T,
144 mask: usize,
145 shared: Arc<Shared<T>>,
146}
147
148unsafe impl<T: Send> Send for Producer<T> {}
152
153impl<T> Producer<T> {
154 #[inline]
159 #[must_use = "push returns Err if full, which should be handled"]
160 pub fn push(&self, value: T) -> Result<(), Full<T>> {
161 let tail = self.local_tail.get();
162
163 if tail.wrapping_sub(self.cached_head.get()) > self.mask {
164 self.cached_head.set(self.shared.head.load(Ordering::Relaxed));
165
166 std::sync::atomic::fence(Ordering::Acquire);
167 if tail.wrapping_sub(self.cached_head.get()) > self.mask {
168 return Err(Full(value));
169 }
170 }
171
172 unsafe { self.buffer.add(tail & self.mask).write(value) };
175 let new_tail = tail.wrapping_add(1);
176 std::sync::atomic::fence(Ordering::Release);
177
178 self.shared.tail.store(new_tail, Ordering::Relaxed);
179 self.local_tail.set(new_tail);
180
181 Ok(())
182 }
183
184 #[inline]
186 pub fn capacity(&self) -> usize {
187 self.mask + 1
188 }
189
190 #[inline]
192 pub fn is_disconnected(&self) -> bool {
193 Arc::strong_count(&self.shared) == 1
194 }
195}
196
197impl<T> fmt::Debug for Producer<T> {
198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
199 f.debug_struct("Producer")
200 .field("capacity", &self.capacity())
201 .finish_non_exhaustive()
202 }
203}
204
205#[repr(C)]
211pub struct Consumer<T> {
212 local_head: Cell<usize>,
213 cached_tail: Cell<usize>,
214 buffer: *mut T,
215 mask: usize,
216 shared: Arc<Shared<T>>,
217}
218
219unsafe impl<T: Send> Send for Consumer<T> {}
223
224impl<T> Consumer<T> {
225 #[inline]
229 pub fn pop(&self) -> Option<T> {
230 let head = self.local_head.get();
231
232 if head == self.cached_tail.get() {
233 self.cached_tail.set(self.shared.tail.load(Ordering::Relaxed));
234 std::sync::atomic::fence(Ordering::Acquire);
235
236 if head == self.cached_tail.get() {
237 return None;
238 }
239 }
240
241 let value = unsafe { self.buffer.add(head & self.mask).read() };
244 let new_head = head.wrapping_add(1);
245 std::sync::atomic::fence(Ordering::Release);
246
247 self.shared.head.store(new_head, Ordering::Relaxed);
248 self.local_head.set(new_head);
249
250 Some(value)
251 }
252
253 #[inline]
255 pub fn capacity(&self) -> usize {
256 self.mask + 1
257 }
258
259 #[inline]
261 pub fn is_disconnected(&self) -> bool {
262 Arc::strong_count(&self.shared) == 1
263 }
264}
265
266impl<T> fmt::Debug for Consumer<T> {
267 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
268 f.debug_struct("Consumer")
269 .field("capacity", &self.capacity())
270 .finish_non_exhaustive()
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
283 fn basic_push_pop() {
284 let (prod, cons) = ring_buffer::<u64>(4);
285
286 assert!(prod.push(1).is_ok());
287 assert!(prod.push(2).is_ok());
288 assert!(prod.push(3).is_ok());
289
290 assert_eq!(cons.pop(), Some(1));
291 assert_eq!(cons.pop(), Some(2));
292 assert_eq!(cons.pop(), Some(3));
293 assert_eq!(cons.pop(), None);
294 }
295
296 #[test]
297 fn empty_pop_returns_none() {
298 let (_, cons) = ring_buffer::<u64>(4);
299 assert_eq!(cons.pop(), None);
300 assert_eq!(cons.pop(), None);
301 }
302
303 #[test]
304 fn fill_then_drain() {
305 let (prod, cons) = ring_buffer::<u64>(4);
306
307 for i in 0..4 {
308 assert!(prod.push(i).is_ok());
309 }
310
311 for i in 0..4 {
312 assert_eq!(cons.pop(), Some(i));
313 }
314
315 assert_eq!(cons.pop(), None);
316 }
317
318 #[test]
319 fn push_returns_error_when_full() {
320 let (prod, _cons) = ring_buffer::<u64>(4);
321
322 assert!(prod.push(1).is_ok());
323 assert!(prod.push(2).is_ok());
324 assert!(prod.push(3).is_ok());
325 assert!(prod.push(4).is_ok());
326
327 let err = prod.push(5).unwrap_err();
328 assert_eq!(err.into_inner(), 5);
329 }
330
331 #[test]
336 fn interleaved_no_overwrite() {
337 let (prod, cons) = ring_buffer::<u64>(8);
338
339 for i in 0..1000 {
340 assert!(prod.push(i).is_ok());
341 assert_eq!(cons.pop(), Some(i));
342 }
343 }
344
345 #[test]
346 fn partial_fill_drain_cycles() {
347 let (prod, cons) = ring_buffer::<u64>(8);
348
349 for round in 0..100 {
350 for i in 0..4 {
351 assert!(prod.push(round * 4 + i).is_ok());
352 }
353
354 for i in 0..4 {
355 assert_eq!(cons.pop(), Some(round * 4 + i));
356 }
357 }
358 }
359
360 #[test]
365 fn single_slot_bounded() {
366 let (prod, cons) = ring_buffer::<u64>(1);
367
368 assert!(prod.push(1).is_ok());
369 assert!(prod.push(2).is_err());
370
371 assert_eq!(cons.pop(), Some(1));
372 assert!(prod.push(2).is_ok());
373 }
374
375 #[test]
380 fn producer_disconnected() {
381 let (prod, cons) = ring_buffer::<u64>(4);
382
383 assert!(!cons.is_disconnected());
384 drop(prod);
385 assert!(cons.is_disconnected());
386 }
387
388 #[test]
389 fn consumer_disconnected() {
390 let (prod, cons) = ring_buffer::<u64>(4);
391
392 assert!(!prod.is_disconnected());
393 drop(cons);
394 assert!(prod.is_disconnected());
395 }
396
397 #[test]
402 fn drop_cleans_up_remaining() {
403 use std::sync::atomic::AtomicUsize;
404
405 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
406
407 struct DropCounter;
408 impl Drop for DropCounter {
409 fn drop(&mut self) {
410 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
411 }
412 }
413
414 DROP_COUNT.store(0, Ordering::SeqCst);
415
416 let (prod, cons) = ring_buffer::<DropCounter>(4);
417
418 let _ = prod.push(DropCounter);
419 let _ = prod.push(DropCounter);
420 let _ = prod.push(DropCounter);
421
422 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
423
424 drop(prod);
425 drop(cons);
426
427 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
428 }
429
430 #[test]
435 fn cross_thread_bounded() {
436 use std::thread;
437
438 let (prod, cons) = ring_buffer::<u64>(64);
439
440 let producer = thread::spawn(move || {
441 for i in 0..10_000 {
442 while prod.push(i).is_err() {
443 std::hint::spin_loop();
444 }
445 }
446 });
447
448 let consumer = thread::spawn(move || {
449 let mut received = 0u64;
450 while received < 10_000 {
451 if cons.pop().is_some() {
452 received += 1;
453 } else {
454 std::hint::spin_loop();
455 }
456 }
457 received
458 });
459
460 producer.join().unwrap();
461 let received = consumer.join().unwrap();
462 assert_eq!(received, 10_000);
463 }
464
465 #[test]
470 fn zero_sized_type() {
471 let (prod, cons) = ring_buffer::<()>(8);
472
473 let _ = prod.push(());
474 let _ = prod.push(());
475
476 assert_eq!(cons.pop(), Some(()));
477 assert_eq!(cons.pop(), Some(()));
478 assert_eq!(cons.pop(), None);
479 }
480
481 #[test]
482 fn string_type() {
483 let (prod, cons) = ring_buffer::<String>(4);
484
485 let _ = prod.push("hello".to_string());
486 let _ = prod.push("world".to_string());
487
488 assert_eq!(cons.pop(), Some("hello".to_string()));
489 assert_eq!(cons.pop(), Some("world".to_string()));
490 }
491
492 #[test]
493 #[should_panic(expected = "capacity must be non-zero")]
494 fn zero_capacity_panics() {
495 let _ = ring_buffer::<u64>(0);
496 }
497
498 #[test]
499 fn large_message_type() {
500 #[repr(C, align(64))]
501 struct LargeMessage {
502 data: [u8; 256],
503 }
504
505 let (prod, cons) = ring_buffer::<LargeMessage>(8);
506
507 let msg = LargeMessage { data: [42u8; 256] };
508 assert!(prod.push(msg).is_ok());
509
510 let received = cons.pop().unwrap();
511 assert_eq!(received.data[0], 42);
512 assert_eq!(received.data[255], 42);
513 }
514
515 #[test]
516 fn multiple_laps() {
517 let (prod, cons) = ring_buffer::<u64>(4);
518
519 for i in 0..40 {
521 assert!(prod.push(i).is_ok());
522 assert_eq!(cons.pop(), Some(i));
523 }
524 }
525
526 #[test]
527 fn fifo_order_cross_thread() {
528 use std::thread;
529
530 let (prod, cons) = ring_buffer::<u64>(64);
531
532 let producer = thread::spawn(move || {
533 for i in 0..10_000u64 {
534 while prod.push(i).is_err() {
535 std::hint::spin_loop();
536 }
537 }
538 });
539
540 let consumer = thread::spawn(move || {
541 let mut expected = 0u64;
542 while expected < 10_000 {
543 if let Some(val) = cons.pop() {
544 assert_eq!(val, expected, "FIFO order violated");
545 expected += 1;
546 } else {
547 std::hint::spin_loop();
548 }
549 }
550 });
551
552 producer.join().unwrap();
553 consumer.join().unwrap();
554 }
555
556 #[test]
557 fn stress_high_volume() {
558 use std::thread;
559
560 const COUNT: u64 = 1_000_000;
561
562 let (prod, cons) = ring_buffer::<u64>(1024);
563
564 let producer = thread::spawn(move || {
565 for i in 0..COUNT {
566 while prod.push(i).is_err() {
567 std::hint::spin_loop();
568 }
569 }
570 });
571
572 let consumer = thread::spawn(move || {
573 let mut sum = 0u64;
574 let mut received = 0u64;
575 while received < COUNT {
576 if let Some(val) = cons.pop() {
577 sum = sum.wrapping_add(val);
578 received += 1;
579 } else {
580 std::hint::spin_loop();
581 }
582 }
583 sum
584 });
585
586 producer.join().unwrap();
587 let sum = consumer.join().unwrap();
588 assert_eq!(sum, COUNT * (COUNT - 1) / 2);
589 }
590
591 #[test]
592 fn capacity_rounds_to_power_of_two() {
593 let (prod, _) = ring_buffer::<u64>(100);
594 assert_eq!(prod.capacity(), 128);
595
596 let (prod, _) = ring_buffer::<u64>(1000);
597 assert_eq!(prod.capacity(), 1024);
598 }
599}