lockless_datastructures/
mutex_ring_buffer.rs1use parking_lot::Mutex;
2use std::mem::MaybeUninit;
3
4use crate::primitives::Arc;
5
6#[derive(Debug)]
7struct RingBuffer<T, const N: usize> {
8 head: usize,
9 tail: usize,
10 buffer: [MaybeUninit<T>; N],
11}
12
13#[derive(Debug, Clone)]
15pub struct MutexRingBuffer<T, const N: usize>(Arc<Mutex<RingBuffer<T, N>>>);
16
17impl<T, const N: usize> Default for MutexRingBuffer<T, N> {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl<T, const N: usize> MutexRingBuffer<T, N> {
24 pub fn new() -> Self {
25 const {
26 assert!(
27 N != 0 && N.is_power_of_two(),
28 "Buffer size N must be a power of two"
29 )
30 };
31 Self(Arc::new(Mutex::new(RingBuffer {
32 buffer: std::array::from_fn(|_| MaybeUninit::uninit()),
33 head: 0,
34 tail: 0,
35 })))
36 }
37
38 pub fn push(&self, value: T) -> Result<(), T> {
39 let mut ring_buffer = self.0.lock();
40
41 if ring_buffer.head.wrapping_sub(ring_buffer.tail) == N {
42 return Err(value);
43 }
44
45 let idx = Self::mask(ring_buffer.head);
46 unsafe {
47 ring_buffer.buffer.get_unchecked_mut(idx).write(value);
48 }
49 ring_buffer.head = ring_buffer.head.wrapping_add(1);
50 Ok(())
51 }
52
53 pub fn pop(&self) -> Option<T> {
54 let mut ring_buffer = self.0.lock();
55 if ring_buffer.tail != ring_buffer.head {
56 let idx = Self::mask(ring_buffer.tail);
57 let value;
58 unsafe {
59 let ptr = ring_buffer.buffer.get_unchecked(idx).as_ptr();
60
61 value = std::ptr::read(ptr);
62 }
63 ring_buffer.tail = ring_buffer.tail.wrapping_add(1);
64 return Some(value);
65 }
66
67 None
68 }
69 #[inline(always)]
70 fn mask(index: usize) -> usize {
71 index & (N - 1)
72 }
73}
74
75impl<T, const N: usize> Drop for RingBuffer<T, N> {
76 fn drop(&mut self) {
77 if std::mem::needs_drop::<T>() {
78 while self.tail != self.head {
79 let mask = self.tail & (N - 1);
80 unsafe {
81 std::ptr::drop_in_place(self.buffer.get_unchecked_mut(mask).as_mut_ptr());
82 }
83 self.tail = self.tail.wrapping_add(1);
84 }
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use std::sync::Arc;
93 use std::sync::atomic::{AtomicUsize, Ordering};
94 use std::thread;
95
96 #[test]
97 fn test_basic_push_pop_wrap() {
98 let buffer = MutexRingBuffer::<i32, 4>::new();
99
100 assert!(buffer.push(1).is_ok());
101 assert!(buffer.push(2).is_ok());
102 assert!(buffer.push(3).is_ok());
103 assert!(buffer.push(4).is_ok());
104
105 assert_eq!(buffer.push(5), Err(5));
106 assert_eq!(buffer.pop(), Some(1));
107 assert_eq!(buffer.pop(), Some(2));
108
109 assert!(buffer.push(5).is_ok());
110 assert!(buffer.push(6).is_ok());
111
112 assert_eq!(buffer.push(7), Err(7));
113
114 assert_eq!(buffer.pop(), Some(3));
115 assert_eq!(buffer.pop(), Some(4));
116 assert_eq!(buffer.pop(), Some(5));
117 assert_eq!(buffer.pop(), Some(6));
118
119 assert_eq!(buffer.pop(), None);
120 }
121
122 #[test]
123 fn test_multithreaded_concurrency() {
124 let buffer = MutexRingBuffer::<usize, 32>::new();
125 let total_items = 10_000;
126
127 let producer_sum = Arc::new(AtomicUsize::new(0));
128 let consumer_sum = Arc::new(AtomicUsize::new(0));
129
130 let mut handles = vec![];
131
132 for _ in 0..2 {
133 let buf = buffer.clone();
134 let sum = producer_sum.clone();
135 handles.push(thread::spawn(move || {
136 for i in 0..(total_items / 2) {
137 loop {
138 if buf.push(i).is_ok() {
139 sum.fetch_add(i, Ordering::Relaxed);
140 break;
141 }
142 std::hint::spin_loop();
143 }
144 }
145 }));
146 }
147
148 for _ in 0..2 {
149 let buf = buffer.clone();
150 let sum = consumer_sum.clone();
151 handles.push(thread::spawn(move || {
152 let mut count = 0;
153 while count < (total_items / 2) {
154 if let Some(val) = buf.pop() {
155 sum.fetch_add(val, Ordering::Relaxed);
156 count += 1;
157 } else {
158 std::hint::spin_loop();
159 }
160 }
161 }));
162 }
163
164 for h in handles {
165 h.join().unwrap();
166 }
167
168 assert_eq!(
169 producer_sum.load(Ordering::Relaxed),
170 consumer_sum.load(Ordering::Relaxed),
171 "Sum of pushed items should equal sum of popped items"
172 );
173 }
174
175 #[test]
176 fn test_drop_cleanup() {
177 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
178
179 #[derive(Debug)]
180 struct Droppable;
181 impl Drop for Droppable {
182 fn drop(&mut self) {
183 DROP_COUNT.fetch_add(1, Ordering::Relaxed);
184 }
185 }
186
187 DROP_COUNT.store(0, Ordering::Relaxed);
188
189 {
190 let buffer = MutexRingBuffer::<Droppable, 8>::new();
191
192 for _ in 0..5 {
193 buffer.push(Droppable).unwrap();
194 }
195
196 {
197 let _a = buffer.pop();
198 let _b = buffer.pop();
199 }
200
201 assert_eq!(
202 DROP_COUNT.load(Ordering::Relaxed),
203 2,
204 "Popped items didn't drop"
205 );
206 }
207
208 assert_eq!(
209 DROP_COUNT.load(Ordering::Relaxed),
210 5,
211 "Buffer failed to drop remaining items"
212 );
213 }
214
215 #[test]
216 fn test_zst() {
217 struct Zst;
218
219 let buffer = MutexRingBuffer::<Zst, 4>::new();
220
221 assert!(buffer.push(Zst).is_ok());
222 assert!(buffer.push(Zst).is_ok());
223 assert!(buffer.push(Zst).is_ok());
224 assert!(buffer.push(Zst).is_ok());
225 assert!(buffer.push(Zst).is_err());
226
227 assert!(buffer.pop().is_some());
228 assert!(buffer.push(Zst).is_ok());
229 }
230}