1use crate::rand::gen_random;
3
4pub struct IdGenerator {}
6
7impl IdGenerator {
8 pub fn generate() -> u64 {
10 let random_bytes: [u8; 8] = gen_random();
12 u64::from_le_bytes(random_bytes)
13 }
14}
15
16pub use cubecl_common::stream_id::StreamId;
17
18#[cfg(test)]
19mod tests {
20 use super::*;
21
22 use alloc::collections::BTreeSet;
23
24 #[cfg(feature = "std")]
25 use dashmap::DashSet; #[cfg(feature = "std")]
27 use std::{sync::Arc, thread};
28
29 #[test]
30 fn uniqueness_test() {
31 const IDS_CNT: usize = 10_000;
32
33 let mut set: BTreeSet<u64> = BTreeSet::new();
34
35 for _i in 0..IDS_CNT {
36 assert!(set.insert(IdGenerator::generate()));
37 }
38
39 assert_eq!(set.len(), IDS_CNT);
40 }
41
42 #[cfg(feature = "std")]
43 #[test]
44 fn thread_safety_test() {
45 const NUM_THREADS: usize = 10;
46 const NUM_REPEATS: usize = 1_000;
47 const EXPECTED_TOTAL_IDS: usize = NUM_THREADS * NUM_REPEATS;
48
49 let set: Arc<DashSet<u64>> = Arc::new(DashSet::new());
50
51 let mut handles = vec![];
52
53 for _ in 0..NUM_THREADS {
54 let set = set.clone();
55
56 let handle = thread::spawn(move || {
57 for _i in 0..NUM_REPEATS {
58 assert!(set.insert(IdGenerator::generate()));
59 }
60 });
61 handles.push(handle);
62 }
63
64 for handle in handles {
65 handle.join().unwrap();
66 }
67 assert_eq!(set.len(), EXPECTED_TOTAL_IDS);
68 }
69}