Skip to main content

dynamo_memory/numa/
mod.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! NUMA-aware memory allocation utilities.
5//!
6//! This module provides utilities for NUMA-aware memory allocation, which is critical
7//! for optimal performance on multi-socket systems with GPUs. Memory allocated on the
8//! NUMA node closest to the target GPU has significantly lower access latency.
9//!
10//! ## Architecture
11//!
12//! - [`NumaNode`]: Represents a NUMA node ID
13//! - [`topology`]: Reads CPU-to-NUMA mapping from `/sys/devices/system/node`
14//! - [`worker_pool`]: Dedicated worker threads pinned to specific NUMA nodes
15//!
16//! ## Usage
17//!
18//! NUMA optimization is opt-in via environment variable:
19//! ```bash
20//! export DYN_KVBM_ENABLE_NUMA=1
21//! ```
22//!
23//! When enabled, pinned memory allocations are routed through NUMA workers
24//! that are pinned to the target GPU's NUMA node, ensuring first-touch policy
25//! places pages on the correct node.
26
27pub mod topology;
28pub mod worker_pool;
29
30use nix::libc;
31use serde::{Deserialize, Serialize};
32use std::{mem, process::Command};
33
34/// Check if NUMA optimization is enabled via environment variable
35///
36/// Set `DYN_KVBM_ENABLE_NUMA=1` to enable NUMA-aware allocation.
37/// Default: disabled (opt-in)
38pub fn is_numa_enabled() -> bool {
39    std::env::var("DYN_KVBM_ENABLE_NUMA")
40        .map(|v| v == "1" || v.to_lowercase() == "true")
41        .unwrap_or(false)
42}
43
44/// Represents a NUMA node identifier.
45///
46/// NUMA nodes are typically numbered 0, 1, 2, etc. corresponding to physical
47/// CPU sockets. Use [`NumaNode::UNKNOWN`] when the node cannot be determined.
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub struct NumaNode(pub u32);
50
51impl NumaNode {
52    /// Sentinel value for unknown NUMA node.
53    pub const UNKNOWN: NumaNode = NumaNode(u32::MAX);
54
55    /// Returns true if this represents an unknown NUMA node.
56    pub fn is_unknown(&self) -> bool {
57        self.0 == u32::MAX
58    }
59}
60
61impl std::fmt::Display for NumaNode {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        if self.is_unknown() {
64            write!(f, "UNKNOWN")
65        } else {
66            write!(f, "NumaNode({})", self.0)
67        }
68    }
69}
70
71/// Get the current CPU's NUMA node.
72///
73/// Uses the Linux `getcpu` syscall to determine which NUMA node the current CPU belongs to.
74/// Returns [`NumaNode::UNKNOWN`] if the syscall fails.
75pub fn get_current_cpu_numa_node() -> NumaNode {
76    unsafe {
77        let mut cpu: libc::c_uint = 0;
78        let mut node: libc::c_uint = 0;
79
80        // getcpu syscall: int getcpu(unsigned *cpu, unsigned *node, struct getcpu_cache *tcache);
81        let result = libc::syscall(
82            libc::SYS_getcpu,
83            &mut cpu,
84            &mut node,
85            std::ptr::null_mut::<libc::c_void>(),
86        );
87        if result == 0 {
88            NumaNode(node)
89        } else {
90            NumaNode::UNKNOWN
91        }
92    }
93}
94
95/// Resolve process-local CUDA device index to the physical identifier for nvidia-smi.
96///
97/// When `CUDA_VISIBLE_DEVICES` is set, the process sees a remapped device space (e.g. only
98/// GPU 2 visible as device 0). nvidia-smi's `-i` flag expects the *physical* device index or
99/// UUID, not the process-local index. This function parses `CUDA_VISIBLE_DEVICES` to map
100/// process-local `device_id` to the correct physical identifier.
101///
102/// Returns the identifier string to pass to `nvidia-smi -i` (physical index or UUID).
103fn cuda_device_id_to_nvidia_smi_id(device_id: u32) -> String {
104    let visible = match std::env::var("CUDA_VISIBLE_DEVICES") {
105        Ok(v) if !v.trim().is_empty() => v,
106        _ => return device_id.to_string(), // No remapping: identity
107    };
108
109    // Parse comma-separated list. Supports: "0,1,2", "2,3", "GPU-uuid", "2,GPU-uuid", etc.
110    let devices: Vec<&str> = visible
111        .split(',')
112        .map(|s| s.trim())
113        .filter(|s| !s.is_empty())
114        .collect();
115    if device_id as usize >= devices.len() {
116        tracing::warn!(
117            "device_id {} out of range for CUDA_VISIBLE_DEVICES ({} devices), using identity",
118            device_id,
119            devices.len()
120        );
121        return device_id.to_string();
122    }
123
124    let id = devices[device_id as usize];
125    id.to_string()
126}
127
128/// Get NUMA node for a GPU device.
129///
130/// For GPU memory, the NUMA affinity depends on which PCIe bus the GPU is attached to.
131/// This is queried via nvidia-smi. Falls back to a heuristic (device_id % 2) if nvidia-smi
132/// is unavailable.
133///
134/// When `CUDA_VISIBLE_DEVICES` is set, the process-local `device_id` is correctly mapped
135/// to the physical GPU identifier before querying nvidia-smi, so NUMA attribution is accurate.
136///
137/// # Arguments
138/// * `device_id` - CUDA device index (0, 1, 2, ...) as seen by the process
139///
140/// # Returns
141/// The NUMA node closest to the specified GPU, or a heuristic fallback.
142pub fn get_device_numa_node(device_id: u32) -> NumaNode {
143    let nvidia_smi_id = cuda_device_id_to_nvidia_smi_id(device_id);
144
145    // Use nvidia-smi topo to get NUMA ID of nearest CPU
146    // -i must be physical device index or UUID, not process-local index
147    let output = match Command::new("nvidia-smi")
148        .args(["topo", "--get-numa-id-of-nearby-cpu", "-i", &nvidia_smi_id])
149        .output()
150    {
151        Ok(out) if out.status.success() => out,
152        _ => {
153            tracing::warn!(
154                "nvidia-smi failed for GPU {} (nvidia-smi -i {}), using heuristic",
155                device_id,
156                nvidia_smi_id
157            );
158            return NumaNode(device_id % 2);
159        }
160    };
161
162    if let Ok(stdout) = std::str::from_utf8(&output.stdout)
163        && let Some(line) = stdout.lines().next()
164        && let Some(numa_str) = line.split(':').nth(1)
165        && let Ok(node) = numa_str.trim().parse::<u32>()
166    {
167        tracing::trace!(
168            "GPU {} (physical {}) on NUMA node {}",
169            device_id,
170            nvidia_smi_id,
171            node
172        );
173        return NumaNode(node);
174    }
175    tracing::warn!("Failed to get NUMA node for GPU {}", device_id);
176    NumaNode::UNKNOWN
177}
178
179/// Pin the current thread to a specific NUMA node's CPUs.
180///
181/// This sets the CPU affinity for the calling thread to only run on CPUs
182/// belonging to the specified NUMA node. This is critical for ensuring
183/// that memory allocations follow the first-touch policy on the correct node.
184///
185/// # Arguments
186/// * `node` - The NUMA node to pin the thread to
187///
188/// # Errors
189/// Returns an error if:
190/// - NUMA topology cannot be read
191/// - No CPUs are found for the specified node
192/// - The `sched_setaffinity` syscall fails
193pub fn pin_thread_to_numa_node(node: NumaNode) -> Result<(), String> {
194    let topology =
195        topology::get_numa_topology().map_err(|e| format!("Can not get NUMA topology: {}", e))?;
196
197    let cpus = topology
198        .cpus_for_node(node.0)
199        .ok_or_else(|| format!("No CPUs found for NUMA node {}", node.0))?;
200
201    if cpus.is_empty() {
202        return Err(format!("No CPUs found for NUMA node {}", node.0));
203    }
204
205    unsafe {
206        let mut cpu_set: libc::cpu_set_t = mem::zeroed();
207
208        for cpu in cpus {
209            libc::CPU_SET(*cpu, &mut cpu_set);
210        }
211
212        let result = libc::sched_setaffinity(
213            0, // current thread
214            mem::size_of::<libc::cpu_set_t>(),
215            &cpu_set,
216        );
217
218        if result != 0 {
219            let err = std::io::Error::last_os_error();
220            return Err(format!("Failed to set CPU affinity: {}", err));
221        }
222    }
223
224    Ok(())
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_numa_node_equality() {
233        let node0a = NumaNode(0);
234        let node0b = NumaNode(0);
235        let node1 = NumaNode(1);
236
237        assert_eq!(node0a, node0b);
238        assert_ne!(node0a, node1);
239    }
240
241    #[test]
242    fn test_numa_node_unknown() {
243        let unknown = NumaNode::UNKNOWN;
244        assert!(unknown.is_unknown());
245        assert_eq!(unknown.0, u32::MAX);
246
247        let valid = NumaNode(0);
248        assert!(!valid.is_unknown());
249    }
250
251    #[test]
252    fn test_numa_node_display() {
253        assert_eq!(format!("{}", NumaNode(0)), "NumaNode(0)");
254        assert_eq!(format!("{}", NumaNode(7)), "NumaNode(7)");
255        assert_eq!(format!("{}", NumaNode::UNKNOWN), "UNKNOWN");
256    }
257
258    #[test]
259    fn test_numa_node_serialization() {
260        // Verify NumaNode can be serialized (important for benchmarking)
261        let node = NumaNode(1);
262        let json = serde_json::to_string(&node).unwrap();
263        let deserialized: NumaNode = serde_json::from_str(&json).unwrap();
264        assert_eq!(node, deserialized);
265    }
266
267    #[test]
268    fn test_get_current_cpu_numa_node() {
269        // Should either return a valid node or UNKNOWN
270        let node = get_current_cpu_numa_node();
271
272        // If not unknown, should be a reasonable NUMA node number (< 8 on most systems)
273        if !node.is_unknown() {
274            assert!(node.0 < 8, "NUMA node {} seems unreasonably high", node.0);
275        }
276    }
277
278    #[test]
279    fn test_get_device_numa_node_valid_gpu() {
280        // Test GPU 0 detection
281        let node = get_device_numa_node(0);
282
283        // Should return either a valid node (0-7) or use heuristic (gpu_id % 2)
284        // On dual-socket systems, GPU 0 typically on node 0 or 1
285        println!("GPU 0 detected on NUMA node: {}", node.0);
286    }
287
288    #[test]
289    fn test_numa_node_hash() {
290        // Verify NumaNode can be used as a HashMap key
291        use std::collections::HashMap;
292
293        let mut map = HashMap::new();
294        map.insert(NumaNode(0), "node0");
295        map.insert(NumaNode(1), "node1");
296
297        assert_eq!(map.get(&NumaNode(0)), Some(&"node0"));
298        assert_eq!(map.get(&NumaNode(1)), Some(&"node1"));
299        assert_eq!(map.get(&NumaNode(2)), None);
300    }
301
302    #[test]
303    fn test_numa_node_copy_clone() {
304        // Verify NumaNode is Copy and Clone
305        let node1 = NumaNode(5);
306        let node2 = node1; // Copy
307        let node3 = node1; // Clone
308
309        assert_eq!(node1, node2);
310        assert_eq!(node1, node3);
311        assert_eq!(node2, node3);
312    }
313}