conquer_util/local.rs
1#[cfg(all(feature = "alloc", not(feature = "std")))]
2use alloc::boxed::Box;
3
4use core::cell::UnsafeCell;
5use core::fmt;
6use core::hint;
7use core::ops::Index;
8use core::sync::atomic::{AtomicUsize, Ordering};
9
10////////////////////////////////////////////////////////////////////////////////////////////////////
11// BoundedThreadLocal
12////////////////////////////////////////////////////////////////////////////////////////////////////
13
14/// A one-shot, lock-free, bounded per-value thread local storage.
15///
16/// # Examples
17///
18/// ```
19/// use std::thread;
20/// use std::sync::Arc;
21///
22/// use conquer_util::BoundedThreadLocal;
23///
24/// const THREADS: usize = 4;
25/// let counter = Arc::new(BoundedThreadLocal::<usize>::new(THREADS));
26///
27/// let handles: Vec<_> = (0..THREADS)
28/// .map(|id| {
29/// let counter = Arc::clone(&counter);
30/// thread::spawn(move || {
31/// let mut token = counter.thread_token().unwrap();
32/// *token.get_mut() += id;
33/// })
34/// })
35/// .collect();
36///
37/// for handle in handles {
38/// handle.join().unwrap();
39/// }
40///
41/// let mut counter = Arc::try_unwrap(counter).unwrap();
42///
43/// let sum: usize = counter.into_iter().sum();
44/// assert_eq!(sum, (0..4).sum());
45/// ```
46pub struct BoundedThreadLocal<'s, T> {
47 storage: Storage<'s, T>,
48 registered: AtomicUsize,
49 completed: AtomicUsize,
50}
51
52/********** impl Send + Sync **********************************************************************/
53
54unsafe impl<T> Send for BoundedThreadLocal<'_, T> {}
55unsafe impl<T> Sync for BoundedThreadLocal<'_, T> {}
56
57/********** impl inherent *************************************************************************/
58
59impl<'s, T: Default> BoundedThreadLocal<'s, T> {
60 /// Creates a new [`Default`] initialized [`BoundedThreadLocal`] that
61 /// internally allocates a buffer of `max_size`.
62 ///
63 /// # Panics
64 ///
65 /// This method panics, if `max_size` is 0.
66 #[inline]
67 pub fn new(max_threads: usize) -> Self {
68 Self::with_init(max_threads, Default::default)
69 }
70}
71
72impl<'s, T> BoundedThreadLocal<'s, T> {
73 /// Creates a new [`BoundedThreadLocal`] that internally allocates a buffer
74 /// of `max_size` and initializes each [`Local`] with `init`.
75 ///
76 /// # Panics
77 ///
78 /// This method panics, if `max_size` is 0.
79 #[inline]
80 pub fn with_init(max_threads: usize, init: impl Fn() -> T) -> Self {
81 assert!(max_threads > 0, "`max_threads` must be greater than 0");
82 Self {
83 storage: Storage::Heap(
84 (0..max_threads).map(|_| Local(UnsafeCell::new(Some(init())))).collect(),
85 ),
86 registered: AtomicUsize::new(0),
87 completed: AtomicUsize::new(0),
88 }
89 }
90}
91
92impl<'s, T> BoundedThreadLocal<'s, T> {
93 /// Creates a new [`BoundedThreadLocal`] that is based on a separate `buffer`.
94 ///
95 /// # Safety
96 ///
97 /// The given `buf` must be treated as if it were mutably, i.e. it **must
98 /// not** be used or otherwise accessed during the lifetime of the
99 /// [`BoundedThreadLocal`] that borrows it.
100 ///
101 /// # Examples
102 ///
103 /// ```
104 /// use conquer_util::{BoundedThreadLocal, Local};
105 ///
106 /// static BUF: [Local<usize>; 4] =
107 /// [Local::new(0), Local::new(0), Local::new(0), Local::new(0)];
108 /// static TLS: BoundedThreadLocal<usize> = unsafe { BoundedThreadLocal::with_buffer(&BUF) };
109 /// assert_eq!(TLS.thread_token().unwrap().get(), &0);
110 /// ```
111 #[inline]
112 pub const unsafe fn with_buffer(buf: &'s [Local<T>]) -> Self {
113 Self {
114 storage: Storage::Buffer(buf),
115 registered: AtomicUsize::new(0),
116 completed: AtomicUsize::new(0),
117 }
118 }
119
120 /// Returns a thread local token to a unique instance of `T`.
121 ///
122 /// The thread local instance will **not** be dropped, when the token itself
123 /// is dropped and can e.g. be iterated afterwards.
124 ///
125 /// # Errors
126 ///
127 /// This method fails, if the maximum number of tokens for this
128 /// [`BoundedThreadLocal`] has already been acquired.
129 ///
130 /// # Examples
131 ///
132 /// ```
133 /// use conquer_util::BoundedThreadLocal;
134 ///
135 /// # fn main() -> Result<(), conquer_util::BoundsError> {
136 ///
137 /// let tls = BoundedThreadLocal::<usize>::new(1);
138 /// let mut token = tls.thread_token()?;
139 /// *token.get_mut() += 1;
140 /// assert_eq!(token.get(), &1);
141 ///
142 /// # Ok(())
143 /// # }
144 /// ```
145 #[inline]
146 pub fn thread_token(&self) -> Result<Token<'_, '_, T>, BoundsError> {
147 let token: usize = self.registered.fetch_add(1, Ordering::Relaxed);
148 assert!(token <= isize::max_value() as usize, "thread counter close to overflow");
149
150 if token < self.storage.len() {
151 let local = unsafe { self.storage[token].as_ptr_unchecked() };
152 Ok(Token { local, tls: self })
153 } else {
154 Err(BoundsError(()))
155 }
156 }
157
158 /// Attempts to create an [`Iter`] over all [`Local`] instances.
159 ///
160 /// # Errors
161 ///
162 /// Fails, if there are still outstanding thread tokens, that might
163 /// concurrently access any of the thread local state instances.
164 #[inline]
165 pub fn try_iter(&self) -> Result<Iter<T>, ConcurrentAccessErr> {
166 let (completed, len) = (self.completed.load(Ordering::Relaxed), self.storage.len());
167 if completed == len || completed == self.registered.load(Ordering::Relaxed) {
168 Ok(Iter { idx: 0, tls: self })
169 } else {
170 Err(ConcurrentAccessErr(()))
171 }
172 }
173}
174
175/********** impl IntoIterator *********************************************************************/
176
177impl<'s, T> IntoIterator for BoundedThreadLocal<'s, T> {
178 type Item = T;
179 type IntoIter = IntoIter<'s, T>;
180
181 #[inline]
182 fn into_iter(self) -> Self::IntoIter {
183 IntoIter { tls: self, idx: 0 }
184 }
185}
186
187/********** impl Debug ****************************************************************************/
188
189impl<T> fmt::Debug for BoundedThreadLocal<'_, T> {
190 #[inline]
191 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192 f.debug_struct("BoundedThreadLocal")
193 .field("max_size", &self.storage.len())
194 .field("access_count", &self.registered.load(Ordering::Relaxed))
195 .finish()
196 }
197}
198
199////////////////////////////////////////////////////////////////////////////////////////////////////
200// Local
201////////////////////////////////////////////////////////////////////////////////////////////////////
202
203/// A wrapper for an instance of `T` that can be managed by a
204/// [`BoundedThreadLocal`].
205#[derive(Debug, Default)]
206#[repr(align(64))]
207pub struct Local<T>(UnsafeCell<Option<T>>);
208
209/********** impl Send + Sync **********************************************************************/
210
211unsafe impl<T> Send for Local<T> {}
212unsafe impl<T> Sync for Local<T> {}
213
214/********** impl inherent *************************************************************************/
215
216impl<T> Local<T> {
217 /// Creates a new [`Local`].
218 #[inline]
219 pub const fn new(local: T) -> Self {
220 Self(UnsafeCell::new(Some(local)))
221 }
222
223 #[inline]
224 unsafe fn as_ptr_unchecked(&self) -> *mut T {
225 (*self.0.get()).as_mut().unwrap_or_else(|| hint::unreachable_unchecked())
226 }
227
228 #[inline]
229 unsafe fn take_unchecked(&self) -> T {
230 (*self.0.get()).take().unwrap_or_else(|| hint::unreachable_unchecked())
231 }
232}
233
234////////////////////////////////////////////////////////////////////////////////////////////////////
235// Token
236////////////////////////////////////////////////////////////////////////////////////////////////////
237
238/// A thread local token granting unique access to an instance of `T` that is
239/// contained in a [`BoundedThreadLocal`]
240pub struct Token<'s, 'tls, T> {
241 local: *mut T,
242 tls: &'tls BoundedThreadLocal<'s, T>,
243}
244
245/********** impl inherent *************************************************************************/
246
247impl<T> Token<'_, '_, T> {
248 /// Returns a reference to the thread local state.
249 #[inline]
250 pub fn get(&self) -> &T {
251 unsafe { &(*self.local) }
252 }
253
254 /// Returns a mutable reference to the thread local state.
255 #[inline]
256 pub fn get_mut(&mut self) -> &mut T {
257 unsafe { &mut (*self.local) }
258 }
259}
260
261/********** impl Debug ****************************************************************************/
262
263impl<T: fmt::Debug> fmt::Debug for Token<'_, '_, T> {
264 #[inline]
265 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
266 f.debug_struct("Token").field("slot", &self.get()).finish()
267 }
268}
269
270/********** impl Display **************************************************************************/
271
272impl<T: fmt::Display> fmt::Display for Token<'_, '_, T> {
273 #[inline]
274 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
275 fmt::Display::fmt(&self.get(), f)
276 }
277}
278
279/********** impl Drop *****************************************************************************/
280
281impl<T> Drop for Token<'_, '_, T> {
282 fn drop(&mut self) {
283 self.tls.completed.fetch_add(1, Ordering::Relaxed);
284 }
285}
286
287////////////////////////////////////////////////////////////////////////////////////////////////////
288// Iter
289////////////////////////////////////////////////////////////////////////////////////////////////////
290
291#[derive(Debug)]
292pub struct Iter<'s, 'tls, T> {
293 idx: usize,
294 tls: &'tls BoundedThreadLocal<'s, T>,
295}
296
297/********** impl Iterator *************************************************************************/
298
299impl<'s, 'tls, T> Iterator for Iter<'s, 'tls, T> {
300 type Item = &'tls T;
301
302 #[inline]
303 fn next(&mut self) -> Option<Self::Item> {
304 let idx = self.idx;
305 if idx < self.tls.storage.len() {
306 self.idx += 1;
307 let local = unsafe { &*self.tls.storage[idx].as_ptr_unchecked() };
308 Some(local)
309 } else {
310 None
311 }
312 }
313}
314
315////////////////////////////////////////////////////////////////////////////////////////////////////
316// IntoIter
317////////////////////////////////////////////////////////////////////////////////////////////////////
318
319/// An owning iterator that can be created from an owned [`BoundedThreadLocal`].
320#[derive(Debug)]
321pub struct IntoIter<'s, T> {
322 idx: usize,
323 tls: BoundedThreadLocal<'s, T>,
324}
325
326/********** impl Iterator *************************************************************************/
327
328impl<T> Iterator for IntoIter<'_, T> {
329 type Item = T;
330
331 #[inline]
332 fn next(&mut self) -> Option<Self::Item> {
333 let idx = self.idx;
334 if idx < self.tls.storage.len() {
335 self.idx += 1;
336 let local = unsafe { self.tls.storage[idx].take_unchecked() };
337 Some(local)
338 } else {
339 None
340 }
341 }
342}
343
344////////////////////////////////////////////////////////////////////////////////////////////////////
345// BoundsError
346////////////////////////////////////////////////////////////////////////////////////////////////////
347
348/// An Error for signalling than more than the specified maximum number of
349/// threads attempted to access a [`BoundedThreadLocal`].
350#[derive(Copy, Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
351pub struct BoundsError(());
352
353/********** impl Display **************************************************************************/
354
355impl fmt::Display for BoundsError {
356 #[inline]
357 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
358 write!(f, "exceeded bounds for `BoundedThreadLocal`")
359 }
360}
361
362/********** impl Error ****************************************************************************/
363
364#[cfg(feature = "std")]
365impl std::error::Error for BoundsError {}
366
367////////////////////////////////////////////////////////////////////////////////////////////////////
368// ConcurrentAccessErr
369////////////////////////////////////////////////////////////////////////////////////////////////////
370
371#[derive(Copy, Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
372pub struct ConcurrentAccessErr(());
373
374/********** impl Display **************************************************************************/
375
376impl fmt::Display for ConcurrentAccessErr {
377 #[inline]
378 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
379 write!(f, "concurrent access from live thread token (not all tokens have yet been dropped")
380 }
381}
382
383/********** impl Error ****************************************************************************/
384
385#[cfg(feature = "std")]
386impl std::error::Error for ConcurrentAccessErr {}
387
388////////////////////////////////////////////////////////////////////////////////////////////////////
389// Storage
390////////////////////////////////////////////////////////////////////////////////////////////////////
391
392#[derive(Debug)]
393enum Storage<'s, T> {
394 Buffer(&'s [Local<T>]),
395 Heap(Box<[Local<T>]>),
396}
397
398/********** impl inherent *************************************************************************/
399
400impl<T> Storage<'_, T> {
401 #[inline]
402 fn len(&self) -> usize {
403 match self {
404 Storage::Buffer(slice) => slice.len(),
405 Storage::Heap(boxed) => boxed.len(),
406 }
407 }
408}
409
410/********** impl Index ****************************************************************************/
411
412impl<T> Index<usize> for Storage<'_, T> {
413 type Output = Local<T>;
414
415 #[inline]
416 fn index(&self, index: usize) -> &Self::Output {
417 match self {
418 &Storage::Buffer(slice) => &slice[index],
419 &Storage::Heap(ref boxed) => &boxed[index],
420 }
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 extern crate std;
427
428 #[cfg(any(feature = "alloc", feature = "std"))]
429 use std::sync::Arc;
430 use std::thread;
431 use std::vec::Vec;
432
433 use super::BoundedThreadLocal;
434 use crate::Local;
435
436 #[test]
437 fn static_buffer() {
438 static BUF: [Local<usize>; 4] =
439 [Local::new(0), Local::new(0), Local::new(0), Local::new(0)];
440 static TLS: BoundedThreadLocal<usize> = unsafe { BoundedThreadLocal::with_buffer(&BUF) };
441
442 let handles: Vec<_> = (0..BUF.len())
443 .map(|_| {
444 thread::spawn(move || {
445 let mut token = TLS.thread_token().unwrap();
446 for _ in 0..10 {
447 *token.get_mut() += 1;
448 }
449 })
450 })
451 .collect();
452
453 for handle in handles {
454 handle.join().unwrap();
455 }
456
457 assert!(TLS.try_iter().unwrap().all(|&count| count == 10));
458 }
459
460 #[test]
461 fn into_iter() {
462 const THREADS: usize = 4;
463 let tls: Arc<BoundedThreadLocal<usize>> = Arc::new(BoundedThreadLocal::new(THREADS));
464
465 let handles: Vec<_> = (0..THREADS)
466 .map(|_| {
467 let tls = Arc::clone(&tls);
468 thread::spawn(move || {
469 let mut token = tls.thread_token().unwrap();
470 for _ in 0..10 {
471 *token.get_mut() += 1;
472 }
473 })
474 })
475 .collect();
476
477 for handle in handles {
478 handle.join().unwrap();
479 }
480
481 let counter = Arc::try_unwrap(tls).unwrap();
482 assert_eq!(counter.into_iter().sum::<usize>(), THREADS * 10);
483 }
484}