candle_mi/memory.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Process and GPU memory reporting.
4//!
5//! Provides [`MemorySnapshot`] to capture current RAM and VRAM usage,
6//! and [`MemoryReport`] to measure deltas between two snapshots.
7//!
8//! # VRAM measurement strategy
9//!
10//! VRAM is measured using a three-tier approach:
11//!
12//! 1. **Windows primary — DXGI** (per-process): Uses
13//! `IDXGIAdapter3::QueryVideoMemoryInfo` (DXGI 1.4, Windows 10+) to get
14//! true per-process GPU memory. This is the only reliable method on Windows
15//! because WDDM means the Windows kernel manages GPU memory, not the NVIDIA
16//! driver — so NVML returns `NOT_AVAILABLE` for per-process queries.
17//! 2. **Linux primary — NVML** (per-process): Dynamically loads
18//! `libnvidia-ml.so.1` via `libloading` and calls
19//! `nvmlDeviceGetComputeRunningProcesses` for true per-process GPU memory.
20//! 3. **Fallback — `nvidia-smi`** (device-wide): If both DXGI and NVML fail,
21//! spawns `nvidia-smi` as a subprocess. Reports device-wide VRAM; the delta
22//! between two snapshots is accurate on single-user machines.
23//!
24//! # Platform support
25//!
26//! | Metric | Windows | Linux |
27//! |--------|---------|-------|
28//! | RAM (RSS) | `K32GetProcessMemoryInfo` (per-process, exact) | `/proc/self/status` `VmRSS` (per-process, exact) |
29//! | VRAM (DXGI) | `IDXGIAdapter3` (per-process, exact) | N/A |
30//! | VRAM (NVML) | `NOT_AVAILABLE` under WDDM | `libnvidia-ml.so.1` (per-process, exact) |
31//! | VRAM (fallback) | `nvidia-smi` (device-wide) | `nvidia-smi` (device-wide) |
32//!
33//! # Feature gates
34//!
35//! - **`memory`**: Enables this module. Relaxes `#![forbid(unsafe_code)]` to
36//! `#![deny(unsafe_code)]` for the Windows FFI calls (`K32GetProcessMemoryInfo`,
37//! DXGI COM calls) and for NVML dynamic symbol loading via `libloading`.
38//! On Linux RAM measurement, no unsafe code is used.
39//! - **`memory-debug`** (implies `memory`): Prints raw DXGI query results and
40//! per-chunk VRAM measurements to stderr for diagnosing GPU memory issues.
41
42use crate::{MIError, Result};
43
44// ---------------------------------------------------------------------------
45// Public types
46// ---------------------------------------------------------------------------
47
48/// Memory snapshot at a point in time.
49///
50/// Captures process RAM (resident set size) and optionally GPU VRAM.
51/// Use [`MemorySnapshot::now`] to take a measurement, and
52/// [`MemoryReport::new`] to compute deltas between two snapshots.
53///
54/// # Example
55///
56/// ```no_run
57/// use candle_mi::MemorySnapshot;
58///
59/// let before = MemorySnapshot::now(&candle_core::Device::Cpu)?;
60/// // ... load a model ...
61/// let after = MemorySnapshot::now(&candle_core::Device::Cpu)?;
62/// let report = candle_mi::MemoryReport::new(before, after);
63/// println!("RAM delta: {:+.1} MB", report.ram_delta_mb());
64/// # Ok::<(), candle_mi::MIError>(())
65/// ```
66#[derive(Debug, Clone)]
67pub struct MemorySnapshot {
68 /// Process resident set size (working set on Windows) in bytes.
69 pub ram_bytes: u64,
70 /// GPU memory used in bytes.
71 /// Per-process when measured via DXGI/NVML, device-wide when via `nvidia-smi` fallback.
72 /// `None` if no GPU is present or measurement failed.
73 pub vram_bytes: Option<u64>,
74 /// Total GPU memory on the active device in bytes.
75 /// `None` if no GPU is present or measurement failed.
76 pub vram_total_bytes: Option<u64>,
77 /// Whether the VRAM measurement is per-process (`true`) or device-wide (`false`).
78 /// `None` if no VRAM data is available.
79 pub vram_per_process: Option<bool>,
80 /// GPU adapter name (e.g., `NVIDIA GeForce RTX 5060 Ti`).
81 /// `None` if not available (non-DXGI path or no GPU).
82 pub gpu_name: Option<String>,
83}
84
85/// Memory delta between two snapshots.
86///
87/// Computed from a `before` and `after` [`MemorySnapshot`].
88/// Positive deltas mean memory increased; negative means freed.
89#[derive(Debug, Clone)]
90pub struct MemoryReport {
91 /// Snapshot taken before the operation.
92 pub before: MemorySnapshot,
93 /// Snapshot taken after the operation.
94 pub after: MemorySnapshot,
95}
96
97impl MemorySnapshot {
98 /// Capture current memory state.
99 ///
100 /// RAM is always measured (per-process RSS). VRAM is measured only if
101 /// `device` is CUDA — first via DXGI (Windows, per-process), then NVML
102 /// (Linux, per-process), falling back to `nvidia-smi` (device-wide).
103 ///
104 /// # Errors
105 ///
106 /// Returns [`MIError::Memory`] if the RAM query fails (platform API error).
107 /// VRAM measurement failures are non-fatal — `vram_bytes` is set to `None`.
108 pub fn now(device: &candle_core::Device) -> Result<Self> {
109 let ram_bytes = process_rss()?;
110 let (vram_bytes, vram_total_bytes, per_process, gpu_name) = if device.is_cuda() {
111 gpu_memory_used()
112 } else {
113 (None, None, None, None)
114 };
115 Ok(Self {
116 ram_bytes,
117 vram_bytes,
118 vram_total_bytes,
119 vram_per_process: per_process,
120 gpu_name,
121 })
122 }
123
124 /// Format RAM usage as megabytes.
125 #[must_use]
126 pub fn ram_mb(&self) -> f64 {
127 // CAST: u64 → f64, value is memory in bytes — fits in f64 mantissa
128 // for any realistic process size (< 2^53 bytes = 8 PB)
129 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
130 let mb = self.ram_bytes as f64 / 1_048_576.0;
131 mb
132 }
133
134 /// Format VRAM usage as megabytes, if available.
135 #[must_use]
136 pub fn vram_mb(&self) -> Option<f64> {
137 // CAST: u64 → f64, same justification as ram_mb
138 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
139 self.vram_bytes.map(|b| b as f64 / 1_048_576.0)
140 }
141}
142
143/// Synchronize the CUDA device and trim its memory pool.
144///
145/// On a CUDA device this:
146/// 1. Calls `cuCtxSynchronize` so all pending async frees complete.
147/// 2. Calls `cuMemPoolTrimTo(pool, 0)` to release all unused reserved
148/// VRAM back to the device.
149///
150/// cudarc's stream-ordered allocator (`malloc_async` / `free_async`)
151/// keeps freed blocks in a pool for reuse. Over many forward passes
152/// with varying tensor sizes the pool grows monotonically — DXGI and
153/// `nvidia-smi` report this reserved memory as "in use", eventually
154/// causing OOM even though no live tensors need it.
155///
156/// This function is a no-op on CPU and on Metal.
157///
158/// # Example
159///
160/// ```no_run
161/// # use candle_mi::sync_and_trim_gpu;
162/// # let device = candle_core::Device::Cpu;
163/// // After dropping all GPU tensors from a forward pass:
164/// sync_and_trim_gpu(&device);
165/// ```
166// Cannot be `const fn`: the `cuda` branch calls non-const FFI (cuDeviceGetDefaultMemPool,
167// cuMemPoolTrimTo). Without `cuda` the body collapses to a no-op, which is why clippy
168// suggests `const` — but `const fn` would break the cuda-enabled build.
169#[allow(clippy::missing_const_for_fn)]
170pub fn sync_and_trim_gpu(device: &candle_core::Device) {
171 #[cfg(feature = "cuda")]
172 if let candle_core::Device::Cuda(cuda_dev) = device {
173 use candle_core::backend::BackendDevice;
174 // Synchronize so all pending async frees complete.
175 let _ = cuda_dev.synchronize();
176
177 // Trim the default memory pool to release all unused reserved VRAM.
178 // SAFETY: cuDeviceGetDefaultMemPool and cuMemPoolTrimTo are
179 // documented CUDA driver APIs for pool management. The CUdevice
180 // handle comes from candle's CudaContext (valid after synchronize).
181 // cuMemPoolTrimTo(pool, 0) releases all unused memory — it cannot
182 // free memory that is still in use by live tensors.
183 #[allow(unsafe_code)]
184 {
185 use candle_core::cuda_backend::cudarc::driver::sys;
186
187 let stream = cuda_dev.cuda_stream();
188 // Allocate a zero-length slice just to access the CudaContext
189 // (CudaStream.ctx is pub(crate), but CudaSlice.context() is pub).
190 if let Ok(probe) = stream.null::<u8>() {
191 let ctx = probe.context();
192 let cu_device = ctx.cu_device();
193 unsafe {
194 let mut pool = std::mem::zeroed();
195 let rc = sys::cuDeviceGetDefaultMemPool(&raw mut pool, cu_device);
196 if rc == sys::CUresult::CUDA_SUCCESS {
197 let _ = sys::cuMemPoolTrimTo(pool, 0);
198 }
199 }
200 }
201 }
202 }
203
204 // Suppress unused-variable warning on non-CUDA builds.
205 #[cfg(not(feature = "cuda"))]
206 let _ = device;
207}
208
209impl MemoryReport {
210 /// Create a report from two snapshots.
211 #[must_use]
212 pub const fn new(before: MemorySnapshot, after: MemorySnapshot) -> Self {
213 Self { before, after }
214 }
215
216 /// RAM delta in megabytes (positive = increased).
217 #[must_use]
218 pub fn ram_delta_mb(&self) -> f64 {
219 self.after.ram_mb() - self.before.ram_mb()
220 }
221
222 /// VRAM delta in megabytes (positive = increased).
223 /// Returns `None` if either snapshot lacks VRAM data.
224 #[must_use]
225 pub fn vram_delta_mb(&self) -> Option<f64> {
226 match (self.after.vram_mb(), self.before.vram_mb()) {
227 (Some(after), Some(before)) => Some(after - before),
228 (Some(_) | None, None) | (None, Some(_)) => None,
229 }
230 }
231
232 /// Print a one-line summary of the delta.
233 pub fn print_delta(&self, label: &str) {
234 let ram = self.ram_delta_mb();
235 print!(" {label}: RAM {ram:+.0} MB");
236 if let Some(vram) = self.vram_delta_mb() {
237 let qualifier = self.vram_qualifier();
238 print!(" | VRAM {vram:+.0} MB{qualifier}");
239 }
240 println!();
241 }
242
243 /// Print a two-line summary showing before → after for both RAM and VRAM.
244 pub fn print_before_after(&self, label: &str) {
245 println!(
246 " {label}: RAM {:.0} MB → {:.0} MB ({:+.0} MB)",
247 self.before.ram_mb(),
248 self.after.ram_mb(),
249 self.ram_delta_mb(),
250 );
251 if let (Some(before), Some(after)) = (self.before.vram_mb(), self.after.vram_mb()) {
252 // CAST: u64 → f64, same justification as ram_mb
253 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
254 let total = self.after.vram_total_bytes.map_or(String::new(), |t| {
255 format!(" / {:.0} MB", t as f64 / 1_048_576.0)
256 });
257 let qualifier = self.vram_qualifier();
258 let gpu = self
259 .after
260 .gpu_name
261 .as_deref()
262 .map_or(String::new(), |name| format!(" [{name}]"));
263 println!(
264 " {label}: VRAM {before:.0} MB → {after:.0} MB ({:+.0} MB{total}){qualifier}{gpu}",
265 after - before,
266 );
267 }
268 }
269
270 /// Return a short qualifier string indicating VRAM measurement quality.
271 #[must_use]
272 const fn vram_qualifier(&self) -> &'static str {
273 match self.after.vram_per_process {
274 Some(true) => " [per-process]",
275 Some(false) => " [device-wide]",
276 None => "",
277 }
278 }
279}
280
281// ---------------------------------------------------------------------------
282// RAM measurement — per-process RSS
283// ---------------------------------------------------------------------------
284
285/// Query the current process's resident set size (RSS) in bytes.
286///
287/// # Platform
288///
289/// - **Windows**: `K32GetProcessMemoryInfo` → `WorkingSetSize` (exact, per-process).
290/// - **Linux**: `/proc/self/status` → `VmRSS` (exact, per-process, no unsafe).
291///
292/// # Errors
293///
294/// Returns [`MIError::Memory`] if the platform API call fails.
295fn process_rss() -> Result<u64> {
296 #[cfg(target_os = "windows")]
297 {
298 windows_rss()
299 }
300 #[cfg(target_os = "linux")]
301 {
302 linux_rss()
303 }
304 #[cfg(not(any(target_os = "windows", target_os = "linux")))]
305 {
306 Err(MIError::Memory(
307 "RAM measurement not supported on this platform".into(),
308 ))
309 }
310}
311
312// -- Windows ----------------------------------------------------------------
313
314/// Windows FFI types and functions for `K32GetProcessMemoryInfo`.
315#[cfg(target_os = "windows")]
316mod win_ffi {
317 /// `PROCESS_MEMORY_COUNTERS` structure from the Windows API.
318 ///
319 /// See: <https://learn.microsoft.com/en-us/windows/win32/api/psapi/ns-psapi-process_memory_counters>
320 #[repr(C)]
321 pub(super) struct ProcessMemoryCounters {
322 /// Size of this structure in bytes.
323 pub cb: u32,
324 /// Number of page faults.
325 pub page_fault_count: u32,
326 /// Peak working set size in bytes.
327 pub peak_working_set_size: usize,
328 /// Current working set size in bytes (= RSS).
329 pub working_set_size: usize,
330 /// Peak paged pool usage in bytes.
331 pub quota_peak_paged_pool_usage: usize,
332 /// Current paged pool usage in bytes.
333 pub quota_paged_pool_usage: usize,
334 /// Peak non-paged pool usage in bytes.
335 pub quota_peak_non_paged_pool_usage: usize,
336 /// Current non-paged pool usage in bytes.
337 pub quota_non_paged_pool_usage: usize,
338 /// Current pagefile usage in bytes.
339 pub pagefile_usage: usize,
340 /// Peak pagefile usage in bytes.
341 pub peak_pagefile_usage: usize,
342 }
343
344 // SAFETY: These are stable Windows API functions with well-defined ABI.
345 // GetCurrentProcess always returns a valid pseudo-handle.
346 // K32GetProcessMemoryInfo writes to caller-provided memory of known size.
347 #[allow(unsafe_code)]
348 unsafe extern "system" {
349 /// Returns a pseudo-handle to the current process (always valid, never null).
350 pub(super) safe fn GetCurrentProcess() -> isize;
351
352 /// Retrieves memory usage information for the specified process.
353 pub(super) unsafe fn K32GetProcessMemoryInfo(
354 process: isize,
355 ppsmem_counters: *mut ProcessMemoryCounters,
356 cb: u32,
357 ) -> i32;
358 }
359}
360
361/// Query RSS on Windows via `K32GetProcessMemoryInfo`.
362#[cfg(target_os = "windows")]
363#[allow(unsafe_code)]
364fn windows_rss() -> Result<u64> {
365 let mut counters = win_ffi::ProcessMemoryCounters {
366 cb: 0,
367 page_fault_count: 0,
368 peak_working_set_size: 0,
369 working_set_size: 0,
370 quota_peak_paged_pool_usage: 0,
371 quota_paged_pool_usage: 0,
372 quota_peak_non_paged_pool_usage: 0,
373 quota_non_paged_pool_usage: 0,
374 pagefile_usage: 0,
375 peak_pagefile_usage: 0,
376 };
377 // CAST: usize → u32, struct size is 80 bytes on x64 — fits in u32
378 #[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
379 let cb = std::mem::size_of::<win_ffi::ProcessMemoryCounters>() as u32;
380 counters.cb = cb;
381
382 let handle = win_ffi::GetCurrentProcess();
383
384 // SAFETY: K32GetProcessMemoryInfo writes into the stack-allocated
385 // `counters` struct, which is correctly sized (cb field set to struct
386 // size). The process handle from GetCurrentProcess is a pseudo-handle
387 // that is always valid for the lifetime of the process.
388 let ok = unsafe { win_ffi::K32GetProcessMemoryInfo(handle, &raw mut counters, cb) };
389
390 if ok != 0 {
391 // CAST: usize → u64, working set size in bytes — always fits
392 #[allow(clippy::as_conversions)]
393 let rss = counters.working_set_size as u64;
394 Ok(rss)
395 } else {
396 Err(MIError::Memory("K32GetProcessMemoryInfo failed".into()))
397 }
398}
399
400// -- Linux ------------------------------------------------------------------
401
402/// Query RSS on Linux via `/proc/self/status`.
403#[cfg(target_os = "linux")]
404fn linux_rss() -> Result<u64> {
405 let status = std::fs::read_to_string("/proc/self/status")
406 .map_err(|e| MIError::Memory(format!("failed to read /proc/self/status: {e}")))?;
407
408 for line in status.lines() {
409 if let Some(rest) = line.strip_prefix("VmRSS:") {
410 let kb_str = rest.trim().trim_end_matches(" kB").trim();
411 let kb: u64 = kb_str.parse().map_err(|e| {
412 MIError::Memory(format!("failed to parse VmRSS value '{kb_str}': {e}"))
413 })?;
414 return Ok(kb * 1024);
415 }
416 }
417
418 Err(MIError::Memory(
419 "VmRSS not found in /proc/self/status".into(),
420 ))
421}
422
423// ---------------------------------------------------------------------------
424// VRAM measurement — DXGI (Windows), NVML, nvidia-smi fallback
425// ---------------------------------------------------------------------------
426
427/// Result of a GPU memory query: `(used_bytes, total_bytes, per_process, gpu_name)`.
428type GpuMemoryResult = (Option<u64>, Option<u64>, Option<bool>, Option<String>);
429
430/// NVML shared library path (stable across driver versions).
431#[cfg(target_os = "linux")]
432const NVML_LIB_PATH: &str = "libnvidia-ml.so.1";
433
434/// NVML shared library path (stable across driver versions).
435#[cfg(target_os = "windows")]
436const NVML_LIB_PATH: &str = "nvml.dll";
437
438/// NVML return code: success.
439const NVML_SUCCESS: u32 = 0;
440
441/// NVML return code: buffer too small (need to retry with larger buffer).
442const NVML_ERROR_INSUFFICIENT_SIZE: u32 = 7;
443
444/// Maximum number of processes to query from NVML in a single call.
445/// 64 is generous — most machines have fewer than 10 GPU processes.
446const NVML_MAX_PROCESSES: usize = 64;
447
448/// Per-process GPU memory info returned by NVML.
449///
450/// Matches the C struct `nvmlProcessInfo_v2_t` (24 bytes) used by both
451/// `nvmlDeviceGetComputeRunningProcesses_v2` and `_v3` (the `_v3` suffix
452/// is a function version, not a struct version).
453/// See: <https://docs.nvidia.com/deploy/nvml-api/structnvmlProcessInfo__v2__t.html>
454#[repr(C)]
455#[derive(Debug, Clone, Copy)]
456struct NvmlProcessInfo {
457 /// Process ID.
458 pid: u32,
459 /// GPU memory used by this process in bytes.
460 /// `u64::MAX` (`0xFFFF_FFFF_FFFF_FFFF`) means "not available".
461 used_gpu_memory: u64,
462 /// GPU instance ID (MIG). Unused in non-MIG mode.
463 gpu_instance_id: u32,
464 /// Compute instance ID (MIG). Unused in non-MIG mode.
465 compute_instance_id: u32,
466}
467
468/// NVML memory info for a device.
469///
470/// Matches the C struct `nvmlMemory_t` from the NVML API.
471/// See: <https://docs.nvidia.com/deploy/nvml-api/structnvmlMemory__t.html>
472#[repr(C)]
473#[derive(Debug, Clone, Copy)]
474struct NvmlMemoryInfo {
475 /// Total GPU memory in bytes.
476 total: u64,
477 /// Free GPU memory in bytes.
478 free: u64,
479 /// Used GPU memory in bytes.
480 used: u64,
481}
482
483/// Opaque NVML device handle.
484type NvmlDevice = *mut std::ffi::c_void;
485
486/// Function signature: `nvmlInit_v2() -> nvmlReturn_t`.
487type NvmlInitFn = unsafe extern "C" fn() -> u32;
488
489/// Function signature: `nvmlShutdown() -> nvmlReturn_t`.
490type NvmlShutdownFn = unsafe extern "C" fn() -> u32;
491
492/// Function signature: `nvmlDeviceGetHandleByIndex_v2(index, *mut device) -> nvmlReturn_t`.
493type NvmlDeviceGetHandleByIndexFn = unsafe extern "C" fn(u32, *mut NvmlDevice) -> u32;
494
495/// Function signature: `nvmlDeviceGetMemoryInfo(device, *mut memory) -> nvmlReturn_t`.
496type NvmlDeviceGetMemoryInfoFn = unsafe extern "C" fn(NvmlDevice, *mut NvmlMemoryInfo) -> u32;
497
498/// Function signature:
499/// `nvmlDeviceGetComputeRunningProcesses_v3(device, *mut count, *mut infos) -> nvmlReturn_t`.
500type NvmlDeviceGetComputeRunningProcessesFn =
501 unsafe extern "C" fn(NvmlDevice, *mut u32, *mut NvmlProcessInfo) -> u32;
502
503/// Query GPU memory — DXGI (Windows), NVML, or `nvidia-smi` fallback.
504///
505/// Returns `(used_bytes, total_bytes, per_process, gpu_name)`:
506/// - `per_process = Some(true)` when per-process query succeeded (DXGI or NVML).
507/// - `per_process = Some(false)` when falling back to `nvidia-smi` (device-wide).
508/// - `gpu_name` is set when DXGI provides the adapter description.
509/// - All `None` if all methods fail.
510fn gpu_memory_used() -> GpuMemoryResult {
511 // Windows: try DXGI first (per-process, works under WDDM)
512 #[cfg(windows)]
513 if let Some(result) = dxgi_query_process_vram() {
514 return result;
515 }
516
517 // Try NVML (per-process on Linux, NOT_AVAILABLE on Windows WDDM)
518 if let Some(result) = nvml_query_process_vram() {
519 let (used, total, per_process) = result;
520 return (used, total, per_process, None);
521 }
522
523 // Fallback to nvidia-smi (device-wide)
524 let (used, total) = nvidia_smi_query();
525 if used.is_some() {
526 (used, total, Some(false), None)
527 } else {
528 (None, None, None, None)
529 }
530}
531
532/// Attempt to query per-process VRAM via NVML.
533///
534/// Returns `None` if NVML cannot be loaded or any API call fails,
535/// signaling the caller to try the fallback path.
536#[allow(unsafe_code)]
537fn nvml_query_process_vram() -> Option<(Option<u64>, Option<u64>, Option<bool>)> {
538 // SAFETY: libloading::Library::new dynamically loads a shared library.
539 // The NVML library is a stable NVIDIA driver component with a well-defined
540 // C ABI. We load it, call functions, and unload it within this scope.
541 let lib = unsafe { libloading::Library::new(NVML_LIB_PATH) }.ok()?;
542
543 // SAFETY: Loading function symbols from the NVML library. Each symbol
544 // name matches the documented NVML C API exactly. The function signatures
545 // (type aliases above) match the NVML header definitions.
546 let init: libloading::Symbol<'_, NvmlInitFn> = unsafe { lib.get(b"nvmlInit_v2\0") }.ok()?;
547 let shutdown: libloading::Symbol<'_, NvmlShutdownFn> =
548 unsafe { lib.get(b"nvmlShutdown\0") }.ok()?;
549 let get_handle: libloading::Symbol<'_, NvmlDeviceGetHandleByIndexFn> =
550 unsafe { lib.get(b"nvmlDeviceGetHandleByIndex_v2\0") }.ok()?;
551 let get_memory: libloading::Symbol<'_, NvmlDeviceGetMemoryInfoFn> =
552 unsafe { lib.get(b"nvmlDeviceGetMemoryInfo\0") }.ok()?;
553 let get_processes: libloading::Symbol<'_, NvmlDeviceGetComputeRunningProcessesFn> =
554 unsafe { lib.get(b"nvmlDeviceGetComputeRunningProcesses_v3\0") }.ok()?;
555
556 // Initialize NVML
557 // SAFETY: nvmlInit_v2 is safe to call from any thread; it initializes
558 // internal NVML state. Returns NVML_SUCCESS (0) on success.
559 let ret = unsafe { init() };
560 if ret != NVML_SUCCESS {
561 return None;
562 }
563
564 // Get device handle for GPU 0 (primary GPU)
565 let mut device: NvmlDevice = std::ptr::null_mut();
566 // SAFETY: nvmlDeviceGetHandleByIndex_v2 writes a valid opaque handle
567 // into `device` when it returns NVML_SUCCESS. Index 0 = primary GPU.
568 let ret = unsafe { get_handle(0, &raw mut device) };
569 if ret != NVML_SUCCESS {
570 // SAFETY: nvmlShutdown is always safe after a successful nvmlInit.
571 unsafe { shutdown() };
572 return None;
573 }
574
575 // Get total memory for the device
576 let mut mem_info = NvmlMemoryInfo {
577 total: 0,
578 free: 0,
579 used: 0,
580 };
581 // SAFETY: nvmlDeviceGetMemoryInfo writes into the caller-provided
582 // NvmlMemoryInfo struct. The device handle is valid (obtained above).
583 let ret = unsafe { get_memory(device, &raw mut mem_info) };
584 let total_bytes = if ret == NVML_SUCCESS {
585 Some(mem_info.total)
586 } else {
587 None
588 };
589
590 // Get per-process memory
591 // CAST: usize → u32, NVML_MAX_PROCESSES is 64 — fits in u32
592 #[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
593 let mut count = NVML_MAX_PROCESSES as u32;
594 let mut infos = [NvmlProcessInfo {
595 pid: 0,
596 used_gpu_memory: 0,
597 gpu_instance_id: 0,
598 compute_instance_id: 0,
599 }; NVML_MAX_PROCESSES];
600
601 // SAFETY: nvmlDeviceGetComputeRunningProcesses_v3 fills `infos` with
602 // up to `count` entries and updates `count` to the actual number written.
603 // The buffer is stack-allocated with NVML_MAX_PROCESSES slots, which is
604 // sufficient for typical workloads.
605 let ret = unsafe { get_processes(device, &raw mut count, infos.as_mut_ptr()) };
606
607 // SAFETY: nvmlShutdown pairs with nvmlInit; always called before return.
608 unsafe { shutdown() };
609
610 if ret != NVML_SUCCESS && ret != NVML_ERROR_INSUFFICIENT_SIZE {
611 return None;
612 }
613
614 // Find our process in the list
615 let my_pid = std::process::id();
616 // CAST: u32 → usize, count is a small process count — always fits
617 #[allow(clippy::as_conversions)]
618 let actual_count = count as usize;
619 let my_vram = infos
620 .get(..actual_count)?
621 .iter()
622 .find(|info| info.pid == my_pid)
623 .map(|info| info.used_gpu_memory);
624
625 // NVML uses u64::MAX as "not available" sentinel — some drivers (e.g., R570
626 // on RTX 5060 Ti) return this for all processes. Fall back to nvidia-smi.
627 if my_vram == Some(u64::MAX) {
628 return None;
629 }
630
631 // Sanity check: if per-process VRAM exceeds total device memory, the value
632 // is likely garbage. Fall back to nvidia-smi.
633 if let (Some(used), Some(total)) = (my_vram, total_bytes)
634 && used > total
635 {
636 return None;
637 }
638
639 // Our PID might not be in the list (no active CUDA context yet?) — return None to trigger fallback
640 my_vram.map(|used| (Some(used), total_bytes, Some(true)))
641}
642
643/// Query GPU memory via `nvidia-smi` (device-wide fallback).
644///
645/// Returns `(Some(used_bytes), Some(total_bytes))` on success,
646/// or `(None, None)` if `nvidia-smi` is not available or fails.
647fn nvidia_smi_query() -> (Option<u64>, Option<u64>) {
648 let output = std::process::Command::new("nvidia-smi")
649 .args([
650 "--query-gpu=memory.used,memory.total",
651 "--format=csv,noheader,nounits",
652 ])
653 .output();
654
655 let output = match output {
656 Ok(o) if o.status.success() => o,
657 _ => return (None, None),
658 };
659
660 // BORROW: explicit String::from_utf8_lossy — nvidia-smi output is ASCII
661 let stdout = String::from_utf8_lossy(&output.stdout);
662 let line = match stdout.lines().next() {
663 Some(l) => l.trim(),
664 None => return (None, None),
665 };
666
667 // Format: "1234, 16384" (used MiB, total MiB)
668 let mut parts = line.split(',');
669 let used_str = match parts.next() {
670 Some(s) => s.trim(),
671 None => return (None, None),
672 };
673 let total_str = match parts.next() {
674 Some(s) => s.trim(),
675 None => return (None, None),
676 };
677
678 let used_mb: u64 = match used_str.parse() {
679 Ok(v) => v,
680 Err(_) => return (None, None),
681 };
682 let total_mb: u64 = match total_str.parse() {
683 Ok(v) => v,
684 Err(_) => return (None, None),
685 };
686
687 // nvidia-smi reports in MiB — convert to bytes
688 (Some(used_mb * 1_048_576), Some(total_mb * 1_048_576))
689}
690
691// ---------------------------------------------------------------------------
692// DXGI per-process VRAM (Windows only)
693// ---------------------------------------------------------------------------
694
695/// Query per-process GPU VRAM via DXGI (`IDXGIAdapter3::QueryVideoMemoryInfo`).
696///
697/// This is the only reliable way to get per-process GPU memory on Windows
698/// (WDDM). NVML returns `NVML_VALUE_NOT_AVAILABLE` under WDDM because the
699/// Windows kernel memory manager owns GPU memory, not the NVIDIA driver.
700///
701/// DXGI 1.4 (Windows 10+) provides `QueryVideoMemoryInfo` which returns:
702/// - `CurrentUsage`: per-process GPU memory in bytes (exactly what we want).
703/// - `Budget`: OS-assigned memory budget for this process.
704///
705/// We query `DXGI_MEMORY_SEGMENT_GROUP_LOCAL` (dedicated VRAM on discrete GPUs).
706/// Total VRAM comes from `IDXGIAdapter::GetDesc` → `DedicatedVideoMemory`.
707///
708/// Returns `None` if DXGI is not available or the query fails,
709/// signaling the caller to try NVML or nvidia-smi fallback.
710#[cfg(windows)]
711#[allow(unsafe_code)]
712fn dxgi_query_process_vram() -> Option<GpuMemoryResult> {
713 use windows::Win32::Graphics::Dxgi::{
714 CreateDXGIFactory1, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, DXGI_QUERY_VIDEO_MEMORY_INFO,
715 IDXGIAdapter, IDXGIAdapter3, IDXGIFactory1,
716 };
717 use windows::core::Interface;
718
719 // SAFETY: CreateDXGIFactory1 is a well-documented COM factory function.
720 // It initializes COM internally if needed. The returned IDXGIFactory1
721 // is reference-counted and released automatically when dropped.
722 let factory: IDXGIFactory1 = unsafe { CreateDXGIFactory1() }.ok()?;
723
724 // Enumerate adapters — find the first one with dedicated VRAM > 0
725 // (skip software/render-only adapters like Microsoft Basic Render Driver).
726 let mut adapter_idx = 0u32;
727 loop {
728 // SAFETY: EnumAdapters1 returns S_OK with a valid adapter, or
729 // DXGI_ERROR_NOT_FOUND when idx is out of range.
730 let adapter: IDXGIAdapter = unsafe { factory.EnumAdapters1(adapter_idx) }
731 .ok()?
732 .cast()
733 .ok()?;
734
735 // SAFETY: GetDesc writes a valid DXGI_ADAPTER_DESC into the
736 // caller-provided struct. The adapter handle is valid (obtained above).
737 let desc = unsafe { adapter.GetDesc() }.ok()?;
738 let dedicated_vram = desc.DedicatedVideoMemory;
739
740 if dedicated_vram == 0 {
741 adapter_idx += 1;
742 continue;
743 }
744
745 // Cast to IDXGIAdapter3 for QueryVideoMemoryInfo (DXGI 1.4)
746 let adapter3: IDXGIAdapter3 = adapter.cast().ok()?;
747
748 // SAFETY: QueryVideoMemoryInfo fills a DXGI_QUERY_VIDEO_MEMORY_INFO
749 // struct with per-process memory stats. Node 0 = primary GPU node.
750 // DXGI_MEMORY_SEGMENT_GROUP_LOCAL = dedicated VRAM on discrete GPUs.
751 let mut mem_info = DXGI_QUERY_VIDEO_MEMORY_INFO::default();
752 unsafe {
753 adapter3.QueryVideoMemoryInfo(0, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, &raw mut mem_info)
754 }
755 .ok()?;
756
757 // CAST: usize → u64, DedicatedVideoMemory is usize on Windows
758 #[allow(clippy::as_conversions)]
759 let total = dedicated_vram as u64;
760
761 // Trim trailing null characters from the UTF-16 adapter description
762 // BORROW: explicit from_utf16_lossy — DXGI Description is a fixed-size UTF-16 array
763 let raw_name = String::from_utf16_lossy(&desc.Description);
764 // BORROW: to_owned — trim returns a &str slice; we need an owned String
765 let gpu_name = raw_name.trim_end_matches('\0').to_owned();
766
767 #[cfg(feature = "memory-debug")]
768 eprintln!(
769 "[DXGI debug] adapter={gpu_name}, dedicated_vram={total}, \
770 current_usage={}, budget={}",
771 mem_info.CurrentUsage, mem_info.Budget,
772 );
773
774 return Some((
775 Some(mem_info.CurrentUsage),
776 Some(total),
777 Some(true),
778 Some(gpu_name),
779 ));
780 }
781}
782
783// ---------------------------------------------------------------------------
784// Tests
785// ---------------------------------------------------------------------------
786
787#[cfg(test)]
788#[allow(clippy::unwrap_used, clippy::expect_used)]
789mod tests {
790 use super::*;
791
792 #[test]
793 fn snapshot_cpu_has_ram() {
794 let snap = MemorySnapshot::now(&candle_core::Device::Cpu).unwrap();
795 // Process must be using > 0 bytes of RAM
796 assert!(snap.ram_bytes > 0, "RAM should be non-zero");
797 // CPU device should not have VRAM
798 assert!(snap.vram_bytes.is_none(), "CPU should have no VRAM");
799 assert!(
800 snap.vram_per_process.is_none(),
801 "CPU should have no VRAM qualifier"
802 );
803 }
804
805 #[test]
806 fn report_delta_positive_for_allocation() {
807 let before = MemorySnapshot {
808 ram_bytes: 100 * 1_048_576, // 100 MB
809 vram_bytes: Some(500 * 1_048_576),
810 vram_total_bytes: Some(16_384 * 1_048_576),
811 vram_per_process: Some(true),
812 gpu_name: None,
813 };
814 let after = MemorySnapshot {
815 ram_bytes: 200 * 1_048_576, // 200 MB
816 vram_bytes: Some(1_000 * 1_048_576),
817 vram_total_bytes: Some(16_384 * 1_048_576),
818 vram_per_process: Some(true),
819 gpu_name: None,
820 };
821 let report = MemoryReport::new(before, after);
822
823 let ram_delta = report.ram_delta_mb();
824 assert!(
825 (ram_delta - 100.0).abs() < 0.01,
826 "RAM delta should be ~100 MB, got {ram_delta}"
827 );
828
829 let vram_delta = report.vram_delta_mb().unwrap();
830 assert!(
831 (vram_delta - 500.0).abs() < 0.01,
832 "VRAM delta should be ~500 MB, got {vram_delta}"
833 );
834 }
835
836 #[test]
837 fn report_delta_none_when_no_vram() {
838 let before = MemorySnapshot {
839 ram_bytes: 100,
840 vram_bytes: None,
841 vram_total_bytes: None,
842 vram_per_process: None,
843 gpu_name: None,
844 };
845 let after = MemorySnapshot {
846 ram_bytes: 200,
847 vram_bytes: None,
848 vram_total_bytes: None,
849 vram_per_process: None,
850 gpu_name: None,
851 };
852 let report = MemoryReport::new(before, after);
853 assert!(report.vram_delta_mb().is_none());
854 }
855
856 #[test]
857 fn ram_mb_conversion() {
858 let snap = MemorySnapshot {
859 ram_bytes: 1_048_576, // exactly 1 MB
860 vram_bytes: None,
861 vram_total_bytes: None,
862 vram_per_process: None,
863 gpu_name: None,
864 };
865 assert!((snap.ram_mb() - 1.0).abs() < 0.001);
866 }
867
868 #[test]
869 fn vram_qualifier_per_process() {
870 let snap = MemorySnapshot {
871 ram_bytes: 100,
872 vram_bytes: Some(500),
873 vram_total_bytes: Some(1000),
874 vram_per_process: Some(true),
875 gpu_name: None,
876 };
877 let report = MemoryReport::new(snap.clone(), snap);
878 assert_eq!(report.vram_qualifier(), " [per-process]");
879 }
880
881 #[test]
882 fn vram_qualifier_device_wide() {
883 let snap = MemorySnapshot {
884 ram_bytes: 100,
885 vram_bytes: Some(500),
886 vram_total_bytes: Some(1000),
887 vram_per_process: Some(false),
888 gpu_name: None,
889 };
890 let report = MemoryReport::new(snap.clone(), snap);
891 assert_eq!(report.vram_qualifier(), " [device-wide]");
892 }
893}