use core::alloc::Layout;
use core::ptr::NonNull;
use std::alloc::{alloc, alloc_zeroed, dealloc};
#[derive(Debug, Clone)]
pub struct NumaTopology {
pub num_nodes: usize,
pub cpus_per_node: usize,
pub total_cpus: usize,
}
impl NumaTopology {
pub fn detect() -> Self {
let total_cpus = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
#[cfg(target_os = "linux")]
{
if let Ok(num_nodes) = Self::detect_linux_numa_nodes() {
return NumaTopology {
num_nodes,
cpus_per_node: total_cpus.saturating_div(num_nodes.max(1)),
total_cpus,
};
}
}
NumaTopology {
num_nodes: 1,
cpus_per_node: total_cpus,
total_cpus,
}
}
#[inline]
pub fn is_numa_system(&self) -> bool {
self.num_nodes > 1
}
#[inline]
pub fn cpu_to_node(&self, cpu_id: usize) -> usize {
if self.num_nodes <= 1 {
0
} else {
cpu_id.saturating_div(self.cpus_per_node.max(1)) % self.num_nodes
}
}
pub fn node_cpu_range(&self, node_id: usize) -> (usize, usize) {
let start = node_id * self.cpus_per_node;
let end = ((node_id + 1) * self.cpus_per_node).min(self.total_cpus);
(start, end)
}
#[cfg(target_os = "linux")]
fn detect_linux_numa_nodes() -> Result<usize, std::io::Error> {
use std::fs;
let node_path = std::path::Path::new("/sys/devices/system/node");
if !node_path.exists() {
return Ok(1);
}
let mut count = 0;
for entry in fs::read_dir(node_path)? {
let entry = entry?;
let name = entry.file_name();
if let Some(name_str) = name.to_str() {
if name_str.starts_with("node") {
count += 1;
}
}
}
Ok(count.max(1))
}
}
impl Default for NumaTopology {
fn default() -> Self {
Self::detect()
}
}
#[derive(Debug, Clone, Copy)]
pub struct NumaWorkHint {
pub preferred_node: usize,
pub range_start: usize,
pub range_end: usize,
}
pub fn numa_distribute_work(total_size: usize, topology: &NumaTopology) -> Vec<NumaWorkHint> {
let num_nodes = topology.num_nodes;
if num_nodes <= 1 {
return vec![NumaWorkHint {
preferred_node: 0,
range_start: 0,
range_end: total_size,
}];
}
let base_chunk = total_size / num_nodes;
let remainder = total_size % num_nodes;
let mut hints = Vec::with_capacity(num_nodes);
let mut start = 0;
for node in 0..num_nodes {
let chunk_size = base_chunk + if node < remainder { 1 } else { 0 };
let end = start + chunk_size;
hints.push(NumaWorkHint {
preferred_node: node,
range_start: start,
range_end: end,
});
start = end;
}
hints
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NumaInterleavingStrategy {
FirstTouch,
Interleave,
PreferNode(usize),
BindNode(usize),
}
pub struct NumaAllocHint {
pub strategy: NumaInterleavingStrategy,
pub page_size: usize,
pub prefault: bool,
}
impl Default for NumaAllocHint {
fn default() -> Self {
NumaAllocHint {
strategy: NumaInterleavingStrategy::FirstTouch,
page_size: 0,
prefault: false,
}
}
}
impl NumaAllocHint {
pub fn first_touch() -> Self {
Self::default()
}
pub fn interleaved() -> Self {
NumaAllocHint {
strategy: NumaInterleavingStrategy::Interleave,
..Self::default()
}
}
pub fn on_node(node: usize) -> Self {
NumaAllocHint {
strategy: NumaInterleavingStrategy::PreferNode(node),
..Self::default()
}
}
pub fn bind_node(node: usize) -> Self {
NumaAllocHint {
strategy: NumaInterleavingStrategy::BindNode(node),
..Self::default()
}
}
pub fn with_prefault(mut self) -> Self {
self.prefault = true;
self
}
}
pub unsafe fn numa_alloc(layout: Layout, hint: &NumaAllocHint) -> Option<NonNull<u8>> {
let ptr = alloc(layout);
if ptr.is_null() {
return None;
}
#[cfg(target_os = "linux")]
{
apply_linux_numa_policy(ptr, layout.size(), hint);
}
if hint.prefault {
prefault_pages(ptr, layout.size());
}
NonNull::new(ptr)
}
pub unsafe fn numa_alloc_zeroed(layout: Layout, hint: &NumaAllocHint) -> Option<NonNull<u8>> {
let ptr = alloc_zeroed(layout);
if ptr.is_null() {
return None;
}
#[cfg(target_os = "linux")]
{
apply_linux_numa_policy(ptr, layout.size(), hint);
}
#[cfg(not(target_os = "linux"))]
let _ = hint;
NonNull::new(ptr)
}
fn prefault_pages(ptr: *mut u8, size: usize) {
const PAGE_SIZE: usize = 4096;
let mut offset = 0;
while offset < size {
unsafe {
core::ptr::write_volatile(ptr.add(offset), 0);
}
offset += PAGE_SIZE;
}
}
#[cfg(target_os = "linux")]
fn apply_linux_numa_policy(ptr: *mut u8, size: usize, hint: &NumaAllocHint) {
use std::os::raw::c_int;
const _MPOL_DEFAULT: c_int = 0;
const MPOL_PREFERRED: c_int = 1;
const MPOL_BIND: c_int = 2;
const MPOL_INTERLEAVE: c_int = 3;
const SYS_MBIND: libc::c_long = 237;
let (mode, node_mask, max_node) = match hint.strategy {
NumaInterleavingStrategy::FirstTouch => return, NumaInterleavingStrategy::Interleave => {
let mask: usize = !0;
(MPOL_INTERLEAVE, mask, 64)
}
NumaInterleavingStrategy::PreferNode(node) => {
let mask: usize = 1 << node;
(MPOL_PREFERRED, mask, node + 1)
}
NumaInterleavingStrategy::BindNode(node) => {
let mask: usize = 1 << node;
(MPOL_BIND, mask, node + 1)
}
};
unsafe {
let _ = libc::syscall(
SYS_MBIND,
ptr as *mut libc::c_void,
size,
mode,
&node_mask as *const usize,
max_node,
0u32,
);
}
}
pub fn get_page_size() -> usize {
#[cfg(unix)]
{
unsafe { libc::sysconf(libc::_SC_PAGESIZE) as usize }
}
#[cfg(not(unix))]
{
4096 }
}
pub fn get_huge_page_size() -> Option<usize> {
#[cfg(target_os = "linux")]
{
use std::fs;
if let Ok(content) = fs::read_to_string("/proc/meminfo") {
for line in content.lines() {
if line.starts_with("Hugepagesize:") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
if let Ok(kb) = parts[1].parse::<usize>() {
return Some(kb * 1024);
}
}
}
}
}
None
}
#[cfg(not(target_os = "linux"))]
{
None
}
}
pub struct NumaAllocator<T> {
hint: NumaAllocHint,
_marker: core::marker::PhantomData<T>,
}
impl<T> NumaAllocator<T> {
#[inline]
pub fn new(hint: NumaAllocHint) -> Self {
NumaAllocator {
hint,
_marker: core::marker::PhantomData,
}
}
#[inline]
pub fn on_node(node: usize) -> Self {
Self::new(NumaAllocHint::on_node(node))
}
#[inline]
pub fn interleaved() -> Self {
Self::new(NumaAllocHint::interleaved())
}
#[inline]
pub fn first_touch() -> Self {
Self::new(NumaAllocHint::first_touch())
}
pub fn allocate(&self, count: usize) -> Option<NonNull<T>> {
if count == 0 {
return None;
}
let layout = Layout::array::<T>(count).ok()?;
let raw = unsafe { numa_alloc(layout, &self.hint) }?;
Some(raw.cast::<T>())
}
pub fn allocate_zeroed(&self, count: usize) -> Option<NonNull<T>> {
if count == 0 {
return None;
}
let layout = Layout::array::<T>(count).ok()?;
let raw = unsafe { numa_alloc_zeroed(layout, &self.hint) }?;
Some(raw.cast::<T>())
}
pub unsafe fn deallocate(&self, ptr: NonNull<T>, count: usize) {
if count == 0 {
return;
}
if let Ok(layout) = Layout::array::<T>(count) {
unsafe { dealloc(ptr.cast::<u8>().as_ptr(), layout) };
}
}
}
pub struct NumaVec<T> {
ptr: NonNull<T>,
len: usize,
cap: usize,
hint: NumaAllocHint,
}
unsafe impl<T: Send> Send for NumaVec<T> {}
unsafe impl<T: Sync> Sync for NumaVec<T> {}
impl<T> NumaVec<T> {
pub fn new() -> Self {
NumaVec {
ptr: NonNull::dangling(),
len: 0,
cap: 0,
hint: NumaAllocHint::first_touch(),
}
}
pub fn with_hint(hint: NumaAllocHint) -> Self {
NumaVec {
ptr: NonNull::dangling(),
len: 0,
cap: 0,
hint,
}
}
pub fn with_hint_and_capacity(
hint: NumaAllocHint,
capacity: usize,
) -> Result<Self, &'static str> {
if capacity == 0 {
return Ok(Self::with_hint(hint));
}
let layout = Layout::array::<T>(capacity).map_err(|_| "layout overflow")?;
let raw = unsafe { numa_alloc(layout, &hint) }.ok_or("allocation failed")?;
Ok(NumaVec {
ptr: raw.cast::<T>(),
len: 0,
cap: capacity,
hint,
})
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn capacity(&self) -> usize {
self.cap
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.ptr.as_ptr()
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr.as_ptr()
}
#[inline]
pub fn as_slice(&self) -> &[T] {
unsafe { core::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { core::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
}
pub fn push(&mut self, value: T) -> Result<(), &'static str> {
if self.len == self.cap {
self.grow()?;
}
unsafe { core::ptr::write(self.ptr.as_ptr().add(self.len), value) };
self.len += 1;
Ok(())
}
pub fn pop(&mut self) -> Option<T> {
if self.len == 0 {
return None;
}
self.len -= 1;
Some(unsafe { core::ptr::read(self.ptr.as_ptr().add(self.len)) })
}
pub fn reserve(&mut self, additional: usize) -> Result<(), &'static str> {
let required = self
.len
.checked_add(additional)
.ok_or("capacity overflow")?;
if required <= self.cap {
return Ok(());
}
self.realloc(required)
}
fn grow(&mut self) -> Result<(), &'static str> {
let new_cap = if self.cap == 0 {
4
} else {
self.cap.checked_mul(2).ok_or("capacity overflow")?
};
self.realloc(new_cap)
}
fn realloc(&mut self, new_cap: usize) -> Result<(), &'static str> {
let new_layout = Layout::array::<T>(new_cap).map_err(|_| "layout overflow")?;
let new_raw = unsafe { numa_alloc(new_layout, &self.hint) }.ok_or("allocation failed")?;
let new_ptr = new_raw.cast::<T>();
if self.cap > 0 {
unsafe {
core::ptr::copy_nonoverlapping(self.ptr.as_ptr(), new_ptr.as_ptr(), self.len);
}
if let Ok(old_layout) = Layout::array::<T>(self.cap) {
unsafe { dealloc(self.ptr.cast::<u8>().as_ptr(), old_layout) };
}
}
self.ptr = new_ptr;
self.cap = new_cap;
Ok(())
}
}
impl<T> Default for NumaVec<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Drop for NumaVec<T> {
fn drop(&mut self) {
if self.cap == 0 {
return;
}
for i in 0..self.len {
unsafe { core::ptr::drop_in_place(self.ptr.as_ptr().add(i)) };
}
if let Ok(layout) = Layout::array::<T>(self.cap) {
unsafe { dealloc(self.ptr.cast::<u8>().as_ptr(), layout) };
}
}
}
impl<T> core::ops::Deref for NumaVec<T> {
type Target = [T];
#[inline]
fn deref(&self) -> &[T] {
self.as_slice()
}
}
impl<T> core::ops::DerefMut for NumaVec<T> {
#[inline]
fn deref_mut(&mut self) -> &mut [T] {
self.as_mut_slice()
}
}
pub struct MatNuma<T> {
data: NumaVec<T>,
rows: usize,
cols: usize,
}
impl<T: Copy + Default> MatNuma<T> {
pub fn zeros(rows: usize, cols: usize, hint: NumaAllocHint) -> Result<Self, &'static str> {
let total = rows.checked_mul(cols).ok_or("dimension overflow")?;
let mut data = NumaVec::with_hint_and_capacity(hint, total)?;
for _ in 0..total {
data.push(T::default()).map_err(|_| "push failed")?;
}
Ok(MatNuma { data, rows, cols })
}
}
impl<T> MatNuma<T> {
#[inline]
pub fn nrows(&self) -> usize {
self.rows
}
#[inline]
pub fn ncols(&self) -> usize {
self.cols
}
#[inline]
pub fn len(&self) -> usize {
self.rows * self.cols
}
#[inline]
pub fn is_empty(&self) -> bool {
self.rows == 0 || self.cols == 0
}
pub fn get(&self, row: usize, col: usize) -> Option<&T> {
if row >= self.rows || col >= self.cols {
return None;
}
let idx = row * self.cols + col;
self.data.as_slice().get(idx)
}
pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
if row >= self.rows || col >= self.cols {
return None;
}
let idx = row * self.cols + col;
self.data.as_mut_slice().get_mut(idx)
}
#[inline]
pub fn as_slice(&self) -> &[T] {
self.data.as_slice()
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
self.data.as_mut_slice()
}
pub fn row(&self, row: usize) -> Option<&[T]> {
if row >= self.rows {
return None;
}
let start = row * self.cols;
self.data.as_slice().get(start..start + self.cols)
}
pub fn row_mut(&mut self, row: usize) -> Option<&mut [T]> {
if row >= self.rows {
return None;
}
let start = row * self.cols;
self.data.as_mut_slice().get_mut(start..start + self.cols)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_numa_distribute_work() {
let topo = NumaTopology {
num_nodes: 4,
cpus_per_node: 8,
total_cpus: 32,
};
let hints = numa_distribute_work(1000, &topo);
assert_eq!(hints.len(), 4);
let total: usize = hints.iter().map(|h| h.range_end - h.range_start).sum();
assert_eq!(total, 1000);
for i in 1..hints.len() {
assert_eq!(hints[i].range_start, hints[i - 1].range_end);
}
}
#[test]
fn test_numa_distribute_work_single_node() {
let topo = NumaTopology {
num_nodes: 1,
cpus_per_node: 8,
total_cpus: 8,
};
let hints = numa_distribute_work(500, &topo);
assert_eq!(hints.len(), 1);
assert_eq!(hints[0].range_start, 0);
assert_eq!(hints[0].range_end, 500);
}
#[test]
fn test_numa_alloc_hint_builders() {
let hint = NumaAllocHint::first_touch();
assert_eq!(hint.strategy, NumaInterleavingStrategy::FirstTouch);
assert!(!hint.prefault);
let hint = NumaAllocHint::interleaved();
assert_eq!(hint.strategy, NumaInterleavingStrategy::Interleave);
let hint = NumaAllocHint::on_node(2);
assert_eq!(hint.strategy, NumaInterleavingStrategy::PreferNode(2));
let hint = NumaAllocHint::bind_node(1).with_prefault();
assert_eq!(hint.strategy, NumaInterleavingStrategy::BindNode(1));
assert!(hint.prefault);
}
#[test]
fn test_get_page_size() {
let page_size = get_page_size();
assert!(page_size.is_power_of_two());
assert!(page_size >= 4096);
assert!(page_size <= 65536);
}
#[test]
fn test_numa_allocator_alloc_dealloc() {
let alloc: NumaAllocator<f64> = NumaAllocator::first_touch();
let count = 64usize;
let ptr = alloc.allocate(count).expect("allocation must succeed");
unsafe {
for i in 0..count {
core::ptr::write(ptr.as_ptr().add(i), i as f64);
}
for i in 0..count {
assert!((core::ptr::read(ptr.as_ptr().add(i)) - i as f64).abs() < f64::EPSILON);
}
alloc.deallocate(ptr, count);
}
}
#[test]
fn test_numa_allocator_zeroed() {
let alloc: NumaAllocator<u64> = NumaAllocator::first_touch();
let count = 32usize;
let ptr = alloc
.allocate_zeroed(count)
.expect("zeroed allocation must succeed");
unsafe {
for i in 0..count {
assert_eq!(core::ptr::read(ptr.as_ptr().add(i)), 0u64);
}
alloc.deallocate(ptr, count);
}
}
#[test]
fn test_numa_allocator_on_node() {
let alloc: NumaAllocator<f32> = NumaAllocator::on_node(0);
let ptr = alloc.allocate(16).expect("allocation must succeed");
unsafe { alloc.deallocate(ptr, 16) };
}
#[test]
fn test_numa_allocator_zero_count_returns_none() {
let alloc: NumaAllocator<f64> = NumaAllocator::first_touch();
assert!(alloc.allocate(0).is_none());
assert!(alloc.allocate_zeroed(0).is_none());
}
#[test]
fn test_numa_vec_push_pop() {
let mut v: NumaVec<i32> = NumaVec::new();
assert!(v.is_empty());
for i in 0..100i32 {
v.push(i).expect("push must succeed");
}
assert_eq!(v.len(), 100);
for i in (0..100i32).rev() {
assert_eq!(v.pop(), Some(i));
}
assert!(v.is_empty());
}
#[test]
fn test_numa_vec_slice_access() {
let mut v: NumaVec<f64> = NumaVec::new();
for i in 0..10 {
v.push(i as f64).expect("push");
}
let s = v.as_slice();
assert_eq!(s.len(), 10);
for (i, &x) in s.iter().enumerate() {
assert!((x - i as f64).abs() < f64::EPSILON);
}
}
#[test]
fn test_numa_vec_with_hint_and_capacity() {
let v = NumaVec::<f64>::with_hint_and_capacity(NumaAllocHint::first_touch(), 128)
.expect("alloc");
assert_eq!(v.len(), 0);
assert_eq!(v.capacity(), 128);
}
#[test]
fn test_numa_vec_reserve() {
let mut v: NumaVec<u8> = NumaVec::new();
v.reserve(256).expect("reserve");
assert!(v.capacity() >= 256);
}
#[test]
fn test_numa_vec_interleaved() {
let mut v = NumaVec::<u32>::with_hint(NumaAllocHint::interleaved());
for i in 0..50u32 {
v.push(i).expect("push");
}
assert_eq!(v.len(), 50);
assert_eq!(v[0], 0);
assert_eq!(v[49], 49);
}
#[test]
fn test_mat_numa_zeros() {
let mat: MatNuma<f64> = MatNuma::zeros(4, 4, NumaAllocHint::first_touch()).expect("zeros");
assert_eq!(mat.nrows(), 4);
assert_eq!(mat.ncols(), 4);
assert_eq!(mat.len(), 16);
for &v in mat.as_slice() {
assert_eq!(v, 0.0f64);
}
}
#[test]
fn test_mat_numa_get_set() {
let mut mat: MatNuma<f64> =
MatNuma::zeros(3, 5, NumaAllocHint::first_touch()).expect("zeros");
*mat.get_mut(1, 3).expect("valid index") = 42.0;
assert!((mat.get(1, 3).expect("valid index") - 42.0).abs() < f64::EPSILON);
}
#[test]
fn test_mat_numa_out_of_bounds() {
let mat: MatNuma<f32> = MatNuma::zeros(2, 2, NumaAllocHint::first_touch()).expect("zeros");
assert!(mat.get(2, 0).is_none());
assert!(mat.get(0, 2).is_none());
}
#[test]
fn test_mat_numa_row_slice() {
let mut mat: MatNuma<i32> =
MatNuma::zeros(3, 4, NumaAllocHint::first_touch()).expect("zeros");
if let Some(row) = mat.row_mut(1) {
for (i, v) in row.iter_mut().enumerate() {
*v = i as i32 * 10;
}
}
let row = mat.row(1).expect("valid row");
assert_eq!(row, &[0, 10, 20, 30]);
}
#[test]
fn test_mat_numa_on_node_zero() {
let mat: MatNuma<f64> = MatNuma::zeros(8, 8, NumaAllocHint::on_node(0)).expect("zeros");
assert_eq!(mat.len(), 64);
}
#[test]
fn test_mat_numa_fallback_non_numa() {
let mat: MatNuma<f32> =
MatNuma::zeros(16, 16, NumaAllocHint::interleaved()).expect("zeros");
assert_eq!(mat.len(), 256);
for &v in mat.as_slice() {
assert_eq!(v, 0.0f32);
}
}
}