use crate::rocblas::Handle;
use crate::rocblas::ffi as rocblas_ffi;
use crate::rocsolver::bindings;
use crate::rocsolver::error::{Error, Result};
use crate::rocsolver::types::{Complex32, Complex64, Evect, Fill};
type RocblasHandle = rocblas_ffi::rocblas_handle;
type RocblasStatus = rocblas_ffi::rocblas_status;
#[inline]
fn cast_handle(handle: RocblasHandle) -> bindings::rocblas_handle {
handle as bindings::rocblas_handle
}
pub trait SyevType: Sized + Copy {
unsafe fn syev(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
D: *mut Self,
E: *mut Self,
info: *mut i32,
) -> RocblasStatus;
unsafe fn syev_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *const *mut Self,
lda: i32,
D: *mut Self,
stride_d: i64,
E: *mut Self,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus;
unsafe fn syev_strided_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
stride_a: i64,
D: *mut Self,
stride_d: i64,
E: *mut Self,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus;
}
pub trait HeevType: Sized + Copy {
type RealType: Copy;
unsafe fn heev(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
D: *mut Self::RealType,
E: *mut Self::RealType,
info: *mut i32,
) -> RocblasStatus;
unsafe fn heev_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *const *mut Self,
lda: i32,
D: *mut Self::RealType,
stride_d: i64,
E: *mut Self::RealType,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus;
unsafe fn heev_strided_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
stride_a: i64,
D: *mut Self::RealType,
stride_d: i64,
E: *mut Self::RealType,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus;
}
impl SyevType for f32 {
unsafe fn syev(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
D: *mut Self,
E: *mut Self,
info: *mut i32,
) -> RocblasStatus {
bindings::rocsolver_ssyev(cast_handle(handle), evect, uplo, n, A, lda, D, E, info)
}
unsafe fn syev_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *const *mut Self,
lda: i32,
D: *mut Self,
stride_d: i64,
E: *mut Self,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus {
bindings::rocsolver_ssyev_batched(
cast_handle(handle),
evect,
uplo,
n,
A,
lda,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
}
unsafe fn syev_strided_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
stride_a: i64,
D: *mut Self,
stride_d: i64,
E: *mut Self,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus {
bindings::rocsolver_ssyev_strided_batched(
cast_handle(handle),
evect,
uplo,
n,
A,
lda,
stride_a,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
}
}
impl SyevType for f64 {
unsafe fn syev(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
D: *mut Self,
E: *mut Self,
info: *mut i32,
) -> RocblasStatus {
bindings::rocsolver_dsyev(cast_handle(handle), evect, uplo, n, A, lda, D, E, info)
}
unsafe fn syev_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *const *mut Self,
lda: i32,
D: *mut Self,
stride_d: i64,
E: *mut Self,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus {
bindings::rocsolver_dsyev_batched(
cast_handle(handle),
evect,
uplo,
n,
A,
lda,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
}
unsafe fn syev_strided_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
stride_a: i64,
D: *mut Self,
stride_d: i64,
E: *mut Self,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus {
bindings::rocsolver_dsyev_strided_batched(
cast_handle(handle),
evect,
uplo,
n,
A,
lda,
stride_a,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
}
}
impl HeevType for Complex32 {
type RealType = f32;
unsafe fn heev(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
D: *mut Self::RealType,
E: *mut Self::RealType,
info: *mut i32,
) -> RocblasStatus {
bindings::rocsolver_cheev(cast_handle(handle), evect, uplo, n, A, lda, D, E, info)
}
unsafe fn heev_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *const *mut Self,
lda: i32,
D: *mut Self::RealType,
stride_d: i64,
E: *mut Self::RealType,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus {
bindings::rocsolver_cheev_batched(
cast_handle(handle),
evect,
uplo,
n,
A,
lda,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
}
unsafe fn heev_strided_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
stride_a: i64,
D: *mut Self::RealType,
stride_d: i64,
E: *mut Self::RealType,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus {
bindings::rocsolver_cheev_strided_batched(
cast_handle(handle),
evect,
uplo,
n,
A,
lda,
stride_a,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
}
}
impl HeevType for Complex64 {
type RealType = f64;
unsafe fn heev(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
D: *mut Self::RealType,
E: *mut Self::RealType,
info: *mut i32,
) -> RocblasStatus {
bindings::rocsolver_zheev(cast_handle(handle), evect, uplo, n, A, lda, D, E, info)
}
unsafe fn heev_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *const *mut Self,
lda: i32,
D: *mut Self::RealType,
stride_d: i64,
E: *mut Self::RealType,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus {
bindings::rocsolver_zheev_batched(
cast_handle(handle),
evect,
uplo,
n,
A,
lda,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
}
unsafe fn heev_strided_batched(
handle: RocblasHandle,
evect: bindings::rocblas_evect,
uplo: rocblas_ffi::rocblas_fill,
n: i32,
A: *mut Self,
lda: i32,
stride_a: i64,
D: *mut Self::RealType,
stride_d: i64,
E: *mut Self::RealType,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> RocblasStatus {
bindings::rocsolver_zheev_strided_batched(
cast_handle(handle),
evect,
uplo,
n,
A,
lda,
stride_a,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
}
}
#[inline]
pub fn syev<T: SyevType>(
handle: &Handle,
evect: Evect,
uplo: Fill,
n: i32,
A: *mut T,
lda: i32,
D: *mut T,
E: *mut T,
info: *mut i32,
) -> Result<()> {
let status = unsafe {
T::syev(
handle.as_raw(),
evect.into(),
uplo.into(),
n,
A,
lda,
D,
E,
info,
)
};
Error::from_status(status)
}
#[inline]
pub fn syev_batched<T: SyevType>(
handle: &Handle,
evect: Evect,
uplo: Fill,
n: i32,
A: *const *mut T,
lda: i32,
D: *mut T,
stride_d: i64,
E: *mut T,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
T::syev_batched(
handle.as_raw(),
evect.into(),
uplo.into(),
n,
A,
lda,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
};
Error::from_status(status)
}
#[inline]
pub fn syev_strided_batched<T: SyevType>(
handle: &Handle,
evect: Evect,
uplo: Fill,
n: i32,
A: *mut T,
lda: i32,
stride_a: i64,
D: *mut T,
stride_d: i64,
E: *mut T,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
T::syev_strided_batched(
handle.as_raw(),
evect.into(),
uplo.into(),
n,
A,
lda,
stride_a,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
};
Error::from_status(status)
}
#[inline]
pub fn heev<T: HeevType>(
handle: &Handle,
evect: Evect,
uplo: Fill,
n: i32,
A: *mut T,
lda: i32,
D: *mut T::RealType,
E: *mut T::RealType,
info: *mut i32,
) -> Result<()> {
let status = unsafe {
T::heev(
handle.as_raw(),
evect.into(),
uplo.into(),
n,
A,
lda,
D,
E,
info,
)
};
Error::from_status(status)
}
#[inline]
pub fn heev_batched<T: HeevType>(
handle: &Handle,
evect: Evect,
uplo: Fill,
n: i32,
A: *const *mut T,
lda: i32,
D: *mut T::RealType,
stride_d: i64,
E: *mut T::RealType,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
T::heev_batched(
handle.as_raw(),
evect.into(),
uplo.into(),
n,
A,
lda,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
};
Error::from_status(status)
}
#[inline]
pub fn heev_strided_batched<T: HeevType>(
handle: &Handle,
evect: Evect,
uplo: Fill,
n: i32,
A: *mut T,
lda: i32,
stride_a: i64,
D: *mut T::RealType,
stride_d: i64,
E: *mut T::RealType,
stride_e: i64,
info: *mut i32,
batch_count: i32,
) -> Result<()> {
let status = unsafe {
T::heev_strided_batched(
handle.as_raw(),
evect.into(),
uplo.into(),
n,
A,
lda,
stride_a,
D,
stride_d,
E,
stride_e,
info,
batch_count,
)
};
Error::from_status(status)
}