use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread::{sleep, spawn};
use std::time::Duration;
use rayon::prelude::*;
pub(super) fn timeout_signal(dur: Duration) -> Arc<AtomicBool> {
let signal = Arc::new(AtomicBool::new(false));
let signal2 = signal.clone();
spawn(move || {
sleep(dur);
signal2.store(true, Ordering::Relaxed);
});
signal
}
pub(super) struct AtomicBox<T>(AtomicPtr<T>);
impl<T> Default for AtomicBox<T> {
fn default() -> Self {
Self(AtomicPtr::default())
}
}
impl<T> AtomicBox<T> {
pub(super) fn try_set(&self, value: Box<T>) -> &T {
let ptr = Box::into_raw(value);
let ret_ptr = if let Err(new_ptr) =
self.0.compare_exchange(std::ptr::null_mut(), ptr, Ordering::SeqCst, Ordering::SeqCst)
{
unsafe { drop(Box::from_raw(ptr)) };
new_ptr
} else {
ptr
};
unsafe { ret_ptr.as_ref().unwrap() }
}
pub(super) fn get(&self) -> Option<&T> {
let ptr = self.0.load(Ordering::Relaxed);
unsafe { ptr.as_ref() }
}
}
impl<T> Drop for AtomicBox<T> {
fn drop(&mut self) {
let ptr = *self.0.get_mut();
if !ptr.is_null() {
unsafe { drop(Box::from_raw(ptr)) };
}
}
}
#[test]
fn test_atomic_box() {
let b = AtomicBox::<u32>::default();
assert_eq!(None, b.get());
b.try_set(Box::new(3));
assert_eq!(Some(&3), b.get());
b.try_set(Box::new(4));
assert_eq!(Some(&3), b.get());
}
pub(super) struct ThreadLocal<T> {
locals: Vec<T>,
ptr: *mut T,
}
unsafe impl<T: Send> Send for ThreadLocal<T> {}
unsafe impl<T> Sync for ThreadLocal<T> {}
impl<T: Send> ThreadLocal<T> {
pub(super) fn new<F: Fn() -> T>(f: F, pool: &rayon::ThreadPool) -> Self {
let n = pool.current_num_threads();
let mut locals = (0..n).map(|_| f()).collect::<Vec<_>>();
let ptr = locals.as_mut_ptr();
Self { locals, ptr }
}
pub(super) fn local_do<F: FnOnce(&mut T)>(&self, f: F) {
let index = rayon::current_thread_index().unwrap();
assert!(index < self.locals.len());
f(unsafe { self.ptr.add(index).as_mut().unwrap() });
}
pub(super) fn do_all_mut<F: FnMut(&mut T)>(&mut self, f: F) {
self.locals.iter_mut().for_each(f);
}
pub(super) fn do_all<F: FnMut(&T)>(&self, f: F) {
self.locals.iter().for_each(f);
}
}
#[test]
fn test_threadlocal() {
use rayon::prelude::*;
let pool = rayon::ThreadPoolBuilder::new().build().unwrap();
let mut tls = ThreadLocal::<u32>::new(|| 0, &pool);
let count = 100000;
(0..count).into_par_iter().for_each(|_| tls.local_do(|x| *x += 1));
let mut sum = 0;
tls.do_all_mut(|x| sum += *x);
assert_eq!(sum, count);
let result = std::panic::catch_unwind(|| {
tls.local_do(|x| *x += 1);
});
assert!(result.is_err());
}
#[repr(align(64))]
pub(super) struct CachePadded<T> {
value: T,
}
impl<T: Default> Default for CachePadded<T> {
fn default() -> Self {
Self { value: T::default() }
}
}
impl<T> Deref for CachePadded<T> {
type Target = T;
fn deref(&self) -> &T {
&self.value
}
}
impl<T> DerefMut for CachePadded<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.value
}
}
pub(super) fn par_iter_in_order<T: Send + Sync>(array: &[T]) -> impl ParallelIterator<Item = &T> {
let index = AtomicUsize::new(0);
(0..array.len())
.into_par_iter()
.with_max_len(1)
.map(move |_| &array[index.fetch_add(1, Ordering::SeqCst)])
}