burn_common/
id.rs

1use crate::rand::gen_random;
2
3/// Simple ID generator.
4pub struct IdGenerator {}
5
6impl IdGenerator {
7    /// Generates a new ID.
8    pub fn generate() -> u64 {
9        // Generate a random u64 (18,446,744,073,709,551,615 combinations)
10        let random_bytes: [u8; 8] = gen_random();
11        u64::from_le_bytes(random_bytes)
12    }
13}
14
15pub use cubecl_common::stream_id::StreamId;
16
17#[cfg(test)]
18mod tests {
19    use super::*;
20
21    use alloc::collections::BTreeSet;
22
23    #[cfg(feature = "std")]
24    use dashmap::DashSet; //Concurrent HashMap
25    #[cfg(feature = "std")]
26    use std::{sync::Arc, thread};
27
28    #[test]
29    fn uniqueness_test() {
30        const IDS_CNT: usize = 10_000;
31
32        let mut set: BTreeSet<u64> = BTreeSet::new();
33
34        for _i in 0..IDS_CNT {
35            assert!(set.insert(IdGenerator::generate()));
36        }
37
38        assert_eq!(set.len(), IDS_CNT);
39    }
40
41    #[cfg(feature = "std")]
42    #[test]
43    fn thread_safety_test() {
44        const NUM_THREADS: usize = 10;
45        const NUM_REPEATS: usize = 1_000;
46        const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS;
47
48        let set: Arc<DashSet<u64>> = Arc::new(DashSet::new());
49
50        let mut handles = vec![];
51
52        for _ in 0..NUM_THREADS {
53            let set = set.clone();
54
55            let handle = thread::spawn(move || {
56                for _i in 0..NUM_REPEATS {
57                    assert!(set.insert(IdGenerator::generate()));
58                }
59            });
60            handles.push(handle);
61        }
62
63        for handle in handles {
64            handle.join().unwrap();
65        }
66        assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
67    }
68}