cubecl_common/
stream_id.rs

1/// Unique identifier that can represent a stream based on the current thread id.
2#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, PartialOrd, Ord)]
4pub struct StreamId {
5    /// The value representing the thread id.
6    pub value: u64,
7}
8
9impl StreamId {
10    /// Get the current thread id.
11    pub fn current() -> Self {
12        Self {
13            #[cfg(feature = "std")]
14            value: Self::from_current_thread(),
15            #[cfg(not(feature = "std"))]
16            value: 0,
17        }
18    }
19
20    #[cfg(feature = "std")]
21    fn from_current_thread() -> u64 {
22        use core::hash::Hash;
23
24        std::thread_local! {
25            static ID: std::cell::OnceCell::<u64> = const { std::cell::OnceCell::new() };
26        };
27
28        // Getting the current thread is expensive, so we cache the value into a thread local
29        // variable, which is very fast.
30        ID.with(|cell| {
31            *cell.get_or_init(|| {
32                // A way to get a thread id encoded as u64.
33                let mut hasher = std::hash::DefaultHasher::default();
34                let id = std::thread::current().id();
35                id.hash(&mut hasher);
36                std::hash::Hasher::finish(&hasher)
37            })
38        })
39    }
40}
41
42impl core::fmt::Display for StreamId {
43    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
44        f.write_fmt(format_args!("StreamId({:?})", self.value))
45    }
46}