cubecl-common 0.10.0-pre.4

Common crate for CubeCL
Documentation
#[cfg(multi_threading)]
use core::sync::atomic::AtomicU64;

/// Unique identifier that can represent a stream based on the current thread id.
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, PartialOrd, Ord)]
pub struct StreamId {
    /// The value representing the thread id.
    pub value: u64,
}

#[cfg(multi_threading)]
static STREAM_COUNT: AtomicU64 = AtomicU64::new(0);

#[cfg(multi_threading)]
std::thread_local! {
        static ID: std::cell::RefCell::<Option<u64>> = const { std::cell::RefCell::new(None) };
}

impl StreamId {
    /// Executes `f` on this stream, restoring the previous stream afterward.
    ///
    /// The previous [`StreamId`] is saved before the call and restored on
    /// return — including on unwind — so the caller never has to manage
    /// raw `swap` pairs.
    pub fn executes<F, T>(self, f: F) -> T
    where
        F: FnOnce() -> T,
    {
        struct Guard(StreamId);

        impl Drop for Guard {
            fn drop(&mut self) {
                unsafe {
                    StreamId::swap(self.0);
                }
            }
        }

        let old = unsafe { StreamId::swap(self) };
        let guard = Guard(old);

        let returned = f();
        core::mem::drop(guard);
        returned
    }

    /// Get the current thread id.
    pub fn current() -> Self {
        Self {
            #[cfg(multi_threading)]
            value: Self::from_current_thread(),
            #[cfg(not(multi_threading))]
            value: 0,
        }
    }

    /// Swap the current stream id for the given one.
    ///
    /// # Safety
    ///
    /// Unknown at this point, don't use that if you don't know what you are doing.
    pub unsafe fn swap(stream: StreamId) -> StreamId {
        unsafe {
            #[cfg(multi_threading)]
            return Self::swap_multi_thread(stream);

            #[cfg(not(multi_threading))]
            return Self::swap_single_thread(stream);
        }
    }

    #[cfg(multi_threading)]
    unsafe fn swap_multi_thread(stream: StreamId) -> StreamId {
        let old = Self::current();
        ID.with(|cell| {
            let mut val = cell.borrow_mut();
            *val = Some(stream.value)
        });

        old
    }

    #[cfg(not(multi_threading))]
    unsafe fn swap_single_thread(stream: StreamId) -> StreamId {
        stream
    }

    #[cfg(multi_threading)]
    fn from_current_thread() -> u64 {
        ID.with(|cell| {
            let mut val = cell.borrow_mut();
            match val.as_mut() {
                Some(val) => *val,
                None => {
                    let new = STREAM_COUNT.fetch_add(1, core::sync::atomic::Ordering::Acquire);
                    *val = Some(new);
                    new
                }
            }
        })
    }
}

impl core::fmt::Display for StreamId {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        f.write_fmt(format_args!("StreamId({:?})", self.value))
    }
}