exclusion_set/lib.rs
1//! A lock-free concurrent set. The operations on this set are O(n) where n is
2//! the number of distinct values that have ever been inserted into the set.
3//!
4//! The intended use-case is to provide a way to ensure exclusivity over
5//! critical sections of code on a per-value basis, where the number of distinct
6//! values is small but dynamic.
7//!
8//! This data structure uses atomic singly-linked lists in two forms to enable
9//! its operations. One list has a node for every distinct value that has
10//! ever been inserted. The other type of list exists within each of those
11//! nodes, and manages a queue of threads waiting in the `wait_to_insert` method
12//! for another thread to call `remove`.
13//!
14//! An atomic singly-linked list is relatively straightforward to insert to:
15//! Allocate a new node, and then in loop, update the 'next' pointer of the node
16//! to the most recent value of the 'head' pointer, and then attempt a
17//! compare-exchange, replacing the old 'head' with the pointer to the new node.
18//!
19//! Things get more complicated as soon as you additionally consider removing
20//! items from the list. Anything that dereferences a node pointer now runs the
21//! risk of attempting to dereference a value which has been removed between the
22//! load that returned the pointer and the dereference of the pointer. Note
23//! that removal itself requires a dereference of the head pointer, to determine
24//! the value of `head.next`. This data structure avoids this issue in slightly
25//! different ways for the two different types of list.
26//!
27//! The main list of nodes for each value avoids the issue by never removing
28//! nodes except in `Drop`. The exclusive access guarentee of Drop ensures that
29//! no other thread could attempt to access the list while it is being freed.
30//!
31//! The list of waiting threads instead avoids the issue by specifying, for each
32//! list of waiting threads, which in the context of this set, means for each
33//! unique value, that at most one thread at a time may dereference a pointer.
34//! It exposes this contract as the safety requirement of the unsafe `remove`
35//! method. This requirement is easy to fulfil for applications where a value
36//! is only removed from the set by a logical "owner" which knows that it
37//! previously inserted a value.
38//!
39//! # Example
40//!
41//! The following code inserts some values into the set, then removes one of
42//! them, and then spawns a second thread that waits to insert into the set.
43//!
44#![cfg_attr(
45 feature = "std",
46 doc = "
47```
48# use std::{any::TypeId, sync::Arc};
49# use exclusion_set::Set;
50# unsafe {
51struct A;
52struct B;
53struct C;
54
55let set: Arc<Set<TypeId>> = Arc::default();
56set.try_insert(TypeId::of::<A>());
57set.try_insert(TypeId::of::<B>());
58set.try_insert(TypeId::of::<C>());
59set.remove(&TypeId::of::<A>());
60let set2 = set.clone();
61# let handle =
62std::thread::spawn(move || {
63 set2.wait_to_insert(TypeId::of::<B>());
64});
65# set.remove(&TypeId::of::<B>()); // avoid a deadlock in the example
66# handle.join();
67# }
68```
69"
70)]
71//!
72//! After this code has been run, we can expect the data structure to look like
73//! this:
74//!
75//! <div style="background-color: white">
76#![doc=include_str!("lib-example.svg")]
77//! </div>
78
79#![cfg_attr(not(feature = "std"), no_std)]
80#![cfg_attr(docsrs, feature(doc_cfg))]
81#![deny(clippy::all, clippy::pedantic)]
82
83extern crate alloc;
84
85use {
86 alloc::boxed::Box,
87 core::{ptr, sync::atomic::Ordering},
88};
89
90#[cfg(loom)]
91use loom::{sync::atomic, thread};
92
93#[cfg(not(loom))]
94use core::sync::atomic;
95
96#[cfg(all(not(loom), feature = "std"))]
97use std::thread;
98
99/// A set of values held in a linked list.
100pub struct Set<T> {
101 /// This pointer is either null (if the set has never been inserted to) or a
102 /// pointer to the first Node in the set.
103 head: atomic::AtomicPtr<Node<T>>,
104 /// This pointer is either null (if the set has never been inserted to) or a
105 /// pointer to one of the nodes in the set. This enables an optimization
106 /// where the last removed item can be cheaper to re-insert by skipping
107 /// navigating the whole linked-list.
108 last_removed: atomic::AtomicPtr<Node<T>>,
109}
110
111struct Node<T> {
112 /// The value this node was created for.
113 value: T,
114
115 /// The current status of the associated value; null if the value is
116 /// currently considered absent, `occupied` if the value is currently
117 /// considered present, or a valid pointer if the value is currently
118 /// considered present and one or more threads are waiting to insert it.
119 status: atomic::AtomicPtr<WaitingThreadNode>,
120
121 /// The next node, or `null` if this is the end of the list.
122 next: *const Node<T>,
123}
124
125/// This otherwise invalid pointer is used as a marker value that a value is
126/// currently considered "in" the set.
127fn occupied() -> *mut WaitingThreadNode {
128 static RESERVED_MEMORY: usize = usize::from_ne_bytes([0xA5; core::mem::size_of::<usize>()]);
129 core::ptr::addr_of!(RESERVED_MEMORY).cast_mut().cast()
130}
131
132/// This type is used for a stack-based linked list of waiting threads, so that
133/// when a value is removed from the set, a thread which is waiting to insert
134/// that value can be notified that it may proceed.
135struct WaitingThreadNode {
136 /// The handle of the waiting thread.
137 #[cfg(feature = "std")]
138 thread: thread::Thread,
139
140 /// A flag to indicate if this node has been removed from the list of
141 /// waiting threads, and the thread should stop waiting.
142 #[cfg(feature = "std")]
143 popped: atomic::AtomicBool,
144
145 /// The next node, or `occupied`, if there are no more waiting threads.
146 next: *mut WaitingThreadNode,
147}
148
149impl<T> Default for Set<T> {
150 fn default() -> Self {
151 Self {
152 head: atomic::AtomicPtr::new(ptr::null_mut()),
153 last_removed: atomic::AtomicPtr::new(ptr::null_mut()),
154 }
155 }
156}
157
158impl<T> Set<T> {
159 /// Create a new, empty, `Set`.
160 #[cfg(not(loom))]
161 #[must_use]
162 pub const fn new() -> Self {
163 Self {
164 head: atomic::AtomicPtr::new(ptr::null_mut()),
165 last_removed: atomic::AtomicPtr::new(ptr::null_mut()),
166 }
167 }
168
169 /// Search linearly through the list for a node with the given value. If it
170 /// was not found, return the value of the head pointer when we started
171 /// searching (it is assured a node with that value cannot be found by
172 /// traversing from that pointer onward).
173 fn find(&self, value: &T) -> Result<&Node<T>, *const Node<T>>
174 where
175 T: Eq,
176 {
177 let starting_node = self.last_removed.load(Ordering::Acquire).cast_const();
178 let mut current_node = starting_node;
179 // Safety: current_node is loaded from self.last_dropped or node.next, both of
180 // which only ever store null or valid pointers created by
181 // Box::into_raw, so it's safe to call .as_ref on it here
182 while let Some(node) = unsafe { current_node.as_ref() } {
183 if node.value == *value {
184 return Ok(node);
185 }
186 current_node = node.next;
187 }
188 let original_head = self.head.load(Ordering::Acquire).cast_const();
189 let mut current_node = original_head;
190 while current_node != starting_node {
191 // Safety: current_node is loaded from self.head or node.next, both of
192 // which only ever store null or valid pointers created by
193 // Box::into_raw, and it won't be null since it must not be the end of
194 // the chain if it wasn't equal to starting_node
195 let node = unsafe { &*current_node };
196 if node.value == *value {
197 return Ok(node);
198 }
199 current_node = node.next;
200 }
201 Err(original_head)
202 }
203
204 /// Try to insert a value into the set. Returns `true` if the value was
205 /// inserted or `false` if the value was already considered present.
206 #[must_use]
207 pub fn try_insert(&self, value: T) -> bool
208 where
209 T: Eq,
210 {
211 self.try_insert_inner(value).is_ok()
212 }
213
214 /// Try to insert a value into the set. Returns Ok if the value was
215 /// inserted, or the occupied node if the value was already considered
216 /// present.
217 fn try_insert_inner(&self, value: T) -> Result<(), &Node<T>>
218 where
219 T: Eq,
220 {
221 let next = match self.find(&value) {
222 Ok(node) => {
223 // The failure Ordering can be Relaxed here, because we don't
224 // try to read any data associated with the value, we just care
225 // if the operation succeeded or not.
226 return node
227 .status
228 .compare_exchange(
229 ptr::null_mut(),
230 occupied(),
231 Ordering::Acquire,
232 Ordering::Relaxed,
233 )
234 .map(|_| ())
235 .map_err(|_| node);
236 }
237 Err(original_head) => original_head,
238 };
239 let new_node = Box::into_raw(Box::new(Node {
240 value,
241 status: atomic::AtomicPtr::new(occupied()),
242 next,
243 }));
244 // Safety: we just created the pointer from the box, so it's safe to
245 // dereference here
246 let Node {
247 ref value,
248 ref mut next,
249 ..
250 } = unsafe { &mut *new_node };
251 let mut found_and_set = Ok(());
252 self.head
253 .fetch_update(Ordering::Release, Ordering::Acquire, |most_recent_head| {
254 let most_recent_head = most_recent_head.cast_const();
255 let mut current_next = most_recent_head;
256 loop {
257 if current_next == *next {
258 *next = most_recent_head;
259 return Some(new_node);
260 }
261 // Safety: current_next is loaded from self.head or
262 // node.next, both of which only ever store null or valid
263 // pointers created by Box::into_raw, and only the last in
264 // the chain can be null, in which case we would have caught
265 // it with the previous condition, so it's safe to
266 // dereference here.
267 let node = unsafe { &*current_next };
268 if &node.value == value {
269 // The failure Ordering can be Relaxed here, because we
270 // don't try to read any data associated with the value,
271 // we just care if the operation succeeded or not.
272 found_and_set = node
273 .status
274 .compare_exchange(
275 ptr::null_mut(),
276 occupied(),
277 Ordering::Acquire,
278 Ordering::Relaxed,
279 )
280 .map(|_| ())
281 .map_err(|_| node);
282 return None;
283 }
284 current_next = node.next;
285 }
286 })
287 .map(|_| ())
288 .or_else(|_| {
289 // Safety: in the error case, we have not stored the box anywhere else
290 // so we can free it here
291 let _: Box<Node<T>> = unsafe { Box::from_raw(new_node) };
292 found_and_set
293 })
294 }
295
296 /// If the value provided is not in the set, insert it. Otherwise, block
297 /// the current thread until another thread calls `remove` for the given
298 /// value (if multiple threads are waiting, only one of them will
299 /// return).
300 #[cfg(feature = "std")]
301 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
302 pub fn wait_to_insert(&self, value: T)
303 where
304 T: Eq,
305 {
306 let Err(node) = self.try_insert_inner(value) else { return };
307 let mut waiting_node = WaitingThreadNode {
308 thread: thread::current(),
309 popped: atomic::AtomicBool::new(false),
310 next: occupied(),
311 };
312 let mut status_guess = occupied();
313 let mut set_status_to: *mut WaitingThreadNode = &mut waiting_node;
314 while let Err(status) = node.status.compare_exchange_weak(
315 status_guess,
316 set_status_to,
317 Ordering::Release,
318 Ordering::Acquire,
319 ) {
320 status_guess = status;
321 if status.is_null() {
322 set_status_to = occupied();
323 } else {
324 waiting_node.next = status;
325 set_status_to = &mut waiting_node;
326 }
327 }
328 if set_status_to == occupied() {
329 // The status was null, so we didn't end up needing to wait.
330 return;
331 }
332 loop {
333 if waiting_node.popped.load(Ordering::Acquire) {
334 break;
335 }
336 thread::park();
337 }
338 drop(waiting_node);
339 }
340
341 /// Mark a value as absent from the set, or notify a waiting thread that
342 /// it may proceed.
343 ///
344 /// Returns true if the value was present in the set.
345 ///
346 /// # Safety
347 /// Must not be called concurrently from multiple threads with the same
348 /// value.
349 pub unsafe fn remove(&self, value: &T) -> bool
350 where
351 T: Eq,
352 {
353 let Ok(node) = self.find(value) else { return false };
354 let mut status_guess = occupied();
355 let mut set_status_to = ptr::null_mut();
356 while let Err(status) = node.status.compare_exchange_weak(
357 status_guess,
358 set_status_to,
359 Ordering::AcqRel,
360 Ordering::Acquire,
361 ) {
362 if status.is_null() {
363 return false;
364 } else if status == occupied() {
365 set_status_to = ptr::null_mut();
366 status_guess = status;
367 } else {
368 // Safety: `status` is either null, `occupied()`, or valid, and
369 // we just checked that it wasn't null or `occupied`, so it's
370 // safe to dereference here. The pointer is still alive unless
371 // this function is called concurrently, which is why it's an
372 // unsafe function with that condition.
373 set_status_to = unsafe { (*status).next };
374 status_guess = status;
375 }
376 }
377 self.last_removed
378 .store(<*const _>::cast_mut(node), Ordering::Release);
379 // If we were successful, it's because our guess was correct, so
380 // `status_guess` holds the previous value of `node.status`.
381 #[cfg(feature = "std")]
382 if status_guess != occupied() {
383 // Safety: `status` is either null, `occupied`, or valid. If it was
384 // null, we would have returned false up above, and we just checked
385 // that it wasn't `occupied()`. The pointer is still alive unless
386 // this function is called concurrently, which is why it's an unsafe
387 // function with that condition.
388 let WaitingThreadNode { thread, popped, .. } = unsafe { &*status_guess };
389 // Clone the thread handle here because it could be invalid as soon
390 // as we store into `popped`.
391 let thread = thread.clone();
392 popped.store(true, Ordering::Release);
393 thread.unpark();
394 }
395 true
396 }
397}
398
399impl<T> Drop for Set<T> {
400 fn drop(&mut self) {
401 #[cfg(loom)]
402 let mut node = self
403 .head
404 .with_mut(|p| core::mem::replace(p, ptr::null_mut()));
405 #[cfg(not(loom))]
406 let mut node = core::mem::replace(self.head.get_mut(), ptr::null_mut());
407 while !node.is_null() {
408 // Node pointers are either null or valid pointers created by
409 // `Box::into_raw`, and we just checked that it was not null, so
410 // it's safe to call `Box::from_raw` here.
411 let boxed = unsafe { Box::from_raw(node) };
412 node = boxed.next.cast_mut();
413 }
414 }
415}