lockless_datastructures/
atomic_ring_buffer_mpmc.rs1use std::cell::UnsafeCell;
2use std::mem::MaybeUninit;
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5use crate::primitives::Arc;
6use crate::{Backoff, Padded};
7
8#[repr(align(64))]
9struct Slot<T> {
10 sequence: AtomicUsize,
11 data: UnsafeCell<MaybeUninit<T>>,
12}
13
14pub struct AtomicRingBufferMpmc<T, const N: usize> {
16 head: Padded<AtomicUsize>,
17 tail: Padded<AtomicUsize>,
18 buffer: [Slot<T>; N],
19}
20
21unsafe impl<T: Send, const N: usize> Sync for AtomicRingBufferMpmc<T, N> {}
22unsafe impl<T: Send, const N: usize> Send for AtomicRingBufferMpmc<T, N> {}
23
24impl<T, const N: usize> AtomicRingBufferMpmc<T, N> {
25 pub fn new() -> Arc<Self> {
26 const { assert!(N != 0 && N.is_power_of_two()) };
27
28 let buffer = std::array::from_fn(|i| Slot {
29 sequence: AtomicUsize::new(i),
30 data: UnsafeCell::new(MaybeUninit::uninit()),
31 });
32
33 Arc::new(Self {
34 head: Padded(AtomicUsize::new(0)),
35 tail: Padded(AtomicUsize::new(0)),
36 buffer,
37 })
38 }
39
40 pub fn push(&self, value: T) -> Result<(), T> {
41 let mut backoff = Backoff::new();
42 let mut head = self.head.load(Ordering::Relaxed);
43
44 loop {
45 let idx = head & (N - 1);
46 let slot;
47 unsafe {
48 slot = self.buffer.get_unchecked(idx);
49 }
50 let seq = slot.sequence.load(Ordering::Acquire);
51
52 let diff = seq as isize - head as isize;
53
54 if diff == 0 {
55 match self.head.compare_exchange_weak(
56 head,
57 head + 1,
58 Ordering::Relaxed,
59 Ordering::Relaxed,
60 ) {
61 Ok(_) => {
62 unsafe {
63 (*slot.data.get()).write(value);
64 }
65 slot.sequence.store(head.wrapping_add(1), Ordering::Release);
66 return Ok(());
67 }
68 Err(real_head) => {
69 head = real_head;
70 }
71 }
72 } else if diff < 0 {
73 let new_head = self.head.load(Ordering::Relaxed);
74 if new_head != head {
75 head = new_head;
76 backoff.reset();
77 continue;
78 }
79 return Err(value);
80 } else {
81 head = self.head.load(Ordering::Relaxed);
82 }
83
84 backoff.snooze();
85 }
86 }
87
88 pub fn pop(&self) -> Option<T> {
89 let mut backoff = Backoff::new();
90 let mut tail = self.tail.load(Ordering::Relaxed);
91
92 loop {
93 let idx = tail & (N - 1);
94 let slot;
95 unsafe {
96 slot = self.buffer.get_unchecked(idx);
97 }
98
99 let seq = slot.sequence.load(Ordering::Acquire);
100
101 let diff = seq as isize - (tail.wrapping_add(1) as isize);
102
103 if diff == 0 {
104 match self.tail.compare_exchange_weak(
105 tail,
106 tail + 1,
107 Ordering::Relaxed,
108 Ordering::Relaxed,
109 ) {
110 Ok(_) => {
111 let value = unsafe { (*slot.data.get()).assume_init_read() };
112
113 slot.sequence.store(tail.wrapping_add(N), Ordering::Release);
114
115 return Some(value);
116 }
117 Err(real_tail) => {
118 tail = real_tail;
119 }
120 }
121 } else if diff < 0 {
122 return None;
123 } else {
124 tail = self.tail.load(Ordering::Relaxed);
125 }
126
127 backoff.snooze();
128 }
129 }
130 pub fn read_head(&self) -> usize {
131 self.head.load(Ordering::Acquire) % N
132 }
133
134 pub fn read_tail(&self) -> usize {
135 self.tail.load(Ordering::Acquire) % N
136 }
137
138 pub fn exists(&self, index: usize) -> bool {
139 let mut tail = self.tail.load(Ordering::Acquire);
140 let mut head = self.head.load(Ordering::Acquire);
141 if head == tail {
142 return false;
143 }
144 head &= N - 1;
145 tail &= N - 1;
146 if head > tail {
147 head > index && index >= tail
148 } else {
149 !(index >= head && tail > index)
150 }
151 }
152}
153
154impl<T, const N: usize> Drop for AtomicRingBufferMpmc<T, N> {
155 fn drop(&mut self) {
156 if !std::mem::needs_drop::<T>() {
157 return;
158 }
159
160 let head = self.head.load(Ordering::Relaxed);
161 let mut tail = self.tail.load(Ordering::Relaxed);
162
163 while tail != head {
164 let idx = tail & (N - 1);
165 let slot = &self.buffer[idx];
166
167 let seq = slot.sequence.load(Ordering::Relaxed);
168 let expected_seq = tail.wrapping_add(1);
169
170 if seq == expected_seq {
171 unsafe {
172 let raw_ptr = (*slot.data.get()).as_mut_ptr();
173 std::ptr::drop_in_place(raw_ptr);
174 }
175 }
176
177 tail = tail.wrapping_add(1);
178 }
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use std::sync::Barrier;
186 use std::sync::atomic::{AtomicUsize, Ordering};
187 use std::thread;
188
189 #[test]
190 fn test_basic_push_and_read() {
191 let queue: Arc<AtomicRingBufferMpmc<i32, 4>> = AtomicRingBufferMpmc::new();
192
193 assert!(queue.push(1).is_ok());
194 assert!(queue.push(2).is_ok());
195 assert!(queue.push(3).is_ok());
196
197 assert_eq!(queue.pop(), Some(1));
198 assert_eq!(queue.pop(), Some(2));
199 assert_eq!(queue.pop(), Some(3));
200 assert_eq!(queue.pop(), None);
201 }
202
203 #[test]
204 fn test_buffer_full() {
205 let queue: Arc<AtomicRingBufferMpmc<i32, 2>> = AtomicRingBufferMpmc::new();
206
207 assert!(queue.push(10).is_ok());
208 assert!(queue.push(20).is_ok());
209
210 let result = queue.push(30);
211 assert_eq!(result, Err(30));
212
213 assert_eq!(queue.pop(), Some(10));
214
215 assert!(queue.push(30).is_ok());
216 assert_eq!(queue.pop(), Some(20));
217 assert_eq!(queue.pop(), Some(30));
218 }
219
220 #[test]
221 fn test_wrap_around() {
222 let queue: Arc<AtomicRingBufferMpmc<usize, 4>> = AtomicRingBufferMpmc::new();
223
224 for i in 0..100 {
225 assert!(queue.push(i).is_ok());
226 assert_eq!(queue.pop(), Some(i));
227 }
228
229 assert_eq!(queue.pop(), None);
230 }
231
232 #[test]
233 fn test_mpmc_concurrency() {
234 const BUFFER_SIZE: usize = 64;
235 const NUM_PRODUCERS: usize = 4;
236 const NUM_CONSUMERS: usize = 4;
237 const OPS_PER_THREAD: usize = 10_000;
238
239 let queue: Arc<AtomicRingBufferMpmc<usize, BUFFER_SIZE>> = AtomicRingBufferMpmc::new();
240 let barrier = Arc::new(Barrier::new(NUM_PRODUCERS + NUM_CONSUMERS));
241
242 let mut handles = vec![];
243
244 for p_id in 0..NUM_PRODUCERS {
245 let q = queue.clone();
246 let b = barrier.clone();
247 handles.push(thread::spawn(move || {
248 b.wait();
249 for i in 0..OPS_PER_THREAD {
250 let value = p_id * OPS_PER_THREAD + i;
251 while q.push(value).is_err() {
252 std::thread::yield_now();
253 }
254 }
255 }));
256 }
257
258 let results = Arc::new(AtomicUsize::new(0));
259 for _ in 0..NUM_CONSUMERS {
260 let q = queue.clone();
261 let b = barrier.clone();
262 let r = results.clone();
263 handles.push(thread::spawn(move || {
264 b.wait();
265
266 loop {
267 match q.pop() {
268 Some(_) => {
269 r.fetch_add(1, Ordering::Relaxed);
270 }
271 None => {
272 if r.load(Ordering::Relaxed) == NUM_PRODUCERS * OPS_PER_THREAD {
273 break;
274 }
275 std::thread::yield_now();
276 }
277 }
278 }
279 }));
280 }
281
282 for h in handles {
283 h.join().unwrap();
284 }
285
286 assert_eq!(
287 results.load(Ordering::SeqCst),
288 NUM_PRODUCERS * OPS_PER_THREAD,
289 "Total items consumed must match total items produced"
290 );
291 }
292 static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
293
294 #[derive(Debug)]
295 struct DropTracker;
296
297 impl Drop for DropTracker {
298 fn drop(&mut self) {
299 DROP_COUNTER.fetch_add(1, Ordering::Relaxed);
300 }
301 }
302
303 #[test]
304 fn test_drop_cleanup() {
305 DROP_COUNTER.store(0, Ordering::Relaxed);
306
307 {
308 let buffer = AtomicRingBufferMpmc::<DropTracker, 8>::new();
309
310 for _ in 0..5 {
311 buffer.push(DropTracker).unwrap();
312 }
313
314 buffer.pop();
315 buffer.pop();
316
317 assert_eq!(DROP_COUNTER.load(Ordering::Relaxed), 2);
318 }
319
320 assert_eq!(DROP_COUNTER.load(Ordering::Relaxed), 5);
321 }
322}