cubecl_common/
stream_id.rs1#[cfg(multi_threading)]
2use core::sync::atomic::AtomicU64;
3
4#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, PartialOrd, Ord)]
7pub struct StreamId {
8 pub value: u64,
10}
11
12#[cfg(multi_threading)]
13static STREAM_COUNT: AtomicU64 = AtomicU64::new(0);
14
15#[cfg(multi_threading)]
16std::thread_local! {
17 static ID: std::cell::RefCell::<Option<u64>> = const { std::cell::RefCell::new(None) };
18}
19
20impl StreamId {
21 pub fn current() -> Self {
23 Self {
24 #[cfg(multi_threading)]
25 value: Self::from_current_thread(),
26 #[cfg(not(multi_threading))]
27 value: 0,
28 }
29 }
30
31 pub unsafe fn swap(stream: StreamId) -> StreamId {
37 unsafe {
38 #[cfg(multi_threading)]
39 return Self::swap_multi_thread(stream);
40
41 #[cfg(not(multi_threading))]
42 return Self::swap_single_thread(stream);
43 }
44 }
45
46 #[cfg(multi_threading)]
47 unsafe fn swap_multi_thread(stream: StreamId) -> StreamId {
48 let old = Self::current();
49 ID.with(|cell| {
50 let mut val = cell.borrow_mut();
51 *val = Some(stream.value)
52 });
53
54 old
55 }
56
57 #[cfg(not(multi_threading))]
58 unsafe fn swap_single_thread(stream: StreamId) -> StreamId {
59 stream
60 }
61
62 #[cfg(multi_threading)]
63 fn from_current_thread() -> u64 {
64 ID.with(|cell| {
65 let mut val = cell.borrow_mut();
66 match val.as_mut() {
67 Some(val) => *val,
68 None => {
69 let new = STREAM_COUNT.fetch_add(1, core::sync::atomic::Ordering::Acquire);
70 *val = Some(new);
71 new
72 }
73 }
74 })
75 }
76}
77
78impl core::fmt::Display for StreamId {
79 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
80 f.write_fmt(format_args!("StreamId({:?})", self.value))
81 }
82}