use crate::hip::DeviceMemory;
use crate::*;
use crate::rocblas::bindings::_rocblas_handle;
use crate::rocblas::error::{Error, Result};
use crate::rocblas::ffi;
use crate::rocblas::handle::Handle;
pub fn scal<T>(handle: &Handle, n: i32, alpha: &T, x: &DeviceMemory<T>, incx: i32) -> Result<()>
where
T: ScalType,
{
unsafe { T::rocblas_scal(handle, n, alpha, x.as_ptr().cast(), incx) }
}
pub fn scal_batched<T>(
handle: &Handle,
n: i32,
alpha: &T,
x: *const *mut T,
incx: i32,
batch_count: i32,
) -> Result<()>
where
T: ScalBatchedType,
{
unsafe { T::rocblas_scal_batched(handle, n, alpha, x, incx, batch_count) }
}
pub unsafe fn scal_strided_batched<T>(
handle: &Handle,
n: i32,
alpha: &T,
x: *mut T,
incx: i32,
stride_x: i64,
batch_count: i32,
) -> Result<()>
where
T: ScalStridedBatchedType,
{
unsafe {
T::rocblas_scal_strided_batched(handle, n, alpha, x, incx, stride_x, batch_count)
}
}
pub unsafe fn copy<T>(
handle: &Handle,
n: i32,
x: *const T,
incx: i32,
y: *mut T,
incy: i32,
) -> Result<()>
where
T: CopyType,
{
unsafe { T::rocblas_copy(handle, n, x, incx, y, incy) }
}
pub unsafe fn copy_batched<T>(
handle: &Handle,
n: i32,
x: *const *const T,
incx: i32,
y: *const *mut T,
incy: i32,
batch_count: i32,
) -> Result<()>
where
T: CopyBatchedType,
{
unsafe {
T::rocblas_copy_batched(handle, n, x, incx, y, incy, batch_count)
}
}
pub unsafe fn copy_strided_batched<T>(
handle: &Handle,
n: i32,
x: *const T,
incx: i32,
stridex: i64,
y: *mut T,
incy: i32,
stridey: i64,
batch_count: i32,
) -> Result<()>
where
T: CopyStridedBatchedType,
{
unsafe {
T::rocblas_copy_strided_batched(handle, n, x, incx, stridex, y, incy, stridey, batch_count)
}
}
pub unsafe fn dot<T>(
handle: &Handle,
n: i32,
x: *const T,
incx: i32,
y: *const T,
incy: i32,
result: *mut T,
) -> Result<()>
where
T: DotType,
{
unsafe { T::rocblas_dot(handle, n, x, incx, y, incy, result) }
}
pub unsafe fn dotu<T>(
handle: &Handle,
n: i32,
x: *const T,
incx: i32,
y: *const T,
incy: i32,
result: *mut T,
) -> Result<()>
where
T: DotuType,
{
unsafe { T::rocblas_dotu(handle, n, x, incx, y, incy, result) }
}
pub unsafe fn dotc<T>(
handle: &Handle,
n: i32,
x: *const T,
incx: i32,
y: *const T,
incy: i32,
result: *mut T,
) -> Result<()>
where
T: DotcType,
{
unsafe { T::rocblas_dotc(handle, n, x, incx, y, incy, result) }
}
impl_rocblas_traits!(
ScalType,
ScalTypeFn,
{
f32 => ffi::rocblas_sscal,
f64 => ffi::rocblas_dscal,
ffi::rocblas_float_complex => ffi::rocblas_cscal,
ffi::rocblas_double_complex => ffi::rocblas_zscal,
},
rocblas_scal,
(handle: &Handle, n: i32, alpha: &Self, x: *mut Self, incx: i32),
(*mut _rocblas_handle, i32, *const T, *mut T, i32),
(handle.as_raw(), n, alpha, x, incx)
);
impl_rocblas_traits!(
ScalBatchedType,
ScalBatchedTypeFn,
{
f32 => ffi::rocblas_sscal_batched,
f64 => ffi::rocblas_dscal_batched,
ffi::rocblas_float_complex => ffi::rocblas_cscal_batched,
ffi::rocblas_double_complex => ffi::rocblas_zscal_batched,
},
rocblas_scal_batched,
(handle: &Handle, n: i32, alpha: &Self, x: *const *mut Self, incx: i32, batch_count: i32),
(*mut _rocblas_handle, i32, *const T, *const *mut T, i32, i32),
(handle.as_raw(), n, alpha, x, incx, batch_count)
);
impl_rocblas_traits!(
ScalStridedBatchedType,
ScalStridedBatchedTypeFn,
{
f32 => ffi::rocblas_sscal_strided_batched,
f64 => ffi::rocblas_dscal_strided_batched,
ffi::rocblas_float_complex => ffi::rocblas_cscal_strided_batched,
ffi::rocblas_double_complex => ffi::rocblas_zscal_strided_batched,
},
rocblas_scal_strided_batched,
(handle: &Handle, n: i32, alpha: &Self, x: *mut Self, incx: i32, stride_x: i64, batch_count: i32),
(*mut _rocblas_handle, i32, *const T, *mut T, i32, i64, i32),
(handle.as_raw(), n, alpha, x, incx, stride_x, batch_count)
);
impl_rocblas_traits!(
CopyType,
CopyTypeFn,
{
f32 => ffi::rocblas_scopy,
f64 => ffi::rocblas_dcopy,
ffi::rocblas_float_complex => ffi::rocblas_ccopy,
ffi::rocblas_double_complex => ffi::rocblas_zcopy,
},
rocblas_copy,
(handle: &Handle, n: i32, x: *const Self, incx: i32, y: *mut Self, incy: i32),
(*mut _rocblas_handle, i32, *const T, i32, *mut T, i32),
(handle.as_raw(), n, x, incx, y, incy)
);
impl_rocblas_traits!(
CopyBatchedType,
CopyBatchedTypeFn,
{
f32 => ffi::rocblas_scopy_batched,
f64 => ffi::rocblas_dcopy_batched,
ffi::rocblas_float_complex => ffi::rocblas_ccopy_batched,
ffi::rocblas_double_complex => ffi::rocblas_zcopy_batched,
},
rocblas_copy_batched,
(handle: &Handle, n: i32, x: *const *const Self, incx: i32, y: *const *mut Self, incy: i32, batch_count: i32,),
(*mut _rocblas_handle, i32, *const *const T, i32, *const *mut T, i32, i32),
(handle.as_raw(), n, x, incx, y, incy, batch_count)
);
impl_rocblas_traits!(
CopyStridedBatchedType,
CopyStridedBatchedTypeFn,
{
f32 => ffi::rocblas_scopy_strided_batched,
f64 => ffi::rocblas_dcopy_strided_batched,
ffi::rocblas_float_complex => ffi::rocblas_ccopy_strided_batched,
ffi::rocblas_double_complex => ffi::rocblas_zcopy_strided_batched,
},
rocblas_copy_strided_batched,
(handle: &Handle, n: i32, x: *const Self, incx: i32, stridex: i64, y: *mut Self, incy: i32, stridey: i64, batch_count: i32),
(*mut _rocblas_handle, i32, *const T, i32, i64, *mut T, i32, i64, i32),
(handle.as_raw(), n, x, incx, stridex, y, incy, stridey, batch_count)
);
impl_rocblas_traits!(
DotType,
DotTypeFn,
{
f32 => ffi::rocblas_sdot,
f64 => ffi::rocblas_ddot,
ffi::rocblas_half => ffi::rocblas_hdot,
ffi::rocblas_bfloat16 => ffi::rocblas_bfdot,
},
rocblas_dot,
(handle: &Handle, n: i32, x: *const Self, incx: i32, y: *const Self, incy: i32, result: *mut Self),
(*mut _rocblas_handle, i32, *const T, i32, *const T, i32, *mut T),
(handle.as_raw(), n, x, incx, y, incy, result)
);
impl_rocblas_traits!(
DotuType,
DotuTypeFn,
{
ffi::rocblas_float_complex => ffi::rocblas_cdotu,
ffi::rocblas_double_complex => ffi::rocblas_zdotu,
},
rocblas_dotu,
(handle: &Handle, n: i32, x: *const Self, incx: i32, y: *const Self, incy: i32, result: *mut Self),
(*mut _rocblas_handle, i32, *const T, i32, *const T, i32, *mut T),
(handle.as_raw(), n, x, incx, y, incy, result)
);
impl_rocblas_traits!(
DotcType,
DotcTypeFn,
{
ffi::rocblas_float_complex => ffi::rocblas_cdotc,
ffi::rocblas_double_complex => ffi::rocblas_zdotc,
},
rocblas_dotc,
(handle: &Handle, n: i32, x: *const Self, incx: i32, y: *const Self, incy: i32, result: *mut Self),
(*mut _rocblas_handle, i32, *const T, i32, *const T, i32, *mut T),
(handle.as_raw(), n, x, incx, y, incy, result)
);
pub fn axpy<T>(
_handle: &Handle,
_n: i32,
_alpha: &T,
_x: *const T,
_incx: i32,
_y: *mut T,
_incy: i32,
) -> Result<()> {
todo!()
}
pub fn nrm2<T, R>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_result: *mut R,
) -> Result<()> {
todo!()
}
pub fn asum<T, R>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_result: *mut R,
) -> Result<()> {
todo!()
}
pub fn amax<T>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_result: *mut i32,
) -> Result<()> {
todo!()
}
pub fn amin<T>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_result: *mut i32,
) -> Result<()> {
todo!()
}
pub fn swap<T>(
_handle: &Handle,
_n: i32,
_x: *mut T,
_incx: i32,
_y: *mut T,
_incy: i32,
) -> Result<()> {
todo!()
}
pub fn rot<T>(
_handle: &Handle,
_n: i32,
_x: *mut T,
_incx: i32,
_y: *mut T,
_incy: i32,
_c: *const f32,
_s: *const f32,
) -> Result<()> {
todo!()
}
pub fn rotg<T>(_handle: &Handle, _a: *mut T, _b: *mut T, _c: *mut T, _s: *mut T) -> Result<()> {
todo!()
}
pub fn rotm<T>(
_handle: &Handle,
_n: i32,
_x: *mut T,
_incx: i32,
_y: *mut T,
_incy: i32,
_param: *const T,
) -> Result<()> {
todo!()
}
pub fn rotmg<T>(
_handle: &Handle,
_d1: *mut T,
_d2: *mut T,
_x1: *mut T,
_y1: *const T,
_param: *mut T,
) -> Result<()> {
todo!()
}
pub fn axpy_batched<T>(
_handle: &Handle,
_n: i32,
_alpha: &T,
_x: *const *const T,
_incx: i32,
_y: *const *mut T,
_incy: i32,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn dot_batched<T, R>(
_handle: &Handle,
_n: i32,
_x: *const *const T,
_incx: i32,
_y: *const *const T,
_incy: i32,
_batch_count: i32,
_result: *mut R,
) -> Result<()> {
todo!()
}
pub fn dotu_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const *const T,
_incx: i32,
_y: *const *const T,
_incy: i32,
_batch_count: i32,
_result: *mut T,
) -> Result<()> {
todo!()
}
pub fn dotc_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const *const T,
_incx: i32,
_y: *const *const T,
_incy: i32,
_batch_count: i32,
_result: *mut T,
) -> Result<()> {
todo!()
}
pub fn nrm2_batched<T, R>(
_handle: &Handle,
_n: i32,
_x: *const *const T,
_incx: i32,
_batch_count: i32,
_result: *mut R,
) -> Result<()> {
todo!()
}
pub fn asum_batched<T, R>(
_handle: &Handle,
_n: i32,
_x: *const *const T,
_incx: i32,
_batch_count: i32,
_result: *mut R,
) -> Result<()> {
todo!()
}
pub fn amax_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const *const T,
_incx: i32,
_batch_count: i32,
_result: *mut i32,
) -> Result<()> {
todo!()
}
pub fn amin_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const *const T,
_incx: i32,
_batch_count: i32,
_result: *mut i32,
) -> Result<()> {
todo!()
}
pub fn swap_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const *mut T,
_incx: i32,
_y: *const *mut T,
_incy: i32,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn rot_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const *mut T,
_incx: i32,
_y: *const *mut T,
_incy: i32,
_c: *const f32,
_s: *const f32,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn rotg_batched<T>(
_handle: &Handle,
_a: *const *mut T,
_b: *const *mut T,
_c: *const *mut T,
_s: *const *mut T,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn rotm_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const *mut T,
_incx: i32,
_y: *const *mut T,
_incy: i32,
_param: *const *const T,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn rotmg_batched<T>(
_handle: &Handle,
_d1: *const *mut T,
_d2: *const *mut T,
_x1: *const *mut T,
_y1: *const *const T,
_param: *const *mut T,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn axpy_strided_batched<T>(
_handle: &Handle,
_n: i32,
_alpha: &T,
_x: *const T,
_incx: i32,
_stridex: i64,
_y: *mut T,
_incy: i32,
_stridey: i64,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn dot_strided_batched<T, R>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_stridex: i64,
_y: *const T,
_incy: i32,
_stridey: i64,
_batch_count: i32,
_result: *mut R,
) -> Result<()> {
todo!()
}
pub fn dotu_strided_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_stridex: i64,
_y: *const T,
_incy: i32,
_stridey: i64,
_batch_count: i32,
_result: *mut T,
) -> Result<()> {
todo!()
}
pub fn dotc_strided_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_stridex: i64,
_y: *const T,
_incy: i32,
_stridey: i64,
_batch_count: i32,
_result: *mut T,
) -> Result<()> {
todo!()
}
pub fn nrm2_strided_batched<T, R>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_stridex: i64,
_batch_count: i32,
_result: *mut R,
) -> Result<()> {
todo!()
}
pub fn asum_strided_batched<T, R>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_stridex: i64,
_batch_count: i32,
_result: *mut R,
) -> Result<()> {
todo!()
}
pub fn amax_strided_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_stridex: i64,
_batch_count: i32,
_result: *mut i32,
) -> Result<()> {
todo!()
}
pub fn amin_strided_batched<T>(
_handle: &Handle,
_n: i32,
_x: *const T,
_incx: i32,
_stridex: i64,
_batch_count: i32,
_result: *mut i32,
) -> Result<()> {
todo!()
}
pub fn swap_strided_batched<T>(
_handle: &Handle,
_n: i32,
_x: *mut T,
_incx: i32,
_stridex: i64,
_y: *mut T,
_incy: i32,
_stridey: i64,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn rot_strided_batched<T>(
_handle: &Handle,
_n: i32,
_x: *mut T,
_incx: i32,
_stridex: i64,
_y: *mut T,
_incy: i32,
_stridey: i64,
_c: *const f32,
_s: *const f32,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn rotg_strided_batched<T>(
_handle: &Handle,
_a: *mut T,
_stridea: i64,
_b: *mut T,
_strideb: i64,
_c: *mut T,
_stridec: i64,
_s: *mut T,
_strides: i64,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn rotm_strided_batched<T>(
_handle: &Handle,
_n: i32,
_x: *mut T,
_incx: i32,
_stridex: i64,
_y: *mut T,
_incy: i32,
_stridey: i64,
_param: *const T,
_param_stride: i64,
_batch_count: i32,
) -> Result<()> {
todo!()
}
pub fn rotmg_strided_batched<T>(
_handle: &Handle,
_d1: *mut T,
_stride_d1: i64,
_d2: *mut T,
_stride_d2: i64,
_x1: *mut T,
_stride_x1: i64,
_y1: *const T,
_stride_y1: i64,
_param: *mut T,
_stride_param: i64,
_batch_count: i32,
) -> Result<()> {
todo!()
}