1use std::sync::atomic::{AtomicBool, Ordering};
2
3use anyhow::Result;
4use rayon::ThreadPoolBuilder;
5
6static IS_INITED: AtomicBool = AtomicBool::new(false);
7
8#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
9pub enum SystemType {
10 #[default]
12 Generic,
13 Python,
15}
16
17pub fn init(system_type: SystemType) -> Result<()> {
18 if !IS_INITED.swap(true, Ordering::SeqCst) {
19 let threads = prepare_threads()?;
20
21 let mut builder = ThreadPoolBuilder::new().num_threads(threads.len());
22 if matches!(system_type, SystemType::Python) {
23 builder = builder.use_current_thread();
24 }
25 builder.build_global()?;
26
27 bind_threads(threads)?;
28 }
29 Ok(())
30}
31
32#[cfg(not(feature = "numa"))]
33const fn get_topology() -> Result<()> {
34 Ok(())
35}
36
37#[cfg(feature = "numa")]
38fn get_topology() -> Result<::hwlocality::Topology> {
39 ::hwlocality::Topology::new().map_err(Into::into)
40}
41
42#[cfg(not(feature = "numa"))]
43fn prepare_threads() -> Result<impl Iterator<Item = usize>> {
44 use std::thread;
45
46 const MAX_THREADS: usize = 32;
48
49 Ok(thread::available_parallelism()
50 .map(usize::from)
51 .unwrap_or(1)
52 .min(MAX_THREADS))
53}
54
55#[cfg(feature = "numa")]
56fn prepare_threads() -> Result<Vec<usize>> {
57 use rand::{
58 distributions::{Distribution, Uniform},
59 thread_rng,
60 };
61
62 let topology = get_topology()?;
64 let all_numa_nodes = topology.nodeset();
65 let all_cpus = topology.cpuset();
66
67 let num_numa_nodes = all_numa_nodes
69 .last_set()
70 .map(|set| set.into())
71 .unwrap_or(0usize)
72 + 1;
73 let num_cpus = all_cpus.last_set().map(|set| set.into()).unwrap_or(0usize) + 1;
74 let num_threads_per_cpu = num_cpus / num_numa_nodes;
75
76 let numa_node = Uniform::new(0usize, num_numa_nodes).sample(&mut thread_rng());
78
79 let cpu_begin = numa_node * num_threads_per_cpu;
81 let cpu_end = cpu_begin + num_threads_per_cpu;
82 Ok((cpu_begin..cpu_end).collect())
83}
84
85#[cfg(not(feature = "numa"))]
86const fn bind_threads() -> Result<()> {
87 Ok(())
88}
89
90#[cfg(feature = "numa")]
91fn bind_threads(threads: Vec<usize>) -> Result<()> {
92 use hwlocality::cpu::{binding::CpuBindingFlags, cpuset::CpuSet};
93 use rayon::iter::{IntoParallelIterator, ParallelIterator};
94
95 threads.into_par_iter().try_for_each(|idx| {
96 let topology = get_topology()?;
98 let cpus = {
99 let mut res = CpuSet::new();
100 res.set(idx);
101 res
102 };
103 topology.bind_cpu(&cpus, CpuBindingFlags::THREAD)?;
104 Ok(())
105 })
106}