use crate::error::CuptiResult;
use std::ffi::c_void;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CallbackDomain {
RuntimeApi,
DriverApi,
Resource,
Synchronize,
Nvtx,
}
impl CallbackDomain {
pub fn cupti_id(&self) -> u32 {
match self {
CallbackDomain::RuntimeApi => 1,
CallbackDomain::DriverApi => 2,
CallbackDomain::Resource => 3,
CallbackDomain::Synchronize => 4,
CallbackDomain::Nvtx => 5,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CallbackId {
CudaMalloc,
CudaFree,
CudaMemcpy,
CudaMemcpyAsync,
CudaLaunchKernel,
CudaDeviceSynchronize,
CudaStreamSynchronize,
CuMemAlloc,
CuMemFree,
CuLaunchKernel,
CuCtxSynchronize,
ContextCreated,
ContextDestroyed,
StreamCreated,
StreamDestroyed,
ModuleLoaded,
ModuleUnloaded,
Other(u32),
}
impl CallbackId {
pub fn cupti_id(&self) -> u32 {
match self {
CallbackId::CudaMalloc => 1,
CallbackId::CudaFree => 2,
CallbackId::CudaMemcpy => 3,
CallbackId::CudaMemcpyAsync => 4,
CallbackId::CudaLaunchKernel => 5,
CallbackId::CudaDeviceSynchronize => 6,
CallbackId::CudaStreamSynchronize => 7,
CallbackId::CuMemAlloc => 100,
CallbackId::CuMemFree => 101,
CallbackId::CuLaunchKernel => 102,
CallbackId::CuCtxSynchronize => 103,
CallbackId::ContextCreated => 200,
CallbackId::ContextDestroyed => 201,
CallbackId::StreamCreated => 202,
CallbackId::StreamDestroyed => 203,
CallbackId::ModuleLoaded => 204,
CallbackId::ModuleUnloaded => 205,
CallbackId::Other(id) => *id,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CallbackSite {
Enter,
Exit,
}
#[derive(Debug)]
pub struct CallbackData {
pub domain: CallbackDomain,
pub callback_id: CallbackId,
pub site: CallbackSite,
pub correlation_id: u64,
pub context: Option<u64>,
pub function_name: Option<String>,
pub return_value: Option<i32>,
}
pub type CallbackFn = Box<dyn Fn(&CallbackData) + Send + Sync>;
pub struct CallbackSubscriber {
id: u64,
callbacks: Vec<(CallbackDomain, Option<CallbackId>, CallbackFn)>,
active: bool,
}
impl CallbackSubscriber {
pub fn new() -> Self {
static NEXT_ID: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(1);
Self {
id: NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
callbacks: Vec::new(),
active: false,
}
}
pub fn id(&self) -> u64 {
self.id
}
pub fn subscribe_domain<F>(&mut self, domain: CallbackDomain, callback: F) -> CuptiResult<()>
where
F: Fn(&CallbackData) + Send + Sync + 'static,
{
self.callbacks.push((domain, None, Box::new(callback)));
Ok(())
}
pub fn subscribe<F>(
&mut self,
domain: CallbackDomain,
callback_id: CallbackId,
callback: F,
) -> CuptiResult<()>
where
F: Fn(&CallbackData) + Send + Sync + 'static,
{
self.callbacks
.push((domain, Some(callback_id), Box::new(callback)));
Ok(())
}
pub fn enable(&mut self) -> CuptiResult<()> {
self.active = true;
Ok(())
}
pub fn disable(&mut self) -> CuptiResult<()> {
self.active = false;
Ok(())
}
pub fn is_active(&self) -> bool {
self.active
}
#[doc(hidden)]
pub fn dispatch(&self, data: &CallbackData) {
if !self.active {
return;
}
for (domain, callback_id, handler) in &self.callbacks {
if *domain == data.domain {
match callback_id {
None => handler(data),
Some(id) if *id == data.callback_id => handler(data),
_ => {}
}
}
}
}
}
impl Default for CallbackSubscriber {
fn default() -> Self {
Self::new()
}
}
#[derive(Default)]
pub struct CallbackBuilder {
subscriptions: Vec<(CallbackDomain, Option<CallbackId>)>,
}
impl CallbackBuilder {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn on_kernel_launch(mut self) -> Self {
self.subscriptions.push((
CallbackDomain::RuntimeApi,
Some(CallbackId::CudaLaunchKernel),
));
self.subscriptions
.push((CallbackDomain::DriverApi, Some(CallbackId::CuLaunchKernel)));
self
}
#[must_use]
pub fn on_memory_ops(mut self) -> Self {
self.subscriptions
.push((CallbackDomain::RuntimeApi, Some(CallbackId::CudaMalloc)));
self.subscriptions
.push((CallbackDomain::RuntimeApi, Some(CallbackId::CudaFree)));
self.subscriptions
.push((CallbackDomain::RuntimeApi, Some(CallbackId::CudaMemcpy)));
self
}
#[must_use]
pub fn on_synchronization(mut self) -> Self {
self.subscriptions.push((CallbackDomain::Synchronize, None));
self
}
#[must_use]
pub fn on_runtime_api(mut self) -> Self {
self.subscriptions.push((CallbackDomain::RuntimeApi, None));
self
}
#[must_use]
pub fn on_driver_api(mut self) -> Self {
self.subscriptions.push((CallbackDomain::DriverApi, None));
self
}
pub fn build<F>(self, handler: F) -> CuptiResult<CallbackSubscriber>
where
F: Fn(&CallbackData) + Send + Sync + Clone + 'static,
{
let mut subscriber = CallbackSubscriber::new();
for (domain, callback_id) in self.subscriptions {
let h = handler.clone();
match callback_id {
Some(id) => subscriber.subscribe(domain, id, h)?,
None => subscriber.subscribe_domain(domain, h)?,
}
}
Ok(subscriber)
}
}
pub type RawCallbackFn = unsafe extern "C" fn(
user_data: *mut c_void,
domain: u32,
callback_id: u32,
callback_data: *const c_void,
);
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn test_callback_domain_id() {
assert_eq!(CallbackDomain::RuntimeApi.cupti_id(), 1);
assert_eq!(CallbackDomain::DriverApi.cupti_id(), 2);
}
#[test]
fn test_callback_subscriber() {
let mut subscriber = CallbackSubscriber::new();
assert!(!subscriber.is_active());
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
subscriber
.subscribe_domain(CallbackDomain::RuntimeApi, move |_data| {
counter_clone.fetch_add(1, Ordering::SeqCst);
})
.unwrap();
subscriber.enable().unwrap();
assert!(subscriber.is_active());
let data = CallbackData {
domain: CallbackDomain::RuntimeApi,
callback_id: CallbackId::CudaMalloc,
site: CallbackSite::Enter,
correlation_id: 1,
context: None,
function_name: Some("cudaMalloc".to_string()),
return_value: None,
};
subscriber.dispatch(&data);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[test]
fn test_callback_builder() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let subscriber = CallbackBuilder::new()
.on_kernel_launch()
.on_memory_ops()
.build(move |_data| {
counter_clone.fetch_add(1, Ordering::SeqCst);
})
.unwrap();
assert!(!subscriber.is_active());
}
#[test]
fn test_callback_filtering() {
let mut subscriber = CallbackSubscriber::new();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
subscriber
.subscribe(
CallbackDomain::RuntimeApi,
CallbackId::CudaMalloc,
move |_data| {
counter_clone.fetch_add(1, Ordering::SeqCst);
},
)
.unwrap();
subscriber.enable().unwrap();
let malloc_data = CallbackData {
domain: CallbackDomain::RuntimeApi,
callback_id: CallbackId::CudaMalloc,
site: CallbackSite::Enter,
correlation_id: 1,
context: None,
function_name: None,
return_value: None,
};
subscriber.dispatch(&malloc_data);
assert_eq!(counter.load(Ordering::SeqCst), 1);
let free_data = CallbackData {
domain: CallbackDomain::RuntimeApi,
callback_id: CallbackId::CudaFree,
site: CallbackSite::Enter,
correlation_id: 2,
context: None,
function_name: None,
return_value: None,
};
subscriber.dispatch(&free_data);
assert_eq!(counter.load(Ordering::SeqCst), 1); }
}