1#![allow(unsafe_code)]
12
13use crate::error::{OxiGdalError, Result};
14use std::sync::Arc;
15use std::sync::atomic::{AtomicU64, Ordering};
16
17#[cfg(target_os = "linux")]
19const MPOL_BIND: libc::c_int = 2;
20#[cfg(target_os = "linux")]
21const MPOL_INTERLEAVE: libc::c_int = 3;
22#[cfg(target_os = "linux")]
23const MPOL_PREFERRED: libc::c_int = 1;
24
25#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
27const SYS_MBIND: libc::c_long = 237;
28
29#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
30const SYS_MBIND: libc::c_long = 235;
31
32#[cfg(all(
33 target_os = "linux",
34 not(any(target_arch = "x86_64", target_arch = "aarch64"))
35))]
36const SYS_MBIND: libc::c_long = 0; #[cfg(target_os = "linux")]
40unsafe fn mbind(
41 addr: *mut libc::c_void,
42 len: libc::size_t,
43 mode: libc::c_int,
44 nodemask: *const libc::c_ulong,
45 maxnode: libc::c_ulong,
46 flags: libc::c_uint,
47) -> libc::c_long {
48 unsafe { libc::syscall(SYS_MBIND, addr, len, mode, nodemask, maxnode, flags) }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
55pub struct NumaNode(pub i32);
56
57impl NumaNode {
58 pub const ANY: Self = Self(-1);
60
61 #[must_use]
63 pub fn new(id: i32) -> Self {
64 Self(id)
65 }
66
67 #[must_use]
69 pub fn id(&self) -> i32 {
70 self.0
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum NumaPolicy {
77 Default,
79 Bind(NumaNode),
81 Interleave,
83 Prefer(NumaNode),
85}
86
87#[derive(Debug, Clone)]
89pub struct NumaConfig {
90 pub policy: NumaPolicy,
92 pub enabled: bool,
94 pub track_stats: bool,
96}
97
98impl Default for NumaConfig {
99 fn default() -> Self {
100 Self {
101 policy: NumaPolicy::Default,
102 enabled: is_numa_available(),
103 track_stats: true,
104 }
105 }
106}
107
108impl NumaConfig {
109 #[must_use]
111 pub fn new() -> Self {
112 Self::default()
113 }
114
115 #[must_use]
117 pub fn with_policy(mut self, policy: NumaPolicy) -> Self {
118 self.policy = policy;
119 self
120 }
121
122 #[must_use]
124 pub fn with_enabled(mut self, enabled: bool) -> Self {
125 self.enabled = enabled;
126 self
127 }
128
129 #[must_use]
131 pub fn with_stats(mut self, track: bool) -> Self {
132 self.track_stats = track;
133 self
134 }
135}
136
137#[derive(Debug, Default)]
139pub struct NumaStats {
140 pub local_allocations: AtomicU64,
142 pub remote_allocations: AtomicU64,
144 pub interleaved_allocations: AtomicU64,
146 pub migrations: AtomicU64,
148 pub bytes_per_node: Vec<AtomicU64>,
150}
151
152impl NumaStats {
153 #[must_use]
155 pub fn new(num_nodes: usize) -> Self {
156 let mut bytes_per_node = Vec::new();
157 for _ in 0..num_nodes {
158 bytes_per_node.push(AtomicU64::new(0));
159 }
160
161 Self {
162 local_allocations: AtomicU64::new(0),
163 remote_allocations: AtomicU64::new(0),
164 interleaved_allocations: AtomicU64::new(0),
165 migrations: AtomicU64::new(0),
166 bytes_per_node,
167 }
168 }
169
170 pub fn record_local(&self) {
172 self.local_allocations.fetch_add(1, Ordering::Relaxed);
173 }
174
175 pub fn record_remote(&self) {
177 self.remote_allocations.fetch_add(1, Ordering::Relaxed);
178 }
179
180 pub fn record_interleaved(&self) {
182 self.interleaved_allocations.fetch_add(1, Ordering::Relaxed);
183 }
184
185 pub fn record_migration(&self) {
187 self.migrations.fetch_add(1, Ordering::Relaxed);
188 }
189
190 pub fn record_bytes(&self, node: usize, bytes: u64) {
192 if node < self.bytes_per_node.len() {
193 self.bytes_per_node[node].fetch_add(bytes, Ordering::Relaxed);
194 }
195 }
196
197 pub fn local_percentage(&self) -> f64 {
199 let local = self.local_allocations.load(Ordering::Relaxed);
200 let remote = self.remote_allocations.load(Ordering::Relaxed);
201 let total = local + remote;
202
203 if total == 0 {
204 0.0
205 } else {
206 (local as f64 / total as f64) * 100.0
207 }
208 }
209}
210
211#[must_use]
213pub fn is_numa_available() -> bool {
214 #[cfg(target_os = "linux")]
215 {
216 std::path::Path::new("/sys/devices/system/node").exists()
218 }
219
220 #[cfg(not(target_os = "linux"))]
221 {
222 false
223 }
224}
225
226pub fn get_numa_node_count() -> Result<usize> {
228 #[cfg(target_os = "linux")]
229 {
230 let mut count = 0;
231 let node_dir = std::path::Path::new("/sys/devices/system/node");
232
233 if !node_dir.exists() {
234 return Ok(1); }
236
237 let entries =
238 std::fs::read_dir(node_dir).map_err(|e| OxiGdalError::io_error(e.to_string()))?;
239
240 for entry in entries {
241 let entry = entry.map_err(|e| OxiGdalError::io_error(e.to_string()))?;
242 let name = entry.file_name();
243 let name_str = name.to_string_lossy();
244
245 if name_str.starts_with("node") && name_str[4..].parse::<u32>().is_ok() {
246 count += 1;
247 }
248 }
249
250 Ok(if count > 0 { count } else { 1 })
251 }
252
253 #[cfg(not(target_os = "linux"))]
254 {
255 Ok(1)
256 }
257}
258
259pub fn get_current_node() -> Result<NumaNode> {
261 #[cfg(target_os = "linux")]
262 {
263 let cpu = unsafe { libc::sched_getcpu() };
264 if cpu < 0 {
265 return Err(OxiGdalError::io_error("Failed to get CPU".to_string()));
266 }
267
268 let path = format!("/sys/devices/system/cpu/cpu{}/node", cpu);
270 let node_dirs = std::fs::read_dir(&path)
271 .map_err(|_| OxiGdalError::io_error("Failed to read NUMA node".to_string()))?;
272
273 for entry in node_dirs {
274 let entry = entry.map_err(|e| OxiGdalError::io_error(e.to_string()))?;
275 let name = entry.file_name();
276 let name_str = name.to_string_lossy();
277
278 if name_str.starts_with("node") {
279 if let Ok(node_id) = name_str[4..].parse::<i32>() {
280 return Ok(NumaNode(node_id));
281 }
282 }
283 }
284
285 Ok(NumaNode(0))
286 }
287
288 #[cfg(not(target_os = "linux"))]
289 {
290 Ok(NumaNode(0))
291 }
292}
293
294pub struct NumaAllocator {
296 config: NumaConfig,
298 stats: Arc<NumaStats>,
300}
301
302impl NumaAllocator {
303 pub fn new() -> Result<Self> {
305 Self::with_config(NumaConfig::default())
306 }
307
308 pub fn with_config(config: NumaConfig) -> Result<Self> {
310 let num_nodes = get_numa_node_count()?;
311 Ok(Self {
312 config,
313 stats: Arc::new(NumaStats::new(num_nodes)),
314 })
315 }
316
317 pub fn allocate(&self, size: usize) -> Result<*mut u8> {
319 if self.config.enabled {
320 self.allocate_numa(size)
321 } else {
322 let layout = std::alloc::Layout::from_size_align(size, 16)
324 .map_err(|e| OxiGdalError::allocation_error(e.to_string()))?;
325
326 unsafe {
327 let ptr = std::alloc::alloc(layout);
328 if ptr.is_null() {
329 return Err(OxiGdalError::allocation_error(
330 "Allocation failed".to_string(),
331 ));
332 }
333 Ok(ptr)
334 }
335 }
336 }
337
338 fn allocate_numa(&self, size: usize) -> Result<*mut u8> {
340 #[cfg(target_os = "linux")]
341 {
342 use std::ptr::null_mut;
343
344 let ptr = match self.config.policy {
345 NumaPolicy::Default => {
346 self.stats.record_local();
347 unsafe {
348 libc::mmap(
349 null_mut(),
350 size,
351 libc::PROT_READ | libc::PROT_WRITE,
352 libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
353 -1,
354 0,
355 )
356 }
357 }
358 NumaPolicy::Bind(node) => {
359 if self.config.track_stats {
360 let current = get_current_node()?;
361 if current == node {
362 self.stats.record_local();
363 } else {
364 self.stats.record_remote();
365 }
366 }
367
368 unsafe {
369 let ptr = libc::mmap(
370 null_mut(),
371 size,
372 libc::PROT_READ | libc::PROT_WRITE,
373 libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
374 -1,
375 0,
376 );
377
378 if ptr != libc::MAP_FAILED {
379 let node_mask: u64 = 1 << node.id();
381 mbind(
382 ptr,
383 size,
384 MPOL_BIND,
385 &node_mask as *const u64 as *const libc::c_ulong,
386 64,
387 0,
388 );
389 }
390
391 ptr
392 }
393 }
394 NumaPolicy::Interleave => {
395 self.stats.record_interleaved();
396 unsafe {
397 let ptr = libc::mmap(
398 null_mut(),
399 size,
400 libc::PROT_READ | libc::PROT_WRITE,
401 libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
402 -1,
403 0,
404 );
405
406 if ptr != libc::MAP_FAILED {
407 mbind(ptr, size, MPOL_INTERLEAVE, null_mut(), 0, 0);
408 }
409
410 ptr
411 }
412 }
413 NumaPolicy::Prefer(node) => {
414 if self.config.track_stats {
415 let current = get_current_node()?;
416 if current == node {
417 self.stats.record_local();
418 } else {
419 self.stats.record_remote();
420 }
421 }
422
423 unsafe {
424 let ptr = libc::mmap(
425 null_mut(),
426 size,
427 libc::PROT_READ | libc::PROT_WRITE,
428 libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
429 -1,
430 0,
431 );
432
433 if ptr != libc::MAP_FAILED {
434 let node_mask: u64 = 1 << node.id();
435 mbind(
436 ptr,
437 size,
438 MPOL_PREFERRED,
439 &node_mask as *const u64 as *const libc::c_ulong,
440 64,
441 0,
442 );
443 }
444
445 ptr
446 }
447 }
448 };
449
450 if ptr == libc::MAP_FAILED {
451 return Err(OxiGdalError::allocation_error(
452 "NUMA allocation failed".to_string(),
453 ));
454 }
455
456 Ok(ptr as *mut u8)
457 }
458
459 #[cfg(not(target_os = "linux"))]
460 {
461 let layout = std::alloc::Layout::from_size_align(size, 16)
463 .map_err(|e| OxiGdalError::allocation_error(e.to_string()))?;
464
465 unsafe {
466 let ptr = std::alloc::alloc(layout);
467 if ptr.is_null() {
468 return Err(OxiGdalError::allocation_error(
469 "Allocation failed".to_string(),
470 ));
471 }
472 Ok(ptr)
473 }
474 }
475 }
476
477 #[allow(clippy::not_unsafe_ptr_arg_deref)]
486 pub fn deallocate(&self, ptr: *mut u8, size: usize) -> Result<()> {
487 #[cfg(target_os = "linux")]
488 {
489 if self.config.enabled {
490 unsafe {
491 libc::munmap(ptr as *mut libc::c_void, size);
492 }
493 return Ok(());
494 }
495 }
496
497 unsafe {
499 let layout = std::alloc::Layout::from_size_align_unchecked(size, 16);
500 std::alloc::dealloc(ptr, layout);
501 }
502
503 Ok(())
504 }
505
506 #[must_use]
508 pub fn stats(&self) -> Arc<NumaStats> {
509 Arc::clone(&self.stats)
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_numa_detection() {
519 let available = is_numa_available();
520 println!("NUMA available: {}", available);
521
522 let node_count = get_numa_node_count().expect("Failed to get NUMA node count in test");
523 println!("NUMA nodes: {}", node_count);
524 assert!(node_count >= 1);
525 }
526
527 #[test]
528 fn test_current_node() {
529 let node = get_current_node().expect("Failed to get current NUMA node in test");
530 println!("Current NUMA node: {}", node.id());
531 assert!(node.id() >= 0);
532 }
533
534 #[test]
535 fn test_numa_allocator() {
536 let allocator = NumaAllocator::new().expect("Failed to create NUMA allocator in test");
537 let ptr = allocator
538 .allocate(4096)
539 .expect("Failed to allocate 4096 bytes with NUMA allocator in test");
540 assert!(!ptr.is_null());
541
542 allocator
543 .deallocate(ptr, 4096)
544 .expect("Failed to deallocate NUMA memory in test");
545 }
546
547 #[test]
548 fn test_numa_stats() {
549 let stats = NumaStats::new(4);
550 stats.record_local();
551 stats.record_local();
552 stats.record_remote();
553
554 assert_eq!(stats.local_allocations.load(Ordering::Relaxed), 2);
555 assert_eq!(stats.remote_allocations.load(Ordering::Relaxed), 1);
556 assert!((stats.local_percentage() - 66.66).abs() < 0.1);
557 }
558}