#[allow(unused_imports)]
use crate::error::Status;
use singe_cuda::{
data_type::{DataType, DataTypeLike},
memory::DeviceMemory,
types::{Complex32, Complex64},
};
use crate::{
context::Context,
error::{Error, Result},
layout::{
BatchedMatrixRef, BatchedVectorRef, ByteWorkspaceMut, MatrixMut, MatrixRef, VectorMut,
VectorRef, WorkspaceSizes,
},
params::Params,
sys, try_ffi,
types::{DiagonalType, DirectMode, FillMode, Operation, SideMode, StorevMode},
utility::{to_i32, to_i64, to_usize},
};
pub fn spotrf_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSpotrf_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dpotrf_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDpotrf_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cpotrf_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCpotrf_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zpotrf_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZpotrf_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn spotrf(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
require_info_buffer(dev_info)?;
let lwork = spotrf_buffer_size(ctx, fill_mode, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSpotrf(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dpotrf(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
require_info_buffer(dev_info)?;
let lwork = dpotrf_buffer_size(ctx, fill_mode, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDpotrf(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cpotrf(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
require_info_buffer(dev_info)?;
let lwork = cpotrf_buffer_size(ctx, fill_mode, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCpotrf(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zpotrf(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
require_info_buffer(dev_info)?;
let lwork = zpotrf_buffer_size(ctx, fill_mode, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZpotrf(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn spotrs(
ctx: &Context,
fill_mode: FillMode,
n: usize,
nrhs: usize,
a: &DeviceMemory<f32>,
lda: usize,
b: &mut DeviceMemory<f32>,
ldb: usize,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
validate_matrix(n, nrhs, b.len(), ldb)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnSpotrs(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dpotrs(
ctx: &Context,
fill_mode: FillMode,
n: usize,
nrhs: usize,
a: &DeviceMemory<f64>,
lda: usize,
b: &mut DeviceMemory<f64>,
ldb: usize,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
validate_matrix(n, nrhs, b.len(), ldb)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnDpotrs(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cpotrs(
ctx: &Context,
fill_mode: FillMode,
n: usize,
nrhs: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
b: &mut DeviceMemory<Complex32>,
ldb: usize,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
validate_matrix(n, nrhs, b.len(), ldb)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnCpotrs(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zpotrs(
ctx: &Context,
fill_mode: FillMode,
n: usize,
nrhs: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
b: &mut DeviceMemory<Complex64>,
ldb: usize,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
validate_matrix(n, nrhs, b.len(), ldb)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnZpotrs(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn spotrf_batched(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: BatchedMatrixRef<'_, f32>,
info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_batched_square_matrix_pointers(n, a)?;
require_info_entries(info, a.len())?;
unsafe {
try_ffi!(sys::cusolverDnSpotrfBatched(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr(),
to_i32(a.leading_dimension, "lda")?,
info.as_mut_ptr().cast(),
to_i32(a.len(), "batch_size")?,
))?;
}
Ok(())
}
pub fn dpotrf_batched(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: BatchedMatrixRef<'_, f64>,
info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_batched_square_matrix_pointers(n, a)?;
require_info_entries(info, a.len())?;
unsafe {
try_ffi!(sys::cusolverDnDpotrfBatched(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr(),
to_i32(a.leading_dimension, "lda")?,
info.as_mut_ptr().cast(),
to_i32(a.len(), "batch_size")?,
))?;
}
Ok(())
}
pub fn cpotrf_batched(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: BatchedMatrixRef<'_, Complex32>,
info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_batched_square_matrix_pointers(n, a)?;
require_info_entries(info, a.len())?;
unsafe {
try_ffi!(sys::cusolverDnCpotrfBatched(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
info.as_mut_ptr().cast(),
to_i32(a.len(), "batch_size")?,
))?;
}
Ok(())
}
pub fn zpotrf_batched(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: BatchedMatrixRef<'_, Complex64>,
info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_batched_square_matrix_pointers(n, a)?;
require_info_entries(info, a.len())?;
unsafe {
try_ffi!(sys::cusolverDnZpotrfBatched(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
info.as_mut_ptr().cast(),
to_i32(a.len(), "batch_size")?,
))?;
}
Ok(())
}
pub fn spotrs_batched(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: BatchedMatrixRef<'_, f32>,
b: BatchedVectorRef<'_, f32>,
info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_batched_square_matrix_pointers(n, a)?;
validate_batched_vector_pointers(n, b)?;
require_info_buffer(info)?;
if a.len() != b.len() {
return Err(Error::InvalidMatrixShape);
}
unsafe {
try_ffi!(sys::cusolverDnSpotrsBatched(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
1,
a.as_mut_ptr(),
to_i32(a.leading_dimension, "lda")?,
b.as_mut_ptr(),
to_i32(b.leading_dimension, "ldb")?,
info.as_mut_ptr().cast(),
to_i32(a.len(), "batch_size")?,
))?;
}
Ok(())
}
pub fn dpotrs_batched(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: BatchedMatrixRef<'_, f64>,
b: BatchedVectorRef<'_, f64>,
info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_batched_square_matrix_pointers(n, a)?;
validate_batched_vector_pointers(n, b)?;
require_info_buffer(info)?;
if a.len() != b.len() {
return Err(Error::InvalidMatrixShape);
}
unsafe {
try_ffi!(sys::cusolverDnDpotrsBatched(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
1,
a.as_mut_ptr(),
to_i32(a.leading_dimension, "lda")?,
b.as_mut_ptr(),
to_i32(b.leading_dimension, "ldb")?,
info.as_mut_ptr().cast(),
to_i32(a.len(), "batch_size")?,
))?;
}
Ok(())
}
pub fn zpotrs_batched(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: BatchedMatrixRef<'_, Complex64>,
b: BatchedVectorRef<'_, Complex64>,
info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_batched_square_matrix_pointers(n, a)?;
validate_batched_vector_pointers(n, b)?;
require_info_buffer(info)?;
if a.len() != b.len() {
return Err(Error::InvalidMatrixShape);
}
unsafe {
try_ffi!(sys::cusolverDnZpotrsBatched(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
1,
a.as_mut_ptr().cast(),
to_i32(a.leading_dimension, "lda")?,
b.as_mut_ptr().cast(),
to_i32(b.leading_dimension, "ldb")?,
info.as_mut_ptr().cast(),
to_i32(a.len(), "batch_size")?,
))?;
}
Ok(())
}
pub fn spotri_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSpotri_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dpotri_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDpotri_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cpotri_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCpotri_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zpotri_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZpotri_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn spotri(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
require_info_buffer(dev_info)?;
let lwork = spotri_buffer_size(ctx, fill_mode, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSpotri(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dpotri(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
require_info_buffer(dev_info)?;
let lwork = dpotri_buffer_size(ctx, fill_mode, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDpotri(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cpotri(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
require_info_buffer(dev_info)?;
let lwork = cpotri_buffer_size(ctx, fill_mode, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCpotri(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zpotri(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
require_info_buffer(dev_info)?;
let lwork = zpotri_buffer_size(ctx, fill_mode, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZpotri(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sgetrf_buffer_size(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSgetrf_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dgetrf_buffer_size(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDgetrf_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cgetrf_buffer_size(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCgetrf_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zgetrf_buffer_size(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZgetrf_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sgetrf(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
workspace: &mut DeviceMemory<f32>,
pivots: Option<&mut DeviceMemory<i32>>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_info_buffer(dev_info)?;
if let Some(pivots) = pivots.as_ref() {
require_pivot_buffer(pivots, m.min(n))?;
}
let lwork = sgetrf_buffer_size(ctx, m, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSgetrf(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
pivots.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()),
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dgetrf(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
workspace: &mut DeviceMemory<f64>,
pivots: Option<&mut DeviceMemory<i32>>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_info_buffer(dev_info)?;
if let Some(pivots) = pivots.as_ref() {
require_pivot_buffer(pivots, m.min(n))?;
}
let lwork = dgetrf_buffer_size(ctx, m, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDgetrf(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
pivots.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()),
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cgetrf(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
workspace: &mut DeviceMemory<Complex32>,
pivots: Option<&mut DeviceMemory<i32>>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_info_buffer(dev_info)?;
if let Some(pivots) = pivots.as_ref() {
require_pivot_buffer(pivots, m.min(n))?;
}
let lwork = cgetrf_buffer_size(ctx, m, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCgetrf(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
pivots.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()),
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zgetrf(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
workspace: &mut DeviceMemory<Complex64>,
pivots: Option<&mut DeviceMemory<i32>>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_info_buffer(dev_info)?;
if let Some(pivots) = pivots.as_ref() {
require_pivot_buffer(pivots, m.min(n))?;
}
let lwork = zgetrf_buffer_size(ctx, m, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZgetrf(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
workspace.as_mut_ptr().cast(),
pivots.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()),
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sgetrs(
ctx: &Context,
operation: Operation,
n: usize,
nrhs: usize,
a: &DeviceMemory<f32>,
lda: usize,
pivots: &DeviceMemory<i32>,
b: &mut DeviceMemory<f32>,
ldb: usize,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
validate_matrix(n, nrhs, b.len(), ldb)?;
require_pivot_buffer(pivots, n)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnSgetrs(
ctx.as_raw(),
operation.into(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
pivots.as_ptr().cast(),
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dgetrs(
ctx: &Context,
operation: Operation,
n: usize,
nrhs: usize,
a: &DeviceMemory<f64>,
lda: usize,
pivots: &DeviceMemory<i32>,
b: &mut DeviceMemory<f64>,
ldb: usize,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
validate_matrix(n, nrhs, b.len(), ldb)?;
require_pivot_buffer(pivots, n)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnDgetrs(
ctx.as_raw(),
operation.into(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
pivots.as_ptr().cast(),
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cgetrs(
ctx: &Context,
operation: Operation,
n: usize,
nrhs: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
pivots: &DeviceMemory<i32>,
b: &mut DeviceMemory<Complex32>,
ldb: usize,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
validate_matrix(n, nrhs, b.len(), ldb)?;
require_pivot_buffer(pivots, n)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnCgetrs(
ctx.as_raw(),
operation.into(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
pivots.as_ptr().cast(),
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zgetrs(
ctx: &Context,
operation: Operation,
n: usize,
nrhs: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
pivots: &DeviceMemory<i32>,
b: &mut DeviceMemory<Complex64>,
ldb: usize,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
validate_matrix(n, nrhs, b.len(), ldb)?;
require_pivot_buffer(pivots, n)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnZgetrs(
ctx.as_raw(),
operation.into(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
pivots.as_ptr().cast(),
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn ssytrf_buffer_size(
ctx: &Context,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSsytrf_bufferSize(
ctx.as_raw(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dsytrf_buffer_size(
ctx: &Context,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDsytrf_bufferSize(
ctx.as_raw(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn csytrf_buffer_size(
ctx: &Context,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCsytrf_bufferSize(
ctx.as_raw(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zsytrf_buffer_size(
ctx: &Context,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZsytrf_bufferSize(
ctx.as_raw(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn ssytrf(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
pivots: Option<&mut DeviceMemory<i32>>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
if let Some(pivots) = pivots.as_ref() {
require_pivot_buffer(pivots, n)?;
}
require_info_buffer(dev_info)?;
let lwork = ssytrf_buffer_size(ctx, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSsytrf(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
pivots.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dsytrf(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
pivots: Option<&mut DeviceMemory<i32>>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
if let Some(pivots) = pivots.as_ref() {
require_pivot_buffer(pivots, n)?;
}
require_info_buffer(dev_info)?;
let lwork = dsytrf_buffer_size(ctx, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDsytrf(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
pivots.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn csytrf(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
pivots: Option<&mut DeviceMemory<i32>>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
if let Some(pivots) = pivots.as_ref() {
require_pivot_buffer(pivots, n)?;
}
require_info_buffer(dev_info)?;
let lwork = csytrf_buffer_size(ctx, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCsytrf(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
pivots.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zsytrf(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
pivots: Option<&mut DeviceMemory<i32>>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_square_matrix(n, a.len(), lda)?;
if let Some(pivots) = pivots.as_ref() {
require_pivot_buffer(pivots, n)?;
}
require_info_buffer(dev_info)?;
let lwork = zsytrf_buffer_size(ctx, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZsytrf(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
pivots.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sgebrd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
ctx.bind()?;
validate_bidiagonal_dims(m, n)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSgebrd_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dgebrd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
ctx.bind()?;
validate_bidiagonal_dims(m, n)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDgebrd_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cgebrd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
ctx.bind()?;
validate_bidiagonal_dims(m, n)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCgebrd_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zgebrd_buffer_size(ctx: &Context, m: usize, n: usize) -> Result<usize> {
ctx.bind()?;
validate_bidiagonal_dims(m, n)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZgebrd_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sgebrd(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
d: &mut DeviceMemory<f32>,
e: &mut DeviceMemory<f32>,
tauq: &mut DeviceMemory<f32>,
taup: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_bidiagonal_buffers(m, n, a.len(), lda, d.len(), e.len(), tauq.len(), taup.len())?;
require_info_buffer(dev_info)?;
let lwork = sgebrd_buffer_size(ctx, m, n)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSgebrd(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
d.as_mut_ptr().cast(),
e.as_mut_ptr().cast(),
tauq.as_mut_ptr().cast(),
taup.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dgebrd(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
d: &mut DeviceMemory<f64>,
e: &mut DeviceMemory<f64>,
tauq: &mut DeviceMemory<f64>,
taup: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_bidiagonal_buffers(m, n, a.len(), lda, d.len(), e.len(), tauq.len(), taup.len())?;
require_info_buffer(dev_info)?;
let lwork = dgebrd_buffer_size(ctx, m, n)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDgebrd(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
d.as_mut_ptr().cast(),
e.as_mut_ptr().cast(),
tauq.as_mut_ptr().cast(),
taup.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cgebrd(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
d: &mut DeviceMemory<f32>,
e: &mut DeviceMemory<f32>,
tauq: &mut DeviceMemory<Complex32>,
taup: &mut DeviceMemory<Complex32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_bidiagonal_buffers(m, n, a.len(), lda, d.len(), e.len(), tauq.len(), taup.len())?;
require_info_buffer(dev_info)?;
let lwork = cgebrd_buffer_size(ctx, m, n)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCgebrd(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
d.as_mut_ptr().cast(),
e.as_mut_ptr().cast(),
tauq.as_mut_ptr().cast(),
taup.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zgebrd(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
d: &mut DeviceMemory<f64>,
e: &mut DeviceMemory<f64>,
tauq: &mut DeviceMemory<Complex64>,
taup: &mut DeviceMemory<Complex64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_bidiagonal_buffers(m, n, a.len(), lda, d.len(), e.len(), tauq.len(), taup.len())?;
require_info_buffer(dev_info)?;
let lwork = zgebrd_buffer_size(ctx, m, n)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZgebrd(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
d.as_mut_ptr().cast(),
e.as_mut_ptr().cast(),
tauq.as_mut_ptr().cast(),
taup.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sorgbr_buffer_size(
ctx: &Context,
side: SideMode,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<f32>,
lda: usize,
tau: &DeviceMemory<f32>,
) -> Result<usize> {
ctx.bind()?;
validate_orgbr_inputs(side, m, n, k, a.len(), lda, tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSorgbr_bufferSize(
ctx.as_raw(),
side.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dorgbr_buffer_size(
ctx: &Context,
side: SideMode,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<f64>,
lda: usize,
tau: &DeviceMemory<f64>,
) -> Result<usize> {
ctx.bind()?;
validate_orgbr_inputs(side, m, n, k, a.len(), lda, tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDorgbr_bufferSize(
ctx.as_raw(),
side.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cungbr_buffer_size(
ctx: &Context,
side: SideMode,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
tau: &DeviceMemory<Complex32>,
) -> Result<usize> {
ctx.bind()?;
validate_orgbr_inputs(side, m, n, k, a.len(), lda, tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCungbr_bufferSize(
ctx.as_raw(),
side.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zungbr_buffer_size(
ctx: &Context,
side: SideMode,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
tau: &DeviceMemory<Complex64>,
) -> Result<usize> {
ctx.bind()?;
validate_orgbr_inputs(side, m, n, k, a.len(), lda, tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZungbr_bufferSize(
ctx.as_raw(),
side.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sorgbr(
ctx: &Context,
side: SideMode,
m: usize,
n: usize,
k: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
tau: &DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_orgbr_inputs(side, m, n, k, a.len(), lda, tau.len())?;
require_info_buffer(dev_info)?;
let lwork = sorgbr_buffer_size(ctx, side, m, n, k, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSorgbr(
ctx.as_raw(),
side.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dorgbr(
ctx: &Context,
side: SideMode,
m: usize,
n: usize,
k: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
tau: &DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_orgbr_inputs(side, m, n, k, a.len(), lda, tau.len())?;
require_info_buffer(dev_info)?;
let lwork = dorgbr_buffer_size(ctx, side, m, n, k, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDorgbr(
ctx.as_raw(),
side.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cungbr(
ctx: &Context,
side: SideMode,
m: usize,
n: usize,
k: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
tau: &DeviceMemory<Complex32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_orgbr_inputs(side, m, n, k, a.len(), lda, tau.len())?;
require_info_buffer(dev_info)?;
let lwork = cungbr_buffer_size(ctx, side, m, n, k, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCungbr(
ctx.as_raw(),
side.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zungbr(
ctx: &Context,
side: SideMode,
m: usize,
n: usize,
k: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
tau: &DeviceMemory<Complex64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_orgbr_inputs(side, m, n, k, a.len(), lda, tau.len())?;
require_info_buffer(dev_info)?;
let lwork = zungbr_buffer_size(ctx, side, m, n, k, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZungbr(
ctx.as_raw(),
side.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn ssytrd_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
d: &DeviceMemory<f32>,
e: &DeviceMemory<f32>,
tau: &DeviceMemory<f32>,
) -> Result<usize> {
ctx.bind()?;
validate_sytrd_inputs(n, a.len(), lda, d.len(), e.len(), tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSsytrd_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
d.as_ptr().cast(),
e.as_ptr().cast(),
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dsytrd_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
d: &DeviceMemory<f64>,
e: &DeviceMemory<f64>,
tau: &DeviceMemory<f64>,
) -> Result<usize> {
ctx.bind()?;
validate_sytrd_inputs(n, a.len(), lda, d.len(), e.len(), tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDsytrd_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
d.as_ptr().cast(),
e.as_ptr().cast(),
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn chetrd_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
d: &DeviceMemory<f32>,
e: &DeviceMemory<f32>,
tau: &DeviceMemory<Complex32>,
) -> Result<usize> {
ctx.bind()?;
validate_sytrd_inputs(n, a.len(), lda, d.len(), e.len(), tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnChetrd_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
d.as_ptr().cast(),
e.as_ptr().cast(),
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zhetrd_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
d: &DeviceMemory<f64>,
e: &DeviceMemory<f64>,
tau: &DeviceMemory<Complex64>,
) -> Result<usize> {
ctx.bind()?;
validate_sytrd_inputs(n, a.len(), lda, d.len(), e.len(), tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZhetrd_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
d.as_ptr().cast(),
e.as_ptr().cast(),
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn ssytrd(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
d: &mut DeviceMemory<f32>,
e: &mut DeviceMemory<f32>,
tau: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_sytrd_inputs(n, a.len(), lda, d.len(), e.len(), tau.len())?;
require_info_buffer(dev_info)?;
let lwork = ssytrd_buffer_size(ctx, fill_mode, n, a, lda, d, e, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSsytrd(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
d.as_mut_ptr().cast(),
e.as_mut_ptr().cast(),
tau.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dsytrd(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
d: &mut DeviceMemory<f64>,
e: &mut DeviceMemory<f64>,
tau: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_sytrd_inputs(n, a.len(), lda, d.len(), e.len(), tau.len())?;
require_info_buffer(dev_info)?;
let lwork = dsytrd_buffer_size(ctx, fill_mode, n, a, lda, d, e, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDsytrd(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
d.as_mut_ptr().cast(),
e.as_mut_ptr().cast(),
tau.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn chetrd(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
d: &mut DeviceMemory<f32>,
e: &mut DeviceMemory<f32>,
tau: &mut DeviceMemory<Complex32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_sytrd_inputs(n, a.len(), lda, d.len(), e.len(), tau.len())?;
require_info_buffer(dev_info)?;
let lwork = chetrd_buffer_size(ctx, fill_mode, n, a, lda, d, e, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnChetrd(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
d.as_mut_ptr().cast(),
e.as_mut_ptr().cast(),
tau.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zhetrd(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
d: &mut DeviceMemory<f64>,
e: &mut DeviceMemory<f64>,
tau: &mut DeviceMemory<Complex64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_sytrd_inputs(n, a.len(), lda, d.len(), e.len(), tau.len())?;
require_info_buffer(dev_info)?;
let lwork = zhetrd_buffer_size(ctx, fill_mode, n, a, lda, d, e, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZhetrd(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
d.as_mut_ptr().cast(),
e.as_mut_ptr().cast(),
tau.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sorgtr_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
tau: &DeviceMemory<f32>,
) -> Result<usize> {
ctx.bind()?;
validate_orgtr_inputs(n, a.len(), lda, tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSorgtr_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dorgtr_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
tau: &DeviceMemory<f64>,
) -> Result<usize> {
ctx.bind()?;
validate_orgtr_inputs(n, a.len(), lda, tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDorgtr_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cungtr_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
tau: &DeviceMemory<Complex32>,
) -> Result<usize> {
ctx.bind()?;
validate_orgtr_inputs(n, a.len(), lda, tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCungtr_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zungtr_buffer_size(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
tau: &DeviceMemory<Complex64>,
) -> Result<usize> {
ctx.bind()?;
validate_orgtr_inputs(n, a.len(), lda, tau.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZungtr_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sorgtr(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
tau: &DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_orgtr_inputs(n, a.len(), lda, tau.len())?;
require_info_buffer(dev_info)?;
let lwork = sorgtr_buffer_size(ctx, fill_mode, n, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSorgtr(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dorgtr(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
tau: &DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_orgtr_inputs(n, a.len(), lda, tau.len())?;
require_info_buffer(dev_info)?;
let lwork = dorgtr_buffer_size(ctx, fill_mode, n, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDorgtr(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cungtr(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
tau: &DeviceMemory<Complex32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_orgtr_inputs(n, a.len(), lda, tau.len())?;
require_info_buffer(dev_info)?;
let lwork = cungtr_buffer_size(ctx, fill_mode, n, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCungtr(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zungtr(
ctx: &Context,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
tau: &DeviceMemory<Complex64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_orgtr_inputs(n, a.len(), lda, tau.len())?;
require_info_buffer(dev_info)?;
let lwork = zungtr_buffer_size(ctx, fill_mode, n, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZungtr(
ctx.as_raw(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sormtr_buffer_size(
ctx: &Context,
side: SideMode,
fill_mode: FillMode,
operation: Operation,
m: usize,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
tau: &DeviceMemory<f32>,
c: &DeviceMemory<f32>,
ldc: usize,
) -> Result<usize> {
ctx.bind()?;
validate_ormtr_inputs(side, m, n, a.len(), lda, tau.len(), c.len(), ldc)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSormtr_bufferSize(
ctx.as_raw(),
side.into(),
fill_mode.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_ptr().cast(),
to_i32(ldc, "ldc")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dormtr_buffer_size(
ctx: &Context,
side: SideMode,
fill_mode: FillMode,
operation: Operation,
m: usize,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
tau: &DeviceMemory<f64>,
c: &DeviceMemory<f64>,
ldc: usize,
) -> Result<usize> {
ctx.bind()?;
validate_ormtr_inputs(side, m, n, a.len(), lda, tau.len(), c.len(), ldc)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDormtr_bufferSize(
ctx.as_raw(),
side.into(),
fill_mode.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_ptr().cast(),
to_i32(ldc, "ldc")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cunmtr_buffer_size(
ctx: &Context,
side: SideMode,
fill_mode: FillMode,
operation: Operation,
m: usize,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
tau: &DeviceMemory<Complex32>,
c: &DeviceMemory<Complex32>,
ldc: usize,
) -> Result<usize> {
ctx.bind()?;
validate_ormtr_inputs(side, m, n, a.len(), lda, tau.len(), c.len(), ldc)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCunmtr_bufferSize(
ctx.as_raw(),
side.into(),
fill_mode.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_ptr().cast(),
to_i32(ldc, "ldc")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zunmtr_buffer_size(
ctx: &Context,
side: SideMode,
fill_mode: FillMode,
operation: Operation,
m: usize,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
tau: &DeviceMemory<Complex64>,
c: &DeviceMemory<Complex64>,
ldc: usize,
) -> Result<usize> {
ctx.bind()?;
validate_ormtr_inputs(side, m, n, a.len(), lda, tau.len(), c.len(), ldc)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZunmtr_bufferSize(
ctx.as_raw(),
side.into(),
fill_mode.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_ptr().cast(),
to_i32(ldc, "ldc")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sormtr(
ctx: &Context,
side: SideMode,
fill_mode: FillMode,
operation: Operation,
m: usize,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
tau: &mut DeviceMemory<f32>,
c: &mut DeviceMemory<f32>,
ldc: usize,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_ormtr_inputs(side, m, n, a.len(), lda, tau.len(), c.len(), ldc)?;
require_info_buffer(dev_info)?;
let lwork = sormtr_buffer_size(ctx, side, fill_mode, operation, m, n, a, lda, tau, c, ldc)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSormtr(
ctx.as_raw(),
side.into(),
fill_mode.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_mut_ptr().cast(),
c.as_mut_ptr().cast(),
to_i32(ldc, "ldc")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dormtr(
ctx: &Context,
side: SideMode,
fill_mode: FillMode,
operation: Operation,
m: usize,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
tau: &mut DeviceMemory<f64>,
c: &mut DeviceMemory<f64>,
ldc: usize,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_ormtr_inputs(side, m, n, a.len(), lda, tau.len(), c.len(), ldc)?;
require_info_buffer(dev_info)?;
let lwork = dormtr_buffer_size(ctx, side, fill_mode, operation, m, n, a, lda, tau, c, ldc)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDormtr(
ctx.as_raw(),
side.into(),
fill_mode.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_mut_ptr().cast(),
c.as_mut_ptr().cast(),
to_i32(ldc, "ldc")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cunmtr(
ctx: &Context,
side: SideMode,
fill_mode: FillMode,
operation: Operation,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
tau: &mut DeviceMemory<Complex32>,
c: &mut DeviceMemory<Complex32>,
ldc: usize,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_ormtr_inputs(side, m, n, a.len(), lda, tau.len(), c.len(), ldc)?;
require_info_buffer(dev_info)?;
let lwork = cunmtr_buffer_size(ctx, side, fill_mode, operation, m, n, a, lda, tau, c, ldc)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCunmtr(
ctx.as_raw(),
side.into(),
fill_mode.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_mut_ptr().cast(),
c.as_mut_ptr().cast(),
to_i32(ldc, "ldc")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zunmtr(
ctx: &Context,
side: SideMode,
fill_mode: FillMode,
operation: Operation,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
tau: &mut DeviceMemory<Complex64>,
c: &mut DeviceMemory<Complex64>,
ldc: usize,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_ormtr_inputs(side, m, n, a.len(), lda, tau.len(), c.len(), ldc)?;
require_info_buffer(dev_info)?;
let lwork = zunmtr_buffer_size(ctx, side, fill_mode, operation, m, n, a, lda, tau, c, ldc)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZunmtr(
ctx.as_raw(),
side.into(),
fill_mode.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_mut_ptr().cast(),
c.as_mut_ptr().cast(),
to_i32(ldc, "ldc")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sgeqrf_buffer_size(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSgeqrf_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dgeqrf_buffer_size(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDgeqrf_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cgeqrf_buffer_size(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCgeqrf_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zgeqrf_buffer_size(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZgeqrf_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sgeqrf(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
tau: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, m.min(n))?;
require_info_buffer(dev_info)?;
let lwork = sgeqrf_buffer_size(ctx, m, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSgeqrf(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dgeqrf(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
tau: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, m.min(n))?;
require_info_buffer(dev_info)?;
let lwork = dgeqrf_buffer_size(ctx, m, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDgeqrf(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cgeqrf(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
tau: &mut DeviceMemory<Complex32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, m.min(n))?;
require_info_buffer(dev_info)?;
let lwork = cgeqrf_buffer_size(ctx, m, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCgeqrf(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zgeqrf(
ctx: &Context,
m: usize,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
tau: &mut DeviceMemory<Complex64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, m.min(n))?;
require_info_buffer(dev_info)?;
let lwork = zgeqrf_buffer_size(ctx, m, n, a, lda)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZgeqrf(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sorgqr_buffer_size(
ctx: &Context,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<f32>,
lda: usize,
tau: &DeviceMemory<f32>,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, k)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSorgqr_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dorgqr_buffer_size(
ctx: &Context,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<f64>,
lda: usize,
tau: &DeviceMemory<f64>,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, k)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDorgqr_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cungqr_buffer_size(
ctx: &Context,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
tau: &DeviceMemory<Complex32>,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, k)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCungqr_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zungqr_buffer_size(
ctx: &Context,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
tau: &DeviceMemory<Complex64>,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, k)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZungqr_bufferSize(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sorgqr(
ctx: &Context,
m: usize,
n: usize,
k: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
tau: &DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, k)?;
require_info_buffer(dev_info)?;
let lwork = sorgqr_buffer_size(ctx, m, n, k, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSorgqr(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dorgqr(
ctx: &Context,
m: usize,
n: usize,
k: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
tau: &DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, k)?;
require_info_buffer(dev_info)?;
let lwork = dorgqr_buffer_size(ctx, m, n, k, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDorgqr(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cungqr(
ctx: &Context,
m: usize,
n: usize,
k: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
tau: &DeviceMemory<Complex32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, k)?;
require_info_buffer(dev_info)?;
let lwork = cungqr_buffer_size(ctx, m, n, k, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCungqr(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zungqr(
ctx: &Context,
m: usize,
n: usize,
k: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
tau: &DeviceMemory<Complex64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(m, n, a.len(), lda)?;
require_tau_buffer(tau, k)?;
require_info_buffer(dev_info)?;
let lwork = zungqr_buffer_size(ctx, m, n, k, a, lda, tau)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZungqr(
ctx.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn sormqr_buffer_size(
ctx: &Context,
side: SideMode,
operation: Operation,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<f32>,
lda: usize,
tau: &DeviceMemory<f32>,
c: &DeviceMemory<f32>,
ldc: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(qr_rows(side, m, n), k, a.len(), lda)?;
require_tau_buffer(tau, k)?;
validate_matrix(m, n, c.len(), ldc)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSormqr_bufferSize(
ctx.as_raw(),
side.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_ptr().cast(),
to_i32(ldc, "ldc")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dormqr_buffer_size(
ctx: &Context,
side: SideMode,
operation: Operation,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<f64>,
lda: usize,
tau: &DeviceMemory<f64>,
c: &DeviceMemory<f64>,
ldc: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(qr_rows(side, m, n), k, a.len(), lda)?;
require_tau_buffer(tau, k)?;
validate_matrix(m, n, c.len(), ldc)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDormqr_bufferSize(
ctx.as_raw(),
side.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_ptr().cast(),
to_i32(ldc, "ldc")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cunmqr_buffer_size(
ctx: &Context,
side: SideMode,
operation: Operation,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
tau: &DeviceMemory<Complex32>,
c: &DeviceMemory<Complex32>,
ldc: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(qr_rows(side, m, n), k, a.len(), lda)?;
require_tau_buffer(tau, k)?;
validate_matrix(m, n, c.len(), ldc)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCunmqr_bufferSize(
ctx.as_raw(),
side.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_ptr().cast(),
to_i32(ldc, "ldc")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zunmqr_buffer_size(
ctx: &Context,
side: SideMode,
operation: Operation,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
tau: &DeviceMemory<Complex64>,
c: &DeviceMemory<Complex64>,
ldc: usize,
) -> Result<usize> {
ctx.bind()?;
validate_matrix(qr_rows(side, m, n), k, a.len(), lda)?;
require_tau_buffer(tau, k)?;
validate_matrix(m, n, c.len(), ldc)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZunmqr_bufferSize(
ctx.as_raw(),
side.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_ptr().cast(),
to_i32(ldc, "ldc")?,
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn sormqr(
ctx: &Context,
side: SideMode,
operation: Operation,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<f32>,
lda: usize,
tau: &DeviceMemory<f32>,
c: &mut DeviceMemory<f32>,
ldc: usize,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(qr_rows(side, m, n), k, a.len(), lda)?;
require_tau_buffer(tau, k)?;
validate_matrix(m, n, c.len(), ldc)?;
require_info_buffer(dev_info)?;
let lwork = sormqr_buffer_size(ctx, side, operation, m, n, k, a, lda, tau, c, ldc)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSormqr(
ctx.as_raw(),
side.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_mut_ptr().cast(),
to_i32(ldc, "ldc")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dormqr(
ctx: &Context,
side: SideMode,
operation: Operation,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<f64>,
lda: usize,
tau: &DeviceMemory<f64>,
c: &mut DeviceMemory<f64>,
ldc: usize,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(qr_rows(side, m, n), k, a.len(), lda)?;
require_tau_buffer(tau, k)?;
validate_matrix(m, n, c.len(), ldc)?;
require_info_buffer(dev_info)?;
let lwork = dormqr_buffer_size(ctx, side, operation, m, n, k, a, lda, tau, c, ldc)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDormqr(
ctx.as_raw(),
side.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_mut_ptr().cast(),
to_i32(ldc, "ldc")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cunmqr(
ctx: &Context,
side: SideMode,
operation: Operation,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
tau: &DeviceMemory<Complex32>,
c: &mut DeviceMemory<Complex32>,
ldc: usize,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(qr_rows(side, m, n), k, a.len(), lda)?;
require_tau_buffer(tau, k)?;
validate_matrix(m, n, c.len(), ldc)?;
require_info_buffer(dev_info)?;
let lwork = cunmqr_buffer_size(ctx, side, operation, m, n, k, a, lda, tau, c, ldc)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCunmqr(
ctx.as_raw(),
side.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_mut_ptr().cast(),
to_i32(ldc, "ldc")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zunmqr(
ctx: &Context,
side: SideMode,
operation: Operation,
m: usize,
n: usize,
k: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
tau: &DeviceMemory<Complex64>,
c: &mut DeviceMemory<Complex64>,
ldc: usize,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_matrix(qr_rows(side, m, n), k, a.len(), lda)?;
require_tau_buffer(tau, k)?;
validate_matrix(m, n, c.len(), ldc)?;
require_info_buffer(dev_info)?;
let lwork = zunmqr_buffer_size(ctx, side, operation, m, n, k, a, lda, tau, c, ldc)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZunmqr(
ctx.as_raw(),
side.into(),
operation.into(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(k, "k")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
tau.as_ptr().cast(),
c.as_mut_ptr().cast(),
to_i32(ldc, "ldc")?,
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn xgeqrf_buffer_size<TA: DataTypeLike, TTau: DataTypeLike>(
ctx: &Context,
params: &Params,
m: usize,
n: usize,
a: MatrixRef<'_, TA>,
tau: VectorRef<'_, TTau>,
compute_type: DataType,
) -> Result<WorkspaceSizes> {
ctx.bind()?;
let a_type = TA::data_type();
let tau_type = TTau::data_type();
validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
validate_x_vector(m.min(n), tau.data.byte_len(), tau_type)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXgeqrf_bufferSize(
ctx.as_raw(),
params.as_raw(),
to_i64(m, "m")?,
to_i64(n, "n")?,
a_type.into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
tau_type.into(),
tau.data.as_ptr().cast(),
compute_type.into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xgeqrf<TA: DataTypeLike, TTau: DataTypeLike>(
ctx: &Context,
params: &Params,
m: usize,
n: usize,
a: MatrixMut<'_, TA>,
tau: VectorMut<'_, TTau>,
compute_type: DataType,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
let a_type = TA::data_type();
let tau_type = TTau::data_type();
validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
validate_x_vector(m.min(n), tau.data.byte_len(), tau_type)?;
require_info_buffer(dev_info)?;
let workspace_sizes =
xgeqrf_buffer_size(ctx, params, m, n, a.as_ref(), tau.as_ref(), compute_type)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
unsafe {
try_ffi!(sys::cusolverDnXgeqrf(
ctx.as_raw(),
params.as_raw(),
to_i64(m, "m")?,
to_i64(n, "n")?,
a_type.into(),
a.data.as_mut_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
tau_type.into(),
tau.data.as_mut_ptr().cast(),
compute_type.into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn xpotrf_buffer_size<TA: DataTypeLike>(
ctx: &Context,
params: &Params,
fill_mode: FillMode,
n: usize,
a: MatrixRef<'_, TA>,
compute_type: DataType,
) -> Result<WorkspaceSizes> {
ctx.bind()?;
let a_type = TA::data_type();
validate_x_matrix(n, n, a.data.byte_len(), a.leading_dimension, a_type)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXpotrf_bufferSize(
ctx.as_raw(),
params.as_raw(),
fill_mode.into(),
to_i64(n, "n")?,
a_type.into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
compute_type.into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xpotrf<TA: DataTypeLike>(
ctx: &Context,
params: &Params,
fill_mode: FillMode,
n: usize,
a: MatrixMut<'_, TA>,
compute_type: DataType,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
let a_type = TA::data_type();
validate_x_matrix(n, n, a.data.byte_len(), a.leading_dimension, a_type)?;
require_info_buffer(dev_info)?;
let workspace_sizes = xpotrf_buffer_size(ctx, params, fill_mode, n, a.as_ref(), compute_type)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
unsafe {
try_ffi!(sys::cusolverDnXpotrf(
ctx.as_raw(),
params.as_raw(),
fill_mode.into(),
to_i64(n, "n")?,
a_type.into(),
a.data.as_mut_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
compute_type.into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn xpotrs<TA: DataTypeLike, TB: DataTypeLike>(
ctx: &Context,
params: &Params,
fill_mode: FillMode,
n: usize,
nrhs: usize,
a: MatrixRef<'_, TA>,
b: MatrixMut<'_, TB>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
let a_type = TA::data_type();
let b_type = TB::data_type();
validate_x_matrix(n, n, a.data.byte_len(), a.leading_dimension, a_type)?;
validate_x_matrix(n, nrhs, b.data.byte_len(), b.leading_dimension, b_type)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnXpotrs(
ctx.as_raw(),
params.as_raw(),
fill_mode.into(),
to_i64(n, "n")?,
to_i64(nrhs, "nrhs")?,
a_type.into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
b_type.into(),
b.data.as_mut_ptr().cast(),
to_i64(b.leading_dimension, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn xtrtri_buffer_size<TA: DataTypeLike>(
ctx: &Context,
fill_mode: FillMode,
diagonal_type: DiagonalType,
n: usize,
a: MatrixRef<'_, TA>,
) -> Result<WorkspaceSizes> {
ctx.bind()?;
validate_x_matrix(
n,
n,
a.data.byte_len(),
a.leading_dimension,
TA::data_type(),
)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXtrtri_bufferSize(
ctx.as_raw(),
fill_mode.into(),
diagonal_type.into(),
to_i64(n, "n")?,
TA::data_type().into(),
a.data.as_ptr().cast_mut().cast(),
to_i64(a.leading_dimension, "lda")?,
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xtrtri<TA: DataTypeLike>(
ctx: &Context,
fill_mode: FillMode,
diagonal_type: DiagonalType,
n: usize,
a: MatrixMut<'_, TA>,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_x_matrix(
n,
n,
a.data.byte_len(),
a.leading_dimension,
TA::data_type(),
)?;
require_info_buffer(dev_info)?;
let workspace_sizes = xtrtri_buffer_size(ctx, fill_mode, diagonal_type, n, a.as_ref())?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
unsafe {
try_ffi!(sys::cusolverDnXtrtri(
ctx.as_raw(),
fill_mode.into(),
diagonal_type.into(),
to_i64(n, "n")?,
TA::data_type().into(),
a.data.as_mut_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn xgetrf_buffer_size<TA: DataTypeLike>(
ctx: &Context,
params: &Params,
m: usize,
n: usize,
a: MatrixRef<'_, TA>,
compute_type: DataType,
) -> Result<WorkspaceSizes> {
ctx.bind()?;
let a_type = TA::data_type();
validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXgetrf_bufferSize(
ctx.as_raw(),
params.as_raw(),
to_i64(m, "m")?,
to_i64(n, "n")?,
a_type.into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
compute_type.into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xgetrf<TA: DataTypeLike>(
ctx: &Context,
params: &Params,
m: usize,
n: usize,
a: MatrixMut<'_, TA>,
pivots: Option<&mut DeviceMemory<i64>>,
compute_type: DataType,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
let a_type = TA::data_type();
validate_x_matrix(m, n, a.data.byte_len(), a.leading_dimension, a_type)?;
if let Some(pivots) = pivots.as_ref() {
require_pivot64_buffer(pivots, m.min(n))?;
}
require_info_buffer(dev_info)?;
let workspace_sizes = xgetrf_buffer_size(ctx, params, m, n, a.as_ref(), compute_type)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
unsafe {
try_ffi!(sys::cusolverDnXgetrf(
ctx.as_raw(),
params.as_raw(),
to_i64(m, "m")?,
to_i64(n, "n")?,
a_type.into(),
a.data.as_mut_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
pivots.map_or(std::ptr::null_mut(), |p| p.as_mut_ptr()),
compute_type.into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn xgetrs<TA: DataTypeLike, TB: DataTypeLike>(
ctx: &Context,
params: &Params,
operation: Operation,
n: usize,
nrhs: usize,
a: MatrixRef<'_, TA>,
pivots: &DeviceMemory<i64>,
b: MatrixMut<'_, TB>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
let a_type = TA::data_type();
let b_type = TB::data_type();
validate_x_matrix(n, n, a.data.byte_len(), a.leading_dimension, a_type)?;
require_pivot64_buffer(pivots, n)?;
validate_x_matrix(n, nrhs, b.data.byte_len(), b.leading_dimension, b_type)?;
require_info_buffer(dev_info)?;
unsafe {
try_ffi!(sys::cusolverDnXgetrs(
ctx.as_raw(),
params.as_raw(),
operation.into(),
to_i64(n, "n")?,
to_i64(nrhs, "nrhs")?,
a_type.into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
pivots.as_ptr().cast(),
b_type.into(),
b.data.as_mut_ptr().cast(),
to_i64(b.leading_dimension, "ldb")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn xsytrs_buffer_size<TA: DataTypeLike, TB: DataTypeLike>(
ctx: &Context,
fill_mode: FillMode,
n: usize,
nrhs: usize,
a: MatrixRef<'_, TA>,
pivots: Option<&DeviceMemory<i64>>,
b: MatrixRef<'_, TB>,
) -> Result<WorkspaceSizes> {
ctx.bind()?;
validate_x_matrix(
n,
n,
a.data.byte_len(),
a.leading_dimension,
TA::data_type(),
)?;
validate_x_matrix(
n,
nrhs,
b.data.byte_len(),
b.leading_dimension,
TB::data_type(),
)?;
if let Some(pivots) = pivots {
require_pivot64_buffer(pivots, n)?;
}
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXsytrs_bufferSize(
ctx.as_raw(),
fill_mode.into(),
to_i64(n, "n")?,
to_i64(nrhs, "nrhs")?,
TA::data_type().into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
pivots.map_or(std::ptr::null(), DeviceMemory::as_ptr),
TB::data_type().into(),
b.data.as_ptr().cast_mut().cast(),
to_i64(b.leading_dimension, "ldb")?,
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xsytrs<TA: DataTypeLike, TB: DataTypeLike>(
ctx: &Context,
fill_mode: FillMode,
n: usize,
nrhs: usize,
a: MatrixRef<'_, TA>,
pivots: Option<&DeviceMemory<i64>>,
b: MatrixMut<'_, TB>,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_x_matrix(
n,
n,
a.data.byte_len(),
a.leading_dimension,
TA::data_type(),
)?;
validate_x_matrix(
n,
nrhs,
b.data.byte_len(),
b.leading_dimension,
TB::data_type(),
)?;
if let Some(pivots) = pivots {
require_pivot64_buffer(pivots, n)?;
}
require_info_buffer(dev_info)?;
let workspace_sizes = xsytrs_buffer_size(ctx, fill_mode, n, nrhs, a, pivots, b.as_ref())?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
unsafe {
try_ffi!(sys::cusolverDnXsytrs(
ctx.as_raw(),
fill_mode.into(),
to_i64(n, "n")?,
to_i64(nrhs, "nrhs")?,
TA::data_type().into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
pivots.map_or(std::ptr::null(), DeviceMemory::as_ptr),
TB::data_type().into(),
b.data.as_mut_ptr().cast(),
to_i64(b.leading_dimension, "ldb")?,
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn xlarft_buffer_size<TV: DataTypeLike, TTau: DataTypeLike, TT: DataTypeLike>(
ctx: &Context,
params: &Params,
direct: DirectMode,
storev: StorevMode,
n: usize,
k: usize,
v: MatrixRef<'_, TV>,
tau: VectorRef<'_, TTau>,
t: MatrixRef<'_, TT>,
compute_type: DataType,
) -> Result<WorkspaceSizes> {
ctx.bind()?;
let v_type = TV::data_type();
let tau_type = TTau::data_type();
let t_type = TT::data_type();
validate_xlarft_inputs(
n,
k,
storev,
v.data.byte_len(),
v.leading_dimension,
v_type,
tau.data.byte_len(),
tau_type,
t.data.byte_len(),
t.leading_dimension,
t_type,
)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXlarft_bufferSize(
ctx.as_raw(),
params.as_raw(),
direct.into(),
storev.into(),
to_i64(n, "n")?,
to_i64(k, "k")?,
v_type.into(),
v.data.as_ptr().cast(),
to_i64(v.leading_dimension, "ldv")?,
tau_type.into(),
tau.data.as_ptr().cast(),
t_type.into(),
t.data.as_ptr().cast_mut().cast(),
to_i64(t.leading_dimension, "ldt")?,
compute_type.into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xlarft<TV: DataTypeLike, TTau: DataTypeLike, TT: DataTypeLike>(
ctx: &Context,
params: &Params,
direct: DirectMode,
storev: StorevMode,
n: usize,
k: usize,
v: MatrixRef<'_, TV>,
tau: VectorRef<'_, TTau>,
t: MatrixMut<'_, TT>,
compute_type: DataType,
workspace: ByteWorkspaceMut<'_>,
) -> Result<()> {
ctx.bind()?;
let v_type = TV::data_type();
let tau_type = TTau::data_type();
let t_type = TT::data_type();
validate_xlarft_inputs(
n,
k,
storev,
v.data.byte_len(),
v.leading_dimension,
v_type,
tau.data.byte_len(),
tau_type,
t.data.byte_len(),
t.leading_dimension,
t_type,
)?;
let workspace_sizes = xlarft_buffer_size(
ctx,
params,
direct,
storev,
n,
k,
v,
tau,
t.as_ref(),
compute_type,
)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
unsafe {
try_ffi!(sys::cusolverDnXlarft(
ctx.as_raw(),
params.as_raw(),
direct.into(),
storev.into(),
to_i64(n, "n")?,
to_i64(k, "k")?,
v_type.into(),
v.data.as_ptr().cast(),
to_i64(v.leading_dimension, "ldv")?,
tau_type.into(),
tau.data.as_ptr().cast(),
t_type.into(),
t.data.as_mut_ptr().cast(),
to_i64(t.leading_dimension, "ldt")?,
compute_type.into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
))?;
}
Ok(())
}
fn validate_square_matrix(n: usize, len: usize, lda: usize) -> Result<()> {
validate_matrix(n, n, len, lda)
}
fn validate_matrix(rows: usize, cols: usize, len: usize, lda: usize) -> Result<()> {
if rows == 0 || cols == 0 {
return Err(Error::InvalidMatrixShape);
}
if lda < rows {
return Err(Error::InvalidLeadingDimension);
}
let required = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
if len < required {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
fn require_workspace(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
fn require_workspace_bytes(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
fn require_host_workspace(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
fn require_info_buffer(dev_info: &DeviceMemory<i32>) -> Result<()> {
if dev_info.is_empty() {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn require_info_entries(dev_info: &DeviceMemory<i32>, required: usize) -> Result<()> {
if dev_info.len() < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn require_pivot_buffer(pivots: &DeviceMemory<i32>, required: usize) -> Result<()> {
if pivots.len() < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn require_pivot64_buffer(pivots: &DeviceMemory<i64>, required: usize) -> Result<()> {
if pivots.len() < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn require_tau_buffer<T>(tau: &DeviceMemory<T>, required: usize) -> Result<()> {
if tau.len() < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn qr_rows(side: SideMode, m: usize, n: usize) -> usize {
match side {
SideMode::Left => m,
SideMode::Right => n,
}
}
fn tridiagonal_order(side: SideMode, m: usize, n: usize) -> usize {
match side {
SideMode::Left => m,
SideMode::Right => n,
}
}
fn validate_bidiagonal_dims(m: usize, n: usize) -> Result<()> {
if m == 0 || n == 0 || m < n {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
fn validate_bidiagonal_buffers(
m: usize,
n: usize,
a_len: usize,
lda: usize,
d_len: usize,
e_len: usize,
tauq_len: usize,
taup_len: usize,
) -> Result<()> {
validate_bidiagonal_dims(m, n)?;
validate_matrix(m, n, a_len, lda)?;
if d_len < n || e_len < n || tauq_len < n || taup_len < n {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_orgbr_inputs(
side: SideMode,
m: usize,
n: usize,
k: usize,
a_len: usize,
lda: usize,
tau_len: usize,
) -> Result<()> {
if m == 0 || n == 0 || k == 0 {
return Err(Error::InvalidMatrixShape);
}
validate_matrix(m, n, a_len, lda)?;
if tau_len < k {
return Err(Error::InvalidVectorShape);
}
match side {
SideMode::Left if m < n || k > m => Err(Error::InvalidMatrixShape),
SideMode::Right if n < m || k > n => Err(Error::InvalidMatrixShape),
_ => Ok(()),
}
}
fn validate_sytrd_inputs(
n: usize,
a_len: usize,
lda: usize,
d_len: usize,
e_len: usize,
tau_len: usize,
) -> Result<()> {
validate_square_matrix(n, a_len, lda)?;
let reflectors = n.saturating_sub(1);
if d_len < n || e_len < reflectors || tau_len < reflectors {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_orgtr_inputs(n: usize, a_len: usize, lda: usize, tau_len: usize) -> Result<()> {
validate_square_matrix(n, a_len, lda)?;
if tau_len < n.saturating_sub(1) {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_ormtr_inputs(
side: SideMode,
m: usize,
n: usize,
a_len: usize,
lda: usize,
tau_len: usize,
c_len: usize,
ldc: usize,
) -> Result<()> {
let nq = tridiagonal_order(side, m, n);
validate_square_matrix(nq, a_len, lda)?;
validate_matrix(m, n, c_len, ldc)?;
if tau_len < nq.saturating_sub(1) {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_batched_square_matrix_pointers<T>(
n: usize,
matrices: BatchedMatrixRef<'_, T>,
) -> Result<()> {
if n == 0 || matrices.is_empty() {
return Err(Error::InvalidMatrixShape);
}
if matrices.leading_dimension < n {
return Err(Error::InvalidLeadingDimension);
}
Ok(())
}
fn validate_batched_vector_pointers<T>(n: usize, vectors: BatchedVectorRef<'_, T>) -> Result<()> {
if n == 0 || vectors.is_empty() {
return Err(Error::InvalidVectorShape);
}
if vectors.leading_dimension < n {
return Err(Error::InvalidLeadingDimension);
}
Ok(())
}
fn validate_x_matrix(
rows: usize,
cols: usize,
bytes: usize,
lda: usize,
data_type: DataType,
) -> Result<()> {
if rows == 0 || cols == 0 {
return Err(Error::InvalidMatrixShape);
}
if lda < rows {
return Err(Error::InvalidLeadingDimension);
}
let required = lda
.checked_mul(cols)
.and_then(|count| count.checked_mul(data_type.size_in_bytes()))
.ok_or(Error::InvalidMatrixShape)?;
if bytes < required {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
fn validate_x_vector(len: usize, bytes: usize, data_type: DataType) -> Result<()> {
let required = len
.checked_mul(data_type.size_in_bytes())
.ok_or(Error::InvalidVectorShape)?;
if bytes < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_xlarft_inputs(
n: usize,
k: usize,
storev: StorevMode,
v_bytes: usize,
ldv: usize,
v_type: DataType,
tau_bytes: usize,
tau_type: DataType,
t_bytes: usize,
ldt: usize,
t_type: DataType,
) -> Result<()> {
if n == 0 || k == 0 || k > n {
return Err(Error::InvalidMatrixShape);
}
if storev != StorevMode::Columnwise {
return Err(Error::InvalidMatrixShape);
}
validate_x_matrix(n, k, v_bytes, ldv, v_type)?;
validate_x_vector(k, tau_bytes, tau_type)?;
validate_x_matrix(k, k, t_bytes, ldt, t_type)?;
Ok(())
}