#[allow(unused_imports)]
use crate::error::Status;
use std::{
fmt::{self, Display, Formatter},
ptr,
time::Duration,
};
use num_enum::{IntoPrimitive, TryFromPrimitive};
use singe_core::impl_enum_conversion;
use singe_cuda_sys::driver;
use crate::{
device::Uuid,
error::{Error, Result},
try_ffi,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum ProcessState {
Running = driver::CUprocessState::CU_PROCESS_STATE_RUNNING as _,
Locked = driver::CUprocessState::CU_PROCESS_STATE_LOCKED as _,
Checkpointed = driver::CUprocessState::CU_PROCESS_STATE_CHECKPOINTED as _,
Failed = driver::CUprocessState::CU_PROCESS_STATE_FAILED as _,
}
impl_enum_conversion!(driver::CUprocessState, ProcessState);
impl Display for ProcessState {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Running => write!(f, "CU_PROCESS_STATE_RUNNING"),
Self::Locked => write!(f, "CU_PROCESS_STATE_LOCKED"),
Self::Checkpointed => write!(f, "CU_PROCESS_STATE_CHECKPOINTED"),
Self::Failed => write!(f, "CU_PROCESS_STATE_FAILED"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub struct LockOptions {
timeout: Option<Duration>,
}
impl LockOptions {
pub const fn new() -> Self {
Self { timeout: None }
}
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
fn to_raw(self) -> Result<driver::CUcheckpointLockArgs> {
let timeout_ms = match self.timeout {
Some(timeout) => timeout
.as_millis()
.try_into()
.map_err(|_| Error::InvalidValue)?,
None => 0,
};
Ok(driver::CUcheckpointLockArgs {
timeoutMs: timeout_ms,
..Default::default()
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GpuPair {
pub old_uuid: Uuid,
pub new_uuid: Uuid,
}
impl GpuPair {
pub const fn new(old_uuid: Uuid, new_uuid: Uuid) -> Self {
Self { old_uuid, new_uuid }
}
fn to_raw(self) -> driver::CUcheckpointGpuPair {
driver::CUcheckpointGpuPair {
oldUuid: self.old_uuid.into(),
newUuid: self.new_uuid.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct CheckpointProcess {
pid: i32,
}
impl CheckpointProcess {
pub const fn from_pid(pid: i32) -> Self {
Self { pid }
}
pub const fn pid(self) -> i32 {
self.pid
}
pub fn state(self) -> Result<ProcessState> {
let mut state = driver::CUprocessState::CU_PROCESS_STATE_RUNNING;
unsafe {
try_ffi!(driver::cuCheckpointProcessGetState(
self.pid,
&raw mut state,
))?;
}
Ok(state.into())
}
pub fn restore_thread_id(self) -> Result<i32> {
let mut thread_id = 0;
unsafe {
try_ffi!(driver::cuCheckpointProcessGetRestoreThreadId(
self.pid,
&raw mut thread_id,
))?;
}
Ok(thread_id)
}
pub fn lock(self, options: LockOptions) -> Result<()> {
let mut args = options.to_raw()?;
unsafe { try_ffi!(driver::cuCheckpointProcessLock(self.pid, &raw mut args)) }
}
pub fn checkpoint(self) -> Result<()> {
let mut args = driver::CUcheckpointCheckpointArgs::default();
unsafe {
try_ffi!(driver::cuCheckpointProcessCheckpoint(
self.pid,
&raw mut args,
))
}
}
pub fn suspend(self, options: LockOptions) -> Result<()> {
self.lock(options)?;
self.checkpoint()
}
pub fn restore(self, gpu_pairs: &[GpuPair]) -> Result<()> {
let gpu_pairs_count = gpu_pairs
.len()
.try_into()
.map_err(|_| Error::InvalidValue)?;
let mut raw_pairs: Vec<_> = gpu_pairs.iter().copied().map(GpuPair::to_raw).collect();
let gpu_pairs_ptr = if raw_pairs.is_empty() {
ptr::null_mut()
} else {
raw_pairs.as_mut_ptr()
};
let mut args = driver::CUcheckpointRestoreArgs {
gpuPairs: gpu_pairs_ptr,
gpuPairsCount: gpu_pairs_count,
..Default::default()
};
unsafe { try_ffi!(driver::cuCheckpointProcessRestore(self.pid, &raw mut args)) }
}
pub fn resume(self, gpu_pairs: &[GpuPair]) -> Result<()> {
self.restore(gpu_pairs)?;
self.unlock()
}
pub fn unlock(self) -> Result<()> {
let mut args = driver::CUcheckpointUnlockArgs::default();
unsafe { try_ffi!(driver::cuCheckpointProcessUnlock(self.pid, &raw mut args)) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
let process = CheckpointProcess::from_pid(std::process::id() as i32);
match process.state() {
Ok(state) => assert!(matches!(
state,
ProcessState::Running
| ProcessState::Locked
| ProcessState::Checkpointed
| ProcessState::Failed
)),
Err(error) => assert_checkpoint_error(error),
}
let missing_process = CheckpointProcess::from_pid(-1);
checkpoint_fails(missing_process.restore_thread_id());
checkpoint_fails(
missing_process.lock(LockOptions::new().with_timeout(Duration::from_millis(1))),
);
checkpoint_fails(missing_process.checkpoint());
checkpoint_fails(missing_process.restore(&[]));
checkpoint_fails(missing_process.unlock());
}
fn checkpoint_fails<T>(result: Result<T>) {
match result {
Err(error) => assert_checkpoint_error(error),
Ok(_) => panic!("checkpoint call unexpectedly succeeded"),
}
}
fn assert_checkpoint_error(error: Error) {
match error {
Error::Cuda { code, .. }
if matches!(
code,
Status::InvalidValue
| Status::NotInitialized
| Status::NotSupported
| Status::IllegalState
| Status::OperatingSystem
) => {}
error => panic!("{error:?}"),
}
}
}