Skip to main content

forge_alloc/hardening/
numa.rs

1//! `NumaLocal<I>` — bind an [`OsBacked`] allocator's memory range to one
2//! or more NUMA nodes.
3//!
4//! The wrapper calls `mbind()` once at construction to apply a NUMA
5//! placement policy to the inner allocator's entire region. Subsequent
6//! page faults that touch the range allocate physical pages on the
7//! chosen node(s) — this is the only point where the kernel actually
8//! decides physical placement; setting policy after pages are faulted
9//! has no effect.
10//!
11//! # Platform support
12//!
13//! - **Linux**: `mbind` is invoked via `libc::syscall(SYS_mbind, …)`.
14//!   Failure (kernel rejects, capability missing) is captured into
15//!   `crate::backing::mmap_last_os_error()` and the construction returns
16//!   `AllocError` — refuse silently-degraded NUMA placement.
17//! - **macOS / Apple Silicon**: UMA platform with no NUMA semantics.
18//!   `NumaLocal` is a no-op; the wrapper compiles to a direct pass-
19//!   through.
20//! - **Windows / other**: no `mbind` equivalent that operates on an
21//!   already-mapped region. `NumaLocal::new` returns the inner
22//!   unchanged with the policy stored but unenforced; production
23//!   Windows NUMA work belongs to `MmapBacked::with_numa_node` at
24//!   MAP-time (deferred to a future release).
25//!
26//! `LocalAtRequest` — re-bind on every backing request — is **not**
27//! implemented in v0.1. The wrap-once model doesn't fit per-allocate
28//! dispatch, and most NUMA-sensitive workloads are well-served by
29//! a one-shot bind at construction with thread-local slabs at the
30//! application layer.
31//!
32//! See `docs/ARCHITECTURE.md` for design context.
33
34use core::ptr::NonNull;
35
36use forge_alloc_core::{
37    AllocError, Allocator, Deallocator, FixedRange, NonZeroLayout, OsBacked, ProtectFlags,
38};
39
40/// NUMA placement policy. v0.1 accepts an explicit node set rather
41/// than dispatching against the calling thread's node — supply
42/// `current_numa_node()` if you want the local-at-construction
43/// behaviour.
44#[derive(Clone, Debug, PartialEq, Eq)]
45pub enum NumaPolicy {
46    /// `MPOL_BIND` — pages must come from the listed nodes; if no
47    /// node has free memory, allocation fails. Maximum strictness.
48    Bind(NodeSet),
49    /// `MPOL_PREFERRED` — a soft hint; falls back to other nodes
50    /// under memory pressure.
51    Preferred(u32),
52    /// `MPOL_INTERLEAVE` — round-robin pages across the listed nodes.
53    /// Bandwidth-bound workloads benefit; latency-bound ones suffer.
54    Interleaved(NodeSet),
55}
56
57/// Compact set of NUMA node IDs (up to 64 nodes). Built directly into
58/// a Linux nodemask word at `mbind` time. Bigger systems need a
59/// dynamic representation; that's not yet shipped.
60#[derive(Copy, Clone, Debug, PartialEq, Eq)]
61pub struct NodeSet {
62    mask: u64,
63}
64
65impl NodeSet {
66    /// Empty set.
67    pub const fn empty() -> Self {
68        Self { mask: 0 }
69    }
70
71    /// Single-node set.
72    pub const fn single(node: u32) -> Option<Self> {
73        if node >= 64 {
74            return None;
75        }
76        Some(Self { mask: 1u64 << node })
77    }
78
79    /// Add a node. Returns `None` if `node >= 64`.
80    pub const fn with(mut self, node: u32) -> Option<Self> {
81        if node >= 64 {
82            return None;
83        }
84        self.mask |= 1u64 << node;
85        Some(self)
86    }
87
88    /// Bit-mask view (low 64 nodes).
89    #[inline]
90    pub const fn mask(&self) -> u64 {
91        self.mask
92    }
93
94    /// Whether the set is empty.
95    #[inline]
96    pub const fn is_empty(&self) -> bool {
97        self.mask == 0
98    }
99
100    /// Highest node id set, plus one. `0` if the set is empty.
101    ///
102    /// Informational only — the `mbind` path always passes
103    /// [`mbind_maxnode`](Self::mbind_maxnode) (a constant 64), not this
104    /// value. Exposed for callers that want to know the occupied range.
105    #[inline]
106    pub fn max_node_plus_one(&self) -> u32 {
107        if self.mask == 0 {
108            0
109        } else {
110            64 - self.mask.leading_zeros()
111        }
112    }
113
114    /// `maxnode` value to pass to `mbind`. `mbind`'s `maxnode` is the
115    /// *number of bits* in the nodemask; the kernel reads
116    /// `ceil(maxnode / bits_per_long)` words. We hand the kernel a single
117    /// `u64`, so `maxnode = 64` makes it read exactly those 8 bytes —
118    /// independent of how many low bits are actually set.
119    #[inline]
120    pub fn mbind_maxnode(&self) -> u32 {
121        64
122    }
123}
124
125// Linux `mbind` mode constants (from <linux/mempolicy.h>). Defined at module
126// scope — not just inside the Linux `apply_policy` — so the policy → syscall-
127// args mapping below is unit-testable on every platform.
128const MPOL_PREFERRED: i32 = 1;
129const MPOL_BIND: i32 = 2;
130const MPOL_INTERLEAVE: i32 = 3;
131
132/// Pure mapping from a [`NumaPolicy`] to the `mbind` arguments
133/// `(mode, nodemask, maxnode)`.
134///
135/// Returns `Err(AllocError)` for an invalid policy — an empty `Bind` /
136/// `Interleaved` node set, or a `Preferred` node id `>= 64`. Because this is
137/// platform-independent, the nodemask construction (and the rejection of
138/// invalid policies) is exercised by tests on *every* host, not only Linux —
139/// the syscall itself is no-op'd off-Linux, so without this split the bitmask
140/// build would be untested on the CI host.
141fn mbind_args(policy: &NumaPolicy) -> Result<(i32, u64, u32), AllocError> {
142    match policy {
143        NumaPolicy::Bind(s) | NumaPolicy::Interleaved(s) if s.is_empty() => Err(AllocError),
144        NumaPolicy::Bind(s) => Ok((MPOL_BIND, s.mask(), s.mbind_maxnode())),
145        NumaPolicy::Interleaved(s) => Ok((MPOL_INTERLEAVE, s.mask(), s.mbind_maxnode())),
146        NumaPolicy::Preferred(n) => {
147            let s = NodeSet::single(*n).ok_or(AllocError)?;
148            Ok((MPOL_PREFERRED, s.mask(), s.mbind_maxnode()))
149        }
150    }
151}
152
153/// NumaLocal wrapper.
154pub struct NumaLocal<I: OsBacked> {
155    inner: I,
156    policy: NumaPolicy,
157}
158
159impl<I: OsBacked> NumaLocal<I> {
160    /// Wrap and apply `policy` to the inner allocator's region.
161    ///
162    /// Returns `Err(AllocError)` if the platform supports NUMA and the
163    /// kernel rejects the bind (insufficient capability, invalid node
164    /// id, no memory available on the bound nodes). On unsupported
165    /// platforms (macOS, Windows, other) returns `Ok` without binding
166    /// — caller can inspect with [`policy`](Self::policy) but the
167    /// region's physical placement is the kernel's default.
168    pub fn new(inner: I, policy: NumaPolicy) -> Result<Self, AllocError> {
169        // Validate the policy on EVERY platform (not just Linux) so an empty
170        // Bind/Interleaved set or an out-of-range `Preferred` node is rejected
171        // uniformly — previously `Preferred(huge)` was accepted off-Linux
172        // because the only range check lived inside the Linux syscall path.
173        let args = mbind_args(&policy)?;
174        apply_policy(&inner, args)?;
175        Ok(Self { inner, policy })
176    }
177
178    /// Borrow the inner allocator.
179    #[inline]
180    pub fn inner(&self) -> &I {
181        &self.inner
182    }
183
184    /// Active policy.
185    #[inline]
186    pub fn policy(&self) -> &NumaPolicy {
187        &self.policy
188    }
189}
190
191unsafe impl<I: OsBacked> Deallocator for NumaLocal<I> {
192    #[inline]
193    unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: NonZeroLayout) {
194        // SAFETY: forwarded.
195        unsafe { self.inner.deallocate(ptr, layout) }
196    }
197}
198
199unsafe impl<I: OsBacked> Allocator for NumaLocal<I> {
200    #[inline]
201    fn allocate(&self, layout: NonZeroLayout) -> Result<NonNull<[u8]>, AllocError> {
202        self.inner.allocate(layout)
203    }
204
205    #[inline]
206    unsafe fn usable_size(&self, ptr: NonNull<u8>, layout: NonZeroLayout) -> Option<usize> {
207        // Layout-transparent forwarder: `allocate` returns the inner's block
208        // unchanged, so forward `usable_size` too — otherwise an outer scrub
209        // wrapper (`PoisonOnFree`/`ZeroizeOnFree`) over `NumaLocal` would see
210        // `None`, fall back to `layout.size()`, and leave the slack tail
211        // un-scrubbed.
212        // SAFETY: forwarded; caller upholds usable_size's contract on inner.
213        unsafe { self.inner.usable_size(ptr, layout) }
214    }
215
216    #[inline]
217    fn capacity_bytes(&self) -> Option<usize> {
218        self.inner.capacity_bytes()
219    }
220
221    #[inline]
222    fn corruption_events(&self) -> u64 {
223        self.inner.corruption_events()
224    }
225}
226
227unsafe impl<I: OsBacked> OsBacked for NumaLocal<I> {
228    #[inline]
229    fn base_ptr(&self) -> NonNull<u8> {
230        self.inner.base_ptr()
231    }
232
233    #[inline]
234    fn region_size(&self) -> usize {
235        self.inner.region_size()
236    }
237
238    #[inline]
239    unsafe fn release_pages(&self, ptr: NonNull<u8>, size: usize) {
240        // SAFETY: forwarded; caller's contract preserved.
241        unsafe { self.inner.release_pages(ptr, size) }
242    }
243
244    #[inline]
245    unsafe fn protect(&self, ptr: NonNull<u8>, size: usize, flags: ProtectFlags) {
246        // SAFETY: forwarded.
247        unsafe { self.inner.protect(ptr, size, flags) }
248    }
249}
250
251impl<I: OsBacked + FixedRange> FixedRange for NumaLocal<I> {
252    #[inline]
253    fn base(&self) -> NonNull<u8> {
254        self.inner.base()
255    }
256
257    #[inline]
258    fn size(&self) -> usize {
259        self.inner.size()
260    }
261
262    /// Pass-through forward so a `commit`-aware consumer reaches the inner
263    /// backing when this wrapper sits over a `lazy_commit` `MmapBacked`.
264    #[inline]
265    fn commit(&self, offset: usize, len: usize) -> Result<(), AllocError> {
266        self.inner.commit(offset, len)
267    }
268}
269
270// ============================================================================
271// Platform glue: apply_policy()
272// ============================================================================
273
274#[cfg(target_os = "linux")]
275fn apply_policy<I: OsBacked>(inner: &I, args: (i32, u64, u32)) -> Result<(), AllocError> {
276    let (mode, mask, maxnode) = args;
277    let base = inner.base_ptr().as_ptr() as *mut libc::c_void;
278    let size = inner.region_size();
279    // mbind's `nodemask` is an array of unsigned longs (bitmap).
280    // For up to 64 nodes a single u64 suffices.
281    let nodemask: u64 = mask;
282    // SAFETY: the FFI signature for SYS_mbind matches the kernel's
283    // ABI: (unsigned long start, unsigned long len, unsigned long mode,
284    // const unsigned long *nodemask, unsigned long maxnode, unsigned flags).
285    // `mode` is passed as `c_ulong` to match the kernel's `unsigned long mode`
286    // (values 1–3, but the width must match the ABI, not just the value).
287    let rc = unsafe {
288        libc::syscall(
289            libc::SYS_mbind,
290            base,
291            size as libc::c_ulong,
292            mode as libc::c_ulong,
293            &nodemask as *const u64,
294            // mbind's `maxnode` is the nodemask width in bits.
295            maxnode as libc::c_ulong,
296            0u32 as libc::c_uint,
297        )
298    };
299    if rc != 0 {
300        // Capture errno into the cross-crate thread-local slot so callers
301        // reading `crate::backing::mmap_last_os_error()` after a failing
302        // `NumaLocal::new(...)` see the actual mbind errno (EINVAL for a
303        // bad node set, EPERM for missing CAP_SYS_NICE, ESRCH for an
304        // off-line node, …) rather than `None` or stale state.
305        crate::backing::mmap_record_os_error();
306        return Err(AllocError);
307    }
308    Ok(())
309}
310
311#[cfg(not(target_os = "linux"))]
312fn apply_policy<I: OsBacked>(_inner: &I, _args: (i32, u64, u32)) -> Result<(), AllocError> {
313    // macOS, Windows, BSD, other Unix: no equivalent operation on an
314    // already-mapped region. Return Ok so the wrapper compiles and the
315    // type is still useful as a marker / future-extension point. (Policy
316    // validity was already checked by `mbind_args` in `new`, so an invalid
317    // policy is rejected here too, not only on Linux.)
318    Ok(())
319}
320
321/// Best-effort detect of the calling thread's NUMA node.
322///
323/// - **Linux**: uses `sched_getcpu()` and `/sys/devices/system/node/...`
324///   to map CPU → node. Returns `None` on lookup failure or non-NUMA
325///   systems (single-node WSL, containers without sysfs).
326/// - **Other**: returns `None` — supply node IDs explicitly via the
327///   `NumaPolicy` constructor instead.
328#[cfg(target_os = "linux")]
329#[must_use]
330pub fn current_numa_node() -> Option<u32> {
331    // Use the getcpu(2) syscall directly. Signature: (cpu, node,
332    // tcache). We only need the node out-pointer.
333    let mut node: libc::c_uint = 0;
334    // SAFETY: getcpu writes through the supplied non-null out-pointer
335    // and returns 0 on success / -1 on failure (errno set). We pass
336    // null for cpu and tcache — both are documented as optional.
337    let rc = unsafe {
338        libc::syscall(
339            libc::SYS_getcpu,
340            core::ptr::null_mut::<libc::c_uint>(),
341            &mut node as *mut libc::c_uint,
342            core::ptr::null_mut::<libc::c_void>(),
343        )
344    };
345    if rc != 0 {
346        None
347    } else {
348        Some(node as u32)
349    }
350}
351
352/// Best-effort detect of the calling thread's NUMA node. On
353/// non-Linux platforms this always returns `None` — callers should
354/// supply node IDs explicitly via [`NumaPolicy`].
355#[cfg(not(target_os = "linux"))]
356#[must_use]
357pub fn current_numa_node() -> Option<u32> {
358    None
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::backing::MmapBacked;
365
366    #[test]
367    fn nodeset_single() {
368        let s = NodeSet::single(3).unwrap();
369        assert_eq!(s.mask(), 0b1000);
370        assert_eq!(s.max_node_plus_one(), 4);
371    }
372
373    #[test]
374    fn nodeset_with() {
375        let s = NodeSet::single(0)
376            .unwrap()
377            .with(2)
378            .unwrap()
379            .with(5)
380            .unwrap();
381        assert_eq!(s.mask(), 0b100101);
382        assert_eq!(s.max_node_plus_one(), 6);
383    }
384
385    #[test]
386    fn nodeset_rejects_overflow() {
387        assert!(NodeSet::single(64).is_none());
388        assert!(NodeSet::single(100).is_none());
389        let s = NodeSet::empty();
390        assert!(s.with(64).is_none());
391    }
392
393    #[test]
394    #[cfg_attr(miri, ignore = "miri-incompatible: mmap / threads")]
395    fn empty_bind_rejected() {
396        let inner = MmapBacked::new(64 * 1024).unwrap();
397        let res = NumaLocal::new(inner, NumaPolicy::Bind(NodeSet::empty()));
398        assert!(res.is_err());
399    }
400
401    /// On WSL / single-node systems, mbind with node 0 should succeed.
402    /// On macOS/Windows it's a no-op and also succeeds.
403    #[test]
404    #[cfg_attr(miri, ignore = "miri-incompatible: mmap / threads")]
405    fn bind_to_node_zero_succeeds() {
406        let inner = MmapBacked::new(64 * 1024).unwrap();
407        let s = NodeSet::single(0).unwrap();
408        // On Linux this calls mbind; on other platforms it's a no-op.
409        // Either way, succeeds.
410        let res = NumaLocal::new(inner, NumaPolicy::Bind(s));
411        assert!(
412            res.is_ok(),
413            "expected mbind(MPOL_BIND, [0]) to succeed on any host"
414        );
415    }
416
417    #[test]
418    #[cfg_attr(miri, ignore = "miri-incompatible: mmap / threads")]
419    fn interleaved_succeeds() {
420        let inner = MmapBacked::new(64 * 1024).unwrap();
421        let s = NodeSet::single(0).unwrap();
422        let res = NumaLocal::new(inner, NumaPolicy::Interleaved(s));
423        assert!(res.is_ok());
424    }
425
426    #[test]
427    #[cfg_attr(miri, ignore = "miri-incompatible: mmap / threads")]
428    fn preferred_succeeds() {
429        let inner = MmapBacked::new(64 * 1024).unwrap();
430        let res = NumaLocal::new(inner, NumaPolicy::Preferred(0));
431        assert!(res.is_ok());
432    }
433
434    // The following tests exercise the policy → mbind-args mapping (mode +
435    // nodemask + maxnode) directly. They run on EVERY platform, including the
436    // Windows CI host where the syscall path is a no-op — so the bitmask
437    // construction is no longer untested off-Linux.
438
439    #[test]
440    fn mbind_args_bind_builds_nodemask() {
441        let s = NodeSet::single(0).unwrap().with(3).unwrap();
442        let (mode, mask, maxnode) = mbind_args(&NumaPolicy::Bind(s)).unwrap();
443        assert_eq!(mode, MPOL_BIND);
444        assert_eq!(mask, 0b1001);
445        assert_eq!(maxnode, 64);
446    }
447
448    #[test]
449    fn mbind_args_interleaved_builds_nodemask() {
450        let s = NodeSet::single(1).unwrap();
451        let (mode, mask, maxnode) = mbind_args(&NumaPolicy::Interleaved(s)).unwrap();
452        assert_eq!(mode, MPOL_INTERLEAVE);
453        assert_eq!(mask, 0b10);
454        assert_eq!(maxnode, 64);
455    }
456
457    #[test]
458    fn mbind_args_preferred_single_node() {
459        let (mode, mask, maxnode) = mbind_args(&NumaPolicy::Preferred(2)).unwrap();
460        assert_eq!(mode, MPOL_PREFERRED);
461        assert_eq!(mask, 0b100);
462        assert_eq!(maxnode, 64);
463    }
464
465    #[test]
466    fn mbind_args_rejects_empty_and_out_of_range() {
467        assert!(mbind_args(&NumaPolicy::Bind(NodeSet::empty())).is_err());
468        assert!(mbind_args(&NumaPolicy::Interleaved(NodeSet::empty())).is_err());
469        assert!(mbind_args(&NumaPolicy::Preferred(64)).is_err());
470        assert!(mbind_args(&NumaPolicy::Preferred(9999)).is_err());
471    }
472
473    /// `Preferred` with an out-of-range node must be rejected by `new` on
474    /// every platform — previously it was accepted off-Linux because the only
475    /// range check lived inside the Linux syscall path.
476    #[test]
477    #[cfg_attr(miri, ignore = "miri-incompatible: mmap / threads")]
478    fn preferred_out_of_range_node_rejected_uniformly() {
479        let inner = MmapBacked::new(64 * 1024).unwrap();
480        let res = NumaLocal::new(inner, NumaPolicy::Preferred(9999));
481        assert!(res.is_err(), "out-of-range Preferred node must be rejected");
482    }
483
484    #[test]
485    fn current_numa_node_returns_some_or_none() {
486        // The function must not panic on any supported platform; the
487        // exact answer is host-dependent.
488        let _ = current_numa_node();
489    }
490}