use crate::rocblas::bindings::_rocblas_handle;
use crate::rocblas::error::{Error, Result};
use crate::rocblas::handle::Handle;
use crate::rocblas::types::{Diagonal, Fill, Operation};
use crate::rocblas::{ffi, rocblas_operation};
use crate::*;
use super::level3::{
HemmType, HerkType, Spr2Type, SprType, SyrBatchedType, SyrStridedBatchedType, SyrType,
Syr2Type,
};
use super::types::Side;
pub unsafe fn gemv<T>(
handle: &Handle,
trans: Operation,
m: i32,
n: i32,
alpha: &T,
A: *const T,
lda: i32,
x: *const T,
incx: i32,
beta: &T,
y: *mut T,
incy: i32,
) -> Result<()>
where
T: GemvType,
{
unsafe { T::rocblas_gemv(handle, trans, m, n, alpha, A, lda, x, incx, beta, y, incy) }
}
pub unsafe fn gemv_batched<T>(
handle: &Handle,
trans: Operation,
m: i32,
n: i32,
alpha: &T,
A: *const *const T,
lda: i32,
x: *const *const T,
incx: i32,
beta: &T,
y: *const *mut T,
incy: i32,
batch_count: i32,
) -> Result<()>
where
T: GemvBatchedType,
{
unsafe {
T::rocblas_gemv_batched(
handle,
trans,
m,
n,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
batch_count,
)
}
}
pub unsafe fn gemv_strided_batched<T>(
handle: &Handle,
trans: Operation,
m: i32,
n: i32,
alpha: &T,
A: *const T,
lda: i32,
stride_A: i64,
x: *const T,
incx: i32,
stride_x: i64,
beta: &T,
y: *mut T,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()>
where
T: GemvStridedBatchedType,
{
unsafe {
T::rocblas_gemv_strided_batched(
handle,
trans,
m,
n,
alpha,
A,
lda,
stride_A,
x,
incx,
stride_x,
beta,
y,
incy,
stride_y,
batch_count,
)
}
}
pub unsafe fn gbmv<T>(
handle: &Handle,
trans: Operation,
m: i32,
n: i32,
kl: i32,
ku: i32,
alpha: &T,
A: *const T,
lda: i32,
x: *const T,
incx: i32,
beta: &T,
y: *mut T,
incy: i32,
) -> Result<()>
where
T: GbmvType,
{
unsafe {
T::rocblas_gbmv(
handle, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy,
)
}
}
pub unsafe fn gbmv_batched<T>(
handle: &Handle,
trans: Operation,
m: i32,
n: i32,
kl: i32,
ku: i32,
alpha: &T,
A: *const *const T,
lda: i32,
x: *const *const T,
incx: i32,
beta: &T,
y: *const *mut T,
incy: i32,
batch_count: i32,
) -> Result<()>
where
T: GbmvBatchedType,
{
unsafe {
T::rocblas_gbmv_batched(
handle,
trans,
m,
n,
kl,
ku,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
batch_count,
)
}
}
pub unsafe fn gbmv_strided_batched<T>(
handle: &Handle,
trans: Operation,
m: i32,
n: i32,
kl: i32,
ku: i32,
alpha: &T,
A: *const T,
lda: i32,
stride_A: i64,
x: *const T,
incx: i32,
stride_x: i64,
beta: &T,
y: *mut T,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()>
where
T: GbmvStridedBatchedType,
{
unsafe {
T::rocblas_gbmv_strided_batched(
handle,
trans,
m,
n,
kl,
ku,
alpha,
A,
lda,
stride_A,
x,
incx,
stride_x,
beta,
y,
incy,
stride_y,
batch_count,
)
}
}
pub unsafe fn hbmv<T>(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &T,
A: *const T,
lda: i32,
x: *const T,
incx: i32,
beta: &T,
y: *mut T,
incy: i32,
) -> Result<()>
where
T: HbmvType,
{
unsafe { T::rocblas_hbmv(handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy) }
}
pub unsafe fn hbmv_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &T,
A: *const *const T,
lda: i32,
x: *const *const T,
incx: i32,
beta: &T,
y: *const *mut T,
incy: i32,
batch_count: i32,
) -> Result<()>
where
T: HbmvBatchedType,
{
unsafe {
T::rocblas_hbmv_batched(
handle,
uplo,
n,
k,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
batch_count,
)
}
}
pub unsafe fn hbmv_strided_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &T,
A: *const T,
lda: i32,
stride_A: i64,
x: *const T,
incx: i32,
stride_x: i64,
beta: &T,
y: *mut T,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()>
where
T: HbmvStridedBatchedType,
{
unsafe {
T::rocblas_hbmv_strided_batched(
handle,
uplo,
n,
k,
alpha,
A,
lda,
stride_A,
x,
incx,
stride_x,
beta,
y,
incy,
stride_y,
batch_count,
)
}
}
impl_rocblas_traits!(
GemvType,
GemvFn,
{
f32 => ffi::rocblas_sgemv,
f64 => ffi::rocblas_dgemv,
ffi::rocblas_float_complex => ffi::rocblas_cgemv,
ffi::rocblas_double_complex => ffi::rocblas_zgemv,
},
rocblas_gemv,
(handle: &Handle, trans: Operation, m: i32, n: i32, alpha: &Self, A: *const Self, lda: i32, x: *const Self, incx: i32, beta: &Self, y: *mut Self, incy: i32),
(*mut _rocblas_handle, rocblas_operation, i32, i32, *const T, *const T, i32, *const T, i32, *const T, *mut T, i32),
(handle.as_raw(), trans.into(), m, n, alpha, A, lda, x, incx, beta, y, incy)
);
impl_rocblas_traits!(
GemvBatchedType,
GemvBatchedFn,
{
f32 => ffi::rocblas_sgemv_batched,
f64 => ffi::rocblas_dgemv_batched,
ffi::rocblas_float_complex => ffi::rocblas_cgemv_batched,
ffi::rocblas_double_complex => ffi::rocblas_zgemv_batched,
},
rocblas_gemv_batched,
(handle: &Handle, trans: Operation, m: i32, n: i32, alpha: &Self, A: *const *const Self, lda: i32, x: *const *const Self, incx: i32, beta: &Self, y: *const *mut Self, incy: i32, batch_count: i32),
(*mut _rocblas_handle, rocblas_operation, i32, i32, *const T, *const *const T, i32, *const *const T, i32, *const T, *const *mut T, i32, i32),
(handle.as_raw(), trans.into(), m, n, alpha, A, lda, x, incx, beta, y, incy, batch_count)
);
impl_rocblas_traits!(
GemvStridedBatchedType,
GemvStridedBatchedFn,
{
f32 => ffi::rocblas_sgemv_strided_batched,
f64 => ffi::rocblas_dgemv_strided_batched,
ffi::rocblas_float_complex => ffi::rocblas_cgemv_strided_batched,
ffi::rocblas_double_complex => ffi::rocblas_zgemv_strided_batched,
},
rocblas_gemv_strided_batched,
(handle: &Handle,trans: Operation,m: i32,n: i32,alpha: &Self,A: *const Self,lda: i32,stride_A: i64,x: *const Self,incx: i32,stride_x: i64,beta: &Self,y: *mut Self,incy: i32,stride_y: i64,batch_count: i32),
(*mut _rocblas_handle,rocblas_operation,i32,i32,*const T,*const T,i32,i64,*const T,i32,i64,*const T,*mut T,i32,i64,i32),
(handle.as_raw(),trans.into(),m,n,alpha,A,lda,stride_A,x,incx,stride_x,beta,y,incy,stride_y,batch_count)
);
impl_rocblas_traits!(
GbmvType,
GbmvFn,
{
f32 => ffi::rocblas_sgbmv,
f64 => ffi::rocblas_dgbmv,
ffi::rocblas_float_complex => ffi::rocblas_cgbmv,
ffi::rocblas_double_complex => ffi::rocblas_zgbmv,
},
rocblas_gbmv,
(handle: &Handle,trans: Operation,m: i32,n: i32,kl: i32,ku: i32,alpha: &Self,A: *const Self,lda: i32,x: *const Self,incx: i32,beta: &Self,y: *mut Self,incy: i32),
(*mut _rocblas_handle,rocblas_operation,i32,i32,i32,i32,*const T,*const T,i32,*const T,i32,*const T,*mut T,i32),
(handle.as_raw(),trans.into(),m,n,kl,ku,alpha,A,lda,x,incx,beta,y,incy)
);
pub trait GbmvBatchedType {
unsafe fn rocblas_gbmv_batched(
handle: &Handle,
trans: Operation,
m: i32,
n: i32,
kl: i32,
ku: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
x: *const *const Self,
incx: i32,
beta: &Self,
y: *const *mut Self,
incy: i32,
batch_count: i32,
) -> Result<()>;
}
impl GbmvBatchedType for f32 {
unsafe fn rocblas_gbmv_batched(
handle: &Handle,
trans: Operation,
m: i32,
n: i32,
kl: i32,
ku: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
x: *const *const Self,
incx: i32,
beta: &Self,
y: *const *mut Self,
incy: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sgbmv_batched(
handle.as_raw(),
trans.into(),
m,
n,
kl,
ku,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GbmvStridedBatchedType {
unsafe fn rocblas_gbmv_strided_batched(
handle: &Handle,
trans: Operation,
m: i32,
n: i32,
kl: i32,
ku: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
x: *const Self,
incx: i32,
stride_x: i64,
beta: &Self,
y: *mut Self,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()>;
}
pub trait HbmvType {
unsafe fn rocblas_hbmv(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()>;
}
impl HbmvType for ffi::rocblas_float_complex {
unsafe fn rocblas_hbmv(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_chbmv(
handle.as_raw(),
uplo.into(),
n,
k,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HbmvType for ffi::rocblas_double_complex {
unsafe fn rocblas_hbmv(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zhbmv(
handle.as_raw(),
uplo.into(),
n,
k,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HbmvBatchedType {
unsafe fn rocblas_hbmv_batched(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
x: *const *const Self,
incx: i32,
beta: &Self,
y: *const *mut Self,
incy: i32,
batch_count: i32,
) -> Result<()>;
}
pub trait HbmvStridedBatchedType {
unsafe fn rocblas_hbmv_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
x: *const Self,
incx: i32,
stride_x: i64,
beta: &Self,
y: *mut Self,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()>;
}
impl HbmvBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_hbmv_batched(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
x: *const *const Self,
incx: i32,
beta: &Self,
y: *const *mut Self,
incy: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_chbmv_batched(
handle.as_raw(),
uplo.into(),
n,
k,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HbmvBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_hbmv_batched(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
x: *const *const Self,
incx: i32,
beta: &Self,
y: *const *mut Self,
incy: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zhbmv_batched(
handle.as_raw(),
uplo.into(),
n,
k,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HbmvStridedBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_hbmv_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
x: *const Self,
incx: i32,
stride_x: i64,
beta: &Self,
y: *mut Self,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_chbmv_strided_batched(
handle.as_raw(),
uplo.into(),
n,
k,
alpha,
A,
lda,
stride_A,
x,
incx,
stride_x,
beta,
y,
incy,
stride_y,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HbmvStridedBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_hbmv_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
x: *const Self,
incx: i32,
stride_x: i64,
beta: &Self,
y: *mut Self,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zhbmv_strided_batched(
handle.as_raw(),
uplo.into(),
n,
k,
alpha,
A,
lda,
stride_A,
x,
incx,
stride_x,
beta,
y,
incy,
stride_y,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub unsafe fn hemv<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
A: *const T,
lda: i32,
x: *const T,
incx: i32,
beta: &T,
y: *mut T,
incy: i32,
) -> Result<()>
where
T: HemvType,
{
unsafe { T::rocblas_hemv(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy) }
}
pub unsafe fn hemv_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
A: *const *const T,
lda: i32,
x: *const *const T,
incx: i32,
beta: &T,
y: *const *mut T,
incy: i32,
batch_count: i32,
) -> Result<()>
where
T: HemvBatchedType,
{
unsafe {
T::rocblas_hemv_batched(
handle,
uplo,
n,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
batch_count,
)
}
}
pub unsafe fn hemv_strided_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
A: *const T,
lda: i32,
stride_A: i64,
x: *const T,
incx: i32,
stride_x: i64,
beta: &T,
y: *mut T,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()>
where
T: HemvStridedBatchedType,
{
unsafe {
T::rocblas_hemv_strided_batched(
handle,
uplo,
n,
alpha,
A,
lda,
stride_A,
x,
incx,
stride_x,
beta,
y,
incy,
stride_y,
batch_count,
)
}
}
pub trait HemvType {
unsafe fn rocblas_hemv(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()>;
}
impl HemvType for ffi::rocblas_float_complex {
unsafe fn rocblas_hemv(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_chemv(
handle.as_raw(),
uplo.into(),
n,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HemvType for ffi::rocblas_double_complex {
unsafe fn rocblas_hemv(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zhemv(
handle.as_raw(),
uplo.into(),
n,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HemvBatchedType {
unsafe fn rocblas_hemv_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
x: *const *const Self,
incx: i32,
beta: &Self,
y: *const *mut Self,
incy: i32,
batch_count: i32,
) -> Result<()>;
}
impl HemvBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_hemv_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
x: *const *const Self,
incx: i32,
beta: &Self,
y: *const *mut Self,
incy: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_chemv_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HemvBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_hemv_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
x: *const *const Self,
incx: i32,
beta: &Self,
y: *const *mut Self,
incy: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zhemv_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HemvStridedBatchedType {
unsafe fn rocblas_hemv_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
x: *const Self,
incx: i32,
stride_x: i64,
beta: &Self,
y: *mut Self,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()>;
}
impl HemvStridedBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_hemv_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
x: *const Self,
incx: i32,
stride_x: i64,
beta: &Self,
y: *mut Self,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_chemv_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
A,
lda,
stride_A,
x,
incx,
stride_x,
beta,
y,
incy,
stride_y,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HemvStridedBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_hemv_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
x: *const Self,
incx: i32,
stride_x: i64,
beta: &Self,
y: *mut Self,
incy: i32,
stride_y: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zhemv_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
A,
lda,
stride_A,
x,
incx,
stride_x,
beta,
y,
incy,
stride_y,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub unsafe fn ger<T>(
handle: &Handle,
m: i32,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
y: *const T,
incy: i32,
A: *mut T,
lda: i32,
) -> Result<()>
where
T: GerType,
{
unsafe { T::rocblas_ger(handle, m, n, alpha, x, incx, y, incy, A, lda) }
}
pub unsafe fn geru<T>(
handle: &Handle,
m: i32,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
y: *const T,
incy: i32,
A: *mut T,
lda: i32,
) -> Result<()>
where
T: GeruType,
{
unsafe { T::rocblas_geru(handle, m, n, alpha, x, incx, y, incy, A, lda) }
}
pub unsafe fn gerc<T>(
handle: &Handle,
m: i32,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
y: *const T,
incy: i32,
A: *mut T,
lda: i32,
) -> Result<()>
where
T: GercType,
{
unsafe { T::rocblas_gerc(handle, m, n, alpha, x, incx, y, incy, A, lda) }
}
pub unsafe fn ger_batched<T>(
handle: &Handle,
m: i32,
n: i32,
alpha: &T,
x: *const *const T,
incx: i32,
y: *const *const T,
incy: i32,
A: *const *mut T,
lda: i32,
batch_count: i32,
) -> Result<()>
where
T: GerBatchedType,
{
unsafe { T::rocblas_ger_batched(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count) }
}
pub unsafe fn geru_batched<T>(
handle: &Handle,
m: i32,
n: i32,
alpha: &T,
x: *const *const T,
incx: i32,
y: *const *const T,
incy: i32,
A: *const *mut T,
lda: i32,
batch_count: i32,
) -> Result<()>
where
T: GeruBatchedType,
{
unsafe { T::rocblas_geru_batched(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count) }
}
pub unsafe fn gerc_batched<T>(
handle: &Handle,
m: i32,
n: i32,
alpha: &T,
x: *const *const T,
incx: i32,
y: *const *const T,
incy: i32,
A: *const *mut T,
lda: i32,
batch_count: i32,
) -> Result<()>
where
T: GercBatchedType,
{
unsafe { T::rocblas_gerc_batched(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count) }
}
pub unsafe fn ger_strided_batched<T>(
handle: &Handle,
m: i32,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
stride_x: i64,
y: *const T,
incy: i32,
stride_y: i64,
A: *mut T,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>
where
T: GerStridedBatchedType,
{
unsafe {
T::rocblas_ger_strided_batched(
handle,
m,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
}
}
pub unsafe fn geru_strided_batched<T>(
handle: &Handle,
m: i32,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
stride_x: i64,
y: *const T,
incy: i32,
stride_y: i64,
A: *mut T,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>
where
T: GeruStridedBatchedType,
{
unsafe {
T::rocblas_geru_strided_batched(
handle,
m,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
}
}
pub unsafe fn gerc_strided_batched<T>(
handle: &Handle,
m: i32,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
stride_x: i64,
y: *const T,
incy: i32,
stride_y: i64,
A: *mut T,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>
where
T: GercStridedBatchedType,
{
unsafe {
T::rocblas_gerc_strided_batched(
handle,
m,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
}
}
pub trait GerType {
unsafe fn rocblas_ger(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()>;
}
impl GerType for f32 {
unsafe fn rocblas_ger(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_sger(handle.as_raw(), m, n, alpha, x, incx, y, incy, A, lda) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GerType for f64 {
unsafe fn rocblas_ger(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_dger(handle.as_raw(), m, n, alpha, x, incx, y, incy, A, lda) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GeruType {
unsafe fn rocblas_geru(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()>;
}
impl GeruType for ffi::rocblas_float_complex {
unsafe fn rocblas_geru(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_cgeru(handle.as_raw(), m, n, alpha, x, incx, y, incy, A, lda) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GeruType for ffi::rocblas_double_complex {
unsafe fn rocblas_geru(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_zgeru(handle.as_raw(), m, n, alpha, x, incx, y, incy, A, lda) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GercType {
unsafe fn rocblas_gerc(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()>;
}
impl GercType for ffi::rocblas_float_complex {
unsafe fn rocblas_gerc(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_cgerc(handle.as_raw(), m, n, alpha, x, incx, y, incy, A, lda) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GercType for ffi::rocblas_double_complex {
unsafe fn rocblas_gerc(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_zgerc(handle.as_raw(), m, n, alpha, x, incx, y, incy, A, lda) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GerBatchedType {
unsafe fn rocblas_ger_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()>;
}
pub unsafe fn spr<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
AP: *mut T,
) -> Result<()>
where
T: SprType,
{
unsafe {
T::rocblas_spr(handle, uplo, n, alpha, x, incx, AP)
}
}
pub unsafe fn syr<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
A: *mut T,
lda: i32,
) -> Result<()>
where
T: SyrType,
{
unsafe { T::rocblas_syr(handle, uplo, n, alpha, x, incx, A, lda) }
}
pub unsafe fn syr2<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
y: *const T,
incy: i32,
A: *mut T,
lda: i32,
) -> Result<()>
where
T: Syr2Type,
{
unsafe { T::rocblas_syr2(handle, uplo, n, alpha, x, incx, y, incy, A, lda) }
}
pub unsafe fn spr2<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
y: *const T,
incy: i32,
AP: *mut T,
) -> Result<()>
where
T: Spr2Type,
{
unsafe { T::rocblas_spr2(handle, uplo, n, alpha, x, incx, y, incy, AP) }
}
pub unsafe fn hemm<T>(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &T,
A: *const T,
lda: i32,
B: *const T,
ldb: i32,
beta: &T,
C: *mut T,
ldc: i32,
) -> Result<()>
where
T: HemmType,
{
unsafe {
T::rocblas_hemm(
handle, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc,
)
}
}
pub unsafe fn herk<T, R>(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &R,
A: *const T,
lda: i32,
beta: &R,
C: *mut T,
ldc: i32,
) -> Result<()>
where
T: HerkType<ScalarType = R>,
{
unsafe {
T::rocblas_herk(handle, uplo, transA, n, k, alpha, A, lda, beta, C, ldc)
}
}
pub unsafe fn syr_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const *const T,
incx: i32,
A: *const *mut T,
lda: i32,
batch_count: i32,
) -> Result<()>
where
T: SyrBatchedType,
{
unsafe {
T::rocblas_syr_batched(handle, uplo, n, alpha, x, incx, A, lda, batch_count)
}
}
pub unsafe fn syr_strided_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
stride_x: i64,
A: *mut T,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>
where
T: SyrStridedBatchedType,
{
unsafe {
T::rocblas_syr_strided_batched(
handle,
uplo,
n,
alpha,
x,
incx,
stride_x,
A,
lda,
stride_A,
batch_count,
)
}
}
pub unsafe fn syr2_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const *const T,
incx: i32,
y: *const *const T,
incy: i32,
A: *const *mut T,
lda: i32,
batch_count: i32,
) -> Result<()>
where
T: Syr2BatchedType,
{
unsafe {
T::rocblas_syr2_batched(
handle,
uplo,
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
}
}
pub unsafe fn syr2_strided_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
stride_x: i64,
y: *const T,
incy: i32,
stride_y: i64,
A: *mut T,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>
where
T: Syr2StridedBatchedType,
{
unsafe {
T::rocblas_syr2_strided_batched(
handle,
uplo,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
}
}
pub unsafe fn spr_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const *const T,
incx: i32,
AP: *const *mut T,
batch_count: i32,
) -> Result<()>
where
T: SprBatchedType,
{
unsafe { T::rocblas_spr_batched(handle, uplo, n, alpha, x, incx, AP, batch_count) }
}
pub unsafe fn spr_strided_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
stride_x: i64,
AP: *mut T,
stride_A: i64,
batch_count: i32,
) -> Result<()>
where
T: SprStridedBatchedType,
{
unsafe {
T::rocblas_spr_strided_batched(
handle,
uplo,
n,
alpha,
x,
incx,
stride_x,
AP,
stride_A,
batch_count,
)
}
}
pub unsafe fn spr2_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const *const T,
incx: i32,
y: *const *const T,
incy: i32,
AP: *const *mut T,
batch_count: i32,
) -> Result<()>
where
T: Spr2BatchedType,
{
unsafe { T::rocblas_spr2_batched(handle, uplo, n, alpha, x, incx, y, incy, AP, batch_count) }
}
pub unsafe fn spr2_strided_batched<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
stride_x: i64,
y: *const T,
incy: i32,
stride_y: i64,
AP: *mut T,
stride_A: i64,
batch_count: i32,
) -> Result<()>
where
T: Spr2StridedBatchedType,
{
unsafe {
T::rocblas_spr2_strided_batched(
handle,
uplo,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
AP,
stride_A,
batch_count,
)
}
}
impl GerBatchedType for f32 {
unsafe fn rocblas_ger_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sger_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GerBatchedType for f64 {
unsafe fn rocblas_ger_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dger_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GeruBatchedType {
unsafe fn rocblas_geru_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()>;
}
impl GeruBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_geru_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cgeru_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GeruBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_geru_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zgeru_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GercBatchedType {
unsafe fn rocblas_gerc_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()>;
}
impl GercBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_gerc_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cgerc_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GercBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_gerc_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zgerc_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GerStridedBatchedType {
unsafe fn rocblas_ger_strided_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>;
}
impl GerStridedBatchedType for f32 {
unsafe fn rocblas_ger_strided_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sger_strided_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GerStridedBatchedType for f64 {
unsafe fn rocblas_ger_strided_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dger_strided_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GeruStridedBatchedType {
unsafe fn rocblas_geru_strided_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>;
}
impl GeruStridedBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_geru_strided_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cgeru_strided_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GeruStridedBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_geru_strided_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zgeru_strided_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GercStridedBatchedType {
unsafe fn rocblas_gerc_strided_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>;
}
impl GercStridedBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_gerc_strided_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cgerc_strided_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GercStridedBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_gerc_strided_batched(
handle: &Handle,
m: i32,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zgerc_strided_batched(
handle.as_raw(),
m,
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait SprBatchedType {
unsafe fn rocblas_spr_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
AP: *const *mut Self,
batch_count: i32,
) -> Result<()>;
}
impl SprBatchedType for f32 {
unsafe fn rocblas_spr_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
AP: *const *mut Self,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sspr_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
AP,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SprBatchedType for f64 {
unsafe fn rocblas_spr_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
AP: *const *mut Self,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dspr_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
AP,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SprBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_spr_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
AP: *const *mut Self,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cspr_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
AP,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SprBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_spr_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
AP: *const *mut Self,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zspr_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
AP,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait SprStridedBatchedType {
unsafe fn rocblas_spr_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
AP: *mut Self,
stride_A: i64,
batch_count: i32,
) -> Result<()>;
}
impl SprStridedBatchedType for f32 {
unsafe fn rocblas_spr_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
AP: *mut Self,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sspr_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
AP,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SprStridedBatchedType for f64 {
unsafe fn rocblas_spr_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
AP: *mut Self,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dspr_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
AP,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SprStridedBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_spr_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
AP: *mut Self,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cspr_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
AP,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SprStridedBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_spr_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
AP: *mut Self,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zspr_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
AP,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait Spr2BatchedType {
unsafe fn rocblas_spr2_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
AP: *const *mut Self,
batch_count: i32,
) -> Result<()>;
}
impl Spr2BatchedType for f32 {
unsafe fn rocblas_spr2_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
AP: *const *mut Self,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sspr2_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
AP,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Spr2BatchedType for f64 {
unsafe fn rocblas_spr2_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
AP: *const *mut Self,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dspr2_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
AP,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait Spr2StridedBatchedType {
unsafe fn rocblas_spr2_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
AP: *mut Self,
stride_A: i64,
batch_count: i32,
) -> Result<()>;
}
impl Spr2StridedBatchedType for f32 {
unsafe fn rocblas_spr2_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
AP: *mut Self,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sspr2_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
AP,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Spr2StridedBatchedType for f64 {
unsafe fn rocblas_spr2_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
AP: *mut Self,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dspr2_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
AP,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait Syr2BatchedType {
unsafe fn rocblas_syr2_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()>;
}
impl Syr2BatchedType for f32 {
unsafe fn rocblas_syr2_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_ssyr2_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Syr2BatchedType for f64 {
unsafe fn rocblas_syr2_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dsyr2_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Syr2BatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_syr2_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_csyr2_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Syr2BatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_syr2_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
y: *const *const Self,
incy: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zsyr2_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait Syr2StridedBatchedType {
unsafe fn rocblas_syr2_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>;
}
impl Syr2StridedBatchedType for f32 {
unsafe fn rocblas_syr2_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_ssyr2_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Syr2StridedBatchedType for f64 {
unsafe fn rocblas_syr2_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dsyr2_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Syr2StridedBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_syr2_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_csyr2_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Syr2StridedBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_syr2_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
y: *const Self,
incy: i32,
stride_y: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zsyr2_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
y,
incy,
stride_y,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub unsafe fn symv<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
A: *const T,
lda: i32,
x: *const T,
incx: i32,
beta: &T,
y: *mut T,
incy: i32,
) -> Result<()>
where
T: SymvType,
{
unsafe { T::rocblas_symv(handle, uplo, n, alpha, A, lda, x, incx, beta, y, incy) }
}
pub trait SymvType {
unsafe fn rocblas_symv(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()>;
}
macro_rules! impl_symv {
($t:ty, $func:path) => {
impl SymvType for $t {
unsafe fn rocblas_symv(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
n,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_symv!(f32, ffi::rocblas_ssymv);
impl_symv!(f64, ffi::rocblas_dsymv);
impl_symv!(ffi::rocblas_float_complex, ffi::rocblas_csymv);
impl_symv!(ffi::rocblas_double_complex, ffi::rocblas_zsymv);
pub unsafe fn trmv<T>(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const T,
lda: i32,
x: *mut T,
incx: i32,
) -> Result<()>
where
T: TrmvType,
{
unsafe { T::rocblas_trmv(handle, uplo, trans_a, diag, n, A, lda, x, incx) }
}
pub trait TrmvType {
unsafe fn rocblas_trmv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const Self,
lda: i32,
x: *mut Self,
incx: i32,
) -> Result<()>;
}
macro_rules! impl_trmv {
($t:ty, $func:path) => {
impl TrmvType for $t {
unsafe fn rocblas_trmv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const Self,
lda: i32,
x: *mut Self,
incx: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
trans_a.into(),
diag.into(),
n,
A,
lda,
x,
incx,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_trmv!(f32, ffi::rocblas_strmv);
impl_trmv!(f64, ffi::rocblas_dtrmv);
impl_trmv!(ffi::rocblas_float_complex, ffi::rocblas_ctrmv);
impl_trmv!(ffi::rocblas_double_complex, ffi::rocblas_ztrmv);
pub unsafe fn trsv<T>(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const T,
lda: i32,
x: *mut T,
incx: i32,
) -> Result<()>
where
T: TrsvType,
{
unsafe { T::rocblas_trsv(handle, uplo, trans_a, diag, n, A, lda, x, incx) }
}
pub trait TrsvType {
unsafe fn rocblas_trsv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const Self,
lda: i32,
x: *mut Self,
incx: i32,
) -> Result<()>;
}
macro_rules! impl_trsv {
($t:ty, $func:path) => {
impl TrsvType for $t {
unsafe fn rocblas_trsv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const Self,
lda: i32,
x: *mut Self,
incx: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
trans_a.into(),
diag.into(),
n,
A,
lda,
x,
incx,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_trsv!(f32, ffi::rocblas_strsv);
impl_trsv!(f64, ffi::rocblas_dtrsv);
impl_trsv!(ffi::rocblas_float_complex, ffi::rocblas_ctrsv);
impl_trsv!(ffi::rocblas_double_complex, ffi::rocblas_ztrsv);
pub unsafe fn tpmv<T>(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const T,
x: *mut T,
incx: i32,
) -> Result<()>
where
T: TpmvType,
{
unsafe { T::rocblas_tpmv(handle, uplo, trans_a, diag, n, A, x, incx) }
}
pub trait TpmvType {
unsafe fn rocblas_tpmv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const Self,
x: *mut Self,
incx: i32,
) -> Result<()>;
}
macro_rules! impl_tpmv {
($t:ty, $func:path) => {
impl TpmvType for $t {
unsafe fn rocblas_tpmv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const Self,
x: *mut Self,
incx: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
trans_a.into(),
diag.into(),
n,
A,
x,
incx,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_tpmv!(f32, ffi::rocblas_stpmv);
impl_tpmv!(f64, ffi::rocblas_dtpmv);
impl_tpmv!(ffi::rocblas_float_complex, ffi::rocblas_ctpmv);
impl_tpmv!(ffi::rocblas_double_complex, ffi::rocblas_ztpmv);
pub unsafe fn tpsv<T>(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const T,
x: *mut T,
incx: i32,
) -> Result<()>
where
T: TpsvType,
{
unsafe { T::rocblas_tpsv(handle, uplo, trans_a, diag, n, A, x, incx) }
}
pub trait TpsvType {
unsafe fn rocblas_tpsv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const Self,
x: *mut Self,
incx: i32,
) -> Result<()>;
}
macro_rules! impl_tpsv {
($t:ty, $func:path) => {
impl TpsvType for $t {
unsafe fn rocblas_tpsv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
A: *const Self,
x: *mut Self,
incx: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
trans_a.into(),
diag.into(),
n,
A,
x,
incx,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_tpsv!(f32, ffi::rocblas_stpsv);
impl_tpsv!(f64, ffi::rocblas_dtpsv);
impl_tpsv!(ffi::rocblas_float_complex, ffi::rocblas_ctpsv);
impl_tpsv!(ffi::rocblas_double_complex, ffi::rocblas_ztpsv);
pub unsafe fn tbmv<T>(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
k: i32,
A: *const T,
lda: i32,
x: *mut T,
incx: i32,
) -> Result<()>
where
T: TbmvType,
{
unsafe { T::rocblas_tbmv(handle, uplo, trans_a, diag, n, k, A, lda, x, incx) }
}
pub trait TbmvType {
unsafe fn rocblas_tbmv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
k: i32,
A: *const Self,
lda: i32,
x: *mut Self,
incx: i32,
) -> Result<()>;
}
macro_rules! impl_tbmv {
($t:ty, $func:path) => {
impl TbmvType for $t {
unsafe fn rocblas_tbmv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
k: i32,
A: *const Self,
lda: i32,
x: *mut Self,
incx: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
trans_a.into(),
diag.into(),
n,
k,
A,
lda,
x,
incx,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_tbmv!(f32, ffi::rocblas_stbmv);
impl_tbmv!(f64, ffi::rocblas_dtbmv);
impl_tbmv!(ffi::rocblas_float_complex, ffi::rocblas_ctbmv);
impl_tbmv!(ffi::rocblas_double_complex, ffi::rocblas_ztbmv);
pub unsafe fn tbsv<T>(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
k: i32,
A: *const T,
lda: i32,
x: *mut T,
incx: i32,
) -> Result<()>
where
T: TbsvType,
{
unsafe { T::rocblas_tbsv(handle, uplo, trans_a, diag, n, k, A, lda, x, incx) }
}
pub trait TbsvType {
unsafe fn rocblas_tbsv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
k: i32,
A: *const Self,
lda: i32,
x: *mut Self,
incx: i32,
) -> Result<()>;
}
macro_rules! impl_tbsv {
($t:ty, $func:path) => {
impl TbsvType for $t {
unsafe fn rocblas_tbsv(
handle: &Handle,
uplo: Fill,
trans_a: Operation,
diag: Diagonal,
n: i32,
k: i32,
A: *const Self,
lda: i32,
x: *mut Self,
incx: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
trans_a.into(),
diag.into(),
n,
k,
A,
lda,
x,
incx,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_tbsv!(f32, ffi::rocblas_stbsv);
impl_tbsv!(f64, ffi::rocblas_dtbsv);
impl_tbsv!(ffi::rocblas_float_complex, ffi::rocblas_ctbsv);
impl_tbsv!(ffi::rocblas_double_complex, ffi::rocblas_ztbsv);
pub unsafe fn spmv<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
A: *const T,
x: *const T,
incx: i32,
beta: &T,
y: *mut T,
incy: i32,
) -> Result<()>
where
T: SpmvType,
{
unsafe { T::rocblas_spmv(handle, uplo, n, alpha, A, x, incx, beta, y, incy) }
}
pub trait SpmvType {
unsafe fn rocblas_spmv(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()>;
}
macro_rules! impl_spmv {
($t:ty, $func:path) => {
impl SpmvType for $t {
unsafe fn rocblas_spmv(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
A: *const Self,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
n,
alpha,
A,
x,
incx,
beta,
y,
incy,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_spmv!(f32, ffi::rocblas_sspmv);
impl_spmv!(f64, ffi::rocblas_dspmv);
pub unsafe fn sbmv<T>(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &T,
A: *const T,
lda: i32,
x: *const T,
incx: i32,
beta: &T,
y: *mut T,
incy: i32,
) -> Result<()>
where
T: SbmvType,
{
unsafe {
T::rocblas_sbmv(
handle, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy,
)
}
}
pub trait SbmvType {
unsafe fn rocblas_sbmv(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()>;
}
macro_rules! impl_sbmv {
($t:ty, $func:path) => {
impl SbmvType for $t {
unsafe fn rocblas_sbmv(
handle: &Handle,
uplo: Fill,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
n,
k,
alpha,
A,
lda,
x,
incx,
beta,
y,
incy,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_sbmv!(f32, ffi::rocblas_ssbmv);
impl_sbmv!(f64, ffi::rocblas_dsbmv);
pub unsafe fn hpmv<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
AP: *const T,
x: *const T,
incx: i32,
beta: &T,
y: *mut T,
incy: i32,
) -> Result<()>
where
T: HpmvType,
{
unsafe { T::rocblas_hpmv(handle, uplo, n, alpha, AP, x, incx, beta, y, incy) }
}
pub trait HpmvType {
unsafe fn rocblas_hpmv(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
AP: *const Self,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()>;
}
macro_rules! impl_hpmv {
($t:ty, $func:path) => {
impl HpmvType for $t {
unsafe fn rocblas_hpmv(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
AP: *const Self,
x: *const Self,
incx: i32,
beta: &Self,
y: *mut Self,
incy: i32,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
n,
alpha,
AP,
x,
incx,
beta,
y,
incy,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_hpmv!(ffi::rocblas_float_complex, ffi::rocblas_chpmv);
impl_hpmv!(ffi::rocblas_double_complex, ffi::rocblas_zhpmv);
pub unsafe fn hpr<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T::Real,
x: *const T,
incx: i32,
AP: *mut T,
) -> Result<()>
where
T: HprType,
{
unsafe { T::rocblas_hpr(handle, uplo, n, alpha, x, incx, AP) }
}
pub trait HprType {
type Real;
unsafe fn rocblas_hpr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self::Real,
x: *const Self,
incx: i32,
AP: *mut Self,
) -> Result<()>;
}
macro_rules! impl_hpr {
($t:ty, $r:ty, $func:path) => {
impl HprType for $t {
type Real = $r;
unsafe fn rocblas_hpr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self::Real,
x: *const Self,
incx: i32,
AP: *mut Self,
) -> Result<()> {
let status =
unsafe { $func(handle.as_raw(), uplo.into(), n, alpha, x, incx, AP) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_hpr!(ffi::rocblas_float_complex, f32, ffi::rocblas_chpr);
impl_hpr!(ffi::rocblas_double_complex, f64, ffi::rocblas_zhpr);
pub unsafe fn hpr2<T>(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &T,
x: *const T,
incx: i32,
y: *const T,
incy: i32,
AP: *mut T,
) -> Result<()>
where
T: Hpr2Type,
{
unsafe { T::rocblas_hpr2(handle, uplo, n, alpha, x, incx, y, incy, AP) }
}
pub trait Hpr2Type {
unsafe fn rocblas_hpr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
AP: *mut Self,
) -> Result<()>;
}
macro_rules! impl_hpr2 {
($t:ty, $func:path) => {
impl Hpr2Type for $t {
unsafe fn rocblas_hpr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
AP: *mut Self,
) -> Result<()> {
let status = unsafe {
$func(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
AP,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
};
}
impl_hpr2!(ffi::rocblas_float_complex, ffi::rocblas_chpr2);
impl_hpr2!(ffi::rocblas_double_complex, ffi::rocblas_zhpr2);
#[cfg(test)]
mod tests {
use super::*;
use crate::hip::DeviceMemory;
fn dev(data: &[f32]) -> DeviceMemory<f32> {
let mut m = DeviceMemory::<f32>::new(data.len()).unwrap();
m.copy_from_host(data).unwrap();
m
}
fn host(m: &DeviceMemory<f32>, n: usize) -> Vec<f32> {
let mut v = vec![0.0f32; n];
m.copy_to_host(&mut v).unwrap();
v
}
fn approx(actual: &[f32], expected: &[f32]) {
assert_eq!(actual.len(), expected.len());
for (a, e) in actual.iter().zip(expected) {
assert!((a - e).abs() < 1e-4, "{actual:?} != {expected:?}");
}
}
#[test]
fn test_symv() {
let handle = Handle::new().unwrap();
let a = dev(&[2.0, 0.0, 1.0, 3.0]);
let x = dev(&[1.0, 1.0]);
let mut y = dev(&[0.0, 0.0]);
unsafe {
symv(
&handle,
Fill::Upper,
2,
&1.0,
a.as_ptr().cast::<f32>(),
2,
x.as_ptr().cast::<f32>(),
1,
&0.0,
y.as_ptr().cast::<f32>(),
1,
)
.unwrap();
}
approx(&host(&y, 2), &[3.0, 4.0]);
}
#[test]
fn test_trmv() {
let handle = Handle::new().unwrap();
let a = dev(&[2.0, 0.0, 1.0, 3.0]);
let mut x = dev(&[1.0, 1.0]);
unsafe {
trmv(
&handle,
Fill::Upper,
Operation::None,
Diagonal::NonUnit,
2,
a.as_ptr().cast::<f32>(),
2,
x.as_ptr().cast::<f32>(),
1,
)
.unwrap();
}
approx(&host(&x, 2), &[3.0, 3.0]);
}
#[test]
fn test_trsv() {
let handle = Handle::new().unwrap();
let a = dev(&[2.0, 0.0, 1.0, 3.0]);
let mut x = dev(&[3.0, 3.0]);
unsafe {
trsv(
&handle,
Fill::Upper,
Operation::None,
Diagonal::NonUnit,
2,
a.as_ptr().cast::<f32>(),
2,
x.as_ptr().cast::<f32>(),
1,
)
.unwrap();
}
approx(&host(&x, 2), &[1.0, 1.0]);
}
}