Skip to main content

kernelkit/
numa.rs

1//! NUMA-aware helpers with Linux fast paths and portable fallbacks.
2
3#[cfg(target_os = "linux")]
4use std::sync::OnceLock;
5
6use crate::{Error, Result, checked_len};
7
8/// Return the current NUMA node for the calling thread when the platform can determine it.
9///
10/// # Example
11///
12/// ```rust
13/// let _node = kernelkit::numa::current_node();
14/// ```
15#[must_use]
16pub fn current_node() -> Option<u32> {
17    #[cfg(target_os = "linux")]
18    {
19        let mut cpu = 0_u32;
20        let mut node = 0_u32;
21        let result = unsafe {
22            libc::syscall(
23                libc::SYS_getcpu,
24                &raw mut cpu,
25                &raw mut node,
26                std::ptr::null_mut::<libc::c_void>(),
27            )
28        };
29        if result == 0 {
30            return Some(node);
31        }
32        None
33    }
34
35    #[cfg(not(target_os = "linux"))]
36    {
37        None
38    }
39}
40
41/// Pin the calling thread to a NUMA node when supported.
42///
43/// # Errors
44///
45/// Returns an error when the node is out of range or the operating system
46/// rejects the affinity request.
47pub fn pin_to_node(node: u32) -> Result<()> {
48    validate_node(node)?;
49
50    #[cfg(target_os = "linux")]
51    {
52        if let Some(library) = LinuxNuma::load()? {
53            library.run_on_node(node)?;
54        }
55    }
56
57    Ok(())
58}
59
60/// Allocate initialized values and, on Linux, migrate the backing pages toward a NUMA node.
61///
62/// # Example
63///
64/// ```rust
65/// let values = kernelkit::numa::alloc_on_node::<u64>(8, 0)?;
66/// assert_eq!(values.len(), 8);
67/// # Ok::<(), kernelkit::Error>(())
68/// ```
69/// # Errors
70/// Returns an error if node is out of bounds or allocation fails.
71pub fn alloc_on_node<T: Default>(count: usize, node: u32) -> Result<Vec<T>> {
72    validate_node(node)?;
73    checked_len::<T>(count)?;
74
75    let mut values = Vec::new();
76    values.try_reserve(count).map_err(|e| crate::Error::System {
77        operation: "numa alloc try_reserve",
78        source: std::io::Error::other(format!("{e}")),
79    })?;
80    values.resize_with(count, T::default);
81
82    #[cfg(target_os = "linux")]
83    {
84        if !values.is_empty()
85            && let Some(library) = LinuxNuma::load()?
86            && library.has_multiple_nodes()
87        {
88            let byte_len = checked_len::<T>(count)?;
89            library.tonode_memory(values.as_mut_ptr().cast::<libc::c_void>(), byte_len, node)?;
90        }
91    }
92
93    Ok(values)
94}
95
96/// Return the number of NUMA nodes visible to the current process.
97///
98/// # Example
99///
100/// ```rust
101/// assert!(kernelkit::numa::node_count() >= 1);
102/// ```
103#[must_use]
104pub fn node_count() -> usize {
105    #[cfg(target_os = "linux")]
106    {
107        if let Ok(Some(library)) = LinuxNuma::load() {
108            let count = library.max_node().saturating_add(1);
109            return usize::try_from(count).unwrap_or(1);
110        }
111    }
112
113    1
114}
115
116fn validate_node(node: u32) -> Result<()> {
117    let available = node_count();
118    if usize::try_from(node)
119        .ok()
120        .is_some_and(|index| index < available)
121    {
122        Ok(())
123    } else {
124        Err(Error::InvalidNode { node, available })
125    }
126}
127
128#[cfg(target_os = "linux")]
129struct LinuxNuma {
130    _library: libloading::Library,
131    available: unsafe extern "C" fn() -> libc::c_int,
132    max_node: unsafe extern "C" fn() -> libc::c_int,
133    run_on_node: unsafe extern "C" fn(libc::c_int) -> libc::c_int,
134    tonode_memory:
135        unsafe extern "C" fn(*mut libc::c_void, libc::size_t, libc::c_int) -> libc::c_long,
136}
137
138#[cfg(target_os = "linux")]
139impl LinuxNuma {
140    fn load() -> Result<Option<&'static Self>> {
141        static LIBRARY: OnceLock<Option<LinuxNuma>> = OnceLock::new();
142        static INIT_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
143
144        if let Some(value) = LIBRARY.get() {
145            return Ok(value.as_ref());
146        }
147
148        let _guard = INIT_LOCK
149            .lock()
150            .unwrap_or_else(std::sync::PoisonError::into_inner);
151        if let Some(value) = LIBRARY.get() {
152            return Ok(value.as_ref());
153        }
154
155        let value = Self::try_load()?;
156        if LIBRARY.set(value).is_err() {
157            // Another thread initialized the cell while we held the lock.
158            // This is harmless; drop our loaded value and use the cached one.
159        }
160        if let Some(l) = LIBRARY.get() {
161            return Ok(l.as_ref());
162        }
163        Err(Error::System {
164            operation: "numa library load",
165            source: std::io::Error::other(
166                "OnceLock initialization failed unexpectedly",
167            ),
168        })
169    }
170
171    fn try_load() -> Result<Option<Self>> {
172        let candidates = ["libnuma.so.1", "libnuma.so"];
173        for library_name in candidates {
174            match Self::load_from_name(library_name) {
175                Ok(Some(library)) => return Ok(Some(library)),
176                Ok(None) | Err(crate::Error::LibraryLoad { .. }) => {}
177
178                Err(other) => return Err(other),
179            }
180        }
181        Ok(None)
182    }
183
184    fn load_from_name(library_name: &'static str) -> Result<Option<Self>> {
185        let library = unsafe { libloading::Library::new(library_name) }.map_err(|source| {
186            Error::LibraryLoad {
187                library: library_name,
188                source,
189            }
190        })?;
191
192        let available = unsafe {
193            *library
194                .get::<unsafe extern "C" fn() -> libc::c_int>(b"numa_available\0")
195                .map_err(|source| Error::SymbolLoad {
196                    library: library_name,
197                    symbol: "numa_available",
198                    source,
199                })?
200        };
201        let max_node = unsafe {
202            *library
203                .get::<unsafe extern "C" fn() -> libc::c_int>(b"numa_max_node\0")
204                .map_err(|source| Error::SymbolLoad {
205                    library: library_name,
206                    symbol: "numa_max_node",
207                    source,
208                })?
209        };
210        let run_on_node = unsafe {
211            *library
212                .get::<unsafe extern "C" fn(libc::c_int) -> libc::c_int>(b"numa_run_on_node\0")
213                .map_err(|source| Error::SymbolLoad {
214                    library: library_name,
215                    symbol: "numa_run_on_node",
216                    source,
217                })?
218        };
219        let tonode_memory = unsafe {
220            *library
221                .get::<unsafe extern "C" fn(*mut libc::c_void, libc::size_t, libc::c_int) -> libc::c_long>(
222                    b"numa_tonode_memory\0",
223                )
224                .map_err(|source| Error::SymbolLoad {
225                    library: library_name,
226                    symbol: "numa_tonode_memory",
227                    source,
228                })?
229        };
230
231        let state = unsafe { available() };
232        if state < 0 {
233            return Ok(None);
234        }
235
236        Ok(Some(Self {
237            _library: library,
238            available,
239            max_node,
240            run_on_node,
241            tonode_memory,
242        }))
243    }
244
245    fn max_node(&self) -> libc::c_int {
246        unsafe {
247            let _ = (self.available)();
248            (self.max_node)()
249        }
250    }
251
252    fn has_multiple_nodes(&self) -> bool {
253        self.max_node() > 0
254    }
255
256    fn run_on_node(&self, node: u32) -> Result<()> {
257        let node = libc::c_int::try_from(node).map_err(|_| Error::InvalidNode {
258            node,
259            available: node_count(),
260        })?;
261        let result = unsafe { (self.run_on_node)(node) };
262        if result == 0 {
263            Ok(())
264        } else {
265            Err(Error::System {
266                operation: "numa_run_on_node",
267                source: std::io::Error::last_os_error(),
268            })
269        }
270    }
271
272    fn tonode_memory(&self, ptr: *mut libc::c_void, len: usize, node: u32) -> Result<()> {
273        let node = libc::c_int::try_from(node).map_err(|_| Error::InvalidNode {
274            node,
275            available: node_count(),
276        })?;
277        let result = unsafe { (self.tonode_memory)(ptr, len, node) };
278        if result == 0 {
279            Ok(())
280        } else {
281            Err(Error::System {
282                operation: "numa_tonode_memory",
283                source: std::io::Error::last_os_error(),
284            })
285        }
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::{alloc_on_node, current_node, node_count, pin_to_node};
292
293    #[test]
294    fn node_count_is_at_least_one() {
295        assert!(node_count() >= 1);
296    }
297
298    #[test]
299    fn alloc_on_node_returns_initialized_values() {
300        let values = alloc_on_node::<u32>(4, 0).expect("allocation on node 0 must succeed");
301        assert_eq!(values, vec![0; 4]);
302    }
303
304    #[cfg(target_os = "linux")]
305    #[test]
306    fn linux_current_node_query_is_non_fatal() {
307        let _ = current_node();
308    }
309
310    #[cfg(not(target_os = "linux"))]
311    #[test]
312    fn non_linux_current_node_is_none() {
313        assert_eq!(current_node(), None);
314    }
315
316    #[test]
317    fn invalid_node_is_rejected() {
318        let invalid = u32::try_from(node_count()).unwrap_or(u32::MAX);
319        let error = pin_to_node(invalid).expect_err("invalid node must fail");
320        assert!(matches!(error, crate::Error::InvalidNode { .. }));
321    }
322}