1use std::fmt;
46use std::mem::ManuallyDrop;
47use std::sync::Arc;
48use std::sync::atomic::{AtomicUsize, Ordering};
49
50use crossbeam_utils::CachePadded;
51
52use crate::Full;
53
54pub fn ring_buffer<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
62 assert!(capacity > 0, "capacity must be non-zero");
63
64 let capacity = capacity.next_power_of_two();
65 let mask = capacity - 1;
66
67 let mut slots = ManuallyDrop::new(Vec::<T>::with_capacity(capacity));
68 let buffer = slots.as_mut_ptr();
69
70 let shared = Arc::new(Shared {
71 tail: CachePadded::new(AtomicUsize::new(0)),
72 head: CachePadded::new(AtomicUsize::new(0)),
73 buffer,
74 mask,
75 });
76
77 (
78 Producer {
79 local_tail: 0,
80 cached_head: 0,
81 buffer,
82 mask,
83 shared: Arc::clone(&shared),
84 },
85 Consumer {
86 local_head: 0,
87 cached_tail: 0,
88 buffer,
89 mask,
90 shared,
91 },
92 )
93}
94
95#[repr(C)]
98struct Shared<T> {
99 tail: CachePadded<AtomicUsize>,
100 head: CachePadded<AtomicUsize>,
101 buffer: *mut T,
102 mask: usize,
103}
104
105unsafe impl<T: Send> Send for Shared<T> {}
109unsafe impl<T: Send> Sync for Shared<T> {}
110
111impl<T> Drop for Shared<T> {
112 fn drop(&mut self) {
113 let head = self.head.load(Ordering::Relaxed);
114 let tail = self.tail.load(Ordering::Relaxed);
115
116 let mut i = head;
117 while i != tail {
118 unsafe { self.buffer.add(i & self.mask).drop_in_place() };
121 i = i.wrapping_add(1);
122 }
123
124 unsafe {
127 let capacity = self.mask + 1;
128 let _ = Vec::from_raw_parts(self.buffer, 0, capacity);
129 }
130 }
131}
132
133#[repr(C)]
139pub struct Producer<T> {
140 local_tail: usize,
141 cached_head: usize,
142 buffer: *mut T,
143 mask: usize,
144 shared: Arc<Shared<T>>,
145}
146
147unsafe impl<T: Send> Send for Producer<T> {}
151
152impl<T> Producer<T> {
153 #[inline]
158 #[must_use = "push returns Err if full, which should be handled"]
159 pub fn push(&mut self, value: T) -> Result<(), Full<T>> {
160 let tail = self.local_tail;
161
162 if tail.wrapping_sub(self.cached_head) > self.mask {
163 self.cached_head = self.shared.head.load(Ordering::Relaxed);
164
165 std::sync::atomic::fence(Ordering::Acquire);
166 if tail.wrapping_sub(self.cached_head) > self.mask {
167 return Err(Full(value));
168 }
169 }
170
171 unsafe { self.buffer.add(tail & self.mask).write(value) };
174 let new_tail = tail.wrapping_add(1);
175 std::sync::atomic::fence(Ordering::Release);
176
177 self.shared.tail.store(new_tail, Ordering::Relaxed);
178 self.local_tail = new_tail;
179
180 Ok(())
181 }
182
183 #[inline]
185 pub fn capacity(&self) -> usize {
186 self.mask + 1
187 }
188
189 #[inline]
191 pub fn is_disconnected(&self) -> bool {
192 Arc::strong_count(&self.shared) == 1
193 }
194}
195
196impl<T> fmt::Debug for Producer<T> {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 f.debug_struct("Producer")
199 .field("capacity", &self.capacity())
200 .finish_non_exhaustive()
201 }
202}
203
204#[repr(C)]
210pub struct Consumer<T> {
211 local_head: usize,
212 cached_tail: usize,
213 buffer: *mut T,
214 mask: usize,
215 shared: Arc<Shared<T>>,
216}
217
218unsafe impl<T: Send> Send for Consumer<T> {}
222
223impl<T> Consumer<T> {
224 #[inline]
228 pub fn pop(&mut self) -> Option<T> {
229 let head = self.local_head;
230
231 if head == self.cached_tail {
232 self.cached_tail = self.shared.tail.load(Ordering::Relaxed);
233 std::sync::atomic::fence(Ordering::Acquire);
234
235 if head == self.cached_tail {
236 return None;
237 }
238 }
239
240 let value = unsafe { self.buffer.add(head & self.mask).read() };
243 let new_head = head.wrapping_add(1);
244 std::sync::atomic::fence(Ordering::Release);
245
246 self.shared.head.store(new_head, Ordering::Relaxed);
247 self.local_head = new_head;
248
249 Some(value)
250 }
251
252 #[inline]
254 pub fn capacity(&self) -> usize {
255 self.mask + 1
256 }
257
258 #[inline]
260 pub fn is_disconnected(&self) -> bool {
261 Arc::strong_count(&self.shared) == 1
262 }
263}
264
265impl<T> fmt::Debug for Consumer<T> {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 f.debug_struct("Consumer")
268 .field("capacity", &self.capacity())
269 .finish_non_exhaustive()
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
282 fn basic_push_pop() {
283 let (mut prod, mut cons) = ring_buffer::<u64>(4);
284
285 assert!(prod.push(1).is_ok());
286 assert!(prod.push(2).is_ok());
287 assert!(prod.push(3).is_ok());
288
289 assert_eq!(cons.pop(), Some(1));
290 assert_eq!(cons.pop(), Some(2));
291 assert_eq!(cons.pop(), Some(3));
292 assert_eq!(cons.pop(), None);
293 }
294
295 #[test]
296 fn empty_pop_returns_none() {
297 let (_, mut cons) = ring_buffer::<u64>(4);
298 assert_eq!(cons.pop(), None);
299 assert_eq!(cons.pop(), None);
300 }
301
302 #[test]
303 fn fill_then_drain() {
304 let (mut prod, mut cons) = ring_buffer::<u64>(4);
305
306 for i in 0..4 {
307 assert!(prod.push(i).is_ok());
308 }
309
310 for i in 0..4 {
311 assert_eq!(cons.pop(), Some(i));
312 }
313
314 assert_eq!(cons.pop(), None);
315 }
316
317 #[test]
318 fn push_returns_error_when_full() {
319 let (mut prod, _cons) = ring_buffer::<u64>(4);
320
321 assert!(prod.push(1).is_ok());
322 assert!(prod.push(2).is_ok());
323 assert!(prod.push(3).is_ok());
324 assert!(prod.push(4).is_ok());
325
326 let err = prod.push(5).unwrap_err();
327 assert_eq!(err.into_inner(), 5);
328 }
329
330 #[test]
335 fn interleaved_no_overwrite() {
336 let (mut prod, mut cons) = ring_buffer::<u64>(8);
337
338 for i in 0..1000 {
339 assert!(prod.push(i).is_ok());
340 assert_eq!(cons.pop(), Some(i));
341 }
342 }
343
344 #[test]
345 fn partial_fill_drain_cycles() {
346 let (mut prod, mut cons) = ring_buffer::<u64>(8);
347
348 for round in 0..100 {
349 for i in 0..4 {
350 assert!(prod.push(round * 4 + i).is_ok());
351 }
352
353 for i in 0..4 {
354 assert_eq!(cons.pop(), Some(round * 4 + i));
355 }
356 }
357 }
358
359 #[test]
364 fn single_slot_bounded() {
365 let (mut prod, mut cons) = ring_buffer::<u64>(1);
366
367 assert!(prod.push(1).is_ok());
368 assert!(prod.push(2).is_err());
369
370 assert_eq!(cons.pop(), Some(1));
371 assert!(prod.push(2).is_ok());
372 }
373
374 #[test]
379 fn producer_disconnected() {
380 let (prod, cons) = ring_buffer::<u64>(4);
381
382 assert!(!cons.is_disconnected());
383 drop(prod);
384 assert!(cons.is_disconnected());
385 }
386
387 #[test]
388 fn consumer_disconnected() {
389 let (prod, cons) = ring_buffer::<u64>(4);
390
391 assert!(!prod.is_disconnected());
392 drop(cons);
393 assert!(prod.is_disconnected());
394 }
395
396 #[test]
401 fn drop_cleans_up_remaining() {
402 use std::sync::atomic::AtomicUsize;
403
404 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
405
406 struct DropCounter;
407 impl Drop for DropCounter {
408 fn drop(&mut self) {
409 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
410 }
411 }
412
413 DROP_COUNT.store(0, Ordering::SeqCst);
414
415 let (mut prod, cons) = ring_buffer::<DropCounter>(4);
416
417 let _ = prod.push(DropCounter);
418 let _ = prod.push(DropCounter);
419 let _ = prod.push(DropCounter);
420
421 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
422
423 drop(prod);
424 drop(cons);
425
426 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
427 }
428
429 #[test]
434 fn cross_thread_bounded() {
435 use std::thread;
436
437 let (mut prod, mut cons) = ring_buffer::<u64>(64);
438
439 let producer = thread::spawn(move || {
440 for i in 0..10_000 {
441 while prod.push(i).is_err() {
442 std::hint::spin_loop();
443 }
444 }
445 });
446
447 let consumer = thread::spawn(move || {
448 let mut received = 0u64;
449 while received < 10_000 {
450 if cons.pop().is_some() {
451 received += 1;
452 } else {
453 std::hint::spin_loop();
454 }
455 }
456 received
457 });
458
459 producer.join().unwrap();
460 let received = consumer.join().unwrap();
461 assert_eq!(received, 10_000);
462 }
463
464 #[test]
469 fn zero_sized_type() {
470 let (mut prod, mut cons) = ring_buffer::<()>(8);
471
472 let _ = prod.push(());
473 let _ = prod.push(());
474
475 assert_eq!(cons.pop(), Some(()));
476 assert_eq!(cons.pop(), Some(()));
477 assert_eq!(cons.pop(), None);
478 }
479
480 #[test]
481 fn string_type() {
482 let (mut prod, mut cons) = ring_buffer::<String>(4);
483
484 let _ = prod.push("hello".to_string());
485 let _ = prod.push("world".to_string());
486
487 assert_eq!(cons.pop(), Some("hello".to_string()));
488 assert_eq!(cons.pop(), Some("world".to_string()));
489 }
490
491 #[test]
492 #[should_panic(expected = "capacity must be non-zero")]
493 fn zero_capacity_panics() {
494 let _ = ring_buffer::<u64>(0);
495 }
496
497 #[test]
498 fn large_message_type() {
499 #[repr(C, align(64))]
500 struct LargeMessage {
501 data: [u8; 256],
502 }
503
504 let (mut prod, mut cons) = ring_buffer::<LargeMessage>(8);
505
506 let msg = LargeMessage { data: [42u8; 256] };
507 assert!(prod.push(msg).is_ok());
508
509 let received = cons.pop().unwrap();
510 assert_eq!(received.data[0], 42);
511 assert_eq!(received.data[255], 42);
512 }
513
514 #[test]
515 fn multiple_laps() {
516 let (mut prod, mut cons) = ring_buffer::<u64>(4);
517
518 for i in 0..40 {
520 assert!(prod.push(i).is_ok());
521 assert_eq!(cons.pop(), Some(i));
522 }
523 }
524
525 #[test]
526 fn fifo_order_cross_thread() {
527 use std::thread;
528
529 let (mut prod, mut cons) = ring_buffer::<u64>(64);
530
531 let producer = thread::spawn(move || {
532 for i in 0..10_000u64 {
533 while prod.push(i).is_err() {
534 std::hint::spin_loop();
535 }
536 }
537 });
538
539 let consumer = thread::spawn(move || {
540 let mut expected = 0u64;
541 while expected < 10_000 {
542 if let Some(val) = cons.pop() {
543 assert_eq!(val, expected, "FIFO order violated");
544 expected += 1;
545 } else {
546 std::hint::spin_loop();
547 }
548 }
549 });
550
551 producer.join().unwrap();
552 consumer.join().unwrap();
553 }
554
555 #[test]
556 fn stress_high_volume() {
557 use std::thread;
558
559 const COUNT: u64 = 1_000_000;
560
561 let (mut prod, mut cons) = ring_buffer::<u64>(1024);
562
563 let producer = thread::spawn(move || {
564 for i in 0..COUNT {
565 while prod.push(i).is_err() {
566 std::hint::spin_loop();
567 }
568 }
569 });
570
571 let consumer = thread::spawn(move || {
572 let mut sum = 0u64;
573 let mut received = 0u64;
574 while received < COUNT {
575 if let Some(val) = cons.pop() {
576 sum = sum.wrapping_add(val);
577 received += 1;
578 } else {
579 std::hint::spin_loop();
580 }
581 }
582 sum
583 });
584
585 producer.join().unwrap();
586 let sum = consumer.join().unwrap();
587 assert_eq!(sum, COUNT * (COUNT - 1) / 2);
588 }
589
590 #[test]
591 fn capacity_rounds_to_power_of_two() {
592 let (prod, _) = ring_buffer::<u64>(100);
593 assert_eq!(prod.capacity(), 128);
594
595 let (prod, _) = ring_buffer::<u64>(1000);
596 assert_eq!(prod.capacity(), 1024);
597 }
598}