use crate::rocprofiler::error::{Error, Result};
use crate::rocprofiler::types::{Feature, ProfilerMode, InfoData, InfoKind};
use crate::rocprofiler::context::{Context, Properties};
use crate::rocprofiler::bindings;
use crate::hip;
use std::sync::Once;
use std::ffi::{CStr, CString};
use std::ptr;
use std::os::raw::c_void;
static INIT: Once = Once::new();
pub fn init() -> Result<()> {
Ok(())
}
pub fn version_string() -> String {
let major = unsafe { bindings::rocprofiler_version_major() };
let minor = unsafe { bindings::rocprofiler_version_minor() };
format!("{}.{}", major, minor)
}
unsafe extern "C" fn info_callback(
info: bindings::rocprofiler_info_data_t,
data: *mut c_void,
) -> u32 {
if data.is_null() {
return bindings::hsa_status_t_HSA_STATUS_ERROR_INVALID_ARGUMENT;
}
let results = &mut *(data as *mut Vec<InfoData>);
match InfoData::from_c(&info) {
Ok(info_data) => {
results.push(info_data);
bindings::hsa_status_t_HSA_STATUS_SUCCESS
},
Err(_) => {
bindings::hsa_status_t_HSA_STATUS_ERROR
}
}
}
pub fn get_metrics(device: Option<&hip::Device>) -> Result<Vec<InfoData>> {
let mut results = Vec::new();
let agent_ptr = match device {
Some(dev) => {
let agent = bindings::hsa_agent_t {
handle: dev.id() as u64,
};
&agent as *const _
},
None => ptr::null(),
};
let status = unsafe {
bindings::rocprofiler_iterate_info(
agent_ptr,
InfoKind::Metric as u32,
Some(info_callback),
&mut results as *mut _ as *mut c_void,
)
};
if status != bindings::hsa_status_t_HSA_STATUS_SUCCESS {
return Err(Error::new(status));
}
Ok(results)
}
pub fn get_traces(device: Option<&hip::Device>) -> Result<Vec<InfoData>> {
let mut results = Vec::new();
let agent_ptr = match device {
Some(dev) => {
let agent = bindings::hsa_agent_t {
handle: dev.id() as u64,
};
&agent as *const _
},
None => ptr::null(),
};
let status = unsafe {
bindings::rocprofiler_iterate_info(
agent_ptr,
InfoKind::Trace as u32,
Some(info_callback),
&mut results as *mut _ as *mut c_void,
)
};
if status != bindings::hsa_status_t_HSA_STATUS_SUCCESS {
return Err(Error::new(status));
}
Ok(results)
}
pub struct Profiler {
context: Context,
device: hip::Device,
}
impl Profiler {
pub fn new(
device: hip::Device,
features: Vec<Feature>,
modes: &[ProfilerMode],
properties: Option<Properties>,
) -> Result<Self> {
let context = Context::new(device.clone(), features, modes, properties)?;
Ok(Self {
context,
device,
})
}
pub fn start(&self, group_index: u32) -> Result<()> {
self.context.start(group_index)
}
pub fn stop(&self, group_index: u32) -> Result<()> {
self.context.stop(group_index)
}
pub fn read(&self, group_index: u32) -> Result<()> {
self.context.read(group_index)
}
pub fn get_data(&mut self, group_index: u32) -> Result<()> {
self.context.get_data(group_index)
}
pub fn profile_all(&mut self) -> Result<()> {
let group_count = self.context.group_count()?;
for i in 0..group_count {
self.start(i)?;
self.stop(i)?;
self.read(i)?;
self.get_data(i)?;
}
Ok(())
}
pub fn get_groups(&self) -> Result<Vec<crate::rocprofiler::types::Group>> {
let group_count = self.context.group_count()?;
let mut groups = Vec::with_capacity(group_count as usize);
for i in 0..group_count {
groups.push(self.context.get_group(i)?);
}
Ok(groups)
}
pub fn features(&self) -> &[Feature] {
self.context.features()
}
pub fn features_mut(&mut self) -> &mut [Feature] {
self.context.features_mut()
}
pub fn reset(&self, group_index: u32) -> Result<()> {
self.context.reset(group_index)
}
pub fn device(&self) -> &hip::Device {
&self.device
}
pub fn context(&self) -> &Context {
&self.context
}
pub fn context_mut(&mut self) -> &mut Context {
&mut self.context
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hip;
use crate::rocprofiler::types::{Parameter, ParameterName};
#[test]
fn test_version_string() {
let version = version_string();
println!("ROCProfiler version: {}", version);
assert!(version.contains('.'));
}
#[test]
fn test_get_metrics() {
if let Ok(device_count) = hip::device_count() {
if device_count > 0 {
if let Ok(device) = hip::Device::new(0) {
match get_metrics(Some(&device)) {
Ok(metrics) => {
println!("Found {} metrics", metrics.len());
for metric in &metrics {
if let InfoData::Metric(info) = metric {
println!(" {}", info.name);
}
}
},
Err(e) => {
println!("Error getting metrics: {}", e);
}
}
}
}
}
}
}