#[allow(unused_imports)]
use crate::{eigen::xsyevd, error::Status};
use std::ptr;
use singe_cuda::{
data_type::{DataType, DataTypeLike},
memory::DeviceMemory,
types::{Complex32, Complex64},
};
use crate::{
context::Context,
error::{Error, Result},
layout::{
ByteWorkspaceMut, MatrixMut, MatrixRef, StridedBatchedMatrixMut, StridedBatchedMatrixRef,
StridedBatchedVectorMut, StridedBatchedVectorRef, WorkspaceSizes,
},
params::Params,
sys, try_ffi,
types::{EigenMode, SvdMode, TruncatedSvdMode},
utility::{to_i32, to_i64, to_usize},
};
#[derive(Debug)]
pub struct GesvdjInfo {
handle: sys::gesvdjInfo_t,
}
unsafe impl Send for GesvdjInfo {}
unsafe impl Sync for GesvdjInfo {}
impl GesvdjInfo {
pub fn create() -> Result<Self> {
let mut handle = ptr::null_mut();
unsafe {
try_ffi!(sys::cusolverDnCreateGesvdjInfo(&raw mut handle))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Self { handle })
}
pub fn set_tolerance(&mut self, tolerance: f64) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnXgesvdjSetTolerance(self.as_raw(), tolerance,))?;
}
Ok(())
}
pub fn set_max_sweeps(&mut self, max_sweeps: i32) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnXgesvdjSetMaxSweeps(
self.as_raw(),
max_sweeps,
))?;
}
Ok(())
}
pub fn set_sort_eigenvalues(&mut self, sort_eigenvalues: bool) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnXgesvdjSetSortEig(
self.as_raw(),
i32::from(sort_eigenvalues),
))?;
}
Ok(())
}
pub fn residual(&self, ctx: &Context) -> Result<f64> {
ctx.bind()?;
let mut residual = 0.0;
unsafe {
try_ffi!(sys::cusolverDnXgesvdjGetResidual(
ctx.as_raw(),
self.as_raw(),
&raw mut residual,
))?;
}
Ok(residual)
}
pub fn executed_sweeps(&self, ctx: &Context) -> Result<i32> {
ctx.bind()?;
let mut sweeps = 0;
unsafe {
try_ffi!(sys::cusolverDnXgesvdjGetSweeps(
ctx.as_raw(),
self.as_raw(),
&raw mut sweeps,
))?;
}
Ok(sweeps)
}
pub fn as_raw(&self) -> sys::gesvdjInfo_t {
self.handle
}
}
impl Drop for GesvdjInfo {
fn drop(&mut self) {
unsafe {
if let Err(err) = try_ffi!(sys::cusolverDnDestroyGesvdjInfo(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy cusolver gesvdj info: {err}");
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Gesvd {
pub job_u: SvdMode,
pub job_vt: SvdMode,
pub rows: usize,
pub columns: usize,
}
impl Gesvd {
pub fn new(job_u: SvdMode, job_vt: SvdMode, rows: usize, columns: usize) -> Self {
Self {
job_u,
job_vt,
rows,
columns,
}
}
pub fn workspace_size<
TA: DataTypeLike,
TS: DataTypeLike,
TU: DataTypeLike,
TVT: DataTypeLike,
>(
self,
ctx: &Context,
params: &Params,
input: GesvdInput<'_, TA, TS, TU, TVT>,
) -> Result<WorkspaceSizes> {
xgesvd_buffer_size(
ctx,
params,
self.job_u,
self.job_vt,
self.rows,
self.columns,
input.a,
input.singular_values,
input.left_vectors,
input.right_vectors_transposed,
)
}
pub fn execute<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TVT: DataTypeLike>(
self,
ctx: &Context,
params: &Params,
bindings: GesvdBindings<'_, TA, TS, TU, TVT>,
) -> Result<()> {
xgesvd(
ctx,
params,
self.job_u,
self.job_vt,
self.rows,
self.columns,
bindings.a,
bindings.singular_values,
bindings.left_vectors,
bindings.right_vectors_transposed,
bindings.workspace,
bindings.dev_info,
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct GesvdInput<'a, TA, TS, TU, TVT> {
pub a: MatrixRef<'a, TA>,
pub singular_values: &'a DeviceMemory<TS>,
pub left_vectors: Option<MatrixRef<'a, TU>>,
pub right_vectors_transposed: Option<MatrixRef<'a, TVT>>,
}
#[derive(Debug)]
pub struct GesvdBindings<'a, TA, TS, TU, TVT> {
pub a: MatrixMut<'a, TA>,
pub singular_values: &'a mut DeviceMemory<TS>,
pub left_vectors: Option<MatrixMut<'a, TU>>,
pub right_vectors_transposed: Option<MatrixMut<'a, TVT>>,
pub workspace: ByteWorkspaceMut<'a>,
pub dev_info: &'a mut DeviceMemory<i32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Gesvdj {
pub mode: EigenMode,
pub economy: bool,
pub rows: usize,
pub columns: usize,
}
impl Gesvdj {
pub fn new(mode: EigenMode, economy: bool, rows: usize, columns: usize) -> Self {
Self {
mode,
economy,
rows,
columns,
}
}
pub fn workspace_size_f32(
self,
ctx: &Context,
input: GesvdjInput<'_, f32, f32>,
params: &GesvdjInfo,
) -> Result<usize> {
sgesvdj_buffer_size(
ctx,
self.mode,
self.economy,
self.rows,
self.columns,
input.a,
input.singular_values,
input.left_vectors,
input.right_vectors,
params,
)
}
pub fn execute_f32(
self,
ctx: &Context,
bindings: GesvdjBindings<'_, f32, f32>,
params: &GesvdjInfo,
) -> Result<()> {
sgesvdj(
ctx,
self.mode,
self.economy,
self.rows,
self.columns,
bindings.a,
bindings.singular_values,
bindings.left_vectors,
bindings.right_vectors,
bindings.workspace,
bindings.dev_info,
params,
)
}
pub fn workspace_size_f64(
self,
ctx: &Context,
input: GesvdjInput<'_, f64, f64>,
params: &GesvdjInfo,
) -> Result<usize> {
dgesvdj_buffer_size(
ctx,
self.mode,
self.economy,
self.rows,
self.columns,
input.a,
input.singular_values,
input.left_vectors,
input.right_vectors,
params,
)
}
pub fn execute_f64(
self,
ctx: &Context,
bindings: GesvdjBindings<'_, f64, f64>,
params: &GesvdjInfo,
) -> Result<()> {
dgesvdj(
ctx,
self.mode,
self.economy,
self.rows,
self.columns,
bindings.a,
bindings.singular_values,
bindings.left_vectors,
bindings.right_vectors,
bindings.workspace,
bindings.dev_info,
params,
)
}
pub fn workspace_size_complex_f32(
self,
ctx: &Context,
input: GesvdjInput<'_, Complex32, f32>,
params: &GesvdjInfo,
) -> Result<usize> {
cgesvdj_buffer_size(
ctx,
self.mode,
self.economy,
self.rows,
self.columns,
input.a,
input.singular_values,
input.left_vectors,
input.right_vectors,
params,
)
}
pub fn execute_complex_f32(
self,
ctx: &Context,
bindings: GesvdjBindings<'_, Complex32, f32>,
params: &GesvdjInfo,
) -> Result<()> {
cgesvdj(
ctx,
self.mode,
self.economy,
self.rows,
self.columns,
bindings.a,
bindings.singular_values,
bindings.left_vectors,
bindings.right_vectors,
bindings.workspace,
bindings.dev_info,
params,
)
}
pub fn workspace_size_complex_f64(
self,
ctx: &Context,
input: GesvdjInput<'_, Complex64, f64>,
params: &GesvdjInfo,
) -> Result<usize> {
zgesvdj_buffer_size(
ctx,
self.mode,
self.economy,
self.rows,
self.columns,
input.a,
input.singular_values,
input.left_vectors,
input.right_vectors,
params,
)
}
pub fn execute_complex_f64(
self,
ctx: &Context,
bindings: GesvdjBindings<'_, Complex64, f64>,
params: &GesvdjInfo,
) -> Result<()> {
zgesvdj(
ctx,
self.mode,
self.economy,
self.rows,
self.columns,
bindings.a,
bindings.singular_values,
bindings.left_vectors,
bindings.right_vectors,
bindings.workspace,
bindings.dev_info,
params,
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct GesvdjInput<'a, TA, TS> {
pub a: MatrixRef<'a, TA>,
pub singular_values: &'a DeviceMemory<TS>,
pub left_vectors: Option<MatrixRef<'a, TA>>,
pub right_vectors: Option<MatrixRef<'a, TA>>,
}
#[derive(Debug)]
pub struct GesvdjBindings<'a, TA, TS> {
pub a: MatrixMut<'a, TA>,
pub singular_values: &'a mut DeviceMemory<TS>,
pub left_vectors: Option<MatrixMut<'a, TA>>,
pub right_vectors: Option<MatrixMut<'a, TA>>,
pub workspace: &'a mut DeviceMemory<TA>,
pub dev_info: &'a mut DeviceMemory<i32>,
}
pub fn xgesvd_buffer_size<
TA: DataTypeLike,
TS: DataTypeLike,
TU: DataTypeLike,
TVT: DataTypeLike,
>(
ctx: &Context,
params: &Params,
job_u: SvdMode,
job_vt: SvdMode,
m: usize,
n: usize,
a: MatrixRef<'_, TA>,
s: &DeviceMemory<TS>,
u: Option<MatrixRef<'_, TU>>,
vt: Option<MatrixRef<'_, TVT>>,
) -> Result<WorkspaceSizes> {
let a_type = TA::data_type();
let s_type = TS::data_type();
let u_type = TU::data_type();
let vt_type = TVT::data_type();
ctx.bind()?;
validate_gesvd_dims(m, n)?;
validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
validate_x_vector(m.min(n), s.byte_len(), s_type)?;
validate_x_svd_output(m, m, matrix_ref_parts(u), job_u, u_type)?;
validate_x_svd_output(n, n, matrix_ref_parts(vt), job_vt, vt_type)?;
if matches!(job_u, SvdMode::Overwrite) && matches!(job_vt, SvdMode::Overwrite) {
return Err(Error::InvalidSvdMode);
}
let (u_ptr, ldu) = optional_x_matrix_ptr(matrix_ref_parts(u), m, m, job_u, u_type)?;
let (vt_ptr, ldvt) = optional_x_matrix_ptr(matrix_ref_parts(vt), n, n, job_vt, vt_type)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXgesvd_bufferSize(
ctx.as_raw(),
params.as_raw(),
job_u.as_raw(),
job_vt.as_raw(),
to_i64(m, "m")?,
to_i64(n, "n")?,
a_type.into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
s_type.into(),
s.as_ptr().cast(),
u_type.into(),
u_ptr.cast(),
ldu,
vt_type.into(),
vt_ptr.cast(),
ldvt,
a_type.into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xgesvd<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TVT: DataTypeLike>(
ctx: &Context,
params: &Params,
job_u: SvdMode,
job_vt: SvdMode,
m: usize,
n: usize,
a: MatrixMut<'_, TA>,
s: &mut DeviceMemory<TS>,
u: Option<MatrixMut<'_, TU>>,
vt: Option<MatrixMut<'_, TVT>>,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
let a_type = TA::data_type();
let s_type = TS::data_type();
let u_type = TU::data_type();
let vt_type = TVT::data_type();
ctx.bind()?;
validate_gesvd_dims(m, n)?;
validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
validate_x_vector(m.min(n), s.byte_len(), s_type)?;
validate_x_svd_output(m, m, matrix_mut_ref_parts(u.as_ref()), job_u, u_type)?;
validate_x_svd_output(n, n, matrix_mut_ref_parts(vt.as_ref()), job_vt, vt_type)?;
if matches!(job_u, SvdMode::Overwrite) && matches!(job_vt, SvdMode::Overwrite) {
return Err(Error::InvalidSvdMode);
}
require_info_buffer(dev_info)?;
let workspace_sizes = xgesvd_buffer_size(
ctx,
params,
job_u,
job_vt,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(vt.as_ref()),
)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
let (u_ptr, ldu) = optional_x_matrix_mut_ptr(matrix_mut_parts(u), m, m, job_u, u_type)?;
let (vt_ptr, ldvt) = optional_x_matrix_mut_ptr(matrix_mut_parts(vt), n, n, job_vt, vt_type)?;
unsafe {
try_ffi!(sys::cusolverDnXgesvd(
ctx.as_raw(),
params.as_raw(),
job_u.as_raw(),
job_vt.as_raw(),
to_i64(m, "m")?,
to_i64(n, "n")?,
a_type.into(),
a.data.as_mut_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
s_type.into(),
s.as_mut_ptr().cast(),
u_type.into(),
u_ptr.cast(),
ldu,
vt_type.into(),
vt_ptr.cast(),
ldvt,
a_type.into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn xgesvdp_buffer_size<
TA: DataTypeLike,
TS: DataTypeLike,
TU: DataTypeLike,
TV: DataTypeLike,
>(
ctx: &Context,
params: &Params,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixRef<'_, TA>,
s: &DeviceMemory<TS>,
u: Option<MatrixRef<'_, TU>>,
v: Option<MatrixRef<'_, TV>>,
) -> Result<WorkspaceSizes> {
let a_type = TA::data_type();
let s_type = TS::data_type();
let u_type = TU::data_type();
let v_type = TV::data_type();
ctx.bind()?;
validate_xgesvdp_inputs(
m,
n,
a.data.byte_len(),
a.leading_dimension,
a_type,
s.byte_len(),
s_type,
jobz,
econ,
matrix_ref_parts(u).as_ref(),
u_type,
matrix_ref_parts(v).as_ref(),
v_type,
)?;
let (u_ptr, ldu) = optional_x_eig_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ, u_type)?;
let (v_ptr, ldv) = optional_x_eig_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ, v_type)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXgesvdp_bufferSize(
ctx.as_raw(),
params.as_raw(),
jobz.into(),
i32::from(econ),
to_i64(m, "m")?,
to_i64(n, "n")?,
a_type.into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
s_type.into(),
s.as_ptr().cast(),
u_type.into(),
u_ptr.cast(),
ldu,
v_type.into(),
v_ptr.cast(),
ldv,
a_type.into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xgesvdp<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TV: DataTypeLike>(
ctx: &Context,
params: &Params,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixMut<'_, TA>,
s: &mut DeviceMemory<TS>,
u: Option<MatrixMut<'_, TU>>,
v: Option<MatrixMut<'_, TV>>,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
err_sigma: Option<&mut f64>,
) -> Result<()> {
let a_type = TA::data_type();
let s_type = TS::data_type();
let u_type = TU::data_type();
let v_type = TV::data_type();
ctx.bind()?;
validate_xgesvdp_inputs(
m,
n,
a.data.byte_len(),
a.leading_dimension,
a_type,
s.byte_len(),
s_type,
jobz,
econ,
matrix_mut_ref_parts(u.as_ref()).as_ref(),
u_type,
matrix_mut_ref_parts(v.as_ref()).as_ref(),
v_type,
)?;
require_info_buffer(dev_info)?;
let workspace_sizes = xgesvdp_buffer_size(
ctx,
params,
jobz,
econ,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
let (u_ptr, ldu) =
optional_x_eig_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ, u_type)?;
let (v_ptr, ldv) =
optional_x_eig_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ, v_type)?;
unsafe {
try_ffi!(sys::cusolverDnXgesvdp(
ctx.as_raw(),
params.as_raw(),
jobz.into(),
i32::from(econ),
to_i64(m, "m")?,
to_i64(n, "n")?,
a_type.into(),
a.data.as_mut_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
s_type.into(),
s.as_mut_ptr().cast(),
u_type.into(),
u_ptr.cast(),
ldu,
v_type.into(),
v_ptr.cast(),
ldv,
a_type.into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
err_sigma.map_or(ptr::null_mut(), |value| value as *mut f64),
))?;
}
Ok(())
}
pub fn xgesvdr_buffer_size<
TA: DataTypeLike,
TS: DataTypeLike,
TU: DataTypeLike,
TV: DataTypeLike,
>(
ctx: &Context,
params: &Params,
job_u: TruncatedSvdMode,
job_v: TruncatedSvdMode,
m: usize,
n: usize,
k: usize,
p: usize,
niters: usize,
a: MatrixRef<'_, TA>,
s: &DeviceMemory<TS>,
u: Option<MatrixRef<'_, TU>>,
v: Option<MatrixRef<'_, TV>>,
) -> Result<WorkspaceSizes> {
let a_type = TA::data_type();
let s_type = TS::data_type();
let u_type = TU::data_type();
let v_type = TV::data_type();
ctx.bind()?;
validate_xgesvdr_inputs(
m,
n,
k,
p,
niters,
a.data.byte_len(),
a.leading_dimension,
a_type,
s.byte_len(),
s_type,
job_u,
matrix_ref_parts(u).as_ref(),
u_type,
job_v,
matrix_ref_parts(v).as_ref(),
v_type,
)?;
let (u_ptr, ldu) = optional_x_truncated_u_ptr(matrix_ref_parts(u), m, k, job_u, u_type)?;
let (v_ptr, ldv) = optional_x_truncated_v_ptr(matrix_ref_parts(v), n, k, job_v, v_type)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXgesvdr_bufferSize(
ctx.as_raw(),
params.as_raw(),
job_u.as_raw(),
job_v.as_raw(),
to_i64(m, "m")?,
to_i64(n, "n")?,
to_i64(k, "k")?,
to_i64(p, "p")?,
to_i64(niters, "niters")?,
a_type.into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
s_type.into(),
s.as_ptr().cast(),
u_type.into(),
u_ptr.cast(),
ldu,
v_type.into(),
v_ptr.cast(),
ldv,
a_type.into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xgesvdr<TA: DataTypeLike, TS: DataTypeLike, TU: DataTypeLike, TV: DataTypeLike>(
ctx: &Context,
params: &Params,
job_u: TruncatedSvdMode,
job_v: TruncatedSvdMode,
m: usize,
n: usize,
k: usize,
p: usize,
niters: usize,
a: MatrixMut<'_, TA>,
s: &mut DeviceMemory<TS>,
u: Option<MatrixMut<'_, TU>>,
v: Option<MatrixMut<'_, TV>>,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
let a_type = TA::data_type();
let s_type = TS::data_type();
let u_type = TU::data_type();
let v_type = TV::data_type();
ctx.bind()?;
validate_xgesvdr_inputs(
m,
n,
k,
p,
niters,
a.data.byte_len(),
a.leading_dimension,
a_type,
s.byte_len(),
s_type,
job_u,
matrix_mut_ref_parts(u.as_ref()).as_ref(),
u_type,
job_v,
matrix_mut_ref_parts(v.as_ref()).as_ref(),
v_type,
)?;
require_info_buffer(dev_info)?;
let workspace_sizes = xgesvdr_buffer_size(
ctx,
params,
job_u,
job_v,
m,
n,
k,
p,
niters,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
let (u_ptr, ldu) = optional_x_truncated_u_mut_ptr(matrix_mut_parts(u), m, k, job_u, u_type)?;
let (v_ptr, ldv) = optional_x_truncated_v_mut_ptr(matrix_mut_parts(v), n, k, job_v, v_type)?;
unsafe {
try_ffi!(sys::cusolverDnXgesvdr(
ctx.as_raw(),
params.as_raw(),
job_u.as_raw(),
job_v.as_raw(),
to_i64(m, "m")?,
to_i64(n, "n")?,
to_i64(k, "k")?,
to_i64(p, "p")?,
to_i64(niters, "niters")?,
a_type.into(),
a.data.as_mut_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
s_type.into(),
s.as_mut_ptr().cast(),
u_type.into(),
u_ptr.cast(),
ldu,
v_type.into(),
v_ptr.cast(),
ldv,
a_type.into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
ctx.bind()?;
validate_gesvd_dims(m, n)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSgesvd_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
ctx.bind()?;
validate_gesvd_dims(m, n)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDgesvd_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
ctx.bind()?;
validate_gesvd_dims(m, n)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCgesvd_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zgesvd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
ctx.bind()?;
validate_gesvd_dims(m, n)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZgesvd_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sgesvd(
ctx: &Context,
job_u: SvdMode,
job_vt: SvdMode,
m: usize,
n: usize,
a: MatrixMut<'_, f32>,
s: &mut DeviceMemory<f32>,
u: Option<MatrixMut<'_, f32>>,
vt: Option<MatrixMut<'_, f32>>,
workspace: &mut DeviceMemory<f32>,
rwork: Option<&mut DeviceMemory<f32>>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_gesvd_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
job_u,
matrix_mut_ref_parts(u.as_ref()).as_ref(),
job_vt,
matrix_mut_ref_parts(vt.as_ref()).as_ref(),
)?;
require_info_buffer(dev_info)?;
require_rwork_buffer(rwork.as_deref(), m, n)?;
let lwork = sgesvd_buffer_size(ctx, m, n)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
unsafe {
try_ffi!(sys::cusolverDnSgesvd(
ctx.as_raw(),
job_u.as_raw(),
job_vt.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
vt_ptr.cast(),
ldvt,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dgesvd(
ctx: &Context,
job_u: SvdMode,
job_vt: SvdMode,
m: usize,
n: usize,
a: MatrixMut<'_, f64>,
s: &mut DeviceMemory<f64>,
u: Option<MatrixMut<'_, f64>>,
vt: Option<MatrixMut<'_, f64>>,
workspace: &mut DeviceMemory<f64>,
rwork: Option<&mut DeviceMemory<f64>>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_gesvd_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
job_u,
matrix_mut_ref_parts(u.as_ref()).as_ref(),
job_vt,
matrix_mut_ref_parts(vt.as_ref()).as_ref(),
)?;
require_info_buffer(dev_info)?;
require_rwork_buffer(rwork.as_deref(), m, n)?;
let lwork = dgesvd_buffer_size(ctx, m, n)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
unsafe {
try_ffi!(sys::cusolverDnDgesvd(
ctx.as_raw(),
job_u.as_raw(),
job_vt.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
vt_ptr.cast(),
ldvt,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cgesvd(
ctx: &Context,
job_u: SvdMode,
job_vt: SvdMode,
m: usize,
n: usize,
a: MatrixMut<'_, Complex32>,
s: &mut DeviceMemory<f32>,
u: Option<MatrixMut<'_, Complex32>>,
vt: Option<MatrixMut<'_, Complex32>>,
workspace: &mut DeviceMemory<Complex32>,
rwork: Option<&mut DeviceMemory<f32>>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_gesvd_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
job_u,
matrix_mut_ref_parts(u.as_ref()).as_ref(),
job_vt,
matrix_mut_ref_parts(vt.as_ref()).as_ref(),
)?;
require_info_buffer(dev_info)?;
require_rwork_buffer(rwork.as_deref(), m, n)?;
let lwork = cgesvd_buffer_size(ctx, m, n)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
unsafe {
try_ffi!(sys::cusolverDnCgesvd(
ctx.as_raw(),
job_u.as_raw(),
job_vt.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
vt_ptr.cast(),
ldvt,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zgesvd(
ctx: &Context,
job_u: SvdMode,
job_vt: SvdMode,
m: usize,
n: usize,
a: MatrixMut<'_, Complex64>,
s: &mut DeviceMemory<f64>,
u: Option<MatrixMut<'_, Complex64>>,
vt: Option<MatrixMut<'_, Complex64>>,
workspace: &mut DeviceMemory<Complex64>,
rwork: Option<&mut DeviceMemory<f64>>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_gesvd_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
job_u,
matrix_mut_ref_parts(u.as_ref()).as_ref(),
job_vt,
matrix_mut_ref_parts(vt.as_ref()).as_ref(),
)?;
require_info_buffer(dev_info)?;
require_rwork_buffer(rwork.as_deref(), m, n)?;
let lwork = zgesvd_buffer_size(ctx, m, n)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_matrix_ptr(matrix_mut_parts(u), m, job_u)?;
let (vt_ptr, ldvt) = optional_matrix_ptr(matrix_mut_parts(vt), n, job_vt)?;
unsafe {
try_ffi!(sys::cusolverDnZgesvd(
ctx.as_raw(),
job_u.as_raw(),
job_vt.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
vt_ptr.cast(),
ldvt,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
rwork.map_or(ptr::null_mut(), |buffer| buffer.as_mut_ptr()),
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sgesvdj_buffer_size(
ctx: &Context,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixRef<'_, f32>,
s: &DeviceMemory<f32>,
u: Option<MatrixRef<'_, f32>>,
v: Option<MatrixRef<'_, f32>>,
params: &GesvdjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_gesvdj_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
econ,
matrix_ref_parts(u),
matrix_ref_parts(v),
)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSgesvdj_bufferSize(
ctx.as_raw(),
jobz.into(),
i32::from(econ),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn dgesvdj_buffer_size(
ctx: &Context,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixRef<'_, f64>,
s: &DeviceMemory<f64>,
u: Option<MatrixRef<'_, f64>>,
v: Option<MatrixRef<'_, f64>>,
params: &GesvdjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_gesvdj_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
econ,
matrix_ref_parts(u),
matrix_ref_parts(v),
)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDgesvdj_bufferSize(
ctx.as_raw(),
jobz.into(),
i32::from(econ),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn cgesvdj_buffer_size(
ctx: &Context,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixRef<'_, Complex32>,
s: &DeviceMemory<f32>,
u: Option<MatrixRef<'_, Complex32>>,
v: Option<MatrixRef<'_, Complex32>>,
params: &GesvdjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_gesvdj_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
econ,
matrix_ref_parts(u),
matrix_ref_parts(v),
)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCgesvdj_bufferSize(
ctx.as_raw(),
jobz.into(),
i32::from(econ),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn zgesvdj_buffer_size(
ctx: &Context,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixRef<'_, Complex64>,
s: &DeviceMemory<f64>,
u: Option<MatrixRef<'_, Complex64>>,
v: Option<MatrixRef<'_, Complex64>>,
params: &GesvdjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_gesvdj_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
econ,
matrix_ref_parts(u),
matrix_ref_parts(v),
)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, econ)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, econ)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZgesvdj_bufferSize(
ctx.as_raw(),
jobz.into(),
i32::from(econ),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn sgesvdj(
ctx: &Context,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixMut<'_, f32>,
s: &mut DeviceMemory<f32>,
u: Option<MatrixMut<'_, f32>>,
v: Option<MatrixMut<'_, f32>>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
params: &GesvdjInfo,
) -> Result<()> {
ctx.bind()?;
validate_gesvdj_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
econ,
matrix_mut_ref_parts(u.as_ref()),
matrix_mut_ref_parts(v.as_ref()),
)?;
require_info_buffer(dev_info)?;
let lwork = sgesvdj_buffer_size(
ctx,
jobz,
econ,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
params,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
unsafe {
try_ffi!(sys::cusolverDnSgesvdj(
ctx.as_raw(),
jobz.into(),
i32::from(econ),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn dgesvdj(
ctx: &Context,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixMut<'_, f64>,
s: &mut DeviceMemory<f64>,
u: Option<MatrixMut<'_, f64>>,
v: Option<MatrixMut<'_, f64>>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
params: &GesvdjInfo,
) -> Result<()> {
ctx.bind()?;
validate_gesvdj_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
econ,
matrix_mut_ref_parts(u.as_ref()),
matrix_mut_ref_parts(v.as_ref()),
)?;
require_info_buffer(dev_info)?;
let lwork = dgesvdj_buffer_size(
ctx,
jobz,
econ,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
params,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
unsafe {
try_ffi!(sys::cusolverDnDgesvdj(
ctx.as_raw(),
jobz.into(),
i32::from(econ),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn cgesvdj(
ctx: &Context,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixMut<'_, Complex32>,
s: &mut DeviceMemory<f32>,
u: Option<MatrixMut<'_, Complex32>>,
v: Option<MatrixMut<'_, Complex32>>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
params: &GesvdjInfo,
) -> Result<()> {
ctx.bind()?;
validate_gesvdj_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
econ,
matrix_mut_ref_parts(u.as_ref()),
matrix_mut_ref_parts(v.as_ref()),
)?;
require_info_buffer(dev_info)?;
let lwork = cgesvdj_buffer_size(
ctx,
jobz,
econ,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
params,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
unsafe {
try_ffi!(sys::cusolverDnCgesvdj(
ctx.as_raw(),
jobz.into(),
i32::from(econ),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn zgesvdj(
ctx: &Context,
jobz: EigenMode,
econ: bool,
m: usize,
n: usize,
a: MatrixMut<'_, Complex64>,
s: &mut DeviceMemory<f64>,
u: Option<MatrixMut<'_, Complex64>>,
v: Option<MatrixMut<'_, Complex64>>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
params: &GesvdjInfo,
) -> Result<()> {
ctx.bind()?;
validate_gesvdj_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
econ,
matrix_mut_ref_parts(u.as_ref()),
matrix_mut_ref_parts(v.as_ref()),
)?;
require_info_buffer(dev_info)?;
let lwork = zgesvdj_buffer_size(
ctx,
jobz,
econ,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
params,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, econ)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, econ)?;
unsafe {
try_ffi!(sys::cusolverDnZgesvdj(
ctx.as_raw(),
jobz.into(),
i32::from(econ),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn sgesvdj_batched_buffer_size(
ctx: &Context,
jobz: EigenMode,
m: usize,
n: usize,
a: MatrixRef<'_, f32>,
s: &DeviceMemory<f32>,
u: Option<MatrixRef<'_, f32>>,
v: Option<MatrixRef<'_, f32>>,
params: &GesvdjInfo,
batch_size: usize,
) -> Result<usize> {
ctx.bind()?;
validate_gesvdj_batched_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
matrix_ref_parts(u),
matrix_ref_parts(v),
batch_size,
)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSgesvdjBatched_bufferSize(
ctx.as_raw(),
jobz.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
&raw mut lwork,
params.as_raw(),
to_i32(batch_size, "batch_size")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dgesvdj_batched_buffer_size(
ctx: &Context,
jobz: EigenMode,
m: usize,
n: usize,
a: MatrixRef<'_, f64>,
s: &DeviceMemory<f64>,
u: Option<MatrixRef<'_, f64>>,
v: Option<MatrixRef<'_, f64>>,
params: &GesvdjInfo,
batch_size: usize,
) -> Result<usize> {
ctx.bind()?;
validate_gesvdj_batched_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
matrix_ref_parts(u),
matrix_ref_parts(v),
batch_size,
)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDgesvdjBatched_bufferSize(
ctx.as_raw(),
jobz.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
&raw mut lwork,
params.as_raw(),
to_i32(batch_size, "batch_size")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cgesvdj_batched_buffer_size(
ctx: &Context,
jobz: EigenMode,
m: usize,
n: usize,
a: MatrixRef<'_, Complex32>,
s: &DeviceMemory<f32>,
u: Option<MatrixRef<'_, Complex32>>,
v: Option<MatrixRef<'_, Complex32>>,
params: &GesvdjInfo,
batch_size: usize,
) -> Result<usize> {
ctx.bind()?;
validate_gesvdj_batched_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
matrix_ref_parts(u),
matrix_ref_parts(v),
batch_size,
)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCgesvdjBatched_bufferSize(
ctx.as_raw(),
jobz.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
&raw mut lwork,
params.as_raw(),
to_i32(batch_size, "batch_size")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zgesvdj_batched_buffer_size(
ctx: &Context,
jobz: EigenMode,
m: usize,
n: usize,
a: MatrixRef<'_, Complex64>,
s: &DeviceMemory<f64>,
u: Option<MatrixRef<'_, Complex64>>,
v: Option<MatrixRef<'_, Complex64>>,
params: &GesvdjInfo,
batch_size: usize,
) -> Result<usize> {
ctx.bind()?;
validate_gesvdj_batched_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
matrix_ref_parts(u),
matrix_ref_parts(v),
batch_size,
)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_ptr(matrix_ref_parts(u), m, n, jobz, true)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_ptr(matrix_ref_parts(v), n, n, jobz, true)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZgesvdjBatched_bufferSize(
ctx.as_raw(),
jobz.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
&raw mut lwork,
params.as_raw(),
to_i32(batch_size, "batch_size")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sgesvdj_batched(
ctx: &Context,
jobz: EigenMode,
m: usize,
n: usize,
a: MatrixMut<'_, f32>,
s: &mut DeviceMemory<f32>,
u: Option<MatrixMut<'_, f32>>,
v: Option<MatrixMut<'_, f32>>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
params: &GesvdjInfo,
batch_size: usize,
) -> Result<()> {
ctx.bind()?;
validate_gesvdj_batched_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
matrix_mut_ref_parts(u.as_ref()),
matrix_mut_ref_parts(v.as_ref()),
batch_size,
)?;
require_info_buffer_len(dev_info, batch_size)?;
let lwork = sgesvdj_batched_buffer_size(
ctx,
jobz,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
params,
batch_size,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
unsafe {
try_ffi!(sys::cusolverDnSgesvdjBatched(
ctx.as_raw(),
jobz.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
to_i32(batch_size, "batch_size")?,
))?;
}
Ok(())
}
pub fn dgesvdj_batched(
ctx: &Context,
jobz: EigenMode,
m: usize,
n: usize,
a: MatrixMut<'_, f64>,
s: &mut DeviceMemory<f64>,
u: Option<MatrixMut<'_, f64>>,
v: Option<MatrixMut<'_, f64>>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
params: &GesvdjInfo,
batch_size: usize,
) -> Result<()> {
ctx.bind()?;
validate_gesvdj_batched_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
matrix_mut_ref_parts(u.as_ref()),
matrix_mut_ref_parts(v.as_ref()),
batch_size,
)?;
require_info_buffer_len(dev_info, batch_size)?;
let lwork = dgesvdj_batched_buffer_size(
ctx,
jobz,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
params,
batch_size,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
unsafe {
try_ffi!(sys::cusolverDnDgesvdjBatched(
ctx.as_raw(),
jobz.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
to_i32(batch_size, "batch_size")?,
))?;
}
Ok(())
}
pub fn cgesvdj_batched(
ctx: &Context,
jobz: EigenMode,
m: usize,
n: usize,
a: MatrixMut<'_, Complex32>,
s: &mut DeviceMemory<f32>,
u: Option<MatrixMut<'_, Complex32>>,
v: Option<MatrixMut<'_, Complex32>>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
params: &GesvdjInfo,
batch_size: usize,
) -> Result<()> {
ctx.bind()?;
validate_gesvdj_batched_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
matrix_mut_ref_parts(u.as_ref()),
matrix_mut_ref_parts(v.as_ref()),
batch_size,
)?;
require_info_buffer_len(dev_info, batch_size)?;
let lwork = cgesvdj_batched_buffer_size(
ctx,
jobz,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
params,
batch_size,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
unsafe {
try_ffi!(sys::cusolverDnCgesvdjBatched(
ctx.as_raw(),
jobz.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
to_i32(batch_size, "batch_size")?,
))?;
}
Ok(())
}
pub fn zgesvdj_batched(
ctx: &Context,
jobz: EigenMode,
m: usize,
n: usize,
a: MatrixMut<'_, Complex64>,
s: &mut DeviceMemory<f64>,
u: Option<MatrixMut<'_, Complex64>>,
v: Option<MatrixMut<'_, Complex64>>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
params: &GesvdjInfo,
batch_size: usize,
) -> Result<()> {
ctx.bind()?;
validate_gesvdj_batched_inputs(
m,
n,
a.data.len(),
a.leading_dimension,
s.len(),
jobz,
matrix_mut_ref_parts(u.as_ref()),
matrix_mut_ref_parts(v.as_ref()),
batch_size,
)?;
require_info_buffer_len(dev_info, batch_size)?;
let lwork = zgesvdj_batched_buffer_size(
ctx,
jobz,
m,
n,
a.as_ref(),
s,
matrix_mut_ref_option(u.as_ref()),
matrix_mut_ref_option(v.as_ref()),
params,
batch_size,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(u), m, n, jobz, true)?;
let (v_ptr, ldv) = optional_gesvdj_matrix_mut_ptr(matrix_mut_parts(v), n, n, jobz, true)?;
unsafe {
try_ffi!(sys::cusolverDnZgesvdjBatched(
ctx.as_raw(),
jobz.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
s.as_mut_ptr().cast(),
u_ptr.cast(),
ldu,
v_ptr.cast(),
ldv,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
to_i32(batch_size, "batch_size")?,
))?;
}
Ok(())
}
pub fn sgesvda_strided_batched_buffer_size(
ctx: &Context,
jobz: EigenMode,
rank: usize,
m: usize,
n: usize,
a: StridedBatchedMatrixRef<'_, f32>,
s: StridedBatchedVectorRef<'_, f32>,
u: Option<StridedBatchedMatrixRef<'_, f32>>,
v: Option<StridedBatchedMatrixRef<'_, f32>>,
batch_size: usize,
) -> Result<usize> {
ctx.bind()?;
validate_gesvda_strided_batched_inputs(
rank,
m,
n,
a.data.len(),
a.leading_dimension,
a.stride,
s.data.len(),
s.stride,
jobz,
strided_batched_matrix_ref_parts(u),
strided_batched_matrix_ref_parts(v),
batch_size,
)?;
let (u_ptr, ldu, stride_u) =
optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
let (v_ptr, ldv, stride_v) =
optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSgesvdaStridedBatched_bufferSize(
ctx.as_raw(),
jobz.into(),
to_i32(rank, "rank")?,
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
to_i64(a.stride, "stride_a")?,
s.data.as_ptr().cast(),
to_i64(s.stride, "stride_s")?,
u_ptr.cast(),
ldu,
stride_u,
v_ptr.cast(),
ldv,
stride_v,
&raw mut lwork,
to_i32(batch_size, "batch_size")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dgesvda_strided_batched_buffer_size(
ctx: &Context,
jobz: EigenMode,
rank: usize,
m: usize,
n: usize,
a: StridedBatchedMatrixRef<'_, f64>,
s: StridedBatchedVectorRef<'_, f64>,
u: Option<StridedBatchedMatrixRef<'_, f64>>,
v: Option<StridedBatchedMatrixRef<'_, f64>>,
batch_size: usize,
) -> Result<usize> {
ctx.bind()?;
validate_gesvda_strided_batched_inputs(
rank,
m,
n,
a.data.len(),
a.leading_dimension,
a.stride,
s.data.len(),
s.stride,
jobz,
strided_batched_matrix_ref_parts(u),
strided_batched_matrix_ref_parts(v),
batch_size,
)?;
let (u_ptr, ldu, stride_u) =
optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
let (v_ptr, ldv, stride_v) =
optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDgesvdaStridedBatched_bufferSize(
ctx.as_raw(),
jobz.into(),
to_i32(rank, "rank")?,
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
to_i64(a.stride, "stride_a")?,
s.data.as_ptr().cast(),
to_i64(s.stride, "stride_s")?,
u_ptr.cast(),
ldu,
stride_u,
v_ptr.cast(),
ldv,
stride_v,
&raw mut lwork,
to_i32(batch_size, "batch_size")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cgesvda_strided_batched_buffer_size(
ctx: &Context,
jobz: EigenMode,
rank: usize,
m: usize,
n: usize,
a: StridedBatchedMatrixRef<'_, Complex32>,
s: StridedBatchedVectorRef<'_, f32>,
u: Option<StridedBatchedMatrixRef<'_, Complex32>>,
v: Option<StridedBatchedMatrixRef<'_, Complex32>>,
batch_size: usize,
) -> Result<usize> {
ctx.bind()?;
validate_gesvda_strided_batched_inputs(
rank,
m,
n,
a.data.len(),
a.leading_dimension,
a.stride,
s.data.len(),
s.stride,
jobz,
strided_batched_matrix_ref_parts(u),
strided_batched_matrix_ref_parts(v),
batch_size,
)?;
let (u_ptr, ldu, stride_u) =
optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
let (v_ptr, ldv, stride_v) =
optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCgesvdaStridedBatched_bufferSize(
ctx.as_raw(),
jobz.into(),
to_i32(rank, "rank")?,
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
to_i64(a.stride, "stride_a")?,
s.data.as_ptr().cast(),
to_i64(s.stride, "stride_s")?,
u_ptr.cast(),
ldu,
stride_u,
v_ptr.cast(),
ldv,
stride_v,
&raw mut lwork,
to_i32(batch_size, "batch_size")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zgesvda_strided_batched_buffer_size(
ctx: &Context,
jobz: EigenMode,
rank: usize,
m: usize,
n: usize,
a: StridedBatchedMatrixRef<'_, Complex64>,
s: StridedBatchedVectorRef<'_, f64>,
u: Option<StridedBatchedMatrixRef<'_, Complex64>>,
v: Option<StridedBatchedMatrixRef<'_, Complex64>>,
batch_size: usize,
) -> Result<usize> {
ctx.bind()?;
validate_gesvda_strided_batched_inputs(
rank,
m,
n,
a.data.len(),
a.leading_dimension,
a.stride,
s.data.len(),
s.stride,
jobz,
strided_batched_matrix_ref_parts(u),
strided_batched_matrix_ref_parts(v),
batch_size,
)?;
let (u_ptr, ldu, stride_u) =
optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(u), m, rank, jobz)?;
let (v_ptr, ldv, stride_v) =
optional_gesvda_output_ptr(strided_batched_matrix_ref_parts(v), n, rank, jobz)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZgesvdaStridedBatched_bufferSize(
ctx.as_raw(),
jobz.into(),
to_i32(rank, "rank")?,
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
to_i64(a.stride, "stride_a")?,
s.data.as_ptr().cast(),
to_i64(s.stride, "stride_s")?,
u_ptr.cast(),
ldu,
stride_u,
v_ptr.cast(),
ldv,
stride_v,
&raw mut lwork,
to_i32(batch_size, "batch_size")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sgesvda_strided_batched(
ctx: &Context,
jobz: EigenMode,
rank: usize,
m: usize,
n: usize,
a: StridedBatchedMatrixRef<'_, f32>,
s: StridedBatchedVectorMut<'_, f32>,
u: Option<StridedBatchedMatrixMut<'_, f32>>,
v: Option<StridedBatchedMatrixMut<'_, f32>>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
residual: Option<&mut f64>,
batch_size: usize,
) -> Result<()> {
ctx.bind()?;
validate_gesvda_strided_batched_inputs(
rank,
m,
n,
a.data.len(),
a.leading_dimension,
a.stride,
s.data.len(),
s.stride,
jobz,
strided_batched_matrix_mut_ref_option(u.as_ref())
.map(|m| (m.data, m.leading_dimension, m.stride)),
strided_batched_matrix_mut_ref_option(v.as_ref())
.map(|m| (m.data, m.leading_dimension, m.stride)),
batch_size,
)?;
require_info_buffer_len(dev_info, batch_size)?;
let lwork = sgesvda_strided_batched_buffer_size(
ctx,
jobz,
rank,
m,
n,
a,
s.as_ref(),
strided_batched_matrix_mut_ref_option(u.as_ref()),
strided_batched_matrix_mut_ref_option(v.as_ref()),
batch_size,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu, stride_u) =
optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
let (v_ptr, ldv, stride_v) =
optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
unsafe {
try_ffi!(sys::cusolverDnSgesvdaStridedBatched(
ctx.as_raw(),
jobz.into(),
to_i32(rank, "rank")?,
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
to_i64(a.stride, "stride_a")?,
s.data.as_mut_ptr().cast(),
to_i64(s.stride, "stride_s")?,
u_ptr.cast(),
ldu,
stride_u,
v_ptr.cast(),
ldv,
stride_v,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
residual.map_or(ptr::null_mut(), |value| value as *mut f64),
to_i32(batch_size, "batch_size")?,
))?;
}
Ok(())
}
pub fn dgesvda_strided_batched(
ctx: &Context,
jobz: EigenMode,
rank: usize,
m: usize,
n: usize,
a: StridedBatchedMatrixRef<'_, f64>,
s: StridedBatchedVectorMut<'_, f64>,
u: Option<StridedBatchedMatrixMut<'_, f64>>,
v: Option<StridedBatchedMatrixMut<'_, f64>>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
residual: Option<&mut f64>,
batch_size: usize,
) -> Result<()> {
ctx.bind()?;
validate_gesvda_strided_batched_inputs(
rank,
m,
n,
a.data.len(),
a.leading_dimension,
a.stride,
s.data.len(),
s.stride,
jobz,
strided_batched_matrix_mut_ref_option(u.as_ref())
.map(|m| (m.data, m.leading_dimension, m.stride)),
strided_batched_matrix_mut_ref_option(v.as_ref())
.map(|m| (m.data, m.leading_dimension, m.stride)),
batch_size,
)?;
require_info_buffer_len(dev_info, batch_size)?;
let lwork = dgesvda_strided_batched_buffer_size(
ctx,
jobz,
rank,
m,
n,
a,
s.as_ref(),
strided_batched_matrix_mut_ref_option(u.as_ref()),
strided_batched_matrix_mut_ref_option(v.as_ref()),
batch_size,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu, stride_u) =
optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
let (v_ptr, ldv, stride_v) =
optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
unsafe {
try_ffi!(sys::cusolverDnDgesvdaStridedBatched(
ctx.as_raw(),
jobz.into(),
to_i32(rank, "rank")?,
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
to_i64(a.stride, "stride_a")?,
s.data.as_mut_ptr().cast(),
to_i64(s.stride, "stride_s")?,
u_ptr.cast(),
ldu,
stride_u,
v_ptr.cast(),
ldv,
stride_v,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
residual.map_or(ptr::null_mut(), |value| value as *mut f64),
to_i32(batch_size, "batch_size")?,
))?;
}
Ok(())
}
pub fn cgesvda_strided_batched(
ctx: &Context,
jobz: EigenMode,
rank: usize,
m: usize,
n: usize,
a: StridedBatchedMatrixRef<'_, Complex32>,
s: StridedBatchedVectorMut<'_, f32>,
u: Option<StridedBatchedMatrixMut<'_, Complex32>>,
v: Option<StridedBatchedMatrixMut<'_, Complex32>>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
residual: Option<&mut f64>,
batch_size: usize,
) -> Result<()> {
ctx.bind()?;
validate_gesvda_strided_batched_inputs(
rank,
m,
n,
a.data.len(),
a.leading_dimension,
a.stride,
s.data.len(),
s.stride,
jobz,
strided_batched_matrix_mut_ref_option(u.as_ref())
.map(|m| (m.data, m.leading_dimension, m.stride)),
strided_batched_matrix_mut_ref_option(v.as_ref())
.map(|m| (m.data, m.leading_dimension, m.stride)),
batch_size,
)?;
require_info_buffer_len(dev_info, batch_size)?;
let lwork = cgesvda_strided_batched_buffer_size(
ctx,
jobz,
rank,
m,
n,
a,
s.as_ref(),
strided_batched_matrix_mut_ref_option(u.as_ref()),
strided_batched_matrix_mut_ref_option(v.as_ref()),
batch_size,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu, stride_u) =
optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
let (v_ptr, ldv, stride_v) =
optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
unsafe {
try_ffi!(sys::cusolverDnCgesvdaStridedBatched(
ctx.as_raw(),
jobz.into(),
to_i32(rank, "rank")?,
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
to_i64(a.stride, "stride_a")?,
s.data.as_mut_ptr().cast(),
to_i64(s.stride, "stride_s")?,
u_ptr.cast(),
ldu,
stride_u,
v_ptr.cast(),
ldv,
stride_v,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
residual.map_or(ptr::null_mut(), |value| value as *mut f64),
to_i32(batch_size, "batch_size")?,
))?;
}
Ok(())
}
pub fn zgesvda_strided_batched(
ctx: &Context,
jobz: EigenMode,
rank: usize,
m: usize,
n: usize,
a: StridedBatchedMatrixRef<'_, Complex64>,
s: StridedBatchedVectorMut<'_, f64>,
u: Option<StridedBatchedMatrixMut<'_, Complex64>>,
v: Option<StridedBatchedMatrixMut<'_, Complex64>>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
residual: Option<&mut f64>,
batch_size: usize,
) -> Result<()> {
ctx.bind()?;
validate_gesvda_strided_batched_inputs(
rank,
m,
n,
a.data.len(),
a.leading_dimension,
a.stride,
s.data.len(),
s.stride,
jobz,
strided_batched_matrix_mut_ref_option(u.as_ref())
.map(|m| (m.data, m.leading_dimension, m.stride)),
strided_batched_matrix_mut_ref_option(v.as_ref())
.map(|m| (m.data, m.leading_dimension, m.stride)),
batch_size,
)?;
require_info_buffer_len(dev_info, batch_size)?;
let lwork = zgesvda_strided_batched_buffer_size(
ctx,
jobz,
rank,
m,
n,
a,
s.as_ref(),
strided_batched_matrix_mut_ref_option(u.as_ref()),
strided_batched_matrix_mut_ref_option(v.as_ref()),
batch_size,
)?;
require_workspace(workspace.len(), lwork)?;
let (u_ptr, ldu, stride_u) =
optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(u), m, rank, jobz)?;
let (v_ptr, ldv, stride_v) =
optional_gesvda_output_mut_ptr(strided_batched_matrix_mut_parts(v), n, rank, jobz)?;
unsafe {
try_ffi!(sys::cusolverDnZgesvdaStridedBatched(
ctx.as_raw(),
jobz.into(),
to_i32(rank, "rank")?,
to_i32(m, "m")?,
to_i32(n, "n")?,
a.data.as_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
to_i64(a.stride, "stride_a")?,
s.data.as_mut_ptr().cast(),
to_i64(s.stride, "stride_s")?,
u_ptr.cast(),
ldu,
stride_u,
v_ptr.cast(),
ldv,
stride_v,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
residual.map_or(ptr::null_mut(), |value| value as *mut f64),
to_i32(batch_size, "batch_size")?,
))?;
}
Ok(())
}
fn validate_gesvd_dims(m: usize, n: usize) -> Result<()> {
if m == 0 || n == 0 || m < n {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
fn validate_xgesvdp_inputs<TU, TV>(
m: usize,
n: usize,
a_bytes: usize,
lda: usize,
a_type: DataType,
s_bytes: usize,
s_type: DataType,
jobz: EigenMode,
econ: bool,
u: Option<&(&DeviceMemory<TU>, usize)>,
u_type: DataType,
v: Option<&(&DeviceMemory<TV>, usize)>,
v_type: DataType,
) -> Result<()> {
if m == 0 || n == 0 {
return Err(Error::InvalidMatrixShape);
}
validate_x_matrix(m, n, a_bytes, lda, a_type)?;
validate_x_vector(m.min(n), s_bytes, s_type)?;
match jobz {
EigenMode::NoVector => Ok(()),
EigenMode::Vector => {
let Some((u, ldu)) = u else {
return Err(Error::InvalidMatrixShape);
};
let Some((v, ldv)) = v else {
return Err(Error::InvalidMatrixShape);
};
validate_x_eig_output(m, n, u.byte_len(), *ldu, econ, u_type)?;
validate_x_eig_output(n, n, v.byte_len(), *ldv, econ, v_type)
}
}
}
fn matrix_ref_parts<T>(matrix: Option<MatrixRef<'_, T>>) -> Option<(&DeviceMemory<T>, usize)> {
matrix.map(|matrix| (matrix.data, matrix.leading_dimension))
}
fn matrix_mut_parts<T>(matrix: Option<MatrixMut<'_, T>>) -> Option<(&mut DeviceMemory<T>, usize)> {
matrix.map(|matrix| (matrix.data, matrix.leading_dimension))
}
fn matrix_mut_ref_parts<'a, T>(
matrix: Option<&'a MatrixMut<'a, T>>,
) -> Option<(&'a DeviceMemory<T>, usize)> {
matrix.map(|matrix| (&*matrix.data, matrix.leading_dimension))
}
fn matrix_mut_ref_option<'a, T>(matrix: Option<&'a MatrixMut<'a, T>>) -> Option<MatrixRef<'a, T>> {
matrix.map(MatrixMut::as_ref)
}
fn strided_batched_matrix_ref_parts<T>(
matrix: Option<StridedBatchedMatrixRef<'_, T>>,
) -> Option<(&DeviceMemory<T>, usize, usize)> {
matrix.map(|matrix| (matrix.data, matrix.leading_dimension, matrix.stride))
}
fn strided_batched_matrix_mut_parts<T>(
matrix: Option<StridedBatchedMatrixMut<'_, T>>,
) -> Option<(&mut DeviceMemory<T>, usize, usize)> {
matrix.map(|matrix| (matrix.data, matrix.leading_dimension, matrix.stride))
}
fn strided_batched_matrix_mut_ref_option<'a, T>(
matrix: Option<&'a StridedBatchedMatrixMut<'a, T>>,
) -> Option<StridedBatchedMatrixRef<'a, T>> {
matrix.map(StridedBatchedMatrixMut::as_ref)
}
fn validate_xgesvdr_inputs<TU, TV>(
m: usize,
n: usize,
k: usize,
p: usize,
niters: usize,
a_bytes: usize,
lda: usize,
a_type: DataType,
s_bytes: usize,
s_type: DataType,
job_u: TruncatedSvdMode,
u: Option<&(&DeviceMemory<TU>, usize)>,
u_type: DataType,
job_v: TruncatedSvdMode,
v: Option<&(&DeviceMemory<TV>, usize)>,
v_type: DataType,
) -> Result<()> {
if m == 0 || n == 0 || k == 0 || k >= m.min(n) || p == 0 || k.checked_add(p).is_none() {
return Err(Error::InvalidMatrixShape);
}
let kp = k.checked_add(p).ok_or(Error::InvalidMatrixShape)?;
if kp >= m.min(n) || niters == 0 {
return Err(Error::InvalidMatrixShape);
}
validate_x_matrix(m, n, a_bytes, lda, a_type)?;
validate_x_vector(k, s_bytes, s_type)?;
if matches!(job_u, TruncatedSvdMode::Some) {
let Some((u, ldu)) = u else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(m, k, u.byte_len(), *ldu, u_type)?;
}
if matches!(job_v, TruncatedSvdMode::Some) {
let Some((v, ldv)) = v else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(n, k, v.byte_len(), *ldv, v_type)?;
}
Ok(())
}
fn validate_gesvdj_inputs<T>(
m: usize,
n: usize,
a_len: usize,
lda: usize,
s_len: usize,
jobz: EigenMode,
econ: bool,
u: Option<(&DeviceMemory<T>, usize)>,
v: Option<(&DeviceMemory<T>, usize)>,
) -> Result<()> {
if m == 0 || n == 0 {
return Err(Error::InvalidMatrixShape);
}
validate_matrix(m, n, a_len, lda)?;
if s_len < m.min(n) {
return Err(Error::InvalidVectorShape);
}
validate_gesvdj_output(m, n, jobz, econ, u)?;
validate_gesvdj_output(n, n, jobz, econ, v)?;
Ok(())
}
fn validate_gesvda_strided_batched_inputs<T>(
rank: usize,
m: usize,
n: usize,
a_len: usize,
lda: usize,
stride_a: usize,
s_len: usize,
stride_s: usize,
jobz: EigenMode,
u: Option<(&DeviceMemory<T>, usize, usize)>,
v: Option<(&DeviceMemory<T>, usize, usize)>,
batch_size: usize,
) -> Result<()> {
if batch_size == 0 || m == 0 || n == 0 || m < n || rank == 0 || rank > n {
return Err(Error::InvalidMatrixShape);
}
validate_strided_matrix(m, n, a_len, lda, stride_a, batch_size)?;
validate_strided_vector(s_len, n, stride_s, batch_size)?;
match jobz {
EigenMode::NoVector => {}
EigenMode::Vector => {
let Some((u, ldu, stride_u)) = u else {
return Err(Error::InvalidMatrixShape);
};
let Some((v, ldv, stride_v)) = v else {
return Err(Error::InvalidMatrixShape);
};
validate_strided_matrix(m, rank, u.len(), ldu, stride_u, batch_size)?;
validate_strided_matrix(n, rank, v.len(), ldv, stride_v, batch_size)?;
}
}
Ok(())
}
fn validate_gesvdj_batched_inputs<T>(
m: usize,
n: usize,
a_len: usize,
lda: usize,
s_len: usize,
jobz: EigenMode,
u: Option<(&DeviceMemory<T>, usize)>,
v: Option<(&DeviceMemory<T>, usize)>,
batch_size: usize,
) -> Result<()> {
if batch_size == 0 || m == 0 || n == 0 || m > 32 || n > 32 {
return Err(Error::InvalidMatrixShape);
}
let a_cols = n.checked_mul(batch_size).ok_or(Error::InvalidMatrixShape)?;
validate_matrix(m, a_cols, a_len, lda)?;
let s_required = m
.min(n)
.checked_mul(batch_size)
.ok_or(Error::InvalidVectorShape)?;
if s_len < s_required {
return Err(Error::InvalidVectorShape);
}
validate_gesvdj_batched_output(m, n, jobz, u, batch_size)?;
validate_gesvdj_batched_output(n, n, jobz, v, batch_size)?;
Ok(())
}
fn validate_gesvdj_output<T>(
rows: usize,
cols: usize,
jobz: EigenMode,
econ: bool,
matrix: Option<(&DeviceMemory<T>, usize)>,
) -> Result<()> {
match jobz {
EigenMode::NoVector => Ok(()),
EigenMode::Vector => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
let out_cols = if econ { rows.min(cols) } else { cols };
validate_matrix(rows, out_cols, matrix.len(), ld)
}
}
}
fn validate_gesvdj_batched_output<T>(
rows: usize,
cols: usize,
jobz: EigenMode,
matrix: Option<(&DeviceMemory<T>, usize)>,
batch_size: usize,
) -> Result<()> {
match jobz {
EigenMode::NoVector => Ok(()),
EigenMode::Vector => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
let out_cols = rows
.min(cols)
.checked_mul(batch_size)
.ok_or(Error::InvalidMatrixShape)?;
validate_matrix(rows, out_cols, matrix.len(), ld)
}
}
}
fn optional_gesvda_output_ptr<T>(
matrix: Option<(&DeviceMemory<T>, usize, usize)>,
rows: usize,
cols: usize,
jobz: EigenMode,
) -> Result<(*mut T, i32, i64)> {
match jobz {
EigenMode::NoVector => Ok((ptr::null_mut(), 1, 0)),
EigenMode::Vector => {
let Some((matrix, ld, stride)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_strided_matrix(rows, cols, matrix.len(), ld, stride, 1)?;
Ok((
matrix.as_ptr() as *mut T,
to_i32(ld, "ld")?,
to_i64(stride, "stride")?,
))
}
}
}
fn optional_gesvda_output_mut_ptr<T>(
matrix: Option<(&mut DeviceMemory<T>, usize, usize)>,
rows: usize,
cols: usize,
jobz: EigenMode,
) -> Result<(*mut T, i32, i64)> {
match jobz {
EigenMode::NoVector => Ok((ptr::null_mut(), 1, 0)),
EigenMode::Vector => {
let Some((matrix, ld, stride)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_strided_matrix(rows, cols, matrix.len(), ld, stride, 1)?;
Ok((
matrix.as_mut_ptr().cast(),
to_i32(ld, "ld")?,
to_i64(stride, "stride")?,
))
}
}
}
fn validate_gesvd_inputs<T>(
m: usize,
n: usize,
a_len: usize,
lda: usize,
s_len: usize,
job_u: SvdMode,
u: Option<&(&DeviceMemory<T>, usize)>,
job_vt: SvdMode,
vt: Option<&(&DeviceMemory<T>, usize)>,
) -> Result<()> {
validate_gesvd_dims(m, n)?;
validate_matrix(m, n, a_len, lda)?;
if s_len < n {
return Err(Error::InvalidVectorShape);
}
validate_svd_output(m, m, job_u, u)?;
validate_svd_output(n, n, job_vt, vt)?;
Ok(())
}
fn validate_x_svd_output<T>(
rows: usize,
full_cols: usize,
matrix: Option<(&DeviceMemory<T>, usize)>,
mode: SvdMode,
data_type: DataType,
) -> Result<()> {
match mode {
SvdMode::None | SvdMode::Overwrite => Ok(()),
SvdMode::All => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, full_cols, matrix.byte_len(), ld, data_type)
}
SvdMode::Some => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, full_cols.min(rows), matrix.byte_len(), ld, data_type)
}
}
}
fn validate_svd_output<T>(
rows: usize,
full_cols: usize,
mode: SvdMode,
matrix: Option<&(&DeviceMemory<T>, usize)>,
) -> Result<()> {
match mode {
SvdMode::None | SvdMode::Overwrite => Ok(()),
SvdMode::All => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_matrix(rows, full_cols, matrix.len(), *ld)
}
SvdMode::Some => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_matrix(rows, full_cols.min(rows), matrix.len(), *ld)
}
}
}
fn validate_x_eig_output(
rows: usize,
cols: usize,
bytes: usize,
ld: usize,
econ: bool,
data_type: DataType,
) -> Result<()> {
let out_cols = if econ { rows.min(cols) } else { cols };
validate_x_matrix(rows, out_cols, bytes, ld, data_type)
}
fn optional_matrix_ptr<T>(
matrix: Option<(&mut DeviceMemory<T>, usize)>,
order: usize,
mode: SvdMode,
) -> Result<(*mut T, i32)> {
match mode {
SvdMode::None | SvdMode::Overwrite => Ok((ptr::null_mut(), to_i32(order.max(1), "ld")?)),
SvdMode::All | SvdMode::Some => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
Ok((matrix.as_mut_ptr().cast(), to_i32(ld, "ld")?))
}
}
}
fn optional_x_matrix_ptr<T>(
matrix: Option<(&DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
mode: SvdMode,
data_type: DataType,
) -> Result<(*mut T, i64)> {
match mode {
SvdMode::None | SvdMode::Overwrite => Ok((ptr::null_mut(), 1)),
SvdMode::All => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
}
SvdMode::Some => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, cols.min(rows), matrix.byte_len(), ld, data_type)?;
Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
}
}
}
fn optional_x_matrix_mut_ptr<T>(
matrix: Option<(&mut DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
mode: SvdMode,
data_type: DataType,
) -> Result<(*mut T, i64)> {
match mode {
SvdMode::None | SvdMode::Overwrite => Ok((ptr::null_mut(), 1)),
SvdMode::All => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
}
SvdMode::Some => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, cols.min(rows), matrix.byte_len(), ld, data_type)?;
Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
}
}
}
fn optional_x_eig_matrix_ptr<T>(
matrix: Option<(&DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
jobz: EigenMode,
econ: bool,
data_type: DataType,
) -> Result<(*mut T, i64)> {
match jobz {
EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
EigenMode::Vector => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_eig_output(rows, cols, matrix.byte_len(), ld, econ, data_type)?;
Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
}
}
}
fn optional_x_eig_matrix_mut_ptr<T>(
matrix: Option<(&mut DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
jobz: EigenMode,
econ: bool,
data_type: DataType,
) -> Result<(*mut T, i64)> {
match jobz {
EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
EigenMode::Vector => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_eig_output(rows, cols, matrix.byte_len(), ld, econ, data_type)?;
Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
}
}
}
fn optional_x_truncated_u_ptr<T>(
matrix: Option<(&DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
mode: TruncatedSvdMode,
data_type: DataType,
) -> Result<(*mut T, i64)> {
match mode {
TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
TruncatedSvdMode::Some => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
}
}
}
fn optional_x_truncated_u_mut_ptr<T>(
matrix: Option<(&mut DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
mode: TruncatedSvdMode,
data_type: DataType,
) -> Result<(*mut T, i64)> {
match mode {
TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
TruncatedSvdMode::Some => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
}
}
}
fn optional_x_truncated_v_ptr<T>(
matrix: Option<(&DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
mode: TruncatedSvdMode,
data_type: DataType,
) -> Result<(*mut T, i64)> {
match mode {
TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
TruncatedSvdMode::Some => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
Ok((matrix.as_ptr() as *mut T, to_i64(ld, "ld")?))
}
}
}
fn optional_x_truncated_v_mut_ptr<T>(
matrix: Option<(&mut DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
mode: TruncatedSvdMode,
data_type: DataType,
) -> Result<(*mut T, i64)> {
match mode {
TruncatedSvdMode::None => Ok((ptr::null_mut(), 1)),
TruncatedSvdMode::Some => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
validate_x_matrix(rows, cols, matrix.byte_len(), ld, data_type)?;
Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ld")?))
}
}
}
fn optional_gesvdj_matrix_ptr<T>(
matrix: Option<(&DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
jobz: EigenMode,
econ: bool,
) -> Result<(*mut T, i32)> {
match jobz {
EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
EigenMode::Vector => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
let out_cols = if econ { rows.min(cols) } else { cols };
validate_matrix(rows, out_cols, matrix.len(), ld)?;
Ok((matrix.as_ptr() as *mut T, to_i32(ld, "ld")?))
}
}
}
fn optional_gesvdj_matrix_mut_ptr<T>(
matrix: Option<(&mut DeviceMemory<T>, usize)>,
rows: usize,
cols: usize,
jobz: EigenMode,
econ: bool,
) -> Result<(*mut T, i32)> {
match jobz {
EigenMode::NoVector => Ok((ptr::null_mut(), 1)),
EigenMode::Vector => {
let Some((matrix, ld)) = matrix else {
return Err(Error::InvalidMatrixShape);
};
let out_cols = if econ { rows.min(cols) } else { cols };
validate_matrix(rows, out_cols, matrix.len(), ld)?;
Ok((matrix.as_mut_ptr().cast(), to_i32(ld, "ld")?))
}
}
}
fn require_rwork_buffer<T>(rwork: Option<&DeviceMemory<T>>, m: usize, n: usize) -> Result<()> {
let required = n.saturating_sub(1).min(m.saturating_sub(1));
if let Some(rwork) = rwork
&& rwork.len() < required
{
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_matrix(rows: usize, cols: usize, len: usize, lda: usize) -> Result<()> {
if rows == 0 || cols == 0 {
return Err(Error::InvalidMatrixShape);
}
if lda < rows {
return Err(Error::InvalidLeadingDimension);
}
let required = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
if len < required {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
fn validate_x_matrix(
rows: usize,
cols: usize,
bytes: usize,
lda: usize,
data_type: DataType,
) -> Result<()> {
if rows == 0 || cols == 0 {
return Err(Error::InvalidMatrixShape);
}
if lda < rows {
return Err(Error::InvalidLeadingDimension);
}
let elem_size = data_type.size_in_bytes();
let required = lda
.checked_mul(cols)
.and_then(|count| count.checked_mul(elem_size))
.ok_or(Error::InvalidMatrixShape)?;
if bytes < required {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
fn validate_x_vector(len: usize, bytes: usize, data_type: DataType) -> Result<()> {
let required = len
.checked_mul(data_type.size_in_bytes())
.ok_or(Error::InvalidVectorShape)?;
if bytes < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_strided_matrix(
rows: usize,
cols: usize,
len: usize,
lda: usize,
stride: usize,
batch_size: usize,
) -> Result<()> {
validate_matrix(rows, cols, len, lda)?;
if batch_size == 0 {
return Err(Error::InvalidMatrixShape);
}
let footprint = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
if stride < footprint {
return Err(Error::InvalidMatrixShape);
}
let required = if batch_size == 1 {
footprint
} else {
stride
.checked_mul(batch_size - 1)
.and_then(|base| base.checked_add(footprint))
.ok_or(Error::InvalidMatrixShape)?
};
if len < required {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
fn validate_strided_vector(
len: usize,
width: usize,
stride: usize,
batch_size: usize,
) -> Result<()> {
if width == 0 || batch_size == 0 {
return Err(Error::InvalidVectorShape);
}
if stride < width {
return Err(Error::InvalidVectorShape);
}
let required = if batch_size == 1 {
width
} else {
stride
.checked_mul(batch_size - 1)
.and_then(|base| base.checked_add(width))
.ok_or(Error::InvalidVectorShape)?
};
if len < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn require_workspace(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
fn require_workspace_bytes(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
fn require_host_workspace(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
fn require_info_buffer(dev_info: &DeviceMemory<i32>) -> Result<()> {
if dev_info.is_empty() {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn require_info_buffer_len(dev_info: &DeviceMemory<i32>, required: usize) -> Result<()> {
if dev_info.len() < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
#[cfg(all(test, feature = "testing"))]
mod tests {
use singe_core::assert_close;
use singe_cuda::memory::DeviceMemory;
use super::*;
use crate::{params::Params, testing::setup_context_if_available};
#[test]
fn test_sgesvd_returns_expected_singular_values() -> Result<()> {
let Some(ctx) = setup_context_if_available()? else {
return Ok(());
};
let mut a = DeviceMemory::from_slice(&[
3.0_f32, 0.0, 0.0, 2.0,
])?;
let mut s = DeviceMemory::create(2)?;
let mut workspace = DeviceMemory::create(sgesvd_buffer_size(&ctx, 2, 2)?)?;
let mut dev_info = DeviceMemory::create(1)?;
sgesvd(
&ctx,
SvdMode::None,
SvdMode::None,
2,
2,
MatrixMut::new(&mut a, 2),
&mut s,
None,
None,
&mut workspace,
None,
&mut dev_info,
)?;
let singular_values = s.copy_to_host_vec()?;
let info = dev_info.copy_to_host_vec()?;
assert_eq!(info, vec![0]);
assert_close!(&singular_values, &[3.0, 2.0], 1.0e-5);
Ok(())
}
#[test]
fn test_xgesvd_returns_expected_singular_values() -> Result<()> {
let Some(ctx) = setup_context_if_available()? else {
return Ok(());
};
let params = Params::create()?;
let mut a = DeviceMemory::from_slice(&[
3.0_f32, 0.0, 0.0, 2.0,
])?;
let mut s = DeviceMemory::create(2)?;
let workspace_sizes = xgesvd_buffer_size::<f32, f32, f32, f32>(
&ctx,
¶ms,
SvdMode::None,
SvdMode::None,
2,
2,
MatrixRef::new(&a, 2),
&s,
None,
None,
)?;
let mut device_workspace = DeviceMemory::create(workspace_sizes.device_bytes.max(1))?;
let mut host_workspace = vec![0_u8; workspace_sizes.host_bytes.max(1)];
let mut dev_info = DeviceMemory::create(1)?;
xgesvd::<f32, f32, f32, f32>(
&ctx,
¶ms,
SvdMode::None,
SvdMode::None,
2,
2,
MatrixMut::new(&mut a, 2),
&mut s,
None,
None,
ByteWorkspaceMut::new(&mut device_workspace, &mut host_workspace),
&mut dev_info,
)?;
let singular_values = s.copy_to_host_vec()?;
let info = dev_info.copy_to_host_vec()?;
assert_eq!(info, vec![0]);
assert_close!(&singular_values, &[3.0, 2.0], 1.0e-5);
Ok(())
}
}