lockless_datastructures/
atomic_ring_buffer_spsc.rs1use std::{
2 cell::UnsafeCell,
3 mem::MaybeUninit,
4 sync::atomic::{AtomicUsize, Ordering},
5};
6
7use crate::{Padded, primitives::Arc};
8
9#[derive(Debug)]
11pub struct AtomicRingBufferSpsc<T, const N: usize> {
12 cached_head: UnsafeCell<usize>,
13 cached_tail: UnsafeCell<usize>,
14 head: Padded<AtomicUsize>,
15 tail: Padded<AtomicUsize>,
16 buffer: UnsafeCell<[MaybeUninit<T>; N]>,
17}
18unsafe impl<T, const N: usize> Sync for AtomicRingBufferSpsc<T, N> {}
19
20impl<T, const N: usize> AtomicRingBufferSpsc<T, N> {
21 pub fn new() -> Arc<Self> {
22 const {
23 assert!(
24 N != 0 && N.is_power_of_two(),
25 "Buffer size N must be a power of two"
26 )
27 };
28 Arc::new(Self {
29 cached_head: UnsafeCell::new(0),
30 cached_tail: UnsafeCell::new(0),
31 buffer: UnsafeCell::new(std::array::from_fn(|_| MaybeUninit::uninit())),
32 head: Padded(AtomicUsize::new(0)),
33 tail: Padded(AtomicUsize::new(0)),
34 })
35 }
36
37 pub fn push(&self, value: T) -> Result<(), T> {
38 let head = self.head.load(Ordering::Relaxed);
39 let mut tail;
40 unsafe {
41 tail = self.cached_tail.get().read();
42 }
43
44 if head.wrapping_sub(tail) == N {
45 tail = self.tail.load(Ordering::Acquire);
46
47 unsafe {
48 self.cached_tail.get().write(tail);
49 }
50
51 if head.wrapping_sub(tail) == N {
52 return Err(value);
53 }
54 }
55
56 unsafe {
57 let buffer_ptr = self.buffer.get() as *mut MaybeUninit<T>;
58 let slot_ptr = buffer_ptr.add(head & (N - 1));
59 (*slot_ptr).write(value);
60 }
61
62 self.head.store(head.wrapping_add(1), Ordering::Release);
63
64 Ok(())
65 }
66
67 pub fn pop(&self) -> Option<T> {
68 let tail = self.tail.load(Ordering::Relaxed);
69
70 let mut head;
71 unsafe {
72 head = self.cached_head.get().read();
73 }
74
75 if tail == head {
76 head = self.head.load(Ordering::Acquire);
77
78 unsafe {
79 self.cached_head.get().write(head);
80 }
81
82 if head == tail {
83 return None;
84 }
85 }
86
87 let value;
88 unsafe {
89 let buffer_ptr = self.buffer.get() as *mut MaybeUninit<T>;
90 let slot_ptr = buffer_ptr.add(tail & (N - 1));
91 value = (*slot_ptr).assume_init_read();
92 }
93
94 self.tail.store(tail.wrapping_add(1), Ordering::Release);
95
96 Some(value)
97 }
98 pub fn read_head(&self) -> usize {
99 self.head.load(Ordering::Acquire) % N
100 }
101
102 pub fn read_tail(&self) -> usize {
103 self.tail.load(Ordering::Acquire) % N
104 }
105
106 pub fn exists(&self, index: usize) -> bool {
107 let mut tail = self.tail.load(Ordering::Acquire);
108 let mut head = self.head.load(Ordering::Acquire);
109 if head == tail {
110 return false;
111 }
112 head &= N - 1;
113 tail &= N - 1;
114 if head > tail {
115 head > index && index > tail
116 } else {
117 !(index >= head && tail > index)
118 }
119 }
120}
121
122impl<T, const N: usize> Drop for AtomicRingBufferSpsc<T, N> {
123 fn drop(&mut self) {
124 if std::mem::needs_drop::<T>() {
125 let head = self.head.load(Ordering::Relaxed);
126 let tail = self.tail.load(Ordering::Relaxed);
127
128 let mut current = tail;
129 while current != head {
130 let mask = current & (N - 1);
131 unsafe {
132 let slot = (*self.buffer.get()).get_unchecked_mut(mask);
133 std::ptr::drop_in_place(slot.as_mut_ptr());
134 }
135 current = current.wrapping_add(1);
136 }
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use std::sync::atomic::{AtomicUsize, Ordering};
145 use std::thread;
146
147 #[test]
148 fn test_simple_push_pop() {
149 let buffer = AtomicRingBufferSpsc::<i32, 4>::new();
150
151 assert!(buffer.push(1).is_ok());
152 assert!(buffer.push(2).is_ok());
153 assert!(buffer.push(3).is_ok());
154 assert!(buffer.push(4).is_ok());
155
156 assert!(buffer.push(5).is_err());
157
158 assert_eq!(buffer.pop(), Some(1));
159 assert_eq!(buffer.pop(), Some(2));
160
161 assert!(buffer.push(5).is_ok());
162
163 assert_eq!(buffer.pop(), Some(3));
164 assert_eq!(buffer.pop(), Some(4));
165 assert_eq!(buffer.pop(), Some(5));
166 assert_eq!(buffer.pop(), None);
167 }
168
169 #[test]
170 fn test_threaded_spsc_ordering() {
171 let buffer = AtomicRingBufferSpsc::<usize, 16>::new();
172 let consumer_buffer = buffer.clone();
173
174 let thread_count = 100_000;
175
176 let producer = thread::spawn(move || {
177 for i in 0..thread_count {
178 while buffer.push(i).is_err() {
179 std::hint::spin_loop();
180 }
181 }
182 });
183
184 let consumer = thread::spawn(move || {
185 for i in 0..thread_count {
186 loop {
187 if let Some(val) = consumer_buffer.pop() {
188 assert_eq!(val, i, "Items received out of order!");
189 break;
190 }
191 std::hint::spin_loop();
192 }
193 }
194 });
195
196 producer.join().unwrap();
197 consumer.join().unwrap();
198 }
199
200 static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
201
202 #[derive(Debug)]
203 struct DropTracker;
204
205 impl Drop for DropTracker {
206 fn drop(&mut self) {
207 DROP_COUNTER.fetch_add(1, Ordering::Relaxed);
208 }
209 }
210
211 #[test]
212 fn test_drop_cleanup() {
213 DROP_COUNTER.store(0, Ordering::Relaxed);
214
215 {
216 let buffer = AtomicRingBufferSpsc::<DropTracker, 8>::new();
217
218 for _ in 0..5 {
219 buffer.push(DropTracker).unwrap();
220 }
221
222 buffer.pop();
223 buffer.pop();
224
225 assert_eq!(DROP_COUNTER.load(Ordering::Relaxed), 2);
226 }
227
228 assert_eq!(DROP_COUNTER.load(Ordering::Relaxed), 5);
229 }
230
231 #[test]
232 fn test_zst() {
233 struct Zst;
234
235 let buffer = AtomicRingBufferSpsc::<Zst, 4>::new();
236
237 assert!(buffer.push(Zst).is_ok());
238 assert!(buffer.push(Zst).is_ok());
239 assert!(buffer.push(Zst).is_ok());
240 assert!(buffer.push(Zst).is_ok());
241 assert!(buffer.push(Zst).is_err());
242
243 assert!(buffer.pop().is_some());
244 assert!(buffer.push(Zst).is_ok());
245 }
246}