sack/
lib.rs

1#![cfg_attr(not(test), no_std)]
2#![doc = include_str!("../README.md")]
3
4//! A lock-free data structure.
5//!
6//! This crate provides a `Sack<T>` type, which is a concurrent, lock-free
7//! collection that supports adding and draining items. See [`Sack<T>`] for more
8//! details.
9//!
10//! This crate also provides a `WakerSet` type, which is a set of wakers that can
11//! be woken all at once. This is useful for implementing synchronization
12//! primitives that need to wake up multiple tasks.
13
14extern crate alloc;
15
16use core::{
17    ptr,
18    sync::atomic::{AtomicPtr, Ordering},
19};
20
21use alloc::boxed::Box;
22
23#[cfg(feature = "waker")]
24mod waker;
25#[cfg(feature = "waker")]
26pub use waker::*;
27
28/// A single entry in the sack.
29struct Entry<T> {
30    /// The item stored in the entry.
31    item: T,
32    /// A pointer to the next entry in the sack.
33    next: *mut Entry<T>,
34}
35
36/// A lock-free sack data structure.
37///
38/// A sack is a concurrent data structure that allows adding items and draining
39/// them in a lock-free manner. It is implemented as a singly-linked list where
40/// the head is an atomic pointer. This allows multiple producers to add items
41/// concurrently without locks.
42///
43/// ## How it works
44///
45/// The `Sack` is essentially a LIFO (last-in, first-out) stack. When an item is
46/// added, it is pushed to the front of the list. When the sack is drained, the
47/// entire list is atomically swapped with an empty list, and the old list is
48/// returned as a draining iterator.
49///
50/// This design has the following properties:
51///
52/// * **Lock-free:** Adding and draining items are lock-free operations, which
53///   means they don't require mutual exclusion. This makes them very fast and
54///   scalable.
55/// * **Concurrent producers:** Multiple threads can add items to the sack
56///   concurrently.
57/// * **Single consumer:** Only one thread can drain the sack at a time. This is
58///   enforced by the `&self` receiver on the `drain` method.
59///
60/// ## Example
61///
62/// ```
63/// use sack::Sack;
64/// use std::sync::Arc;
65/// use std::thread;
66///
67/// let sack = Arc::new(Sack::new());
68///
69/// // Spawn a producer thread.
70/// let producer = {
71///     let sack = Arc::clone(&sack);
72///     thread::spawn(move || {
73///         for i in 0..10 {
74///             sack.add(i);
75///         }
76///     })
77/// };
78///
79/// // Wait for the producer to finish.
80/// producer.join().unwrap();
81///
82/// // Drain the sack and collect the items.
83/// let mut items: Vec<_> = sack.drain().collect();
84/// items.sort();
85///
86/// assert_eq!(items, (0..10).collect::<Vec<_>>());
87/// ```
88pub struct Sack<T> {
89    head: AtomicPtr<Entry<T>>,
90}
91
92impl<T> Default for Sack<T> {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97
98impl<T> Sack<T> {
99    /// Creates a new, empty sack.
100    pub const fn new() -> Self {
101        Self {
102            head: AtomicPtr::new(ptr::null_mut()),
103        }
104    }
105
106    /// Adds an item to the sack.
107    ///
108    /// This operation is lock-free and can be called by multiple threads concurrently.
109    pub fn add(&self, item: T) {
110        let entry = Box::leak(Box::new(Entry {
111            item,
112            next: ptr::null_mut(),
113        }));
114
115        entry.next = self.head.load(Ordering::Acquire);
116        loop {
117            match self.head.compare_exchange_weak(
118                entry.next,
119                entry,
120                Ordering::Release,
121                Ordering::Acquire,
122            ) {
123                Ok(_) => break,
124                Err(current) => entry.next = current,
125            }
126        }
127    }
128
129    /// Drains all items from the sack.
130    ///
131    /// This operation is lock-free and returns a draining iterator over the items in the sack.
132    pub fn drain(&self) -> Drain<T> {
133        let head = self.head.swap(ptr::null_mut(), Ordering::AcqRel);
134        Drain::new(head)
135    }
136
137    /// Checks if the sack is empty.
138    ///
139    /// This operation is lock-free.
140    pub fn is_empty(&self) -> bool {
141        self.head.load(Ordering::Acquire).is_null()
142    }
143}
144
145/// A draining iterator for [`Sack<T>`].
146///
147/// This struct is created by [`Sack<T>::drain`]. See its documentation for more.
148pub struct Drain<T>(Option<Box<Entry<T>>>);
149
150impl<T> Drain<T> {
151    /// Creates a new draining iterator from a pointer to the head of the sack.
152    fn new(ptr: *mut Entry<T>) -> Self {
153        let head = if ptr.is_null() {
154            None
155        } else {
156            Some(unsafe { Box::from_raw(ptr) })
157        };
158        Self(head)
159    }
160}
161impl<T> Iterator for Drain<T> {
162    type Item = T;
163
164    fn next(&mut self) -> Option<Self::Item> {
165        let entry = self.0.take()?;
166        *self = Self::new(entry.next);
167        Some(entry.item)
168    }
169}
170impl<T> Drop for Drain<T> {
171    fn drop(&mut self) {
172        while let Some(entry) = self.0.take() {
173            *self = Self::new(entry.next);
174        }
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use std::{
181        sync::{
182            Arc,
183            atomic::{AtomicUsize, Ordering},
184        },
185        task::{Wake, Waker},
186        thread, vec,
187        vec::Vec,
188    };
189
190    use super::*;
191
192    struct CountingWaker {
193        count: AtomicUsize,
194    }
195
196    impl Wake for CountingWaker {
197        fn wake(self: Arc<Self>) {
198            self.count.fetch_add(1, Ordering::SeqCst);
199        }
200    }
201
202    #[test]
203    fn test_waker_set() {
204        let waker = Arc::new(CountingWaker {
205            count: AtomicUsize::new(0),
206        });
207
208        let wake_set = WakerSet::new();
209        wake_set.add(Waker::from(waker.clone()));
210        wake_set.add(Waker::from(waker.clone()));
211
212        assert_eq!(wake_set.wake_all(), 2);
213        assert_eq!(waker.count.load(Ordering::SeqCst), 2);
214    }
215
216    #[test]
217    fn test_sack_add_drain() {
218        let sack = Sack::new();
219        sack.add(1);
220        sack.add(2);
221        sack.add(3);
222
223        let mut drained: Vec<_> = sack.drain().collect();
224        drained.sort();
225        assert_eq!(drained, vec![1, 2, 3]);
226    }
227
228    #[test]
229    fn test_sack_is_empty() {
230        let sack = Sack::new();
231        assert!(sack.is_empty());
232        sack.add(1);
233        assert!(!sack.is_empty());
234        let _ = sack.drain();
235        assert!(sack.is_empty());
236    }
237
238    #[test]
239    fn test_sack_concurrent_add() {
240        let sack = Arc::new(Sack::new());
241        let mut handles = vec![];
242
243        for i in 0..10 {
244            let sack = Arc::clone(&sack);
245            handles.push(thread::spawn(move || {
246                for j in 0..100 {
247                    sack.add(i * 100 + j);
248                }
249            }));
250        }
251
252        for handle in handles {
253            handle.join().unwrap();
254        }
255
256        let mut drained: Vec<_> = sack.drain().collect();
257        assert_eq!(drained.len(), 1000);
258        drained.sort();
259        for (i, item) in drained.into_iter().enumerate() {
260            assert_eq!(item, i);
261        }
262    }
263}