parallel_processor/utils/
scoped_thread_local.rs

1use dashmap::DashMap;
2use once_cell::sync::Lazy;
3use std::any::Any;
4use std::cell::UnsafeCell;
5use std::marker::PhantomData;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::thread::ThreadId;
9
10struct ThreadVarRef {
11    val: UnsafeCell<Option<Box<dyn Any>>>,
12}
13unsafe impl Sync for ThreadVarRef {}
14unsafe impl Send for ThreadVarRef {}
15
16static THREADS_MAP: Lazy<DashMap<u64, DashMap<ThreadId, ThreadVarRef>>> =
17    Lazy::new(|| DashMap::new());
18
19static THREAD_LOCAL_VAR_INDEX: AtomicU64 = AtomicU64::new(0);
20
21pub struct ThreadLocalVariable<T: 'static> {
22    var: Option<T>,
23    index: u64,
24    _not_send_sync: PhantomData<*mut ()>,
25}
26
27impl<T> Deref for ThreadLocalVariable<T> {
28    type Target = T;
29
30    #[inline(always)]
31    fn deref(&self) -> &Self::Target {
32        self.var.as_ref().unwrap()
33    }
34}
35
36impl<T> DerefMut for ThreadLocalVariable<T> {
37    #[inline(always)]
38    fn deref_mut(&mut self) -> &mut Self::Target {
39        self.var.as_mut().unwrap()
40    }
41}
42
43impl<T> ThreadLocalVariable<T> {
44    pub fn take(&mut self) -> T {
45        self.var.take().unwrap()
46    }
47
48    pub fn put_back(&mut self, value: T) {
49        assert!(self.var.is_none());
50        self.var = Some(value);
51    }
52}
53
54impl<T: 'static> Drop for ThreadLocalVariable<T> {
55    #[inline(always)]
56    fn drop(&mut self) {
57        let obj_entry = THREADS_MAP.get(&self.index).unwrap();
58        unsafe {
59            *obj_entry
60                .get(&std::thread::current().id())
61                .unwrap()
62                .val
63                .get() = Some(Box::new(
64                self.var
65                    .take()
66                    .expect("Thread local variable not managed correctly"),
67            ));
68        }
69    }
70}
71
72pub struct ScopedThreadLocal<T: 'static> {
73    index: u64,
74    alloc: Box<dyn Fn() -> T + Send + Sync>,
75}
76
77impl<T: 'static> Drop for ScopedThreadLocal<T> {
78    fn drop(&mut self) {
79        THREADS_MAP.remove(&self.index);
80    }
81}
82
83impl<T: 'static> ScopedThreadLocal<T> {
84    pub fn new<F: Fn() -> T + Send + Sync + 'static>(alloc: F) -> Self {
85        let index = THREAD_LOCAL_VAR_INDEX.fetch_add(1, Ordering::Relaxed);
86        THREADS_MAP.insert(index, DashMap::new());
87
88        Self {
89            index,
90            alloc: Box::new(alloc),
91        }
92    }
93
94    pub fn get(&self) -> ThreadLocalVariable<T> {
95        let obj_entry = THREADS_MAP.get(&self.index).unwrap();
96        let entry = obj_entry
97            .entry(std::thread::current().id())
98            .or_insert_with(|| ThreadVarRef {
99                val: UnsafeCell::new(Some(Box::new((self.alloc)()))),
100            });
101
102        if let Some(value) = unsafe { (*entry.val.get()).take() } {
103            ThreadLocalVariable {
104                var: Some(*value.downcast().unwrap()),
105                index: self.index,
106                _not_send_sync: PhantomData,
107            }
108        } else {
109            panic!("Thread local variable taken multiple times, aborting!");
110        }
111    }
112}