cubecl_common/
stream_id.rs1#[cfg(multi_threading)]
2use core::sync::atomic::AtomicU64;
3
4#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, PartialOrd, Ord)]
7pub struct StreamId {
8    pub value: u64,
10}
11
12#[cfg(multi_threading)]
13static STREAM_COUNT: AtomicU64 = AtomicU64::new(0);
14
15#[cfg(multi_threading)]
16std::thread_local! {
17        static ID: std::cell::RefCell::<Option<u64>> = const { std::cell::RefCell::new(None) };
18}
19
20impl StreamId {
21    pub fn current() -> Self {
23        Self {
24            #[cfg(multi_threading)]
25            value: Self::from_current_thread(),
26            #[cfg(not(multi_threading))]
27            value: 0,
28        }
29    }
30
31    pub unsafe fn swap(stream: StreamId) -> StreamId {
37        unsafe {
38            #[cfg(multi_threading)]
39            return Self::swap_multi_thread(stream);
40
41            #[cfg(not(multi_threading))]
42            return Self::swap_single_thread(stream);
43        }
44    }
45
46    #[cfg(multi_threading)]
47    unsafe fn swap_multi_thread(stream: StreamId) -> StreamId {
48        let old = Self::current();
49        ID.with(|cell| {
50            let mut val = cell.borrow_mut();
51            *val = Some(stream.value)
52        });
53
54        old
55    }
56
57    #[cfg(not(multi_threading))]
58    unsafe fn swap_single_thread(stream: StreamId) -> StreamId {
59        stream
60    }
61
62    #[cfg(multi_threading)]
63    fn from_current_thread() -> u64 {
64        ID.with(|cell| {
65            let mut val = cell.borrow_mut();
66            match val.as_mut() {
67                Some(val) => *val,
68                None => {
69                    let new = STREAM_COUNT.fetch_add(1, core::sync::atomic::Ordering::Acquire);
70                    *val = Some(new);
71                    new
72                }
73            }
74        })
75    }
76}
77
78impl core::fmt::Display for StreamId {
79    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
80        f.write_fmt(format_args!("StreamId({:?})", self.value))
81    }
82}