waitfree_sync/
spsc.rs

1//! A wait-free single-producer single-consumer (SPSC) queue to send data to another thread.
2//! It is based on the improved FastForward queue.
3//!
4//! # Example
5//! ```rust
6//! use waitfree_sync::spsc;
7//!
8//! //                            Type ──╮   ╭─ Capacity
9//! let (mut tx, mut rx) = spsc::spsc::<u64>(8);
10//! tx.try_send(234);
11//! assert_eq!(rx.try_recv(),Some(234u64));
12//! ```
13//!
14//! # Behavior for full and empty queue.
15//! If the queue is full, the [Sender] returns a [NoSpaceLeftError].
16//! If the queue is empty, the [Receiver] returns `None`
17
18//!
19use crate::import::{Arc, AtomicBool, Ordering, UnsafeCell};
20use core::error::Error;
21use crossbeam_utils::CachePadded;
22use std::fmt::Debug;
23
24/// Create a new wait-free SPSC queue. The `capacity` must be a power of two, which is validate during runtime.
25/// # Panic
26/// Panics if the `capacity` is not a power of two.
27/// # Example
28/// ```rust
29/// use waitfree_sync::spsc;
30///
31/// //               Data type ──╮   ╭─ Capacity
32/// let (tx, rx) = spsc::spsc::<u64>(8);
33/// ```
34pub fn spsc<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
35    if !is_power_of_two(capacity) {
36        panic!("The SIZE must be a power of 2")
37    }
38
39    let chan = Arc::new(Spsc::new(capacity));
40
41    let r = Receiver::new(chan.clone());
42    let w = Sender::new(chan);
43
44    (w, r)
45}
46
47const fn is_power_of_two(x: usize) -> bool {
48    let c = x.wrapping_sub(1);
49    (x != 0) && (x != 1) && ((x & c) == 0)
50}
51
52/// Indicates that a queue is full.
53#[derive(Clone, Debug, PartialEq)]
54pub struct NoSpaceLeftError<T>(T);
55impl<T: Debug> Error for NoSpaceLeftError<T> {}
56impl<T> core::fmt::Display for NoSpaceLeftError<T> {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        write!(f, "No space left in the SPSC queue.")
59    }
60}
61
62#[derive(Debug)]
63struct Slot<T> {
64    value: UnsafeCell<Option<T>>,
65    occupied: CachePadded<AtomicBool>,
66}
67impl<T> Slot<T> {
68    fn new() -> Self {
69        Self {
70            value: UnsafeCell::new(None),
71            occupied: CachePadded::new(false.into()),
72        }
73    }
74}
75
76#[derive(Debug)]
77struct Spsc<T> {
78    mem: Box<[Slot<T>]>,
79    // The mask is written when this structure is created and is then only read.
80    // Therefore, we do not need Atomic here.
81    mask: usize,
82}
83
84impl<T> Spsc<T> {
85    fn new(size: usize) -> Self {
86        let mut buffer = Vec::with_capacity(size);
87        for _ in 0..size {
88            buffer.push(Slot::new());
89        }
90        let buffer: Box<[Slot<T>]> = buffer.into_boxed_slice();
91        Spsc {
92            mem: buffer,
93            mask: size - 1,
94        }
95    }
96
97    #[inline]
98    fn capacity(&self) -> usize {
99        self.mask + 1
100    }
101}
102
103/// The receiving side of the [spsc] queue.
104#[derive(Debug)]
105pub struct Receiver<T> {
106    spsc: Arc<Spsc<T>>,
107    read: usize,
108}
109unsafe impl<T: Send> Send for Receiver<T> {}
110unsafe impl<T: Send> Sync for Receiver<T> {}
111
112impl<T> Receiver<T> {
113    fn new(spsc: Arc<Spsc<T>>) -> Self {
114        Receiver { spsc, read: 0 }
115    }
116}
117
118impl<T> Receiver<T> {
119    /// Retrieve the next available element from the queue.
120    /// Returns [None] if the queue is empty.
121    pub fn try_recv(&mut self) -> Option<T> {
122        let rpos = self.read & self.spsc.mask;
123        let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
124        if !slot.occupied.load(Ordering::Acquire) {
125            None
126        } else {
127            #[cfg(not(loom))]
128            let val = unsafe { slot.value.get().replace(None) };
129            #[cfg(loom)]
130            let val = unsafe { slot.value.get_mut().with(|ptr| ptr.replace(None)) };
131
132            slot.occupied.store(false, Ordering::Release);
133            self.read += 1;
134            val
135        }
136    }
137    /// Peeks the next element in the queue without removing it.
138    #[cfg(not(loom))] // We can't return a reference to an UnsafeCell of loom.
139    pub fn peek(&self) -> Option<&T> {
140        let rpos = self.read & self.spsc.mask;
141        let slot = unsafe { self.spsc.mem.get_unchecked(rpos) };
142        if !slot.occupied.load(Ordering::Acquire) {
143            None
144        } else {
145            let val = unsafe { &*slot.value.get() };
146            val.as_ref()
147        }
148    }
149    /// Returns the total number of items that the queue can hold at most.
150    #[inline]
151    pub fn capacity(&self) -> usize {
152        // SAFETY: This is safe because we only read size which is never written.
153        self.spsc.capacity()
154    }
155}
156
157/// The sending side of the [spsc] queue.
158#[derive(Debug)]
159pub struct Sender<T> {
160    spsc: Arc<Spsc<T>>,
161    write: usize,
162}
163unsafe impl<T: Send> Send for Sender<T> {}
164unsafe impl<T: Send> Sync for Sender<T> {}
165impl<T> Sender<T> {
166    fn new(spsc: Arc<Spsc<T>>) -> Self {
167        Sender { spsc, write: 0 }
168    }
169}
170
171impl<T> Sender<T> {
172    /// Attempts to send a value to the queue without blocking.
173    /// Returns a [NoSpaceLeftError] if the queue is full.
174    pub fn try_send(&mut self, data: T) -> Result<(), NoSpaceLeftError<T>> {
175        let wpos = self.write & self.spsc.mask;
176
177        let slot = unsafe { self.spsc.mem.get_unchecked(wpos) };
178        if slot.occupied.load(Ordering::Acquire) {
179            Err(NoSpaceLeftError(data))
180        } else {
181            #[cfg(not(loom))]
182            unsafe {
183                slot.value.get().write(Some(data))
184            };
185            #[cfg(loom)]
186            unsafe {
187                slot.value.get_mut().with(|ptr| ptr.write(Some(data)))
188            };
189            slot.occupied.store(true, Ordering::Release);
190            self.write += 1;
191            Ok(())
192        }
193    }
194
195    /// Returns the total number of items that the queue can hold at most.
196    #[inline]
197    pub fn capacity(&self) -> usize {
198        // SAFETY: This is safe because we only read size which is never written.
199        self.spsc.capacity()
200    }
201}
202
203#[cfg(not(loom))]
204#[cfg(test)]
205mod test {
206    #[cfg(loom)]
207    use loom::thread;
208    #[cfg(not(loom))]
209    use std::thread;
210
211    use super::*;
212
213    #[test]
214    fn smoke() {
215        let (mut w, mut r) = spsc(4);
216        w.try_send(vec![0; 15]).unwrap();
217        w.try_send(vec![0; 16]).unwrap();
218        w.try_send(vec![0; 17]).unwrap();
219        w.try_send(vec![0; 18]).unwrap();
220
221        assert_eq!(r.try_recv(), Some(vec![0; 15]));
222        assert_eq!(r.try_recv(), Some(vec![0; 16]));
223        assert_eq!(r.try_recv(), Some(vec![0; 17]));
224        assert_eq!(r.try_recv(), Some(vec![0; 18]));
225    }
226
227    #[test]
228    fn test_is_power_of_two() {
229        assert!(!is_power_of_two(0));
230        assert!(!is_power_of_two(1));
231        assert!(is_power_of_two(2));
232        assert!(!is_power_of_two(3));
233        assert!(is_power_of_two(4));
234        assert!(!is_power_of_two(5));
235        assert!(!is_power_of_two(6));
236        assert!(!is_power_of_two(7));
237        assert!(is_power_of_two(8));
238        assert!(!is_power_of_two(9));
239
240        assert!(!is_power_of_two(15));
241        assert!(is_power_of_two(16));
242        assert!(!is_power_of_two(17));
243
244        assert!(!is_power_of_two(31));
245        assert!(is_power_of_two(32));
246        assert!(!is_power_of_two(33));
247    }
248
249    #[test]
250    fn test_full_empty() {
251        let (mut write, mut read) = spsc::<i32>(4);
252        assert_eq!(write.try_send(1), Ok(()));
253        assert_eq!(write.try_send(2), Ok(()));
254        assert_eq!(write.try_send(3), Ok(()));
255        assert_eq!(write.try_send(4), Ok(()));
256        assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
257        assert_eq!(read.try_recv(), Some(1));
258        assert_eq!(write.try_send(6), Ok(()));
259        assert_eq!(read.try_recv(), Some(2));
260        assert_eq!(read.try_recv(), Some(3));
261        assert_eq!(read.try_recv(), Some(4));
262        assert_eq!(read.try_recv(), Some(6));
263        assert_eq!(read.try_recv(), None);
264    }
265
266    #[test]
267    fn test_drop_one_side() {
268        let (mut write, read) = spsc::<i32>(4);
269        drop(read);
270        assert_eq!(write.try_send(1), Ok(()));
271        assert_eq!(write.try_send(2), Ok(()));
272        assert_eq!(write.try_send(3), Ok(()));
273        assert_eq!(write.try_send(4), Ok(()));
274        assert_eq!(write.try_send(5), Err(NoSpaceLeftError(5)));
275    }
276
277    #[test]
278    fn test_peek() {
279        let (mut w, mut r) = spsc(4);
280        w.try_send(vec![0; 15]).unwrap();
281        w.try_send(vec![0; 16]).unwrap();
282        w.try_send(vec![0; 17]).unwrap();
283        w.try_send(vec![0; 18]).unwrap();
284
285        assert_eq!(r.peek(), Some(&vec![0; 15]));
286        assert_eq!(r.try_recv(), Some(vec![0; 15]));
287        assert_eq!(r.peek(), Some(&vec![0; 16]));
288        assert_eq!(r.try_recv(), Some(vec![0; 16]));
289        assert_eq!(r.peek(), Some(&vec![0; 17]));
290        assert_eq!(r.try_recv(), Some(vec![0; 17]));
291        assert_eq!(r.peek(), Some(&vec![0; 18]));
292        assert_eq!(r.peek(), Some(&vec![0; 18]));
293        assert_eq!(r.peek(), Some(&vec![0; 18]));
294        assert_eq!(r.try_recv(), Some(vec![0; 18]));
295        assert_eq!(r.peek(), None);
296    }
297
298    #[test]
299    fn test_peek_threaded() {
300        let (mut sender, mut receiver) = spsc(4);
301
302        let writer_thread = thread::spawn(move || {
303            thread::park();
304            for i in 0..4 {
305                assert_eq!(sender.try_send([i; 50]), Ok(()));
306            }
307        });
308        let reader_thread = thread::spawn(move || {
309            thread::park();
310            for _ in 0..4 {
311                if let Some(val) = receiver.peek() {
312                    let first_entry = val[0];
313                    for entry in val {
314                        assert_eq!(*entry, first_entry);
315                    }
316                    let val = receiver.try_recv().unwrap();
317                    let first_entry = val[0];
318                    for entry in val {
319                        assert_eq!(entry, first_entry);
320                    }
321                }
322            }
323        });
324        writer_thread.thread().unpark();
325        reader_thread.thread().unpark();
326        assert!(writer_thread.join().is_ok());
327        assert!(reader_thread.join().is_ok());
328    }
329}