use std::ffi::CStr;
use crate::rocblas::error::Result;
use crate::rocblas::ffi;
use crate::rocblas::handle::Handle;
use super::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PointerMode {
Host,
Device,
}
impl From<PointerMode> for ffi::rocblas_pointer_mode {
fn from(mode: PointerMode) -> Self {
match mode {
PointerMode::Host => ffi::rocblas_pointer_mode__rocblas_pointer_mode_host,
PointerMode::Device => ffi::rocblas_pointer_mode__rocblas_pointer_mode_device,
}
}
}
impl From<ffi::rocblas_pointer_mode> for PointerMode {
fn from(mode: ffi::rocblas_pointer_mode) -> Self {
match mode {
ffi::rocblas_pointer_mode__rocblas_pointer_mode_host => PointerMode::Host,
ffi::rocblas_pointer_mode__rocblas_pointer_mode_device => PointerMode::Device,
_ => PointerMode::Host, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AtomicsMode {
NotAllowed,
Allowed,
}
impl From<AtomicsMode> for ffi::rocblas_atomics_mode {
fn from(mode: AtomicsMode) -> Self {
match mode {
AtomicsMode::NotAllowed => ffi::rocblas_atomics_mode__rocblas_atomics_not_allowed,
AtomicsMode::Allowed => ffi::rocblas_atomics_mode__rocblas_atomics_allowed,
}
}
}
impl From<ffi::rocblas_atomics_mode> for AtomicsMode {
fn from(mode: ffi::rocblas_atomics_mode) -> Self {
match mode {
ffi::rocblas_atomics_mode__rocblas_atomics_not_allowed => AtomicsMode::NotAllowed,
ffi::rocblas_atomics_mode__rocblas_atomics_allowed => AtomicsMode::Allowed,
_ => AtomicsMode::Allowed, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PerformanceMetric {
Default,
DeviceEfficiency,
CUEfficiency,
}
impl From<PerformanceMetric> for ffi::rocblas_performance_metric {
fn from(metric: PerformanceMetric) -> Self {
match metric {
PerformanceMetric::Default => {
ffi::rocblas_performance_metric__rocblas_default_performance_metric
}
PerformanceMetric::DeviceEfficiency => {
ffi::rocblas_performance_metric__rocblas_device_efficiency_performance_metric
}
PerformanceMetric::CUEfficiency => {
ffi::rocblas_performance_metric__rocblas_cu_efficiency_performance_metric
}
}
}
}
impl From<ffi::rocblas_performance_metric> for PerformanceMetric {
fn from(metric: ffi::rocblas_performance_metric) -> Self {
match metric {
ffi::rocblas_performance_metric__rocblas_default_performance_metric => {
PerformanceMetric::Default
}
ffi::rocblas_performance_metric__rocblas_device_efficiency_performance_metric => {
PerformanceMetric::DeviceEfficiency
}
ffi::rocblas_performance_metric__rocblas_cu_efficiency_performance_metric => {
PerformanceMetric::CUEfficiency
}
_ => PerformanceMetric::Default, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerMode {
None,
LogTrace,
LogBench,
LogProfile,
}
impl From<LayerMode> for ffi::rocblas_layer_mode {
fn from(mode: LayerMode) -> Self {
match mode {
LayerMode::None => ffi::rocblas_layer_mode__rocblas_layer_mode_none,
LayerMode::LogTrace => ffi::rocblas_layer_mode__rocblas_layer_mode_log_trace,
LayerMode::LogBench => ffi::rocblas_layer_mode__rocblas_layer_mode_log_bench,
LayerMode::LogProfile => ffi::rocblas_layer_mode__rocblas_layer_mode_log_profile,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GemmAlgo {
Standard,
SolutionIndex,
}
impl From<GemmAlgo> for ffi::rocblas_gemm_algo {
fn from(algo: GemmAlgo) -> Self {
match algo {
GemmAlgo::Standard => ffi::rocblas_gemm_algo__rocblas_gemm_algo_standard,
GemmAlgo::SolutionIndex => ffi::rocblas_gemm_algo__rocblas_gemm_algo_solution_index,
}
}
}
impl From<ffi::rocblas_gemm_algo> for GemmAlgo {
fn from(algo: ffi::rocblas_gemm_algo) -> Self {
match algo {
ffi::rocblas_gemm_algo__rocblas_gemm_algo_standard => GemmAlgo::Standard,
ffi::rocblas_gemm_algo__rocblas_gemm_algo_solution_index => GemmAlgo::SolutionIndex,
_ => GemmAlgo::Standard, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GemmFlags {
None,
UseCUEfficiency,
FP16AltImpl,
CheckSolutionIndex,
FP16AltImplRNZ,
StochasticRounding,
}
impl From<GemmFlags> for ffi::rocblas_gemm_flags {
fn from(flags: GemmFlags) -> Self {
match flags {
GemmFlags::None => ffi::rocblas_gemm_flags__rocblas_gemm_flags_none,
GemmFlags::UseCUEfficiency => {
ffi::rocblas_gemm_flags__rocblas_gemm_flags_use_cu_efficiency
}
GemmFlags::FP16AltImpl => ffi::rocblas_gemm_flags__rocblas_gemm_flags_fp16_alt_impl,
GemmFlags::CheckSolutionIndex => {
ffi::rocblas_gemm_flags__rocblas_gemm_flags_check_solution_index
}
GemmFlags::FP16AltImplRNZ => {
ffi::rocblas_gemm_flags__rocblas_gemm_flags_fp16_alt_impl_rnz
}
GemmFlags::StochasticRounding => {
ffi::rocblas_gemm_flags__rocblas_gemm_flags_stochastic_rounding
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MathMode {
Default,
XF32XDLMathOp,
}
impl From<MathMode> for ffi::rocblas_math_mode {
fn from(mode: MathMode) -> Self {
match mode {
MathMode::Default => ffi::rocblas_math_mode__rocblas_default_math,
MathMode::XF32XDLMathOp => ffi::rocblas_math_mode__rocblas_xf32_xdl_math_op,
}
}
}
impl From<ffi::rocblas_math_mode> for MathMode {
fn from(mode: ffi::rocblas_math_mode) -> Self {
match mode {
ffi::rocblas_math_mode__rocblas_default_math => MathMode::Default,
ffi::rocblas_math_mode__rocblas_xf32_xdl_math_op => MathMode::XF32XDLMathOp,
_ => MathMode::Default, }
}
}
pub fn set_pointer_mode(handle: &Handle, mode: PointerMode) -> Result<()> {
handle.set_pointer_mode(mode.into())
}
pub fn get_pointer_mode(handle: &Handle) -> Result<PointerMode> {
let mode = handle.get_pointer_mode()?;
Ok(mode.into())
}
pub fn set_atomics_mode(handle: &Handle, mode: AtomicsMode) -> Result<()> {
handle.set_atomics_mode(mode.into())
}
pub fn get_atomics_mode(handle: &Handle) -> Result<AtomicsMode> {
let mode = handle.get_atomics_mode()?;
Ok(mode.into())
}
pub fn set_performance_metric(handle: &Handle, metric: PerformanceMetric) -> Result<()> {
handle.set_performance_metric(metric.into())
}
pub fn get_performance_metric(handle: &Handle) -> Result<PerformanceMetric> {
let metric = handle.get_performance_metric()?;
Ok(metric.into())
}
pub fn set_math_mode(handle: &Handle, mode: MathMode) -> Result<()> {
handle.set_math_mode(mode.into())
}
pub fn get_math_mode(handle: &Handle) -> Result<MathMode> {
let mode = handle.get_math_mode()?;
Ok(mode.into())
}
pub fn status_to_string(status: ffi::rocblas_status) -> String {
unsafe {
let c_str = ffi::rocblas_status_to_string(status);
if c_str.is_null() {
return String::from("Unknown rocBLAS status");
}
CStr::from_ptr(c_str).to_string_lossy().into_owned()
}
}
pub fn initialize() {
unsafe {
ffi::rocblas_initialize();
}
}
pub fn get_version_string() -> Result<String> {
let mut size: usize = 0;
let status = unsafe { ffi::rocblas_get_version_string_size(&mut size) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
let mut buffer = vec![0u8; size];
let status = unsafe { ffi::rocblas_get_version_string(buffer.as_mut_ptr() as *mut i8, size) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(String::from_utf8_lossy(
&buffer[..buffer.iter().position(|&b| b == 0).unwrap_or(buffer.len())],
)
.into_owned())
}
pub fn start_device_memory_size_query(handle: &Handle) -> Result<()> {
let status = unsafe { ffi::rocblas_start_device_memory_size_query(handle.as_raw()) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
pub fn stop_device_memory_size_query(handle: &Handle) -> Result<usize> {
let mut size: usize = 0;
let status = unsafe { ffi::rocblas_stop_device_memory_size_query(handle.as_raw(), &mut size) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(size)
}
pub fn is_device_memory_size_query(handle: &Handle) -> bool {
unsafe { ffi::rocblas_is_device_memory_size_query(handle.as_raw()) }
}
pub fn get_device_memory_size(handle: &Handle) -> Result<usize> {
let mut size: usize = 0;
let status = unsafe { ffi::rocblas_get_device_memory_size(handle.as_raw(), &mut size) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(size)
}
pub fn set_device_memory_size(handle: &Handle, size: usize) -> Result<()> {
let status = unsafe { ffi::rocblas_set_device_memory_size(handle.as_raw(), size) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
pub unsafe fn set_workspace(
handle: &Handle,
addr: *mut std::ffi::c_void,
size: usize,
) -> Result<()> {
let status = unsafe { ffi::rocblas_set_workspace(handle.as_raw(), addr, size) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
pub fn is_managing_device_memory(handle: &Handle) -> bool {
unsafe { ffi::rocblas_is_managing_device_memory(handle.as_raw()) }
}
pub fn is_user_managing_device_memory(handle: &Handle) -> bool {
unsafe { ffi::rocblas_is_user_managing_device_memory(handle.as_raw()) }
}
pub fn device_malloc_set_default_memory_size(size: usize) {
unsafe {
ffi::rocblas_device_malloc_set_default_memory_size(size);
}
}
pub fn abort() -> ! {
unsafe {
ffi::rocblas_abort();
}
}