1use std::cell::UnsafeCell;
7use std::mem::MaybeUninit;
8use std::sync::atomic::{AtomicUsize, Ordering};
9
10use super::{CachePadded, DEFAULT_RING_CAPACITY, next_power_of_two};
11
12pub struct LockFreeRing<T> {
29 buffer: Box<[UnsafeCell<MaybeUninit<T>>]>,
31 capacity: usize,
33 mask: usize,
35 head: CachePadded<AtomicUsize>,
37 tail: CachePadded<AtomicUsize>,
39}
40
41unsafe impl<T: Send> Send for LockFreeRing<T> {}
44unsafe impl<T: Send> Sync for LockFreeRing<T> {}
45
46impl<T> LockFreeRing<T> {
47 #[must_use]
55 pub fn new(capacity: usize) -> Self {
56 assert!(capacity > 0, "capacity must be > 0");
57
58 let capacity = next_power_of_two(capacity);
59 let mask = capacity - 1;
60
61 let buffer: Vec<UnsafeCell<MaybeUninit<T>>> = (0..capacity)
63 .map(|_| UnsafeCell::new(MaybeUninit::uninit()))
64 .collect();
65
66 Self {
67 buffer: buffer.into_boxed_slice(),
68 capacity,
69 mask,
70 head: CachePadded::new(AtomicUsize::new(0)),
71 tail: CachePadded::new(AtomicUsize::new(0)),
72 }
73 }
74
75 #[must_use]
77 pub fn with_default_capacity() -> Self {
78 Self::new(DEFAULT_RING_CAPACITY)
79 }
80
81 #[inline]
83 #[must_use]
84 pub const fn capacity(&self) -> usize {
85 self.capacity
86 }
87
88 #[inline]
90 #[must_use]
91 pub fn len(&self) -> usize {
92 let head = self.head.0.load(Ordering::Acquire);
93 let tail = self.tail.0.load(Ordering::Acquire);
94 head.wrapping_sub(tail)
95 }
96
97 #[inline]
99 #[must_use]
100 pub fn is_empty(&self) -> bool {
101 self.len() == 0
102 }
103
104 #[inline]
106 #[must_use]
107 pub fn is_full(&self) -> bool {
108 self.len() >= self.capacity
109 }
110
111 #[inline]
113 #[must_use]
114 pub fn free_slots(&self) -> usize {
115 self.capacity - self.len()
116 }
117
118 #[inline]
122 pub fn enqueue(&self, item: T) -> Result<(), T> {
123 let head = self.head.0.load(Ordering::Relaxed);
124 let tail = self.tail.0.load(Ordering::Acquire);
125
126 if head.wrapping_sub(tail) >= self.capacity {
128 return Err(item);
129 }
130
131 let idx = head & self.mask;
133 unsafe {
134 (*self.buffer[idx].get()).write(item);
135 }
136
137 self.head.0.store(head.wrapping_add(1), Ordering::Release);
139
140 Ok(())
141 }
142
143 #[inline]
147 pub fn dequeue(&self) -> Option<T> {
148 let tail = self.tail.0.load(Ordering::Relaxed);
149 let head = self.head.0.load(Ordering::Acquire);
150
151 if tail == head {
153 return None;
154 }
155
156 let idx = tail & self.mask;
158 let item = unsafe { (*self.buffer[idx].get()).assume_init_read() };
159
160 self.tail.0.store(tail.wrapping_add(1), Ordering::Release);
162
163 Some(item)
164 }
165
166 pub fn enqueue_batch(&self, items: &[T]) -> usize
171 where
172 T: Copy,
173 {
174 let head = self.head.0.load(Ordering::Relaxed);
175 let tail = self.tail.0.load(Ordering::Acquire);
176
177 let free = self.capacity - head.wrapping_sub(tail);
178 let count = items.len().min(free);
179
180 if count == 0 {
181 return 0;
182 }
183
184 for (i, item) in items.iter().take(count).enumerate() {
186 let idx = (head + i) & self.mask;
187 unsafe {
188 (*self.buffer[idx].get()).write(*item);
189 }
190 }
191
192 self.head
194 .0
195 .store(head.wrapping_add(count), Ordering::Release);
196
197 count
198 }
199
200 pub fn dequeue_batch(&self, out: &mut [T]) -> usize
204 where
205 T: Copy,
206 {
207 let tail = self.tail.0.load(Ordering::Relaxed);
208 let head = self.head.0.load(Ordering::Acquire);
209
210 let available = head.wrapping_sub(tail);
211 let count = out.len().min(available);
212
213 if count == 0 {
214 return 0;
215 }
216
217 for (i, slot) in out[..count].iter_mut().enumerate() {
219 let idx = (tail + i) & self.mask;
220 *slot = unsafe { (*self.buffer[idx].get()).assume_init_read() };
221 }
222
223 self.tail
225 .0
226 .store(tail.wrapping_add(count), Ordering::Release);
227
228 count
229 }
230
231 #[inline]
237 pub unsafe fn peek(&self) -> Option<&T> {
238 let tail = self.tail.0.load(Ordering::Relaxed);
239 let head = self.head.0.load(Ordering::Acquire);
240
241 if tail == head {
242 return None;
243 }
244
245 let idx = tail & self.mask;
246 Some(unsafe { (*self.buffer[idx].get()).assume_init_ref() })
248 }
249
250 pub unsafe fn clear(&self) {
256 while self.dequeue().is_some() {}
257 }
258}
259
260impl<T> Drop for LockFreeRing<T> {
261 fn drop(&mut self) {
262 while self.dequeue().is_some() {}
264 }
265}
266
267#[allow(clippy::missing_fields_in_debug)] impl<T> std::fmt::Debug for LockFreeRing<T> {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 f.debug_struct("LockFreeRing")
271 .field("capacity", &self.capacity)
272 .field("len", &self.len())
273 .field("head", &self.head.0.load(Ordering::Relaxed))
274 .field("tail", &self.tail.0.load(Ordering::Relaxed))
275 .finish()
276 }
277}
278
279struct MpmcSlot<T> {
282 seq: AtomicUsize,
283 data: UnsafeCell<MaybeUninit<T>>,
284}
285
286pub struct MpmcRing<T> {
294 buffer: Box<[MpmcSlot<T>]>,
296 capacity: usize,
298 mask: usize,
300 head: CachePadded<AtomicUsize>,
302 tail: CachePadded<AtomicUsize>,
304}
305
306unsafe impl<T: Send> Send for MpmcRing<T> {}
307unsafe impl<T: Send> Sync for MpmcRing<T> {}
308
309impl<T: Copy> MpmcRing<T> {
310 #[must_use]
312 pub fn new(capacity: usize) -> Self {
313 assert!(capacity > 0, "capacity must be > 0");
314
315 let capacity = next_power_of_two(capacity);
316 let mask = capacity - 1;
317
318 let buffer: Vec<MpmcSlot<T>> = (0..capacity)
322 .map(|i| MpmcSlot {
323 seq: AtomicUsize::new(i),
324 data: UnsafeCell::new(MaybeUninit::uninit()),
325 })
326 .collect();
327
328 Self {
329 buffer: buffer.into_boxed_slice(),
330 capacity,
331 mask,
332 head: CachePadded::new(AtomicUsize::new(0)),
333 tail: CachePadded::new(AtomicUsize::new(0)),
334 }
335 }
336
337 #[inline]
339 #[must_use]
340 pub const fn capacity(&self) -> usize {
341 self.capacity
342 }
343
344 #[inline]
346 #[must_use]
347 pub fn len(&self) -> usize {
348 let head = self.head.0.load(Ordering::Acquire);
349 let tail = self.tail.0.load(Ordering::Acquire);
350 head.wrapping_sub(tail)
351 }
352
353 #[inline]
355 #[must_use]
356 pub fn is_empty(&self) -> bool {
357 self.len() == 0
358 }
359
360 pub fn enqueue(&self, item: T) -> Result<(), T> {
364 let mut head = self.head.0.load(Ordering::Relaxed);
365
366 loop {
367 let slot = &self.buffer[head & self.mask];
368 let seq = slot.seq.load(Ordering::Acquire);
369
370 #[allow(clippy::cast_possible_wrap)]
371 let diff = (seq as isize).wrapping_sub(head as isize);
372
373 match diff.cmp(&0) {
374 std::cmp::Ordering::Equal => {
375 match self.head.0.compare_exchange_weak(
377 head,
378 head.wrapping_add(1),
379 Ordering::Relaxed,
380 Ordering::Relaxed,
381 ) {
382 Ok(_) => {
383 unsafe { (*slot.data.get()).write(item) };
385 slot.seq.store(head.wrapping_add(1), Ordering::Release);
387 return Ok(());
388 }
389 Err(h) => head = h,
390 }
391 }
392 std::cmp::Ordering::Less => {
393 return Err(item);
395 }
396 std::cmp::Ordering::Greater => {
397 head = self.head.0.load(Ordering::Relaxed);
399 }
400 }
401 }
402 }
403
404 pub fn dequeue(&self) -> Option<T> {
408 let mut tail = self.tail.0.load(Ordering::Relaxed);
409
410 loop {
411 let slot = &self.buffer[tail & self.mask];
412 let seq = slot.seq.load(Ordering::Acquire);
413
414 #[allow(clippy::cast_possible_wrap)]
415 let diff = (seq as isize).wrapping_sub(tail.wrapping_add(1) as isize);
416
417 match diff.cmp(&0) {
418 std::cmp::Ordering::Equal => {
419 match self.tail.0.compare_exchange_weak(
421 tail,
422 tail.wrapping_add(1),
423 Ordering::Relaxed,
424 Ordering::Relaxed,
425 ) {
426 Ok(_) => {
427 let item = unsafe { (*slot.data.get()).assume_init_read() };
430 slot.seq
432 .store(tail.wrapping_add(self.capacity), Ordering::Release);
433 return Some(item);
434 }
435 Err(t) => tail = t,
436 }
437 }
438 std::cmp::Ordering::Less => {
439 return None;
441 }
442 std::cmp::Ordering::Greater => {
443 tail = self.tail.0.load(Ordering::Relaxed);
445 }
446 }
447 }
448 }
449}
450
451#[cfg(test)]
455mod tests {
456 use super::*;
457 use std::sync::Arc;
458 use std::thread;
459
460 #[test]
461 fn test_spsc_basic() {
462 let ring = LockFreeRing::<u32>::new(4);
463
464 assert!(ring.is_empty());
465 assert_eq!(ring.capacity(), 4);
466
467 ring.enqueue(1).unwrap();
469 ring.enqueue(2).unwrap();
470 ring.enqueue(3).unwrap();
471 ring.enqueue(4).unwrap();
472
473 assert!(ring.is_full());
474 assert!(ring.enqueue(5).is_err());
475
476 assert_eq!(ring.dequeue(), Some(1));
478 assert_eq!(ring.dequeue(), Some(2));
479 assert_eq!(ring.dequeue(), Some(3));
480 assert_eq!(ring.dequeue(), Some(4));
481
482 assert!(ring.is_empty());
483 assert_eq!(ring.dequeue(), None);
484 }
485
486 #[test]
487 fn test_spsc_batch() {
488 let ring = LockFreeRing::<u32>::new(8);
489
490 let items = [1, 2, 3, 4, 5];
491 let count = ring.enqueue_batch(&items);
492 assert_eq!(count, 5);
493 assert_eq!(ring.len(), 5);
494
495 let mut out = [0u32; 10];
496 let count = ring.dequeue_batch(&mut out);
497 assert_eq!(count, 5);
498 assert_eq!(&out[..5], &items);
499 }
500
501 #[test]
502 fn test_spsc_wrap() {
503 let ring = LockFreeRing::<u32>::new(4);
504
505 for round in 0..10 {
507 for i in 0..4 {
508 ring.enqueue(round * 4 + i).unwrap();
509 }
510 for i in 0..4 {
511 assert_eq!(ring.dequeue(), Some(round * 4 + i));
512 }
513 }
514 }
515
516 #[test]
517 fn test_spsc_threaded() {
518 let ring = Arc::new(LockFreeRing::<u64>::new(1024));
519 let ring_producer = Arc::clone(&ring);
520 let ring_consumer = Arc::clone(&ring);
521
522 let count = 100_000u64;
523
524 let producer = thread::spawn(move || {
525 for i in 0..count {
526 while ring_producer.enqueue(i).is_err() {
527 std::hint::spin_loop();
528 }
529 }
530 });
531
532 let consumer = thread::spawn(move || {
533 let mut received = 0u64;
534 let mut last = 0u64;
535 while received < count {
536 if let Some(v) = ring_consumer.dequeue() {
537 assert!(v >= last, "out of order: {} < {}", v, last);
539 last = v;
540 received += 1;
541 } else {
542 std::hint::spin_loop();
543 }
544 }
545 });
546
547 producer.join().unwrap();
548 consumer.join().unwrap();
549 }
550
551 #[test]
552 fn test_capacity_rounding() {
553 let ring = LockFreeRing::<u32>::new(3);
554 assert_eq!(ring.capacity(), 4); let ring = LockFreeRing::<u32>::new(5);
557 assert_eq!(ring.capacity(), 8);
558
559 let ring = LockFreeRing::<u32>::new(1024);
560 assert_eq!(ring.capacity(), 1024);
561 }
562
563 #[test]
564 fn test_peek() {
565 let ring = LockFreeRing::<u32>::new(4);
566
567 unsafe {
568 assert!(ring.peek().is_none());
569 }
570
571 ring.enqueue(42).unwrap();
572
573 unsafe {
574 assert_eq!(ring.peek(), Some(&42));
575 assert_eq!(ring.peek(), Some(&42)); }
577
578 assert_eq!(ring.dequeue(), Some(42));
579 }
580
581 #[test]
582 fn test_mpmc_basic() {
583 let ring = MpmcRing::<u32>::new(4);
584
585 ring.enqueue(1).unwrap();
586 ring.enqueue(2).unwrap();
587
588 assert_eq!(ring.dequeue(), Some(1));
589 assert_eq!(ring.dequeue(), Some(2));
590 assert_eq!(ring.dequeue(), None);
591 }
592
593 #[test]
597 fn test_mpmc_stress() {
598 use std::sync::atomic::AtomicBool;
599
600 const PRODUCERS: usize = 4;
601 const CONSUMERS: usize = 4;
602 const ITEMS_PER_PRODUCER: usize = 10_000;
603 const TOTAL: usize = PRODUCERS * ITEMS_PER_PRODUCER;
604
605 let ring = Arc::new(MpmcRing::<usize>::new(256));
606 let producers_done = Arc::new(AtomicBool::new(false));
607
608 let mut producer_handles = Vec::new();
610 for p in 0..PRODUCERS {
611 let ring = Arc::clone(&ring);
612 producer_handles.push(thread::spawn(move || {
613 let base = p * ITEMS_PER_PRODUCER;
614 for i in 0..ITEMS_PER_PRODUCER {
615 while ring.enqueue(base + i).is_err() {
616 std::hint::spin_loop();
617 }
618 }
619 }));
620 }
621
622 let mut consumer_handles = Vec::new();
624 for _ in 0..CONSUMERS {
625 let ring = Arc::clone(&ring);
626 let done = Arc::clone(&producers_done);
627 consumer_handles.push(thread::spawn(move || {
628 let mut collected = Vec::new();
629 loop {
630 match ring.dequeue() {
631 Some(v) => collected.push(v),
632 None => {
633 if done.load(Ordering::Acquire) {
634 while let Some(v) = ring.dequeue() {
636 collected.push(v);
637 }
638 break;
639 }
640 std::hint::spin_loop();
641 }
642 }
643 }
644 collected
645 }));
646 }
647
648 for h in producer_handles {
650 h.join().unwrap();
651 }
652 producers_done.store(true, Ordering::Release);
653
654 let mut all: Vec<usize> = consumer_handles
656 .into_iter()
657 .flat_map(|h| h.join().unwrap())
658 .collect();
659
660 while let Some(v) = ring.dequeue() {
662 all.push(v);
663 }
664
665 all.sort_unstable();
666 all.dedup();
667 assert_eq!(
668 all.len(),
669 TOTAL,
670 "expected {TOTAL} unique items, got {} (duplicates or lost items)",
671 all.len()
672 );
673 }
674}