use crate::{MIError, Result};
#[derive(Debug, Clone)]
pub struct MemorySnapshot {
pub ram_bytes: u64,
pub vram_bytes: Option<u64>,
pub vram_total_bytes: Option<u64>,
pub vram_per_process: Option<bool>,
pub gpu_name: Option<String>,
}
#[derive(Debug, Clone)]
pub struct MemoryReport {
pub before: MemorySnapshot,
pub after: MemorySnapshot,
}
impl MemorySnapshot {
pub fn now(device: &candle_core::Device) -> Result<Self> {
let ram_bytes = process_rss()?;
let (vram_bytes, vram_total_bytes, per_process, gpu_name) = if device.is_cuda() {
gpu_memory_used()
} else {
(None, None, None, None)
};
Ok(Self {
ram_bytes,
vram_bytes,
vram_total_bytes,
vram_per_process: per_process,
gpu_name,
})
}
#[must_use]
pub fn ram_mb(&self) -> f64 {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let mb = self.ram_bytes as f64 / 1_048_576.0;
mb
}
#[must_use]
pub fn vram_mb(&self) -> Option<f64> {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
self.vram_bytes.map(|b| b as f64 / 1_048_576.0)
}
}
pub fn sync_and_trim_gpu(device: &candle_core::Device) {
#[cfg(feature = "cuda")]
if let candle_core::Device::Cuda(cuda_dev) = device {
use candle_core::backend::BackendDevice;
let _ = cuda_dev.synchronize();
#[allow(unsafe_code)]
{
use candle_core::cuda_backend::cudarc::driver::sys;
let stream = cuda_dev.cuda_stream();
if let Ok(probe) = stream.null::<u8>() {
let ctx = probe.context();
let cu_device = ctx.cu_device();
unsafe {
let mut pool = std::mem::zeroed();
let rc = sys::cuDeviceGetDefaultMemPool(&raw mut pool, cu_device);
if rc == sys::CUresult::CUDA_SUCCESS {
let _ = sys::cuMemPoolTrimTo(pool, 0);
}
}
}
}
}
#[cfg(not(feature = "cuda"))]
let _ = device;
}
impl MemoryReport {
#[must_use]
pub const fn new(before: MemorySnapshot, after: MemorySnapshot) -> Self {
Self { before, after }
}
#[must_use]
pub fn ram_delta_mb(&self) -> f64 {
self.after.ram_mb() - self.before.ram_mb()
}
#[must_use]
pub fn vram_delta_mb(&self) -> Option<f64> {
match (self.after.vram_mb(), self.before.vram_mb()) {
(Some(after), Some(before)) => Some(after - before),
(Some(_) | None, None) | (None, Some(_)) => None,
}
}
pub fn print_delta(&self, label: &str) {
let ram = self.ram_delta_mb();
print!(" {label}: RAM {ram:+.0} MB");
if let Some(vram) = self.vram_delta_mb() {
let qualifier = self.vram_qualifier();
print!(" | VRAM {vram:+.0} MB{qualifier}");
}
println!();
}
pub fn print_before_after(&self, label: &str) {
println!(
" {label}: RAM {:.0} MB → {:.0} MB ({:+.0} MB)",
self.before.ram_mb(),
self.after.ram_mb(),
self.ram_delta_mb(),
);
if let (Some(before), Some(after)) = (self.before.vram_mb(), self.after.vram_mb()) {
#[allow(clippy::cast_precision_loss, clippy::as_conversions)]
let total = self.after.vram_total_bytes.map_or(String::new(), |t| {
format!(" / {:.0} MB", t as f64 / 1_048_576.0)
});
let qualifier = self.vram_qualifier();
let gpu = self
.after
.gpu_name
.as_deref()
.map_or(String::new(), |name| format!(" [{name}]"));
println!(
" {label}: VRAM {before:.0} MB → {after:.0} MB ({:+.0} MB{total}){qualifier}{gpu}",
after - before,
);
}
}
#[must_use]
const fn vram_qualifier(&self) -> &'static str {
match self.after.vram_per_process {
Some(true) => " [per-process]",
Some(false) => " [device-wide]",
None => "",
}
}
}
fn process_rss() -> Result<u64> {
#[cfg(target_os = "windows")]
{
windows_rss()
}
#[cfg(target_os = "linux")]
{
linux_rss()
}
#[cfg(not(any(target_os = "windows", target_os = "linux")))]
{
Err(MIError::Memory(
"RAM measurement not supported on this platform".into(),
))
}
}
#[cfg(target_os = "windows")]
mod win_ffi {
#[repr(C)]
pub(super) struct ProcessMemoryCounters {
pub cb: u32,
pub page_fault_count: u32,
pub peak_working_set_size: usize,
pub working_set_size: usize,
pub quota_peak_paged_pool_usage: usize,
pub quota_paged_pool_usage: usize,
pub quota_peak_non_paged_pool_usage: usize,
pub quota_non_paged_pool_usage: usize,
pub pagefile_usage: usize,
pub peak_pagefile_usage: usize,
}
#[allow(unsafe_code)]
unsafe extern "system" {
pub(super) safe fn GetCurrentProcess() -> isize;
pub(super) unsafe fn K32GetProcessMemoryInfo(
process: isize,
ppsmem_counters: *mut ProcessMemoryCounters,
cb: u32,
) -> i32;
}
}
#[cfg(target_os = "windows")]
#[allow(unsafe_code)]
fn windows_rss() -> Result<u64> {
let mut counters = win_ffi::ProcessMemoryCounters {
cb: 0,
page_fault_count: 0,
peak_working_set_size: 0,
working_set_size: 0,
quota_peak_paged_pool_usage: 0,
quota_paged_pool_usage: 0,
quota_peak_non_paged_pool_usage: 0,
quota_non_paged_pool_usage: 0,
pagefile_usage: 0,
peak_pagefile_usage: 0,
};
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let cb = std::mem::size_of::<win_ffi::ProcessMemoryCounters>() as u32;
counters.cb = cb;
let handle = win_ffi::GetCurrentProcess();
let ok = unsafe { win_ffi::K32GetProcessMemoryInfo(handle, &raw mut counters, cb) };
if ok != 0 {
#[allow(clippy::as_conversions)]
let rss = counters.working_set_size as u64;
Ok(rss)
} else {
Err(MIError::Memory("K32GetProcessMemoryInfo failed".into()))
}
}
#[cfg(target_os = "linux")]
fn linux_rss() -> Result<u64> {
let status = std::fs::read_to_string("/proc/self/status")
.map_err(|e| MIError::Memory(format!("failed to read /proc/self/status: {e}")))?;
for line in status.lines() {
if let Some(rest) = line.strip_prefix("VmRSS:") {
let kb_str = rest.trim().trim_end_matches(" kB").trim();
let kb: u64 = kb_str.parse().map_err(|e| {
MIError::Memory(format!("failed to parse VmRSS value '{kb_str}': {e}"))
})?;
return Ok(kb * 1024);
}
}
Err(MIError::Memory(
"VmRSS not found in /proc/self/status".into(),
))
}
type GpuMemoryResult = (Option<u64>, Option<u64>, Option<bool>, Option<String>);
#[cfg(target_os = "linux")]
const NVML_LIB_PATH: &str = "libnvidia-ml.so.1";
#[cfg(target_os = "windows")]
const NVML_LIB_PATH: &str = "nvml.dll";
const NVML_SUCCESS: u32 = 0;
const NVML_ERROR_INSUFFICIENT_SIZE: u32 = 7;
const NVML_MAX_PROCESSES: usize = 64;
#[repr(C)]
#[derive(Debug, Clone, Copy)]
struct NvmlProcessInfo {
pid: u32,
used_gpu_memory: u64,
gpu_instance_id: u32,
compute_instance_id: u32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
struct NvmlMemoryInfo {
total: u64,
free: u64,
used: u64,
}
type NvmlDevice = *mut std::ffi::c_void;
type NvmlInitFn = unsafe extern "C" fn() -> u32;
type NvmlShutdownFn = unsafe extern "C" fn() -> u32;
type NvmlDeviceGetHandleByIndexFn = unsafe extern "C" fn(u32, *mut NvmlDevice) -> u32;
type NvmlDeviceGetMemoryInfoFn = unsafe extern "C" fn(NvmlDevice, *mut NvmlMemoryInfo) -> u32;
type NvmlDeviceGetComputeRunningProcessesFn =
unsafe extern "C" fn(NvmlDevice, *mut u32, *mut NvmlProcessInfo) -> u32;
fn gpu_memory_used() -> GpuMemoryResult {
#[cfg(windows)]
if let Some(result) = dxgi_query_process_vram() {
return result;
}
if let Some(result) = nvml_query_process_vram() {
let (used, total, per_process) = result;
return (used, total, per_process, None);
}
let (used, total) = nvidia_smi_query();
if used.is_some() {
(used, total, Some(false), None)
} else {
(None, None, None, None)
}
}
#[allow(unsafe_code)]
fn nvml_query_process_vram() -> Option<(Option<u64>, Option<u64>, Option<bool>)> {
let lib = unsafe { libloading::Library::new(NVML_LIB_PATH) }.ok()?;
let init: libloading::Symbol<'_, NvmlInitFn> = unsafe { lib.get(b"nvmlInit_v2\0") }.ok()?;
let shutdown: libloading::Symbol<'_, NvmlShutdownFn> =
unsafe { lib.get(b"nvmlShutdown\0") }.ok()?;
let get_handle: libloading::Symbol<'_, NvmlDeviceGetHandleByIndexFn> =
unsafe { lib.get(b"nvmlDeviceGetHandleByIndex_v2\0") }.ok()?;
let get_memory: libloading::Symbol<'_, NvmlDeviceGetMemoryInfoFn> =
unsafe { lib.get(b"nvmlDeviceGetMemoryInfo\0") }.ok()?;
let get_processes: libloading::Symbol<'_, NvmlDeviceGetComputeRunningProcessesFn> =
unsafe { lib.get(b"nvmlDeviceGetComputeRunningProcesses_v3\0") }.ok()?;
let ret = unsafe { init() };
if ret != NVML_SUCCESS {
return None;
}
let mut device: NvmlDevice = std::ptr::null_mut();
let ret = unsafe { get_handle(0, &raw mut device) };
if ret != NVML_SUCCESS {
unsafe { shutdown() };
return None;
}
let mut mem_info = NvmlMemoryInfo {
total: 0,
free: 0,
used: 0,
};
let ret = unsafe { get_memory(device, &raw mut mem_info) };
let total_bytes = if ret == NVML_SUCCESS {
Some(mem_info.total)
} else {
None
};
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let mut count = NVML_MAX_PROCESSES as u32;
let mut infos = [NvmlProcessInfo {
pid: 0,
used_gpu_memory: 0,
gpu_instance_id: 0,
compute_instance_id: 0,
}; NVML_MAX_PROCESSES];
let ret = unsafe { get_processes(device, &raw mut count, infos.as_mut_ptr()) };
unsafe { shutdown() };
if ret != NVML_SUCCESS && ret != NVML_ERROR_INSUFFICIENT_SIZE {
return None;
}
let my_pid = std::process::id();
#[allow(clippy::as_conversions)]
let actual_count = count as usize;
let my_vram = infos
.get(..actual_count)?
.iter()
.find(|info| info.pid == my_pid)
.map(|info| info.used_gpu_memory);
if my_vram == Some(u64::MAX) {
return None;
}
if let (Some(used), Some(total)) = (my_vram, total_bytes)
&& used > total
{
return None;
}
my_vram.map(|used| (Some(used), total_bytes, Some(true)))
}
fn nvidia_smi_query() -> (Option<u64>, Option<u64>) {
let output = std::process::Command::new("nvidia-smi")
.args([
"--query-gpu=memory.used,memory.total",
"--format=csv,noheader,nounits",
])
.output();
let output = match output {
Ok(o) if o.status.success() => o,
_ => return (None, None),
};
let stdout = String::from_utf8_lossy(&output.stdout);
let line = match stdout.lines().next() {
Some(l) => l.trim(),
None => return (None, None),
};
let mut parts = line.split(',');
let used_str = match parts.next() {
Some(s) => s.trim(),
None => return (None, None),
};
let total_str = match parts.next() {
Some(s) => s.trim(),
None => return (None, None),
};
let used_mb: u64 = match used_str.parse() {
Ok(v) => v,
Err(_) => return (None, None),
};
let total_mb: u64 = match total_str.parse() {
Ok(v) => v,
Err(_) => return (None, None),
};
(Some(used_mb * 1_048_576), Some(total_mb * 1_048_576))
}
#[cfg(windows)]
#[allow(unsafe_code)]
fn dxgi_query_process_vram() -> Option<GpuMemoryResult> {
use windows::Win32::Graphics::Dxgi::{
CreateDXGIFactory1, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, DXGI_QUERY_VIDEO_MEMORY_INFO,
IDXGIAdapter, IDXGIAdapter3, IDXGIFactory1,
};
use windows::core::Interface;
let factory: IDXGIFactory1 = unsafe { CreateDXGIFactory1() }.ok()?;
let mut adapter_idx = 0u32;
loop {
let adapter: IDXGIAdapter = unsafe { factory.EnumAdapters1(adapter_idx) }
.ok()?
.cast()
.ok()?;
let desc = unsafe { adapter.GetDesc() }.ok()?;
let dedicated_vram = desc.DedicatedVideoMemory;
if dedicated_vram == 0 {
adapter_idx += 1;
continue;
}
let adapter3: IDXGIAdapter3 = adapter.cast().ok()?;
let mut mem_info = DXGI_QUERY_VIDEO_MEMORY_INFO::default();
unsafe {
adapter3.QueryVideoMemoryInfo(0, DXGI_MEMORY_SEGMENT_GROUP_LOCAL, &raw mut mem_info)
}
.ok()?;
#[allow(clippy::as_conversions)]
let total = dedicated_vram as u64;
let raw_name = String::from_utf16_lossy(&desc.Description);
let gpu_name = raw_name.trim_end_matches('\0').to_owned();
#[cfg(feature = "memory-debug")]
eprintln!(
"[DXGI debug] adapter={gpu_name}, dedicated_vram={total}, \
current_usage={}, budget={}",
mem_info.CurrentUsage, mem_info.Budget,
);
return Some((
Some(mem_info.CurrentUsage),
Some(total),
Some(true),
Some(gpu_name),
));
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn snapshot_cpu_has_ram() {
let snap = MemorySnapshot::now(&candle_core::Device::Cpu).unwrap();
assert!(snap.ram_bytes > 0, "RAM should be non-zero");
assert!(snap.vram_bytes.is_none(), "CPU should have no VRAM");
assert!(
snap.vram_per_process.is_none(),
"CPU should have no VRAM qualifier"
);
}
#[test]
fn report_delta_positive_for_allocation() {
let before = MemorySnapshot {
ram_bytes: 100 * 1_048_576, vram_bytes: Some(500 * 1_048_576),
vram_total_bytes: Some(16_384 * 1_048_576),
vram_per_process: Some(true),
gpu_name: None,
};
let after = MemorySnapshot {
ram_bytes: 200 * 1_048_576, vram_bytes: Some(1_000 * 1_048_576),
vram_total_bytes: Some(16_384 * 1_048_576),
vram_per_process: Some(true),
gpu_name: None,
};
let report = MemoryReport::new(before, after);
let ram_delta = report.ram_delta_mb();
assert!(
(ram_delta - 100.0).abs() < 0.01,
"RAM delta should be ~100 MB, got {ram_delta}"
);
let vram_delta = report.vram_delta_mb().unwrap();
assert!(
(vram_delta - 500.0).abs() < 0.01,
"VRAM delta should be ~500 MB, got {vram_delta}"
);
}
#[test]
fn report_delta_none_when_no_vram() {
let before = MemorySnapshot {
ram_bytes: 100,
vram_bytes: None,
vram_total_bytes: None,
vram_per_process: None,
gpu_name: None,
};
let after = MemorySnapshot {
ram_bytes: 200,
vram_bytes: None,
vram_total_bytes: None,
vram_per_process: None,
gpu_name: None,
};
let report = MemoryReport::new(before, after);
assert!(report.vram_delta_mb().is_none());
}
#[test]
fn ram_mb_conversion() {
let snap = MemorySnapshot {
ram_bytes: 1_048_576, vram_bytes: None,
vram_total_bytes: None,
vram_per_process: None,
gpu_name: None,
};
assert!((snap.ram_mb() - 1.0).abs() < 0.001);
}
#[test]
fn vram_qualifier_per_process() {
let snap = MemorySnapshot {
ram_bytes: 100,
vram_bytes: Some(500),
vram_total_bytes: Some(1000),
vram_per_process: Some(true),
gpu_name: None,
};
let report = MemoryReport::new(snap.clone(), snap);
assert_eq!(report.vram_qualifier(), " [per-process]");
}
#[test]
fn vram_qualifier_device_wide() {
let snap = MemorySnapshot {
ram_bytes: 100,
vram_bytes: Some(500),
vram_total_bytes: Some(1000),
vram_per_process: Some(false),
gpu_name: None,
};
let report = MemoryReport::new(snap.clone(), snap);
assert_eq!(report.vram_qualifier(), " [device-wide]");
}
}