use std::sync::Arc;
use singe_cuda::graph::{ExecutableGraph, Graph, GraphNode};
use singe_cupti_sys as sys;
use crate::{
error::Result,
try_ffi,
types::{
ActivityAutoBoostState, ActivityThreadIdType, ContextId, DeviceId, GraphExecId, GraphId,
GraphNodeId,
},
};
#[derive(Debug, Clone)]
pub struct Context {
inner: Arc<singe_cuda::context::Context>,
}
impl Context {
pub fn create() -> Result<Self> {
Ok(Self {
inner: singe_cuda::context::Context::create()?,
})
}
pub fn bind(&self) -> Result<()> {
self.inner.bind()?;
Ok(())
}
pub(crate) fn as_raw(&self) -> sys::CUcontext {
self.inner.as_raw() as sys::CUcontext
}
}
pub fn runtime_version() -> Result<u32> {
let mut version = 0u32;
unsafe {
try_ffi!(sys::cuptiGetVersion(&mut version))?;
}
Ok(version)
}
fn raw_get_context_id(context: sys::CUcontext, context_id: *mut u32) -> Result<()> {
unsafe {
try_ffi!(sys::cuptiGetContextId(context, context_id))?;
}
Ok(())
}
fn raw_get_device_id(context: sys::CUcontext, device_id: *mut u32) -> Result<()> {
unsafe {
try_ffi!(sys::cuptiGetDeviceId(context, device_id))?;
}
Ok(())
}
fn raw_get_graph_exec_id(graph_exec: sys::CUgraphExec, graph_exec_id: *mut u32) -> Result<()> {
unsafe {
try_ffi!(sys::cuptiGetGraphExecId(graph_exec, graph_exec_id))?;
}
Ok(())
}
fn raw_get_graph_id(graph: sys::CUgraph, graph_id: *mut u32) -> Result<()> {
unsafe {
try_ffi!(sys::cuptiGetGraphId(graph, graph_id))?;
}
Ok(())
}
fn raw_get_graph_node_id(node: sys::CUgraphNode, node_id: *mut u64) -> Result<()> {
unsafe {
try_ffi!(sys::cuptiGetGraphNodeId(node, node_id))?;
}
Ok(())
}
pub fn last_error() -> Result<()> {
unsafe {
try_ffi!(sys::cuptiGetLastError())?;
}
Ok(())
}
fn raw_get_timestamp(timestamp: *mut u64) -> Result<()> {
unsafe {
try_ffi!(sys::cuptiGetTimestamp(timestamp))?;
}
Ok(())
}
fn raw_get_thread_id_type(thread_id_type: *mut sys::CUpti_ActivityThreadIdType) -> Result<()> {
unsafe {
try_ffi!(sys::cuptiGetThreadIdType(thread_id_type))?;
}
Ok(())
}
pub fn context_id(context: &Context) -> Result<ContextId> {
let mut context_id = 0;
raw_get_context_id(context.as_raw(), &mut context_id)?;
Ok(ContextId::from(context_id))
}
pub fn device_id(context: &Context) -> Result<DeviceId> {
let mut device_id = 0;
raw_get_device_id(context.as_raw(), &mut device_id)?;
Ok(DeviceId::from(device_id))
}
pub fn graph_exec_id(graph_exec: &ExecutableGraph) -> Result<GraphExecId> {
let mut graph_exec_id = 0;
raw_get_graph_exec_id(graph_exec.as_raw() as sys::CUgraphExec, &mut graph_exec_id)?;
Ok(GraphExecId::from(graph_exec_id))
}
pub fn graph_id(graph: &Graph) -> Result<GraphId> {
let mut graph_id = 0;
raw_get_graph_id(graph.as_raw() as sys::CUgraph, &mut graph_id)?;
Ok(GraphId::from(graph_id))
}
pub fn graph_node_id(node: GraphNode) -> Result<GraphNodeId> {
let mut node_id = 0;
raw_get_graph_node_id(node.as_raw() as sys::CUgraphNode, &mut node_id)?;
Ok(GraphNodeId::from(node_id))
}
pub fn timestamp() -> Result<u64> {
let mut timestamp = 0;
raw_get_timestamp(&mut timestamp)?;
Ok(timestamp)
}
pub fn thread_id_type() -> Result<ActivityThreadIdType> {
let mut thread_id_type = sys::CUpti_ActivityThreadIdType::CUPTI_ACTIVITY_THREAD_ID_TYPE_DEFAULT;
raw_get_thread_id_type(&mut thread_id_type)?;
Ok(ActivityThreadIdType::from(thread_id_type))
}
pub fn set_thread_id_type(thread_id_type: ActivityThreadIdType) -> Result<()> {
unsafe {
try_ffi!(sys::cuptiSetThreadIdType(thread_id_type.into()))?;
}
Ok(())
}
fn raw_get_auto_boost_state(
context: sys::CUcontext,
state: *mut sys::CUpti_ActivityAutoBoostState,
) -> Result<()> {
unsafe {
try_ffi!(sys::cuptiGetAutoBoostState(context, state))?;
}
Ok(())
}
pub fn auto_boost_state(context: &Context) -> Result<ActivityAutoBoostState> {
let mut state = sys::CUpti_ActivityAutoBoostState::default();
raw_get_auto_boost_state(context.as_raw(), &mut state)?;
Ok(state.into())
}
pub fn finalize_process() -> Result<()> {
unsafe {
try_ffi!(sys::cuptiFinalize())?;
}
Ok(())
}