candle_mi/memory.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Process and GPU memory reporting.
4//!
5//! Provides [`MemorySnapshot`] to capture current RAM and VRAM usage,
6//! and [`MemoryReport`] to measure deltas between two snapshots.
7//!
8//! # VRAM measurement strategy
9//!
10//! VRAM is measured using a three-tier approach:
11//!
12//! 1. **Windows primary — DXGI** (per-process): Uses
13//! `IDXGIAdapter3::QueryVideoMemoryInfo` (DXGI 1.4, Windows 10+) to get
14//! true per-process GPU memory. This is the only reliable method on Windows
15//! because WDDM means the Windows kernel manages GPU memory, not the NVIDIA
16//! driver — so NVML returns `NOT_AVAILABLE` for per-process queries.
17//! 2. **Linux primary — NVML** (per-process): Dynamically loads
18//! `libnvidia-ml.so.1` via `libloading` and calls
19//! `nvmlDeviceGetComputeRunningProcesses` for true per-process GPU memory.
20//! 3. **Fallback — `nvidia-smi`** (device-wide): If both DXGI and NVML fail,
21//! spawns `nvidia-smi` as a subprocess. Reports device-wide VRAM; the delta
22//! between two snapshots is accurate on single-user machines.
23//!
24//! # Platform support
25//!
26//! | Metric | Windows | Linux |
27//! |--------|---------|-------|
28//! | RAM (RSS) | `K32GetProcessMemoryInfo` (per-process, exact) | `/proc/self/status` `VmRSS` (per-process, exact) |
29//! | VRAM (DXGI) | `IDXGIAdapter3` (per-process, exact) | N/A |
30//! | VRAM (NVML) | `NOT_AVAILABLE` under WDDM | `libnvidia-ml.so.1` (per-process, exact) |
31//! | VRAM (fallback) | `nvidia-smi` (device-wide) | `nvidia-smi` (device-wide) |
32//!
33//! # Feature gates
34//!
35//! - **`memory`**: Enables this module. Relaxes `#![forbid(unsafe_code)]` to
36//! `#![deny(unsafe_code)]` for the Windows FFI calls (`K32GetProcessMemoryInfo`,
37//! DXGI COM calls) and for NVML dynamic symbol loading via `libloading`.
38//! On Linux RAM measurement, no unsafe code is used.
39//! - **`memory-debug`** (implies `memory`): Prints raw DXGI query results and
40//! per-chunk VRAM measurements to stderr for diagnosing GPU memory issues.
41
42use crate::{MIError, Result};
43
44// ---------------------------------------------------------------------------
45// Public types
46// ---------------------------------------------------------------------------
47
48/// Memory snapshot at a point in time.
49///
50/// Captures process RAM (resident set size) and optionally GPU VRAM.
51/// Use [`MemorySnapshot::now`] to take a measurement, and
52/// [`MemoryReport::new`] to compute deltas between two snapshots.
53///
54/// # Example
55///
56/// ```no_run
57/// use candle_mi::MemorySnapshot;
58///
59/// let before = MemorySnapshot::now(&candle_core::Device::Cpu)?;
60/// // ... load a model ...
61/// let after = MemorySnapshot::now(&candle_core::Device::Cpu)?;
62/// let report = candle_mi::MemoryReport::new(before, after);
63/// println!("RAM delta: {:+.1} MB", report.ram_delta_mb());
64/// # Ok::<(), candle_mi::MIError>(())
65/// ```
66#[derive(Debug, Clone)]
67pub struct MemorySnapshot {
68 /// Process resident set size (working set on Windows) in bytes.
69 pub ram_bytes: u64,
70 /// GPU memory used in bytes.
71 /// Per-process when measured via DXGI/NVML, device-wide when via `nvidia-smi` fallback.
72 /// `None` if no GPU is present or measurement failed.
73 pub vram_bytes: Option<u64>,
74 /// Total GPU memory on the active device in bytes.
75 /// `None` if no GPU is present or measurement failed.
76 pub vram_total_bytes: Option<u64>,
77 /// Whether the VRAM measurement is per-process (`true`) or device-wide (`false`).
78 /// `None` if no VRAM data is available.
79 pub vram_per_process: Option<bool>,
80 /// GPU adapter name (e.g., `NVIDIA GeForce RTX 5060 Ti`).
81 /// `None` if not available (non-DXGI path or no GPU).
82 pub gpu_name: Option<String>,
83}
84
85/// Memory delta between two snapshots.
86///
87/// Computed from a `before` and `after` [`MemorySnapshot`].
88/// Positive deltas mean memory increased; negative means freed.
89#[derive(Debug, Clone)]
90pub struct MemoryReport {
91 /// Snapshot taken before the operation.
92 pub before: MemorySnapshot,
93 /// Snapshot taken after the operation.
94 pub after: MemorySnapshot,
95}
96
97impl MemorySnapshot {
98 /// Capture current memory state.
99 ///
100 /// RAM is always measured (per-process RSS). VRAM is measured only if
101 /// `device` is CUDA — first via DXGI (Windows, per-process), then NVML
102 /// (Linux, per-process), falling back to `nvidia-smi` (device-wide).
103 ///
104 /// # Errors
105 ///
106 /// Returns [`MIError::Memory`] if the RAM query fails (platform API error).
107 /// VRAM measurement failures are non-fatal — `vram_bytes` is set to `None`.
108 pub fn now(device: &candle_core::Device) -> Result<Self> {
109 let ram_bytes = process_rss()?;
110 let (vram_bytes, vram_total_bytes, per_process, gpu_name) = if device.is_cuda() {
111 gpu_memory_used()
112 } else {
113 (None, None, None, None)
114 };
115 Ok(Self {
116 ram_bytes,
117 vram_bytes,
118 vram_total_bytes,
119 vram_per_process: per_process,
120 gpu_name,
121 })
122 }
123
124 /// Format RAM usage as megabytes.
125 #[must_use]
126 pub fn ram_mb(&self) -> f64 {
127 // CAST: u64 → f64, value is memory in bytes — fits in f64 mantissa
128 // for any realistic process size (< 2^53 bytes = 8 PB)
129 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
130 let mb = self.ram_bytes as f64 / 1_048_576.0;
131 mb
132 }
133
134 /// Format VRAM usage as megabytes, if available.
135 #[must_use]
136 pub fn vram_mb(&self) -> Option<f64> {
137 // CAST: u64 → f64, same justification as ram_mb
138 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
139 self.vram_bytes.map(|b| b as f64 / 1_048_576.0)
140 }
141}
142
143/// Synchronize the CUDA device and trim its memory pool.
144///
145/// On a CUDA device this:
146/// 1. Calls `cuCtxSynchronize` so all pending async frees complete.
147/// 2. Calls `cuMemPoolTrimTo(pool, 0)` to release all unused reserved
148/// VRAM back to the device.
149///
150/// cudarc's stream-ordered allocator (`malloc_async` / `free_async`)
151/// keeps freed blocks in a pool for reuse. Over many forward passes
152/// with varying tensor sizes the pool grows monotonically — DXGI and
153/// `nvidia-smi` report this reserved memory as "in use", eventually
154/// causing OOM even though no live tensors need it.
155///
156/// This function is a no-op on CPU and on Metal.
157///
158/// # Example
159///
160/// ```no_run
161/// # use candle_mi::sync_and_trim_gpu;
162/// # let device = candle_core::Device::Cpu;
163/// // After dropping all GPU tensors from a forward pass:
164/// sync_and_trim_gpu(&device);
165/// ```
166pub fn sync_and_trim_gpu(device: &candle_core::Device) {
167 #[cfg(feature = "cuda")]
168 if let candle_core::Device::Cuda(cuda_dev) = device {
169 use candle_core::backend::BackendDevice;
170 // Synchronize so all pending async frees complete.
171 let _ = cuda_dev.synchronize();
172
173 // Trim the default memory pool to release all unused reserved VRAM.
174 // SAFETY: cuDeviceGetDefaultMemPool and cuMemPoolTrimTo are
175 // documented CUDA driver APIs for pool management. The CUdevice
176 // handle comes from candle's CudaContext (valid after synchronize).
177 // cuMemPoolTrimTo(pool, 0) releases all unused memory — it cannot
178 // free memory that is still in use by live tensors.
179 #[allow(unsafe_code)]
180 {
181 use candle_core::cuda_backend::cudarc::driver::sys;
182
183 let stream = cuda_dev.cuda_stream();
184 // Allocate a zero-length slice just to access the CudaContext
185 // (CudaStream.ctx is pub(crate), but CudaSlice.context() is pub).
186 if let Ok(probe) = stream.null::<u8>() {
187 let ctx = probe.context();
188 let cu_device = ctx.cu_device();
189 unsafe {
190 let mut pool = std::mem::zeroed();
191 let rc = sys::cuDeviceGetDefaultMemPool(&raw mut pool, cu_device);
192 if rc == sys::CUresult::CUDA_SUCCESS {
193 let _ = sys::cuMemPoolTrimTo(pool, 0);
194 }
195 }
196 }
197 }
198 }
199
200 // Suppress unused-variable warning on non-CUDA builds.
201 #[cfg(not(feature = "cuda"))]
202 let _ = device;
203}
204
205impl MemoryReport {
206 /// Create a report from two snapshots.
207 #[must_use]
208 pub const fn new(before: MemorySnapshot, after: MemorySnapshot) -> Self {
209 Self { before, after }
210 }
211
212 /// RAM delta in megabytes (positive = increased).
213 #[must_use]
214 pub fn ram_delta_mb(&self) -> f64 {
215 self.after.ram_mb() - self.before.ram_mb()
216 }
217
218 /// VRAM delta in megabytes (positive = increased).
219 /// Returns `None` if either snapshot lacks VRAM data.
220 #[must_use]
221 pub fn vram_delta_mb(&self) -> Option<f64> {
222 match (self.after.vram_mb(), self.before.vram_mb()) {
223 (Some(after), Some(before)) => Some(after - before),
224 (Some(_) | None, None) | (None, Some(_)) => None,
225 }
226 }
227
228 /// Print a one-line summary of the delta.
229 pub fn print_delta(&self, label: &str) {
230 let ram = self.ram_delta_mb();
231 print!(" {label}: RAM {ram:+.0} MB");
232 if let Some(vram) = self.vram_delta_mb() {
233 let qualifier = self.vram_qualifier();
234 print!(" | VRAM {vram:+.0} MB{qualifier}");
235 }
236 println!();
237 }
238
239 /// Print a two-line summary showing before → after for both RAM and VRAM.
240 pub fn print_before_after(&self, label: &str) {
241 println!(
242 " {label}: RAM {:.0} MB → {:.0} MB ({:+.0} MB)",
243 self.before.ram_mb(),
244 self.after.ram_mb(),
245 self.ram_delta_mb(),
246 );
247 if let (Some(before), Some(after)) = (self.before.vram_mb(), self.after.vram_mb()) {
248 // CAST: u64 → f64, same justification as ram_mb
249 #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
250 let total = self.after.vram_total_bytes.map_or(String::new(), |t| {
251 format!(" / {:.0} MB", t as f64 / 1_048_576.0)
252 });
253 let qualifier = self.vram_qualifier();
254 let gpu = self
255 .after
256 .gpu_name
257 .as_deref()
258 .map_or(String::new(), |name| format!(" [{name}]"));
259 println!(
260 " {label}: VRAM {before:.0} MB → {after:.0} MB ({:+.0} MB{total}){qualifier}{gpu}",
261 after - before,
262 );
263 }
264 }
265
266 /// Return a short qualifier string indicating VRAM measurement quality.
267 #[must_use]
268 const fn vram_qualifier(&self) -> &'static str {
269 match self.after.vram_per_process {
270 Some(true) => " [per-process]",
271 Some(false) => " [device-wide]",
272 None => "",
273 }
274 }
275}
276
277// ---------------------------------------------------------------------------
278// RAM measurement — per-process RSS
279// ---------------------------------------------------------------------------
280
281/// Query the current process's resident set size (RSS) in bytes.
282///
283/// # Platform
284///
285/// - **Windows**: `K32GetProcessMemoryInfo` → `WorkingSetSize` (exact, per-process).
286/// - **Linux**: `/proc/self/status` → `VmRSS` (exact, per-process, no unsafe).
287///
288/// # Errors
289///
290/// Returns [`MIError::Memory`] if the platform API call fails.
291fn process_rss() -> Result<u64> {
292 #[cfg(target_os = "windows")]
293 {
294 windows_rss()
295 }
296 #[cfg(target_os = "linux")]
297 {
298 linux_rss()
299 }
300 #[cfg(not(any(target_os = "windows", target_os = "linux")))]
301 {
302 Err(MIError::Memory(
303 "RAM measurement not supported on this platform".into(),
304 ))
305 }
306}
307
308// -- Windows ----------------------------------------------------------------
309
310/// Windows FFI types and functions for `K32GetProcessMemoryInfo`.
311#[cfg(target_os = "windows")]
312mod win_ffi {
313 /// `PROCESS_MEMORY_COUNTERS` structure from the Windows API.
314 ///
315 /// See: <https://learn.microsoft.com/en-us/windows/win32/api/psapi/ns-psapi-process_memory_counters>
316 #[repr(C)]
317 pub(super) struct ProcessMemoryCounters {
318 /// Size of this structure in bytes.
319 pub cb: u32,
320 /// Number of page faults.
321 pub page_fault_count: u32,
322 /// Peak working set size in bytes.
323 pub peak_working_set_size: usize,
324 /// Current working set size in bytes (= RSS).
325 pub working_set_size: usize,
326 /// Peak paged pool usage in bytes.
327 pub quota_peak_paged_pool_usage: usize,
328 /// Current paged pool usage in bytes.
329 pub quota_paged_pool_usage: usize,
330 /// Peak non-paged pool usage in bytes.
331 pub quota_peak_non_paged_pool_usage: usize,
332 /// Current non-paged pool usage in bytes.
333 pub quota_non_paged_pool_usage: usize,
334 /// Current pagefile usage in bytes.
335 pub pagefile_usage: usize,
336 /// Peak pagefile usage in bytes.
337 pub peak_pagefile_usage: usize,
338 }
339
340 // SAFETY: These are stable Windows API functions with well-defined ABI.
341 // GetCurrentProcess always returns a valid pseudo-handle.
342 // K32GetProcessMemoryInfo writes to caller-provided memory of known size.
343 #[allow(unsafe_code)]
344 unsafe extern "system" {
345 /// Returns a pseudo-handle to the current process (always valid, never null).
346 pub(super) safe fn GetCurrentProcess() -> isize;
347
348 /// Retrieves memory usage information for the specified process.
349 pub(super) unsafe fn K32GetProcessMemoryInfo(
350 process: isize,
351 ppsmem_counters: *mut ProcessMemoryCounters,
352 cb: u32,
353 ) -> i32;
354 }
355}
356
357/// Query RSS on Windows via `K32GetProcessMemoryInfo`.
358#[cfg(target_os = "windows")]
359#[allow(unsafe_code)]
360fn windows_rss() -> Result<u64> {
361 let mut counters = win_ffi::ProcessMemoryCounters {
362 cb: 0,
363 page_fault_count: 0,
364 peak_working_set_size: 0,
365 working_set_size: 0,
366 quota_peak_paged_pool_usage: 0,
367 quota_paged_pool_usage: 0,
368 quota_peak_non_paged_pool_usage: 0,
369 quota_non_paged_pool_usage: 0,
370 pagefile_usage: 0,
371 peak_pagefile_usage: 0,
372 };
373 // CAST: usize → u32, struct size is 80 bytes on x64 — fits in u32
374 #[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
375 let cb = std::mem::size_of::<win_ffi::ProcessMemoryCounters>() as u32;
376 counters.cb = cb;
377
378 let handle = win_ffi::GetCurrentProcess();
379
380 // SAFETY: K32GetProcessMemoryInfo writes into the stack-allocated
381 // `counters` struct, which is correctly sized (cb field set to struct
382 // size). The process handle from GetCurrentProcess is a pseudo-handle
383 // that is always valid for the lifetime of the process.
384 let ok = unsafe { win_ffi::K32GetProcessMemoryInfo(handle, &raw mut counters, cb) };
385
386 if ok != 0 {
387 // CAST: usize → u64, working set size in bytes — always fits
388 #[allow(clippy::as_conversions)]
389 let rss = counters.working_set_size as u64;
390 Ok(rss)
391 } else {
392 Err(MIError::Memory("K32GetProcessMemoryInfo failed".into()))
393 }
394}
395
396// -- Linux ------------------------------------------------------------------
397
398/// Query RSS on Linux via `/proc/self/status`.
399#[cfg(target_os = "linux")]
400fn linux_rss() -> Result<u64> {
401 let status = std::fs::read_to_string("/proc/self/status")
402 .map_err(|e| MIError::Memory(format!("failed to read /proc/self/status: {e}")))?;
403
404 for line in status.lines() {
405 if let Some(rest) = line.strip_prefix("VmRSS:") {
406 let kb_str = rest.trim().trim_end_matches(" kB").trim();
407 let kb: u64 = kb_str.parse().map_err(|e| {
408 MIError::Memory(format!("failed to parse VmRSS value '{kb_str}': {e}"))
409 })?;
410 return Ok(kb * 1024);
411 }
412 }
413
414 Err(MIError::Memory(
415 "VmRSS not found in /proc/self/status".into(),
416 ))
417}
418
419// ---------------------------------------------------------------------------
420// VRAM measurement — DXGI (Windows), NVML, nvidia-smi fallback
421// ---------------------------------------------------------------------------
422
423/// Result of a GPU memory query: `(used_bytes, total_bytes, per_process, gpu_name)`.
424type GpuMemoryResult = (Option<u64>, Option<u64>, Option<bool>, Option<String>);
425
426/// NVML shared library path (stable across driver versions).
427#[cfg(target_os = "linux")]
428const NVML_LIB_PATH: &str = "libnvidia-ml.so.1";
429
430/// NVML shared library path (stable across driver versions).
431#[cfg(target_os = "windows")]
432const NVML_LIB_PATH: &str = "nvml.dll";
433
434/// NVML return code: success.
435const NVML_SUCCESS: u32 = 0;
436
437/// NVML return code: buffer too small (need to retry with larger buffer).
438const NVML_ERROR_INSUFFICIENT_SIZE: u32 = 7;
439
440/// Maximum number of processes to query from NVML in a single call.
441/// 64 is generous — most machines have fewer than 10 GPU processes.
442const NVML_MAX_PROCESSES: usize = 64;
443
444/// Per-process GPU memory info returned by NVML.
445///
446/// Matches the C struct `nvmlProcessInfo_v2_t` (24 bytes) used by both
447/// `nvmlDeviceGetComputeRunningProcesses_v2` and `_v3` (the `_v3` suffix
448/// is a function version, not a struct version).
449/// See: <https://docs.nvidia.com/deploy/nvml-api/structnvmlProcessInfo__v2__t.html>
450#[repr(C)]
451#[derive(Debug, Clone, Copy)]
452struct NvmlProcessInfo {
453 /// Process ID.
454 pid: u32,
455 /// GPU memory used by this process in bytes.
456 /// `u64::MAX` (`0xFFFF_FFFF_FFFF_FFFF`) means "not available".
457 used_gpu_memory: u64,
458 /// GPU instance ID (MIG). Unused in non-MIG mode.
459 gpu_instance_id: u32,
460 /// Compute instance ID (MIG). Unused in non-MIG mode.
461 compute_instance_id: u32,
462}
463
464/// NVML memory info for a device.
465///
466/// Matches the C struct `nvmlMemory_t` from the NVML API.
467/// See: <https://docs.nvidia.com/deploy/nvml-api/structnvmlMemory__t.html>
468#[repr(C)]
469#[derive(Debug, Clone, Copy)]
470struct NvmlMemoryInfo {
471 /// Total GPU memory in bytes.
472 total: u64,
473 /// Free GPU memory in bytes.
474 free: u64,
475 /// Used GPU memory in bytes.
476 used: u64,
477}
478
479/// Opaque NVML device handle.
480type NvmlDevice = *mut std::ffi::c_void;
481
482/// Function signature: `nvmlInit_v2() -> nvmlReturn_t`.
483type NvmlInitFn = unsafe extern "C" fn() -> u32;
484
485/// Function signature: `nvmlShutdown() -> nvmlReturn_t`.
486type NvmlShutdownFn = unsafe extern "C" fn() -> u32;
487
488/// Function signature: `nvmlDeviceGetHandleByIndex_v2(index, *mut device) -> nvmlReturn_t`.
489type NvmlDeviceGetHandleByIndexFn = unsafe extern "C" fn(u32, *mut NvmlDevice) -> u32;
490
491/// Function signature: `nvmlDeviceGetMemoryInfo(device, *mut memory) -> nvmlReturn_t`.
492type NvmlDeviceGetMemoryInfoFn = unsafe extern "C" fn(NvmlDevice, *mut NvmlMemoryInfo) -> u32;
493
494/// Function signature:
495/// `nvmlDeviceGetComputeRunningProcesses_v3(device, *mut count, *mut infos) -> nvmlReturn_t`.
496type NvmlDeviceGetComputeRunningProcessesFn =
497 unsafe extern "C" fn(NvmlDevice, *mut u32, *mut NvmlProcessInfo) -> u32;
498
499/// Query GPU memory — DXGI (Windows), NVML, or `nvidia-smi` fallback.
500///
501/// Returns `(used_bytes, total_bytes, per_process, gpu_name)`:
502/// - `per_process = Some(true)` when per-process query succeeded (DXGI or NVML).
503/// - `per_process = Some(false)` when falling back to `nvidia-smi` (device-wide).
504/// - `gpu_name` is set when DXGI provides the adapter description.
505/// - All `None` if all methods fail.
506fn gpu_memory_used() -> GpuMemoryResult {
507 // Windows: try DXGI first (per-process, works under WDDM)
508 #[cfg(windows)]
509 if let Some(result) = dxgi_query_process_vram() {
510 return result;
511 }
512
513 // Try NVML (per-process on Linux, NOT_AVAILABLE on Windows WDDM)
514 if let Some(result) = nvml_query_process_vram() {
515 let (used, total, per_process) = result;
516 return (used, total, per_process, None);
517 }
518
519 // Fallback to nvidia-smi (device-wide)
520 let (used, total) = nvidia_smi_query();
521 if used.is_some() {
522 (used, total, Some(false), None)
523 } else {
524 (None, None, None, None)
525 }
526}
527
528/// Attempt to query per-process VRAM via NVML.
529///
530/// Returns `None` if NVML cannot be loaded or any API call fails,
531/// signaling the caller to try the fallback path.
532#[allow(unsafe_code)]
533fn nvml_query_process_vram() -> Option<(Option<u64>, Option<u64>, Option<bool>)> {
534 // SAFETY: libloading::Library::new dynamically loads a shared library.
535 // The NVML library is a stable NVIDIA driver component with a well-defined
536 // C ABI. We load it, call functions, and unload it within this scope.
537 let lib = unsafe { libloading::Library::new(NVML_LIB_PATH) }.ok()?;
538
539 // SAFETY: Loading function symbols from the NVML library. Each symbol
540 // name matches the documented NVML C API exactly. The function signatures
541 // (type aliases above) match the NVML header definitions.
542 let init: libloading::Symbol<'_, NvmlInitFn> = unsafe { lib.get(b"nvmlInit_v2\0") }.ok()?;
543 let shutdown: libloading::Symbol<'_, NvmlShutdownFn> =
544 unsafe { lib.get(b"nvmlShutdown\0") }.ok()?;
545 let get_handle: libloading::Symbol<'_, NvmlDeviceGetHandleByIndexFn> =
546 unsafe { lib.get(b"nvmlDeviceGetHandleByIndex_v2\0") }.ok()?;
547 let get_memory: libloading::Symbol<'_, NvmlDeviceGetMemoryInfoFn> =
548 unsafe { lib.get(b"nvmlDeviceGetMemoryInfo\0") }.ok()?;
549 let get_processes: libloading::Symbol<'_, NvmlDeviceGetComputeRunningProcessesFn> =
550 unsafe { lib.get(b"nvmlDeviceGetComputeRunningProcesses_v3\0") }.ok()?;
551
552 // Initialize NVML
553 // SAFETY: nvmlInit_v2 is safe to call from any thread; it initializes
554 // internal NVML state. Returns NVML_SUCCESS (0) on success.
555 let ret = unsafe { init() };
556 if ret != NVML_SUCCESS {
557 return None;
558 }
559
560 // Get device handle for GPU 0 (primary GPU)
561 let mut device: NvmlDevice = std::ptr::null_mut();
562 // SAFETY: nvmlDeviceGetHandleByIndex_v2 writes a valid opaque handle
563 // into `device` when it returns NVML_SUCCESS. Index 0 = primary GPU.
564 let ret = unsafe { get_handle(0, &raw mut device) };
565 if ret != NVML_SUCCESS {
566 // SAFETY: nvmlShutdown is always safe after a successful nvmlInit.
567 unsafe { shutdown() };
568 return None;
569 }
570
571 // Get total memory for the device
572 let mut mem_info = NvmlMemoryInfo {
573 total: 0,
574 free: 0,
575 used: 0,
576 };
577 // SAFETY: nvmlDeviceGetMemoryInfo writes into the caller-provided
578 // NvmlMemoryInfo struct. The device handle is valid (obtained above).
579 let ret = unsafe { get_memory(device, &raw mut mem_info) };
580 let total_bytes = if ret == NVML_SUCCESS {
581 Some(mem_info.total)
582 } else {
583 None
584 };
585
586 // Get per-process memory
587 // CAST: usize → u32, NVML_MAX_PROCESSES is 64 — fits in u32
588 #[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
589 let mut count = NVML_MAX_PROCESSES as u32;
590 let mut infos = [NvmlProcessInfo {
591 pid: 0,
592 used_gpu_memory: 0,
593 gpu_instance_id: 0,
594 compute_instance_id: 0,
595 }; NVML_MAX_PROCESSES];
596
597 // SAFETY: nvmlDeviceGetComputeRunningProcesses_v3 fills `infos` with
598 // up to `count` entries and updates `count` to the actual number written.
599 // The buffer is stack-allocated with NVML_MAX_PROCESSES slots, which is
600 // sufficient for typical workloads.
601 let ret = unsafe { get_processes(device, &raw mut count, infos.as_mut_ptr()) };
602
603 // SAFETY: nvmlShutdown pairs with nvmlInit; always called before return.
604 unsafe { shutdown() };
605
606 if ret != NVML_SUCCESS && ret != NVML_ERROR_INSUFFICIENT_SIZE {
607 return None;
608 }
609
610 // Find our process in the list
611 let my_pid = std::process::id();
612 // CAST: u32 → usize, count is a small process count — always fits
613 #[allow(clippy::as_conversions)]
614 let actual_count = count as usize;
615 let my_vram = infos
616 .get(..actual_count)?
617 .iter()
618 .find(|info| info.pid == my_pid)
619 .map(|info| info.used_gpu_memory);
620
621 // NVML uses u64::MAX as "not available" sentinel — some drivers (e.g., R570
622 // on RTX 5060 Ti) return this for all processes. Fall back to nvidia-smi.
623 if my_vram == Some(u64::MAX) {
624 return None;
625 }
626
627 // Sanity check: if per-process VRAM exceeds total device memory, the value
628 // is likely garbage. Fall back to nvidia-smi.
629 if let (Some(used), Some(total)) = (my_vram, total_bytes) {
630 if used > total {
631 return None;
632 }
633 }
634
635 // Our PID might not be in the list (no active CUDA context yet?) — return None to trigger fallback
636 my_vram.map(|used| (Some(used), total_bytes, Some(true)))
637}
638
639/// Query GPU memory via `nvidia-smi` (device-wide fallback).
640///
641/// Returns `(Some(used_bytes), Some(total_bytes))` on success,
642/// or `(None, None)` if `nvidia-smi` is not available or fails.
643fn nvidia_smi_query() -> (Option<u64>, Option<u64>) {
644 let output = std::process::Command::new("nvidia-smi")
645 .args([
646 "--query-gpu=memory.used,memory.total",
647 "--format=csv,noheader,nounits",
648 ])
649 .output();
650
651 let output = match output {
652 Ok(o) if o.status.success() => o,
653 _ => return (None, None),
654 };
655
656 // BORROW: explicit String::from_utf8_lossy — nvidia-smi output is ASCII
657 let stdout = String::from_utf8_lossy(&output.stdout);
658 let line = match stdout.lines().next() {
659 Some(l) => l.trim(),
660 None => return (None, None),
661 };
662
663 // Format: "1234, 16384" (used MiB, total MiB)
664 let mut parts = line.split(',');
665 let used_str = match parts.next() {
666 Some(s) => s.trim(),
667 None => return (None, None),
668 };
669 let total_str = match parts.next() {
670 Some(s) => s.trim(),
671 None => return (None, None),
672 };
673
674 let used_mb: u64 = match used_str.parse() {
675 Ok(v) => v,
676 Err(_) => return (None, None),
677 };
678 let total_mb: u64 = match total_str.parse() {
679 Ok(v) => v,
680 Err(_) => return (None, None),
681 };
682
683 // nvidia-smi reports in MiB — convert to bytes
684 (Some(used_mb * 1_048_576), Some(total_mb * 1_048_576))
685}
686
687// ---------------------------------------------------------------------------
688// DXGI per-process VRAM (Windows only)
689// ---------------------------------------------------------------------------
690
691/// Query per-process GPU VRAM via DXGI (`IDXGIAdapter3::QueryVideoMemoryInfo`).
692///
693/// This is the only reliable way to get per-process GPU memory on Windows
694/// (WDDM). NVML returns `NVML_VALUE_NOT_AVAILABLE` under WDDM because the
695/// Windows kernel memory manager owns GPU memory, not the NVIDIA driver.
696///
697/// DXGI 1.4 (Windows 10+) provides `QueryVideoMemoryInfo` which returns:
698/// - `CurrentUsage`: per-process GPU memory in bytes (exactly what we want).
699/// - `Budget`: OS-assigned memory budget for this process.
700///
701/// We query `DXGI_MEMORY_SEGMENT_GROUP_LOCAL` (dedicated VRAM on discrete GPUs).
702/// Total VRAM comes from `IDXGIAdapter::GetDesc` → `DedicatedVideoMemory`.
703///
704/// Returns `None` if DXGI is not available or the query fails,
705/// signaling the caller to try NVML or nvidia-smi fallback.
706#[cfg(windows)]
707#[allow(unsafe_code)]
708fn dxgi_query_process_vram() -> Option<GpuMemoryResult> {
709 use windows::Win32::Graphics::Dxgi::{
710 CreateDXGIFactory1, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, DXGI_QUERY_VIDEO_MEMORY_INFO,
711 IDXGIAdapter, IDXGIAdapter3, IDXGIFactory1,
712 };
713 use windows::core::Interface;
714
715 // SAFETY: CreateDXGIFactory1 is a well-documented COM factory function.
716 // It initializes COM internally if needed. The returned IDXGIFactory1
717 // is reference-counted and released automatically when dropped.
718 let factory: IDXGIFactory1 = unsafe { CreateDXGIFactory1() }.ok()?;
719
720 // Enumerate adapters — find the first one with dedicated VRAM > 0
721 // (skip software/render-only adapters like Microsoft Basic Render Driver).
722 let mut adapter_idx = 0u32;
723 loop {
724 // SAFETY: EnumAdapters1 returns S_OK with a valid adapter, or
725 // DXGI_ERROR_NOT_FOUND when idx is out of range.
726 let adapter: IDXGIAdapter = unsafe { factory.EnumAdapters1(adapter_idx) }
727 .ok()?
728 .cast()
729 .ok()?;
730
731 // SAFETY: GetDesc writes a valid DXGI_ADAPTER_DESC into the
732 // caller-provided struct. The adapter handle is valid (obtained above).
733 let desc = unsafe { adapter.GetDesc() }.ok()?;
734 let dedicated_vram = desc.DedicatedVideoMemory;
735
736 if dedicated_vram == 0 {
737 adapter_idx += 1;
738 continue;
739 }
740
741 // Cast to IDXGIAdapter3 for QueryVideoMemoryInfo (DXGI 1.4)
742 let adapter3: IDXGIAdapter3 = adapter.cast().ok()?;
743
744 // SAFETY: QueryVideoMemoryInfo fills a DXGI_QUERY_VIDEO_MEMORY_INFO
745 // struct with per-process memory stats. Node 0 = primary GPU node.
746 // DXGI_MEMORY_SEGMENT_GROUP_LOCAL = dedicated VRAM on discrete GPUs.
747 let mut mem_info = DXGI_QUERY_VIDEO_MEMORY_INFO::default();
748 unsafe {
749 adapter3.QueryVideoMemoryInfo(0, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, &raw mut mem_info)
750 }
751 .ok()?;
752
753 // CAST: usize → u64, DedicatedVideoMemory is usize on Windows
754 #[allow(clippy::as_conversions)]
755 let total = dedicated_vram as u64;
756
757 // Trim trailing null characters from the UTF-16 adapter description
758 // BORROW: explicit from_utf16_lossy — DXGI Description is a fixed-size UTF-16 array
759 let raw_name = String::from_utf16_lossy(&desc.Description);
760 // BORROW: to_owned — trim returns a &str slice; we need an owned String
761 let gpu_name = raw_name.trim_end_matches('\0').to_owned();
762
763 #[cfg(feature = "memory-debug")]
764 eprintln!(
765 "[DXGI debug] adapter={gpu_name}, dedicated_vram={total}, \
766 current_usage={}, budget={}",
767 mem_info.CurrentUsage, mem_info.Budget,
768 );
769
770 return Some((
771 Some(mem_info.CurrentUsage),
772 Some(total),
773 Some(true),
774 Some(gpu_name),
775 ));
776 }
777}
778
779// ---------------------------------------------------------------------------
780// Tests
781// ---------------------------------------------------------------------------
782
783#[cfg(test)]
784#[allow(clippy::unwrap_used, clippy::expect_used)]
785mod tests {
786 use super::*;
787
788 #[test]
789 fn snapshot_cpu_has_ram() {
790 let snap = MemorySnapshot::now(&candle_core::Device::Cpu).unwrap();
791 // Process must be using > 0 bytes of RAM
792 assert!(snap.ram_bytes > 0, "RAM should be non-zero");
793 // CPU device should not have VRAM
794 assert!(snap.vram_bytes.is_none(), "CPU should have no VRAM");
795 assert!(
796 snap.vram_per_process.is_none(),
797 "CPU should have no VRAM qualifier"
798 );
799 }
800
801 #[test]
802 fn report_delta_positive_for_allocation() {
803 let before = MemorySnapshot {
804 ram_bytes: 100 * 1_048_576, // 100 MB
805 vram_bytes: Some(500 * 1_048_576),
806 vram_total_bytes: Some(16_384 * 1_048_576),
807 vram_per_process: Some(true),
808 gpu_name: None,
809 };
810 let after = MemorySnapshot {
811 ram_bytes: 200 * 1_048_576, // 200 MB
812 vram_bytes: Some(1_000 * 1_048_576),
813 vram_total_bytes: Some(16_384 * 1_048_576),
814 vram_per_process: Some(true),
815 gpu_name: None,
816 };
817 let report = MemoryReport::new(before, after);
818
819 let ram_delta = report.ram_delta_mb();
820 assert!(
821 (ram_delta - 100.0).abs() < 0.01,
822 "RAM delta should be ~100 MB, got {ram_delta}"
823 );
824
825 let vram_delta = report.vram_delta_mb().unwrap();
826 assert!(
827 (vram_delta - 500.0).abs() < 0.01,
828 "VRAM delta should be ~500 MB, got {vram_delta}"
829 );
830 }
831
832 #[test]
833 fn report_delta_none_when_no_vram() {
834 let before = MemorySnapshot {
835 ram_bytes: 100,
836 vram_bytes: None,
837 vram_total_bytes: None,
838 vram_per_process: None,
839 gpu_name: None,
840 };
841 let after = MemorySnapshot {
842 ram_bytes: 200,
843 vram_bytes: None,
844 vram_total_bytes: None,
845 vram_per_process: None,
846 gpu_name: None,
847 };
848 let report = MemoryReport::new(before, after);
849 assert!(report.vram_delta_mb().is_none());
850 }
851
852 #[test]
853 fn ram_mb_conversion() {
854 let snap = MemorySnapshot {
855 ram_bytes: 1_048_576, // exactly 1 MB
856 vram_bytes: None,
857 vram_total_bytes: None,
858 vram_per_process: None,
859 gpu_name: None,
860 };
861 assert!((snap.ram_mb() - 1.0).abs() < 0.001);
862 }
863
864 #[test]
865 fn vram_qualifier_per_process() {
866 let snap = MemorySnapshot {
867 ram_bytes: 100,
868 vram_bytes: Some(500),
869 vram_total_bytes: Some(1000),
870 vram_per_process: Some(true),
871 gpu_name: None,
872 };
873 let report = MemoryReport::new(snap.clone(), snap);
874 assert_eq!(report.vram_qualifier(), " [per-process]");
875 }
876
877 #[test]
878 fn vram_qualifier_device_wide() {
879 let snap = MemorySnapshot {
880 ram_bytes: 100,
881 vram_bytes: Some(500),
882 vram_total_bytes: Some(1000),
883 vram_per_process: Some(false),
884 gpu_name: None,
885 };
886 let report = MemoryReport::new(snap.clone(), snap);
887 assert_eq!(report.vram_qualifier(), " [device-wide]");
888 }
889}