1use super::get_current_cpu_numa_node;
16use cudarc::driver::CudaContext;
17use cudarc::driver::result::malloc_host;
18use cudarc::driver::sys::CU_MEMHOSTALLOC_DEVICEMAP;
19use nix::libc;
20use std::collections::HashMap;
21use std::sync::mpsc::{Receiver, Sender, channel};
22use std::sync::{Arc, Mutex, OnceLock};
23use std::thread::{self, JoinHandle};
24use std::time::Duration;
25
26use super::{NumaNode, get_device_numa_node};
27
28fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>, String> {
30 static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
31 let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
32
33 if let Some(existing) = map.get(&device_id) {
34 return Ok(existing.clone());
35 }
36
37 let ctx = CudaContext::new(device_id as usize).map_err(|e| {
38 format!(
39 "Failed to create CUDA context for device {}: {:?}",
40 device_id, e
41 )
42 })?;
43 map.insert(device_id, ctx.clone());
44 Ok(ctx)
45}
46
47struct SendPtr(*mut u8);
57
58unsafe impl Send for SendPtr {}
61
62struct AllocRequest {
64 size: usize,
66 node: NumaNode,
68 gpu_id: u32,
70 response: Sender<AllocResult>,
72}
73
74type AllocResult = Result<SendPtr, String>;
76
77struct NumaWorker {
79 node: NumaNode,
80 request_tx: Option<Sender<AllocRequest>>,
81 handle: Option<JoinHandle<()>>,
82}
83
84impl NumaWorker {
85 fn spawn(node: NumaNode) -> Result<Self, String> {
87 let (request_tx, request_rx) = channel();
88
89 let handle = thread::Builder::new()
90 .name(format!("numa-worker-{}", node.0))
91 .spawn(move || {
92 Self::worker_loop(node, request_rx);
93 })
94 .map_err(|e| format!("Failed to spawn worker thread: {}", e))?;
95
96 Ok(Self {
97 node,
98 request_tx: Some(request_tx),
99 handle: Some(handle),
100 })
101 }
102
103 fn worker_loop(node: NumaNode, requests: Receiver<AllocRequest>) {
109 tracing::trace!("Pinning worker thread to node {}", node.0);
111 if let Err(e) = super::pin_thread_to_numa_node(node) {
112 tracing::error!("Failed to pin worker thread to node {}: {}", node.0, e);
113 tracing::error!("Worker will continue but allocations may be suboptimal");
114 } else {
115 tracing::trace!("Successfully pinned worker thread to node {}", node.0);
116
117 thread::yield_now();
124 thread::sleep(Duration::from_millis(1));
125
126 let current_node = super::get_current_cpu_numa_node();
128 tracing::trace!("Current node after pinning: {}", current_node.0);
129 if current_node != node {
130 tracing::warn!(
131 "Worker thread on node {} after pinning (expected {})",
132 current_node.0,
133 node.0
134 );
135 } else {
136 tracing::trace!("NUMA worker thread for node {} started and pinned", node.0);
137 }
138 }
139
140 loop {
142 tracing::trace!("Worker waiting for request on node {}", node.0);
143 match requests.recv() {
144 Ok(req) => {
145 tracing::trace!(
146 "Worker received CUDA pinned allocation request on node {}",
147 node.0
148 );
149 let result = Self::do_cuda_pinned_allocation(req.size, req.node, req.gpu_id);
150 match result {
151 Ok(SendPtr(ptr)) => {
152 if let Err(_e) = req.response.send(Ok(SendPtr(ptr))) {
153 tracing::warn!(
155 "Receiver dropped before receiving allocation, freeing {} bytes at {:p}",
156 req.size,
157 ptr
158 );
159 unsafe {
160 let _ = cudarc::driver::result::free_host(
161 ptr as *mut std::ffi::c_void,
162 );
163 }
164 }
165 }
166 Err(err) => {
167 let _ = req.response.send(Err(err));
168 }
169 }
170 }
171 Err(_) => {
172 tracing::trace!(
174 "NUMA worker for node {} shutting down (channel closed)",
175 node.0
176 );
177 break;
178 }
179 }
180 }
181 }
182
183 fn do_cuda_pinned_allocation(size: usize, node: NumaNode, gpu_id: u32) -> AllocResult {
185 if size == 0 {
186 return Err("Cannot allocate zero bytes".to_string());
187 }
188
189 let node_before = get_current_cpu_numa_node();
191 if node_before != node {
192 tracing::warn!(
193 "Worker thread moved! Expected NUMA node {}, currently on node {}",
194 node.0,
195 node_before.0
196 );
197 }
198
199 let ctx = cuda_context(gpu_id)?;
201
202 unsafe {
203 ctx.bind_to_thread()
206 .map_err(|e| format!("Failed to bind CUDA context to worker thread: {:?}", e))?;
207
208 let node_after_ctx = get_current_cpu_numa_node();
210 if node_after_ctx != node {
211 tracing::warn!(
212 "Thread moved after CUDA context bind! Expected node {}, now on node {}",
213 node.0,
214 node_after_ctx.0
215 );
216 }
217
218 let ptr = malloc_host(size, CU_MEMHOSTALLOC_DEVICEMAP)
222 .map_err(|e| format!("malloc_host failed: {:?}", e))?;
223
224 let ptr = ptr as *mut u8;
225
226 if ptr.is_null() {
227 return Err("malloc_host returned null".to_string());
228 }
229
230 let node_before_touch = get_current_cpu_numa_node();
232 if node_before_touch != node {
233 tracing::error!(
234 "Thread on wrong node before first-touch! Expected {}, on node {} - memory will be misplaced!",
235 node.0,
236 node_before_touch.0
237 );
238 }
239
240 let page_size = match libc::sysconf(libc::_SC_PAGESIZE) {
243 n if n > 0 => n as usize,
244 _ => 4096,
245 };
246 let mut offset = 0usize;
247 while offset < size {
248 std::ptr::write_volatile(ptr.add(offset), 0);
249 offset = offset.saturating_add(page_size);
250 }
251 if size > 0 && !size.is_multiple_of(page_size) {
253 std::ptr::write_volatile(ptr.add(size - 1), 0);
254 }
255
256 let node_after_touch = get_current_cpu_numa_node();
258
259 tracing::trace!(
260 "Worker allocated {} bytes (CUDA pinned) on GPU {} (target NUMA node {}) at {:p} - thread nodes: before={} after_ctx={} before_touch={} after_touch={}",
261 size,
262 gpu_id,
263 node.0,
264 ptr,
265 node_before.0,
266 node_after_ctx.0,
267 node_before_touch.0,
268 node_after_touch.0
269 );
270
271 Ok(SendPtr(ptr))
272 }
273 }
274
275 fn allocate(&self, size: usize, gpu_id: u32) -> AllocResult {
277 let (response_tx, response_rx) = channel();
278
279 let request = AllocRequest {
280 size,
281 node: self.node,
282 gpu_id,
283 response: response_tx,
284 };
285
286 self.request_tx
287 .as_ref()
288 .ok_or_else(|| "Worker has been shut down".to_string())?
289 .send(request)
290 .map_err(|_| "Worker thread has died".to_string())?;
291
292 let timeout_secs = 10u64 + (size as u64 / (1024 * 1024 * 1024));
296 let timeout = Duration::from_secs(timeout_secs.clamp(10, 300)); tracing::trace!(
299 "Worker pool waiting for allocation of {} MB with timeout of {} seconds",
300 size / (1024 * 1024),
301 timeout.as_secs()
302 );
303
304 response_rx
305 .recv_timeout(timeout)
306 .map_err(|e| format!("Worker timeout after {} seconds: {}", timeout.as_secs(), e))?
307 }
308}
309
310impl Drop for NumaWorker {
311 fn drop(&mut self) {
312 tracing::trace!("Dropping NUMA worker for node {}", self.node.0);
313
314 self.request_tx.take();
317 tracing::trace!("Channel closed for worker node {}", self.node.0);
318
319 if let Some(handle) = self.handle.take() {
321 tracing::trace!("Waiting for worker thread {} to join", self.node.0);
322 let _ = handle.join();
323 tracing::trace!("Worker thread {} joined", self.node.0);
324 }
325 }
326}
327
328pub struct NumaWorkerPool {
334 workers: Mutex<std::collections::HashMap<u32, Arc<NumaWorker>>>,
335}
336
337impl NumaWorkerPool {
338 fn new() -> Self {
339 Self {
340 workers: Mutex::new(std::collections::HashMap::new()),
341 }
342 }
343
344 pub fn global() -> &'static Self {
348 static POOL: OnceLock<NumaWorkerPool> = OnceLock::new();
349 POOL.get_or_init(NumaWorkerPool::new)
350 }
351
352 fn get_or_spawn_worker(&self, node: NumaNode) -> Result<Arc<NumaWorker>, String> {
354 let mut workers = self.workers.lock().unwrap();
355
356 if let Some(worker) = workers.get(&node.0) {
357 return Ok(worker.clone());
358 }
359
360 let worker = NumaWorker::spawn(node)?;
362 let worker = Arc::new(worker);
363 workers.insert(node.0, worker.clone());
364
365 tracing::trace!("Spawned NUMA worker for node {}", node.0);
366
367 Ok(worker)
368 }
369
370 pub fn allocate_pinned_for_gpu(&self, size: usize, gpu_id: u32) -> Result<*mut u8, String> {
385 let node = get_device_numa_node(gpu_id);
386
387 tracing::debug!(
388 "Allocating {} bytes pinned memory for GPU {} (NUMA node {})",
389 size,
390 gpu_id,
391 node.0
392 );
393
394 let worker = self.get_or_spawn_worker(node)?;
395 worker.allocate(size, gpu_id).map(|send_ptr| send_ptr.0)
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use crate::numa::{get_current_cpu_numa_node, get_device_numa_node};
403
404 fn is_cuda_available() -> bool {
406 if std::process::Command::new("nvidia-smi")
408 .arg("--query-gpu=count")
409 .arg("--format=csv,noheader")
410 .output()
411 .is_err()
412 {
413 return false;
414 }
415
416 cuda_context(0).is_ok()
418 }
419
420 #[test]
421 fn test_worker_spawn() {
422 let node = NumaNode(0);
423 let worker = NumaWorker::spawn(node);
424 assert!(worker.is_ok());
425 }
426
427 #[test]
428 fn test_worker_allocate_pinned() {
429 if !is_cuda_available() {
430 eprintln!("Skipping test_worker_allocate_pinned: CUDA not available");
431 return;
432 }
433
434 let node = NumaNode(0);
435 let worker = NumaWorker::spawn(node).unwrap();
436
437 let send_ptr = worker.allocate(4096, 0).unwrap();
438 let ptr = send_ptr.0;
439 assert!(!ptr.is_null());
440
441 unsafe {
442 cudarc::driver::result::free_host(ptr as *mut std::ffi::c_void).unwrap();
443 }
444 }
445
446 #[test]
447 fn test_worker_pool() {
448 if !is_cuda_available() {
449 eprintln!("Skipping test_worker_pool: CUDA not available");
450 return;
451 }
452
453 let pool = NumaWorkerPool::new();
454
455 unsafe {
456 let ptr = pool.allocate_pinned_for_gpu(8192, 0).unwrap();
457 assert!(!ptr.is_null());
458
459 cudarc::driver::result::free_host(ptr as *mut std::ffi::c_void).unwrap();
460 }
461 }
462
463 #[test]
464 fn test_worker_pool_singleton() {
465 let pool1 = NumaWorkerPool::global();
467 let pool2 = NumaWorkerPool::global();
468
469 assert!(std::ptr::eq(pool1, pool2));
471 }
472
473 #[test]
474 fn test_worker_reuse() {
475 if !is_cuda_available() {
476 eprintln!("Skipping test_worker_reuse: CUDA not available");
477 return;
478 }
479
480 let pool = NumaWorkerPool::new();
482
483 unsafe {
484 let ptr1 = pool.allocate_pinned_for_gpu(1024, 0).unwrap();
486
487 let ptr2 = pool.allocate_pinned_for_gpu(1024, 0).unwrap();
489
490 assert!(!ptr1.is_null());
491 assert!(!ptr2.is_null());
492 assert_ne!(ptr1, ptr2);
493
494 cudarc::driver::result::free_host(ptr1 as *mut std::ffi::c_void).unwrap();
495 cudarc::driver::result::free_host(ptr2 as *mut std::ffi::c_void).unwrap();
496 }
497 }
498
499 #[test]
500 fn test_zero_size_allocation() {
501 let pool = NumaWorkerPool::new();
503 let result = pool.allocate_pinned_for_gpu(0, 0);
504 assert!(result.is_err());
505 assert!(result.unwrap_err().contains("zero"));
506 }
507
508 #[test]
509 fn test_get_current_cpu_numa_node() {
510 let node = get_current_cpu_numa_node();
512
513 if !node.is_unknown() {
516 println!("Current CPU on NUMA node: {}", node.0);
517 } else {
518 println!("NUMA node detection unavailable (single-node or fake NUMA)");
519 }
520 }
521
522 #[test]
523 fn test_get_device_numa_node() {
524 let node = get_device_numa_node(0);
527
528 if !node.is_unknown() {
529 println!("GPU 0 on NUMA node: {}", node.0);
530 assert!(node.0 <= 1 || node.0 == u32::MAX);
532 } else {
533 println!("GPU NUMA detection unavailable (no nvidia-smi or no GPU)");
534 }
535 }
536
537 #[test]
538 fn test_numa_node_display() {
539 let node = NumaNode(0);
541 assert_eq!(format!("{}", node), "NumaNode(0)");
542
543 let unknown = NumaNode::UNKNOWN;
544 assert_eq!(format!("{}", unknown), "UNKNOWN");
545 }
546
547 #[test]
548 fn test_numa_node_is_unknown() {
549 let valid = NumaNode(0);
550 assert!(!valid.is_unknown());
551
552 let unknown = NumaNode::UNKNOWN;
553 assert!(unknown.is_unknown());
554 }
555
556 #[test]
557 fn test_pinned_allocation_api() {
558 let pool = NumaWorkerPool::new();
560
561 unsafe {
562 if let Ok(ptr) = pool.allocate_pinned_for_gpu(1024, 0) {
564 assert!(!ptr.is_null());
565 cudarc::driver::result::free_host(ptr as *mut std::ffi::c_void).unwrap();
566 }
567 }
568 }
569
570 #[test]
571 fn test_worker_channel_communication() {
572 let node = NumaNode(0);
574 let worker = NumaWorker::spawn(node).unwrap();
575
576 let result = worker.allocate(1024, 0);
578
579 assert!(result.is_ok() || result.is_err());
581
582 if let Ok(send_ptr) = result {
583 unsafe {
584 let ptr = send_ptr.0;
585 assert!(!ptr.is_null());
586 cudarc::driver::result::free_host(ptr as *mut std::ffi::c_void).unwrap();
587 }
588 }
589 }
590}