Skip to main content

lockless_datastructures/
atomic_ring_buffer_mpmc.rs

1use std::cell::UnsafeCell;
2use std::mem::MaybeUninit;
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5use crate::primitives::Arc;
6use crate::{Backoff, Padded};
7
8#[repr(align(64))]
9struct Slot<T> {
10    sequence: AtomicUsize,
11    data: UnsafeCell<MaybeUninit<T>>,
12}
13
14///Uses atomic's instead of mutexes
15pub struct AtomicRingBufferMpmc<T, const N: usize> {
16    head: Padded<AtomicUsize>,
17    tail: Padded<AtomicUsize>,
18    buffer: [Slot<T>; N],
19}
20
21unsafe impl<T: Send, const N: usize> Sync for AtomicRingBufferMpmc<T, N> {}
22unsafe impl<T: Send, const N: usize> Send for AtomicRingBufferMpmc<T, N> {}
23
24impl<T, const N: usize> AtomicRingBufferMpmc<T, N> {
25    pub fn new() -> Arc<Self> {
26        const { assert!(N != 0 && N.is_power_of_two()) };
27
28        let buffer = std::array::from_fn(|i| Slot {
29            sequence: AtomicUsize::new(i),
30            data: UnsafeCell::new(MaybeUninit::uninit()),
31        });
32
33        Arc::new(Self {
34            head: Padded(AtomicUsize::new(0)),
35            tail: Padded(AtomicUsize::new(0)),
36            buffer,
37        })
38    }
39
40    pub fn push(&self, value: T) -> Result<(), T> {
41        let mut backoff = Backoff::new();
42        let mut head = self.head.load(Ordering::Relaxed);
43
44        loop {
45            let idx = head & (N - 1);
46            let slot;
47            unsafe {
48                slot = self.buffer.get_unchecked(idx);
49            }
50            let seq = slot.sequence.load(Ordering::Acquire);
51
52            let diff = seq as isize - head as isize;
53
54            if diff == 0 {
55                match self.head.compare_exchange_weak(
56                    head,
57                    head + 1,
58                    Ordering::Relaxed,
59                    Ordering::Relaxed,
60                ) {
61                    Ok(_) => {
62                        unsafe {
63                            (*slot.data.get()).write(value);
64                        }
65                        slot.sequence.store(head.wrapping_add(1), Ordering::Release);
66                        return Ok(());
67                    }
68                    Err(real_head) => {
69                        head = real_head;
70                    }
71                }
72            } else if diff < 0 {
73                let new_head = self.head.load(Ordering::Relaxed);
74                if new_head != head {
75                    head = new_head;
76                    backoff.reset();
77                    continue;
78                }
79                return Err(value);
80            } else {
81                head = self.head.load(Ordering::Relaxed);
82            }
83
84            backoff.snooze();
85        }
86    }
87
88    pub fn pop(&self) -> Option<T> {
89        let mut backoff = Backoff::new();
90        let mut tail = self.tail.load(Ordering::Relaxed);
91
92        loop {
93            let idx = tail & (N - 1);
94            let slot;
95            unsafe {
96                slot = self.buffer.get_unchecked(idx);
97            }
98
99            let seq = slot.sequence.load(Ordering::Acquire);
100
101            let diff = seq as isize - (tail.wrapping_add(1) as isize);
102
103            if diff == 0 {
104                match self.tail.compare_exchange_weak(
105                    tail,
106                    tail + 1,
107                    Ordering::Relaxed,
108                    Ordering::Relaxed,
109                ) {
110                    Ok(_) => {
111                        let value = unsafe { (*slot.data.get()).assume_init_read() };
112
113                        slot.sequence.store(tail.wrapping_add(N), Ordering::Release);
114
115                        return Some(value);
116                    }
117                    Err(real_tail) => {
118                        tail = real_tail;
119                    }
120                }
121            } else if diff < 0 {
122                return None;
123            } else {
124                tail = self.tail.load(Ordering::Relaxed);
125            }
126
127            backoff.snooze();
128        }
129    }
130    pub fn read_head(&self) -> usize {
131        self.head.load(Ordering::Acquire) % N
132    }
133
134    pub fn read_tail(&self) -> usize {
135        self.tail.load(Ordering::Acquire) % N
136    }
137
138    pub fn exists(&self, index: usize) -> bool {
139        let mut tail = self.tail.load(Ordering::Acquire);
140        let mut head = self.head.load(Ordering::Acquire);
141        if head == tail {
142            return false;
143        }
144        head &= N - 1;
145        tail &= N - 1;
146        if head > tail {
147            head > index && index >= tail
148        } else {
149            !(index >= head && tail > index)
150        }
151    }
152}
153
154impl<T, const N: usize> Drop for AtomicRingBufferMpmc<T, N> {
155    fn drop(&mut self) {
156        if !std::mem::needs_drop::<T>() {
157            return;
158        }
159
160        let head = self.head.load(Ordering::Relaxed);
161        let mut tail = self.tail.load(Ordering::Relaxed);
162
163        while tail != head {
164            let idx = tail & (N - 1);
165            let slot = &self.buffer[idx];
166
167            let seq = slot.sequence.load(Ordering::Relaxed);
168            let expected_seq = tail.wrapping_add(1);
169
170            if seq == expected_seq {
171                unsafe {
172                    let raw_ptr = (*slot.data.get()).as_mut_ptr();
173                    std::ptr::drop_in_place(raw_ptr);
174                }
175            }
176
177            tail = tail.wrapping_add(1);
178        }
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use std::sync::Barrier;
186    use std::sync::atomic::{AtomicUsize, Ordering};
187    use std::thread;
188
189    #[test]
190    fn test_basic_push_and_read() {
191        let queue: Arc<AtomicRingBufferMpmc<i32, 4>> = AtomicRingBufferMpmc::new();
192
193        assert!(queue.push(1).is_ok());
194        assert!(queue.push(2).is_ok());
195        assert!(queue.push(3).is_ok());
196
197        assert_eq!(queue.pop(), Some(1));
198        assert_eq!(queue.pop(), Some(2));
199        assert_eq!(queue.pop(), Some(3));
200        assert_eq!(queue.pop(), None);
201    }
202
203    #[test]
204    fn test_buffer_full() {
205        let queue: Arc<AtomicRingBufferMpmc<i32, 2>> = AtomicRingBufferMpmc::new();
206
207        assert!(queue.push(10).is_ok());
208        assert!(queue.push(20).is_ok());
209
210        let result = queue.push(30);
211        assert_eq!(result, Err(30));
212
213        assert_eq!(queue.pop(), Some(10));
214
215        assert!(queue.push(30).is_ok());
216        assert_eq!(queue.pop(), Some(20));
217        assert_eq!(queue.pop(), Some(30));
218    }
219
220    #[test]
221    fn test_wrap_around() {
222        let queue: Arc<AtomicRingBufferMpmc<usize, 4>> = AtomicRingBufferMpmc::new();
223
224        for i in 0..100 {
225            assert!(queue.push(i).is_ok());
226            assert_eq!(queue.pop(), Some(i));
227        }
228
229        assert_eq!(queue.pop(), None);
230    }
231
232    #[test]
233    fn test_mpmc_concurrency() {
234        const BUFFER_SIZE: usize = 64;
235        const NUM_PRODUCERS: usize = 4;
236        const NUM_CONSUMERS: usize = 4;
237        const OPS_PER_THREAD: usize = 10_000;
238
239        let queue: Arc<AtomicRingBufferMpmc<usize, BUFFER_SIZE>> = AtomicRingBufferMpmc::new();
240        let barrier = Arc::new(Barrier::new(NUM_PRODUCERS + NUM_CONSUMERS));
241
242        let mut handles = vec![];
243
244        for p_id in 0..NUM_PRODUCERS {
245            let q = queue.clone();
246            let b = barrier.clone();
247            handles.push(thread::spawn(move || {
248                b.wait();
249                for i in 0..OPS_PER_THREAD {
250                    let value = p_id * OPS_PER_THREAD + i;
251                    while q.push(value).is_err() {
252                        std::thread::yield_now();
253                    }
254                }
255            }));
256        }
257
258        let results = Arc::new(AtomicUsize::new(0));
259        for _ in 0..NUM_CONSUMERS {
260            let q = queue.clone();
261            let b = barrier.clone();
262            let r = results.clone();
263            handles.push(thread::spawn(move || {
264                b.wait();
265
266                loop {
267                    match q.pop() {
268                        Some(_) => {
269                            r.fetch_add(1, Ordering::Relaxed);
270                        }
271                        None => {
272                            if r.load(Ordering::Relaxed) == NUM_PRODUCERS * OPS_PER_THREAD {
273                                break;
274                            }
275                            std::thread::yield_now();
276                        }
277                    }
278                }
279            }));
280        }
281
282        for h in handles {
283            h.join().unwrap();
284        }
285
286        assert_eq!(
287            results.load(Ordering::SeqCst),
288            NUM_PRODUCERS * OPS_PER_THREAD,
289            "Total items consumed must match total items produced"
290        );
291    }
292    static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
293
294    #[derive(Debug)]
295    struct DropTracker;
296
297    impl Drop for DropTracker {
298        fn drop(&mut self) {
299            DROP_COUNTER.fetch_add(1, Ordering::Relaxed);
300        }
301    }
302
303    #[test]
304    fn test_drop_cleanup() {
305        DROP_COUNTER.store(0, Ordering::Relaxed);
306
307        {
308            let buffer = AtomicRingBufferMpmc::<DropTracker, 8>::new();
309
310            for _ in 0..5 {
311                buffer.push(DropTracker).unwrap();
312            }
313
314            buffer.pop();
315            buffer.pop();
316
317            assert_eq!(DROP_COUNTER.load(Ordering::Relaxed), 2);
318        }
319
320        assert_eq!(DROP_COUNTER.load(Ordering::Relaxed), 5);
321    }
322}