Skip to main content

candle_mi/
memory.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Process and GPU memory reporting.
4//!
5//! Provides [`MemorySnapshot`] to capture current RAM and VRAM usage,
6//! and [`MemoryReport`] to measure deltas between two snapshots.
7//!
8//! # VRAM measurement strategy
9//!
10//! VRAM is measured using a three-tier approach:
11//!
12//! 1. **Windows primary — DXGI** (per-process): Uses
13//!    `IDXGIAdapter3::QueryVideoMemoryInfo` (DXGI 1.4, Windows 10+) to get
14//!    true per-process GPU memory. This is the only reliable method on Windows
15//!    because WDDM means the Windows kernel manages GPU memory, not the NVIDIA
16//!    driver — so NVML returns `NOT_AVAILABLE` for per-process queries.
17//! 2. **Linux primary — NVML** (per-process): Dynamically loads
18//!    `libnvidia-ml.so.1` via `libloading` and calls
19//!    `nvmlDeviceGetComputeRunningProcesses` for true per-process GPU memory.
20//! 3. **Fallback — `nvidia-smi`** (device-wide): If both DXGI and NVML fail,
21//!    spawns `nvidia-smi` as a subprocess. Reports device-wide VRAM; the delta
22//!    between two snapshots is accurate on single-user machines.
23//!
24//! # Platform support
25//!
26//! | Metric | Windows | Linux |
27//! |--------|---------|-------|
28//! | RAM (RSS) | `K32GetProcessMemoryInfo` (per-process, exact) | `/proc/self/status` `VmRSS` (per-process, exact) |
29//! | VRAM (DXGI) | `IDXGIAdapter3` (per-process, exact) | N/A |
30//! | VRAM (NVML) | `NOT_AVAILABLE` under WDDM | `libnvidia-ml.so.1` (per-process, exact) |
31//! | VRAM (fallback) | `nvidia-smi` (device-wide) | `nvidia-smi` (device-wide) |
32//!
33//! # Feature gates
34//!
35//! - **`memory`**: Enables this module. Relaxes `#![forbid(unsafe_code)]` to
36//!   `#![deny(unsafe_code)]` for the Windows FFI calls (`K32GetProcessMemoryInfo`,
37//!   DXGI COM calls) and for NVML dynamic symbol loading via `libloading`.
38//!   On Linux RAM measurement, no unsafe code is used.
39//! - **`memory-debug`** (implies `memory`): Prints raw DXGI query results and
40//!   per-chunk VRAM measurements to stderr for diagnosing GPU memory issues.
41
42use crate::{MIError, Result};
43
44// ---------------------------------------------------------------------------
45// Public types
46// ---------------------------------------------------------------------------
47
48/// Memory snapshot at a point in time.
49///
50/// Captures process RAM (resident set size) and optionally GPU VRAM.
51/// Use [`MemorySnapshot::now`] to take a measurement, and
52/// [`MemoryReport::new`] to compute deltas between two snapshots.
53///
54/// # Example
55///
56/// ```no_run
57/// use candle_mi::MemorySnapshot;
58///
59/// let before = MemorySnapshot::now(&candle_core::Device::Cpu)?;
60/// // ... load a model ...
61/// let after = MemorySnapshot::now(&candle_core::Device::Cpu)?;
62/// let report = candle_mi::MemoryReport::new(before, after);
63/// println!("RAM delta: {:+.1} MB", report.ram_delta_mb());
64/// # Ok::<(), candle_mi::MIError>(())
65/// ```
66#[derive(Debug, Clone)]
67pub struct MemorySnapshot {
68    /// Process resident set size (working set on Windows) in bytes.
69    pub ram_bytes: u64,
70    /// GPU memory used in bytes.
71    /// Per-process when measured via DXGI/NVML, device-wide when via `nvidia-smi` fallback.
72    /// `None` if no GPU is present or measurement failed.
73    pub vram_bytes: Option<u64>,
74    /// Total GPU memory on the active device in bytes.
75    /// `None` if no GPU is present or measurement failed.
76    pub vram_total_bytes: Option<u64>,
77    /// Whether the VRAM measurement is per-process (`true`) or device-wide (`false`).
78    /// `None` if no VRAM data is available.
79    pub vram_per_process: Option<bool>,
80    /// GPU adapter name (e.g., `NVIDIA GeForce RTX 5060 Ti`).
81    /// `None` if not available (non-DXGI path or no GPU).
82    pub gpu_name: Option<String>,
83}
84
85/// Memory delta between two snapshots.
86///
87/// Computed from a `before` and `after` [`MemorySnapshot`].
88/// Positive deltas mean memory increased; negative means freed.
89#[derive(Debug, Clone)]
90pub struct MemoryReport {
91    /// Snapshot taken before the operation.
92    pub before: MemorySnapshot,
93    /// Snapshot taken after the operation.
94    pub after: MemorySnapshot,
95}
96
97impl MemorySnapshot {
98    /// Capture current memory state.
99    ///
100    /// RAM is always measured (per-process RSS). VRAM is measured only if
101    /// `device` is CUDA — first via DXGI (Windows, per-process), then NVML
102    /// (Linux, per-process), falling back to `nvidia-smi` (device-wide).
103    ///
104    /// # Errors
105    ///
106    /// Returns [`MIError::Memory`] if the RAM query fails (platform API error).
107    /// VRAM measurement failures are non-fatal — `vram_bytes` is set to `None`.
108    pub fn now(device: &candle_core::Device) -> Result<Self> {
109        let ram_bytes = process_rss()?;
110        let (vram_bytes, vram_total_bytes, per_process, gpu_name) = if device.is_cuda() {
111            gpu_memory_used()
112        } else {
113            (None, None, None, None)
114        };
115        Ok(Self {
116            ram_bytes,
117            vram_bytes,
118            vram_total_bytes,
119            vram_per_process: per_process,
120            gpu_name,
121        })
122    }
123
124    /// Format RAM usage as megabytes.
125    #[must_use]
126    pub fn ram_mb(&self) -> f64 {
127        // CAST: u64 → f64, value is memory in bytes — fits in f64 mantissa
128        // for any realistic process size (< 2^53 bytes = 8 PB)
129        #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
130        let mb = self.ram_bytes as f64 / 1_048_576.0;
131        mb
132    }
133
134    /// Format VRAM usage as megabytes, if available.
135    #[must_use]
136    pub fn vram_mb(&self) -> Option<f64> {
137        // CAST: u64 → f64, same justification as ram_mb
138        #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
139        self.vram_bytes.map(|b| b as f64 / 1_048_576.0)
140    }
141}
142
143/// Synchronize the CUDA device and trim its memory pool.
144///
145/// On a CUDA device this:
146/// 1. Calls `cuCtxSynchronize` so all pending async frees complete.
147/// 2. Calls `cuMemPoolTrimTo(pool, 0)` to release all unused reserved
148///    VRAM back to the device.
149///
150/// cudarc's stream-ordered allocator (`malloc_async` / `free_async`)
151/// keeps freed blocks in a pool for reuse. Over many forward passes
152/// with varying tensor sizes the pool grows monotonically — DXGI and
153/// `nvidia-smi` report this reserved memory as "in use", eventually
154/// causing OOM even though no live tensors need it.
155///
156/// This function is a no-op on CPU and on Metal.
157///
158/// # Example
159///
160/// ```no_run
161/// # use candle_mi::sync_and_trim_gpu;
162/// # let device = candle_core::Device::Cpu;
163/// // After dropping all GPU tensors from a forward pass:
164/// sync_and_trim_gpu(&device);
165/// ```
166pub fn sync_and_trim_gpu(device: &candle_core::Device) {
167    #[cfg(feature = "cuda")]
168    if let candle_core::Device::Cuda(cuda_dev) = device {
169        use candle_core::backend::BackendDevice;
170        // Synchronize so all pending async frees complete.
171        let _ = cuda_dev.synchronize();
172
173        // Trim the default memory pool to release all unused reserved VRAM.
174        // SAFETY: cuDeviceGetDefaultMemPool and cuMemPoolTrimTo are
175        // documented CUDA driver APIs for pool management. The CUdevice
176        // handle comes from candle's CudaContext (valid after synchronize).
177        // cuMemPoolTrimTo(pool, 0) releases all unused memory — it cannot
178        // free memory that is still in use by live tensors.
179        #[allow(unsafe_code)]
180        {
181            use candle_core::cuda_backend::cudarc::driver::sys;
182
183            let stream = cuda_dev.cuda_stream();
184            // Allocate a zero-length slice just to access the CudaContext
185            // (CudaStream.ctx is pub(crate), but CudaSlice.context() is pub).
186            if let Ok(probe) = stream.null::<u8>() {
187                let ctx = probe.context();
188                let cu_device = ctx.cu_device();
189                unsafe {
190                    let mut pool = std::mem::zeroed();
191                    let rc = sys::cuDeviceGetDefaultMemPool(&raw mut pool, cu_device);
192                    if rc == sys::CUresult::CUDA_SUCCESS {
193                        let _ = sys::cuMemPoolTrimTo(pool, 0);
194                    }
195                }
196            }
197        }
198    }
199
200    // Suppress unused-variable warning on non-CUDA builds.
201    #[cfg(not(feature = "cuda"))]
202    let _ = device;
203}
204
205impl MemoryReport {
206    /// Create a report from two snapshots.
207    #[must_use]
208    pub const fn new(before: MemorySnapshot, after: MemorySnapshot) -> Self {
209        Self { before, after }
210    }
211
212    /// RAM delta in megabytes (positive = increased).
213    #[must_use]
214    pub fn ram_delta_mb(&self) -> f64 {
215        self.after.ram_mb() - self.before.ram_mb()
216    }
217
218    /// VRAM delta in megabytes (positive = increased).
219    /// Returns `None` if either snapshot lacks VRAM data.
220    #[must_use]
221    pub fn vram_delta_mb(&self) -> Option<f64> {
222        match (self.after.vram_mb(), self.before.vram_mb()) {
223            (Some(after), Some(before)) => Some(after - before),
224            (Some(_) | None, None) | (None, Some(_)) => None,
225        }
226    }
227
228    /// Print a one-line summary of the delta.
229    pub fn print_delta(&self, label: &str) {
230        let ram = self.ram_delta_mb();
231        print!("  {label}: RAM {ram:+.0} MB");
232        if let Some(vram) = self.vram_delta_mb() {
233            let qualifier = self.vram_qualifier();
234            print!("  |  VRAM {vram:+.0} MB{qualifier}");
235        }
236        println!();
237    }
238
239    /// Print a two-line summary showing before → after for both RAM and VRAM.
240    pub fn print_before_after(&self, label: &str) {
241        println!(
242            "  {label}: RAM {:.0} MB → {:.0} MB ({:+.0} MB)",
243            self.before.ram_mb(),
244            self.after.ram_mb(),
245            self.ram_delta_mb(),
246        );
247        if let (Some(before), Some(after)) = (self.before.vram_mb(), self.after.vram_mb()) {
248            // CAST: u64 → f64, same justification as ram_mb
249            #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
250            let total = self.after.vram_total_bytes.map_or(String::new(), |t| {
251                format!(" / {:.0} MB", t as f64 / 1_048_576.0)
252            });
253            let qualifier = self.vram_qualifier();
254            let gpu = self
255                .after
256                .gpu_name
257                .as_deref()
258                .map_or(String::new(), |name| format!(" [{name}]"));
259            println!(
260                "  {label}: VRAM {before:.0} MB → {after:.0} MB ({:+.0} MB{total}){qualifier}{gpu}",
261                after - before,
262            );
263        }
264    }
265
266    /// Return a short qualifier string indicating VRAM measurement quality.
267    #[must_use]
268    const fn vram_qualifier(&self) -> &'static str {
269        match self.after.vram_per_process {
270            Some(true) => " [per-process]",
271            Some(false) => " [device-wide]",
272            None => "",
273        }
274    }
275}
276
277// ---------------------------------------------------------------------------
278// RAM measurement — per-process RSS
279// ---------------------------------------------------------------------------
280
281/// Query the current process's resident set size (RSS) in bytes.
282///
283/// # Platform
284///
285/// - **Windows**: `K32GetProcessMemoryInfo` → `WorkingSetSize` (exact, per-process).
286/// - **Linux**: `/proc/self/status` → `VmRSS` (exact, per-process, no unsafe).
287///
288/// # Errors
289///
290/// Returns [`MIError::Memory`] if the platform API call fails.
291fn process_rss() -> Result<u64> {
292    #[cfg(target_os = "windows")]
293    {
294        windows_rss()
295    }
296    #[cfg(target_os = "linux")]
297    {
298        linux_rss()
299    }
300    #[cfg(not(any(target_os = "windows", target_os = "linux")))]
301    {
302        Err(MIError::Memory(
303            "RAM measurement not supported on this platform".into(),
304        ))
305    }
306}
307
308// -- Windows ----------------------------------------------------------------
309
310/// Windows FFI types and functions for `K32GetProcessMemoryInfo`.
311#[cfg(target_os = "windows")]
312mod win_ffi {
313    /// `PROCESS_MEMORY_COUNTERS` structure from the Windows API.
314    ///
315    /// See: <https://learn.microsoft.com/en-us/windows/win32/api/psapi/ns-psapi-process_memory_counters>
316    #[repr(C)]
317    pub(super) struct ProcessMemoryCounters {
318        /// Size of this structure in bytes.
319        pub cb: u32,
320        /// Number of page faults.
321        pub page_fault_count: u32,
322        /// Peak working set size in bytes.
323        pub peak_working_set_size: usize,
324        /// Current working set size in bytes (= RSS).
325        pub working_set_size: usize,
326        /// Peak paged pool usage in bytes.
327        pub quota_peak_paged_pool_usage: usize,
328        /// Current paged pool usage in bytes.
329        pub quota_paged_pool_usage: usize,
330        /// Peak non-paged pool usage in bytes.
331        pub quota_peak_non_paged_pool_usage: usize,
332        /// Current non-paged pool usage in bytes.
333        pub quota_non_paged_pool_usage: usize,
334        /// Current pagefile usage in bytes.
335        pub pagefile_usage: usize,
336        /// Peak pagefile usage in bytes.
337        pub peak_pagefile_usage: usize,
338    }
339
340    // SAFETY: These are stable Windows API functions with well-defined ABI.
341    // GetCurrentProcess always returns a valid pseudo-handle.
342    // K32GetProcessMemoryInfo writes to caller-provided memory of known size.
343    #[allow(unsafe_code)]
344    unsafe extern "system" {
345        /// Returns a pseudo-handle to the current process (always valid, never null).
346        pub(super) safe fn GetCurrentProcess() -> isize;
347
348        /// Retrieves memory usage information for the specified process.
349        pub(super) unsafe fn K32GetProcessMemoryInfo(
350            process: isize,
351            ppsmem_counters: *mut ProcessMemoryCounters,
352            cb: u32,
353        ) -> i32;
354    }
355}
356
357/// Query RSS on Windows via `K32GetProcessMemoryInfo`.
358#[cfg(target_os = "windows")]
359#[allow(unsafe_code)]
360fn windows_rss() -> Result<u64> {
361    let mut counters = win_ffi::ProcessMemoryCounters {
362        cb: 0,
363        page_fault_count: 0,
364        peak_working_set_size: 0,
365        working_set_size: 0,
366        quota_peak_paged_pool_usage: 0,
367        quota_paged_pool_usage: 0,
368        quota_peak_non_paged_pool_usage: 0,
369        quota_non_paged_pool_usage: 0,
370        pagefile_usage: 0,
371        peak_pagefile_usage: 0,
372    };
373    // CAST: usize → u32, struct size is 80 bytes on x64 — fits in u32
374    #[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
375    let cb = std::mem::size_of::<win_ffi::ProcessMemoryCounters>() as u32;
376    counters.cb = cb;
377
378    let handle = win_ffi::GetCurrentProcess();
379
380    // SAFETY: K32GetProcessMemoryInfo writes into the stack-allocated
381    // `counters` struct, which is correctly sized (cb field set to struct
382    // size). The process handle from GetCurrentProcess is a pseudo-handle
383    // that is always valid for the lifetime of the process.
384    let ok = unsafe { win_ffi::K32GetProcessMemoryInfo(handle, &raw mut counters, cb) };
385
386    if ok != 0 {
387        // CAST: usize → u64, working set size in bytes — always fits
388        #[allow(clippy::as_conversions)]
389        let rss = counters.working_set_size as u64;
390        Ok(rss)
391    } else {
392        Err(MIError::Memory("K32GetProcessMemoryInfo failed".into()))
393    }
394}
395
396// -- Linux ------------------------------------------------------------------
397
398/// Query RSS on Linux via `/proc/self/status`.
399#[cfg(target_os = "linux")]
400fn linux_rss() -> Result<u64> {
401    let status = std::fs::read_to_string("/proc/self/status")
402        .map_err(|e| MIError::Memory(format!("failed to read /proc/self/status: {e}")))?;
403
404    for line in status.lines() {
405        if let Some(rest) = line.strip_prefix("VmRSS:") {
406            let kb_str = rest.trim().trim_end_matches(" kB").trim();
407            let kb: u64 = kb_str.parse().map_err(|e| {
408                MIError::Memory(format!("failed to parse VmRSS value '{kb_str}': {e}"))
409            })?;
410            return Ok(kb * 1024);
411        }
412    }
413
414    Err(MIError::Memory(
415        "VmRSS not found in /proc/self/status".into(),
416    ))
417}
418
419// ---------------------------------------------------------------------------
420// VRAM measurement — DXGI (Windows), NVML, nvidia-smi fallback
421// ---------------------------------------------------------------------------
422
423/// Result of a GPU memory query: `(used_bytes, total_bytes, per_process, gpu_name)`.
424type GpuMemoryResult = (Option<u64>, Option<u64>, Option<bool>, Option<String>);
425
426/// NVML shared library path (stable across driver versions).
427#[cfg(target_os = "linux")]
428const NVML_LIB_PATH: &str = "libnvidia-ml.so.1";
429
430/// NVML shared library path (stable across driver versions).
431#[cfg(target_os = "windows")]
432const NVML_LIB_PATH: &str = "nvml.dll";
433
434/// NVML return code: success.
435const NVML_SUCCESS: u32 = 0;
436
437/// NVML return code: buffer too small (need to retry with larger buffer).
438const NVML_ERROR_INSUFFICIENT_SIZE: u32 = 7;
439
440/// Maximum number of processes to query from NVML in a single call.
441/// 64 is generous — most machines have fewer than 10 GPU processes.
442const NVML_MAX_PROCESSES: usize = 64;
443
444/// Per-process GPU memory info returned by NVML.
445///
446/// Matches the C struct `nvmlProcessInfo_v2_t` (24 bytes) used by both
447/// `nvmlDeviceGetComputeRunningProcesses_v2` and `_v3` (the `_v3` suffix
448/// is a function version, not a struct version).
449/// See: <https://docs.nvidia.com/deploy/nvml-api/structnvmlProcessInfo__v2__t.html>
450#[repr(C)]
451#[derive(Debug, Clone, Copy)]
452struct NvmlProcessInfo {
453    /// Process ID.
454    pid: u32,
455    /// GPU memory used by this process in bytes.
456    /// `u64::MAX` (`0xFFFF_FFFF_FFFF_FFFF`) means "not available".
457    used_gpu_memory: u64,
458    /// GPU instance ID (MIG). Unused in non-MIG mode.
459    gpu_instance_id: u32,
460    /// Compute instance ID (MIG). Unused in non-MIG mode.
461    compute_instance_id: u32,
462}
463
464/// NVML memory info for a device.
465///
466/// Matches the C struct `nvmlMemory_t` from the NVML API.
467/// See: <https://docs.nvidia.com/deploy/nvml-api/structnvmlMemory__t.html>
468#[repr(C)]
469#[derive(Debug, Clone, Copy)]
470struct NvmlMemoryInfo {
471    /// Total GPU memory in bytes.
472    total: u64,
473    /// Free GPU memory in bytes.
474    free: u64,
475    /// Used GPU memory in bytes.
476    used: u64,
477}
478
479/// Opaque NVML device handle.
480type NvmlDevice = *mut std::ffi::c_void;
481
482/// Function signature: `nvmlInit_v2() -> nvmlReturn_t`.
483type NvmlInitFn = unsafe extern "C" fn() -> u32;
484
485/// Function signature: `nvmlShutdown() -> nvmlReturn_t`.
486type NvmlShutdownFn = unsafe extern "C" fn() -> u32;
487
488/// Function signature: `nvmlDeviceGetHandleByIndex_v2(index, *mut device) -> nvmlReturn_t`.
489type NvmlDeviceGetHandleByIndexFn = unsafe extern "C" fn(u32, *mut NvmlDevice) -> u32;
490
491/// Function signature: `nvmlDeviceGetMemoryInfo(device, *mut memory) -> nvmlReturn_t`.
492type NvmlDeviceGetMemoryInfoFn = unsafe extern "C" fn(NvmlDevice, *mut NvmlMemoryInfo) -> u32;
493
494/// Function signature:
495/// `nvmlDeviceGetComputeRunningProcesses_v3(device, *mut count, *mut infos) -> nvmlReturn_t`.
496type NvmlDeviceGetComputeRunningProcessesFn =
497    unsafe extern "C" fn(NvmlDevice, *mut u32, *mut NvmlProcessInfo) -> u32;
498
499/// Query GPU memory — DXGI (Windows), NVML, or `nvidia-smi` fallback.
500///
501/// Returns `(used_bytes, total_bytes, per_process, gpu_name)`:
502/// - `per_process = Some(true)` when per-process query succeeded (DXGI or NVML).
503/// - `per_process = Some(false)` when falling back to `nvidia-smi` (device-wide).
504/// - `gpu_name` is set when DXGI provides the adapter description.
505/// - All `None` if all methods fail.
506fn gpu_memory_used() -> GpuMemoryResult {
507    // Windows: try DXGI first (per-process, works under WDDM)
508    #[cfg(windows)]
509    if let Some(result) = dxgi_query_process_vram() {
510        return result;
511    }
512
513    // Try NVML (per-process on Linux, NOT_AVAILABLE on Windows WDDM)
514    if let Some(result) = nvml_query_process_vram() {
515        let (used, total, per_process) = result;
516        return (used, total, per_process, None);
517    }
518
519    // Fallback to nvidia-smi (device-wide)
520    let (used, total) = nvidia_smi_query();
521    if used.is_some() {
522        (used, total, Some(false), None)
523    } else {
524        (None, None, None, None)
525    }
526}
527
528/// Attempt to query per-process VRAM via NVML.
529///
530/// Returns `None` if NVML cannot be loaded or any API call fails,
531/// signaling the caller to try the fallback path.
532#[allow(unsafe_code)]
533fn nvml_query_process_vram() -> Option<(Option<u64>, Option<u64>, Option<bool>)> {
534    // SAFETY: libloading::Library::new dynamically loads a shared library.
535    // The NVML library is a stable NVIDIA driver component with a well-defined
536    // C ABI. We load it, call functions, and unload it within this scope.
537    let lib = unsafe { libloading::Library::new(NVML_LIB_PATH) }.ok()?;
538
539    // SAFETY: Loading function symbols from the NVML library. Each symbol
540    // name matches the documented NVML C API exactly. The function signatures
541    // (type aliases above) match the NVML header definitions.
542    let init: libloading::Symbol<'_, NvmlInitFn> = unsafe { lib.get(b"nvmlInit_v2\0") }.ok()?;
543    let shutdown: libloading::Symbol<'_, NvmlShutdownFn> =
544        unsafe { lib.get(b"nvmlShutdown\0") }.ok()?;
545    let get_handle: libloading::Symbol<'_, NvmlDeviceGetHandleByIndexFn> =
546        unsafe { lib.get(b"nvmlDeviceGetHandleByIndex_v2\0") }.ok()?;
547    let get_memory: libloading::Symbol<'_, NvmlDeviceGetMemoryInfoFn> =
548        unsafe { lib.get(b"nvmlDeviceGetMemoryInfo\0") }.ok()?;
549    let get_processes: libloading::Symbol<'_, NvmlDeviceGetComputeRunningProcessesFn> =
550        unsafe { lib.get(b"nvmlDeviceGetComputeRunningProcesses_v3\0") }.ok()?;
551
552    // Initialize NVML
553    // SAFETY: nvmlInit_v2 is safe to call from any thread; it initializes
554    // internal NVML state. Returns NVML_SUCCESS (0) on success.
555    let ret = unsafe { init() };
556    if ret != NVML_SUCCESS {
557        return None;
558    }
559
560    // Get device handle for GPU 0 (primary GPU)
561    let mut device: NvmlDevice = std::ptr::null_mut();
562    // SAFETY: nvmlDeviceGetHandleByIndex_v2 writes a valid opaque handle
563    // into `device` when it returns NVML_SUCCESS. Index 0 = primary GPU.
564    let ret = unsafe { get_handle(0, &raw mut device) };
565    if ret != NVML_SUCCESS {
566        // SAFETY: nvmlShutdown is always safe after a successful nvmlInit.
567        unsafe { shutdown() };
568        return None;
569    }
570
571    // Get total memory for the device
572    let mut mem_info = NvmlMemoryInfo {
573        total: 0,
574        free: 0,
575        used: 0,
576    };
577    // SAFETY: nvmlDeviceGetMemoryInfo writes into the caller-provided
578    // NvmlMemoryInfo struct. The device handle is valid (obtained above).
579    let ret = unsafe { get_memory(device, &raw mut mem_info) };
580    let total_bytes = if ret == NVML_SUCCESS {
581        Some(mem_info.total)
582    } else {
583        None
584    };
585
586    // Get per-process memory
587    // CAST: usize → u32, NVML_MAX_PROCESSES is 64 — fits in u32
588    #[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
589    let mut count = NVML_MAX_PROCESSES as u32;
590    let mut infos = [NvmlProcessInfo {
591        pid: 0,
592        used_gpu_memory: 0,
593        gpu_instance_id: 0,
594        compute_instance_id: 0,
595    }; NVML_MAX_PROCESSES];
596
597    // SAFETY: nvmlDeviceGetComputeRunningProcesses_v3 fills `infos` with
598    // up to `count` entries and updates `count` to the actual number written.
599    // The buffer is stack-allocated with NVML_MAX_PROCESSES slots, which is
600    // sufficient for typical workloads.
601    let ret = unsafe { get_processes(device, &raw mut count, infos.as_mut_ptr()) };
602
603    // SAFETY: nvmlShutdown pairs with nvmlInit; always called before return.
604    unsafe { shutdown() };
605
606    if ret != NVML_SUCCESS && ret != NVML_ERROR_INSUFFICIENT_SIZE {
607        return None;
608    }
609
610    // Find our process in the list
611    let my_pid = std::process::id();
612    // CAST: u32 → usize, count is a small process count — always fits
613    #[allow(clippy::as_conversions)]
614    let actual_count = count as usize;
615    let my_vram = infos
616        .get(..actual_count)?
617        .iter()
618        .find(|info| info.pid == my_pid)
619        .map(|info| info.used_gpu_memory);
620
621    // NVML uses u64::MAX as "not available" sentinel — some drivers (e.g., R570
622    // on RTX 5060 Ti) return this for all processes. Fall back to nvidia-smi.
623    if my_vram == Some(u64::MAX) {
624        return None;
625    }
626
627    // Sanity check: if per-process VRAM exceeds total device memory, the value
628    // is likely garbage. Fall back to nvidia-smi.
629    if let (Some(used), Some(total)) = (my_vram, total_bytes) {
630        if used > total {
631            return None;
632        }
633    }
634
635    // Our PID might not be in the list (no active CUDA context yet?) — return None to trigger fallback
636    my_vram.map(|used| (Some(used), total_bytes, Some(true)))
637}
638
639/// Query GPU memory via `nvidia-smi` (device-wide fallback).
640///
641/// Returns `(Some(used_bytes), Some(total_bytes))` on success,
642/// or `(None, None)` if `nvidia-smi` is not available or fails.
643fn nvidia_smi_query() -> (Option<u64>, Option<u64>) {
644    let output = std::process::Command::new("nvidia-smi")
645        .args([
646            "--query-gpu=memory.used,memory.total",
647            "--format=csv,noheader,nounits",
648        ])
649        .output();
650
651    let output = match output {
652        Ok(o) if o.status.success() => o,
653        _ => return (None, None),
654    };
655
656    // BORROW: explicit String::from_utf8_lossy — nvidia-smi output is ASCII
657    let stdout = String::from_utf8_lossy(&output.stdout);
658    let line = match stdout.lines().next() {
659        Some(l) => l.trim(),
660        None => return (None, None),
661    };
662
663    // Format: "1234, 16384" (used MiB, total MiB)
664    let mut parts = line.split(',');
665    let used_str = match parts.next() {
666        Some(s) => s.trim(),
667        None => return (None, None),
668    };
669    let total_str = match parts.next() {
670        Some(s) => s.trim(),
671        None => return (None, None),
672    };
673
674    let used_mb: u64 = match used_str.parse() {
675        Ok(v) => v,
676        Err(_) => return (None, None),
677    };
678    let total_mb: u64 = match total_str.parse() {
679        Ok(v) => v,
680        Err(_) => return (None, None),
681    };
682
683    // nvidia-smi reports in MiB — convert to bytes
684    (Some(used_mb * 1_048_576), Some(total_mb * 1_048_576))
685}
686
687// ---------------------------------------------------------------------------
688// DXGI per-process VRAM (Windows only)
689// ---------------------------------------------------------------------------
690
691/// Query per-process GPU VRAM via DXGI (`IDXGIAdapter3::QueryVideoMemoryInfo`).
692///
693/// This is the only reliable way to get per-process GPU memory on Windows
694/// (WDDM). NVML returns `NVML_VALUE_NOT_AVAILABLE` under WDDM because the
695/// Windows kernel memory manager owns GPU memory, not the NVIDIA driver.
696///
697/// DXGI 1.4 (Windows 10+) provides `QueryVideoMemoryInfo` which returns:
698/// - `CurrentUsage`: per-process GPU memory in bytes (exactly what we want).
699/// - `Budget`: OS-assigned memory budget for this process.
700///
701/// We query `DXGI_MEMORY_SEGMENT_GROUP_LOCAL` (dedicated VRAM on discrete GPUs).
702/// Total VRAM comes from `IDXGIAdapter::GetDesc` → `DedicatedVideoMemory`.
703///
704/// Returns `None` if DXGI is not available or the query fails,
705/// signaling the caller to try NVML or nvidia-smi fallback.
706#[cfg(windows)]
707#[allow(unsafe_code)]
708fn dxgi_query_process_vram() -> Option<GpuMemoryResult> {
709    use windows::Win32::Graphics::Dxgi::{
710        CreateDXGIFactory1, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, DXGI_QUERY_VIDEO_MEMORY_INFO,
711        IDXGIAdapter, IDXGIAdapter3, IDXGIFactory1,
712    };
713    use windows::core::Interface;
714
715    // SAFETY: CreateDXGIFactory1 is a well-documented COM factory function.
716    // It initializes COM internally if needed. The returned IDXGIFactory1
717    // is reference-counted and released automatically when dropped.
718    let factory: IDXGIFactory1 = unsafe { CreateDXGIFactory1() }.ok()?;
719
720    // Enumerate adapters — find the first one with dedicated VRAM > 0
721    // (skip software/render-only adapters like Microsoft Basic Render Driver).
722    let mut adapter_idx = 0u32;
723    loop {
724        // SAFETY: EnumAdapters1 returns S_OK with a valid adapter, or
725        // DXGI_ERROR_NOT_FOUND when idx is out of range.
726        let adapter: IDXGIAdapter = unsafe { factory.EnumAdapters1(adapter_idx) }
727            .ok()?
728            .cast()
729            .ok()?;
730
731        // SAFETY: GetDesc writes a valid DXGI_ADAPTER_DESC into the
732        // caller-provided struct. The adapter handle is valid (obtained above).
733        let desc = unsafe { adapter.GetDesc() }.ok()?;
734        let dedicated_vram = desc.DedicatedVideoMemory;
735
736        if dedicated_vram == 0 {
737            adapter_idx += 1;
738            continue;
739        }
740
741        // Cast to IDXGIAdapter3 for QueryVideoMemoryInfo (DXGI 1.4)
742        let adapter3: IDXGIAdapter3 = adapter.cast().ok()?;
743
744        // SAFETY: QueryVideoMemoryInfo fills a DXGI_QUERY_VIDEO_MEMORY_INFO
745        // struct with per-process memory stats. Node 0 = primary GPU node.
746        // DXGI_MEMORY_SEGMENT_GROUP_LOCAL = dedicated VRAM on discrete GPUs.
747        let mut mem_info = DXGI_QUERY_VIDEO_MEMORY_INFO::default();
748        unsafe {
749            adapter3.QueryVideoMemoryInfo(0, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, &raw mut mem_info)
750        }
751        .ok()?;
752
753        // CAST: usize → u64, DedicatedVideoMemory is usize on Windows
754        #[allow(clippy::as_conversions)]
755        let total = dedicated_vram as u64;
756
757        // Trim trailing null characters from the UTF-16 adapter description
758        // BORROW: explicit from_utf16_lossy — DXGI Description is a fixed-size UTF-16 array
759        let raw_name = String::from_utf16_lossy(&desc.Description);
760        // BORROW: to_owned — trim returns a &str slice; we need an owned String
761        let gpu_name = raw_name.trim_end_matches('\0').to_owned();
762
763        #[cfg(feature = "memory-debug")]
764        eprintln!(
765            "[DXGI debug] adapter={gpu_name}, dedicated_vram={total}, \
766             current_usage={}, budget={}",
767            mem_info.CurrentUsage, mem_info.Budget,
768        );
769
770        return Some((
771            Some(mem_info.CurrentUsage),
772            Some(total),
773            Some(true),
774            Some(gpu_name),
775        ));
776    }
777}
778
779// ---------------------------------------------------------------------------
780// Tests
781// ---------------------------------------------------------------------------
782
783#[cfg(test)]
784#[allow(clippy::unwrap_used, clippy::expect_used)]
785mod tests {
786    use super::*;
787
788    #[test]
789    fn snapshot_cpu_has_ram() {
790        let snap = MemorySnapshot::now(&candle_core::Device::Cpu).unwrap();
791        // Process must be using > 0 bytes of RAM
792        assert!(snap.ram_bytes > 0, "RAM should be non-zero");
793        // CPU device should not have VRAM
794        assert!(snap.vram_bytes.is_none(), "CPU should have no VRAM");
795        assert!(
796            snap.vram_per_process.is_none(),
797            "CPU should have no VRAM qualifier"
798        );
799    }
800
801    #[test]
802    fn report_delta_positive_for_allocation() {
803        let before = MemorySnapshot {
804            ram_bytes: 100 * 1_048_576, // 100 MB
805            vram_bytes: Some(500 * 1_048_576),
806            vram_total_bytes: Some(16_384 * 1_048_576),
807            vram_per_process: Some(true),
808            gpu_name: None,
809        };
810        let after = MemorySnapshot {
811            ram_bytes: 200 * 1_048_576, // 200 MB
812            vram_bytes: Some(1_000 * 1_048_576),
813            vram_total_bytes: Some(16_384 * 1_048_576),
814            vram_per_process: Some(true),
815            gpu_name: None,
816        };
817        let report = MemoryReport::new(before, after);
818
819        let ram_delta = report.ram_delta_mb();
820        assert!(
821            (ram_delta - 100.0).abs() < 0.01,
822            "RAM delta should be ~100 MB, got {ram_delta}"
823        );
824
825        let vram_delta = report.vram_delta_mb().unwrap();
826        assert!(
827            (vram_delta - 500.0).abs() < 0.01,
828            "VRAM delta should be ~500 MB, got {vram_delta}"
829        );
830    }
831
832    #[test]
833    fn report_delta_none_when_no_vram() {
834        let before = MemorySnapshot {
835            ram_bytes: 100,
836            vram_bytes: None,
837            vram_total_bytes: None,
838            vram_per_process: None,
839            gpu_name: None,
840        };
841        let after = MemorySnapshot {
842            ram_bytes: 200,
843            vram_bytes: None,
844            vram_total_bytes: None,
845            vram_per_process: None,
846            gpu_name: None,
847        };
848        let report = MemoryReport::new(before, after);
849        assert!(report.vram_delta_mb().is_none());
850    }
851
852    #[test]
853    fn ram_mb_conversion() {
854        let snap = MemorySnapshot {
855            ram_bytes: 1_048_576, // exactly 1 MB
856            vram_bytes: None,
857            vram_total_bytes: None,
858            vram_per_process: None,
859            gpu_name: None,
860        };
861        assert!((snap.ram_mb() - 1.0).abs() < 0.001);
862    }
863
864    #[test]
865    fn vram_qualifier_per_process() {
866        let snap = MemorySnapshot {
867            ram_bytes: 100,
868            vram_bytes: Some(500),
869            vram_total_bytes: Some(1000),
870            vram_per_process: Some(true),
871            gpu_name: None,
872        };
873        let report = MemoryReport::new(snap.clone(), snap);
874        assert_eq!(report.vram_qualifier(), " [per-process]");
875    }
876
877    #[test]
878    fn vram_qualifier_device_wide() {
879        let snap = MemorySnapshot {
880            ram_bytes: 100,
881            vram_bytes: Some(500),
882            vram_total_bytes: Some(1000),
883            vram_per_process: Some(false),
884            gpu_name: None,
885        };
886        let report = MemoryReport::new(snap.clone(), snap);
887        assert_eq!(report.vram_qualifier(), " [device-wide]");
888    }
889}