use std::cell::Cell;
use std::iter::FilterMap;
use std::ops::{Deref, DerefMut};
use std::sync::Mutex;
pub struct ThreadLocalCtx<T, F> {
inner: F,
init_mutex: Mutex<Vec<(Option<T>, bool)>>,
cloned: Cell<*mut (Option<T>, bool)>,
}
unsafe impl<T, F: Send + Sync> Sync for ThreadLocalCtx<T, F> {}
unsafe impl<T, F: Send + Sync> Send for ThreadLocalCtx<T, F> {}
impl<T, F: Fn() -> T> ThreadLocalCtx<T, F> {
pub fn new(inner: F) -> Self {
Self {
inner, init_mutex: Mutex::new(vec![]),
cloned: Cell::new(std::ptr::null_mut()),
}
}
pub fn new_locked(inner: F) -> ThreadLocalCtx<T, impl Fn() -> T> {
let locked_inner = Mutex::new(inner);
let inner = move || (locked_inner.lock().unwrap())();
ThreadLocalCtx {
inner,
init_mutex: Mutex::new(vec![]),
cloned: Cell::new(std::ptr::null_mut()),
}
}
pub unsafe fn get(&self) -> ThreadLocalMut<T, F> {
if self.cloned.get().is_null() {
let mut data = self.init_mutex.lock().unwrap();
if self.cloned.get().is_null() {
*data = (0..=rayon::current_num_threads())
.map(|_| (None, false))
.collect();
self.cloned.set(data.as_mut_ptr());
}
}
let tid = rayon::current_thread_index().map(|i| i + 1).unwrap_or(0);
match &mut *self.cloned.get().add(tid) {
(_, true) => panic!("Already borrowed the value on thread {}!", tid),
(Some(val), b) => {
*b = true;
ThreadLocalMut {
val,
parent: self,
tid,
}
}
(val, b) => {
*b = true;
let cloned = (self.inner)();
*val = Some(cloned);
ThreadLocalMut {
val: val.as_mut().unwrap(),
parent: self,
tid,
}
}
}
}
}
type VecIter<T> = std::vec::IntoIter<(Option<T>, bool)>;
type FmapFn<T> = fn((Option<T>, bool)) -> Option<T>;
impl<T, F> IntoIterator for ThreadLocalCtx<T, F> {
type Item = T;
type IntoIter = FilterMap<VecIter<T>, FmapFn<T>>;
fn into_iter(self) -> Self::IntoIter {
self.init_mutex
.into_inner()
.unwrap()
.into_iter()
.filter_map(|(i, _)| i)
}
}
pub struct ThreadLocalMut<'a, T, F> {
val: &'a mut T,
parent: &'a ThreadLocalCtx<T, F>,
tid: usize,
}
impl<'a, T, F> Deref for ThreadLocalMut<'a, T, F> {
type Target = T;
fn deref(&self) -> &T {
self.val
}
}
impl<'a, T, F> DerefMut for ThreadLocalMut<'a, T, F> {
fn deref_mut(&mut self) -> &mut T {
self.val
}
}
impl<'a, T, F> Drop for ThreadLocalMut<'a, T, F> {
fn drop(&mut self) {
unsafe {
(*self.parent.cloned.get().add(self.tid)).1 = false;
}
}
}