use super::allocator::{VedaAllocator, SystemAllocator};
use std::alloc::Layout;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct NumaNode {
pub id: usize,
allocator: Arc<dyn VedaAllocator>,
}
impl NumaNode {
pub fn new(id: usize) -> Self {
Self {
id,
allocator: Arc::new(SystemAllocator),
}
}
pub fn allocate(&self, layout: Layout) -> *mut u8 {
self.allocator.allocate(layout)
}
pub fn deallocate(&self, ptr: *mut u8, layout: Layout) {
self.allocator.deallocate(ptr, layout);
}
}
pub struct NumaAllocator {
nodes: Vec<NumaNode>,
current_node: std::sync::atomic::AtomicUsize,
}
impl NumaAllocator {
pub fn new() -> crate::error::Result<Self> {
#[cfg(target_os = "linux")]
let num_nodes = Self::detect_numa_nodes();
#[cfg(not(target_os = "linux"))]
let num_nodes = 1;
let nodes = (0..num_nodes)
.map(NumaNode::new)
.collect();
Ok(Self {
nodes,
current_node: std::sync::atomic::AtomicUsize::new(0),
})
}
#[cfg(target_os = "linux")]
fn detect_numa_nodes() -> usize {
if let Ok(entries) = std::fs::read_dir("/sys/devices/system/node") {
let count = entries
.filter_map(|e| e.ok())
.filter(|e| {
e.file_name()
.to_str()
.map(|s| s.starts_with("node"))
.unwrap_or(false)
})
.count();
if count > 0 {
return count;
}
}
1
}
pub fn allocate_on_node(&self, layout: Layout, node_id: usize) -> *mut u8 {
if let Some(node) = self.nodes.get(node_id) {
node.allocate(layout)
} else {
std::ptr::null_mut()
}
}
pub fn allocate_round_robin(&self, layout: Layout) -> *mut u8 {
let node_id = self.current_node.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
% self.nodes.len();
self.allocate_on_node(layout, node_id)
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn node(&self, id: usize) -> Option<&NumaNode> {
self.nodes.get(id)
}
}
impl Default for NumaAllocator {
fn default() -> Self {
Self::new().expect("Failed to create NUMA allocator")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_numa_allocator_creation() {
let allocator = NumaAllocator::new().unwrap();
assert!(allocator.num_nodes() >= 1);
}
#[test]
fn test_numa_node_allocation() {
let allocator = NumaAllocator::new().unwrap();
let layout = Layout::from_size_align(1024, 8).unwrap();
let ptr = allocator.allocate_on_node(layout, 0);
assert!(!ptr.is_null());
if let Some(node) = allocator.node(0) {
node.deallocate(ptr, layout);
}
}
#[test]
fn test_round_robin_allocation() {
let allocator = NumaAllocator::new().unwrap();
let layout = Layout::from_size_align(64, 8).unwrap();
let mut ptrs = Vec::new();
for _ in 0..10 {
let ptr = allocator.allocate_round_robin(layout);
assert!(!ptr.is_null());
ptrs.push(ptr);
}
for (i, ptr) in ptrs.iter().enumerate() {
let node_id = i % allocator.num_nodes();
if let Some(node) = allocator.node(node_id) {
node.deallocate(*ptr, layout);
}
}
}
}