use core::ffi::CStr;
use super::sys;
use crate::driver;
pub mod activity;
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct CuptiError(pub sys::CUptiResult);
impl sys::CUptiResult {
#[inline]
pub fn result(self) -> Result<(), CuptiError> {
match self {
sys::CUptiResult::CUPTI_SUCCESS => Ok(()),
_ => Err(CuptiError(self)),
}
}
}
impl CuptiError {
pub fn error_string(&self) -> Result<&CStr, CuptiError> {
let mut err_str = std::ptr::null();
unsafe {
sys::cuptiGetResultString(self.0, &mut err_str).result()?;
Ok(CStr::from_ptr(err_str))
}
}
}
impl std::fmt::Debug for CuptiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let err_str = self.error_string().unwrap();
f.debug_tuple("CuptiError")
.field(&self.0)
.field(&err_str)
.finish()
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for CuptiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CuptiError {}
pub unsafe fn compute_capability_supported(
major: core::ffi::c_int,
minor: core::ffi::c_int,
support: *mut core::ffi::c_int,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiComputeCapabilitySupported(major, minor, support) }.result()
}
pub unsafe fn device_supported(
dev: driver::sys::CUdevice,
support: *mut core::ffi::c_int,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiDeviceSupported(dev, support) }.result()
}
pub unsafe fn device_virtualization_mode(
dev: driver::sys::CUdevice,
mode: *mut sys::CUpti_DeviceVirtualizationMode,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiDeviceVirtualizationMode(dev, mode) }.result()
}
pub unsafe fn enable_all_domains(
enable: u32,
subscriber: sys::CUpti_SubscriberHandle,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiEnableAllDomains(enable, subscriber) }.result()
}
pub unsafe fn enable_callback(
enable: u32,
subscriber: sys::CUpti_SubscriberHandle,
domain: sys::CUpti_CallbackDomain,
cbid: sys::CUpti_CallbackId,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiEnableCallback(enable, subscriber, domain, cbid) }.result()
}
pub unsafe fn enable_domain(
enable: u32,
subscriber: sys::CUpti_SubscriberHandle,
domain: sys::CUpti_CallbackDomain,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiEnableDomain(enable, subscriber, domain) }.result()
}
pub fn finalize() -> Result<(), CuptiError> {
unsafe { sys::cuptiFinalize() }.result()
}
pub unsafe fn get_auto_boost_state(
context: driver::sys::CUcontext,
state: *mut sys::CUpti_ActivityAutoBoostState,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetAutoBoostState(context, state) }.result()
}
pub unsafe fn get_callback_name(
domain: sys::CUpti_CallbackDomain,
cbid: u32,
name: *mut *const core::ffi::c_char,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetCallbackName(domain, cbid, name) }.result()
}
pub unsafe fn get_callback_state(
enable: *mut u32,
subscriber: sys::CUpti_SubscriberHandle,
domain: sys::CUpti_CallbackDomain,
cbid: sys::CUpti_CallbackId,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetCallbackState(enable, subscriber, domain, cbid) }.result()
}
pub unsafe fn get_context_id(
context: driver::sys::CUcontext,
context_id: *mut u32,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetContextId(context, context_id) }.result()
}
pub unsafe fn get_device_id(
context: driver::sys::CUcontext,
device_id: *mut u32,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetDeviceId(context, device_id) }.result()
}
#[cfg(any(
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000"
))]
pub unsafe fn get_graph_exec_id(
graph_exec: driver::sys::CUgraphExec,
p_id: *mut u32,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetGraphExecId(graph_exec, p_id) }.result()
}
pub unsafe fn get_graph_id(graph: driver::sys::CUgraph, p_id: *mut u32) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetGraphId(graph, p_id) }.result()
}
pub unsafe fn get_graph_node_id(
node: driver::sys::CUgraphNode,
node_id: *mut u64,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetGraphNodeId(node, node_id) }.result()
}
pub fn get_last_error() -> Result<(), CuptiError> {
unsafe { sys::cuptiGetLastError() }.result()
}
pub unsafe fn get_stream_id(
context: driver::sys::CUcontext,
stream: driver::sys::CUstream,
stream_id: *mut u32,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetStreamId(context, stream, stream_id) }.result()
}
pub unsafe fn get_stream_id_ex(
context: driver::sys::CUcontext,
stream: driver::sys::CUstream,
per_thread_stream: u8,
stream_id: *mut u32,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetStreamIdEx(context, stream, per_thread_stream, stream_id) }.result()
}
pub unsafe fn get_thread_id_type(
r#type: *mut sys::CUpti_ActivityThreadIdType,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetThreadIdType(r#type) }.result()
}
pub unsafe fn get_timestamp(timestamp: *mut u64) -> Result<(), CuptiError> {
unsafe { sys::cuptiGetTimestamp(timestamp) }.result()
}
pub fn set_thread_id_type(r#type: sys::CUpti_ActivityThreadIdType) -> Result<(), CuptiError> {
unsafe { sys::cuptiSetThreadIdType(r#type) }.result()
}
pub unsafe fn subscribe(
subscriber: *mut sys::CUpti_SubscriberHandle,
callback: sys::CUpti_CallbackFunc,
userdata: *mut std::os::raw::c_void,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiSubscribe(subscriber, callback, userdata) }.result()
}
#[cfg(feature = "cuda-13000")]
pub unsafe fn subscribe_v2(
subscriber: *mut sys::CUpti_SubscriberHandle,
callback: sys::CUpti_CallbackFunc,
userdata: *mut core::ffi::c_void,
p_params: *mut sys::CUpti_SubscriberParams,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiSubscribe_v2(subscriber, callback, userdata, p_params) }.result()
}
pub unsafe fn supported_domains(
domain_count: *mut usize,
domain_table: *mut sys::CUpti_DomainTable,
) -> Result<(), CuptiError> {
unsafe { sys::cuptiSupportedDomains(domain_count, domain_table) }.result()
}
pub unsafe fn unsubscribe(subscriber: sys::CUpti_SubscriberHandle) -> Result<(), CuptiError> {
unsafe { sys::cuptiUnsubscribe(subscriber) }.result()
}