singe-cuda 0.1.0-alpha.6

Safe Rust wrappers for CUDA driver, runtime, NVRTC, NVVM, NVTX, memory, streams, modules, and graphs.
Documentation
#[allow(unused_imports)]
use crate::error::Status;

use std::{
    fmt::{self, Display, Formatter},
    ptr,
    time::Duration,
};

use num_enum::{IntoPrimitive, TryFromPrimitive};
use singe_core::impl_enum_conversion;
use singe_cuda_sys::driver;

use crate::{
    device::Uuid,
    error::{Error, Result},
    try_ffi,
};

/// CUDA process state used by the checkpoint and restore driver APIs.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum ProcessState {
    /// The process can make CUDA API calls.
    Running = driver::CUprocessState::CU_PROCESS_STATE_RUNNING as _,
    /// CUDA API locks are taken and further CUDA API calls will block.
    Locked = driver::CUprocessState::CU_PROCESS_STATE_LOCKED as _,
    /// GPU memory has been moved to host memory and device handles were released.
    Checkpointed = driver::CUprocessState::CU_PROCESS_STATE_CHECKPOINTED as _,
    /// The process entered an unrecoverable error during checkpoint or restore.
    Failed = driver::CUprocessState::CU_PROCESS_STATE_FAILED as _,
}

impl_enum_conversion!(driver::CUprocessState, ProcessState);

impl Display for ProcessState {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            Self::Running => write!(f, "CU_PROCESS_STATE_RUNNING"),
            Self::Locked => write!(f, "CU_PROCESS_STATE_LOCKED"),
            Self::Checkpointed => write!(f, "CU_PROCESS_STATE_CHECKPOINTED"),
            Self::Failed => write!(f, "CU_PROCESS_STATE_FAILED"),
        }
    }
}

/// Options for [`CheckpointProcess::lock`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct LockOptions {
    timeout: Option<Duration>,
}

impl LockOptions {
    /// Creates lock options without a timeout.
    pub const fn new() -> Self {
        Self { timeout: None }
    }

    /// Sets the maximum time CUDA should spend attempting to lock the process.
    ///
    /// A missing timeout passes `0` to CUDA, which means no timeout.
    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
        self.timeout = Some(timeout);
        self
    }

    fn to_raw(self) -> Result<driver::CUcheckpointLockArgs> {
        let timeout_ms = match self.timeout {
            Some(timeout) => timeout
                .as_millis()
                .try_into()
                .map_err(|_| Error::InvalidValue)?,
            None => 0,
        };

        Ok(driver::CUcheckpointLockArgs {
            timeoutMs: timeout_ms,
            ..Default::default()
        })
    }
}

/// GPU UUID remapping entry used during restore.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GpuPair {
    /// UUID of the GPU that was checkpointed.
    pub old_uuid: Uuid,
    /// UUID of the GPU to restore onto.
    pub new_uuid: Uuid,
}

impl GpuPair {
    pub const fn new(old_uuid: Uuid, new_uuid: Uuid) -> Self {
        Self { old_uuid, new_uuid }
    }

    fn to_raw(self) -> driver::CUcheckpointGpuPair {
        driver::CUcheckpointGpuPair {
            oldUuid: self.old_uuid.into(),
            newUuid: self.new_uuid.into(),
        }
    }
}

/// A CUDA process controlled through the driver checkpoint APIs.
///
/// These APIs are intended for an external controller process. Locking a
/// process blocks further CUDA API calls in that process until it is restored
/// and unlocked.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct CheckpointProcess {
    pid: i32,
}

impl CheckpointProcess {
    /// Creates a CUDA checkpoint target from an operating-system process ID.
    pub const fn from_pid(pid: i32) -> Self {
        Self { pid }
    }

    /// Returns the operating-system process ID controlled by this value.
    pub const fn pid(self) -> i32 {
        self.pid
    }

    /// Returns the current CUDA checkpoint state of the process.
    pub fn state(self) -> Result<ProcessState> {
        let mut state = driver::CUprocessState::CU_PROCESS_STATE_RUNNING;
        unsafe {
            try_ffi!(driver::cuCheckpointProcessGetState(
                self.pid,
                &raw mut state,
            ))?;
        }
        Ok(state.into())
    }

