use crate::miopen::error::{Error, Result};
use crate::miopen::ffi;
use crate::miopen::handle::Handle;
use crate::miopen::tensor::TensorDescriptor;
use std::os::raw::c_void;
use std::ptr;
pub type ReduceTensorOp = ffi::miopenReduceTensorOp_t;
pub type NanPropagation = ffi::miopenNanPropagation_t;
pub type ReduceTensorIndices = ffi::miopenReduceTensorIndices_t;
pub type IndicesType = ffi::miopenIndicesType_t;
pub struct ReduceTensorDescriptor {
desc: ffi::miopenReduceTensorDescriptor_t,
}
impl ReduceTensorDescriptor {
pub fn new() -> Result<Self> {
let mut desc = ptr::null_mut();
let status = unsafe { ffi::miopenCreateReduceTensorDescriptor(&mut desc) };
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(Self { desc })
}
pub fn set(
&mut self,
reduce_op: ReduceTensorOp,
comp_type: ffi::miopenDataType_t,
nan_opt: NanPropagation,
indices: ReduceTensorIndices,
indices_type: IndicesType,
) -> Result<()> {
let status = unsafe {
ffi::miopenSetReduceTensorDescriptor(
self.desc,
reduce_op,
comp_type,
nan_opt,
indices,
indices_type,
)
};
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(())
}
pub fn get(
&self,
) -> Result<(
ReduceTensorOp,
ffi::miopenDataType_t,
NanPropagation,
ReduceTensorIndices,
IndicesType,
)> {
let mut reduce_op = 0;
let mut comp_type = 0;
let mut nan_opt = 0;
let mut indices = 0;
let mut indices_type = 0;
let status = unsafe {
ffi::miopenGetReduceTensorDescriptor(
self.desc,
&mut reduce_op,
&mut comp_type,
&mut nan_opt,
&mut indices,
&mut indices_type,
)
};
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok((reduce_op, comp_type, nan_opt, indices, indices_type))
}
pub fn as_raw(&self) -> ffi::miopenReduceTensorDescriptor_t {
self.desc
}
}
impl Drop for ReduceTensorDescriptor {
fn drop(&mut self) {
if !self.desc.is_null() {
unsafe {
let _ = ffi::miopenDestroyReduceTensorDescriptor(self.desc);
};
self.desc = ptr::null_mut();
}
}
}
pub fn get_reduction_indices_size(
handle: &Handle,
reduce_desc: &ReduceTensorDescriptor,
a_desc: &TensorDescriptor,
c_desc: &TensorDescriptor,
) -> Result<usize> {
let mut size_in_bytes = 0;
let status = unsafe {
ffi::miopenGetReductionIndicesSize(
handle.as_raw(),
reduce_desc.as_raw(),
a_desc.as_raw(),
c_desc.as_raw(),
&mut size_in_bytes,
)
};
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(size_in_bytes)
}
pub fn get_reduction_workspace_size(
handle: &Handle,
reduce_desc: &ReduceTensorDescriptor,
a_desc: &TensorDescriptor,
c_desc: &TensorDescriptor,
) -> Result<usize> {
let mut size_in_bytes = 0;
let status = unsafe {
ffi::miopenGetReductionWorkspaceSize(
handle.as_raw(),
reduce_desc.as_raw(),
a_desc.as_raw(),
c_desc.as_raw(),
&mut size_in_bytes,
)
};
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(size_in_bytes)
}
pub unsafe fn reduce_tensor(
handle: &Handle,
reduce_desc: &ReduceTensorDescriptor,
indices: *mut c_void,
indices_size: usize,
workspace: *mut c_void,
workspace_size: usize,
alpha: &[u8],
a_desc: &TensorDescriptor,
a: *const c_void,
beta: &[u8],
c_desc: &TensorDescriptor,
c: *mut c_void,
) -> Result<()> {
let status = unsafe {
ffi::miopenReduceTensor(
handle.as_raw(),
reduce_desc.as_raw(),
indices,
indices_size,
workspace,
workspace_size,
alpha.as_ptr() as *const c_void,
a_desc.as_raw(),
a,
beta.as_ptr() as *const c_void,
c_desc.as_raw(),
c,
)
};
if status != ffi::miopenStatus_t_miopenStatusSuccess {
return Err(Error::new(status));
}
Ok(())
}