atomic_queue/
lib.rs

1// Augmented Audio: Audio libraries and applications
2// Copyright (c) 2022 Pedro Tacla Yamada
3//
4// The MIT License (MIT)
5//
6// Permission is hereby granted, free of charge, to any person obtaining a copy
7// of this software and associated documentation files (the "Software"), to deal
8// in the Software without restriction, including without limitation the rights
9// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10// copies of the Software, and to permit persons to whom the Software is
11// furnished to do so, subject to the following conditions:
12//
13// The above copyright notice and this permission notice shall be included in
14// all copies or substantial portions of the Software.
15//
16// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22// THE SOFTWARE.
23//! [`atomic_queue`] is a port of C++'s [max0x7ba/atomic_queue](https://github.com/max0x7ba/atomic_queue)
24//! implementation to rust.
25//!
26//! This is part of [augmented-audio](https://github.com/yamadapc/augmented-audio).
27//!
28//! It provides a bounded multi-producer, multi-consumer lock-free queue that is real-time safe.
29//!
30//! # Usage
31//! ```rust
32//! let queue: atomic_queue::Queue<usize> = atomic_queue::bounded(10);
33//!
34//! queue.push(10);
35//! if let Some(v) = queue.pop() {
36//!     assert_eq!(v, 10);
37//! }
38//! ```
39//!
40//! # Safety
41//! This queue implementation uses unsafe internally.
42//!
43//! # Performance
44//! When benchmarked on a 2017 i7, this was a lot slower than `ringbuf` (~2x).
45//!
46//! I'd think this is fine since this queue supporting multiple consumers and
47//! multiple producers while `ringbuf` is single producer single consumer.
48//!
49//! Testing again on a M1 Pro, it is 30% faster.
50use std::cell::UnsafeCell;
51use std::cmp::max;
52use std::mem::MaybeUninit;
53use std::sync::atomic::{AtomicI8, AtomicUsize, Ordering};
54
55#[warn(missing_docs)]
56
57/// State a slot in the Queue's circular buffer can be in.
58enum CellState {
59    Empty = 0,
60    Storing = 1,
61    Stored = 2,
62    Loading = 3,
63}
64
65impl From<CellState> for i8 {
66    fn from(value: CellState) -> Self {
67        match value {
68            CellState::Empty => 0,
69            CellState::Storing => 1,
70            CellState::Stored => 2,
71            CellState::Loading => 3,
72        }
73    }
74}
75
76/// Atomic queue cloned from https://github.com/max0x7ba/atomic_queue
77///
78/// Should be:
79/// * Lock-free
80///
81/// Any type can be pushed into the queue, but it's recommended to use some sort of smart pointer
82/// that can be free-ed outside of the critical path.
83///
84/// Uses unsafe internally.
85pub struct Queue<T> {
86    head: AtomicUsize,
87    tail: AtomicUsize,
88    elements: Vec<UnsafeCell<MaybeUninit<T>>>,
89    states: Vec<AtomicI8>,
90}
91
92unsafe impl<T: Send> Send for Queue<T> {}
93unsafe impl<T: Send> Sync for Queue<T> {}
94
95/// Alias for `Queue::new`, creates a new bounded `MPMC` queue with the given capacity.
96///
97/// Writes will fail if the queue is full.
98pub fn bounded<T>(capacity: usize) -> Queue<T> {
99    Queue::new(capacity)
100}
101
102impl<T> Queue<T> {
103    /// Create a queue with a certain capacity. Writes will fail when the queue is full.
104    pub fn new(capacity: usize) -> Self {
105        let mut elements = Vec::with_capacity(capacity);
106        for _ in 0..capacity {
107            elements.push(UnsafeCell::new(MaybeUninit::uninit()));
108        }
109        let mut states = Vec::with_capacity(capacity);
110        for _ in 0..capacity {
111            states.push(AtomicI8::new(CellState::Empty.into()));
112        }
113        let head = AtomicUsize::new(0);
114        let tail = AtomicUsize::new(0);
115        Queue {
116            head,
117            tail,
118            elements,
119            states,
120        }
121    }
122
123    /// Push an element into the queue and return `true` on success.
124    ///
125    /// `false` will be returned if the queue is full. If there's contention this operation will
126    /// wait until it's able to claim a slot in the queue.
127    ///
128    /// This is a CAS loop to increment the head of the queue, then another to push this element in.
129    pub fn push(&self, element: T) -> bool {
130        let mut head = self.head.load(Ordering::Relaxed);
131        let elements_len = self.elements.len();
132        loop {
133            let length = head as i64 - self.tail.load(Ordering::Relaxed) as i64;
134            if length >= elements_len as i64 {
135                return false;
136            }
137
138            if self
139                .head
140                .compare_exchange(head, head + 1, Ordering::Acquire, Ordering::Relaxed)
141                .is_ok()
142            {
143                self.do_push(element, head);
144                return true;
145            }
146
147            head = self.head.load(Ordering::Relaxed);
148        }
149    }
150
151    /// Pop an element from the queue and return `true` on success.
152    ///
153    /// `false` will be returned if the queue is empty. If there's contention this operation will
154    /// wait until it's able to claim a slot in the queue.
155    ///
156    /// This is a CAS loop to increment the tail of the queue then another CAS loop to pop the
157    /// element at this index out.
158    pub fn pop(&self) -> Option<T> {
159        let mut tail = self.tail.load(Ordering::Relaxed);
160        loop {
161            let length = self.head.load(Ordering::Relaxed) as i64 - tail as i64;
162            if length <= 0 {
163                return None;
164            }
165
166            if self
167                .tail
168                .compare_exchange(tail, tail + 1, Ordering::Acquire, Ordering::Relaxed)
169                .is_ok()
170            {
171                break;
172            }
173
174            tail = self.tail.load(Ordering::Relaxed);
175        }
176        Some(self.do_pop(tail))
177    }
178
179    /// Pop an element from the queue without checking if it's empty.
180    ///
181    /// # Safety
182    /// There's nothing safe about this.
183    pub unsafe fn force_pop(&self) -> T {
184        let tail = self.tail.fetch_add(1, Ordering::Acquire);
185        self.do_pop(tail)
186    }
187
188    /// Push an element into the queue without checking if it's full.
189    ///
190    /// # Safety
191    /// There's nothing safe about this.
192    pub unsafe fn force_push(&self, element: T) {
193        let head = self.head.fetch_add(1, Ordering::Acquire);
194        self.do_push(element, head);
195    }
196
197    /// True if the queue is empty.
198    pub fn is_empty(&self) -> bool {
199        self.len() == 0
200    }
201
202    /// Get the length of the queue.
203    pub fn len(&self) -> usize {
204        max(
205            self.head.load(Ordering::Relaxed) - self.tail.load(Ordering::Relaxed),
206            0,
207        )
208    }
209}
210
211impl<T> Queue<T> {
212    fn do_pop(&self, tail: usize) -> T {
213        let state = &self.states[tail % self.states.len()];
214        loop {
215            let expected = CellState::Stored;
216            if state
217                .compare_exchange(
218                    expected.into(),
219                    CellState::Loading.into(),
220                    Ordering::Acquire,
221                    Ordering::Relaxed,
222                )
223                .is_ok()
224            {
225                let element = unsafe {
226                    self.elements[tail % self.elements.len()]
227                        .get()
228                        .replace(MaybeUninit::uninit())
229                        .assume_init()
230                };
231
232                state.store(CellState::Empty.into(), Ordering::Release);
233
234                return element;
235            }
236        }
237    }
238
239    fn do_push(&self, element: T, head: usize) {
240        self.do_push_any(element, head);
241    }
242
243    fn do_push_any(&self, element: T, head: usize) {
244        let state = &self.states[head % self.states.len()];
245        loop {
246            let expected = CellState::Empty;
247            if state
248                .compare_exchange(
249                    expected.into(),
250                    CellState::Storing.into(),
251                    Ordering::Acquire,
252                    Ordering::Relaxed,
253                )
254                .is_ok()
255            {
256                unsafe {
257                    // There's a potential small % optimisation from removing bounds checking here &
258                    // using mem::replace.
259                    self.elements[head % self.elements.len()]
260                        .get()
261                        .write(MaybeUninit::new(element));
262                }
263                state.store(CellState::Stored.into(), Ordering::Release);
264                return;
265            }
266        }
267    }
268}
269
270impl<T> Drop for Queue<T> {
271    fn drop(&mut self) {
272        if std::mem::needs_drop::<T>() {
273            // Could probably be made more efficient by using [std::ptr::drop_in_place()]
274            // as the &mut self here guarantees that we are the only remaining user of this Queue
275            while let Some(element) = self.pop() {
276                drop(element);
277            }
278        }
279    }
280}
281
282#[cfg(test)]
283mod test {
284    use std::ffi::c_void;
285    use std::sync::{Arc, Mutex};
286    use std::thread;
287    use std::thread::JoinHandle;
288    use std::time::Duration;
289
290    use super::*;
291
292    #[derive(Eq, PartialEq, Debug, Copy, Clone)]
293    struct MockPtr(*mut c_void);
294
295    unsafe impl Send for MockPtr {}
296
297    fn mock_ptr(value: i32) -> MockPtr {
298        MockPtr(value as *mut c_void)
299    }
300
301    #[test]
302    fn test_create_bounded_queue() {
303        let _queue = Queue::<MockPtr>::new(10);
304    }
305
306    #[test]
307    fn test_get_empty_queue_len() {
308        let queue = Queue::<MockPtr>::new(10);
309        assert_eq!(queue.len(), 0);
310    }
311
312    #[test]
313    fn test_queue_drops_items() {
314        struct Item {
315            drop_count: Arc<AtomicUsize>,
316        }
317        impl Drop for Item {
318            fn drop(&mut self) {
319                self.drop_count.fetch_add(1, Ordering::Relaxed);
320            }
321        }
322        let drop_count = Arc::new(AtomicUsize::new(0));
323        let queue: Queue<Item> = Queue::new(10);
324        queue.push(Item {
325            drop_count: drop_count.clone(),
326        });
327        queue.push(Item {
328            drop_count: drop_count.clone(),
329        });
330        queue.push(Item {
331            drop_count: drop_count.clone(),
332        });
333        drop(queue);
334
335        assert_eq!(drop_count.load(Ordering::Relaxed), 3);
336    }
337
338    #[test]
339    fn test_push_element_to_queue_increments_length() {
340        let queue = Queue::<MockPtr>::new(10);
341        assert_eq!(queue.len(), 0);
342        let ptr = mock_ptr(1);
343        assert!(queue.push(ptr));
344        assert_eq!(queue.len(), 1);
345        let value = queue.pop();
346        assert_eq!(value.unwrap(), ptr);
347        assert_eq!(queue.len(), 0);
348    }
349
350    #[test]
351    fn test_push_pop_push_pop() {
352        let queue = Queue::<MockPtr>::new(10);
353        assert_eq!(queue.len(), 0);
354        {
355            let ptr = mock_ptr(1);
356            assert!(queue.push(ptr));
357            assert_eq!(queue.len(), 1);
358            let value = queue.pop();
359            assert_eq!(value.unwrap(), ptr);
360            assert_eq!(queue.len(), 0);
361        }
362        {
363            let ptr = mock_ptr(2);
364            assert!(queue.push(ptr));
365            assert_eq!(queue.len(), 1);
366            let value = queue.pop();
367            assert_eq!(value.unwrap(), ptr);
368            assert_eq!(queue.len(), 0);
369        }
370    }
371
372    #[test]
373    fn test_overflow_will_not_break_things() {
374        let queue = Queue::<MockPtr>::new(3);
375        assert_eq!(queue.len(), 0);
376
377        // ENTRY 1 - HEAD, ENTRY, TAIL, EMPTY, EMPTY
378        assert!(queue.push(mock_ptr(1)));
379        assert_eq!(queue.len(), 1);
380
381        // ENTRY 2 - HEAD, ENTRY, ENTRY, TAIL, EMPTY
382        assert!(queue.push(mock_ptr(2)));
383        assert_eq!(queue.len(), 2);
384
385        // ENTRY 3 - HEAD, ENTRY, ENTRY, ENTRY, TAIL
386        assert!(queue.push(mock_ptr(3)));
387        assert_eq!(queue.len(), 3);
388
389        // ENTRY 4 - Will fail
390        assert_eq!(queue.len(), 3);
391        let result = queue.push(mock_ptr(4));
392        assert!(!result);
393        assert_eq!(queue.len(), 3);
394    }
395
396    #[test]
397    fn test_multithread_push() {
398        wisual_logger::init_from_env();
399
400        let queue = Arc::new(Queue::new(50000));
401
402        let writer_thread_1 = spawn_writer_thread(
403            10,
404            queue.clone(),
405            Duration::from_millis((0.0 * rand::random::<f64>()) as u64),
406        );
407        let writer_thread_2 = spawn_writer_thread(
408            10,
409            queue.clone(),
410            Duration::from_millis((0.0 * rand::random::<f64>()) as u64),
411        );
412        let writer_thread_3 = spawn_writer_thread(
413            10,
414            queue.clone(),
415            Duration::from_millis((0.0 * rand::random::<f64>()) as u64),
416        );
417
418        writer_thread_1.join().unwrap();
419        writer_thread_2.join().unwrap();
420        writer_thread_3.join().unwrap();
421        assert_eq!(queue.len(), 30);
422    }
423
424    #[test]
425    fn test_multithread_push_pop() {
426        wisual_logger::init_from_env();
427
428        let size = 10000;
429        let num_threads = 5;
430
431        let queue: Arc<Queue<MockPtr>> = Arc::new(Queue::new(size * num_threads / 3));
432        let output_queue: Arc<Queue<MockPtr>> = Arc::new(Queue::new(size * num_threads));
433
434        let is_running = Arc::new(Mutex::new(true));
435        let reader_thread = {
436            let is_running = is_running.clone();
437            let queue = queue.clone();
438            let output_queue = output_queue.clone();
439            thread::spawn(move || {
440                while *is_running.lock().unwrap() || queue.len() > 0 {
441                    loop {
442                        match queue.pop() {
443                            None => break,
444                            Some(value) => {
445                                output_queue.push(value);
446                            }
447                        }
448                    }
449                }
450                log::info!("Reader thread done reading");
451            })
452        };
453
454        let threads: Vec<JoinHandle<()>> = (0..num_threads)
455            .into_iter()
456            .map(|_| {
457                spawn_writer_thread(
458                    size,
459                    queue.clone(),
460                    Duration::from_millis((rand::random::<f64>()) as u64),
461                )
462            })
463            .collect();
464
465        for thread in threads {
466            thread.join().unwrap();
467        }
468
469        {
470            let mut is_running = is_running.lock().unwrap();
471            *is_running = false;
472        }
473        reader_thread.join().unwrap();
474
475        assert_eq!(queue.len(), 0);
476        assert_eq!(output_queue.len(), size * num_threads);
477    }
478
479    fn spawn_writer_thread(
480        size: usize,
481        queue: Arc<Queue<MockPtr>>,
482        duration: Duration,
483    ) -> JoinHandle<()> {
484        thread::spawn(move || {
485            for i in 0..size {
486                loop {
487                    let pushed = queue.push(mock_ptr(i as i32));
488                    if pushed {
489                        break;
490                    }
491                }
492                thread::sleep(duration);
493            }
494            log::info!("Thread done writing");
495        })
496    }
497}