Skip to main content

cubecl_common/
stream_id.rs

1#[cfg(multi_threading)]
2use core::sync::atomic::AtomicU64;
3
4/// Unique identifier that can represent a stream based on the current thread id.
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, PartialOrd, Ord)]
7pub struct StreamId {
8    /// The value representing the thread id.
9    pub value: u64,
10}
11
12#[cfg(multi_threading)]
13static STREAM_COUNT: AtomicU64 = AtomicU64::new(0);
14
15#[cfg(multi_threading)]
16std::thread_local! {
17        static ID: std::cell::RefCell::<Option<u64>> = const { std::cell::RefCell::new(None) };
18}
19
20impl StreamId {
21    /// Executes `f` on this stream, restoring the previous stream afterward.
22    ///
23    /// The previous [`StreamId`] is saved before the call and restored on
24    /// return — including on unwind — so the caller never has to manage
25    /// raw `swap` pairs.
26    pub fn executes<F, T>(self, f: F) -> T
27    where
28        F: FnOnce() -> T,
29    {
30        struct Guard(StreamId);
31
32        impl Drop for Guard {
33            fn drop(&mut self) {
34                unsafe {
35                    StreamId::swap(self.0);
36                }
37            }
38        }
39
40        let old = unsafe { StreamId::swap(self) };
41        let guard = Guard(old);
42
43        let returned = f();
44        core::mem::drop(guard);
45        returned
46    }
47
48    /// Get the current thread id.
49    pub fn current() -> Self {
50        Self {
51            #[cfg(multi_threading)]
52            value: Self::from_current_thread(),
53            #[cfg(not(multi_threading))]
54            value: 0,
55        }
56    }
57
58    /// Swap the current stream id for the given one.
59    ///
60    /// # Safety
61    ///
62    /// Unknown at this point, don't use that if you don't know what you are doing.
63    pub unsafe fn swap(stream: StreamId) -> StreamId {
64        unsafe {
65            #[cfg(multi_threading)]
66            return Self::swap_multi_thread(stream);
67
68            #[cfg(not(multi_threading))]
69            return Self::swap_single_thread(stream);
70        }
71    }
72
73    #[cfg(multi_threading)]
74    unsafe fn swap_multi_thread(stream: StreamId) -> StreamId {
75        let old = Self::current();
76        ID.with(|cell| {
77            let mut val = cell.borrow_mut();
78            *val = Some(stream.value)
79        });
80
81        old
82    }
83
84    #[cfg(not(multi_threading))]
85    unsafe fn swap_single_thread(stream: StreamId) -> StreamId {
86        stream
87    }
88
89    #[cfg(multi_threading)]
90    fn from_current_thread() -> u64 {
91        ID.with(|cell| {
92            let mut val = cell.borrow_mut();
93            match val.as_mut() {
94                Some(val) => *val,
95                None => {
96                    let new = STREAM_COUNT.fetch_add(1, core::sync::atomic::Ordering::Acquire);
97                    *val = Some(new);
98                    new
99                }
100            }
101        })
102    }
103}
104
105impl core::fmt::Display for StreamId {
106    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
107        f.write_fmt(format_args!("StreamId({:?})", self.value))
108    }
109}