thread_local_scope 1.0.0

Scoped access to thread local storage
Documentation
//! Provides a token type [`LocalScope`] that guards access to thread local storage, avoiding the need for a closure on every access.
//!
//! # Examples
//!
//! You can use the scoping to gracefully handle errors.
//!
//! ```
//! use thread_local_scope::local_scope;
//! # struct Whatever();
//! # impl Whatever { fn new() -> Self { Self() } }
//! thread_local! {
//!     static WHATEVER: Whatever = Whatever::new();
//! }
//!
//! fn with_whatever<R>(f: impl FnOnce(&Whatever) -> R) -> R {
//!     local_scope(|scope| {
//!         if let Ok(x) = scope.try_access(&WHATEVER) {
//!             f(x)
//!         } else {
//!             let stack_local_fallback = Whatever::new();
//!             f(&stack_local_fallback)
//!         }
//!     })
//! }
//!
//! // The equivalent without this requires .unwrap()
//! fn with_whatever_std<R>(f: impl FnOnce(&Whatever) -> R) -> R {
//!     let mut f = Some(f);
//!     WHATEVER
//!         .try_with(|x| f.take().unwrap()(x))
//!         .unwrap_or_else(|_| {
//!             let stack_local_fallback = Whatever::new();
//!             f.unwrap()(&stack_local_fallback)
//!         })
//! }
//! ```
//!
//!
//!
//! This allows avoiding nested closures if working with multiple LocalKeys.
//! ```
//! # use std::{thread::LocalKey, cell::Cell};
//! # use thread_local_scope::local_scope;
//! fn swap_local_cells<T>(a: &'static LocalKey<Cell<T>>, b: &'static LocalKey<Cell<T>>) {
//!     local_scope(|s| {
//!         s.access(a).swap(s.access(b))
//!     })
//! }
//!
//! fn swap_local_cells_std<T>(a: &'static LocalKey<Cell<T>>, b: &'static LocalKey<Cell<T>>) {
//!     a.with(|a| b.with(|b| a.swap(b)))
//! }
//! ```

use std::{
    fmt,
    marker::PhantomData,
    thread::{AccessError, LocalKey},
};

/// ZST token that guarantees consistent access to thread local storage values for the duration of `'a`.
///
/// Created with [`local_scope`].
///
/// # Thread safety
///
/// Since this struct makes assertions about the current thread, it implements neither [`Send`] nor [`Sync`].
#[derive(Clone, Copy)]
#[repr(transparent)]
pub struct LocalScope<'a>(PhantomData<*const &'a ()>);

impl<'a> fmt::Debug for LocalScope<'a> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("LocalScope").finish()
    }
}

/// Crates a new [`LocalScope`] bound to the current call stack.
///
/// Within the callback function, thread local storage can be freely accessed.
///
/// References to thread locals can't escape the callback:
/// ```compile_fail
/// thread_local! { static A: u8 = 0; }
/// thread_local_scope::local_scope(|x| x.access(&A));
/// ```
pub fn local_scope<F, R>(f: F) -> R
where
    F: for<'a> FnOnce(LocalScope<'a>) -> R,
{
    f(
        // Safety, because 'a is unbound in the callback signature, this lifetime is limited to the duration of this call, during which we can't enter any TLS destructors
        //
        // this is the same safety argument for why `LocalKey::with` is safe in the first place
        unsafe { LocalScope::new_unchecked() },
    )
}

/// Internal function used as the common callback.
/// This is only safe to use as a callback if thread locals will live for 'a
#[inline(always)]
fn unsafe_tls_callback<'a, T>(tls: &T) -> &'a T {
    unsafe { &*(tls as *const T) }
}

impl<'a> LocalScope<'a> {
    /// Creates a new [LocalScope] without any checks. Prefer [local_scope] for safe usage.
    ///
    /// # Safety
    ///
    /// References to thread local storage must live for `'a`, that is none of the current thread's local keys may become destroyed during `'a`.
    pub const unsafe fn new_unchecked() -> Self {
        Self(PhantomData)
    }

    /// Equivalent to [`LocalKey::try_with`] without the need for the closure.
    pub fn try_access<T>(self, target: &'static LocalKey<T>) -> Result<&'a T, AccessError> {
        target.try_with(unsafe_tls_callback)
    }

    /// Equivalent to [`LocalKey::with`] without the need for the closure.
    #[track_caller]
    pub fn access<T>(self, target: &'static LocalKey<T>) -> &'a T {
        target.with(unsafe_tls_callback)
    }
}

#[cfg(test)]
mod test {
    static_assertions::assert_not_impl_any!(LocalScope<'static>: Send, Sync);
    static_assertions::assert_eq_size!(LocalScope<'static>, ());

    use crate::*;
    use std::{
        cell::Cell,
        sync::atomic::{AtomicUsize, Ordering},
        thread::spawn,
    };

    #[test]
    fn re_entrant() {
        static DID_RUN_DESTRUCTOR: AtomicUsize = AtomicUsize::new(0);

        thread_local! {
            static MY_THING: MyThing = MyThing;
        }

        struct MyThing;
        impl Drop for MyThing {
            fn drop(&mut self) {
                local_scope(|sc| {
                    // we don't care, since we join to sync with the main thread anyways
                    DID_RUN_DESTRUCTOR.fetch_add(1, Ordering::Relaxed);
                    assert!(
                        sc.try_access(&MY_THING).is_err(),
                        "Can't access self while in destructor"
                    )
                })
            }
        }

        spawn(|| {
            local_scope(|s| {
                let _ = s.try_access(&MY_THING).expect("Testing, should be defined");
            })
        })
        .join()
        .unwrap();

        assert_eq!(DID_RUN_DESTRUCTOR.load(Ordering::Relaxed), 1);
    }

    #[test]
    fn swap() {
        fn swap_local_cells<T>(a: &'static LocalKey<Cell<T>>, b: &'static LocalKey<Cell<T>>) {
            local_scope(|s| s.access(a).swap(s.access(b)))
        }

        thread_local! {
            static A: Cell<u8> = Cell::new(0);
            static B: Cell<u8> = Cell::new(1);
        }

        swap_local_cells(&A, &B);

        assert_eq!(A.with(|x| x.get()), 1);
        assert_eq!(B.with(|x| x.get()), 0);
    }
}