use crate::rocblas::error::{Error, Result};
use crate::rocblas::ffi;
use crate::rocblas::handle::Handle;
use crate::rocblas::types::{DataType, Operation};
use crate::rocblas::utils::GemmAlgo;
use super::types::{Fill, Side};
pub unsafe fn gemm<T>(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &T,
A: *const T,
lda: i32,
B: *const T,
ldb: i32,
beta: &T,
C: *mut T,
ldc: i32,
) -> Result<()>
where
T: GemmType,
{
unsafe {
T::rocblas_gemm(
handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc,
)
}
}
pub unsafe fn gemm_batched<T>(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &T,
A: *const *const T,
lda: i32,
B: *const *const T,
ldb: i32,
beta: &T,
C: *const *mut T,
ldc: i32,
batch_count: i32,
) -> Result<()>
where
T: GemmBatchedType,
{
unsafe {
T::rocblas_gemm_batched(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
}
}
pub unsafe fn gemm_strided_batched<T>(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &T,
A: *const T,
lda: i32,
stride_A: i64,
B: *const T,
ldb: i32,
stride_B: i64,
beta: &T,
C: *mut T,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()>
where
T: GemmStridedBatchedType,
{
unsafe {
T::rocblas_gemm_strided_batched(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
}
}
pub unsafe fn gemm_ex(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: *const std::ffi::c_void,
A: *const std::ffi::c_void,
a_type: DataType,
lda: i32,
B: *const std::ffi::c_void,
b_type: DataType,
ldb: i32,
beta: *const std::ffi::c_void,
C: *mut std::ffi::c_void,
c_type: DataType,
ldc: i32,
compute_type: DataType,
algo: GemmAlgo,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_gemm_ex(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
a_type.into(),
lda,
B,
b_type.into(),
ldb,
beta,
C,
c_type.into(),
ldc,
C, c_type.into(),
ldc,
compute_type.into(),
algo.into(),
0, 0, )
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
pub trait GemmType {
unsafe fn rocblas_gemm(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self,
C: *mut Self,
ldc: i32,
) -> Result<()>;
}
impl GemmType for f32 {
unsafe fn rocblas_gemm(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sgemm(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GemmType for f64 {
unsafe fn rocblas_gemm(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dgemm(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GemmType for ffi::rocblas_float_complex {
unsafe fn rocblas_gemm(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cgemm(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GemmType for ffi::rocblas_double_complex {
unsafe fn rocblas_gemm(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zgemm(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GemmBatchedType {
unsafe fn rocblas_gemm_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()>;
}
impl GemmBatchedType for f32 {
unsafe fn rocblas_gemm_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sgemm_batched(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GemmBatchedType for f64 {
unsafe fn rocblas_gemm_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dgemm_batched(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GemmBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_gemm_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cgemm_batched(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GemmBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_gemm_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zgemm_batched(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait GemmStridedBatchedType {
unsafe fn rocblas_gemm_strided_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()>;
}
impl GemmStridedBatchedType for f32 {
unsafe fn rocblas_gemm_strided_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sgemm_strided_batched(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GemmStridedBatchedType for f64 {
unsafe fn rocblas_gemm_strided_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dgemm_strided_batched(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GemmStridedBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_gemm_strided_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cgemm_strided_batched(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl GemmStridedBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_gemm_strided_batched(
handle: &Handle,
transa: Operation,
transb: Operation,
m: i32,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zgemm_strided_batched(
handle.as_raw(),
transa.into(),
transb.into(),
m,
n,
k,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HemmType {
unsafe fn rocblas_hemm(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self,
C: *mut Self,
ldc: i32,
) -> Result<()>;
}
impl HemmType for ffi::rocblas_float_complex {
unsafe fn rocblas_hemm(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_chemm(
handle.as_raw(),
side.into(),
uplo.into(),
m,
n,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HemmType for ffi::rocblas_double_complex {
unsafe fn rocblas_hemm(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zhemm(
handle.as_raw(),
side.into(),
uplo.into(),
m,
n,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HerkType {
type ScalarType;
unsafe fn rocblas_herk(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &Self::ScalarType,
A: *const Self,
lda: i32,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
) -> Result<()>;
}
impl HerkType for ffi::rocblas_float_complex {
type ScalarType = f32;
unsafe fn rocblas_herk(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &Self::ScalarType,
A: *const Self,
lda: i32,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cherk(
handle.as_raw(),
uplo.into(),
transA.into(),
n,
k,
alpha,
A,
lda,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HerkType for ffi::rocblas_double_complex {
type ScalarType = f64;
unsafe fn rocblas_herk(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &Self::ScalarType,
A: *const Self,
lda: i32,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zherk(
handle.as_raw(),
uplo.into(),
transA.into(),
n,
k,
alpha,
A,
lda,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait SprType {
unsafe fn rocblas_spr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
AP: *mut Self,
) -> Result<()>;
}
impl SprType for f32 {
unsafe fn rocblas_spr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
AP: *mut Self,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_sspr(handle.as_raw(), uplo.into(), n, alpha, x, incx, AP) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SprType for f64 {
unsafe fn rocblas_spr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
AP: *mut Self,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_dspr(handle.as_raw(), uplo.into(), n, alpha, x, incx, AP) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SprType for ffi::rocblas_float_complex {
unsafe fn rocblas_spr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
AP: *mut Self,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_cspr(handle.as_raw(), uplo.into(), n, alpha, x, incx, AP) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SprType for ffi::rocblas_double_complex {
unsafe fn rocblas_spr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
AP: *mut Self,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_zspr(handle.as_raw(), uplo.into(), n, alpha, x, incx, AP) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait Spr2Type {
unsafe fn rocblas_spr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
AP: *mut Self,
) -> Result<()>;
}
impl Spr2Type for f32 {
unsafe fn rocblas_spr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
AP: *mut Self,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_sspr2(handle.as_raw(), uplo.into(), n, alpha, x, incx, y, incy, AP)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Spr2Type for f64 {
unsafe fn rocblas_spr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
AP: *mut Self,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dspr2(handle.as_raw(), uplo.into(), n, alpha, x, incx, y, incy, AP)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait SyrType {
unsafe fn rocblas_syr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
A: *mut Self,
lda: i32,
) -> Result<()>;
}
impl SyrType for f32 {
unsafe fn rocblas_syr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_ssyr(handle.as_raw(), uplo.into(), n, alpha, x, incx, A, lda) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait Syr2Type {
unsafe fn rocblas_syr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()>;
}
impl Syr2Type for f32 {
unsafe fn rocblas_syr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_ssyr2(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Syr2Type for f64 {
unsafe fn rocblas_syr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_dsyr2(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SyrType for ffi::rocblas_float_complex {
unsafe fn rocblas_syr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_csyr(handle.as_raw(), uplo.into(), n, alpha, x, incx, A, lda) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl SyrType for ffi::rocblas_double_complex {
unsafe fn rocblas_syr(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status =
unsafe { ffi::rocblas_zsyr(handle.as_raw(), uplo.into(), n, alpha, x, incx, A, lda) };
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Syr2Type for ffi::rocblas_float_complex {
unsafe fn rocblas_syr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_csyr2(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl Syr2Type for ffi::rocblas_double_complex {
unsafe fn rocblas_syr2(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
y: *const Self,
incy: i32,
A: *mut Self,
lda: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zsyr2(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
y,
incy,
A,
lda,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait SyrBatchedType {
unsafe fn rocblas_syr_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()>;
}
impl SyrBatchedType for f32 {
unsafe fn rocblas_syr_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const *const Self,
incx: i32,
A: *const *mut Self,
lda: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_ssyr_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
A,
lda,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait SyrStridedBatchedType {
unsafe fn rocblas_syr_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()>;
}
impl SyrStridedBatchedType for f32 {
unsafe fn rocblas_syr_strided_batched(
handle: &Handle,
uplo: Fill,
n: i32,
alpha: &Self,
x: *const Self,
incx: i32,
stride_x: i64,
A: *mut Self,
lda: i32,
stride_A: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_ssyr_strided_batched(
handle.as_raw(),
uplo.into(),
n,
alpha,
x,
incx,
stride_x,
A,
lda,
stride_A,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HemmBatchedType {
unsafe fn rocblas_hemm_batched(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()>;
}
impl HemmBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_hemm_batched(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_chemm_batched(
handle.as_raw(),
side.into(),
uplo.into(),
m,
n,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HemmBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_hemm_batched(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zhemm_batched(
handle.as_raw(),
side.into(),
uplo.into(),
m,
n,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HemmStridedBatchedType {
unsafe fn rocblas_hemm_strided_batched(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()>;
}
impl HemmStridedBatchedType for ffi::rocblas_float_complex {
unsafe fn rocblas_hemm_strided_batched(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_chemm_strided_batched(
handle.as_raw(),
side.into(),
uplo.into(),
m,
n,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HemmStridedBatchedType for ffi::rocblas_double_complex {
unsafe fn rocblas_hemm_strided_batched(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zhemm_strided_batched(
handle.as_raw(),
side.into(),
uplo.into(),
m,
n,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HerkBatchedType {
type ScalarType;
unsafe fn rocblas_herk_batched(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &Self::ScalarType,
A: *const *const Self,
lda: i32,
beta: &Self::ScalarType,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()>;
}
impl HerkBatchedType for ffi::rocblas_float_complex {
type ScalarType = f32;
unsafe fn rocblas_herk_batched(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &Self::ScalarType,
A: *const *const Self,
lda: i32,
beta: &Self::ScalarType,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cherk_batched(
handle.as_raw(),
uplo.into(),
transA.into(),
n,
k,
alpha,
A,
lda,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HerkBatchedType for ffi::rocblas_double_complex {
type ScalarType = f64;
unsafe fn rocblas_herk_batched(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &Self::ScalarType,
A: *const *const Self,
lda: i32,
beta: &Self::ScalarType,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zherk_batched(
handle.as_raw(),
uplo.into(),
transA.into(),
n,
k,
alpha,
A,
lda,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HerkStridedBatchedType {
type ScalarType;
unsafe fn rocblas_herk_strided_batched(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &Self::ScalarType,
A: *const Self,
lda: i32,
stride_A: i64,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()>;
}
impl HerkStridedBatchedType for ffi::rocblas_float_complex {
type ScalarType = f32;
unsafe fn rocblas_herk_strided_batched(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &Self::ScalarType,
A: *const Self,
lda: i32,
stride_A: i64,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cherk_strided_batched(
handle.as_raw(),
uplo.into(),
transA.into(),
n,
k,
alpha,
A,
lda,
stride_A,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HerkStridedBatchedType for ffi::rocblas_double_complex {
type ScalarType = f64;
unsafe fn rocblas_herk_strided_batched(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &Self::ScalarType,
A: *const Self,
lda: i32,
stride_A: i64,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zherk_strided_batched(
handle.as_raw(),
uplo.into(),
transA.into(),
n,
k,
alpha,
A,
lda,
stride_A,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub unsafe fn hemm_batched<T>(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &T,
A: *const *const T,
lda: i32,
B: *const *const T,
ldb: i32,
beta: &T,
C: *const *mut T,
ldc: i32,
batch_count: i32,
) -> Result<()>
where
T: HemmBatchedType,
{
unsafe {
T::rocblas_hemm_batched(
handle,
side,
uplo,
m,
n,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
}
}
pub unsafe fn hemm_strided_batched<T>(
handle: &Handle,
side: Side,
uplo: Fill,
m: i32,
n: i32,
alpha: &T,
A: *const T,
lda: i32,
stride_A: i64,
B: *const T,
ldb: i32,
stride_B: i64,
beta: &T,
C: *mut T,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()>
where
T: HemmStridedBatchedType,
{
unsafe {
T::rocblas_hemm_strided_batched(
handle,
side,
uplo,
m,
n,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
}
}
pub unsafe fn herk_batched<T, R>(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &R,
A: *const *const T,
lda: i32,
beta: &R,
C: *const *mut T,
ldc: i32,
batch_count: i32,
) -> Result<()>
where
T: HerkBatchedType<ScalarType = R>,
{
unsafe {
T::rocblas_herk_batched(
handle,
uplo,
transA,
n,
k,
alpha,
A,
lda,
beta,
C,
ldc,
batch_count,
)
}
}
pub unsafe fn herk_strided_batched<T, R>(
handle: &Handle,
uplo: Fill,
transA: Operation,
n: i32,
k: i32,
alpha: &R,
A: *const T,
lda: i32,
stride_A: i64,
beta: &R,
C: *mut T,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()>
where
T: HerkStridedBatchedType<ScalarType = R>,
{
unsafe {
T::rocblas_herk_strided_batched(
handle,
uplo,
transA,
n,
k,
alpha,
A,
lda,
stride_A,
beta,
C,
ldc,
stride_C,
batch_count,
)
}
}
pub unsafe fn herkx<T, R>(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &T,
A: *const T,
lda: i32,
B: *const T,
ldb: i32,
beta: &R,
C: *mut T,
ldc: i32,
) -> Result<()>
where
T: HerkxType<ScalarType = R>,
{
unsafe {
T::rocblas_herkx(
handle, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc,
)
}
}
pub unsafe fn herkx_batched<T, R>(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &T,
A: *const *const T,
lda: i32,
B: *const *const T,
ldb: i32,
beta: &R,
C: *const *mut T,
ldc: i32,
batch_count: i32,
) -> Result<()>
where
T: HerkxBatchedType<ScalarType = R>,
{
unsafe {
T::rocblas_herkx_batched(
handle,
uplo,
trans,
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
}
}
pub unsafe fn herkx_strided_batched<T, R>(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &T,
A: *const T,
lda: i32,
stride_A: i64,
B: *const T,
ldb: i32,
stride_B: i64,
beta: &R,
C: *mut T,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()>
where
T: HerkxStridedBatchedType<ScalarType = R>,
{
unsafe {
T::rocblas_herkx_strided_batched(
handle,
uplo,
trans,
n,
k,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
}
}
pub trait HerkxType {
type ScalarType;
unsafe fn rocblas_herkx(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
) -> Result<()>;
}
impl HerkxType for ffi::rocblas_float_complex {
type ScalarType = f32;
unsafe fn rocblas_herkx(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cherkx(
handle.as_raw(),
uplo.into(),
trans.into(),
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HerkxType for ffi::rocblas_double_complex {
type ScalarType = f64;
unsafe fn rocblas_herkx(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
B: *const Self,
ldb: i32,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zherkx(
handle.as_raw(),
uplo.into(),
trans.into(),
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HerkxBatchedType {
type ScalarType;
unsafe fn rocblas_herkx_batched(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self::ScalarType,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()>;
}
impl HerkxBatchedType for ffi::rocblas_float_complex {
type ScalarType = f32;
unsafe fn rocblas_herkx_batched(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self::ScalarType,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cherkx_batched(
handle.as_raw(),
uplo.into(),
trans.into(),
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HerkxBatchedType for ffi::rocblas_double_complex {
type ScalarType = f64;
unsafe fn rocblas_herkx_batched(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &Self,
A: *const *const Self,
lda: i32,
B: *const *const Self,
ldb: i32,
beta: &Self::ScalarType,
C: *const *mut Self,
ldc: i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zherkx_batched(
handle.as_raw(),
uplo.into(),
trans.into(),
n,
k,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
pub trait HerkxStridedBatchedType {
type ScalarType;
unsafe fn rocblas_herkx_strided_batched(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()>;
}
impl HerkxStridedBatchedType for ffi::rocblas_float_complex {
type ScalarType = f32;
unsafe fn rocblas_herkx_strided_batched(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_cherkx_strided_batched(
handle.as_raw(),
uplo.into(),
trans.into(),
n,
k,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
impl HerkxStridedBatchedType for ffi::rocblas_double_complex {
type ScalarType = f64;
unsafe fn rocblas_herkx_strided_batched(
handle: &Handle,
uplo: Fill,
trans: Operation,
n: i32,
k: i32,
alpha: &Self,
A: *const Self,
lda: i32,
stride_A: i64,
B: *const Self,
ldb: i32,
stride_B: i64,
beta: &Self::ScalarType,
C: *mut Self,
ldc: i32,
stride_C: i64,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
ffi::rocblas_zherkx_strided_batched(
handle.as_raw(),
uplo.into(),
trans.into(),
n,
k,
alpha,
A,
lda,
stride_A,
B,
ldb,
stride_B,
beta,
C,
ldc,
stride_C,
batch_count,
)
};
if status != ffi::rocblas_status__rocblas_status_success {
return Err(Error::new(status));
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Diagonal {
NonUnit,
Unit,
}
impl From<Diagonal> for ffi::rocblas_diagonal {
fn from(diag: Diagonal) -> Self {
match diag {
Diagonal::NonUnit => ffi::rocblas_diagonal__rocblas_diagonal_non_unit,
Diagonal::Unit => ffi::rocblas_diagonal__rocblas_diagonal_unit,
}
}
}
impl From<ffi::rocblas_diagonal> for Diagonal {
fn from(diag: ffi::rocblas_diagonal) -> Self {
match diag {
ffi::rocblas_diagonal__rocblas_diagonal_non_unit => Diagonal::NonUnit,
ffi::rocblas_diagonal__rocblas_diagonal_unit => Diagonal::Unit,
_ => Diagonal::NonUnit, }
}
}