Skip to main content

candle_mi/
memory.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Process and GPU memory reporting.
4//!
5//! Provides [`MemorySnapshot`] to capture current RAM and VRAM usage,
6//! and [`MemoryReport`] to measure deltas between two snapshots.
7//!
8//! # VRAM measurement strategy
9//!
10//! VRAM is measured using a three-tier approach:
11//!
12//! 1. **Windows primary — DXGI** (per-process): Uses
13//!    `IDXGIAdapter3::QueryVideoMemoryInfo` (DXGI 1.4, Windows 10+) to get
14//!    true per-process GPU memory. This is the only reliable method on Windows
15//!    because WDDM means the Windows kernel manages GPU memory, not the NVIDIA
16//!    driver — so NVML returns `NOT_AVAILABLE` for per-process queries.
17//! 2. **Linux primary — NVML** (per-process): Dynamically loads
18//!    `libnvidia-ml.so.1` via `libloading` and calls
19//!    `nvmlDeviceGetComputeRunningProcesses` for true per-process GPU memory.
20//! 3. **Fallback — `nvidia-smi`** (device-wide): If both DXGI and NVML fail,
21//!    spawns `nvidia-smi` as a subprocess. Reports device-wide VRAM; the delta
22//!    between two snapshots is accurate on single-user machines.
23//!
24//! # Platform support
25//!
26//! | Metric | Windows | Linux |
27//! |--------|---------|-------|
28//! | RAM (RSS) | `K32GetProcessMemoryInfo` (per-process, exact) | `/proc/self/status` `VmRSS` (per-process, exact) |
29//! | VRAM (DXGI) | `IDXGIAdapter3` (per-process, exact) | N/A |
30//! | VRAM (NVML) | `NOT_AVAILABLE` under WDDM | `libnvidia-ml.so.1` (per-process, exact) |
31//! | VRAM (fallback) | `nvidia-smi` (device-wide) | `nvidia-smi` (device-wide) |
32//!
33//! # Feature gates
34//!
35//! - **`memory`**: Enables this module. Relaxes `#![forbid(unsafe_code)]` to
36//!   `#![deny(unsafe_code)]` for the Windows FFI calls (`K32GetProcessMemoryInfo`,
37//!   DXGI COM calls) and for NVML dynamic symbol loading via `libloading`.
38//!   On Linux RAM measurement, no unsafe code is used.
39//! - **`memory-debug`** (implies `memory`): Prints raw DXGI query results and
40//!   per-chunk VRAM measurements to stderr for diagnosing GPU memory issues.
41
42use crate::{MIError, Result};
43
44// ---------------------------------------------------------------------------
45// Public types
46// ---------------------------------------------------------------------------
47
48/// Memory snapshot at a point in time.
49///
50/// Captures process RAM (resident set size) and optionally GPU VRAM.
51/// Use [`MemorySnapshot::now`] to take a measurement, and
52/// [`MemoryReport::new`] to compute deltas between two snapshots.
53///
54/// # Example
55///
56/// ```no_run
57/// use candle_mi::MemorySnapshot;
58///
59/// let before = MemorySnapshot::now(&candle_core::Device::Cpu)?;
60/// // ... load a model ...
61/// let after = MemorySnapshot::now(&candle_core::Device::Cpu)?;
62/// let report = candle_mi::MemoryReport::new(before, after);
63/// println!("RAM delta: {:+.1} MB", report.ram_delta_mb());
64/// # Ok::<(), candle_mi::MIError>(())
65/// ```
66#[derive(Debug, Clone)]
67pub struct MemorySnapshot {
68    /// Process resident set size (working set on Windows) in bytes.
69    pub ram_bytes: u64,
70    /// GPU memory used in bytes.
71    /// Per-process when measured via DXGI/NVML, device-wide when via `nvidia-smi` fallback.
72    /// `None` if no GPU is present or measurement failed.
73    pub vram_bytes: Option<u64>,
74    /// Total GPU memory on the active device in bytes.
75    /// `None` if no GPU is present or measurement failed.
76    pub vram_total_bytes: Option<u64>,
77    /// Whether the VRAM measurement is per-process (`true`) or device-wide (`false`).
78    /// `None` if no VRAM data is available.
79    pub vram_per_process: Option<bool>,
80    /// GPU adapter name (e.g., `NVIDIA GeForce RTX 5060 Ti`).
81    /// `None` if not available (non-DXGI path or no GPU).
82    pub gpu_name: Option<String>,
83}
84
85/// Memory delta between two snapshots.
86///
87/// Computed from a `before` and `after` [`MemorySnapshot`].
88/// Positive deltas mean memory increased; negative means freed.
89#[derive(Debug, Clone)]
90pub struct MemoryReport {
91    /// Snapshot taken before the operation.
92    pub before: MemorySnapshot,
93    /// Snapshot taken after the operation.
94    pub after: MemorySnapshot,
95}
96
97impl MemorySnapshot {
98    /// Capture current memory state.
99    ///
100    /// RAM is always measured (per-process RSS). VRAM is measured only if
101    /// `device` is CUDA — first via DXGI (Windows, per-process), then NVML
102    /// (Linux, per-process), falling back to `nvidia-smi` (device-wide).
103    ///
104    /// # Errors
105    ///
106    /// Returns [`MIError::Memory`] if the RAM query fails (platform API error).
107    /// VRAM measurement failures are non-fatal — `vram_bytes` is set to `None`.
108    pub fn now(device: &candle_core::Device) -> Result<Self> {
109        let ram_bytes = process_rss()?;
110        let (vram_bytes, vram_total_bytes, per_process, gpu_name) = if device.is_cuda() {
111            gpu_memory_used()
112        } else {
113            (None, None, None, None)
114        };
115        Ok(Self {
116            ram_bytes,
117            vram_bytes,
118            vram_total_bytes,
119            vram_per_process: per_process,
120            gpu_name,
121        })
122    }
123
124    /// Format RAM usage as megabytes.
125    #[must_use]
126    pub fn ram_mb(&self) -> f64 {
127        // CAST: u64 → f64, value is memory in bytes — fits in f64 mantissa
128        // for any realistic process size (< 2^53 bytes = 8 PB)
129        #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
130        let mb = self.ram_bytes as f64 / 1_048_576.0;
131        mb
132    }
133
134    /// Format VRAM usage as megabytes, if available.
135    #[must_use]
136    pub fn vram_mb(&self) -> Option<f64> {
137        // CAST: u64 → f64, same justification as ram_mb
138        #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
139        self.vram_bytes.map(|b| b as f64 / 1_048_576.0)
140    }
141}
142
143/// Synchronize the CUDA device and trim its memory pool.
144///
145/// On a CUDA device this:
146/// 1. Calls `cuCtxSynchronize` so all pending async frees complete.
147/// 2. Calls `cuMemPoolTrimTo(pool, 0)` to release all unused reserved
148///    VRAM back to the device.
149///
150/// cudarc's stream-ordered allocator (`malloc_async` / `free_async`)
151/// keeps freed blocks in a pool for reuse. Over many forward passes
152/// with varying tensor sizes the pool grows monotonically — DXGI and
153/// `nvidia-smi` report this reserved memory as "in use", eventually
154/// causing OOM even though no live tensors need it.
155///
156/// This function is a no-op on CPU and on Metal.
157///
158/// # Example
159///
160/// ```no_run
161/// # use candle_mi::sync_and_trim_gpu;
162/// # let device = candle_core::Device::Cpu;
163/// // After dropping all GPU tensors from a forward pass:
164/// sync_and_trim_gpu(&device);
165/// ```
166// Cannot be `const fn`: the `cuda` branch calls non-const FFI (cuDeviceGetDefaultMemPool,
167// cuMemPoolTrimTo). Without `cuda` the body collapses to a no-op, which is why clippy
168// suggests `const` — but `const fn` would break the cuda-enabled build.
169#[allow(clippy::missing_const_for_fn)]
170pub fn sync_and_trim_gpu(device: &candle_core::Device) {
171    #[cfg(feature = "cuda")]
172    if let candle_core::Device::Cuda(cuda_dev) = device {
173        use candle_core::backend::BackendDevice;
174        // Synchronize so all pending async frees complete.
175        let _ = cuda_dev.synchronize();
176
177        // Trim the default memory pool to release all unused reserved VRAM.
178        // SAFETY: cuDeviceGetDefaultMemPool and cuMemPoolTrimTo are
179        // documented CUDA driver APIs for pool management. The CUdevice
180        // handle comes from candle's CudaContext (valid after synchronize).
181        // cuMemPoolTrimTo(pool, 0) releases all unused memory — it cannot
182        // free memory that is still in use by live tensors.
183        #[allow(unsafe_code)]
184        {
185            use candle_core::cuda_backend::cudarc::driver::sys;
186
187            let stream = cuda_dev.cuda_stream();
188            // Allocate a zero-length slice just to access the CudaContext
189            // (CudaStream.ctx is pub(crate), but CudaSlice.context() is pub).
190            if let Ok(probe) = stream.null::<u8>() {
191                let ctx = probe.context();
192                let cu_device = ctx.cu_device();
193                unsafe {
194                    let mut pool = std::mem::zeroed();
195                    let rc = sys::cuDeviceGetDefaultMemPool(&raw mut pool, cu_device);
196                    if rc == sys::CUresult::CUDA_SUCCESS {
197                        let _ = sys::cuMemPoolTrimTo(pool, 0);
198                    }
199                }
200            }
201        }
202    }
203
204    // Suppress unused-variable warning on non-CUDA builds.
205    #[cfg(not(feature = "cuda"))]
206    let _ = device;
207}
208
209impl MemoryReport {
210    /// Create a report from two snapshots.
211    #[must_use]
212    pub const fn new(before: MemorySnapshot, after: MemorySnapshot) -> Self {
213        Self { before, after }
214    }
215
216    /// RAM delta in megabytes (positive = increased).
217    #[must_use]
218    pub fn ram_delta_mb(&self) -> f64 {
219        self.after.ram_mb() - self.before.ram_mb()
220    }
221
222    /// VRAM delta in megabytes (positive = increased).
223    /// Returns `None` if either snapshot lacks VRAM data.
224    #[must_use]
225    pub fn vram_delta_mb(&self) -> Option<f64> {
226        match (self.after.vram_mb(), self.before.vram_mb()) {
227            (Some(after), Some(before)) => Some(after - before),
228            (Some(_) | None, None) | (None, Some(_)) => None,
229        }
230    }
231
232    /// Print a one-line summary of the delta.
233    pub fn print_delta(&self, label: &str) {
234        let ram = self.ram_delta_mb();
235        print!("  {label}: RAM {ram:+.0} MB");
236        if let Some(vram) = self.vram_delta_mb() {
237            let qualifier = self.vram_qualifier();
238            print!("  |  VRAM {vram:+.0} MB{qualifier}");
239        }
240        println!();
241    }
242
243    /// Print a two-line summary showing before → after for both RAM and VRAM.
244    pub fn print_before_after(&self, label: &str) {
245        println!(
246            "  {label}: RAM {:.0} MB → {:.0} MB ({:+.0} MB)",
247            self.before.ram_mb(),
248            self.after.ram_mb(),
249            self.ram_delta_mb(),
250        );
251        if let (Some(before), Some(after)) = (self.before.vram_mb(), self.after.vram_mb()) {
252            // CAST: u64 → f64, same justification as ram_mb
253            #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
254            let total = self.after.vram_total_bytes.map_or(String::new(), |t| {
255                format!(" / {:.0} MB", t as f64 / 1_048_576.0)
256            });
257            let qualifier = self.vram_qualifier();
258            let gpu = self
259                .after
260                .gpu_name
261                .as_deref()
262                .map_or(String::new(), |name| format!(" [{name}]"));
263            println!(
264                "  {label}: VRAM {before:.0} MB → {after:.0} MB ({:+.0} MB{total}){qualifier}{gpu}",
265                after - before,
266            );
267        }
268    }
269
270    /// Return a short qualifier string indicating VRAM measurement quality.
271    #[must_use]
272    const fn vram_qualifier(&self) -> &'static str {
273        match self.after.vram_per_process {
274            Some(true) => " [per-process]",
275            Some(false) => " [device-wide]",
276            None => "",
277        }
278    }
279}
280
281// ---------------------------------------------------------------------------
282// RAM measurement — per-process RSS
283// ---------------------------------------------------------------------------
284
285/// Query the current process's resident set size (RSS) in bytes.
286///
287/// # Platform
288///
289/// - **Windows**: `K32GetProcessMemoryInfo` → `WorkingSetSize` (exact, per-process).
290/// - **Linux**: `/proc/self/status` → `VmRSS` (exact, per-process, no unsafe).
291///
292/// # Errors
293///
294/// Returns [`MIError::Memory`] if the platform API call fails.
295fn process_rss() -> Result<u64> {
296    #[cfg(target_os = "windows")]
297    {
298        windows_rss()
299    }
300    #[cfg(target_os = "linux")]
301    {
302        linux_rss()
303    }
304    #[cfg(not(any(target_os = "windows", target_os = "linux")))]
305    {
306        Err(MIError::Memory(
307            "RAM measurement not supported on this platform".into(),
308        ))
309    }
310}
311
312// -- Windows ----------------------------------------------------------------
313
314/// Windows FFI types and functions for `K32GetProcessMemoryInfo`.
315#[cfg(target_os = "windows")]
316mod win_ffi {
317    /// `PROCESS_MEMORY_COUNTERS` structure from the Windows API.
318    ///
319    /// See: <https://learn.microsoft.com/en-us/windows/win32/api/psapi/ns-psapi-process_memory_counters>
320    #[repr(C)]
321    pub(super) struct ProcessMemoryCounters {
322        /// Size of this structure in bytes.
323        pub cb: u32,
324        /// Number of page faults.
325        pub page_fault_count: u32,
326        /// Peak working set size in bytes.
327        pub peak_working_set_size: usize,
328        /// Current working set size in bytes (= RSS).
329        pub working_set_size: usize,
330        /// Peak paged pool usage in bytes.
331        pub quota_peak_paged_pool_usage: usize,
332        /// Current paged pool usage in bytes.
333        pub quota_paged_pool_usage: usize,
334        /// Peak non-paged pool usage in bytes.
335        pub quota_peak_non_paged_pool_usage: usize,
336        /// Current non-paged pool usage in bytes.
337        pub quota_non_paged_pool_usage: usize,
338        /// Current pagefile usage in bytes.
339        pub pagefile_usage: usize,
340        /// Peak pagefile usage in bytes.
341        pub peak_pagefile_usage: usize,
342    }
343
344    // SAFETY: These are stable Windows API functions with well-defined ABI.
345    // GetCurrentProcess always returns a valid pseudo-handle.
346    // K32GetProcessMemoryInfo writes to caller-provided memory of known size.
347    #[allow(unsafe_code)]
348    unsafe extern "system" {
349        /// Returns a pseudo-handle to the current process (always valid, never null).
350        pub(super) safe fn GetCurrentProcess() -> isize;
351
352        /// Retrieves memory usage information for the specified process.
353        pub(super) unsafe fn K32GetProcessMemoryInfo(
354            process: isize,
355            ppsmem_counters: *mut ProcessMemoryCounters,
356            cb: u32,
357        ) -> i32;
358    }
359}
360
361/// Query RSS on Windows via `K32GetProcessMemoryInfo`.
362#[cfg(target_os = "windows")]
363#[allow(unsafe_code)]
364fn windows_rss() -> Result<u64> {
365    let mut counters = win_ffi::ProcessMemoryCounters {
366        cb: 0,
367        page_fault_count: 0,
368        peak_working_set_size: 0,
369        working_set_size: 0,
370        quota_peak_paged_pool_usage: 0,
371        quota_paged_pool_usage: 0,
372        quota_peak_non_paged_pool_usage: 0,
373        quota_non_paged_pool_usage: 0,
374        pagefile_usage: 0,
375        peak_pagefile_usage: 0,
376    };
377    // CAST: usize → u32, struct size is 80 bytes on x64 — fits in u32
378    #[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
379    let cb = std::mem::size_of::<win_ffi::ProcessMemoryCounters>() as u32;
380    counters.cb = cb;
381
382    let handle = win_ffi::GetCurrentProcess();
383
384    // SAFETY: K32GetProcessMemoryInfo writes into the stack-allocated
385    // `counters` struct, which is correctly sized (cb field set to struct
386    // size). The process handle from GetCurrentProcess is a pseudo-handle
387    // that is always valid for the lifetime of the process.
388    let ok = unsafe { win_ffi::K32GetProcessMemoryInfo(handle, &raw mut counters, cb) };
389
390    if ok != 0 {
391        // CAST: usize → u64, working set size in bytes — always fits
392        #[allow(clippy::as_conversions)]
393        let rss = counters.working_set_size as u64;
394        Ok(rss)
395    } else {
396        Err(MIError::Memory("K32GetProcessMemoryInfo failed".into()))
397    }
398}
399
400// -- Linux ------------------------------------------------------------------
401
402/// Query RSS on Linux via `/proc/self/status`.
403#[cfg(target_os = "linux")]
404fn linux_rss() -> Result<u64> {
405    let status = std::fs::read_to_string("/proc/self/status")
406        .map_err(|e| MIError::Memory(format!("failed to read /proc/self/status: {e}")))?;
407
408    for line in status.lines() {
409        if let Some(rest) = line.strip_prefix("VmRSS:") {
410            let kb_str = rest.trim().trim_end_matches(" kB").trim();
411            let kb: u64 = kb_str.parse().map_err(|e| {
412                MIError::Memory(format!("failed to parse VmRSS value '{kb_str}': {e}"))
413            })?;
414            return Ok(kb * 1024);
415        }
416    }
417
418    Err(MIError::Memory(
419        "VmRSS not found in /proc/self/status".into(),
420    ))
421}
422
423// ---------------------------------------------------------------------------
424// VRAM measurement — DXGI (Windows), NVML, nvidia-smi fallback
425// ---------------------------------------------------------------------------
426
427/// Result of a GPU memory query: `(used_bytes, total_bytes, per_process, gpu_name)`.
428type GpuMemoryResult = (Option<u64>, Option<u64>, Option<bool>, Option<String>);
429
430/// NVML shared library path (stable across driver versions).
431#[cfg(target_os = "linux")]
432const NVML_LIB_PATH: &str = "libnvidia-ml.so.1";
433
434/// NVML shared library path (stable across driver versions).
435#[cfg(target_os = "windows")]
436const NVML_LIB_PATH: &str = "nvml.dll";
437
438/// NVML return code: success.
439const NVML_SUCCESS: u32 = 0;
440
441/// NVML return code: buffer too small (need to retry with larger buffer).
442const NVML_ERROR_INSUFFICIENT_SIZE: u32 = 7;
443
444/// Maximum number of processes to query from NVML in a single call.
445/// 64 is generous — most machines have fewer than 10 GPU processes.
446const NVML_MAX_PROCESSES: usize = 64;
447
448/// Per-process GPU memory info returned by NVML.
449///
450/// Matches the C struct `nvmlProcessInfo_v2_t` (24 bytes) used by both
451/// `nvmlDeviceGetComputeRunningProcesses_v2` and `_v3` (the `_v3` suffix
452/// is a function version, not a struct version).
453/// See: <https://docs.nvidia.com/deploy/nvml-api/structnvmlProcessInfo__v2__t.html>
454#[repr(C)]
455#[derive(Debug, Clone, Copy)]
456struct NvmlProcessInfo {
457    /// Process ID.
458    pid: u32,
459    /// GPU memory used by this process in bytes.
460    /// `u64::MAX` (`0xFFFF_FFFF_FFFF_FFFF`) means "not available".
461    used_gpu_memory: u64,
462    /// GPU instance ID (MIG). Unused in non-MIG mode.
463    gpu_instance_id: u32,
464    /// Compute instance ID (MIG). Unused in non-MIG mode.
465    compute_instance_id: u32,
466}
467
468/// NVML memory info for a device.
469///
470/// Matches the C struct `nvmlMemory_t` from the NVML API.
471/// See: <https://docs.nvidia.com/deploy/nvml-api/structnvmlMemory__t.html>
472#[repr(C)]
473#[derive(Debug, Clone, Copy)]
474struct NvmlMemoryInfo {
475    /// Total GPU memory in bytes.
476    total: u64,
477    /// Free GPU memory in bytes.
478    free: u64,
479    /// Used GPU memory in bytes.
480    used: u64,
481}
482
483/// Opaque NVML device handle.
484type NvmlDevice = *mut std::ffi::c_void;
485
486/// Function signature: `nvmlInit_v2() -> nvmlReturn_t`.
487type NvmlInitFn = unsafe extern "C" fn() -> u32;
488
489/// Function signature: `nvmlShutdown() -> nvmlReturn_t`.
490type NvmlShutdownFn = unsafe extern "C" fn() -> u32;
491
492/// Function signature: `nvmlDeviceGetHandleByIndex_v2(index, *mut device) -> nvmlReturn_t`.
493type NvmlDeviceGetHandleByIndexFn = unsafe extern "C" fn(u32, *mut NvmlDevice) -> u32;
494
495/// Function signature: `nvmlDeviceGetMemoryInfo(device, *mut memory) -> nvmlReturn_t`.
496type NvmlDeviceGetMemoryInfoFn = unsafe extern "C" fn(NvmlDevice, *mut NvmlMemoryInfo) -> u32;
497
498/// Function signature:
499/// `nvmlDeviceGetComputeRunningProcesses_v3(device, *mut count, *mut infos) -> nvmlReturn_t`.
500type NvmlDeviceGetComputeRunningProcessesFn =
501    unsafe extern "C" fn(NvmlDevice, *mut u32, *mut NvmlProcessInfo) -> u32;
502
503/// Query GPU memory — DXGI (Windows), NVML, or `nvidia-smi` fallback.
504///
505/// Returns `(used_bytes, total_bytes, per_process, gpu_name)`:
506/// - `per_process = Some(true)` when per-process query succeeded (DXGI or NVML).
507/// - `per_process = Some(false)` when falling back to `nvidia-smi` (device-wide).
508/// - `gpu_name` is set when DXGI provides the adapter description.
509/// - All `None` if all methods fail.
510fn gpu_memory_used() -> GpuMemoryResult {
511    // Windows: try DXGI first (per-process, works under WDDM)
512    #[cfg(windows)]
513    if let Some(result) = dxgi_query_process_vram() {
514        return result;
515    }
516
517    // Try NVML (per-process on Linux, NOT_AVAILABLE on Windows WDDM)
518    if let Some(result) = nvml_query_process_vram() {
519        let (used, total, per_process) = result;
520        return (used, total, per_process, None);
521    }
522
523    // Fallback to nvidia-smi (device-wide)
524    let (used, total) = nvidia_smi_query();
525    if used.is_some() {
526        (used, total, Some(false), None)
527    } else {
528        (None, None, None, None)
529    }
530}
531
532/// Attempt to query per-process VRAM via NVML.
533///
534/// Returns `None` if NVML cannot be loaded or any API call fails,
535/// signaling the caller to try the fallback path.
536#[allow(unsafe_code)]
537fn nvml_query_process_vram() -> Option<(Option<u64>, Option<u64>, Option<bool>)> {
538    // SAFETY: libloading::Library::new dynamically loads a shared library.
539    // The NVML library is a stable NVIDIA driver component with a well-defined
540    // C ABI. We load it, call functions, and unload it within this scope.
541    let lib = unsafe { libloading::Library::new(NVML_LIB_PATH) }.ok()?;
542
543    // SAFETY: Loading function symbols from the NVML library. Each symbol
544    // name matches the documented NVML C API exactly. The function signatures
545    // (type aliases above) match the NVML header definitions.
546    let init: libloading::Symbol<'_, NvmlInitFn> = unsafe { lib.get(b"nvmlInit_v2\0") }.ok()?;
547    let shutdown: libloading::Symbol<'_, NvmlShutdownFn> =
548        unsafe { lib.get(b"nvmlShutdown\0") }.ok()?;
549    let get_handle: libloading::Symbol<'_, NvmlDeviceGetHandleByIndexFn> =
550        unsafe { lib.get(b"nvmlDeviceGetHandleByIndex_v2\0") }.ok()?;
551    let get_memory: libloading::Symbol<'_, NvmlDeviceGetMemoryInfoFn> =
552        unsafe { lib.get(b"nvmlDeviceGetMemoryInfo\0") }.ok()?;
553    let get_processes: libloading::Symbol<'_, NvmlDeviceGetComputeRunningProcessesFn> =
554        unsafe { lib.get(b"nvmlDeviceGetComputeRunningProcesses_v3\0") }.ok()?;
555
556    // Initialize NVML
557    // SAFETY: nvmlInit_v2 is safe to call from any thread; it initializes
558    // internal NVML state. Returns NVML_SUCCESS (0) on success.
559    let ret = unsafe { init() };
560    if ret != NVML_SUCCESS {
561        return None;
562    }
563
564    // Get device handle for GPU 0 (primary GPU)
565    let mut device: NvmlDevice = std::ptr::null_mut();
566    // SAFETY: nvmlDeviceGetHandleByIndex_v2 writes a valid opaque handle
567    // into `device` when it returns NVML_SUCCESS. Index 0 = primary GPU.
568    let ret = unsafe { get_handle(0, &raw mut device) };
569    if ret != NVML_SUCCESS {
570        // SAFETY: nvmlShutdown is always safe after a successful nvmlInit.
571        unsafe { shutdown() };
572        return None;
573    }
574
575    // Get total memory for the device
576    let mut mem_info = NvmlMemoryInfo {
577        total: 0,
578        free: 0,
579        used: 0,
580    };
581    // SAFETY: nvmlDeviceGetMemoryInfo writes into the caller-provided
582    // NvmlMemoryInfo struct. The device handle is valid (obtained above).
583    let ret = unsafe { get_memory(device, &raw mut mem_info) };
584    let total_bytes = if ret == NVML_SUCCESS {
585        Some(mem_info.total)
586    } else {
587        None
588    };
589
590    // Get per-process memory
591    // CAST: usize → u32, NVML_MAX_PROCESSES is 64 — fits in u32
592    #[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
593    let mut count = NVML_MAX_PROCESSES as u32;
594    let mut infos = [NvmlProcessInfo {
595        pid: 0,
596        used_gpu_memory: 0,
597        gpu_instance_id: 0,
598        compute_instance_id: 0,
599    }; NVML_MAX_PROCESSES];
600
601    // SAFETY: nvmlDeviceGetComputeRunningProcesses_v3 fills `infos` with
602    // up to `count` entries and updates `count` to the actual number written.
603    // The buffer is stack-allocated with NVML_MAX_PROCESSES slots, which is
604    // sufficient for typical workloads.
605    let ret = unsafe { get_processes(device, &raw mut count, infos.as_mut_ptr()) };
606
607    // SAFETY: nvmlShutdown pairs with nvmlInit; always called before return.
608    unsafe { shutdown() };
609
610    if ret != NVML_SUCCESS && ret != NVML_ERROR_INSUFFICIENT_SIZE {
611        return None;
612    }
613
614    // Find our process in the list
615    let my_pid = std::process::id();
616    // CAST: u32 → usize, count is a small process count — always fits
617    #[allow(clippy::as_conversions)]
618    let actual_count = count as usize;
619    let my_vram = infos
620        .get(..actual_count)?
621        .iter()
622        .find(|info| info.pid == my_pid)
623        .map(|info| info.used_gpu_memory);
624
625    // NVML uses u64::MAX as "not available" sentinel — some drivers (e.g., R570
626    // on RTX 5060 Ti) return this for all processes. Fall back to nvidia-smi.
627    if my_vram == Some(u64::MAX) {
628        return None;
629    }
630
631    // Sanity check: if per-process VRAM exceeds total device memory, the value
632    // is likely garbage. Fall back to nvidia-smi.
633    if let (Some(used), Some(total)) = (my_vram, total_bytes)
634        && used > total
635    {
636        return None;
637    }
638
639    // Our PID might not be in the list (no active CUDA context yet?) — return None to trigger fallback
640    my_vram.map(|used| (Some(used), total_bytes, Some(true)))
641}
642
643/// Query GPU memory via `nvidia-smi` (device-wide fallback).
644///
645/// Returns `(Some(used_bytes), Some(total_bytes))` on success,
646/// or `(None, None)` if `nvidia-smi` is not available or fails.
647fn nvidia_smi_query() -> (Option<u64>, Option<u64>) {
648    let output = std::process::Command::new("nvidia-smi")
649        .args([
650            "--query-gpu=memory.used,memory.total",
651            "--format=csv,noheader,nounits",
652        ])
653        .output();
654
655    let output = match output {
656        Ok(o) if o.status.success() => o,
657        _ => return (None, None),
658    };
659
660    // BORROW: explicit String::from_utf8_lossy — nvidia-smi output is ASCII
661    let stdout = String::from_utf8_lossy(&output.stdout);
662    let line = match stdout.lines().next() {
663        Some(l) => l.trim(),
664        None => return (None, None),
665    };
666
667    // Format: "1234, 16384" (used MiB, total MiB)
668    let mut parts = line.split(',');
669    let used_str = match parts.next() {
670        Some(s) => s.trim(),
671        None => return (None, None),
672    };
673    let total_str = match parts.next() {
674        Some(s) => s.trim(),
675        None => return (None, None),
676    };
677
678    let used_mb: u64 = match used_str.parse() {
679        Ok(v) => v,
680        Err(_) => return (None, None),
681    };
682    let total_mb: u64 = match total_str.parse() {
683        Ok(v) => v,
684        Err(_) => return (None, None),
685    };
686
687    // nvidia-smi reports in MiB — convert to bytes
688    (Some(used_mb * 1_048_576), Some(total_mb * 1_048_576))
689}
690
691// ---------------------------------------------------------------------------
692// DXGI per-process VRAM (Windows only)
693// ---------------------------------------------------------------------------
694
695/// Query per-process GPU VRAM via DXGI (`IDXGIAdapter3::QueryVideoMemoryInfo`).
696///
697/// This is the only reliable way to get per-process GPU memory on Windows
698/// (WDDM). NVML returns `NVML_VALUE_NOT_AVAILABLE` under WDDM because the
699/// Windows kernel memory manager owns GPU memory, not the NVIDIA driver.
700///
701/// DXGI 1.4 (Windows 10+) provides `QueryVideoMemoryInfo` which returns:
702/// - `CurrentUsage`: per-process GPU memory in bytes (exactly what we want).
703/// - `Budget`: OS-assigned memory budget for this process.
704///
705/// We query `DXGI_MEMORY_SEGMENT_GROUP_LOCAL` (dedicated VRAM on discrete GPUs).
706/// Total VRAM comes from `IDXGIAdapter::GetDesc` → `DedicatedVideoMemory`.
707///
708/// Returns `None` if DXGI is not available or the query fails,
709/// signaling the caller to try NVML or nvidia-smi fallback.
710#[cfg(windows)]
711#[allow(unsafe_code)]
712fn dxgi_query_process_vram() -> Option<GpuMemoryResult> {
713    use windows::Win32::Graphics::Dxgi::{
714        CreateDXGIFactory1, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, DXGI_QUERY_VIDEO_MEMORY_INFO,
715        IDXGIAdapter, IDXGIAdapter3, IDXGIFactory1,
716    };
717    use windows::core::Interface;
718
719    // SAFETY: CreateDXGIFactory1 is a well-documented COM factory function.
720    // It initializes COM internally if needed. The returned IDXGIFactory1
721    // is reference-counted and released automatically when dropped.
722    let factory: IDXGIFactory1 = unsafe { CreateDXGIFactory1() }.ok()?;
723
724    // Enumerate adapters — find the first one with dedicated VRAM > 0
725    // (skip software/render-only adapters like Microsoft Basic Render Driver).
726    let mut adapter_idx = 0u32;
727    loop {
728        // SAFETY: EnumAdapters1 returns S_OK with a valid adapter, or
729        // DXGI_ERROR_NOT_FOUND when idx is out of range.
730        let adapter: IDXGIAdapter = unsafe { factory.EnumAdapters1(adapter_idx) }
731            .ok()?
732            .cast()
733            .ok()?;
734
735        // SAFETY: GetDesc writes a valid DXGI_ADAPTER_DESC into the
736        // caller-provided struct. The adapter handle is valid (obtained above).
737        let desc = unsafe { adapter.GetDesc() }.ok()?;
738        let dedicated_vram = desc.DedicatedVideoMemory;
739
740        if dedicated_vram == 0 {
741            adapter_idx += 1;
742            continue;
743        }
744
745        // Cast to IDXGIAdapter3 for QueryVideoMemoryInfo (DXGI 1.4)
746        let adapter3: IDXGIAdapter3 = adapter.cast().ok()?;
747
748        // SAFETY: QueryVideoMemoryInfo fills a DXGI_QUERY_VIDEO_MEMORY_INFO
749        // struct with per-process memory stats. Node 0 = primary GPU node.
750        // DXGI_MEMORY_SEGMENT_GROUP_LOCAL = dedicated VRAM on discrete GPUs.
751        let mut mem_info = DXGI_QUERY_VIDEO_MEMORY_INFO::default();
752        unsafe {
753            adapter3.QueryVideoMemoryInfo(0, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, &raw mut mem_info)
754        }
755        .ok()?;
756
757        // CAST: usize → u64, DedicatedVideoMemory is usize on Windows
758        #[allow(clippy::as_conversions)]
759        let total = dedicated_vram as u64;
760
761        // Trim trailing null characters from the UTF-16 adapter description
762        // BORROW: explicit from_utf16_lossy — DXGI Description is a fixed-size UTF-16 array
763        let raw_name = String::from_utf16_lossy(&desc.Description);
764        // BORROW: to_owned — trim returns a &str slice; we need an owned String
765        let gpu_name = raw_name.trim_end_matches('\0').to_owned();
766
767        #[cfg(feature = "memory-debug")]
768        eprintln!(
769            "[DXGI debug] adapter={gpu_name}, dedicated_vram={total}, \
770             current_usage={}, budget={}",
771            mem_info.CurrentUsage, mem_info.Budget,
772        );
773
774        return Some((
775            Some(mem_info.CurrentUsage),
776            Some(total),
777            Some(true),
778            Some(gpu_name),
779        ));
780    }
781}
782
783// ---------------------------------------------------------------------------
784// Tests
785// ---------------------------------------------------------------------------
786
787#[cfg(test)]
788#[allow(clippy::unwrap_used, clippy::expect_used)]
789mod tests {
790    use super::*;
791
792    #[test]
793    fn snapshot_cpu_has_ram() {
794        let snap = MemorySnapshot::now(&candle_core::Device::Cpu).unwrap();
795        // Process must be using > 0 bytes of RAM
796        assert!(snap.ram_bytes > 0, "RAM should be non-zero");
797        // CPU device should not have VRAM
798        assert!(snap.vram_bytes.is_none(), "CPU should have no VRAM");
799        assert!(
800            snap.vram_per_process.is_none(),
801            "CPU should have no VRAM qualifier"
802        );
803    }
804
805    #[test]
806    fn report_delta_positive_for_allocation() {
807        let before = MemorySnapshot {
808            ram_bytes: 100 * 1_048_576, // 100 MB
809            vram_bytes: Some(500 * 1_048_576),
810            vram_total_bytes: Some(16_384 * 1_048_576),
811            vram_per_process: Some(true),
812            gpu_name: None,
813        };
814        let after = MemorySnapshot {
815            ram_bytes: 200 * 1_048_576, // 200 MB
816            vram_bytes: Some(1_000 * 1_048_576),
817            vram_total_bytes: Some(16_384 * 1_048_576),
818            vram_per_process: Some(true),
819            gpu_name: None,
820        };
821        let report = MemoryReport::new(before, after);
822
823        let ram_delta = report.ram_delta_mb();
824        assert!(
825            (ram_delta - 100.0).abs() < 0.01,
826            "RAM delta should be ~100 MB, got {ram_delta}"
827        );
828
829        let vram_delta = report.vram_delta_mb().unwrap();
830        assert!(
831            (vram_delta - 500.0).abs() < 0.01,
832            "VRAM delta should be ~500 MB, got {vram_delta}"
833        );
834    }
835
836    #[test]
837    fn report_delta_none_when_no_vram() {
838        let before = MemorySnapshot {
839            ram_bytes: 100,
840            vram_bytes: None,
841            vram_total_bytes: None,
842            vram_per_process: None,
843            gpu_name: None,
844        };
845        let after = MemorySnapshot {
846            ram_bytes: 200,
847            vram_bytes: None,
848            vram_total_bytes: None,
849            vram_per_process: None,
850            gpu_name: None,
851        };
852        let report = MemoryReport::new(before, after);
853        assert!(report.vram_delta_mb().is_none());
854    }
855
856    #[test]
857    fn ram_mb_conversion() {
858        let snap = MemorySnapshot {
859            ram_bytes: 1_048_576, // exactly 1 MB
860            vram_bytes: None,
861            vram_total_bytes: None,
862            vram_per_process: None,
863            gpu_name: None,
864        };
865        assert!((snap.ram_mb() - 1.0).abs() < 0.001);
866    }
867
868    #[test]
869    fn vram_qualifier_per_process() {
870        let snap = MemorySnapshot {
871            ram_bytes: 100,
872            vram_bytes: Some(500),
873            vram_total_bytes: Some(1000),
874            vram_per_process: Some(true),
875            gpu_name: None,
876        };
877        let report = MemoryReport::new(snap.clone(), snap);
878        assert_eq!(report.vram_qualifier(), " [per-process]");
879    }
880
881    #[test]
882    fn vram_qualifier_device_wide() {
883        let snap = MemorySnapshot {
884            ram_bytes: 100,
885            vram_bytes: Some(500),
886            vram_total_bytes: Some(1000),
887            vram_per_process: Some(false),
888            gpu_name: None,
889        };
890        let report = MemoryReport::new(snap.clone(), snap);
891        assert_eq!(report.vram_qualifier(), " [device-wide]");
892    }
893}