parallel_processor/utils/
scoped_thread_local.rs1use dashmap::DashMap;
2use once_cell::sync::Lazy;
3use std::any::Any;
4use std::cell::UnsafeCell;
5use std::marker::PhantomData;
6use std::ops::{Deref, DerefMut};
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::thread::ThreadId;
9
10struct ThreadVarRef {
11 val: UnsafeCell<Option<Box<dyn Any>>>,
12}
13unsafe impl Sync for ThreadVarRef {}
14unsafe impl Send for ThreadVarRef {}
15
16static THREADS_MAP: Lazy<DashMap<u64, DashMap<ThreadId, ThreadVarRef>>> =
17 Lazy::new(|| DashMap::new());
18
19static THREAD_LOCAL_VAR_INDEX: AtomicU64 = AtomicU64::new(0);
20
21pub struct ThreadLocalVariable<T: 'static> {
22 var: Option<T>,
23 index: u64,
24 _not_send_sync: PhantomData<*mut ()>,
25}
26
27impl<T> Deref for ThreadLocalVariable<T> {
28 type Target = T;
29
30 #[inline(always)]
31 fn deref(&self) -> &Self::Target {
32 self.var.as_ref().unwrap()
33 }
34}
35
36impl<T> DerefMut for ThreadLocalVariable<T> {
37 #[inline(always)]
38 fn deref_mut(&mut self) -> &mut Self::Target {
39 self.var.as_mut().unwrap()
40 }
41}
42
43impl<T> ThreadLocalVariable<T> {
44 pub fn take(&mut self) -> T {
45 self.var.take().unwrap()
46 }
47
48 pub fn put_back(&mut self, value: T) {
49 assert!(self.var.is_none());
50 self.var = Some(value);
51 }
52}
53
54impl<T: 'static> Drop for ThreadLocalVariable<T> {
55 #[inline(always)]
56 fn drop(&mut self) {
57 let obj_entry = THREADS_MAP.get(&self.index).unwrap();
58 unsafe {
59 *obj_entry
60 .get(&std::thread::current().id())
61 .unwrap()
62 .val
63 .get() = Some(Box::new(
64 self.var
65 .take()
66 .expect("Thread local variable not managed correctly"),
67 ));
68 }
69 }
70}
71
72pub struct ScopedThreadLocal<T: 'static> {
73 index: u64,
74 alloc: Box<dyn Fn() -> T + Send + Sync>,
75}
76
77impl<T: 'static> Drop for ScopedThreadLocal<T> {
78 fn drop(&mut self) {
79 THREADS_MAP.remove(&self.index);
80 }
81}
82
83impl<T: 'static> ScopedThreadLocal<T> {
84 pub fn new<F: Fn() -> T + Send + Sync + 'static>(alloc: F) -> Self {
85 let index = THREAD_LOCAL_VAR_INDEX.fetch_add(1, Ordering::Relaxed);
86 THREADS_MAP.insert(index, DashMap::new());
87
88 Self {
89 index,
90 alloc: Box::new(alloc),
91 }
92 }
93
94 pub fn get(&self) -> ThreadLocalVariable<T> {
95 let obj_entry = THREADS_MAP.get(&self.index).unwrap();
96 let entry = obj_entry
97 .entry(std::thread::current().id())
98 .or_insert_with(|| ThreadVarRef {
99 val: UnsafeCell::new(Some(Box::new((self.alloc)()))),
100 });
101
102 if let Some(value) = unsafe { (*entry.val.get()).take() } {
103 ThreadLocalVariable {
104 var: Some(*value.downcast().unwrap()),
105 index: self.index,
106 _not_send_sync: PhantomData,
107 }
108 } else {
109 panic!("Thread local variable taken multiple times, aborting!");
110 }
111 }
112}