burn_std/
id.rs

1//! # Unique Identifiers
2use crate::rand::gen_random;
3
4/// Simple ID generator.
5pub struct IdGenerator {}
6
7impl IdGenerator {
8    /// Generates a new ID.
9    pub fn generate() -> u64 {
10        // Generate a random u64 (18,446,744,073,709,551,615 combinations)
11        let random_bytes: [u8; 8] = gen_random();
12        u64::from_le_bytes(random_bytes)
13    }
14}
15
16pub use cubecl_common::stream_id::StreamId;
17
18#[cfg(test)]
19mod tests {
20    use super::*;
21
22    use alloc::collections::BTreeSet;
23
24    #[cfg(feature = "std")]
25    use dashmap::DashSet; //Concurrent HashMap
26    #[cfg(feature = "std")]
27    use std::{sync::Arc, thread};
28
29    #[test]
30    fn uniqueness_test() {
31        const IDS_CNT: usize = 10_000;
32
33        let mut set: BTreeSet<u64> = BTreeSet::new();
34
35        for _i in 0..IDS_CNT {
36            assert!(set.insert(IdGenerator::generate()));
37        }
38
39        assert_eq!(set.len(), IDS_CNT);
40    }
41
42    #[cfg(feature = "std")]
43    #[test]
44    fn thread_safety_test() {
45        const NUM_THREADS: usize = 10;
46        const NUM_REPEATS: usize = 1_000;
47        const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS;
48
49        let set: Arc<DashSet<u64>> = Arc::new(DashSet::new());
50
51        let mut handles = vec![];
52
53        for _ in 0..NUM_THREADS {
54            let set = set.clone();
55
56            let handle = thread::spawn(move || {
57                for _i in 0..NUM_REPEATS {
58                    assert!(set.insert(IdGenerator::generate()));
59                }
60            });
61            handles.push(handle);
62        }
63
64        for handle in handles {
65            handle.join().unwrap();
66        }
67        assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
68    }
69}