#![allow(non_camel_case_types)]
#![allow(clippy::too_many_arguments)]
use core::ffi::c_void;
use core::ptr;
use super::{
cuComplex, cuDoubleComplex, cudaDataType, cusolverDnCheevd, cusolverDnCheevd_bufferSize,
cusolverDnCreate, cusolverDnCreateGesvdjInfo, cusolverDnCreateParams, cusolverDnDDgels,
cusolverDnDDgels_bufferSize, cusolverDnDestroy, cusolverDnDestroyGesvdjInfo,
cusolverDnDestroyParams, cusolverDnDgeqrf, cusolverDnDgeqrf_bufferSize,
cusolverDnDgesvd, cusolverDnDgesvd_bufferSize, cusolverDnDgesvdaStridedBatched,
cusolverDnDgesvdaStridedBatched_bufferSize, cusolverDnDgesvdjBatched,
cusolverDnDgesvdjBatched_bufferSize, cusolverDnDgetrf, cusolverDnDgetrf_bufferSize,
cusolverDnDgetrs, cusolverDnDormqr, cusolverDnDpotrf, cusolverDnDpotrfBatched,
cusolverDnDpotrf_bufferSize, cusolverDnDsyevd, cusolverDnDsyevd_bufferSize, cusolverDnHandle_t,
cusolverDnParams_t, cusolverDnSSgels, cusolverDnSSgels_bufferSize, cusolverDnSetStream,
cusolverDnSgeqrf, cusolverDnSgeqrf_bufferSize, cusolverDnSgesvd, cusolverDnSgesvd_bufferSize,
cusolverDnSgesvdaStridedBatched, cusolverDnSgesvdaStridedBatched_bufferSize,
cusolverDnSgesvdjBatched, cusolverDnSgesvdjBatched_bufferSize, cusolverDnSgetrf,
cusolverDnSgetrf_bufferSize, cusolverDnSgetrs, cusolverDnSormqr, cusolverDnSpotrf,
cusolverDnSpotrfBatched, cusolverDnSpotrf_bufferSize, cusolverDnSsyevd,
cusolverDnSsyevd_bufferSize, cusolverDnXgeev, cusolverDnXgeev_bufferSize, cusolverDnZheevd,
cusolverDnZheevd_bufferSize, gesvdjInfo_t, CUBLAS_FILL_MODE_LOWER, CUBLAS_FILL_MODE_UPPER,
CUBLAS_OP_N, CUDA_C_32F, CUDA_C_64F, CUDA_R_32F, CUDA_R_64F,
CUSOLVER_EIG_MODE_NOVECTOR, CUSOLVER_EIG_MODE_VECTOR,
};
const OK: i32 = 0;
const INVALID: i32 = 2;
const WS_TOO_SMALL: i32 = 4;
const INTERNAL: i32 = 5;
#[inline]
fn map_cusolver(status: i32) -> i32 {
if status == 0 {
OK
} else {
INTERNAL
}
}
struct Handle {
h: cusolverDnHandle_t,
}
impl Handle {
#[inline]
fn new() -> Self {
Self { h: ptr::null_mut() }
}
}
impl Drop for Handle {
fn drop(&mut self) {
if !self.h.is_null() {
unsafe {
let _ = cusolverDnDestroy(self.h);
}
}
}
}
#[inline]
unsafe fn setup_handle(g: &mut Handle, stream: *mut c_void) -> i32 {
let s = unsafe { cusolverDnCreate(&mut g.h as *mut _) };
if s != 0 {
return INTERNAL;
}
let s = unsafe { cusolverDnSetStream(g.h, stream) };
if s != 0 {
return INTERNAL;
}
OK
}
struct Params {
p: cusolverDnParams_t,
}
impl Params {
#[inline]
fn new() -> Self {
Self { p: ptr::null_mut() }
}
}
impl Drop for Params {
fn drop(&mut self) {
if !self.p.is_null() {
unsafe {
let _ = cusolverDnDestroyParams(self.p);
}
}
}
}
struct JacobiInfo {
p: gesvdjInfo_t,
}
impl JacobiInfo {
#[inline]
fn new() -> Self {
Self { p: ptr::null_mut() }
}
}
impl Drop for JacobiInfo {
fn drop(&mut self) {
if !self.p.is_null() {
unsafe {
let _ = cusolverDnDestroyGesvdjInfo(self.p);
}
}
}
}
#[inline]
fn check_ws(workspace: *mut c_void, workspace_bytes: usize, needed: usize) -> i32 {
if needed == 0 {
return OK;
}
if workspace.is_null() {
return WS_TOO_SMALL;
}
if workspace_bytes < needed {
return WS_TOO_SMALL;
}
OK
}
macro_rules! cholesky_pair {
($nb_name:ident, $bs_name:ident, $ws_name:ident,
$potrf:ident, $potrfb:ident, $potrf_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(n: i32, lda: i32, out_bytes: *mut usize) -> i32 {
if n <= 0 || lda < n || out_bytes.is_null() {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe {
$potrf_bs(h.h, CUBLAS_FILL_MODE_LOWER, n, ptr::null_mut(), lda, &mut lwork)
};
if st != 0 {
return INTERNAL;
}
let bytes = (lwork as usize) * core::mem::size_of::<$T>();
unsafe { *out_bytes = bytes };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $nb_name(
uplo: i32,
n: i32,
lda: i32,
a_inout: *mut c_void,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || lda < n || a_inout.is_null() || info_out.is_null() {
return INVALID;
}
if !matches!(uplo, CUBLAS_FILL_MODE_LOWER | CUBLAS_FILL_MODE_UPPER) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
if workspace_bytes > 0 && workspace.is_null() {
return INVALID;
}
let lwork = (workspace_bytes / core::mem::size_of::<$T>()) as i32;
let st = unsafe {
$potrf(
h.h,
uplo,
n,
a_inout as *mut $T,
lda,
workspace as *mut $T,
lwork,
info_out,
)
};
map_cusolver(st)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $bs_name(
uplo: i32,
n: i32,
lda: i32,
a_array: *mut *mut c_void,
info_array: *mut i32,
batch_size: i32,
stream: *mut c_void,
) -> i32 {
if n <= 0
|| lda < n
|| batch_size <= 0
|| a_array.is_null()
|| info_array.is_null()
{
return INVALID;
}
if !matches!(uplo, CUBLAS_FILL_MODE_LOWER | CUBLAS_FILL_MODE_UPPER) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let st = unsafe {
$potrfb(
h.h,
uplo,
n,
a_array as *mut *mut $T,
lda,
info_array,
batch_size,
)
};
map_cusolver(st)
}
};
}
cholesky_pair!(
baracuda_kernels_cholesky_f32_run,
baracuda_kernels_cholesky_batched_f32_run,
baracuda_kernels_cholesky_f32_workspace_size,
cusolverDnSpotrf,
cusolverDnSpotrfBatched,
cusolverDnSpotrf_bufferSize,
f32
);
cholesky_pair!(
baracuda_kernels_cholesky_f64_run,
baracuda_kernels_cholesky_batched_f64_run,
baracuda_kernels_cholesky_f64_workspace_size,
cusolverDnDpotrf,
cusolverDnDpotrfBatched,
cusolverDnDpotrf_bufferSize,
f64
);
macro_rules! lu_pair {
($name:ident, $ws_name:ident, $getrf:ident, $getrf_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(m: i32, n: i32, lda: i32, out_bytes: *mut usize) -> i32 {
if m <= 0 || n <= 0 || lda < m || out_bytes.is_null() {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $getrf_bs(h.h, m, n, ptr::null_mut(), lda, &mut lwork) };
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<$T>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
m: i32,
n: i32,
lda: i32,
a_inout: *mut c_void,
pivots_out: *mut i32,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if m <= 0 || n <= 0 || lda < m
|| a_inout.is_null() || pivots_out.is_null() || info_out.is_null()
{
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $getrf_bs(h.h, m, n, ptr::null_mut(), lda, &mut lwork) };
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<$T>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
$getrf(
h.h,
m,
n,
a_inout as *mut $T,
lda,
workspace as *mut $T,
pivots_out,
info_out,
)
};
map_cusolver(st)
}
};
}
lu_pair!(
baracuda_kernels_lu_f32_run,
baracuda_kernels_lu_f32_workspace_size,
cusolverDnSgetrf,
cusolverDnSgetrf_bufferSize,
f32
);
lu_pair!(
baracuda_kernels_lu_f64_run,
baracuda_kernels_lu_f64_workspace_size,
cusolverDnDgetrf,
cusolverDnDgetrf_bufferSize,
f64
);
macro_rules! qr_pair {
($name:ident, $ws_name:ident, $geqrf:ident, $geqrf_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(m: i32, n: i32, lda: i32, out_bytes: *mut usize) -> i32 {
if m <= 0 || n <= 0 || lda < m || out_bytes.is_null() {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $geqrf_bs(h.h, m, n, ptr::null_mut(), lda, &mut lwork) };
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<$T>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
m: i32,
n: i32,
lda: i32,
a_inout: *mut c_void,
tau_out: *mut c_void,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if m <= 0 || n <= 0 || m < n || lda < m
|| a_inout.is_null() || tau_out.is_null() || info_out.is_null()
{
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $geqrf_bs(h.h, m, n, ptr::null_mut(), lda, &mut lwork) };
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<$T>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
$geqrf(
h.h,
m,
n,
a_inout as *mut $T,
lda,
tau_out as *mut $T,
workspace as *mut $T,
lwork,
info_out,
)
};
map_cusolver(st)
}
};
}
qr_pair!(
baracuda_kernels_qr_f32_run,
baracuda_kernels_qr_f32_workspace_size,
cusolverDnSgeqrf,
cusolverDnSgeqrf_bufferSize,
f32
);
qr_pair!(
baracuda_kernels_qr_f64_run,
baracuda_kernels_qr_f64_workspace_size,
cusolverDnDgeqrf,
cusolverDnDgeqrf_bufferSize,
f64
);
macro_rules! ormqr_pair {
($name:ident, $ormqr:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
side: i32,
op: i32,
m: i32,
n: i32,
k: i32,
a_packed: *const c_void,
lda: i32,
tau: *const c_void,
c_inout: *mut c_void,
ldc: i32,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if m <= 0 || n <= 0 || k <= 0 || lda <= 0 || ldc <= 0
|| a_packed.is_null() || tau.is_null() || c_inout.is_null() || info_out.is_null()
{
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let lwork = (workspace_bytes / core::mem::size_of::<$T>()) as i32;
let st = unsafe {
$ormqr(
h.h,
side,
op,
m,
n,
k,
a_packed as *const $T,
lda,
tau as *const $T,
c_inout as *mut $T,
ldc,
workspace as *mut $T,
lwork,
info_out,
)
};
map_cusolver(st)
}
};
}
ormqr_pair!(baracuda_kernels_ormqr_f32_run, cusolverDnSormqr, f32);
ormqr_pair!(baracuda_kernels_ormqr_f64_run, cusolverDnDormqr, f64);
macro_rules! svd_pair {
($name:ident, $ws_name:ident, $gesvd:ident, $gesvd_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(m: i32, n: i32, out_bytes: *mut usize) -> i32 {
if m <= 0 || n <= 0 || m < n || out_bytes.is_null() {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $gesvd_bs(h.h, m, n, &mut lwork) };
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<$T>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
jobu: u8,
jobv: u8,
m: i32,
n: i32,
lda: i32,
a_inout: *mut c_void,
ldu: i32,
ldvt: i32,
s_out: *mut c_void,
u_out: *mut c_void,
vt_out: *mut c_void,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if m <= 0 || n <= 0 || m < n || lda < m
|| a_inout.is_null() || s_out.is_null() || info_out.is_null()
{
return INVALID;
}
if !matches!(jobu, b'A' | b'S' | b'N') || !matches!(jobv, b'A' | b'S' | b'N') {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $gesvd_bs(h.h, m, n, &mut lwork) };
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<$T>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
$gesvd(
h.h,
jobu,
jobv,
m,
n,
a_inout as *mut $T,
lda,
s_out as *mut $T,
u_out as *mut $T,
ldu,
vt_out as *mut $T,
ldvt,
workspace as *mut $T,
lwork,
ptr::null_mut(),
info_out,
)
};
map_cusolver(st)
}
};
}
svd_pair!(
baracuda_kernels_svd_f32_run,
baracuda_kernels_svd_f32_workspace_size,
cusolverDnSgesvd,
cusolverDnSgesvd_bufferSize,
f32
);
svd_pair!(
baracuda_kernels_svd_f64_run,
baracuda_kernels_svd_f64_workspace_size,
cusolverDnDgesvd,
cusolverDnDgesvd_bufferSize,
f64
);
macro_rules! svd_batched_pair {
($name:ident, $ws_name:ident, $gesvdj:ident, $gesvdj_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(
jobz: i32,
n: i32,
batch_size: i32,
out_bytes: *mut usize,
) -> i32 {
if n <= 0 || batch_size <= 0 || out_bytes.is_null() {
return INVALID;
}
if !matches!(jobz, CUSOLVER_EIG_MODE_VECTOR | CUSOLVER_EIG_MODE_NOVECTOR) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut p = JacobiInfo::new();
let st = unsafe { cusolverDnCreateGesvdjInfo(&mut p.p as *mut _) };
if st != 0 {
return INTERNAL;
}
let mut lwork: i32 = 0;
let st = unsafe {
$gesvdj_bs(
h.h, jobz, n, n, ptr::null(), n, ptr::null(), ptr::null(), n, ptr::null(), n,
&mut lwork, p.p, batch_size,
)
};
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<$T>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
jobz: i32,
n: i32,
lda: i32,
ldu: i32,
ldv: i32,
a_inout: *mut c_void,
s_out: *mut c_void,
u_out: *mut c_void,
v_out: *mut c_void,
info_out: *mut i32,
batch_size: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || lda < n || ldu < n || ldv < n || batch_size <= 0
|| a_inout.is_null() || s_out.is_null() || info_out.is_null()
{
return INVALID;
}
if !matches!(jobz, CUSOLVER_EIG_MODE_VECTOR | CUSOLVER_EIG_MODE_NOVECTOR) {
return INVALID;
}
if jobz == CUSOLVER_EIG_MODE_VECTOR && (u_out.is_null() || v_out.is_null()) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut p = JacobiInfo::new();
let st = unsafe { cusolverDnCreateGesvdjInfo(&mut p.p as *mut _) };
if st != 0 {
return INTERNAL;
}
let mut lwork: i32 = 0;
let st = unsafe {
$gesvdj_bs(
h.h, jobz, n, n, ptr::null(), lda, ptr::null(), ptr::null(), ldu, ptr::null(),
ldv, &mut lwork, p.p, batch_size,
)
};
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<$T>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
$gesvdj(
h.h,
jobz,
n,
n,
a_inout as *mut $T,
lda,
s_out as *mut $T,
u_out as *mut $T,
ldu,
v_out as *mut $T,
ldv,
workspace as *mut $T,
lwork,
info_out,
p.p,
batch_size,
)
};
map_cusolver(st)
}
};
}
svd_batched_pair!(
baracuda_kernels_svd_batched_f32_run,
baracuda_kernels_svd_batched_f32_workspace_size,
cusolverDnSgesvdjBatched,
cusolverDnSgesvdjBatched_bufferSize,
f32
);
svd_batched_pair!(
baracuda_kernels_svd_batched_f64_run,
baracuda_kernels_svd_batched_f64_workspace_size,
cusolverDnDgesvdjBatched,
cusolverDnDgesvdjBatched_bufferSize,
f64
);
macro_rules! svda_batched_pair {
($name:ident, $ws_name:ident, $gesvda:ident, $gesvda_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(
jobz: i32,
rank: i32,
m: i32,
n: i32,
batch_size: i32,
out_bytes: *mut usize,
) -> i32 {
if m <= 0 || n <= 0 || batch_size <= 0 || rank < 1 || rank > m.min(n)
|| out_bytes.is_null()
{
return INVALID;
}
if !matches!(jobz, CUSOLVER_EIG_MODE_VECTOR | CUSOLVER_EIG_MODE_NOVECTOR) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let stride_a = (m as i64) * (n as i64);
let stride_s = rank as i64;
let stride_u = (m as i64) * (rank as i64);
let stride_v = (n as i64) * (rank as i64);
let mut lwork: i32 = 0;
let st = unsafe {
$gesvda_bs(
h.h, jobz, rank, m, n, ptr::null(), m, stride_a, ptr::null(), stride_s,
ptr::null(), m, stride_u, ptr::null(), n, stride_v, &mut lwork, batch_size,
)
};
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<$T>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
jobz: i32,
rank: i32,
m: i32,
n: i32,
lda: i32,
ldu: i32,
ldv: i32,
stride_a: i64,
stride_s: i64,
stride_u: i64,
stride_v: i64,
a_in: *const c_void,
s_out: *mut c_void,
u_out: *mut c_void,
v_out: *mut c_void,
info_out: *mut i32,
h_r_nrm_f_out: *mut f64,
batch_size: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if m <= 0 || n <= 0 || batch_size <= 0 || rank < 1 || rank > m.min(n)
|| lda < m || ldu < m || ldv < n
|| a_in.is_null() || s_out.is_null() || info_out.is_null()
{
return INVALID;
}
if !matches!(jobz, CUSOLVER_EIG_MODE_VECTOR | CUSOLVER_EIG_MODE_NOVECTOR) {
return INVALID;
}
if jobz == CUSOLVER_EIG_MODE_VECTOR && (u_out.is_null() || v_out.is_null()) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe {
$gesvda_bs(
h.h, jobz, rank, m, n, ptr::null(), lda, stride_a, ptr::null(), stride_s,
ptr::null(), ldu, stride_u, ptr::null(), ldv, stride_v, &mut lwork, batch_size,
)
};
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<$T>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
$gesvda(
h.h,
jobz,
rank,
m,
n,
a_in as *const $T,
lda,
stride_a,
s_out as *mut $T,
stride_s,
u_out as *mut $T,
ldu,
stride_u,
v_out as *mut $T,
ldv,
stride_v,
workspace as *mut $T,
lwork,
info_out,
h_r_nrm_f_out,
batch_size,
)
};
map_cusolver(st)
}
};
}
svda_batched_pair!(
baracuda_kernels_svda_batched_f32_run,
baracuda_kernels_svda_batched_f32_workspace_size,
cusolverDnSgesvdaStridedBatched,
cusolverDnSgesvdaStridedBatched_bufferSize,
f32
);
svda_batched_pair!(
baracuda_kernels_svda_batched_f64_run,
baracuda_kernels_svda_batched_f64_workspace_size,
cusolverDnDgesvdaStridedBatched,
cusolverDnDgesvdaStridedBatched_bufferSize,
f64
);
macro_rules! eigh_real_pair {
($name:ident, $ws_name:ident, $syevd:ident, $syevd_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(uplo: i32, n: i32, out_bytes: *mut usize) -> i32 {
if n <= 0 || out_bytes.is_null() {
return INVALID;
}
if !matches!(uplo, CUBLAS_FILL_MODE_LOWER | CUBLAS_FILL_MODE_UPPER) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe {
$syevd_bs(
h.h, CUSOLVER_EIG_MODE_VECTOR, uplo, n, ptr::null(), n, ptr::null(),
&mut lwork,
)
};
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<$T>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
uplo: i32,
n: i32,
lda: i32,
a_inout: *mut c_void,
eigenvalues_out: *mut c_void,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || lda < n
|| a_inout.is_null() || eigenvalues_out.is_null() || info_out.is_null()
{
return INVALID;
}
if !matches!(uplo, CUBLAS_FILL_MODE_LOWER | CUBLAS_FILL_MODE_UPPER) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe {
$syevd_bs(
h.h, CUSOLVER_EIG_MODE_VECTOR, uplo, n, ptr::null(), lda, ptr::null(),
&mut lwork,
)
};
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<$T>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
$syevd(
h.h,
CUSOLVER_EIG_MODE_VECTOR,
uplo,
n,
a_inout as *mut $T,
lda,
eigenvalues_out as *mut $T,
workspace as *mut $T,
lwork,
info_out,
)
};
map_cusolver(st)
}
};
}
eigh_real_pair!(
baracuda_kernels_eigh_f32_run,
baracuda_kernels_eigh_f32_workspace_size,
cusolverDnSsyevd,
cusolverDnSsyevd_bufferSize,
f32
);
eigh_real_pair!(
baracuda_kernels_eigh_f64_run,
baracuda_kernels_eigh_f64_workspace_size,
cusolverDnDsyevd,
cusolverDnDsyevd_bufferSize,
f64
);
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_eigh_c32_run(
uplo: i32,
n: i32,
lda: i32,
a_inout: *mut c_void,
eigenvalues_out: *mut c_void,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || lda < n
|| a_inout.is_null() || eigenvalues_out.is_null() || info_out.is_null()
{
return INVALID;
}
if !matches!(uplo, CUBLAS_FILL_MODE_LOWER | CUBLAS_FILL_MODE_UPPER) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe {
cusolverDnCheevd_bufferSize(
h.h, CUSOLVER_EIG_MODE_VECTOR, uplo, n, ptr::null(), lda, ptr::null(), &mut lwork,
)
};
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<cuComplex>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
cusolverDnCheevd(
h.h,
CUSOLVER_EIG_MODE_VECTOR,
uplo,
n,
a_inout as *mut cuComplex,
lda,
eigenvalues_out as *mut f32,
workspace as *mut cuComplex,
lwork,
info_out,
)
};
map_cusolver(st)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_eigh_c32_workspace_size(
uplo: i32,
n: i32,
out_bytes: *mut usize,
) -> i32 {
if n <= 0 || out_bytes.is_null() {
return INVALID;
}
if !matches!(uplo, CUBLAS_FILL_MODE_LOWER | CUBLAS_FILL_MODE_UPPER) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe {
cusolverDnCheevd_bufferSize(
h.h, CUSOLVER_EIG_MODE_VECTOR, uplo, n, ptr::null(), n, ptr::null(), &mut lwork,
)
};
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<cuComplex>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_eigh_c64_run(
uplo: i32,
n: i32,
lda: i32,
a_inout: *mut c_void,
eigenvalues_out: *mut c_void,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || lda < n
|| a_inout.is_null() || eigenvalues_out.is_null() || info_out.is_null()
{
return INVALID;
}
if !matches!(uplo, CUBLAS_FILL_MODE_LOWER | CUBLAS_FILL_MODE_UPPER) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe {
cusolverDnZheevd_bufferSize(
h.h, CUSOLVER_EIG_MODE_VECTOR, uplo, n, ptr::null(), lda, ptr::null(), &mut lwork,
)
};
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<cuDoubleComplex>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
cusolverDnZheevd(
h.h,
CUSOLVER_EIG_MODE_VECTOR,
uplo,
n,
a_inout as *mut cuDoubleComplex,
lda,
eigenvalues_out as *mut f64,
workspace as *mut cuDoubleComplex,
lwork,
info_out,
)
};
map_cusolver(st)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_eigh_c64_workspace_size(
uplo: i32,
n: i32,
out_bytes: *mut usize,
) -> i32 {
if n <= 0 || out_bytes.is_null() {
return INVALID;
}
if !matches!(uplo, CUBLAS_FILL_MODE_LOWER | CUBLAS_FILL_MODE_UPPER) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe {
cusolverDnZheevd_bufferSize(
h.h, CUSOLVER_EIG_MODE_VECTOR, uplo, n, ptr::null(), n, ptr::null(), &mut lwork,
)
};
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<cuDoubleComplex>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_eig_workspace_size(
dtype_tag: cudaDataType,
jobvl: i32,
jobvr: i32,
n: i64,
out_device_bytes: *mut usize,
out_host_bytes: *mut usize,
) -> i32 {
if n <= 0 || out_device_bytes.is_null() || out_host_bytes.is_null() {
return INVALID;
}
if !matches!(jobvl, CUSOLVER_EIG_MODE_VECTOR | CUSOLVER_EIG_MODE_NOVECTOR) {
return INVALID;
}
if !matches!(jobvr, CUSOLVER_EIG_MODE_VECTOR | CUSOLVER_EIG_MODE_NOVECTOR) {
return INVALID;
}
if !matches!(dtype_tag, CUDA_R_32F | CUDA_R_64F | CUDA_C_32F | CUDA_C_64F) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut p = Params::new();
let st = unsafe { cusolverDnCreateParams(&mut p.p as *mut _) };
if st != 0 {
return INTERNAL;
}
let mut ws_dev: usize = 0;
let mut ws_host: usize = 0;
let st = unsafe {
cusolverDnXgeev_bufferSize(
h.h,
p.p,
jobvl,
jobvr,
n,
dtype_tag,
ptr::null(),
n,
dtype_tag,
ptr::null(),
dtype_tag,
ptr::null(),
n,
dtype_tag,
ptr::null(),
n,
dtype_tag,
&mut ws_dev,
&mut ws_host,
)
};
if st != 0 {
return INTERNAL;
}
unsafe {
*out_device_bytes = ws_dev;
*out_host_bytes = ws_host;
}
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn baracuda_kernels_eig_run(
dtype_tag: cudaDataType,
jobvl: i32,
jobvr: i32,
n: i64,
lda: i64,
ldvl: i64,
ldvr: i64,
a_inout: *mut c_void,
w_out: *mut c_void,
vl_out: *mut c_void,
vr_out: *mut c_void,
info_out: *mut i32,
workspace_device: *mut c_void,
workspace_bytes_device: usize,
workspace_host: *mut c_void,
workspace_bytes_host: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || lda < n
|| a_inout.is_null() || w_out.is_null() || info_out.is_null()
{
return INVALID;
}
if !matches!(jobvl, CUSOLVER_EIG_MODE_VECTOR | CUSOLVER_EIG_MODE_NOVECTOR) {
return INVALID;
}
if !matches!(jobvr, CUSOLVER_EIG_MODE_VECTOR | CUSOLVER_EIG_MODE_NOVECTOR) {
return INVALID;
}
if !matches!(dtype_tag, CUDA_R_32F | CUDA_R_64F | CUDA_C_32F | CUDA_C_64F) {
return INVALID;
}
if jobvl == CUSOLVER_EIG_MODE_VECTOR && (vl_out.is_null() || ldvl < n) {
return INVALID;
}
if jobvr == CUSOLVER_EIG_MODE_VECTOR && (vr_out.is_null() || ldvr < n) {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut p = Params::new();
let st = unsafe { cusolverDnCreateParams(&mut p.p as *mut _) };
if st != 0 {
return INTERNAL;
}
let mut ws_dev: usize = 0;
let mut ws_host: usize = 0;
let st = unsafe {
cusolverDnXgeev_bufferSize(
h.h,
p.p,
jobvl,
jobvr,
n,
dtype_tag,
ptr::null(),
lda,
dtype_tag,
ptr::null(),
dtype_tag,
ptr::null(),
if ldvl > 0 { ldvl } else { n },
dtype_tag,
ptr::null(),
if ldvr > 0 { ldvr } else { n },
dtype_tag,
&mut ws_dev,
&mut ws_host,
)
};
if st != 0 {
return INTERNAL;
}
if check_ws(workspace_device, workspace_bytes_device, ws_dev) != OK {
return WS_TOO_SMALL;
}
if ws_host > 0 && (workspace_host.is_null() || workspace_bytes_host < ws_host) {
return WS_TOO_SMALL;
}
let st = unsafe {
cusolverDnXgeev(
h.h,
p.p,
jobvl,
jobvr,
n,
dtype_tag,
a_inout,
lda,
dtype_tag,
w_out,
dtype_tag,
vl_out,
if ldvl > 0 { ldvl } else { n },
dtype_tag,
vr_out,
if ldvr > 0 { ldvr } else { n },
dtype_tag,
workspace_device,
workspace_bytes_device,
workspace_host,
workspace_bytes_host,
info_out,
)
};
map_cusolver(st)
}
macro_rules! lstsq_pair {
($name:ident, $ws_name:ident, $gels:ident, $gels_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(
m: i32,
n: i32,
nrhs: i32,
out_bytes: *mut usize,
) -> i32 {
if m <= 0 || n <= 0 || nrhs <= 0 || m < n || out_bytes.is_null() {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork_bytes: usize = 0;
let st = unsafe {
$gels_bs(
h.h, m, n, nrhs, ptr::null_mut(), m, ptr::null_mut(), m, ptr::null_mut(), n,
ptr::null_mut(), &mut lwork_bytes,
)
};
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = lwork_bytes };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
m: i32,
n: i32,
nrhs: i32,
lda: i32,
ldb: i32,
ldx: i32,
a_inout: *mut c_void,
b_inout: *mut c_void,
x_out: *mut c_void,
niters_out: *mut i32,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if m <= 0 || n <= 0 || nrhs <= 0 || m < n
|| lda < m || ldb < m || ldx < n
|| a_inout.is_null() || b_inout.is_null() || x_out.is_null()
|| niters_out.is_null() || info_out.is_null()
{
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork_bytes: usize = 0;
let st = unsafe {
$gels_bs(
h.h, m, n, nrhs, ptr::null_mut(), lda, ptr::null_mut(), ldb,
ptr::null_mut(), ldx, ptr::null_mut(), &mut lwork_bytes,
)
};
if st != 0 {
return INTERNAL;
}
if workspace_bytes < lwork_bytes {
return WS_TOO_SMALL;
}
let st = unsafe {
$gels(
h.h,
m,
n,
nrhs,
a_inout as *mut $T,
lda,
b_inout as *mut $T,
ldb,
x_out as *mut $T,
ldx,
workspace,
lwork_bytes,
niters_out,
info_out,
)
};
map_cusolver(st)
}
};
}
lstsq_pair!(
baracuda_kernels_lstsq_f32_run,
baracuda_kernels_lstsq_f32_workspace_size,
cusolverDnSSgels,
cusolverDnSSgels_bufferSize,
f32
);
lstsq_pair!(
baracuda_kernels_lstsq_f64_run,
baracuda_kernels_lstsq_f64_workspace_size,
cusolverDnDDgels,
cusolverDnDDgels_bufferSize,
f64
);
macro_rules! solve_pair {
($name:ident, $ws_name:ident, $getrf:ident, $getrs:ident, $getrf_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(n: i32, lda: i32, out_bytes: *mut usize) -> i32 {
if n <= 0 || lda < n || out_bytes.is_null() {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $getrf_bs(h.h, n, n, ptr::null_mut(), lda, &mut lwork) };
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<$T>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
n: i32,
nrhs: i32,
lda: i32,
ldb: i32,
a_inout: *mut c_void,
pivots_out: *mut i32,
b_inout: *mut c_void,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || nrhs <= 0 || lda < n || ldb < n
|| a_inout.is_null() || pivots_out.is_null() || b_inout.is_null()
|| info_out.is_null()
{
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $getrf_bs(h.h, n, n, ptr::null_mut(), lda, &mut lwork) };
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<$T>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
$getrf(
h.h,
n,
n,
a_inout as *mut $T,
lda,
workspace as *mut $T,
pivots_out,
info_out,
)
};
if st != 0 {
return INTERNAL;
}
let st = unsafe {
$getrs(
h.h,
CUBLAS_OP_N,
n,
nrhs,
a_inout as *const $T,
lda,
pivots_out as *const i32,
b_inout as *mut $T,
ldb,
info_out,
)
};
map_cusolver(st)
}
};
}
solve_pair!(
baracuda_kernels_solve_f32_run,
baracuda_kernels_solve_f32_workspace_size,
cusolverDnSgetrf,
cusolverDnSgetrs,
cusolverDnSgetrf_bufferSize,
f32
);
solve_pair!(
baracuda_kernels_solve_f64_run,
baracuda_kernels_solve_f64_workspace_size,
cusolverDnDgetrf,
cusolverDnDgetrs,
cusolverDnDgetrf_bufferSize,
f64
);
macro_rules! inverse_pair {
($name:ident, $ws_name:ident, $getrf:ident, $getrs:ident, $getrf_bs:ident, $T:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws_name(n: i32, lda: i32, out_bytes: *mut usize) -> i32 {
if n <= 0 || lda < n || out_bytes.is_null() {
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, ptr::null_mut()) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $getrf_bs(h.h, n, n, ptr::null_mut(), lda, &mut lwork) };
if st != 0 {
return INTERNAL;
}
unsafe { *out_bytes = (lwork as usize) * core::mem::size_of::<$T>() };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $name(
n: i32,
lda: i32,
ldinv: i32,
a_inout: *mut c_void,
pivots_out: *mut i32,
inv_inout: *mut c_void,
info_out: *mut i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || lda < n || ldinv < n
|| a_inout.is_null() || pivots_out.is_null() || inv_inout.is_null()
|| info_out.is_null()
{
return INVALID;
}
let mut h = Handle::new();
let s = unsafe { setup_handle(&mut h, stream) };
if s != OK {
return s;
}
let mut lwork: i32 = 0;
let st = unsafe { $getrf_bs(h.h, n, n, ptr::null_mut(), lda, &mut lwork) };
if st != 0 {
return INTERNAL;
}
let needed = (lwork as usize) * core::mem::size_of::<$T>();
let s = check_ws(workspace, workspace_bytes, needed);
if s != OK {
return s;
}
let st = unsafe {
$getrf(
h.h,
n,
n,
a_inout as *mut $T,
lda,
workspace as *mut $T,
pivots_out,
info_out,
)
};
if st != 0 {
return INTERNAL;
}
let st = unsafe {
$getrs(
h.h,
CUBLAS_OP_N,
n,
n,
a_inout as *const $T,
lda,
pivots_out as *const i32,
inv_inout as *mut $T,
ldinv,
info_out,
)
};
map_cusolver(st)
}
};
}
inverse_pair!(
baracuda_kernels_inverse_f32_run,
baracuda_kernels_inverse_f32_workspace_size,
cusolverDnSgetrf,
cusolverDnSgetrs,
cusolverDnSgetrf_bufferSize,
f32
);
inverse_pair!(
baracuda_kernels_inverse_f64_run,
baracuda_kernels_inverse_f64_workspace_size,
cusolverDnDgetrf,
cusolverDnDgetrs,
cusolverDnDgetrf_bufferSize,
f64
);