use oxicuda_driver::device::Device;
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::loader::try_driver;
use oxicuda_driver::stream::Stream;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MemoryInfo {
pub free: usize,
pub total: usize,
}
impl MemoryInfo {
#[inline]
pub fn used(&self) -> usize {
self.total.saturating_sub(self.free)
}
#[inline]
pub fn usage_fraction(&self) -> f64 {
if self.total == 0 {
return 0.0;
}
self.used() as f64 / self.total as f64
}
}
impl std::fmt::Display for MemoryInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"MemoryInfo(free={} MB, total={} MB, used={:.1}%)",
self.free / (1024 * 1024),
self.total / (1024 * 1024),
self.usage_fraction() * 100.0,
)
}
}
pub fn memory_info() -> CudaResult<MemoryInfo> {
let driver = try_driver()?;
let mut free: usize = 0;
let mut total: usize = 0;
oxicuda_driver::check(unsafe { (driver.cu_mem_get_info_v2)(&mut free, &mut total) })?;
Ok(MemoryInfo { free, total })
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum MemAdvice {
SetReadMostly = 1,
UnsetReadMostly = 2,
SetPreferredLocation = 3,
UnsetPreferredLocation = 4,
SetAccessedBy = 5,
UnsetAccessedBy = 6,
}
pub fn mem_advise(ptr: u64, count: usize, advice: MemAdvice, device: &Device) -> CudaResult<()> {
if count == 0 {
return Err(CudaError::InvalidValue);
}
let driver = try_driver()?;
oxicuda_driver::check(unsafe {
(driver.cu_mem_advise)(ptr, count, advice as u32, device.raw())
})
}
pub fn mem_prefetch(ptr: u64, count: usize, device: &Device, stream: &Stream) -> CudaResult<()> {
if count == 0 {
return Err(CudaError::InvalidValue);
}
let driver = try_driver()?;
oxicuda_driver::check(unsafe {
(driver.cu_mem_prefetch_async)(ptr, count, device.raw(), stream.raw())
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn memory_info_used_calculation() {
let info = MemoryInfo {
free: 4096,
total: 8192,
};
assert_eq!(info.used(), 4096);
}
#[test]
fn memory_info_usage_fraction() {
let info = MemoryInfo {
free: 2048,
total: 8192,
};
let frac = info.usage_fraction();
assert!((frac - 0.75).abs() < 1e-10);
}
#[test]
fn memory_info_usage_fraction_zero_total() {
let info = MemoryInfo { free: 0, total: 0 };
assert!((info.usage_fraction()).abs() < 1e-10);
}
#[test]
fn memory_info_display() {
let info = MemoryInfo {
free: 4 * 1024 * 1024,
total: 8 * 1024 * 1024,
};
let s = format!("{info}");
assert!(s.contains("free=4 MB"));
assert!(s.contains("total=8 MB"));
}
#[test]
fn mem_advice_variants() {
assert_eq!(MemAdvice::SetReadMostly as u32, 1);
assert_eq!(MemAdvice::UnsetReadMostly as u32, 2);
assert_eq!(MemAdvice::SetPreferredLocation as u32, 3);
assert_eq!(MemAdvice::UnsetPreferredLocation as u32, 4);
assert_eq!(MemAdvice::SetAccessedBy as u32, 5);
assert_eq!(MemAdvice::UnsetAccessedBy as u32, 6);
}
#[test]
fn mem_advise_rejects_zero_count() {
let dev = Device::get(0);
if let Ok(dev) = dev {
let result = mem_advise(0x1000, 0, MemAdvice::SetReadMostly, &dev);
assert!(result.is_err());
}
}
#[test]
fn mem_prefetch_rejects_zero_count() {
let _: fn(u64, usize, &Device, &Stream) -> CudaResult<()> = mem_prefetch;
}
#[test]
fn memory_info_signature_compiles() {
let _: fn() -> CudaResult<MemoryInfo> = memory_info;
}
}