exclusion_set/
lib.rs

1//! A lock-free concurrent set.  The operations on this set are O(n) where n is
2//! the number of distinct values that have ever been inserted into the set.
3//!
4//! The intended use-case is to provide a way to ensure exclusivity over
5//! critical sections of code on a per-value basis, where the number of distinct
6//! values is small but dynamic.
7//!
8//! This data structure uses atomic singly-linked lists in two forms to enable
9//! its operations.  One list has a node for every distinct value that has
10//! ever been inserted.  The other type of list exists within each of those
11//! nodes, and manages a queue of threads waiting in the `wait_to_insert` method
12//! for another thread to call `remove`.
13//!
14//! An atomic singly-linked list is relatively straightforward to insert to:
15//! Allocate a new node, and then in loop, update the 'next' pointer of the node
16//! to the most recent value of the 'head' pointer, and then attempt a
17//! compare-exchange, replacing the old 'head' with the pointer to the new node.
18//!
19//! Things get more complicated as soon as you additionally consider removing
20//! items from the list.  Anything that dereferences a node pointer now runs the
21//! risk of attempting to dereference a value which has been removed between the
22//! load that returned the pointer and the dereference of the pointer.  Note
23//! that removal itself requires a dereference of the head pointer, to determine
24//! the value of `head.next`.  This data structure avoids this issue in slightly
25//! different ways for the two different types of list.
26//!
27//! The main list of nodes for each value avoids the issue by never removing
28//! nodes except in `Drop`.  The exclusive access guarentee of Drop ensures that
29//! no other thread could attempt to access the list while it is being freed.
30//!
31//! The list of waiting threads instead avoids the issue by specifying, for each
32//! list of waiting threads, which in the context of this set, means for each
33//! unique value, that at most one thread at a time may dereference a pointer.
34//! It exposes this contract as the safety requirement of the unsafe `remove`
35//! method.  This requirement is easy to fulfil for applications where a value
36//! is only removed from the set by a logical "owner" which knows that it
37//! previously inserted a value.
38//!
39//! # Example
40//!
41//! The following code inserts some values into the set, then removes one of
42//! them, and then spawns a second thread that waits to insert into the set.
43//!
44#![cfg_attr(
45    feature = "std",
46    doc = "
47```
48# use std::{any::TypeId, sync::Arc};
49# use exclusion_set::Set;
50# unsafe {
51struct A;
52struct B;
53struct C;
54
55let set: Arc<Set<TypeId>> = Arc::default();
56set.try_insert(TypeId::of::<A>());
57set.try_insert(TypeId::of::<B>());
58set.try_insert(TypeId::of::<C>());
59set.remove(&TypeId::of::<A>());
60let set2 = set.clone();
61# let handle =
62std::thread::spawn(move || {
63    set2.wait_to_insert(TypeId::of::<B>());
64});
65# set.remove(&TypeId::of::<B>()); // avoid a deadlock in the example
66# handle.join();
67# }
68```
69"
70)]
71//!
72//! After this code has been run, we can expect the data structure to look like
73//! this:
74//!
75//! <div style="background-color: white">
76#![doc=include_str!("lib-example.svg")]
77//! </div>
78
79#![cfg_attr(not(feature = "std"), no_std)]
80#![cfg_attr(docsrs, feature(doc_cfg))]
81#![deny(clippy::all, clippy::pedantic)]
82
83extern crate alloc;
84
85use {
86    alloc::boxed::Box,
87    core::{ptr, sync::atomic::Ordering},
88};
89
90#[cfg(loom)]
91use loom::{sync::atomic, thread};
92
93#[cfg(not(loom))]
94use core::sync::atomic;
95
96#[cfg(all(not(loom), feature = "std"))]
97use std::thread;
98
99/// A set of values held in a linked list.
100pub struct Set<T> {
101    /// This pointer is either null (if the set has never been inserted to) or a
102    /// pointer to the first Node in the set.
103    head: atomic::AtomicPtr<Node<T>>,
104    /// This pointer is either null (if the set has never been inserted to) or a
105    /// pointer to one of the nodes in the set.  This enables an optimization
106    /// where the last removed item can be cheaper to re-insert by skipping
107    /// navigating the whole linked-list.
108    last_removed: atomic::AtomicPtr<Node<T>>,
109}
110
111struct Node<T> {
112    /// The value this node was created for.
113    value: T,
114
115    /// The current status of the associated value; null if the value is
116    /// currently considered absent, `occupied` if the value is currently
117    /// considered present, or a valid pointer if the value is currently
118    /// considered present and one or more threads are waiting to insert it.
119    status: atomic::AtomicPtr<WaitingThreadNode>,
120
121    /// The next node, or `null` if this is the end of the list.
122    next: *const Node<T>,
123}
124
125/// This otherwise invalid pointer is used as a marker value that a value is
126/// currently considered "in" the set.
127fn occupied() -> *mut WaitingThreadNode {
128    static RESERVED_MEMORY: usize = usize::from_ne_bytes([0xA5; core::mem::size_of::<usize>()]);
129    core::ptr::addr_of!(RESERVED_MEMORY).cast_mut().cast()
130}
131
132/// This type is used for a stack-based linked list of waiting threads, so that
133/// when a value is removed from the set, a thread which is waiting to insert
134/// that value can be notified that it may proceed.
135struct WaitingThreadNode {
136    /// The handle of the waiting thread.
137    #[cfg(feature = "std")]
138    thread: thread::Thread,
139
140    /// A flag to indicate if this node has been removed from the list of
141    /// waiting threads, and the thread should stop waiting.
142    #[cfg(feature = "std")]
143    popped: atomic::AtomicBool,
144
145    /// The next node, or `occupied`, if there are no more waiting threads.
146    next: *mut WaitingThreadNode,
147}
148
149impl<T> Default for Set<T> {
150    fn default() -> Self {
151        Self {
152            head: atomic::AtomicPtr::new(ptr::null_mut()),
153            last_removed: atomic::AtomicPtr::new(ptr::null_mut()),
154        }
155    }
156}
157
158impl<T> Set<T> {
159    /// Create a new, empty, `Set`.
160    #[cfg(not(loom))]
161    #[must_use]
162    pub const fn new() -> Self {
163        Self {
164            head: atomic::AtomicPtr::new(ptr::null_mut()),
165            last_removed: atomic::AtomicPtr::new(ptr::null_mut()),
166        }
167    }
168
169    /// Search linearly through the list for a node with the given value. If it
170    /// was not found, return the value of the head pointer when we started
171    /// searching (it is assured a node with that value cannot be found by
172    /// traversing from that pointer onward).
173    fn find(&self, value: &T) -> Result<&Node<T>, *const Node<T>>
174    where
175        T: Eq,
176    {
177        let starting_node = self.last_removed.load(Ordering::Acquire).cast_const();
178        let mut current_node = starting_node;
179        // Safety: current_node is loaded from self.last_dropped or node.next, both of
180        // which only ever store null or valid pointers created by
181        // Box::into_raw, so it's safe to call .as_ref on it here
182        while let Some(node) = unsafe { current_node.as_ref() } {
183            if node.value == *value {
184                return Ok(node);
185            }
186            current_node = node.next;
187        }
188        let original_head = self.head.load(Ordering::Acquire).cast_const();
189        let mut current_node = original_head;
190        while current_node != starting_node {
191            // Safety: current_node is loaded from self.head or node.next, both of
192            // which only ever store null or valid pointers created by
193            // Box::into_raw, and it won't be null since it must not be the end of
194            // the chain if it wasn't equal to starting_node
195            let node = unsafe { &*current_node };
196            if node.value == *value {
197                return Ok(node);
198            }
199            current_node = node.next;
200        }
201        Err(original_head)
202    }
203
204    /// Try to insert a value into the set.  Returns `true` if the value was
205    /// inserted or `false` if the value was already considered present.
206    #[must_use]
207    pub fn try_insert(&self, value: T) -> bool
208    where
209        T: Eq,
210    {
211        self.try_insert_inner(value).is_ok()
212    }
213
214    /// Try to insert a value into the set.  Returns Ok if the value was
215    /// inserted, or the occupied node if the value was already considered
216    /// present.
217    fn try_insert_inner(&self, value: T) -> Result<(), &Node<T>>
218    where
219        T: Eq,
220    {
221        let next = match self.find(&value) {
222            Ok(node) => {
223                // The failure Ordering can be Relaxed here, because we don't
224                // try to read any data associated with the value, we just care
225                // if the operation succeeded or not.
226                return node
227                    .status
228                    .compare_exchange(
229                        ptr::null_mut(),
230                        occupied(),
231                        Ordering::Acquire,
232                        Ordering::Relaxed,
233                    )
234                    .map(|_| ())
235                    .map_err(|_| node);
236            }
237            Err(original_head) => original_head,
238        };
239        let new_node = Box::into_raw(Box::new(Node {
240            value,
241            status: atomic::AtomicPtr::new(occupied()),
242            next,
243        }));
244        // Safety: we just created the pointer from the box, so it's safe to
245        // dereference here
246        let Node {
247            ref value,
248            ref mut next,
249            ..
250        } = unsafe { &mut *new_node };
251        let mut found_and_set = Ok(());
252        self.head
253            .fetch_update(Ordering::Release, Ordering::Acquire, |most_recent_head| {
254                let most_recent_head = most_recent_head.cast_const();
255                let mut current_next = most_recent_head;
256                loop {
257                    if current_next == *next {
258                        *next = most_recent_head;
259                        return Some(new_node);
260                    }
261                    // Safety: current_next is loaded from self.head or
262                    // node.next, both of which only ever store null or valid
263                    // pointers created by Box::into_raw, and only the last in
264                    // the chain can be null, in which case we would have caught
265                    // it with the previous condition, so it's safe to
266                    // dereference here.
267                    let node = unsafe { &*current_next };
268                    if &node.value == value {
269                        // The failure Ordering can be Relaxed here, because we
270                        // don't try to read any data associated with the value,
271                        // we just care if the operation succeeded or not.
272                        found_and_set = node
273                            .status
274                            .compare_exchange(
275                                ptr::null_mut(),
276                                occupied(),
277                                Ordering::Acquire,
278                                Ordering::Relaxed,
279                            )
280                            .map(|_| ())
281                            .map_err(|_| node);
282                        return None;
283                    }
284                    current_next = node.next;
285                }
286            })
287            .map(|_| ())
288            .or_else(|_| {
289                // Safety: in the error case, we have not stored the box anywhere else
290                // so we can free it here
291                let _: Box<Node<T>> = unsafe { Box::from_raw(new_node) };
292                found_and_set
293            })
294    }
295
296    /// If the value provided is not in the set, insert it.  Otherwise, block
297    /// the current thread until another thread calls `remove` for the given
298    /// value (if multiple threads are waiting, only one of them will
299    /// return).
300    #[cfg(feature = "std")]
301    #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
302    pub fn wait_to_insert(&self, value: T)
303    where
304        T: Eq,
305    {
306        let Err(node) = self.try_insert_inner(value) else { return };
307        let mut waiting_node = WaitingThreadNode {
308            thread: thread::current(),
309            popped: atomic::AtomicBool::new(false),
310            next: occupied(),
311        };
312        let mut status_guess = occupied();
313        let mut set_status_to: *mut WaitingThreadNode = &mut waiting_node;
314        while let Err(status) = node.status.compare_exchange_weak(
315            status_guess,
316            set_status_to,
317            Ordering::Release,
318            Ordering::Acquire,
319        ) {
320            status_guess = status;
321            if status.is_null() {
322                set_status_to = occupied();
323            } else {
324                waiting_node.next = status;
325                set_status_to = &mut waiting_node;
326            }
327        }
328        if set_status_to == occupied() {
329            // The status was null, so we didn't end up needing to wait.
330            return;
331        }
332        loop {
333            if waiting_node.popped.load(Ordering::Acquire) {
334                break;
335            }
336            thread::park();
337        }
338        drop(waiting_node);
339    }
340
341    /// Mark a value as absent from the set, or notify a waiting thread that
342    /// it may proceed.
343    ///
344    /// Returns true if the value was present in the set.
345    ///
346    /// # Safety
347    /// Must not be called concurrently from multiple threads with the same
348    /// value.
349    pub unsafe fn remove(&self, value: &T) -> bool
350    where
351        T: Eq,
352    {
353        let Ok(node) = self.find(value) else { return false };
354        let mut status_guess = occupied();
355        let mut set_status_to = ptr::null_mut();
356        while let Err(status) = node.status.compare_exchange_weak(
357            status_guess,
358            set_status_to,
359            Ordering::AcqRel,
360            Ordering::Acquire,
361        ) {
362            if status.is_null() {
363                return false;
364            } else if status == occupied() {
365                set_status_to = ptr::null_mut();
366                status_guess = status;
367            } else {
368                // Safety: `status` is either null, `occupied()`, or valid, and
369                // we just checked that it wasn't null or `occupied`, so it's
370                // safe to dereference here.  The pointer is still alive unless
371                // this function is called concurrently, which is why it's an
372                // unsafe function with that condition.
373                set_status_to = unsafe { (*status).next };
374                status_guess = status;
375            }
376        }
377        self.last_removed
378            .store(<*const _>::cast_mut(node), Ordering::Release);
379        // If we were successful, it's because our guess was correct, so
380        // `status_guess` holds the previous value of `node.status`.
381        #[cfg(feature = "std")]
382        if status_guess != occupied() {
383            // Safety: `status` is either null, `occupied`, or valid. If it was
384            // null, we would have returned false up above, and we just checked
385            // that it wasn't `occupied()`.  The pointer is still alive unless
386            // this function is called concurrently, which is why it's an unsafe
387            // function with that condition.
388            let WaitingThreadNode { thread, popped, .. } = unsafe { &*status_guess };
389            // Clone the thread handle here because it could be invalid as soon
390            // as we store into `popped`.
391            let thread = thread.clone();
392            popped.store(true, Ordering::Release);
393            thread.unpark();
394        }
395        true
396    }
397}
398
399impl<T> Drop for Set<T> {
400    fn drop(&mut self) {
401        #[cfg(loom)]
402        let mut node = self
403            .head
404            .with_mut(|p| core::mem::replace(p, ptr::null_mut()));
405        #[cfg(not(loom))]
406        let mut node = core::mem::replace(self.head.get_mut(), ptr::null_mut());
407        while !node.is_null() {
408            // Node pointers are either null or valid pointers created by
409            // `Box::into_raw`, and we just checked that it was not null, so
410            // it's safe to call `Box::from_raw` here.
411            let boxed = unsafe { Box::from_raw(node) };
412            node = boxed.next.cast_mut();
413        }
414    }
415}