#[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::vec::Vec;
use core::{cell::RefCell, ops::DerefMut};
use thread_local::ThreadLocal;
#[derive(Default)]
pub struct Parallel<T: Send> {
locals: ThreadLocal<RefCell<T>>,
}
impl<T: Send> Parallel<T> {
pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
self.locals.iter_mut().map(RefCell::get_mut)
}
pub fn clear(&mut self) {
self.locals.clear();
}
}
impl<T: Default + Send> Parallel<T> {
pub fn scope<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
let mut cell = self.locals.get_or_default().borrow_mut();
let ret = f(cell.deref_mut());
ret
}
pub fn borrow_local_mut(&self) -> impl DerefMut<Target = T> + '_ {
self.locals.get_or_default().borrow_mut()
}
}
impl<T, I> Parallel<I>
where
I: IntoIterator<Item = T> + Default + Send + 'static,
{
pub fn drain(&mut self) -> impl Iterator<Item = T> + '_ {
self.locals.iter_mut().flat_map(|item| item.take())
}
}
#[cfg(feature = "alloc")]
impl<T: Send> Parallel<Vec<T>> {
pub fn drain_into(&mut self, out: &mut Vec<T>) {
let size = self
.locals
.iter_mut()
.map(|queue| queue.get_mut().len())
.sum();
out.reserve(size);
for queue in self.locals.iter_mut() {
out.append(queue.get_mut());
}
}
}