Skip to main content

lockless_datastructures/
mutex_ring_buffer.rs

1use parking_lot::Mutex;
2use std::mem::MaybeUninit;
3
4use crate::primitives::Arc;
5
6#[derive(Debug)]
7struct RingBuffer<T, const N: usize> {
8    head: usize,
9    tail: usize,
10    buffer: [MaybeUninit<T>; N],
11}
12
13///A mutex protected RingBuffer
14#[derive(Debug, Clone)]
15pub struct MutexRingBuffer<T, const N: usize>(Arc<Mutex<RingBuffer<T, N>>>);
16
17impl<T, const N: usize> Default for MutexRingBuffer<T, N> {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl<T, const N: usize> MutexRingBuffer<T, N> {
24    pub fn new() -> Self {
25        const {
26            assert!(
27                N != 0 && N.is_power_of_two(),
28                "Buffer size N must be a power of two"
29            )
30        };
31        Self(Arc::new(Mutex::new(RingBuffer {
32            buffer: std::array::from_fn(|_| MaybeUninit::uninit()),
33            head: 0,
34            tail: 0,
35        })))
36    }
37
38    pub fn push(&self, value: T) -> Result<(), T> {
39        let mut ring_buffer = self.0.lock();
40
41        if ring_buffer.head.wrapping_sub(ring_buffer.tail) == N {
42            return Err(value);
43        }
44
45        let idx = Self::mask(ring_buffer.head);
46        unsafe {
47            ring_buffer.buffer.get_unchecked_mut(idx).write(value);
48        }
49        ring_buffer.head = ring_buffer.head.wrapping_add(1);
50        Ok(())
51    }
52
53    pub fn pop(&self) -> Option<T> {
54        let mut ring_buffer = self.0.lock();
55        if ring_buffer.tail != ring_buffer.head {
56            let idx = Self::mask(ring_buffer.tail);
57            let value;
58            unsafe {
59                let ptr = ring_buffer.buffer.get_unchecked(idx).as_ptr();
60
61                value = std::ptr::read(ptr);
62            }
63            ring_buffer.tail = ring_buffer.tail.wrapping_add(1);
64            return Some(value);
65        }
66
67        None
68    }
69    #[inline(always)]
70    fn mask(index: usize) -> usize {
71        index & (N - 1)
72    }
73}
74
75impl<T, const N: usize> Drop for RingBuffer<T, N> {
76    fn drop(&mut self) {
77        if std::mem::needs_drop::<T>() {
78            while self.tail != self.head {
79                let mask = self.tail & (N - 1);
80                unsafe {
81                    std::ptr::drop_in_place(self.buffer.get_unchecked_mut(mask).as_mut_ptr());
82                }
83                self.tail = self.tail.wrapping_add(1);
84            }
85        }
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use std::sync::Arc;
93    use std::sync::atomic::{AtomicUsize, Ordering};
94    use std::thread;
95
96    #[test]
97    fn test_basic_push_pop_wrap() {
98        let buffer = MutexRingBuffer::<i32, 4>::new();
99
100        assert!(buffer.push(1).is_ok());
101        assert!(buffer.push(2).is_ok());
102        assert!(buffer.push(3).is_ok());
103        assert!(buffer.push(4).is_ok());
104
105        assert_eq!(buffer.push(5), Err(5));
106        assert_eq!(buffer.pop(), Some(1));
107        assert_eq!(buffer.pop(), Some(2));
108
109        assert!(buffer.push(5).is_ok());
110        assert!(buffer.push(6).is_ok());
111
112        assert_eq!(buffer.push(7), Err(7));
113
114        assert_eq!(buffer.pop(), Some(3));
115        assert_eq!(buffer.pop(), Some(4));
116        assert_eq!(buffer.pop(), Some(5));
117        assert_eq!(buffer.pop(), Some(6));
118
119        assert_eq!(buffer.pop(), None);
120    }
121
122    #[test]
123    fn test_multithreaded_concurrency() {
124        let buffer = MutexRingBuffer::<usize, 32>::new();
125        let total_items = 10_000;
126
127        let producer_sum = Arc::new(AtomicUsize::new(0));
128        let consumer_sum = Arc::new(AtomicUsize::new(0));
129
130        let mut handles = vec![];
131
132        for _ in 0..2 {
133            let buf = buffer.clone();
134            let sum = producer_sum.clone();
135            handles.push(thread::spawn(move || {
136                for i in 0..(total_items / 2) {
137                    loop {
138                        if buf.push(i).is_ok() {
139                            sum.fetch_add(i, Ordering::Relaxed);
140                            break;
141                        }
142                        std::hint::spin_loop();
143                    }
144                }
145            }));
146        }
147
148        for _ in 0..2 {
149            let buf = buffer.clone();
150            let sum = consumer_sum.clone();
151            handles.push(thread::spawn(move || {
152                let mut count = 0;
153                while count < (total_items / 2) {
154                    if let Some(val) = buf.pop() {
155                        sum.fetch_add(val, Ordering::Relaxed);
156                        count += 1;
157                    } else {
158                        std::hint::spin_loop();
159                    }
160                }
161            }));
162        }
163
164        for h in handles {
165            h.join().unwrap();
166        }
167
168        assert_eq!(
169            producer_sum.load(Ordering::Relaxed),
170            consumer_sum.load(Ordering::Relaxed),
171            "Sum of pushed items should equal sum of popped items"
172        );
173    }
174
175    #[test]
176    fn test_drop_cleanup() {
177        static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
178
179        #[derive(Debug)]
180        struct Droppable;
181        impl Drop for Droppable {
182            fn drop(&mut self) {
183                DROP_COUNT.fetch_add(1, Ordering::Relaxed);
184            }
185        }
186
187        DROP_COUNT.store(0, Ordering::Relaxed);
188
189        {
190            let buffer = MutexRingBuffer::<Droppable, 8>::new();
191
192            for _ in 0..5 {
193                buffer.push(Droppable).unwrap();
194            }
195
196            {
197                let _a = buffer.pop();
198                let _b = buffer.pop();
199            }
200
201            assert_eq!(
202                DROP_COUNT.load(Ordering::Relaxed),
203                2,
204                "Popped items didn't drop"
205            );
206        }
207
208        assert_eq!(
209            DROP_COUNT.load(Ordering::Relaxed),
210            5,
211            "Buffer failed to drop remaining items"
212        );
213    }
214
215    #[test]
216    fn test_zst() {
217        struct Zst;
218
219        let buffer = MutexRingBuffer::<Zst, 4>::new();
220
221        assert!(buffer.push(Zst).is_ok());
222        assert!(buffer.push(Zst).is_ok());
223        assert!(buffer.push(Zst).is_ok());
224        assert!(buffer.push(Zst).is_ok());
225        assert!(buffer.push(Zst).is_err());
226
227        assert!(buffer.pop().is_some());
228        assert!(buffer.push(Zst).is_ok());
229    }
230}