1#[cfg(target_os = "linux")]
4use std::sync::OnceLock;
5
6use crate::{Error, Result, checked_len};
7
8#[must_use]
16pub fn current_node() -> Option<u32> {
17 #[cfg(target_os = "linux")]
18 {
19 let mut cpu = 0_u32;
20 let mut node = 0_u32;
21 let result = unsafe {
22 libc::syscall(
23 libc::SYS_getcpu,
24 &raw mut cpu,
25 &raw mut node,
26 std::ptr::null_mut::<libc::c_void>(),
27 )
28 };
29 if result == 0 {
30 return Some(node);
31 }
32 None
33 }
34
35 #[cfg(not(target_os = "linux"))]
36 {
37 None
38 }
39}
40
41pub fn pin_to_node(node: u32) -> Result<()> {
48 validate_node(node)?;
49
50 #[cfg(target_os = "linux")]
51 {
52 if let Some(library) = LinuxNuma::load()? {
53 library.run_on_node(node)?;
54 }
55 }
56
57 Ok(())
58}
59
60pub fn alloc_on_node<T: Default>(count: usize, node: u32) -> Result<Vec<T>> {
72 validate_node(node)?;
73 checked_len::<T>(count)?;
74
75 let mut values = Vec::new();
76 values.try_reserve(count).map_err(|e| crate::Error::System {
77 operation: "numa alloc try_reserve",
78 source: std::io::Error::other(format!("{e}")),
79 })?;
80 values.resize_with(count, T::default);
81
82 #[cfg(target_os = "linux")]
83 {
84 if !values.is_empty()
85 && let Some(library) = LinuxNuma::load()?
86 && library.has_multiple_nodes()
87 {
88 let byte_len = checked_len::<T>(count)?;
89 library.tonode_memory(values.as_mut_ptr().cast::<libc::c_void>(), byte_len, node)?;
90 }
91 }
92
93 Ok(values)
94}
95
96#[must_use]
104pub fn node_count() -> usize {
105 #[cfg(target_os = "linux")]
106 {
107 if let Ok(Some(library)) = LinuxNuma::load() {
108 let count = library.max_node().saturating_add(1);
109 return usize::try_from(count).unwrap_or(1);
110 }
111 }
112
113 1
114}
115
116fn validate_node(node: u32) -> Result<()> {
117 let available = node_count();
118 if usize::try_from(node)
119 .ok()
120 .is_some_and(|index| index < available)
121 {
122 Ok(())
123 } else {
124 Err(Error::InvalidNode { node, available })
125 }
126}
127
128#[cfg(target_os = "linux")]
129struct LinuxNuma {
130 _library: libloading::Library,
131 available: unsafe extern "C" fn() -> libc::c_int,
132 max_node: unsafe extern "C" fn() -> libc::c_int,
133 run_on_node: unsafe extern "C" fn(libc::c_int) -> libc::c_int,
134 tonode_memory:
135 unsafe extern "C" fn(*mut libc::c_void, libc::size_t, libc::c_int) -> libc::c_long,
136}
137
138#[cfg(target_os = "linux")]
139impl LinuxNuma {
140 fn load() -> Result<Option<&'static Self>> {
141 static LIBRARY: OnceLock<Option<LinuxNuma>> = OnceLock::new();
142 static INIT_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
143
144 if let Some(value) = LIBRARY.get() {
145 return Ok(value.as_ref());
146 }
147
148 let _guard = INIT_LOCK
149 .lock()
150 .unwrap_or_else(std::sync::PoisonError::into_inner);
151 if let Some(value) = LIBRARY.get() {
152 return Ok(value.as_ref());
153 }
154
155 let value = Self::try_load()?;
156 if LIBRARY.set(value).is_err() {
157 }
160 if let Some(l) = LIBRARY.get() {
161 return Ok(l.as_ref());
162 }
163 Err(Error::System {
164 operation: "numa library load",
165 source: std::io::Error::other(
166 "OnceLock initialization failed unexpectedly",
167 ),
168 })
169 }
170
171 fn try_load() -> Result<Option<Self>> {
172 let candidates = ["libnuma.so.1", "libnuma.so"];
173 for library_name in candidates {
174 match Self::load_from_name(library_name) {
175 Ok(Some(library)) => return Ok(Some(library)),
176 Ok(None) | Err(crate::Error::LibraryLoad { .. }) => {}
177
178 Err(other) => return Err(other),
179 }
180 }
181 Ok(None)
182 }
183
184 fn load_from_name(library_name: &'static str) -> Result<Option<Self>> {
185 let library = unsafe { libloading::Library::new(library_name) }.map_err(|source| {
186 Error::LibraryLoad {
187 library: library_name,
188 source,
189 }
190 })?;
191
192 let available = unsafe {
193 *library
194 .get::<unsafe extern "C" fn() -> libc::c_int>(b"numa_available\0")
195 .map_err(|source| Error::SymbolLoad {
196 library: library_name,
197 symbol: "numa_available",
198 source,
199 })?
200 };
201 let max_node = unsafe {
202 *library
203 .get::<unsafe extern "C" fn() -> libc::c_int>(b"numa_max_node\0")
204 .map_err(|source| Error::SymbolLoad {
205 library: library_name,
206 symbol: "numa_max_node",
207 source,
208 })?
209 };
210 let run_on_node = unsafe {
211 *library
212 .get::<unsafe extern "C" fn(libc::c_int) -> libc::c_int>(b"numa_run_on_node\0")
213 .map_err(|source| Error::SymbolLoad {
214 library: library_name,
215 symbol: "numa_run_on_node",
216 source,
217 })?
218 };
219 let tonode_memory = unsafe {
220 *library
221 .get::<unsafe extern "C" fn(*mut libc::c_void, libc::size_t, libc::c_int) -> libc::c_long>(
222 b"numa_tonode_memory\0",
223 )
224 .map_err(|source| Error::SymbolLoad {
225 library: library_name,
226 symbol: "numa_tonode_memory",
227 source,
228 })?
229 };
230
231 let state = unsafe { available() };
232 if state < 0 {
233 return Ok(None);
234 }
235
236 Ok(Some(Self {
237 _library: library,
238 available,
239 max_node,
240 run_on_node,
241 tonode_memory,
242 }))
243 }
244
245 fn max_node(&self) -> libc::c_int {
246 unsafe {
247 let _ = (self.available)();
248 (self.max_node)()
249 }
250 }
251
252 fn has_multiple_nodes(&self) -> bool {
253 self.max_node() > 0
254 }
255
256 fn run_on_node(&self, node: u32) -> Result<()> {
257 let node = libc::c_int::try_from(node).map_err(|_| Error::InvalidNode {
258 node,
259 available: node_count(),
260 })?;
261 let result = unsafe { (self.run_on_node)(node) };
262 if result == 0 {
263 Ok(())
264 } else {
265 Err(Error::System {
266 operation: "numa_run_on_node",
267 source: std::io::Error::last_os_error(),
268 })
269 }
270 }
271
272 fn tonode_memory(&self, ptr: *mut libc::c_void, len: usize, node: u32) -> Result<()> {
273 let node = libc::c_int::try_from(node).map_err(|_| Error::InvalidNode {
274 node,
275 available: node_count(),
276 })?;
277 let result = unsafe { (self.tonode_memory)(ptr, len, node) };
278 if result == 0 {
279 Ok(())
280 } else {
281 Err(Error::System {
282 operation: "numa_tonode_memory",
283 source: std::io::Error::last_os_error(),
284 })
285 }
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::{alloc_on_node, current_node, node_count, pin_to_node};
292
293 #[test]
294 fn node_count_is_at_least_one() {
295 assert!(node_count() >= 1);
296 }
297
298 #[test]
299 fn alloc_on_node_returns_initialized_values() {
300 let values = alloc_on_node::<u32>(4, 0).expect("allocation on node 0 must succeed");
301 assert_eq!(values, vec![0; 4]);
302 }
303
304 #[cfg(target_os = "linux")]
305 #[test]
306 fn linux_current_node_query_is_non_fatal() {
307 let _ = current_node();
308 }
309
310 #[cfg(not(target_os = "linux"))]
311 #[test]
312 fn non_linux_current_node_is_none() {
313 assert_eq!(current_node(), None);
314 }
315
316 #[test]
317 fn invalid_node_is_rejected() {
318 let invalid = u32::try_from(node_count()).unwrap_or(u32::MAX);
319 let error = pin_to_node(invalid).expect_err("invalid node must fail");
320 assert!(matches!(error, crate::Error::InvalidNode { .. }));
321 }
322}