use crate::rocprofiler::error::{Error, Result};
use crate::rocprofiler::types::{Feature, ProfilerMode, Group};
use crate::rocprofiler::bindings;
use crate::hip;
use std::marker::PhantomData;
use std::ptr;
use std::mem;
pub struct Properties {
pub(crate) properties: bindings::rocprofiler_properties_t,
_handler_data: Option<Box<HandlerData>>,
}
struct HandlerData {
handler: Box<dyn FnMut(Group) -> bool + Send + 'static>,
}
unsafe extern "C" fn handler_callback(
group: bindings::rocprofiler_group_t,
arg: *mut std::os::raw::c_void,
) -> bool {
if arg.is_null() {
return false;
}
let handler_data = &mut *(arg as *mut HandlerData);
let rust_group = Group::from_c(&group);
(handler_data.handler)(rust_group)
}
impl Properties {
pub fn new() -> Self {
Self {
properties: bindings::rocprofiler_properties_t {
queue: ptr::null_mut(),
queue_depth: 0,
handler: None,
handler_arg: ptr::null_mut(),
},
_handler_data: None,
}
}
pub fn with_queue_depth(mut self, depth: u32) -> Self {
self.properties.queue_depth = depth;
self
}
pub fn with_handler<F>(mut self, handler: F) -> Self
where
F: FnMut(Group) -> bool + Send + 'static,
{
let handler_data = Box::new(HandlerData {
handler: Box::new(handler),
});
self.properties.handler = Some(handler_callback);
self.properties.handler_arg = Box::into_raw(handler_data) as *mut std::os::raw::c_void;
self._handler_data = Some(unsafe { Box::from_raw(self.properties.handler_arg as *mut HandlerData) });
self
}
pub fn with_queue(mut self, queue: *mut hip::ffi::hipStream_t) -> Self {
self.properties.queue = queue as *mut _;
self
}
}
impl Default for Properties {
fn default() -> Self {
Self::new()
}
}
pub struct Context {
context: *mut bindings::rocprofiler_t,
features: Vec<Feature>,
c_features: Vec<*mut bindings::rocprofiler_feature_t>,
agent: hip::Device,
_phantom: PhantomData<()>,
}
unsafe impl Send for Context {}
unsafe impl Sync for Context {}
impl Context {
pub fn new(
device: hip::Device,
mut features: Vec<Feature>,
modes: &[ProfilerMode],
properties: Option<Properties>,
) -> Result<Self> {
let mode: u32 = modes.iter().fold(0, |acc, mode| acc | (*mode as u32));
let mut c_features: Vec<*mut bindings::rocprofiler_feature_t> = Vec::with_capacity(features.len());
for feature in &mut features {
c_features.push(feature.to_c());
}
let properties_ptr = match properties {
Some(props) => &props.properties as *const _ as *mut _,
None => ptr::null_mut(),
};
let mut context = unsafe { std::mem::zeroed() };
let agent = bindings::hsa_agent_t {
handle: device.id() as u64,
};
let status = unsafe {
bindings::rocprofiler_open(
agent,
if c_features.is_empty() { ptr::null_mut() } else { *c_features.as_mut_ptr() },
c_features.len() as u32,
&mut context,
mode,
properties_ptr,
)
};
if status != bindings::hsa_status_t_HSA_STATUS_SUCCESS {
return Err(Error::new(status));
}
Ok(Self {
context,
features,
c_features,
agent: device,
_phantom: PhantomData,
})
}
pub fn group_count(&self) -> Result<u32> {
let mut count = 0;
let status = unsafe { bindings::rocprofiler_group_count(self.context, &mut count) };
Error::from_hsa_status_with_value(status, count)
}
pub fn get_group(&self, index: u32) -> Result<Group> {
let mut group_data = bindings::rocprofiler_group_t {
index,
features: ptr::null_mut(),
feature_count: 0,
context: self.context,
};
let status = unsafe { bindings::rocprofiler_get_group(self.context, index, &mut group_data) };
if status != bindings::hsa_status_t_HSA_STATUS_SUCCESS {
return Err(Error::new(status));
}
Ok(Group::from_c(&group_data))
}
pub fn start(&self, group_index: u32) -> Result<()> {
let status = unsafe { bindings::rocprofiler_start(self.context, group_index) };
Error::from_hsa_status(status)
}
pub fn stop(&self, group_index: u32) -> Result<()> {
let status = unsafe { bindings::rocprofiler_stop(self.context, group_index) };
Error::from_hsa_status(status)
}
pub fn read(&self, group_index: u32) -> Result<()> {
let status = unsafe { bindings::rocprofiler_read(self.context, group_index) };
Error::from_hsa_status(status)
}
pub fn get_data(&mut self, group_index: u32) -> Result<()> {
let status = unsafe { bindings::rocprofiler_get_data(self.context, group_index) };
if status != bindings::hsa_status_t_HSA_STATUS_SUCCESS {
return Err(Error::new(status));
}
for (i, c_feature) in self.c_features.iter().enumerate() {
if i < self.features.len() {
self.features[i].update_from_c(*c_feature);
}
}
Ok(())
}
pub fn features(&self) -> &[Feature] {
&self.features
}
pub fn features_mut(&mut self) -> &mut [Feature] {
&mut self.features
}
pub fn reset(&self, group_index: u32) -> Result<()> {
let status = unsafe { bindings::rocprofiler_reset(self.context, group_index) };
Error::from_hsa_status(status)
}
pub fn agent(&self) -> &hip::Device {
&self.agent
}
pub fn get_metrics(&self) -> Result<()> {
let status = unsafe { bindings::rocprofiler_get_metrics(self.context) };
Error::from_hsa_status(status)
}
}
impl Drop for Context {
fn drop(&mut self) {
if !self.context.is_null() {
unsafe {
let _ = bindings::rocprofiler_close(self.context);
self.context = ptr::null_mut();
}
}
}
}