conquer_util/
local.rs

1#[cfg(all(feature = "alloc", not(feature = "std")))]
2use alloc::boxed::Box;
3
4use core::cell::UnsafeCell;
5use core::fmt;
6use core::hint;
7use core::ops::Index;
8use core::sync::atomic::{AtomicUsize, Ordering};
9
10////////////////////////////////////////////////////////////////////////////////////////////////////
11// BoundedThreadLocal
12////////////////////////////////////////////////////////////////////////////////////////////////////
13
14/// A one-shot, lock-free, bounded per-value thread local storage.
15///
16/// # Examples
17///
18/// ```
19/// use std::thread;
20/// use std::sync::Arc;
21///
22/// use conquer_util::BoundedThreadLocal;
23///
24/// const THREADS: usize = 4;
25/// let counter = Arc::new(BoundedThreadLocal::<usize>::new(THREADS));
26///
27/// let handles: Vec<_> = (0..THREADS)
28///     .map(|id| {
29///         let counter = Arc::clone(&counter);
30///         thread::spawn(move || {
31///             let mut token = counter.thread_token().unwrap();
32///             *token.get_mut() += id;
33///         })
34///     })
35///     .collect();
36///
37/// for handle in handles {
38///     handle.join().unwrap();
39/// }
40///
41/// let mut counter = Arc::try_unwrap(counter).unwrap();
42///
43/// let sum: usize = counter.into_iter().sum();
44/// assert_eq!(sum, (0..4).sum());
45/// ```
46pub struct BoundedThreadLocal<'s, T> {
47    storage: Storage<'s, T>,
48    registered: AtomicUsize,
49    completed: AtomicUsize,
50}
51
52/********** impl Send + Sync **********************************************************************/
53
54unsafe impl<T> Send for BoundedThreadLocal<'_, T> {}
55unsafe impl<T> Sync for BoundedThreadLocal<'_, T> {}
56
57/********** impl inherent *************************************************************************/
58
59impl<'s, T: Default> BoundedThreadLocal<'s, T> {
60    /// Creates a new [`Default`] initialized [`BoundedThreadLocal`] that
61    /// internally allocates a buffer of `max_size`.
62    ///
63    /// # Panics
64    ///
65    /// This method panics, if `max_size` is 0.
66    #[inline]
67    pub fn new(max_threads: usize) -> Self {
68        Self::with_init(max_threads, Default::default)
69    }
70}
71
72impl<'s, T> BoundedThreadLocal<'s, T> {
73    /// Creates a new [`BoundedThreadLocal`] that internally allocates a buffer
74    /// of `max_size` and initializes each [`Local`] with `init`.
75    ///
76    /// # Panics
77    ///
78    /// This method panics, if `max_size` is 0.
79    #[inline]
80    pub fn with_init(max_threads: usize, init: impl Fn() -> T) -> Self {
81        assert!(max_threads > 0, "`max_threads` must be greater than 0");
82        Self {
83            storage: Storage::Heap(
84                (0..max_threads).map(|_| Local(UnsafeCell::new(Some(init())))).collect(),
85            ),
86            registered: AtomicUsize::new(0),
87            completed: AtomicUsize::new(0),
88        }
89    }
90}
91
92impl<'s, T> BoundedThreadLocal<'s, T> {
93    /// Creates a new [`BoundedThreadLocal`] that is based on a separate `buffer`.
94    ///
95    /// # Safety
96    ///
97    /// The given `buf` must be treated as if it were mutably, i.e. it **must
98    /// not** be used or otherwise accessed during the lifetime of the
99    /// [`BoundedThreadLocal`] that borrows it.
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// use conquer_util::{BoundedThreadLocal, Local};
105    ///
106    /// static BUF: [Local<usize>; 4] =
107    ///     [Local::new(0), Local::new(0), Local::new(0), Local::new(0)];
108    /// static TLS: BoundedThreadLocal<usize> = unsafe { BoundedThreadLocal::with_buffer(&BUF) };
109    /// assert_eq!(TLS.thread_token().unwrap().get(), &0);
110    /// ```
111    #[inline]
112    pub const unsafe fn with_buffer(buf: &'s [Local<T>]) -> Self {
113        Self {
114            storage: Storage::Buffer(buf),
115            registered: AtomicUsize::new(0),
116            completed: AtomicUsize::new(0),
117        }
118    }
119
120    /// Returns a thread local token to a unique instance of `T`.
121    ///
122    /// The thread local instance will **not** be dropped, when the token itself
123    /// is dropped and can e.g. be iterated afterwards.
124    ///
125    /// # Errors
126    ///
127    /// This method fails, if the maximum number of tokens for this
128    /// [`BoundedThreadLocal`] has already been acquired.
129    ///
130    /// # Examples
131    ///
132    /// ```
133    /// use conquer_util::BoundedThreadLocal;
134    ///
135    /// # fn main() -> Result<(), conquer_util::BoundsError> {
136    ///
137    /// let tls = BoundedThreadLocal::<usize>::new(1);
138    /// let mut token = tls.thread_token()?;
139    /// *token.get_mut() += 1;
140    /// assert_eq!(token.get(), &1);
141    ///
142    /// # Ok(())
143    /// # }
144    /// ```
145    #[inline]
146    pub fn thread_token(&self) -> Result<Token<'_, '_, T>, BoundsError> {
147        let token: usize = self.registered.fetch_add(1, Ordering::Relaxed);
148        assert!(token <= isize::max_value() as usize, "thread counter close to overflow");
149
150        if token < self.storage.len() {
151            let local = unsafe { self.storage[token].as_ptr_unchecked() };
152            Ok(Token { local, tls: self })
153        } else {
154            Err(BoundsError(()))
155        }
156    }
157
158    /// Attempts to create an [`Iter`] over all [`Local`] instances.
159    ///
160    /// # Errors
161    ///
162    /// Fails, if there are still outstanding thread tokens, that might
163    /// concurrently access any of the thread local state instances.
164    #[inline]
165    pub fn try_iter(&self) -> Result<Iter<T>, ConcurrentAccessErr> {
166        let (completed, len) = (self.completed.load(Ordering::Relaxed), self.storage.len());
167        if completed == len || completed == self.registered.load(Ordering::Relaxed) {
168            Ok(Iter { idx: 0, tls: self })
169        } else {
170            Err(ConcurrentAccessErr(()))
171        }
172    }
173}
174
175/********** impl IntoIterator *********************************************************************/
176
177impl<'s, T> IntoIterator for BoundedThreadLocal<'s, T> {
178    type Item = T;
179    type IntoIter = IntoIter<'s, T>;
180
181    #[inline]
182    fn into_iter(self) -> Self::IntoIter {
183        IntoIter { tls: self, idx: 0 }
184    }
185}
186
187/********** impl Debug ****************************************************************************/
188
189impl<T> fmt::Debug for BoundedThreadLocal<'_, T> {
190    #[inline]
191    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192        f.debug_struct("BoundedThreadLocal")
193            .field("max_size", &self.storage.len())
194            .field("access_count", &self.registered.load(Ordering::Relaxed))
195            .finish()
196    }
197}
198
199////////////////////////////////////////////////////////////////////////////////////////////////////
200// Local
201////////////////////////////////////////////////////////////////////////////////////////////////////
202
203/// A wrapper for an instance of `T` that can be managed by a
204/// [`BoundedThreadLocal`].
205#[derive(Debug, Default)]
206#[repr(align(64))]
207pub struct Local<T>(UnsafeCell<Option<T>>);
208
209/********** impl Send + Sync **********************************************************************/
210
211unsafe impl<T> Send for Local<T> {}
212unsafe impl<T> Sync for Local<T> {}
213
214/********** impl inherent *************************************************************************/
215
216impl<T> Local<T> {
217    /// Creates a new [`Local`].
218    #[inline]
219    pub const fn new(local: T) -> Self {
220        Self(UnsafeCell::new(Some(local)))
221    }
222
223    #[inline]
224    unsafe fn as_ptr_unchecked(&self) -> *mut T {
225        (*self.0.get()).as_mut().unwrap_or_else(|| hint::unreachable_unchecked())
226    }
227
228    #[inline]
229    unsafe fn take_unchecked(&self) -> T {
230        (*self.0.get()).take().unwrap_or_else(|| hint::unreachable_unchecked())
231    }
232}
233
234////////////////////////////////////////////////////////////////////////////////////////////////////
235// Token
236////////////////////////////////////////////////////////////////////////////////////////////////////
237
238/// A thread local token granting unique access to an instance of `T` that is
239/// contained in a [`BoundedThreadLocal`]
240pub struct Token<'s, 'tls, T> {
241    local: *mut T,
242    tls: &'tls BoundedThreadLocal<'s, T>,
243}
244
245/********** impl inherent *************************************************************************/
246
247impl<T> Token<'_, '_, T> {
248    /// Returns a reference to the thread local state.
249    #[inline]
250    pub fn get(&self) -> &T {
251        unsafe { &(*self.local) }
252    }
253
254    /// Returns a mutable reference to the thread local state.
255    #[inline]
256    pub fn get_mut(&mut self) -> &mut T {
257        unsafe { &mut (*self.local) }
258    }
259}
260
261/********** impl Debug ****************************************************************************/
262
263impl<T: fmt::Debug> fmt::Debug for Token<'_, '_, T> {
264    #[inline]
265    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
266        f.debug_struct("Token").field("slot", &self.get()).finish()
267    }
268}
269
270/********** impl Display **************************************************************************/
271
272impl<T: fmt::Display> fmt::Display for Token<'_, '_, T> {
273    #[inline]
274    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
275        fmt::Display::fmt(&self.get(), f)
276    }
277}
278
279/********** impl Drop *****************************************************************************/
280
281impl<T> Drop for Token<'_, '_, T> {
282    fn drop(&mut self) {
283        self.tls.completed.fetch_add(1, Ordering::Relaxed);
284    }
285}
286
287////////////////////////////////////////////////////////////////////////////////////////////////////
288// Iter
289////////////////////////////////////////////////////////////////////////////////////////////////////
290
291#[derive(Debug)]
292pub struct Iter<'s, 'tls, T> {
293    idx: usize,
294    tls: &'tls BoundedThreadLocal<'s, T>,
295}
296
297/********** impl Iterator *************************************************************************/
298
299impl<'s, 'tls, T> Iterator for Iter<'s, 'tls, T> {
300    type Item = &'tls T;
301
302    #[inline]
303    fn next(&mut self) -> Option<Self::Item> {
304        let idx = self.idx;
305        if idx < self.tls.storage.len() {
306            self.idx += 1;
307            let local = unsafe { &*self.tls.storage[idx].as_ptr_unchecked() };
308            Some(local)
309        } else {
310            None
311        }
312    }
313}
314
315////////////////////////////////////////////////////////////////////////////////////////////////////
316// IntoIter
317////////////////////////////////////////////////////////////////////////////////////////////////////
318
319/// An owning iterator that can be created from an owned [`BoundedThreadLocal`].
320#[derive(Debug)]
321pub struct IntoIter<'s, T> {
322    idx: usize,
323    tls: BoundedThreadLocal<'s, T>,
324}
325
326/********** impl Iterator *************************************************************************/
327
328impl<T> Iterator for IntoIter<'_, T> {
329    type Item = T;
330
331    #[inline]
332    fn next(&mut self) -> Option<Self::Item> {
333        let idx = self.idx;
334        if idx < self.tls.storage.len() {
335            self.idx += 1;
336            let local = unsafe { self.tls.storage[idx].take_unchecked() };
337            Some(local)
338        } else {
339            None
340        }
341    }
342}
343
344////////////////////////////////////////////////////////////////////////////////////////////////////
345// BoundsError
346////////////////////////////////////////////////////////////////////////////////////////////////////
347
348/// An Error for signalling than more than the specified maximum number of
349/// threads attempted to access a [`BoundedThreadLocal`].
350#[derive(Copy, Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
351pub struct BoundsError(());
352
353/********** impl Display **************************************************************************/
354
355impl fmt::Display for BoundsError {
356    #[inline]
357    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
358        write!(f, "exceeded bounds for `BoundedThreadLocal`")
359    }
360}
361
362/********** impl Error ****************************************************************************/
363
364#[cfg(feature = "std")]
365impl std::error::Error for BoundsError {}
366
367////////////////////////////////////////////////////////////////////////////////////////////////////
368// ConcurrentAccessErr
369////////////////////////////////////////////////////////////////////////////////////////////////////
370
371#[derive(Copy, Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
372pub struct ConcurrentAccessErr(());
373
374/********** impl Display **************************************************************************/
375
376impl fmt::Display for ConcurrentAccessErr {
377    #[inline]
378    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
379        write!(f, "concurrent access from live thread token (not all tokens have yet been dropped")
380    }
381}
382
383/********** impl Error ****************************************************************************/
384
385#[cfg(feature = "std")]
386impl std::error::Error for ConcurrentAccessErr {}
387
388////////////////////////////////////////////////////////////////////////////////////////////////////
389// Storage
390////////////////////////////////////////////////////////////////////////////////////////////////////
391
392#[derive(Debug)]
393enum Storage<'s, T> {
394    Buffer(&'s [Local<T>]),
395    Heap(Box<[Local<T>]>),
396}
397
398/********** impl inherent *************************************************************************/
399
400impl<T> Storage<'_, T> {
401    #[inline]
402    fn len(&self) -> usize {
403        match self {
404            Storage::Buffer(slice) => slice.len(),
405            Storage::Heap(boxed) => boxed.len(),
406        }
407    }
408}
409
410/********** impl Index ****************************************************************************/
411
412impl<T> Index<usize> for Storage<'_, T> {
413    type Output = Local<T>;
414
415    #[inline]
416    fn index(&self, index: usize) -> &Self::Output {
417        match self {
418            &Storage::Buffer(slice) => &slice[index],
419            &Storage::Heap(ref boxed) => &boxed[index],
420        }
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    extern crate std;
427
428    #[cfg(any(feature = "alloc", feature = "std"))]
429    use std::sync::Arc;
430    use std::thread;
431    use std::vec::Vec;
432
433    use super::BoundedThreadLocal;
434    use crate::Local;
435
436    #[test]
437    fn static_buffer() {
438        static BUF: [Local<usize>; 4] =
439            [Local::new(0), Local::new(0), Local::new(0), Local::new(0)];
440        static TLS: BoundedThreadLocal<usize> = unsafe { BoundedThreadLocal::with_buffer(&BUF) };
441
442        let handles: Vec<_> = (0..BUF.len())
443            .map(|_| {
444                thread::spawn(move || {
445                    let mut token = TLS.thread_token().unwrap();
446                    for _ in 0..10 {
447                        *token.get_mut() += 1;
448                    }
449                })
450            })
451            .collect();
452
453        for handle in handles {
454            handle.join().unwrap();
455        }
456
457        assert!(TLS.try_iter().unwrap().all(|&count| count == 10));
458    }
459
460    #[test]
461    fn into_iter() {
462        const THREADS: usize = 4;
463        let tls: Arc<BoundedThreadLocal<usize>> = Arc::new(BoundedThreadLocal::new(THREADS));
464
465        let handles: Vec<_> = (0..THREADS)
466            .map(|_| {
467                let tls = Arc::clone(&tls);
468                thread::spawn(move || {
469                    let mut token = tls.thread_token().unwrap();
470                    for _ in 0..10 {
471                        *token.get_mut() += 1;
472                    }
473                })
474            })
475            .collect();
476
477        for handle in handles {
478            handle.join().unwrap();
479        }
480
481        let counter = Arc::try_unwrap(tls).unwrap();
482        assert_eq!(counter.into_iter().sum::<usize>(), THREADS * 10);
483    }
484}