dynamo_memory/numa/
mod.rs1pub mod topology;
28pub mod worker_pool;
29
30use nix::libc;
31use serde::{Deserialize, Serialize};
32use std::{mem, process::Command};
33
34pub fn is_numa_enabled() -> bool {
39 std::env::var("DYN_KVBM_ENABLE_NUMA")
40 .map(|v| v == "1" || v.to_lowercase() == "true")
41 .unwrap_or(false)
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
49pub struct NumaNode(pub u32);
50
51impl NumaNode {
52 pub const UNKNOWN: NumaNode = NumaNode(u32::MAX);
54
55 pub fn is_unknown(&self) -> bool {
57 self.0 == u32::MAX
58 }
59}
60
61impl std::fmt::Display for NumaNode {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 if self.is_unknown() {
64 write!(f, "UNKNOWN")
65 } else {
66 write!(f, "NumaNode({})", self.0)
67 }
68 }
69}
70
71pub fn get_current_cpu_numa_node() -> NumaNode {
76 unsafe {
77 let mut cpu: libc::c_uint = 0;
78 let mut node: libc::c_uint = 0;
79
80 let result = libc::syscall(
82 libc::SYS_getcpu,
83 &mut cpu,
84 &mut node,
85 std::ptr::null_mut::<libc::c_void>(),
86 );
87 if result == 0 {
88 NumaNode(node)
89 } else {
90 NumaNode::UNKNOWN
91 }
92 }
93}
94
95fn cuda_device_id_to_nvidia_smi_id(device_id: u32) -> String {
104 let visible = match std::env::var("CUDA_VISIBLE_DEVICES") {
105 Ok(v) if !v.trim().is_empty() => v,
106 _ => return device_id.to_string(), };
108
109 let devices: Vec<&str> = visible
111 .split(',')
112 .map(|s| s.trim())
113 .filter(|s| !s.is_empty())
114 .collect();
115 if device_id as usize >= devices.len() {
116 tracing::warn!(
117 "device_id {} out of range for CUDA_VISIBLE_DEVICES ({} devices), using identity",
118 device_id,
119 devices.len()
120 );
121 return device_id.to_string();
122 }
123
124 let id = devices[device_id as usize];
125 id.to_string()
126}
127
128pub fn get_device_numa_node(device_id: u32) -> NumaNode {
143 let nvidia_smi_id = cuda_device_id_to_nvidia_smi_id(device_id);
144
145 let output = match Command::new("nvidia-smi")
148 .args(["topo", "--get-numa-id-of-nearby-cpu", "-i", &nvidia_smi_id])
149 .output()
150 {
151 Ok(out) if out.status.success() => out,
152 _ => {
153 tracing::warn!(
154 "nvidia-smi failed for GPU {} (nvidia-smi -i {}), using heuristic",
155 device_id,
156 nvidia_smi_id
157 );
158 return NumaNode(device_id % 2);
159 }
160 };
161
162 if let Ok(stdout) = std::str::from_utf8(&output.stdout)
163 && let Some(line) = stdout.lines().next()
164 && let Some(numa_str) = line.split(':').nth(1)
165 && let Ok(node) = numa_str.trim().parse::<u32>()
166 {
167 tracing::trace!(
168 "GPU {} (physical {}) on NUMA node {}",
169 device_id,
170 nvidia_smi_id,
171 node
172 );
173 return NumaNode(node);
174 }
175 tracing::warn!("Failed to get NUMA node for GPU {}", device_id);
176 NumaNode::UNKNOWN
177}
178
179pub fn pin_thread_to_numa_node(node: NumaNode) -> Result<(), String> {
194 let topology =
195 topology::get_numa_topology().map_err(|e| format!("Can not get NUMA topology: {}", e))?;
196
197 let cpus = topology
198 .cpus_for_node(node.0)
199 .ok_or_else(|| format!("No CPUs found for NUMA node {}", node.0))?;
200
201 if cpus.is_empty() {
202 return Err(format!("No CPUs found for NUMA node {}", node.0));
203 }
204
205 unsafe {
206 let mut cpu_set: libc::cpu_set_t = mem::zeroed();
207
208 for cpu in cpus {
209 libc::CPU_SET(*cpu, &mut cpu_set);
210 }
211
212 let result = libc::sched_setaffinity(
213 0, mem::size_of::<libc::cpu_set_t>(),
215 &cpu_set,
216 );
217
218 if result != 0 {
219 let err = std::io::Error::last_os_error();
220 return Err(format!("Failed to set CPU affinity: {}", err));
221 }
222 }
223
224 Ok(())
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_numa_node_equality() {
233 let node0a = NumaNode(0);
234 let node0b = NumaNode(0);
235 let node1 = NumaNode(1);
236
237 assert_eq!(node0a, node0b);
238 assert_ne!(node0a, node1);
239 }
240
241 #[test]
242 fn test_numa_node_unknown() {
243 let unknown = NumaNode::UNKNOWN;
244 assert!(unknown.is_unknown());
245 assert_eq!(unknown.0, u32::MAX);
246
247 let valid = NumaNode(0);
248 assert!(!valid.is_unknown());
249 }
250
251 #[test]
252 fn test_numa_node_display() {
253 assert_eq!(format!("{}", NumaNode(0)), "NumaNode(0)");
254 assert_eq!(format!("{}", NumaNode(7)), "NumaNode(7)");
255 assert_eq!(format!("{}", NumaNode::UNKNOWN), "UNKNOWN");
256 }
257
258 #[test]
259 fn test_numa_node_serialization() {
260 let node = NumaNode(1);
262 let json = serde_json::to_string(&node).unwrap();
263 let deserialized: NumaNode = serde_json::from_str(&json).unwrap();
264 assert_eq!(node, deserialized);
265 }
266
267 #[test]
268 fn test_get_current_cpu_numa_node() {
269 let node = get_current_cpu_numa_node();
271
272 if !node.is_unknown() {
274 assert!(node.0 < 8, "NUMA node {} seems unreasonably high", node.0);
275 }
276 }
277
278 #[test]
279 fn test_get_device_numa_node_valid_gpu() {
280 let node = get_device_numa_node(0);
282
283 println!("GPU 0 detected on NUMA node: {}", node.0);
286 }
287
288 #[test]
289 fn test_numa_node_hash() {
290 use std::collections::HashMap;
292
293 let mut map = HashMap::new();
294 map.insert(NumaNode(0), "node0");
295 map.insert(NumaNode(1), "node1");
296
297 assert_eq!(map.get(&NumaNode(0)), Some(&"node0"));
298 assert_eq!(map.get(&NumaNode(1)), Some(&"node1"));
299 assert_eq!(map.get(&NumaNode(2)), None);
300 }
301
302 #[test]
303 fn test_numa_node_copy_clone() {
304 let node1 = NumaNode(5);
306 let node2 = node1; let node3 = node1; assert_eq!(node1, node2);
310 assert_eq!(node1, node3);
311 assert_eq!(node2, node3);
312 }
313}