Skip to main content

lockless_datastructures/
atomic_ring_buffer_spsc.rs

1use std::{
2    cell::UnsafeCell,
3    mem::MaybeUninit,
4    sync::atomic::{AtomicUsize, Ordering},
5};
6
7use crate::{Padded, primitives::Arc};
8
9///Uses atomic's instead of mutexes
10#[derive(Debug)]
11pub struct AtomicRingBufferSpsc<T, const N: usize> {
12    cached_head: UnsafeCell<usize>,
13    cached_tail: UnsafeCell<usize>,
14    head: Padded<AtomicUsize>,
15    tail: Padded<AtomicUsize>,
16    buffer: UnsafeCell<[MaybeUninit<T>; N]>,
17}
18unsafe impl<T, const N: usize> Sync for AtomicRingBufferSpsc<T, N> {}
19
20impl<T, const N: usize> AtomicRingBufferSpsc<T, N> {
21    pub fn new() -> Arc<Self> {
22        const {
23            assert!(
24                N != 0 && N.is_power_of_two(),
25                "Buffer size N must be a power of two"
26            )
27        };
28        Arc::new(Self {
29            cached_head: UnsafeCell::new(0),
30            cached_tail: UnsafeCell::new(0),
31            buffer: UnsafeCell::new(std::array::from_fn(|_| MaybeUninit::uninit())),
32            head: Padded(AtomicUsize::new(0)),
33            tail: Padded(AtomicUsize::new(0)),
34        })
35    }
36
37    pub fn push(&self, value: T) -> Result<(), T> {
38        let head = self.head.load(Ordering::Relaxed);
39        let mut tail;
40        unsafe {
41            tail = self.cached_tail.get().read();
42        }
43
44        if head.wrapping_sub(tail) == N {
45            tail = self.tail.load(Ordering::Acquire);
46
47            unsafe {
48                self.cached_tail.get().write(tail);
49            }
50
51            if head.wrapping_sub(tail) == N {
52                return Err(value);
53            }
54        }
55
56        unsafe {
57            let buffer_ptr = self.buffer.get() as *mut MaybeUninit<T>;
58            let slot_ptr = buffer_ptr.add(head & (N - 1));
59            (*slot_ptr).write(value);
60        }
61
62        self.head.store(head.wrapping_add(1), Ordering::Release);
63
64        Ok(())
65    }
66
67    pub fn pop(&self) -> Option<T> {
68        let tail = self.tail.load(Ordering::Relaxed);
69
70        let mut head;
71        unsafe {
72            head = self.cached_head.get().read();
73        }
74
75        if tail == head {
76            head = self.head.load(Ordering::Acquire);
77
78            unsafe {
79                self.cached_head.get().write(head);
80            }
81
82            if head == tail {
83                return None;
84            }
85        }
86
87        let value;
88        unsafe {
89            let buffer_ptr = self.buffer.get() as *mut MaybeUninit<T>;
90            let slot_ptr = buffer_ptr.add(tail & (N - 1));
91            value = (*slot_ptr).assume_init_read();
92        }
93
94        self.tail.store(tail.wrapping_add(1), Ordering::Release);
95
96        Some(value)
97    }
98    pub fn read_head(&self) -> usize {
99        self.head.load(Ordering::Acquire) % N
100    }
101
102    pub fn read_tail(&self) -> usize {
103        self.tail.load(Ordering::Acquire) % N
104    }
105
106    pub fn exists(&self, index: usize) -> bool {
107        let mut tail = self.tail.load(Ordering::Acquire);
108        let mut head = self.head.load(Ordering::Acquire);
109        if head == tail {
110            return false;
111        }
112        head &= N - 1;
113        tail &= N - 1;
114        if head > tail {
115            head > index && index > tail
116        } else {
117            !(index >= head && tail > index)
118        }
119    }
120}
121
122impl<T, const N: usize> Drop for AtomicRingBufferSpsc<T, N> {
123    fn drop(&mut self) {
124        if std::mem::needs_drop::<T>() {
125            let head = self.head.load(Ordering::Relaxed);
126            let tail = self.tail.load(Ordering::Relaxed);
127
128            let mut current = tail;
129            while current != head {
130                let mask = current & (N - 1);
131                unsafe {
132                    let slot = (*self.buffer.get()).get_unchecked_mut(mask);
133                    std::ptr::drop_in_place(slot.as_mut_ptr());
134                }
135                current = current.wrapping_add(1);
136            }
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use std::sync::atomic::{AtomicUsize, Ordering};
145    use std::thread;
146
147    #[test]
148    fn test_simple_push_pop() {
149        let buffer = AtomicRingBufferSpsc::<i32, 4>::new();
150
151        assert!(buffer.push(1).is_ok());
152        assert!(buffer.push(2).is_ok());
153        assert!(buffer.push(3).is_ok());
154        assert!(buffer.push(4).is_ok());
155
156        assert!(buffer.push(5).is_err());
157
158        assert_eq!(buffer.pop(), Some(1));
159        assert_eq!(buffer.pop(), Some(2));
160
161        assert!(buffer.push(5).is_ok());
162
163        assert_eq!(buffer.pop(), Some(3));
164        assert_eq!(buffer.pop(), Some(4));
165        assert_eq!(buffer.pop(), Some(5));
166        assert_eq!(buffer.pop(), None);
167    }
168
169    #[test]
170    fn test_threaded_spsc_ordering() {
171        let buffer = AtomicRingBufferSpsc::<usize, 16>::new();
172        let consumer_buffer = buffer.clone();
173
174        let thread_count = 100_000;
175
176        let producer = thread::spawn(move || {
177            for i in 0..thread_count {
178                while buffer.push(i).is_err() {
179                    std::hint::spin_loop();
180                }
181            }
182        });
183
184        let consumer = thread::spawn(move || {
185            for i in 0..thread_count {
186                loop {
187                    if let Some(val) = consumer_buffer.pop() {
188                        assert_eq!(val, i, "Items received out of order!");
189                        break;
190                    }
191                    std::hint::spin_loop();
192                }
193            }
194        });
195
196        producer.join().unwrap();
197        consumer.join().unwrap();
198    }
199
200    static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
201
202    #[derive(Debug)]
203    struct DropTracker;
204
205    impl Drop for DropTracker {
206        fn drop(&mut self) {
207            DROP_COUNTER.fetch_add(1, Ordering::Relaxed);
208        }
209    }
210
211    #[test]
212    fn test_drop_cleanup() {
213        DROP_COUNTER.store(0, Ordering::Relaxed);
214
215        {
216            let buffer = AtomicRingBufferSpsc::<DropTracker, 8>::new();
217
218            for _ in 0..5 {
219                buffer.push(DropTracker).unwrap();
220            }
221
222            buffer.pop();
223            buffer.pop();
224
225            assert_eq!(DROP_COUNTER.load(Ordering::Relaxed), 2);
226        }
227
228        assert_eq!(DROP_COUNTER.load(Ordering::Relaxed), 5);
229    }
230
231    #[test]
232    fn test_zst() {
233        struct Zst;
234
235        let buffer = AtomicRingBufferSpsc::<Zst, 4>::new();
236
237        assert!(buffer.push(Zst).is_ok());
238        assert!(buffer.push(Zst).is_ok());
239        assert!(buffer.push(Zst).is_ok());
240        assert!(buffer.push(Zst).is_ok());
241        assert!(buffer.push(Zst).is_err());
242
243        assert!(buffer.pop().is_some());
244        assert!(buffer.push(Zst).is_ok());
245    }
246}