    /// Returns the CUDA restore thread ID for the process.
    pub fn restore_thread_id(self) -> Result<i32> {
        let mut thread_id = 0;
        unsafe {
            try_ffi!(driver::cuCheckpointProcessGetRestoreThreadId(
                self.pid,
                &raw mut thread_id,
            ))?;
        }
        Ok(thread_id)
    }

    /// Locks a running CUDA process so further CUDA API calls in that process block.
    ///
    /// On success the process enters [`ProcessState::Locked`].
    pub fn lock(self, options: LockOptions) -> Result<()> {
        let mut args = options.to_raw()?;
        unsafe { try_ffi!(driver::cuCheckpointProcessLock(self.pid, &raw mut args)) }
    }

    /// Moves the locked process's GPU memory into host memory managed by the driver.
    ///
    /// On success the process enters [`ProcessState::Checkpointed`].
    pub fn checkpoint(self) -> Result<()> {
        let mut args = driver::CUcheckpointCheckpointArgs::default();
        unsafe {
            try_ffi!(driver::cuCheckpointProcessCheckpoint(
                self.pid,
                &raw mut args,
            ))
        }
    }

    /// Locks and checkpoints a running CUDA process.
    ///
    /// On success the process enters [`ProcessState::Checkpointed`].
    pub fn suspend(self, options: LockOptions) -> Result<()> {
        self.lock(options)?;
        self.checkpoint()
    }

    /// Restores a checkpointed process, optionally remapping checkpointed GPUs.
    ///
    /// If `gpu_pairs` is not empty, CUDA requires it to contain every
    /// checkpointed GPU.
    ///
    /// On success the process enters [`ProcessState::Locked`].
    pub fn restore(self, gpu_pairs: &[GpuPair]) -> Result<()> {
        let gpu_pairs_count = gpu_pairs
            .len()
            .try_into()
            .map_err(|_| Error::InvalidValue)?;
        let mut raw_pairs: Vec<_> = gpu_pairs.iter().copied().map(GpuPair::to_raw).collect();
        let gpu_pairs_ptr = if raw_pairs.is_empty() {
            ptr::null_mut()
        } else {
            raw_pairs.as_mut_ptr()
        };
        let mut args = driver::CUcheckpointRestoreArgs {
            gpuPairs: gpu_pairs_ptr,
            gpuPairsCount: gpu_pairs_count,
            ..Default::default()
        };

        unsafe { try_ffi!(driver::cuCheckpointProcessRestore(self.pid, &raw mut args)) }
    }

    /// Restores and unlocks a checkpointed CUDA process.
    ///
    /// On success the process enters [`ProcessState::Running`].
    pub fn resume(self, gpu_pairs: &[GpuPair]) -> Result<()> {
        self.restore(gpu_pairs)?;
        self.unlock()
    }

    /// Unlocks a locked CUDA process so it can resume CUDA API calls.
    ///
    /// On success the process enters [`ProcessState::Running`].
    pub fn unlock(self) -> Result<()> {
        let mut args = driver::CUcheckpointUnlockArgs::default();
        unsafe { try_ffi!(driver::cuCheckpointProcessUnlock(self.pid, &raw mut args)) }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn it_works() {
        let process = CheckpointProcess::from_pid(std::process::id() as i32);
        match process.state() {
            Ok(state) => assert!(matches!(
                state,
                ProcessState::Running
                    | ProcessState::Locked
                    | ProcessState::Checkpointed
                    | ProcessState::Failed
            )),
            Err(error) => assert_checkpoint_error(error),
        }

        let missing_process = CheckpointProcess::from_pid(-1);
        checkpoint_fails(missing_process.restore_thread_id());
        checkpoint_fails(
            missing_process.lock(LockOptions::new().with_timeout(Duration::from_millis(1))),
        );
        checkpoint_fails(missing_process.checkpoint());
        checkpoint_fails(missing_process.restore(&[]));
        checkpoint_fails(missing_process.unlock());
    }

    fn checkpoint_fails<T>(result: Result<T>) {
        match result {
            Err(error) => assert_checkpoint_error(error),
            Ok(_) => panic!("checkpoint call unexpectedly succeeded"),
        }
    }

    fn assert_checkpoint_error(error: Error) {
        match error {
            Error::Cuda { code, .. }
                if matches!(
                    code,
                    Status::InvalidValue
                        | Status::NotInitialized
                        | Status::NotSupported
                        | Status::IllegalState
                        | Status::OperatingSystem
                ) => {}
            error => panic!("{error:?}"),
        }
    }
}