1use crate::rand::gen_random;
2
3pub struct IdGenerator {}
5
6impl IdGenerator {
7 pub fn generate() -> u64 {
9 let random_bytes: [u8; 8] = gen_random();
11 u64::from_le_bytes(random_bytes)
12 }
13}
14
15pub use cubecl_common::stream_id::StreamId;
16
17#[cfg(test)]
18mod tests {
19 use super::*;
20
21 use alloc::collections::BTreeSet;
22
23 #[cfg(feature = "std")]
24 use dashmap::DashSet; #[cfg(feature = "std")]
26 use std::{sync::Arc, thread};
27
28 #[test]
29 fn uniqueness_test() {
30 const IDS_CNT: usize = 10_000;
31
32 let mut set: BTreeSet<u64> = BTreeSet::new();
33
34 for _i in 0..IDS_CNT {
35 assert!(set.insert(IdGenerator::generate()));
36 }
37
38 assert_eq!(set.len(), IDS_CNT);
39 }
40
41 #[cfg(feature = "std")]
42 #[test]
43 fn thread_safety_test() {
44 const NUM_THREADS: usize = 10;
45 const NUM_REPEATS: usize = 1_000;
46 const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS;
47
48 let set: Arc<DashSet<u64>> = Arc::new(DashSet::new());
49
50 let mut handles = vec![];
51
52 for _ in 0..NUM_THREADS {
53 let set = set.clone();
54
55 let handle = thread::spawn(move || {
56 for _i in 0..NUM_REPEATS {
57 assert!(set.insert(IdGenerator::generate()));
58 }
59 });
60 handles.push(handle);
61 }
62
63 for handle in handles {
64 handle.join().unwrap();
65 }
66 assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
67 }
68}