cubecl_common/
stream_id.rs1#[cfg(multi_threading)]
2use core::sync::atomic::AtomicU64;
3
4#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, PartialOrd, Ord)]
7pub struct StreamId {
8 pub value: u64,
10}
11
12#[cfg(multi_threading)]
13static STREAM_COUNT: AtomicU64 = AtomicU64::new(0);
14
15#[cfg(multi_threading)]
16std::thread_local! {
17 static ID: std::cell::RefCell::<Option<u64>> = const { std::cell::RefCell::new(None) };
18}
19
20impl StreamId {
21 pub fn executes<F, T>(self, f: F) -> T
27 where
28 F: FnOnce() -> T,
29 {
30 struct Guard(StreamId);
31
32 impl Drop for Guard {
33 fn drop(&mut self) {
34 unsafe {
35 StreamId::swap(self.0);
36 }
37 }
38 }
39
40 let old = unsafe { StreamId::swap(self) };
41 let guard = Guard(old);
42
43 let returned = f();
44 core::mem::drop(guard);
45 returned
46 }
47
48 pub fn current() -> Self {
50 Self {
51 #[cfg(multi_threading)]
52 value: Self::from_current_thread(),
53 #[cfg(not(multi_threading))]
54 value: 0,
55 }
56 }
57
58 pub unsafe fn swap(stream: StreamId) -> StreamId {
64 unsafe {
65 #[cfg(multi_threading)]
66 return Self::swap_multi_thread(stream);
67
68 #[cfg(not(multi_threading))]
69 return Self::swap_single_thread(stream);
70 }
71 }
72
73 #[cfg(multi_threading)]
74 unsafe fn swap_multi_thread(stream: StreamId) -> StreamId {
75 let old = Self::current();
76 ID.with(|cell| {
77 let mut val = cell.borrow_mut();
78 *val = Some(stream.value)
79 });
80
81 old
82 }
83
84 #[cfg(not(multi_threading))]
85 unsafe fn swap_single_thread(stream: StreamId) -> StreamId {
86 stream
87 }
88
89 #[cfg(multi_threading)]
90 fn from_current_thread() -> u64 {
91 ID.with(|cell| {
92 let mut val = cell.borrow_mut();
93 match val.as_mut() {
94 Some(val) => *val,
95 None => {
96 let new = STREAM_COUNT.fetch_add(1, core::sync::atomic::Ordering::Acquire);
97 *val = Some(new);
98 new
99 }
100 }
101 })
102 }
103}
104
105impl core::fmt::Display for StreamId {
106 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
107 f.write_fmt(format_args!("StreamId({:?})", self.value))
108 }
109}