rayon-tlsctx 0.1.0

thread local contexts for rayon loops
Documentation
use std::cell::Cell;
use std::iter::FilterMap;
use std::ops::{Deref, DerefMut};
use std::sync::Mutex;

/// A thread local storage container for Rayon jobs
///
/// This context can be used to efficiently clone `inner`, only when it's necessary.
pub struct ThreadLocalCtx<T, F> {
    inner: Mutex<F>,
    init_mutex: Mutex<Vec<(Option<T>, bool)>>,
    cloned: Cell<*mut (Option<T>, bool)>,
}

unsafe impl<T: Send, F: Send> Sync for ThreadLocalCtx<T, F> {}
unsafe impl<T: Send, F: Send> Send for ThreadLocalCtx<T, F> {}

impl<T, F: Fn() -> T> ThreadLocalCtx<T, F> {
    /// Create a new `TlsCtx`
    ///
    /// # Examples
    ///
    /// Creating a thread-local byte buffer:
    /// ```
    /// use rayon_tlsctx::ThreadLocalCtx;
    /// let ctx = ThreadLocalCtx::new(Vec::<u8>::new);
    /// ```
    ///
    /// Cloning an initialised buffer for each thread:
    /// ```
    /// use rayon_tlsctx::ThreadLocalCtx;
    /// # use rand::prelude::*;
    /// # let mut rng = rand::thread_rng();
    ///
    /// let mut buf: Vec<u16> = (0..!0).collect();
    /// buf.shuffle(&mut rng);
    ///
    /// let ctx = ThreadLocalCtx::new(move || buf.clone());
    /// ```
    pub fn new(inner: F) -> Self {
        Self {
            inner: Mutex::new(inner),
            init_mutex: Mutex::new(vec![]),
            cloned: Cell::new(std::ptr::null_mut()),
        }
    }

    /// Get a thread local context reference with dynamically checked borrow rules
    ///
    /// # Safety
    ///
    /// Only one thread pool at a given time should use this scope. The thread pool size can not
    /// grow midway through.
    ///
    /// # Panics
    ///
    /// If called from outside of a Rayon thread pool, or when the reference is already being
    /// borrowed.
    ///
    /// # Examples
    ///
    /// ```
    /// use rayon_tlsctx::ThreadLocalCtx;
    /// use rayon::iter::*;
    ///
    /// const NUM_COPIES: usize = 16;
    ///
    /// let mut buf: Vec<u16> = (0..!0).collect();
    ///
    /// // Create a thread local context with value 0.
    /// let ctx = ThreadLocalCtx::new(|| 0);
    ///
    /// // Sum the buffer `NUM_COPIES` times and accumulate the results
    /// // into the threaded pool of counts. Note that the counts may be
    /// // Unevenly distributed.
    /// (0..NUM_COPIES)
    ///     .into_par_iter()
    ///     .flat_map(|_| buf.par_iter())
    ///     .for_each(|i| {
    ///         let mut cnt = unsafe { ctx.get() };
    ///         *cnt += *i as usize;
    ///     });
    ///
    /// let buf_sum = buf.into_iter().fold(0, |acc, i| acc + i as usize);
    ///
    /// // What matters is that the final sum matches the expected value.
    /// assert_eq!(ctx.into_iter().sum::<usize>(), buf_sum * NUM_COPIES);
    /// ```
    pub unsafe fn get<'a>(&'a self) -> ThreadLocalMut<'a, T, F> {
        if self.cloned.get().is_null() {
            let mut data = self.init_mutex.lock().unwrap();
            if self.cloned.get().is_null() {
                *data = (0..rayon::current_num_threads())
                    .map(|_| (None, false))
                    .collect();
                self.cloned.set(data.as_mut_ptr());
            }
        }

        let tid = rayon::current_thread_index().unwrap();

        match &mut *self.cloned.get().add(tid) {
            (_, true) => panic!("Already borrowed the value on thread {}!", tid),
            (Some(val), b) => {
                *b = true;
                ThreadLocalMut {
                    val,
                    parent: self,
                    tid,
                }
            }
            (val, b) => {
                *b = true;
                let cloned = (self.inner.lock().unwrap())();
                *val = Some(cloned);
                ThreadLocalMut {
                    val: val.as_mut().unwrap(),
                    parent: self,
                    tid,
                }
            }
        }
    }
}

/// Consume the context and retrieve all created items.
impl<T, F> IntoIterator for ThreadLocalCtx<T, F> {
    type Item = T;
    type IntoIter =
        FilterMap<std::vec::IntoIter<(Option<T>, bool)>, fn((Option<T>, bool)) -> Option<T>>;

    fn into_iter(self) -> Self::IntoIter {
        self.init_mutex
            .into_inner()
            .unwrap()
            .into_iter()
            .filter_map(|(i, _)| i)
    }
}

/// Borrowed thread local variable.
///
/// This structure tracks borrow rules at runtime, it may be necessary to manually
/// drop the object, if multiple rayon loops are involved.
pub struct ThreadLocalMut<'a, T, F> {
    val: &'a mut T,
    parent: &'a ThreadLocalCtx<T, F>,
    tid: usize,
}

impl<'a, T, F> Deref for ThreadLocalMut<'a, T, F> {
    type Target = T;

    fn deref(&self) -> &T {
        self.val
    }
}

impl<'a, T, F> DerefMut for ThreadLocalMut<'a, T, F> {
    fn deref_mut(&mut self) -> &mut T {
        self.val
    }
}

impl<'a, T, F> Drop for ThreadLocalMut<'a, T, F> {
    fn drop(&mut self) {
        unsafe {
            (*self.parent.cloned.get().add(self.tid)).1 = false;
        }
    }
}