Skip to main content

oxigdal_core/memory/
numa.rs

1//! NUMA-Aware Memory Allocation
2//!
3//! This module provides NUMA (Non-Uniform Memory Access) aware memory allocation:
4//! - NUMA node detection
5//! - Local allocation preference
6//! - NUMA interleaving for large buffers
7//! - Migration hints
8//! - NUMA metrics (local vs remote access)
9
10// Unsafe code is necessary for NUMA allocations
11#![allow(unsafe_code)]
12
13use crate::error::{OxiGdalError, Result};
14use std::sync::Arc;
15use std::sync::atomic::{AtomicU64, Ordering};
16
17// mbind and MPOL_* constants for Linux NUMA support
18#[cfg(target_os = "linux")]
19const MPOL_BIND: libc::c_int = 2;
20#[cfg(target_os = "linux")]
21const MPOL_INTERLEAVE: libc::c_int = 3;
22#[cfg(target_os = "linux")]
23const MPOL_PREFERRED: libc::c_int = 1;
24
25// System call number for mbind (varies by architecture)
26#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
27const SYS_MBIND: libc::c_long = 237;
28
29#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
30const SYS_MBIND: libc::c_long = 235;
31
32#[cfg(all(
33    target_os = "linux",
34    not(any(target_arch = "x86_64", target_arch = "aarch64"))
35))]
36const SYS_MBIND: libc::c_long = 0; // Unsupported, will fail at runtime
37
38// Wrapper for mbind syscall
39#[cfg(target_os = "linux")]
40unsafe fn mbind(
41    addr: *mut libc::c_void,
42    len: libc::size_t,
43    mode: libc::c_int,
44    nodemask: *const libc::c_ulong,
45    maxnode: libc::c_ulong,
46    flags: libc::c_uint,
47) -> libc::c_long {
48    // SAFETY: mbind is a valid Linux system call. The caller must ensure
49    // that addr and nodemask point to valid memory regions.
50    unsafe { libc::syscall(SYS_MBIND, addr, len, mode, nodemask, maxnode, flags) }
51}
52
53/// NUMA node identifier
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub struct NumaNode(pub i32);
56
57impl NumaNode {
58    /// Any NUMA node
59    pub const ANY: Self = Self(-1);
60
61    /// Create a NUMA node
62    #[must_use]
63    pub fn new(id: i32) -> Self {
64        Self(id)
65    }
66
67    /// Get node ID
68    #[must_use]
69    pub fn id(&self) -> i32 {
70        self.0
71    }
72}
73
74/// NUMA allocation policy
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum NumaPolicy {
77    /// Default policy (typically local)
78    Default,
79    /// Bind to specific node
80    Bind(NumaNode),
81    /// Interleave across all nodes
82    Interleave,
83    /// Prefer specific node but allow others
84    Prefer(NumaNode),
85}
86
87/// NUMA configuration
88#[derive(Debug, Clone)]
89pub struct NumaConfig {
90    /// Allocation policy
91    pub policy: NumaPolicy,
92    /// Enable NUMA awareness
93    pub enabled: bool,
94    /// Track statistics
95    pub track_stats: bool,
96}
97
98impl Default for NumaConfig {
99    fn default() -> Self {
100        Self {
101            policy: NumaPolicy::Default,
102            enabled: is_numa_available(),
103            track_stats: true,
104        }
105    }
106}
107
108impl NumaConfig {
109    /// Create new configuration
110    #[must_use]
111    pub fn new() -> Self {
112        Self::default()
113    }
114
115    /// Set policy
116    #[must_use]
117    pub fn with_policy(mut self, policy: NumaPolicy) -> Self {
118        self.policy = policy;
119        self
120    }
121
122    /// Enable NUMA awareness
123    #[must_use]
124    pub fn with_enabled(mut self, enabled: bool) -> Self {
125        self.enabled = enabled;
126        self
127    }
128
129    /// Enable statistics tracking
130    #[must_use]
131    pub fn with_stats(mut self, track: bool) -> Self {
132        self.track_stats = track;
133        self
134    }
135}
136
137/// NUMA statistics
138#[derive(Debug, Default)]
139pub struct NumaStats {
140    /// Local allocations
141    pub local_allocations: AtomicU64,
142    /// Remote allocations
143    pub remote_allocations: AtomicU64,
144    /// Interleaved allocations
145    pub interleaved_allocations: AtomicU64,
146    /// Migration operations
147    pub migrations: AtomicU64,
148    /// Total bytes allocated per node
149    pub bytes_per_node: Vec<AtomicU64>,
150}
151
152impl NumaStats {
153    /// Create new statistics
154    #[must_use]
155    pub fn new(num_nodes: usize) -> Self {
156        let mut bytes_per_node = Vec::new();
157        for _ in 0..num_nodes {
158            bytes_per_node.push(AtomicU64::new(0));
159        }
160
161        Self {
162            local_allocations: AtomicU64::new(0),
163            remote_allocations: AtomicU64::new(0),
164            interleaved_allocations: AtomicU64::new(0),
165            migrations: AtomicU64::new(0),
166            bytes_per_node,
167        }
168    }
169
170    /// Record a local allocation
171    pub fn record_local(&self) {
172        self.local_allocations.fetch_add(1, Ordering::Relaxed);
173    }
174
175    /// Record a remote allocation
176    pub fn record_remote(&self) {
177        self.remote_allocations.fetch_add(1, Ordering::Relaxed);
178    }
179
180    /// Record an interleaved allocation
181    pub fn record_interleaved(&self) {
182        self.interleaved_allocations.fetch_add(1, Ordering::Relaxed);
183    }
184
185    /// Record a migration
186    pub fn record_migration(&self) {
187        self.migrations.fetch_add(1, Ordering::Relaxed);
188    }
189
190    /// Record bytes allocated on a node
191    pub fn record_bytes(&self, node: usize, bytes: u64) {
192        if node < self.bytes_per_node.len() {
193            self.bytes_per_node[node].fetch_add(bytes, Ordering::Relaxed);
194        }
195    }
196
197    /// Get local allocation percentage
198    pub fn local_percentage(&self) -> f64 {
199        let local = self.local_allocations.load(Ordering::Relaxed);
200        let remote = self.remote_allocations.load(Ordering::Relaxed);
201        let total = local + remote;
202
203        if total == 0 {
204            0.0
205        } else {
206            (local as f64 / total as f64) * 100.0
207        }
208    }
209}
210
211/// Check if NUMA is available on this system
212#[must_use]
213pub fn is_numa_available() -> bool {
214    #[cfg(target_os = "linux")]
215    {
216        // Check if /sys/devices/system/node exists
217        std::path::Path::new("/sys/devices/system/node").exists()
218    }
219
220    #[cfg(not(target_os = "linux"))]
221    {
222        false
223    }
224}
225
226/// Get number of NUMA nodes
227pub fn get_numa_node_count() -> Result<usize> {
228    #[cfg(target_os = "linux")]
229    {
230        let mut count = 0;
231        let node_dir = std::path::Path::new("/sys/devices/system/node");
232
233        if !node_dir.exists() {
234            return Ok(1); // No NUMA, single node
235        }
236
237        let entries =
238            std::fs::read_dir(node_dir).map_err(|e| OxiGdalError::io_error(e.to_string()))?;
239
240        for entry in entries {
241            let entry = entry.map_err(|e| OxiGdalError::io_error(e.to_string()))?;
242            let name = entry.file_name();
243            let name_str = name.to_string_lossy();
244
245            if name_str.starts_with("node") && name_str[4..].parse::<u32>().is_ok() {
246                count += 1;
247            }
248        }
249
250        Ok(if count > 0 { count } else { 1 })
251    }
252
253    #[cfg(not(target_os = "linux"))]
254    {
255        Ok(1)
256    }
257}
258
259/// Get current NUMA node for the calling thread
260pub fn get_current_node() -> Result<NumaNode> {
261    #[cfg(target_os = "linux")]
262    {
263        let cpu = unsafe { libc::sched_getcpu() };
264        if cpu < 0 {
265            return Err(OxiGdalError::io_error("Failed to get CPU".to_string()));
266        }
267
268        // Read NUMA node from sysfs
269        let path = format!("/sys/devices/system/cpu/cpu{}/node", cpu);
270        let node_dirs = std::fs::read_dir(&path)
271            .map_err(|_| OxiGdalError::io_error("Failed to read NUMA node".to_string()))?;
272
273        for entry in node_dirs {
274            let entry = entry.map_err(|e| OxiGdalError::io_error(e.to_string()))?;
275            let name = entry.file_name();
276            let name_str = name.to_string_lossy();
277
278            if name_str.starts_with("node") {
279                if let Ok(node_id) = name_str[4..].parse::<i32>() {
280                    return Ok(NumaNode(node_id));
281                }
282            }
283        }
284
285        Ok(NumaNode(0))
286    }
287
288    #[cfg(not(target_os = "linux"))]
289    {
290        Ok(NumaNode(0))
291    }
292}
293
294/// NUMA-aware allocator
295pub struct NumaAllocator {
296    /// Configuration
297    config: NumaConfig,
298    /// Statistics
299    stats: Arc<NumaStats>,
300}
301
302impl NumaAllocator {
303    /// Create a new NUMA allocator
304    pub fn new() -> Result<Self> {
305        Self::with_config(NumaConfig::default())
306    }
307
308    /// Create with custom configuration
309    pub fn with_config(config: NumaConfig) -> Result<Self> {
310        let num_nodes = get_numa_node_count()?;
311        Ok(Self {
312            config,
313            stats: Arc::new(NumaStats::new(num_nodes)),
314        })
315    }
316
317    /// Allocate memory with NUMA awareness
318    pub fn allocate(&self, size: usize) -> Result<*mut u8> {
319        if self.config.enabled {
320            self.allocate_numa(size)
321        } else {
322            // NUMA not enabled, use standard allocation
323            let layout = std::alloc::Layout::from_size_align(size, 16)
324                .map_err(|e| OxiGdalError::allocation_error(e.to_string()))?;
325
326            unsafe {
327                let ptr = std::alloc::alloc(layout);
328                if ptr.is_null() {
329                    return Err(OxiGdalError::allocation_error(
330                        "Allocation failed".to_string(),
331                    ));
332                }
333                Ok(ptr)
334            }
335        }
336    }
337
338    /// Allocate with NUMA policy
339    fn allocate_numa(&self, size: usize) -> Result<*mut u8> {
340        #[cfg(target_os = "linux")]
341        {
342            use std::ptr::null_mut;
343
344            let ptr = match self.config.policy {
345                NumaPolicy::Default => {
346                    self.stats.record_local();
347                    unsafe {
348                        libc::mmap(
349                            null_mut(),
350                            size,
351                            libc::PROT_READ | libc::PROT_WRITE,
352                            libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
353                            -1,
354                            0,
355                        )
356                    }
357                }
358                NumaPolicy::Bind(node) => {
359                    if self.config.track_stats {
360                        let current = get_current_node()?;
361                        if current == node {
362                            self.stats.record_local();
363                        } else {
364                            self.stats.record_remote();
365                        }
366                    }
367
368                    unsafe {
369                        let ptr = libc::mmap(
370                            null_mut(),
371                            size,
372                            libc::PROT_READ | libc::PROT_WRITE,
373                            libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
374                            -1,
375                            0,
376                        );
377
378                        if ptr != libc::MAP_FAILED {
379                            // Apply NUMA binding
380                            let node_mask: u64 = 1 << node.id();
381                            mbind(
382                                ptr,
383                                size,
384                                MPOL_BIND,
385                                &node_mask as *const u64 as *const libc::c_ulong,
386                                64,
387                                0,
388                            );
389                        }
390
391                        ptr
392                    }
393                }
394                NumaPolicy::Interleave => {
395                    self.stats.record_interleaved();
396                    unsafe {
397                        let ptr = libc::mmap(
398                            null_mut(),
399                            size,
400                            libc::PROT_READ | libc::PROT_WRITE,
401                            libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
402                            -1,
403                            0,
404                        );
405
406                        if ptr != libc::MAP_FAILED {
407                            mbind(ptr, size, MPOL_INTERLEAVE, null_mut(), 0, 0);
408                        }
409
410                        ptr
411                    }
412                }
413                NumaPolicy::Prefer(node) => {
414                    if self.config.track_stats {
415                        let current = get_current_node()?;
416                        if current == node {
417                            self.stats.record_local();
418                        } else {
419                            self.stats.record_remote();
420                        }
421                    }
422
423                    unsafe {
424                        let ptr = libc::mmap(
425                            null_mut(),
426                            size,
427                            libc::PROT_READ | libc::PROT_WRITE,
428                            libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
429                            -1,
430                            0,
431                        );
432
433                        if ptr != libc::MAP_FAILED {
434                            let node_mask: u64 = 1 << node.id();
435                            mbind(
436                                ptr,
437                                size,
438                                MPOL_PREFERRED,
439                                &node_mask as *const u64 as *const libc::c_ulong,
440                                64,
441                                0,
442                            );
443                        }
444
445                        ptr
446                    }
447                }
448            };
449
450            if ptr == libc::MAP_FAILED {
451                return Err(OxiGdalError::allocation_error(
452                    "NUMA allocation failed".to_string(),
453                ));
454            }
455
456            Ok(ptr as *mut u8)
457        }
458
459        #[cfg(not(target_os = "linux"))]
460        {
461            // Fallback to standard allocation
462            let layout = std::alloc::Layout::from_size_align(size, 16)
463                .map_err(|e| OxiGdalError::allocation_error(e.to_string()))?;
464
465            unsafe {
466                let ptr = std::alloc::alloc(layout);
467                if ptr.is_null() {
468                    return Err(OxiGdalError::allocation_error(
469                        "Allocation failed".to_string(),
470                    ));
471                }
472                Ok(ptr)
473            }
474        }
475    }
476
477    /// Deallocate memory
478    ///
479    /// # Safety
480    ///
481    /// The caller must ensure:
482    /// - `ptr` was allocated by this allocator
483    /// - `size` matches the original allocation size
484    /// - `ptr` has not been deallocated previously
485    #[allow(clippy::not_unsafe_ptr_arg_deref)]
486    pub fn deallocate(&self, ptr: *mut u8, size: usize) -> Result<()> {
487        #[cfg(target_os = "linux")]
488        {
489            if self.config.enabled {
490                unsafe {
491                    libc::munmap(ptr as *mut libc::c_void, size);
492                }
493                return Ok(());
494            }
495        }
496
497        // Standard deallocation
498        unsafe {
499            let layout = std::alloc::Layout::from_size_align_unchecked(size, 16);
500            std::alloc::dealloc(ptr, layout);
501        }
502
503        Ok(())
504    }
505
506    /// Get statistics
507    #[must_use]
508    pub fn stats(&self) -> Arc<NumaStats> {
509        Arc::clone(&self.stats)
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_numa_detection() {
519        let available = is_numa_available();
520        println!("NUMA available: {}", available);
521
522        let node_count = get_numa_node_count().expect("Failed to get NUMA node count in test");
523        println!("NUMA nodes: {}", node_count);
524        assert!(node_count >= 1);
525    }
526
527    #[test]
528    fn test_current_node() {
529        let node = get_current_node().expect("Failed to get current NUMA node in test");
530        println!("Current NUMA node: {}", node.id());
531        assert!(node.id() >= 0);
532    }
533
534    #[test]
535    fn test_numa_allocator() {
536        let allocator = NumaAllocator::new().expect("Failed to create NUMA allocator in test");
537        let ptr = allocator
538            .allocate(4096)
539            .expect("Failed to allocate 4096 bytes with NUMA allocator in test");
540        assert!(!ptr.is_null());
541
542        allocator
543            .deallocate(ptr, 4096)
544            .expect("Failed to deallocate NUMA memory in test");
545    }
546
547    #[test]
548    fn test_numa_stats() {
549        let stats = NumaStats::new(4);
550        stats.record_local();
551        stats.record_local();
552        stats.record_remote();
553
554        assert_eq!(stats.local_allocations.load(Ordering::Relaxed), 2);
555        assert_eq!(stats.remote_allocations.load(Ordering::Relaxed), 1);
556        assert!((stats.local_percentage() - 66.66).abs() < 0.1);
557    }
558}