1#[cfg(target_os = "linux")]
7use crate::LanceError;
8use crate::Result;
9use std::sync::OnceLock;
10
11static NUMA_TOPOLOGY: OnceLock<NumaTopology> = OnceLock::new();
13
14#[derive(Debug, Clone)]
16pub struct NumaTopology {
17 pub node_count: usize,
19 pub cpus_per_node: Vec<Vec<usize>>,
21 pub cpu_count: usize,
23}
24
25impl NumaTopology {
26 #[cfg(target_os = "linux")]
28 #[must_use]
29 pub fn detect() -> Self {
30 let node_count = Self::detect_node_count();
31 let cpu_count = Self::detect_cpu_count();
32 let cpus_per_node = Self::detect_cpus_per_node(node_count, cpu_count);
33
34 Self {
35 node_count,
36 cpus_per_node,
37 cpu_count,
38 }
39 }
40
41 #[cfg(not(target_os = "linux"))]
42 #[must_use]
43 #[allow(clippy::redundant_closure_for_method_calls)] pub fn detect() -> Self {
45 let cpu_count = std::thread::available_parallelism()
47 .map(|p| p.get())
48 .unwrap_or(1);
49 Self {
50 node_count: 1,
51 cpus_per_node: vec![(0..cpu_count).collect()],
52 cpu_count,
53 }
54 }
55
56 #[cfg(target_os = "linux")]
57 fn detect_node_count() -> usize {
58 if let Ok(entries) = std::fs::read_dir("/sys/devices/system/node") {
60 entries
61 .filter_map(std::result::Result::ok)
62 .filter(|e| {
63 e.file_name()
64 .to_str()
65 .is_some_and(|s| s.starts_with("node"))
66 })
67 .count()
68 .max(1)
69 } else {
70 1
71 }
72 }
73
74 #[cfg(target_os = "linux")]
75 fn detect_cpu_count() -> usize {
76 std::thread::available_parallelism()
77 .map(std::num::NonZero::get)
78 .unwrap_or(1)
79 }
80
81 #[cfg(target_os = "linux")]
82 fn detect_cpus_per_node(node_count: usize, cpu_count: usize) -> Vec<Vec<usize>> {
83 let mut result = vec![Vec::new(); node_count];
84
85 for (node, node_cpus) in result.iter_mut().enumerate().take(node_count) {
86 let path = format!("/sys/devices/system/node/node{node}/cpulist");
87 if let Ok(content) = std::fs::read_to_string(&path) {
88 *node_cpus = Self::parse_cpu_list(&content);
89 }
90 }
91
92 if result.iter().all(std::vec::Vec::is_empty) {
94 let cpus_per = cpu_count / node_count.max(1);
95 for (node, cpus) in result.iter_mut().enumerate() {
96 let start = node * cpus_per;
97 let end = if node == node_count - 1 {
98 cpu_count
99 } else {
100 start + cpus_per
101 };
102 *cpus = (start..end).collect();
103 }
104 }
105
106 result
107 }
108
109 #[must_use]
111 pub fn parse_cpu_list(s: &str) -> Vec<usize> {
112 let mut cpus = Vec::new();
113 for part in s.trim().split(',') {
114 if let Some((start, end)) = part.split_once('-') {
115 if let (Ok(s), Ok(e)) = (start.parse::<usize>(), end.parse::<usize>()) {
116 cpus.extend(s..=e);
117 }
118 } else if let Ok(cpu) = part.parse::<usize>() {
119 cpus.push(cpu);
120 }
121 }
122 cpus
123 }
124
125 pub fn get() -> &'static NumaTopology {
127 NUMA_TOPOLOGY.get_or_init(Self::detect)
128 }
129
130 #[must_use]
132 pub fn cpu_to_node(&self, cpu: usize) -> usize {
133 for (node, cpus) in self.cpus_per_node.iter().enumerate() {
134 if cpus.contains(&cpu) {
135 return node;
136 }
137 }
138 0 }
140
141 #[must_use]
143 pub fn node_cpus(&self, node: usize) -> &[usize] {
144 self.cpus_per_node.get(node).map_or(&[], Vec::as_slice)
145 }
146}
147
148pub struct NumaAllocator {
150 preferred_node: usize,
151}
152
153impl NumaAllocator {
154 #[must_use]
156 pub fn new(node: usize) -> Self {
157 Self {
158 preferred_node: node,
159 }
160 }
161
162 #[must_use]
164 pub fn for_current_thread() -> Self {
165 let node = get_current_numa_node();
166 Self::new(node)
167 }
168
169 pub fn allocate(&self, size: usize) -> Result<crate::buffer::NumaAlignedBuffer> {
174 crate::buffer::NumaAlignedBuffer::new(size, self.preferred_node)
175 }
176
177 #[must_use]
179 pub fn preferred_node(&self) -> usize {
180 self.preferred_node
181 }
182}
183
184#[cfg(target_os = "linux")]
186#[must_use]
187pub fn get_current_numa_node() -> usize {
188 unsafe {
190 let mut cpu: u32 = 0;
191 let mut node: u32 = 0;
192 if libc::syscall(
193 libc::SYS_getcpu,
194 std::ptr::addr_of_mut!(cpu),
195 std::ptr::addr_of_mut!(node),
196 std::ptr::null::<libc::c_void>(),
197 ) == 0
198 {
199 node as usize
200 } else {
201 0
202 }
203 }
204}
205
206#[cfg(not(target_os = "linux"))]
207#[must_use]
208pub fn get_current_numa_node() -> usize {
209 0
210}
211
212#[cfg(target_os = "linux")]
217pub fn pin_thread_to_cpu(cpu: usize) -> Result<()> {
218 use std::mem::MaybeUninit;
219
220 unsafe {
223 let mut cpuset = MaybeUninit::<libc::cpu_set_t>::zeroed();
224 let cpuset = cpuset.assume_init_mut();
225 libc::CPU_ZERO(cpuset);
226 libc::CPU_SET(cpu, cpuset);
227
228 let result = libc::sched_setaffinity(
229 0, std::mem::size_of::<libc::cpu_set_t>(),
231 cpuset,
232 );
233
234 if result == 0 {
235 Ok(())
236 } else {
237 Err(LanceError::PinFailed(cpu))
238 }
239 }
240}
241
242#[cfg(not(target_os = "linux"))]
243pub fn pin_thread_to_cpu(_cpu: usize) -> Result<()> {
248 Ok(())
250}
251
252pub fn pin_thread_to_numa_node(node: usize) -> Result<()> {
257 let topology = NumaTopology::get();
258 let cpus = topology.node_cpus(node);
259
260 if cpus.is_empty() {
261 return Ok(()); }
263
264 pin_thread_to_cpu(cpus[0])
266}
267
268#[derive(Debug, Clone)]
270pub struct NumaThreadPoolConfig {
271 pub threads_per_node: usize,
273 pub pin_threads: bool,
275}
276
277impl Default for NumaThreadPoolConfig {
278 fn default() -> Self {
279 Self {
280 threads_per_node: 2,
281 pin_threads: true,
282 }
283 }
284}
285
286impl NumaThreadPoolConfig {
287 #[must_use]
289 pub fn total_threads(&self) -> usize {
290 let topology = NumaTopology::get();
291 topology.node_count * self.threads_per_node
292 }
293
294 #[must_use]
296 pub fn thread_cpu(&self, thread_idx: usize) -> Option<usize> {
297 if !self.pin_threads {
298 return None;
299 }
300
301 let topology = NumaTopology::get();
302 let node = thread_idx / self.threads_per_node;
303 let node_thread = thread_idx % self.threads_per_node;
304
305 let cpus = topology.node_cpus(node);
306 cpus.get(node_thread % cpus.len()).copied()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_numa_topology_detect() {
316 let topology = NumaTopology::detect();
317 assert!(topology.node_count >= 1);
318 assert!(topology.cpu_count >= 1);
319 }
320
321 #[test]
322 fn test_numa_topology_cached() {
323 let t1 = NumaTopology::get();
324 let t2 = NumaTopology::get();
325 assert_eq!(t1.node_count, t2.node_count);
326 }
327
328 #[test]
329 fn test_parse_cpu_list() {
330 let cpus = NumaTopology::parse_cpu_list("0-3,8-11");
331 assert_eq!(cpus, vec![0, 1, 2, 3, 8, 9, 10, 11]);
332
333 let cpus = NumaTopology::parse_cpu_list("0,2,4");
334 assert_eq!(cpus, vec![0, 2, 4]);
335 }
336
337 #[test]
338 fn test_numa_allocator() {
339 let allocator = NumaAllocator::new(0);
340 assert_eq!(allocator.preferred_node(), 0);
341 }
342
343 #[test]
344 fn test_thread_pool_config() {
345 let config = NumaThreadPoolConfig::default();
346 assert_eq!(config.threads_per_node, 2);
347 }
348}