use std::{
ffi::CString,
fmt::{self, Display, Formatter},
marker::PhantomData,
ptr,
};
use num_enum::{IntoPrimitive, TryFromPrimitive};
use singe_core::impl_enum_conversion;
use singe_cuda_sys::{driver, runtime};
use crate::{
dim::Dim3,
error::{Error, Result},
event::Event,
memory::{MemAccessDescriptor, MemoryCopyKind, MemoryPoolProps},
stream::Stream,
try_cuda,
types::{DeviceFunction, DevicePtr, HostFunction},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum GraphKernelNodeAttributeId {
Cooperative = runtime::cudaLaunchAttributeID::cudaLaunchAttributeCooperative as _,
ClusterDimension = runtime::cudaLaunchAttributeID::cudaLaunchAttributeClusterDimension as _,
Priority = runtime::cudaLaunchAttributeID::cudaLaunchAttributePriority as _,
PreferredSharedMemoryCarveout =
runtime::cudaLaunchAttributeID::cudaLaunchAttributePreferredSharedMemoryCarveout as _,
}
impl_enum_conversion!(
u32,
runtime::cudaLaunchAttributeID,
GraphKernelNodeAttributeId
);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GraphKernelNodeAttribute {
Cooperative(bool),
ClusterDimension(Dim3),
Priority(i32),
PreferredSharedMemoryCarveout(u32),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct MemoryAllocationNodeInfo {
ptr: DevicePtr,
pub byte_size: usize,
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GraphInstantiateFlags: u64 {
const AUTO_FREE_ON_LAUNCH = driver::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH as _;
const UPLOAD = driver::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD as _;
const DEVICE_LAUNCH = driver::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH as _;
const USE_NODE_PRIORITY = driver::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY as _;
}
}
bitflags::bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GraphDebugDotFlags: u32 {
const VERBOSE = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE as _;
const RUNTIME_TYPES = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES as _;
const KERNEL_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS as _;
const MEMCPY_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS as _;
const MEMSET_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS as _;
const HOST_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS as _;
const EVENT_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS as _;
const EXTERNAL_SEMAPHORE_SIGNAL_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS as _;
const EXTERNAL_SEMAPHORE_WAIT_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS as _;
const KERNEL_NODE_ATTRIBUTES = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES as _;
const HANDLES = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES as _;
const MEMORY_ALLOC_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS as _;
const MEMORY_FREE_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS as _;
const BATCH_MEM_OP_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS as _;
const EXTRA_TOPOLOGY_INFO = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO as _;
const CONDITIONAL_NODE_PARAMS = driver::CUgraphDebugDot_flags::CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS as _;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum GraphNodeType {
Kernel = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL as _,
Memcpy = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMCPY as _,
Memset = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEMSET as _,
Host = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_HOST as _,
Graph = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_GRAPH as _,
Empty = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EMPTY as _,
WaitEvent = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_WAIT_EVENT as _,
EventRecord = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EVENT_RECORD as _,
ExternalSemaphoresSignal = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL as _,
ExternalSemaphoresWait = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT as _,
MemoryAlloc = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_ALLOC as _,
MemoryFree = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_MEM_FREE as _,
BatchMemOp = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_BATCH_MEM_OP as _,
Conditional = driver::CUgraphNodeType::CU_GRAPH_NODE_TYPE_CONDITIONAL as _,
}
impl_enum_conversion!(u32, runtime::cudaGraphNodeType, GraphNodeType);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u8)]
pub enum GraphDependencyType {
Default = driver::CUgraphDependencyType::CU_GRAPH_DEPENDENCY_TYPE_DEFAULT as _,
Programmatic = driver::CUgraphDependencyType::CU_GRAPH_DEPENDENCY_TYPE_PROGRAMMATIC as _,
}
impl From<driver::CUgraphDependencyType> for GraphDependencyType {
fn from(value: driver::CUgraphDependencyType) -> Self {
match value {
driver::CUgraphDependencyType::CU_GRAPH_DEPENDENCY_TYPE_DEFAULT => Self::Default,
driver::CUgraphDependencyType::CU_GRAPH_DEPENDENCY_TYPE_PROGRAMMATIC => {
Self::Programmatic
}
}
}
}
impl From<GraphDependencyType> for driver::CUgraphDependencyType {
fn from(value: GraphDependencyType) -> Self {
match value {
GraphDependencyType::Default => Self::CU_GRAPH_DEPENDENCY_TYPE_DEFAULT,
GraphDependencyType::Programmatic => Self::CU_GRAPH_DEPENDENCY_TYPE_PROGRAMMATIC,
}
}
}
impl Display for GraphNodeType {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Kernel => write!(f, "cudaGraphNodeTypeKernel"),
Self::Memcpy => write!(f, "cudaGraphNodeTypeMemcpy"),
Self::Memset => write!(f, "cudaGraphNodeTypeMemset"),
Self::Host => write!(f, "cudaGraphNodeTypeHost"),
Self::Graph => write!(f, "cudaGraphNodeTypeGraph"),
Self::Empty => write!(f, "cudaGraphNodeTypeEmpty"),
Self::WaitEvent => write!(f, "cudaGraphNodeTypeWaitEvent"),
Self::EventRecord => write!(f, "cudaGraphNodeTypeEventRecord"),
Self::ExternalSemaphoresSignal => {
write!(f, "cudaGraphNodeTypeExternalSemaphoresSignal")
}
Self::ExternalSemaphoresWait => {
write!(f, "cudaGraphNodeTypeExternalSemaphoresWait")
}
Self::MemoryAlloc => write!(f, "cudaGraphNodeTypeMemAlloc"),
Self::MemoryFree => write!(f, "cudaGraphNodeTypeMemFree"),
Self::BatchMemOp => write!(f, "cudaGraphNodeTypeBatchMemOp"),
Self::Conditional => write!(f, "cudaGraphNodeTypeConditional"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
#[repr(u32)]
pub enum GraphExecUpdateResult {
Success = driver::CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_SUCCESS as _,
Error = driver::CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR as _,
ErrorTopologyChanged =
driver::CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED as _,
ErrorNodeTypeChanged =
driver::CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED as _,
ErrorFunctionChanged =
driver::CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_FUNCTION_CHANGED as _,
ErrorParametersChanged =
driver::CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED as _,
ErrorNotSupported =
driver::CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED as _,
ErrorUnsupportedFunctionChange =
driver::CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE
as _,
ErrorAttributesChanged =
driver::CUgraphExecUpdateResult::CU_GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED as _,
}
impl_enum_conversion!(driver::CUgraphExecUpdateResult, GraphExecUpdateResult);
impl Display for GraphExecUpdateResult {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Success => write!(f, "CU_GRAPH_EXEC_UPDATE_SUCCESS"),
Self::Error => write!(f, "CU_GRAPH_EXEC_UPDATE_ERROR"),
Self::ErrorTopologyChanged => {
write!(f, "CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED")
}
Self::ErrorNodeTypeChanged => {
write!(f, "CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED")
}
Self::ErrorFunctionChanged => {
write!(f, "CU_GRAPH_EXEC_UPDATE_ERROR_FUNCTION_CHANGED")
}
Self::ErrorParametersChanged => {
write!(f, "CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED")
}
Self::ErrorNotSupported => {
write!(f, "CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED")
}
Self::ErrorUnsupportedFunctionChange => {
write!(f, "CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE")
}
Self::ErrorAttributesChanged => {
write!(f, "CU_GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED")
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GraphNode {
handle: runtime::cudaGraphNode_t,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GraphEdgeData {
pub from_port: u8,
pub to_port: u8,
pub dependency_type: GraphDependencyType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GraphDependency {
pub node: GraphNode,
pub data: GraphEdgeData,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GraphEdge {
pub from: GraphNode,
pub to: GraphNode,
pub data: GraphEdgeData,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Position {
pub x: usize,
pub y: usize,
pub z: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Extent {
pub width: usize,
pub height: usize,
pub depth: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PitchedPtr {
ptr: *mut (),
pub pitch: usize,
pub x_size: usize,
pub y_size: usize,
}
impl PitchedPtr {
pub const fn new(ptr: *mut (), pitch: usize, x_size: usize, y_size: usize) -> Self {
Self {
ptr,
pitch,
x_size,
y_size,
}
}
pub const fn ptr(self) -> *mut () {
self.ptr
}
}
#[derive(Debug, Clone, Copy)]
pub struct Memcpy3DNodeParams {
pub src_array: Option<ArrayHandle>,
pub src_pos: Position,
pub src_ptr: PitchedPtr,
pub dst_array: Option<ArrayHandle>,
pub dst_pos: Position,
pub dst_ptr: PitchedPtr,
pub extent: Extent,
pub kind: MemoryCopyKind,
}
#[derive(Debug, Clone, Copy)]
pub struct MemcpyToSymbolNodeParams {
pub symbol: *const (),
pub src: *const (),
pub count: usize,
pub offset: usize,
pub kind: MemoryCopyKind,
}
#[derive(Debug, Clone, Copy)]
pub struct MemcpyFromSymbolNodeParams {
pub dst: *mut (),
pub symbol: *const (),
pub count: usize,
pub offset: usize,
pub kind: MemoryCopyKind,
}
#[derive(Debug, Clone)]
pub struct MemAllocNodeParams<'a> {
pub pool_props: MemoryPoolProps,
pub access_descs: &'a [MemAccessDescriptor],
pub byte_size: usize,
}
impl Default for GraphEdgeData {
fn default() -> Self {
Self {
from_port: 0,
to_port: 0,
dependency_type: GraphDependencyType::Default,
}
}
}
impl From<runtime::cudaGraphEdgeData> for GraphEdgeData {
fn from(value: runtime::cudaGraphEdgeData) -> Self {
Self {
from_port: value.from_port,
to_port: value.to_port,
dependency_type: GraphDependencyType::try_from(value.type_)
.unwrap_or(GraphDependencyType::Default),
}
}
}
impl From<GraphEdgeData> for runtime::cudaGraphEdgeData {
fn from(value: GraphEdgeData) -> Self {
Self {
from_port: value.from_port,
to_port: value.to_port,
type_: value.dependency_type.into(),
reserved: [0; 5],
}
}
}
impl From<Position> for runtime::cudaPos {
fn from(value: Position) -> Self {
Self {
x: value.x as _,
y: value.y as _,
z: value.z as _,
}
}
}
impl From<Extent> for runtime::cudaExtent {
fn from(value: Extent) -> Self {
Self {
width: value.width as _,
height: value.height as _,
depth: value.depth as _,
}
}
}
impl From<PitchedPtr> for runtime::cudaPitchedPtr {
fn from(value: PitchedPtr) -> Self {
Self {
ptr: value.ptr.cast(),
pitch: value.pitch as _,
xsize: value.x_size as _,
ysize: value.y_size as _,
}
}
}
impl From<&Memcpy3DNodeParams> for runtime::cudaMemcpy3DParms {
fn from(value: &Memcpy3DNodeParams) -> Self {
Self {
srcArray: value
.src_array
.map_or(ptr::null_mut(), |handle| unsafe { handle.as_raw() }),
srcPos: value.src_pos.into(),
srcPtr: value.src_ptr.into(),
dstArray: value
.dst_array
.map_or(ptr::null_mut(), |handle| unsafe { handle.as_raw() }),
dstPos: value.dst_pos.into(),
dstPtr: value.dst_ptr.into(),
extent: value.extent.into(),
kind: value.kind.into(),
}
}
}
impl GraphNode {
const unsafe fn from_raw(handle: runtime::cudaGraphNode_t) -> Self {
Self { handle }
}
pub fn node_type(self) -> Result<GraphNodeType> {
let mut kind = runtime::cudaGraphNodeType::CU_GRAPH_NODE_TYPE_KERNEL;
unsafe {
try_cuda!(runtime::cudaGraphNodeGetType(self.handle, &raw mut kind))?;
}
Ok(kind.into())
}
pub fn dependencies(self) -> Result<Vec<GraphDependency>> {
unsafe {
let mut count = 0;
try_cuda!(runtime::cudaGraphNodeGetDependencies(
self.handle,
ptr::null_mut(),
ptr::null_mut(),
&raw mut count,
))?;
if count == 0 {
return Ok(Vec::new());
}
let mut handles = Vec::with_capacity(count as usize);
let mut edge_data = Vec::with_capacity(count as usize);
try_cuda!(runtime::cudaGraphNodeGetDependencies(
self.handle,
handles.as_mut_ptr(),
edge_data.as_mut_ptr(),
&raw mut count,
))?;
handles.set_len(count as usize);
edge_data.set_len(count as usize);
Ok(handles
.into_iter()
.zip(edge_data)
.map(|(handle, data)| GraphDependency {
node: Self { handle },
data: data.into(),
})
.collect())
}
}
pub fn dependent_nodes(self) -> Result<Vec<GraphDependency>> {
unsafe {
let mut count = 0;
try_cuda!(runtime::cudaGraphNodeGetDependentNodes(
self.handle,
ptr::null_mut(),
ptr::null_mut(),
&raw mut count,
))?;
if count == 0 {
return Ok(Vec::new());
}
let mut handles = Vec::with_capacity(count as usize);
let mut edge_data = Vec::with_capacity(count as usize);
try_cuda!(runtime::cudaGraphNodeGetDependentNodes(
self.handle,
handles.as_mut_ptr(),
edge_data.as_mut_ptr(),
&raw mut count,
))?;
handles.set_len(count as usize);
edge_data.set_len(count as usize);
Ok(handles
.into_iter()
.zip(edge_data)
.map(|(handle, data)| GraphDependency {
node: Self { handle },
data: data.into(),
})
.collect())
}
}
pub fn event_record_node_event(self) -> Result<runtime::cudaEvent_t> {
let mut event = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaGraphEventRecordNodeGetEvent(
self.handle,
&raw mut event,
))?;
}
if event.is_null() {
return Err(Error::NullHandle);
}
Ok(event)
}
pub fn event_wait_node_event(self) -> Result<runtime::cudaEvent_t> {
let mut event = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaGraphEventWaitNodeGetEvent(
self.handle,
&raw mut event,
))?;
}
if event.is_null() {
return Err(Error::NullHandle);
}
Ok(event)
}
pub fn child_graph(self) -> Result<Graph> {
let mut graph = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaGraphChildGraphNodeGetGraph(
self.handle,
&raw mut graph,
))?;
}
if graph.is_null() {
return Err(Error::NullHandle);
}
Ok(unsafe { Graph::from_raw_borrowed(graph) })
}
pub fn memcpy_node_params(self) -> Result<runtime::cudaMemcpy3DParms> {
let mut params = runtime::cudaMemcpy3DParms::default();
unsafe {
try_cuda!(runtime::cudaGraphMemcpyNodeGetParams(
self.handle,
&raw mut params,
))?;
}
Ok(params)
}
pub fn memset_node_params(self) -> Result<driver::CUDA_MEMSET_NODE_PARAMS> {
let mut params = driver::CUDA_MEMSET_NODE_PARAMS::default();
unsafe {
try_cuda!(runtime::cudaGraphMemsetNodeGetParams(
self.handle,
&raw mut params,
))?;
}
Ok(params)
}
pub fn host_node_params(self) -> Result<driver::CUDA_HOST_NODE_PARAMS> {
let mut params = driver::CUDA_HOST_NODE_PARAMS::default();
unsafe {
try_cuda!(runtime::cudaGraphHostNodeGetParams(
self.handle,
&raw mut params,
))?;
}
Ok(params)
}
pub fn mem_alloc_node_info(self) -> Result<MemoryAllocationNodeInfo> {
let mut params = runtime::cudaMemAllocNodeParams::default();
unsafe {
try_cuda!(runtime::cudaGraphMemAllocNodeGetParams(
self.handle,
&raw mut params,
))?;
}
Ok(MemoryAllocationNodeInfo {
ptr: DevicePtr::new(params.dptr as _),
byte_size: params.bytesize as usize,
})
}
pub unsafe fn mem_free_node_ptr(self) -> Result<DevicePtr> {
let ptr = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaGraphMemFreeNodeGetParams(self.handle, ptr))?;
}
Ok(DevicePtr::new(ptr as _))
}
pub fn kernel_node_attribute(
self,
id: GraphKernelNodeAttributeId,
) -> Result<GraphKernelNodeAttribute> {
let mut value = runtime::cudaLaunchAttributeValue::default();
unsafe {
try_cuda!(runtime::cudaGraphKernelNodeGetAttribute(
self.handle,
id.into(),
&raw mut value,
))?;
Ok(match id {
GraphKernelNodeAttributeId::Cooperative => {
GraphKernelNodeAttribute::Cooperative(*value.cooperative.as_ref() != 0)
}
GraphKernelNodeAttributeId::ClusterDimension => {
let dim = value.clusterDim.as_ref();
GraphKernelNodeAttribute::ClusterDimension(Dim3::new(dim.x, dim.y, dim.z))
}
GraphKernelNodeAttributeId::Priority => {
GraphKernelNodeAttribute::Priority(*value.priority.as_ref())
}
GraphKernelNodeAttributeId::PreferredSharedMemoryCarveout => {
GraphKernelNodeAttribute::PreferredSharedMemoryCarveout(
*value.sharedMemCarveout.as_ref(),
)
}
})
}
}
pub fn set_kernel_node_attribute(&mut self, attribute: GraphKernelNodeAttribute) -> Result<()> {
let (id, value) = match attribute {
GraphKernelNodeAttribute::Cooperative(value) => {
let mut attr = runtime::cudaLaunchAttributeValue {
cooperative: runtime::__BindgenUnionField::new(),
..runtime::cudaLaunchAttributeValue::default()
};
unsafe { *attr.cooperative.as_mut() = i32::from(value) };
(GraphKernelNodeAttributeId::Cooperative, attr)
}
GraphKernelNodeAttribute::ClusterDimension(value) => {
let mut attr = runtime::cudaLaunchAttributeValue {
clusterDim: runtime::__BindgenUnionField::new(),
..runtime::cudaLaunchAttributeValue::default()
};
unsafe {
*attr.clusterDim.as_mut() = runtime::cudaLaunchAttributeValue__bindgen_ty_1 {
x: value.x,
y: value.y,
z: value.z,
};
}
(GraphKernelNodeAttributeId::ClusterDimension, attr)
}
GraphKernelNodeAttribute::Priority(value) => {
let mut attr = runtime::cudaLaunchAttributeValue {
priority: runtime::__BindgenUnionField::new(),
..runtime::cudaLaunchAttributeValue::default()
};
unsafe { *attr.priority.as_mut() = value as _ };
(GraphKernelNodeAttributeId::Priority, attr)
}
GraphKernelNodeAttribute::PreferredSharedMemoryCarveout(value) => {
let mut attr = runtime::cudaLaunchAttributeValue {
sharedMemCarveout: runtime::__BindgenUnionField::new(),
..runtime::cudaLaunchAttributeValue::default()
};
unsafe { *attr.sharedMemCarveout.as_mut() = value };
(
GraphKernelNodeAttributeId::PreferredSharedMemoryCarveout,
attr,
)
}
};
unsafe {
try_cuda!(runtime::cudaGraphKernelNodeSetAttribute(
self.handle,
id.into(),
&raw const value,
))?;
}
Ok(())
}
pub fn copy_kernel_node_attributes(self, other: Self) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphKernelNodeCopyAttributes(
self.handle,
other.handle
))?;
}
Ok(())
}
pub const unsafe fn as_raw(self) -> runtime::cudaGraphNode_t {
self.handle
}
}
impl MemoryAllocationNodeInfo {
pub const fn ptr(&self) -> DevicePtr {
self.ptr
}
}
#[derive(Debug)]
pub struct Graph {
handle: runtime::cudaGraph_t,
owns_handle: bool,
}
impl Graph {
pub(crate) const unsafe fn from_raw(handle: runtime::cudaGraph_t) -> Self {
Self {
handle,
owns_handle: true,
}
}
pub(crate) const unsafe fn from_raw_borrowed(handle: runtime::cudaGraph_t) -> Self {
Self {
handle,
owns_handle: false,
}
}
pub fn create() -> Result<Self> {
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaGraphCreate(&raw mut handle, 0))?;
}
Ok(Self {
handle,
owns_handle: true,
})
}
pub fn instantiate(&self) -> Result<ExecutableGraph> {
self.instantiate_with_flags(GraphInstantiateFlags::empty())
}
pub fn instantiate_with_flags(&self, flags: GraphInstantiateFlags) -> Result<ExecutableGraph> {
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaGraphInstantiateWithFlags(
&raw mut handle,
self.handle,
flags.bits(),
))?;
}
Ok(ExecutableGraph { handle })
}
pub fn try_clone(&self) -> Result<Self> {
let mut handle = ptr::null_mut();
unsafe {
try_cuda!(runtime::cudaGraphClone(&raw mut handle, self.handle))?;
}
Ok(Self {
handle,
owns_handle: true,
})
}
pub fn add_dependency(&mut self, from: GraphNode, to: GraphNode) -> Result<()> {
self.add_dependencies(&[from], &[to])
}
pub fn add_dependencies(&mut self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
self.add_dependencies_with_data(from, to, &[])
}
pub fn add_dependencies_with_data(
&mut self,
from: &[GraphNode],
to: &[GraphNode],
edge_data: &[GraphEdgeData],
) -> Result<()> {
if from.len() != to.len() {
return Err(Error::GraphDependencyMismatch);
}
if !edge_data.is_empty() && edge_data.len() != from.len() {
return Err(Error::GraphDependencyMismatch);
}
if from.is_empty() {
return Ok(());
}
let from_raw: Vec<_> = from.iter().map(|node| node.handle).collect();
let to_raw: Vec<_> = to.iter().map(|node| node.handle).collect();
let edge_data_raw: Vec<_> = edge_data.iter().copied().map(Into::into).collect();
unsafe {
try_cuda!(runtime::cudaGraphAddDependencies(
self.handle,
from_raw.as_ptr(),
to_raw.as_ptr(),
if edge_data_raw.is_empty() {
ptr::null()
} else {
edge_data_raw.as_ptr()
},
from_raw.len() as runtime::size_t,
))?;
}
Ok(())
}
pub fn remove_dependency(&mut self, from: GraphNode, to: GraphNode) -> Result<()> {
self.remove_dependencies(&[from], &[to])
}
pub fn remove_dependencies(&mut self, from: &[GraphNode], to: &[GraphNode]) -> Result<()> {
self.remove_dependencies_with_data(from, to, &[])
}
pub fn remove_dependencies_with_data(
&mut self,
from: &[GraphNode],
to: &[GraphNode],
edge_data: &[GraphEdgeData],
) -> Result<()> {
if from.len() != to.len() {
return Err(Error::GraphDependencyMismatch);
}
if !edge_data.is_empty() && edge_data.len() != from.len() {
return Err(Error::GraphDependencyMismatch);
}
if from.is_empty() {
return Ok(());
}
let from_raw: Vec<_> = from.iter().map(|node| node.handle).collect();
let to_raw: Vec<_> = to.iter().map(|node| node.handle).collect();
let edge_data_raw: Vec<_> = edge_data.iter().copied().map(Into::into).collect();
unsafe {
try_cuda!(runtime::cudaGraphRemoveDependencies(
self.handle,
from_raw.as_ptr(),
to_raw.as_ptr(),
if edge_data_raw.is_empty() {
ptr::null()
} else {
edge_data_raw.as_ptr()
},
from_raw.len() as runtime::size_t,
))?;
}
Ok(())
}
pub fn add_edges(&mut self, edges: &[GraphEdge]) -> Result<()> {
if edges.is_empty() {
return Ok(());
}
let from: Vec<_> = edges.iter().map(|edge| edge.from).collect();
let to: Vec<_> = edges.iter().map(|edge| edge.to).collect();
let data: Vec<_> = edges.iter().map(|edge| edge.data).collect();
self.add_dependencies_with_data(&from, &to, &data)
}
pub fn remove_edges(&mut self, edges: &[GraphEdge]) -> Result<()> {
if edges.is_empty() {
return Ok(());
}
let from: Vec<_> = edges.iter().map(|edge| edge.from).collect();
let to: Vec<_> = edges.iter().map(|edge| edge.to).collect();
let data: Vec<_> = edges.iter().map(|edge| edge.data).collect();
self.remove_dependencies_with_data(&from, &to, &data)
}
pub fn add_empty_node(&mut self, dependencies: &[GraphNode]) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
unsafe {
try_cuda!(runtime::cudaGraphAddEmptyNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_event_record_node(
&mut self,
dependencies: &[GraphNode],
event: &Event,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
unsafe {
try_cuda!(runtime::cudaGraphAddEventRecordNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
event.as_raw(),
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_event_wait_node(
&mut self,
dependencies: &[GraphNode],
event: &Event,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
unsafe {
try_cuda!(runtime::cudaGraphAddEventWaitNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
event.as_raw(),
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_host_node(
&mut self,
dependencies: &[GraphNode],
params: &HostNodeParams,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
let params = params.into();
unsafe {
try_cuda!(runtime::cudaGraphAddHostNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
&raw const params,
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_kernel_node(
&mut self,
dependencies: &[GraphNode],
params: &KernelNodeParams<'_>,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
let params = params.into();
unsafe {
try_cuda!(runtime::cudaGraphAddKernelNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
&raw const params,
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_memcpy_node_1d(
&mut self,
dependencies: &[GraphNode],
params: &Memcpy1DNodeParams,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
unsafe {
try_cuda!(runtime::cudaGraphAddMemcpyNode1D(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
params.dst.cast(),
params.src.cast(),
params.count as _,
params.kind.into(),
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_memcpy_node(
&mut self,
dependencies: &[GraphNode],
params: &Memcpy3DNodeParams,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
let params = params.into();
unsafe {
try_cuda!(runtime::cudaGraphAddMemcpyNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
&raw const params,
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_memcpy_node_to_symbol(
&mut self,
dependencies: &[GraphNode],
params: &MemcpyToSymbolNodeParams,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
unsafe {
try_cuda!(runtime::cudaGraphAddMemcpyNodeToSymbol(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
params.symbol.cast(),
params.src.cast(),
params.count as _,
params.offset as _,
params.kind.into(),
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_memcpy_node_from_symbol(
&mut self,
dependencies: &[GraphNode],
params: &MemcpyFromSymbolNodeParams,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
unsafe {
try_cuda!(runtime::cudaGraphAddMemcpyNodeFromSymbol(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
params.dst.cast(),
params.symbol.cast(),
params.count as _,
params.offset as _,
params.kind.into(),
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_memset_node(
&mut self,
dependencies: &[GraphNode],
params: &MemsetNodeParams,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
let params = params.into();
unsafe {
try_cuda!(runtime::cudaGraphAddMemsetNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
&raw const params,
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_child_graph_node(
&mut self,
dependencies: &[GraphNode],
child_graph: &Self,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
unsafe {
try_cuda!(runtime::cudaGraphAddChildGraphNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
child_graph.handle,
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_mem_free_node(
&mut self,
dependencies: &[GraphNode],
ptr: DevicePtr,
) -> Result<GraphNode> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
unsafe {
try_cuda!(runtime::cudaGraphAddMemFreeNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
ptr.as_ptr() as _,
))?;
Ok(GraphNode::from_raw(handle))
}
}
pub fn add_mem_alloc_node(
&mut self,
dependencies: &[GraphNode],
params: &MemAllocNodeParams<'_>,
) -> Result<(GraphNode, DevicePtr)> {
let mut handle = ptr::null_mut();
let dependencies_raw: Vec<_> = dependencies.iter().map(|node| node.handle).collect();
let access_descs: Vec<_> = params
.access_descs
.iter()
.copied()
.map(Into::into)
.collect();
let mut params_raw = runtime::cudaMemAllocNodeParams {
poolProps: params.pool_props.into(),
accessDescs: access_descs.as_ptr(),
accessDescCount: access_descs.len() as runtime::size_t,
bytesize: params.byte_size as _,
dptr: 0,
};
unsafe {
try_cuda!(runtime::cudaGraphAddMemAllocNode(
&raw mut handle,
self.handle,
dependencies_raw.as_ptr(),
dependencies_raw.len() as runtime::size_t,
&raw mut params_raw,
))?;
Ok((
GraphNode::from_raw(handle),
DevicePtr::new(params_raw.dptr as *mut ()),
))
}
}
pub fn nodes(&self) -> Result<Vec<GraphNode>> {
unsafe {
let mut count = 0;
try_cuda!(runtime::cudaGraphGetNodes(
self.handle,
ptr::null_mut(),
&raw mut count,
))?;
if count == 0 {
return Ok(Vec::new());
}
let mut handles = Vec::with_capacity(count as usize);
try_cuda!(runtime::cudaGraphGetNodes(
self.handle,
handles.as_mut_ptr(),
&raw mut count,
))?;
handles.set_len(count as usize);
Ok(handles
.into_iter()
.map(|handle| GraphNode { handle })
.collect())
}
}
pub fn root_nodes(&self) -> Result<Vec<GraphNode>> {
unsafe {
let mut count = 0;
try_cuda!(runtime::cudaGraphGetRootNodes(
self.handle,
ptr::null_mut(),
&raw mut count,
))?;
if count == 0 {
return Ok(Vec::new());
}
let mut handles = Vec::with_capacity(count as usize);
try_cuda!(runtime::cudaGraphGetRootNodes(
self.handle,
handles.as_mut_ptr(),
&raw mut count,
))?;
handles.set_len(count as usize);
Ok(handles
.into_iter()
.map(|handle| GraphNode { handle })
.collect())
}
}
pub fn edges(&self) -> Result<Vec<GraphEdge>> {
unsafe {
let mut count = 0;
try_cuda!(runtime::cudaGraphGetEdges(
self.handle,
ptr::null_mut(),
ptr::null_mut(),
ptr::null_mut(),
&raw mut count,
))?;
if count == 0 {
return Ok(Vec::new());
}
let len = count as usize;
let mut from = Vec::with_capacity(len);
let mut to = Vec::with_capacity(len);
let mut edge_data = Vec::with_capacity(len);
try_cuda!(runtime::cudaGraphGetEdges(
self.handle,
from.as_mut_ptr(),
to.as_mut_ptr(),
edge_data.as_mut_ptr(),
&raw mut count,
))?;
let len = count as usize;
from.set_len(len);
to.set_len(len);
edge_data.set_len(len);
Ok(from
.into_iter()
.zip(to)
.zip(edge_data)
.map(|((from, to), data)| GraphEdge {
from: GraphNode { handle: from },
to: GraphNode { handle: to },
data: data.into(),
})
.collect())
}
}
pub fn write_dot(&self, path: &str, flags: GraphDebugDotFlags) -> Result<()> {
let path = CString::new(path)?;
unsafe {
try_cuda!(runtime::cudaGraphDebugDotPrint(
self.handle,
path.as_ptr(),
flags.bits(),
))?;
}
Ok(())
}
pub const unsafe fn as_raw(&self) -> runtime::cudaGraph_t {
self.handle
}
}
impl Drop for Graph {
fn drop(&mut self) {
if !self.owns_handle {
return;
}
unsafe {
if let Err(err) = try_cuda!(runtime::cudaGraphDestroy(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy cuda graph: {err}");
}
}
}
}
#[derive(Debug)]
pub struct ExecutableGraph {
handle: runtime::cudaGraphExec_t,
}
impl ExecutableGraph {
pub fn flags(&self) -> Result<GraphInstantiateFlags> {
let mut flags = 0;
unsafe {
try_cuda!(runtime::cudaGraphExecGetFlags(self.handle, &raw mut flags))?;
}
Ok(GraphInstantiateFlags::from_bits_retain(flags))
}
pub fn launch(&self, stream: &Stream) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphLaunch(self.handle, stream.as_raw()))?;
}
Ok(())
}
pub fn upload(&self, stream: &Stream) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphUpload(self.handle, stream.as_raw()))?;
}
Ok(())
}
pub fn update(&mut self, graph: &Graph) -> Result<ExecutableGraphUpdate> {
let mut result_info = runtime::cudaGraphExecUpdateResultInfo::default();
unsafe {
try_cuda!(runtime::cudaGraphExecUpdate(
self.handle,
graph.handle,
&raw mut result_info,
))?;
}
Ok(result_info.into())
}
pub fn set_kernel_node_params(
&mut self,
node: GraphNode,
params: &KernelNodeParams<'_>,
) -> Result<()> {
let params = params.into();
unsafe {
try_cuda!(runtime::cudaGraphExecKernelNodeSetParams(
self.handle,
node.handle,
&raw const params,
))?;
}
Ok(())
}
pub fn set_memcpy_node_1d_params(
&mut self,
node: GraphNode,
params: &Memcpy1DNodeParams,
) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphExecMemcpyNodeSetParams1D(
self.handle,
node.handle,
params.dst.cast(),
params.src.cast(),
params.count as _,
params.kind.into(),
))?;
}
Ok(())
}
pub fn set_memcpy_node_params(
&mut self,
node: GraphNode,
params: &Memcpy3DNodeParams,
) -> Result<()> {
let params = params.into();
unsafe {
try_cuda!(runtime::cudaGraphExecMemcpyNodeSetParams(
self.handle,
node.handle,
&raw const params,
))?;
}
Ok(())
}
pub fn set_memcpy_node_to_symbol_params(
&mut self,
node: GraphNode,
params: &MemcpyToSymbolNodeParams,
) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphExecMemcpyNodeSetParamsToSymbol(
self.handle,
node.handle,
params.symbol.cast(),
params.src.cast(),
params.count as _,
params.offset as _,
params.kind.into(),
))?;
}
Ok(())
}
pub fn set_memcpy_node_from_symbol_params(
&mut self,
node: GraphNode,
params: &MemcpyFromSymbolNodeParams,
) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphExecMemcpyNodeSetParamsFromSymbol(
self.handle,
node.handle,
params.dst.cast(),
params.symbol.cast(),
params.count as _,
params.offset as _,
params.kind.into(),
))?;
}
Ok(())
}
pub fn set_memset_node_params(
&mut self,
node: GraphNode,
params: &MemsetNodeParams,
) -> Result<()> {
let params = params.into();
unsafe {
try_cuda!(runtime::cudaGraphExecMemsetNodeSetParams(
self.handle,
node.handle,
&raw const params,
))?;
}
Ok(())
}
pub fn set_host_node_params(&mut self, node: GraphNode, params: &HostNodeParams) -> Result<()> {
let params = params.into();
unsafe {
try_cuda!(runtime::cudaGraphExecHostNodeSetParams(
self.handle,
node.handle,
&raw const params,
))?;
}
Ok(())
}
pub fn set_event_record_node_event(&mut self, node: GraphNode, event: &Event) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphExecEventRecordNodeSetEvent(
self.handle,
node.handle,
event.as_raw(),
))?;
}
Ok(())
}
pub fn set_child_graph_node(&mut self, node: GraphNode, child_graph: &Graph) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphExecChildGraphNodeSetParams(
self.handle,
node.handle,
child_graph.handle,
))?;
}
Ok(())
}
pub fn set_event_wait_node_event(&mut self, node: GraphNode, event: &Event) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphExecEventWaitNodeSetEvent(
self.handle,
node.handle,
event.as_raw(),
))?;
}
Ok(())
}
pub fn set_node_enabled(&mut self, node: GraphNode, enabled: bool) -> Result<()> {
unsafe {
try_cuda!(runtime::cudaGraphNodeSetEnabled(
self.handle,
node.handle,
u32::from(enabled),
))?;
}
Ok(())
}
pub fn is_node_enabled(&self, node: GraphNode) -> Result<bool> {
let mut enabled = 0;
unsafe {
try_cuda!(runtime::cudaGraphNodeGetEnabled(
self.handle,
node.handle,
&raw mut enabled,
))?;
}
Ok(enabled != 0)
}
pub const unsafe fn as_raw(&self) -> runtime::cudaGraphExec_t {
self.handle
}
}
impl Drop for ExecutableGraph {
fn drop(&mut self) {
unsafe {
if let Err(err) = try_cuda!(runtime::cudaGraphExecDestroy(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy cuda graph exec: {err}");
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ExecutableGraphUpdate {
pub result: GraphExecUpdateResult,
pub error_node: Option<GraphNode>,
pub error_from_node: Option<GraphNode>,
}
impl From<runtime::cudaGraphExecUpdateResultInfo> for ExecutableGraphUpdate {
fn from(value: runtime::cudaGraphExecUpdateResultInfo) -> Self {
Self {
result: value.result.into(),
error_node: if value.errorNode.is_null() {
None
} else {
Some(GraphNode {
handle: value.errorNode,
})
},
error_from_node: if value.errorFromNode.is_null() {
None
} else {
Some(GraphNode {
handle: value.errorFromNode,
})
},
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MemsetNodeParams {
pub dst: DevicePtr,
pub pitch: usize,
pub value: u32,
pub element_size: u32,
pub width: usize,
pub height: usize,
}
impl MemsetNodeParams {
pub const fn new(dst: DevicePtr, element_size: u32, width: usize) -> Self {
Self {
dst,
pitch: 0,
value: 0,
element_size,
width,
height: 1,
}
}
}
impl From<&MemsetNodeParams> for driver::CUDA_MEMSET_NODE_PARAMS {
fn from(value: &MemsetNodeParams) -> Self {
Self {
dst: value.dst.as_ptr() as _,
pitch: value.pitch as _,
value: value.value,
elementSize: value.element_size,
width: value.width as _,
height: value.height as _,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct HostNodeParams {
pub func: HostFunction,
pub user_data: *mut (),
}
impl HostNodeParams {
pub const fn new(func: HostFunction, user_data: *mut ()) -> Self {
Self { func, user_data }
}
}
impl From<&HostNodeParams> for driver::CUDA_HOST_NODE_PARAMS {
fn from(value: &HostNodeParams) -> Self {
Self {
fn_: value.func.as_raw(),
userData: value.user_data.cast(),
}
}
}
#[derive(Debug, Clone)]
pub struct KernelNodeParams<'a> {
pub func: DeviceFunction,
pub grid_dim: Dim3,
pub block_dim: Dim3,
pub shared_mem_bytes: usize,
pub kernel_params: *mut *mut (),
pub extra: *mut *mut (),
_phantom: PhantomData<&'a ()>,
}
impl KernelNodeParams<'_> {
pub const fn new(func: DeviceFunction, grid_dim: Dim3, block_dim: Dim3) -> Self {
Self {
func,
grid_dim,
block_dim,
shared_mem_bytes: 0,
kernel_params: ptr::null_mut(),
extra: ptr::null_mut(),
_phantom: PhantomData,
}
}
pub const fn with_shared_mem_bytes(mut self, shared_mem_bytes: usize) -> Self {
self.shared_mem_bytes = shared_mem_bytes;
self
}
pub const fn with_kernel_params(mut self, kernel_params: *mut *mut ()) -> Self {
self.kernel_params = kernel_params;
self
}
pub const fn with_extra(mut self, extra: *mut *mut ()) -> Self {
self.extra = extra;
self
}
}
impl From<&KernelNodeParams<'_>> for runtime::cudaKernelNodeParams {
fn from(value: &KernelNodeParams<'_>) -> Self {
Self {
func: value.func.as_raw().cast(),
gridDim: value.grid_dim.into(),
blockDim: value.block_dim.into(),
sharedMemBytes: value.shared_mem_bytes as _,
kernelParams: value.kernel_params.cast(),
extra: value.extra.cast(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Memcpy1DNodeParams {
pub dst: *mut (),
pub src: *const (),
pub count: usize,
pub kind: MemoryCopyKind,
}
impl Memcpy1DNodeParams {
pub const fn new(dst: *mut (), src: *const (), count: usize, kind: MemoryCopyKind) -> Self {
Self {
dst,
src,
count,
kind,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct ArrayHandle(runtime::cudaArray_t);
impl ArrayHandle {
pub const unsafe fn from_raw(handle: runtime::cudaArray_t) -> Self {
Self(handle)
}
pub const unsafe fn as_raw(self) -> runtime::cudaArray_t {
self.0
}
}