#![cfg(feature = "graphs-conditional")]
use std::sync::Arc;
use cudarc::driver::sys as driver_sys;
use cudarc::driver::CudaContext;
use crate::error::GpuError;
const LIB: &str = "graph";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConditionalKind {
If,
While,
}
impl ConditionalKind {
fn raw(self) -> driver_sys::CUgraphConditionalNodeType {
match self {
ConditionalKind::If => driver_sys::CUgraphConditionalNodeType::CU_GRAPH_COND_TYPE_IF,
ConditionalKind::While => {
driver_sys::CUgraphConditionalNodeType::CU_GRAPH_COND_TYPE_WHILE
}
}
}
}
#[derive(Clone)]
pub struct IfNodeDescriptor {
pub default_value: u32,
}
#[derive(Clone)]
pub struct WhileNodeDescriptor {
pub default_value: u32,
}
pub fn build_params(
kind: ConditionalKind,
handle: driver_sys::CUgraphConditionalHandle,
ctx: &Arc<CudaContext>,
inner_graph_out: *mut driver_sys::CUgraph,
) -> driver_sys::CUDA_CONDITIONAL_NODE_PARAMS {
driver_sys::CUDA_CONDITIONAL_NODE_PARAMS {
handle,
type_: kind.raw(),
size: 1,
phGraph_out: inner_graph_out,
ctx: ctx.cu_ctx(),
}
}
pub fn driver_supports_conditional() -> Result<bool, GpuError> {
let probe = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut h: driver_sys::CUgraphConditionalHandle = 0;
let s = unsafe {
driver_sys::cuGraphConditionalHandleCreate(
&mut h as *mut _,
std::ptr::null_mut(),
std::ptr::null_mut(),
0,
0,
)
};
s
}));
match probe {
Ok(s) => match s {
driver_sys::cudaError_enum::CUDA_ERROR_NOT_SUPPORTED => Ok(false),
_ => Ok(true),
},
Err(_) => Err(GpuError::Unrecoverable(
"conditional probe: CUDA driver not loadable".into(),
)),
}
}
pub(crate) fn check(s: driver_sys::CUresult, op: &str) -> Result<(), GpuError> {
if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
Ok(())
} else {
Err(GpuError::LibraryError {
lib: LIB,
msg: format!("{op}: {s:?}"),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn if_node_descriptor_compiles() {
let d = IfNodeDescriptor { default_value: 1 };
assert_eq!(d.default_value, 1);
assert_eq!(ConditionalKind::If, ConditionalKind::If);
assert_ne!(ConditionalKind::If, ConditionalKind::While);
let _ = ConditionalKind::If.raw();
let _ = ConditionalKind::While.raw();
}
#[test]
fn while_node_descriptor_compiles() {
let d = WhileNodeDescriptor { default_value: 0 };
assert_eq!(d.default_value, 0);
}
#[test]
fn driver_probe_returns_typed_result() {
let r = driver_supports_conditional();
match r {
Ok(_) => {}
Err(GpuError::Unrecoverable(_)) => {}
other => panic!("unexpected: {other:?}"),
}
}
}