memonitor 0.2.4

Query CPU and GPU memory information in a portable way.
Documentation
use std::ffi::CStr;
use std::ptr::addr_of_mut;

use tracing::{debug, info, warn};

use memonitor_sys::cuda;

use crate::{BackendHandle, BackendId, DeviceHandle, DeviceKind, GPUKind, MemoryStats, CUDA_NAME};

pub(super) struct Cuda {
    handle: cuda::Devices,
}

unsafe impl Send for Cuda {}

unsafe impl Sync for Cuda {}

impl Cuda {
    pub(super) fn init() -> Option<(Self, Vec<CudaDevice>)> {
        debug!("Attempting to load CUDA monitor");

        let res = unsafe { cuda::init() };

        if res == 0 {
            let mut c_devices = unsafe { cuda::list_devices() };

            if c_devices.count == 0 {
                warn!("No Cuda devices found. Aborting CUDA initialisation.");
                unsafe {
                    cuda::destroy_devices(addr_of_mut!(c_devices));
                    cuda::term();
                }
                return None;
            }

            let mut devices = Vec::with_capacity(c_devices.count as usize);
            for i in 0..c_devices.count {
                let c_device = unsafe { cuda::get_device(addr_of_mut!(c_devices), i) };
                if c_device.ctx_handle.is_null() {
                    warn!("Invalid device handle. Aborting CUDA initialisation.");
                    unsafe {
                        cuda::destroy_devices(addr_of_mut!(c_devices));
                        cuda::term();
                    };
                    return None;
                }

                let properties = unsafe { cuda::device_properties(c_device) };
                if properties.name[0] == 0 {
                    warn!("Invalid device name. Aborting CUDA initialisation.");
                    unsafe {
                        cuda::destroy_devices(addr_of_mut!(c_devices));
                        cuda::term();
                    };
                    return None;
                }

                let name = unsafe { CStr::from_ptr(properties.name.as_ptr()) };
                let kind = match properties.kind {
                    cuda::DeviceKind::IntegratedGPU => DeviceKind::GPU(GPUKind::Integrated),
                    cuda::DeviceKind::DiscreteGPU => DeviceKind::GPU(GPUKind::Discrete),
                    cuda::DeviceKind::Other => DeviceKind::Other,
                    _ => DeviceKind::Other,
                };

                if properties.total_memory == 0 {
                    warn!("Invalid amount of memory. Skipping device.");
                    continue;
                }

                let device = CudaDevice {
                    handle: c_device,
                    name: name.to_string_lossy().to_string(),
                    kind,
                    memory: properties.total_memory,
                };
                devices.push(device);
            }

            let backend = Cuda { handle: c_devices };

            info!("Successfully initialised CUDA backend.");
            Some((backend, devices))
        } else {
            warn!("Failed to initialise the CUDA backend.");
            None
        }
    }
}

impl BackendHandle for Cuda {
    fn name(&self) -> &str {
        CUDA_NAME
    }

    fn id(&self) -> BackendId {
        BackendId::CUDA
    }
}

impl Drop for Cuda {
    fn drop(&mut self) {
        unsafe {
            cuda::destroy_devices(addr_of_mut!(self.handle));
            cuda::term();
        }
    }
}

pub(super) struct CudaDevice {
    handle: cuda::DeviceRef,
    name: String,
    kind: DeviceKind,
    memory: usize,
}

unsafe impl Send for CudaDevice {}

unsafe impl Sync for CudaDevice {}

impl DeviceHandle for CudaDevice {
    fn name(&self) -> &str {
        &self.name
    }

    fn kind(&self) -> DeviceKind {
        self.kind
    }

    fn backend(&self) -> BackendId {
        BackendId::CUDA
    }

    fn current_memory_stats(&self) -> MemoryStats {
        let c_stats = unsafe { cuda::device_memory_properties(self.handle) };
        MemoryStats {
            total: self.memory,
            available: c_stats.budget,
            used: c_stats.used,
        }
    }
}