Skip to main content

locklessness/containers/
mpsc_queue.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::marker::PhantomData;
3
4use handle::{HandleInner, Handle, IdHandle, ResizingHandle, BoundedHandle, Like};
5use primitives::atomic_ext::AtomicExt;
6use primitives::index_allocator::IndexAllocator;
7use containers::storage::{Place, Storage};
8use containers::scratch::Scratch;
9
10// Pointers are not wrapped until they reach WRAP_THRESHOLD, at
11// which point they are wrapped modulo RING_SIZE*2. This allows
12// accessors to be confident whether the pointers have changed
13// since they were read, preventing the ABA problem, whilst also
14// distinguishing between an empty queue and a full queue:
15//  ___________________
16// |___|_X_|_X_|___|___|
17//       ^       ^
18//       H       T
19//
20// (H == T) => Empty
21// (H != T) && (H%C == T%C) => Full
22//
23//
24// Each cell on the ring stores an access count in the high bits:
25//  ____________________________
26// | access count | value index |
27// |____BITS/4____|__REMAINING__|
28//
29// An odd access count indicates that the cell contains a value,
30// while an even access count indicates that the cell is empty.
31// All access counts are initialized to zero.
32// The access count is used to prevent a form of the ABA problem,
33// where a producer tries to store into a cell which is no longer
34// the tail of the queue, and happens to have the same value index.
35
36// Number of bits in access count
37const TAG_BITS: usize = ::POINTER_BITS/4;
38// Mask to extract the value index
39const VALUE_MASK: usize = !0 >> TAG_BITS;
40// Mask to extract the tag
41const TAG_MASK: usize = !VALUE_MASK;
42// Lowest bit of tag
43const TAG_BIT: usize = 1 << (::POINTER_BITS - TAG_BITS);
44// Threshold at which to wrap the head/tail pointers
45const WRAP_THRESHOLD: usize = !0 ^ (!0 >> 1);
46
47// The raw queue implementation can only store things that
48// look like a `usize`. The values must also be less than
49// or equal to VALUE_MASK, to allow room for the tag bits.
50#[derive(Debug)]
51pub struct MpscQueueInner<T: Like<usize>> {
52    // Circular buffer storing the values.
53    ring: Vec<AtomicUsize>,
54    // Pair of pointers into the ring buffer
55    head: AtomicUsize,
56    tail: AtomicUsize,
57    // Pretend we actually store instances of T
58    phantom: PhantomData<T>,
59}
60
61// Advance a pointer by one cell, wrapping if necessary
62fn next_cell(mut index: usize, size2: usize) -> usize {
63    index += 1;
64    if index >= WRAP_THRESHOLD {
65        index = index % size2;
66    }
67    index
68}
69
70// Determine if we can just add empty elements to the end of the ring-buffer.
71// If the "live section" wraps around, then we can't.
72fn wraps_around(start: usize, end: usize, size: usize) -> bool {
73    let size2 = size*2;
74    // If the end is before the start, or they're equal but the queue is full,
75    // then we will need to do some additional shuffling after extending the
76    // queue.
77    (end % size) < (start % size) || ((start + size) % size2 == (end % size2))
78}
79
80// In-place rotation algorithm (shifts to the right)
81fn rotate_slice<T>(slice: &mut [T], places: usize) {
82    // Rotation can be implemented by reversing the slice,
83    // splitting the slice in two, and then reversing the
84    // two sub-slices.
85    slice.reverse();
86    let (a, b) = slice.split_at_mut(places);
87    a.reverse();
88    b.reverse();
89}
90
91fn validate_value(v: usize) -> usize {
92    assert!(v <= VALUE_MASK, "Value index outside allowed range!");
93    v
94}
95
96impl<T: Like<usize>> MpscQueueInner<T> {
97    // Constructor takes an iterator to "fill" the buffer with an initial set of
98    // values (even empty cells have a value index...)
99    pub fn new<I: IntoIterator<Item=T>>(iter: I) -> Self {
100        MpscQueueInner {
101            ring: iter.into_iter()
102                .map(Into::into)
103                .map(validate_value)
104                .map(AtomicUsize::new)
105                .collect(),
106            head: AtomicUsize::new(0),
107            tail: AtomicUsize::new(0),
108            phantom: PhantomData
109        }
110    }
111
112    pub fn extend<I: IntoIterator<Item=T>>(&mut self, iter: I) where I::IntoIter: ExactSizeIterator {
113        let iter = iter.into_iter();
114        let size = self.ring.len();
115        // Size of the iterator tells us how much the queue is being extended
116        let extra = iter.len();
117        self.ring.reserve_exact(extra);
118        self.ring.extend(iter.map(Into::into).map(validate_value).map(AtomicUsize::new));
119
120        // If the queue wraps around the buffer, shift the elements
121        // along such that the start section of the queue is moved to the
122        // new end of the buffer.
123        let head = self.head.get_mut();
124        let tail = self.tail.get_mut();
125        if wraps_around(*head, *tail, size) {
126            rotate_slice(&mut self.ring[*head..], extra);
127            *head += extra;
128        }
129    }
130
131    // This is the length of the buffer, not the number of "live" elements
132    pub fn len(&self) -> usize {
133        self.ring.len()
134    }
135
136    // Swap a value onto the tail of the queue. If the queue is observed to
137    // be full, there are no side effects and `false` is returned.
138    pub unsafe fn push(&self, value: &mut T) -> bool {
139        let index = value.borrow_mut();
140        let size = self.ring.len();
141        let size2 = size*2;
142
143        validate_value(*index);
144
145        loop {
146            // Uppdate the cell pointed to by the tail
147            // `try_update_indirect` takes two functions:
148            //
149            // deref
150            //   Takes the tail pointer as input, and returns
151            //   `Ok(&cell_to_update)` or `Err(should_retry)`
152            //
153            // update
154            //   Takes tail pointer, and the cell's previous value,
155            //   and returns `Ok(new_value)` or `Err(should_retry)`
156            //
157            // The function ensures that the tail pointer did not
158            // get updated while the previous value in the cell
159            // was being read.
160            match self.tail.try_update_indirect(|tail| {
161                // deref
162
163                let head = self.head.load(Ordering::Acquire);
164                // If not full
165                if (tail % size2) != (head + size) % size2 {
166                    // Try updating cell at tail position
167                    Ok(&self.ring[tail % size])
168                } else {
169                    // We observed a full queue, so stop trying
170                    Err(false)
171                }
172            }, |tail, cell| {
173                // update
174
175                // If cell at tail is empty
176                if cell & TAG_BIT == 0 {
177                    // Swap in our index, and mark as full
178                    Ok((cell & TAG_MASK).wrapping_add(TAG_BIT) | *index)
179                } else {
180                    // Cell is full, another thread is midway through an insertion
181                    // Try to assist the stalled thread, by advancing the tail pointer for them.
182                    let _ = self.tail.compare_exchange(tail, next_cell(tail, size2), Ordering::AcqRel, Ordering::Acquire);
183                    // Retry the insertion now that we've helped the other thread to progress
184                    Err(true)
185                }
186            }) {
187                Ok((tail, prev_cell, _)) => {
188                    // Update the tail pointer if necessary
189                    let _ = self.tail.compare_exchange(tail, next_cell(tail, size2), Ordering::AcqRel, Ordering::Acquire);
190                    *index = prev_cell & VALUE_MASK;
191                    return true;
192                }
193                Err(false) => return false,
194                Err(true) => {},
195            }
196        }
197    }
198
199    pub unsafe fn pop<R, F: FnOnce(&mut T) -> R>(&self, receiver: F) -> Result<R, ()> {
200        let size = self.ring.len();
201        let size2 = size*2;
202        let head = self.head.load(Ordering::Acquire);
203        let tail = self.tail.load(Ordering::Acquire);
204
205        // If the queue is empty
206        if head % size2 == tail % size2 {
207            Err(())
208        } else {
209            let cell = self.ring[head % size].fetch_add(TAG_BIT, Ordering::AcqRel);
210            assert!(cell & TAG_BIT != 0, "Producer advanced without adding an item!");
211            let result = T::virtual_borrow(cell & VALUE_MASK, receiver);
212            self.head.store((head+1) % size2, Ordering::Release);
213            Ok(result)
214        }
215    }
216}
217
218define_id!(MpscQueueSenderId);
219
220pub struct MpscQueueWrapper<T> {
221    storage: Storage<T>,
222    scratch: Scratch<MpscQueueSenderId, Place<T>>,
223    inner: MpscQueueInner<Place<T>>,
224    id_alloc: IndexAllocator
225}
226
227impl<T> MpscQueueWrapper<T> {
228    pub fn new<H: Handle<HandleInner=Self>>(id_limit: usize, size: usize) -> H {
229        assert!(id_limit > 0);
230        let mut storage = Storage::with_capacity(id_limit + size);
231        let scratch = Scratch::new(storage.none_storing_iter(id_limit));
232        let inner = MpscQueueInner::new(storage.none_storing_iter(size));
233        let id_alloc = IndexAllocator::new(id_limit);
234
235        Handle::new(MpscQueueWrapper {
236            storage: storage,
237            scratch: scratch,
238            inner: inner,
239            id_alloc: id_alloc,
240        })
241    }
242
243    pub unsafe fn push(&self, id: &mut MpscQueueSenderId, value: T) -> Result<(), T> {
244        let place = self.scratch.get_mut(id);
245        self.storage.replace(place, Some(value));
246        if self.inner.push(place) {
247            Ok(())
248        } else {
249            Err(self.storage.replace(place, None).expect("Some(value) in container"))
250        }
251    }
252
253    pub unsafe fn pop(&self) -> Result<T, ()> {
254        self.inner.pop(|place| self.storage.replace(place, None).expect("Some(value) in container"))
255    }
256}
257
258impl<T> HandleInner<MpscQueueSenderId> for MpscQueueWrapper<T> {
259    type IdAllocator = IndexAllocator;
260    fn id_allocator(&self) -> &IndexAllocator {
261        &self.id_alloc
262    }
263    fn raise_id_limit(&mut self, new_limit: usize) {
264        let old_limit = self.id_limit();
265        assert!(new_limit > old_limit);
266        let extra = new_limit - old_limit;
267        self.storage.reserve(extra);
268        self.scratch.extend(self.storage.none_storing_iter(extra));
269        self.id_alloc.resize(new_limit);
270    }
271}
272
273#[derive(Debug)]
274pub struct MpscQueueReceiver<T, H: Handle<HandleInner=MpscQueueWrapper<T>>>(H);
275
276impl<T, H: Handle<HandleInner=MpscQueueWrapper<T>>> MpscQueueReceiver<T, H> {
277    pub fn new(max_senders: usize, size: usize) -> Self {
278        MpscQueueReceiver(MpscQueueWrapper::new(max_senders, size))
279    }
280
281    pub fn receive(&mut self) -> Result<T, ()> {
282        // This is safe because we guarantee that we are unique
283        self.0.with(|inner| unsafe { inner.pop() })
284    }
285}
286
287pub type ResizingMpscQueueReceiver<T> = MpscQueueReceiver<T, ResizingHandle<MpscQueueWrapper<T>>>;
288pub type BoundedMpscQueueReceiver<T> = MpscQueueReceiver<T, BoundedHandle<MpscQueueWrapper<T>>>;
289
290#[derive(Debug)]
291pub struct MpscQueueSender<T, H: Handle<HandleInner=MpscQueueWrapper<T>>>(IdHandle<H, MpscQueueSenderId>);
292
293impl<T, H: Handle<HandleInner=MpscQueueWrapper<T>>> MpscQueueSender<T, H> {
294    pub fn new(receiver: &MpscQueueReceiver<T, H>) -> Self {
295        MpscQueueSender(IdHandle::new(&receiver.0))
296    }
297    pub fn try_new(receiver: &MpscQueueReceiver<T, H>) -> Option<Self> {
298        IdHandle::try_new(&receiver.0).map(MpscQueueSender)
299    }
300
301    pub fn send(&mut self, value: T) -> Result<(), T> {
302        self.0.with_mut(|inner, id| unsafe { inner.push(id, value) })
303    }
304    pub fn try_clone(&self) -> Option<Self> {
305        self.0.try_clone().map(MpscQueueSender)
306    }
307}
308
309impl<T, H: Handle<HandleInner=MpscQueueWrapper<T>>> Clone for MpscQueueSender<T, H> {
310    fn clone(&self) -> Self {
311        MpscQueueSender(self.0.clone())
312    }
313}
314
315pub type ResizingMpscQueueSender<T> = MpscQueueSender<T, ResizingHandle<MpscQueueWrapper<T>>>;
316pub type BoundedMpscQueueSender<T> = MpscQueueSender<T, BoundedHandle<MpscQueueWrapper<T>>>;