1use std::fmt;
58use std::mem::ManuallyDrop;
59use std::sync::Arc;
60use std::sync::atomic::{AtomicUsize, Ordering};
61
62use crossbeam_utils::CachePadded;
63
64use crate::Full;
65
66pub fn ring_buffer<T>(capacity: usize) -> (Producer<T>, Consumer<T>) {
74 assert!(capacity > 0, "capacity must be non-zero");
75
76 let capacity = capacity.next_power_of_two();
77 let mask = capacity - 1;
78
79 let mut slots = ManuallyDrop::new(Vec::<T>::with_capacity(capacity));
80 let buffer = slots.as_mut_ptr();
81
82 let shared = Arc::new(Shared {
83 tail: CachePadded::new(AtomicUsize::new(0)),
84 head: CachePadded::new(AtomicUsize::new(0)),
85 buffer,
86 mask,
87 });
88
89 (
90 Producer {
91 local_tail: 0,
92 cached_head: 0,
93 buffer,
94 mask,
95 shared: Arc::clone(&shared),
96 },
97 Consumer {
98 local_head: 0,
99 cached_tail: 0,
100 buffer,
101 mask,
102 shared,
103 },
104 )
105}
106
107#[repr(C)]
108struct Shared<T> {
109 tail: CachePadded<AtomicUsize>,
110 head: CachePadded<AtomicUsize>,
111 buffer: *mut T,
112 mask: usize,
113}
114
115unsafe impl<T: Send> Send for Shared<T> {}
116unsafe impl<T: Send> Sync for Shared<T> {}
117
118impl<T> Drop for Shared<T> {
119 fn drop(&mut self) {
120 let head = self.head.load(Ordering::Relaxed);
121 let tail = self.tail.load(Ordering::Relaxed);
122
123 let mut i = head;
124 while i != tail {
125 unsafe { self.buffer.add(i & self.mask).drop_in_place() };
126 i = i.wrapping_add(1);
127 }
128
129 unsafe {
130 let capacity = self.mask + 1;
131 let _ = Vec::from_raw_parts(self.buffer, 0, capacity);
132 }
133 }
134}
135
136#[repr(C)]
140pub struct Producer<T> {
141 local_tail: usize,
142 cached_head: usize,
143 buffer: *mut T,
144 mask: usize,
145 shared: Arc<Shared<T>>,
146}
147
148unsafe impl<T: Send> Send for Producer<T> {}
149
150impl<T> Producer<T> {
151 #[inline]
156 #[must_use = "push returns Err if full, which should be handled"]
157 pub fn push(&mut self, value: T) -> Result<(), Full<T>> {
158 let tail = self.local_tail;
159
160 if tail.wrapping_sub(self.cached_head) > self.mask {
161 self.cached_head = self.shared.head.load(Ordering::Relaxed);
162
163 std::sync::atomic::fence(Ordering::Acquire);
164 if tail.wrapping_sub(self.cached_head) > self.mask {
165 return Err(Full(value));
166 }
167 }
168
169 unsafe { self.buffer.add(tail & self.mask).write(value) };
170 let new_tail = tail.wrapping_add(1);
171 std::sync::atomic::fence(Ordering::Release);
172
173 self.shared.tail.store(new_tail, Ordering::Relaxed);
174 self.local_tail = new_tail;
175
176 Ok(())
177 }
178
179 #[inline]
181 pub fn capacity(&self) -> usize {
182 self.mask + 1
183 }
184
185 #[inline]
187 pub fn is_disconnected(&self) -> bool {
188 Arc::strong_count(&self.shared) == 1
189 }
190}
191
192impl<T> fmt::Debug for Producer<T> {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 f.debug_struct("Producer")
195 .field("capacity", &self.capacity())
196 .finish_non_exhaustive()
197 }
198}
199
200#[repr(C)]
204pub struct Consumer<T> {
205 local_head: usize,
206 cached_tail: usize,
207 buffer: *mut T,
208 mask: usize,
209 shared: Arc<Shared<T>>,
210}
211
212unsafe impl<T: Send> Send for Consumer<T> {}
213
214impl<T> Consumer<T> {
215 #[inline]
219 pub fn pop(&mut self) -> Option<T> {
220 let head = self.local_head;
221
222 if head == self.cached_tail {
223 self.cached_tail = self.shared.tail.load(Ordering::Relaxed);
224 std::sync::atomic::fence(Ordering::Acquire);
225
226 if head == self.cached_tail {
227 return None;
228 }
229 }
230
231 let value = unsafe { self.buffer.add(head & self.mask).read() };
232 let new_head = head.wrapping_add(1);
233 std::sync::atomic::fence(Ordering::Release);
234
235 self.shared.head.store(new_head, Ordering::Relaxed);
236 self.local_head = new_head;
237
238 Some(value)
239 }
240
241 #[inline]
243 pub fn capacity(&self) -> usize {
244 self.mask + 1
245 }
246
247 #[inline]
249 pub fn is_disconnected(&self) -> bool {
250 Arc::strong_count(&self.shared) == 1
251 }
252}
253
254impl<T> fmt::Debug for Consumer<T> {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 f.debug_struct("Consumer")
257 .field("capacity", &self.capacity())
258 .finish_non_exhaustive()
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
271 fn basic_push_pop() {
272 let (mut prod, mut cons) = ring_buffer::<u64>(4);
273
274 assert!(prod.push(1).is_ok());
275 assert!(prod.push(2).is_ok());
276 assert!(prod.push(3).is_ok());
277
278 assert_eq!(cons.pop(), Some(1));
279 assert_eq!(cons.pop(), Some(2));
280 assert_eq!(cons.pop(), Some(3));
281 assert_eq!(cons.pop(), None);
282 }
283
284 #[test]
285 fn empty_pop_returns_none() {
286 let (_, mut cons) = ring_buffer::<u64>(4);
287 assert_eq!(cons.pop(), None);
288 assert_eq!(cons.pop(), None);
289 }
290
291 #[test]
292 fn fill_then_drain() {
293 let (mut prod, mut cons) = ring_buffer::<u64>(4);
294
295 for i in 0..4 {
296 assert!(prod.push(i).is_ok());
297 }
298
299 for i in 0..4 {
300 assert_eq!(cons.pop(), Some(i));
301 }
302
303 assert_eq!(cons.pop(), None);
304 }
305
306 #[test]
307 fn push_returns_error_when_full() {
308 let (mut prod, _cons) = ring_buffer::<u64>(4);
309
310 assert!(prod.push(1).is_ok());
311 assert!(prod.push(2).is_ok());
312 assert!(prod.push(3).is_ok());
313 assert!(prod.push(4).is_ok());
314
315 let err = prod.push(5).unwrap_err();
316 assert_eq!(err.into_inner(), 5);
317 }
318 #[test]
323 fn interleaved_no_overwrite() {
324 let (mut prod, mut cons) = ring_buffer::<u64>(8);
325
326 for i in 0..1000 {
327 assert!(prod.push(i).is_ok());
328 assert_eq!(cons.pop(), Some(i));
329 }
330 }
331
332 #[test]
333 fn partial_fill_drain_cycles() {
334 let (mut prod, mut cons) = ring_buffer::<u64>(8);
335
336 for round in 0..100 {
337 for i in 0..4 {
338 assert!(prod.push(round * 4 + i).is_ok());
339 }
340
341 for i in 0..4 {
342 assert_eq!(cons.pop(), Some(round * 4 + i));
343 }
344 }
345 }
346
347 #[test]
352 fn single_slot_bounded() {
353 let (mut prod, mut cons) = ring_buffer::<u64>(1);
354
355 assert!(prod.push(1).is_ok());
356 assert!(prod.push(2).is_err());
357
358 assert_eq!(cons.pop(), Some(1));
359 assert!(prod.push(2).is_ok());
360 }
361
362 #[test]
367 fn producer_disconnected() {
368 let (prod, cons) = ring_buffer::<u64>(4);
369
370 assert!(!cons.is_disconnected());
371 drop(prod);
372 assert!(cons.is_disconnected());
373 }
374
375 #[test]
376 fn consumer_disconnected() {
377 let (prod, cons) = ring_buffer::<u64>(4);
378
379 assert!(!prod.is_disconnected());
380 drop(cons);
381 assert!(prod.is_disconnected());
382 }
383
384 #[test]
389 fn drop_cleans_up_remaining() {
390 use std::sync::atomic::AtomicUsize;
391
392 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
393
394 struct DropCounter;
395 impl Drop for DropCounter {
396 fn drop(&mut self) {
397 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
398 }
399 }
400
401 DROP_COUNT.store(0, Ordering::SeqCst);
402
403 let (mut prod, cons) = ring_buffer::<DropCounter>(4);
404
405 let _ = prod.push(DropCounter);
406 let _ = prod.push(DropCounter);
407 let _ = prod.push(DropCounter);
408
409 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
410
411 drop(prod);
412 drop(cons);
413
414 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 3);
415 }
416
417 #[test]
422 fn cross_thread_bounded() {
423 use std::thread;
424
425 let (mut prod, mut cons) = ring_buffer::<u64>(64);
426
427 let producer = thread::spawn(move || {
428 for i in 0..10_000 {
429 while prod.push(i).is_err() {
430 std::hint::spin_loop();
431 }
432 }
433 });
434
435 let consumer = thread::spawn(move || {
436 let mut received = 0u64;
437 while received < 10_000 {
438 if cons.pop().is_some() {
439 received += 1;
440 } else {
441 std::hint::spin_loop();
442 }
443 }
444 received
445 });
446
447 producer.join().unwrap();
448 let received = consumer.join().unwrap();
449 assert_eq!(received, 10_000);
450 }
451
452 #[test]
457 fn zero_sized_type() {
458 let (mut prod, mut cons) = ring_buffer::<()>(8);
459
460 let _ = prod.push(());
461 let _ = prod.push(());
462
463 assert_eq!(cons.pop(), Some(()));
464 assert_eq!(cons.pop(), Some(()));
465 assert_eq!(cons.pop(), None);
466 }
467
468 #[test]
469 fn string_type() {
470 let (mut prod, mut cons) = ring_buffer::<String>(4);
471
472 let _ = prod.push("hello".to_string());
473 let _ = prod.push("world".to_string());
474
475 assert_eq!(cons.pop(), Some("hello".to_string()));
476 assert_eq!(cons.pop(), Some("world".to_string()));
477 }
478
479 #[test]
480 #[should_panic(expected = "capacity must be non-zero")]
481 fn zero_capacity_panics() {
482 let _ = ring_buffer::<u64>(0);
483 }
484
485 #[test]
486 fn large_message_type() {
487 #[repr(C, align(64))]
488 struct LargeMessage {
489 data: [u8; 256],
490 }
491
492 let (mut prod, mut cons) = ring_buffer::<LargeMessage>(8);
493
494 let msg = LargeMessage { data: [42u8; 256] };
495 assert!(prod.push(msg).is_ok());
496
497 let received = cons.pop().unwrap();
498 assert_eq!(received.data[0], 42);
499 assert_eq!(received.data[255], 42);
500 }
501
502 #[test]
503 fn multiple_laps() {
504 let (mut prod, mut cons) = ring_buffer::<u64>(4);
505
506 for i in 0..40 {
508 assert!(prod.push(i).is_ok());
509 assert_eq!(cons.pop(), Some(i));
510 }
511 }
512
513 #[test]
514 fn fifo_order_cross_thread() {
515 use std::thread;
516
517 let (mut prod, mut cons) = ring_buffer::<u64>(64);
518
519 let producer = thread::spawn(move || {
520 for i in 0..10_000u64 {
521 while prod.push(i).is_err() {
522 std::hint::spin_loop();
523 }
524 }
525 });
526
527 let consumer = thread::spawn(move || {
528 let mut expected = 0u64;
529 while expected < 10_000 {
530 if let Some(val) = cons.pop() {
531 assert_eq!(val, expected, "FIFO order violated");
532 expected += 1;
533 } else {
534 std::hint::spin_loop();
535 }
536 }
537 });
538
539 producer.join().unwrap();
540 consumer.join().unwrap();
541 }
542
543 #[test]
544 fn stress_high_volume() {
545 use std::thread;
546
547 const COUNT: u64 = 1_000_000;
548
549 let (mut prod, mut cons) = ring_buffer::<u64>(1024);
550
551 let producer = thread::spawn(move || {
552 for i in 0..COUNT {
553 while prod.push(i).is_err() {
554 std::hint::spin_loop();
555 }
556 }
557 });
558
559 let consumer = thread::spawn(move || {
560 let mut sum = 0u64;
561 let mut received = 0u64;
562 while received < COUNT {
563 if let Some(val) = cons.pop() {
564 sum = sum.wrapping_add(val);
565 received += 1;
566 } else {
567 std::hint::spin_loop();
568 }
569 }
570 sum
571 });
572
573 producer.join().unwrap();
574 let sum = consumer.join().unwrap();
575 assert_eq!(sum, COUNT * (COUNT - 1) / 2);
576 }
577
578 #[test]
579 fn capacity_rounds_to_power_of_two() {
580 let (prod, _) = ring_buffer::<u64>(100);
581 assert_eq!(prod.capacity(), 128);
582
583 let (prod, _) = ring_buffer::<u64>(1000);
584 assert_eq!(prod.capacity(), 1024);
585 }
586}