#[allow(unused_imports)]
use crate::error::Status;
use std::{mem::size_of, ptr};
use singe_cuda::{
data_type::{DataType, DataTypeLike},
memory::DeviceMemory,
types::{Complex32, Complex64},
};
use crate::{
context::Context,
error::{Error, Result},
layout::{ByteWorkspaceMut, MatrixMut, MatrixRef, SelectionWorkspaceSizes, WorkspaceSizes},
params::Params,
sys, try_ffi,
types::{EigenMode, EigenRange, EigenType, FillMode},
utility::{to_i32, to_i64, to_usize},
};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EigenSelection<T> {
All,
ByValue { lower: T, upper: T },
ByIndex { start: usize, end: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Syevd {
pub mode: EigenMode,
pub fill_mode: FillMode,
pub n: usize,
pub leading_dimension: usize,
}
impl Syevd {
pub fn new(mode: EigenMode, fill_mode: FillMode, n: usize, leading_dimension: usize) -> Self {
Self {
mode,
fill_mode,
n,
leading_dimension,
}
}
pub fn workspace_size<TA: DataTypeLike, TW: DataTypeLike>(
self,
ctx: &Context,
params: &Params,
input: SyevdInput<'_, TA, TW>,
) -> Result<WorkspaceSizes> {
xsyevd_buffer_size(
ctx,
params,
self.mode,
self.fill_mode,
self.n,
input.a,
self.leading_dimension,
input.eigenvalues,
)
}
pub fn execute<TA: DataTypeLike, TW: DataTypeLike>(
self,
ctx: &Context,
params: &Params,
bindings: SyevdBindings<'_, TA, TW>,
) -> Result<()> {
xsyevd(
ctx,
params,
self.mode,
self.fill_mode,
self.n,
bindings.a,
self.leading_dimension,
bindings.eigenvalues,
bindings.workspace,
bindings.dev_info,
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SyevdInput<'a, TA, TW> {
pub a: &'a DeviceMemory<TA>,
pub eigenvalues: &'a DeviceMemory<TW>,
}
#[derive(Debug)]
pub struct SyevdBindings<'a, TA, TW> {
pub a: &'a mut DeviceMemory<TA>,
pub eigenvalues: &'a mut DeviceMemory<TW>,
pub workspace: ByteWorkspaceMut<'a>,
pub dev_info: &'a mut DeviceMemory<i32>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Syevdx<T> {
pub mode: EigenMode,
pub fill_mode: FillMode,
pub selection: EigenSelection<T>,
pub n: usize,
pub leading_dimension: usize,
}
impl<T> Syevdx<T> {
pub fn new(
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<T>,
n: usize,
leading_dimension: usize,
) -> Self {
Self {
mode,
fill_mode,
selection,
n,
leading_dimension,
}
}
}
impl<T: Copy + Default> Syevdx<T> {
pub fn workspace_size<TA: DataTypeLike, TW: DataTypeLike>(
self,
ctx: &Context,
params: &Params,
input: SyevdInput<'_, TA, TW>,
) -> Result<SelectionWorkspaceSizes> {
xsyevdx_buffer_size(
ctx,
params,
self.mode,
self.fill_mode,
self.selection,
self.n,
input.a,
self.leading_dimension,
input.eigenvalues,
)
}
pub fn execute<TA: DataTypeLike, TW: DataTypeLike>(
self,
ctx: &Context,
params: &Params,
bindings: SyevdBindings<'_, TA, TW>,
) -> Result<usize> {
xsyevdx(
ctx,
params,
self.mode,
self.fill_mode,
self.selection,
self.n,
bindings.a,
self.leading_dimension,
bindings.eigenvalues,
bindings.workspace,
bindings.dev_info,
)
}
}
#[derive(Debug)]
pub struct SyevjInfo {
handle: sys::syevjInfo_t,
}
unsafe impl Send for SyevjInfo {}
unsafe impl Sync for SyevjInfo {}
impl SyevjInfo {
pub fn create() -> Result<Self> {
let mut handle = ptr::null_mut();
unsafe {
try_ffi!(sys::cusolverDnCreateSyevjInfo(&raw mut handle))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Self { handle })
}
pub fn set_tolerance(&mut self, tolerance: f64) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnXsyevjSetTolerance(self.as_raw(), tolerance,))?;
}
Ok(())
}
pub fn set_max_sweeps(&mut self, max_sweeps: i32) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnXsyevjSetMaxSweeps(self.as_raw(), max_sweeps,))?;
}
Ok(())
}
pub fn set_sort_eigenvalues(&mut self, sort_eigenvalues: bool) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnXsyevjSetSortEig(
self.as_raw(),
i32::from(sort_eigenvalues),
))?;
}
Ok(())
}
pub fn residual(&self, ctx: &Context) -> Result<f64> {
ctx.bind()?;
let mut residual = 0.0;
unsafe {
try_ffi!(sys::cusolverDnXsyevjGetResidual(
ctx.as_raw(),
self.as_raw(),
&raw mut residual,
))?;
}
Ok(residual)
}
pub fn executed_sweeps(&self, ctx: &Context) -> Result<i32> {
ctx.bind()?;
let mut sweeps = 0;
unsafe {
try_ffi!(sys::cusolverDnXsyevjGetSweeps(
ctx.as_raw(),
self.as_raw(),
&raw mut sweeps,
))?;
}
Ok(sweeps)
}
pub fn as_raw(&self) -> sys::syevjInfo_t {
self.handle
}
}
impl Drop for SyevjInfo {
fn drop(&mut self) {
unsafe {
if let Err(err) = try_ffi!(sys::cusolverDnDestroySyevjInfo(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy cusolver syevj info: {err}");
}
}
}
}
pub fn xsyevd_buffer_size<TA: DataTypeLike, TW: DataTypeLike>(
ctx: &Context,
params: &Params,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<TA>,
lda: usize,
w: &DeviceMemory<TW>,
) -> Result<WorkspaceSizes> {
xsyevd_raw_buffer_size(
ctx,
params,
mode,
fill_mode,
n,
TA::data_type(),
a,
lda,
TW::data_type(),
w,
TA::data_type(),
)
}
pub fn xsyevd<TA: DataTypeLike, TW: DataTypeLike>(
ctx: &Context,
params: &Params,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<TA>,
lda: usize,
w: &mut DeviceMemory<TW>,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
xsyevd_raw(
ctx,
params,
mode,
fill_mode,
n,
TA::data_type(),
a,
lda,
TW::data_type(),
w,
TA::data_type(),
workspace,
dev_info,
)
}
pub fn xsyevdx_buffer_size<TA: DataTypeLike, TR: Copy + Default, TW: DataTypeLike>(
ctx: &Context,
params: &Params,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<TR>,
n: usize,
a: &DeviceMemory<TA>,
lda: usize,
w: &DeviceMemory<TW>,
) -> Result<SelectionWorkspaceSizes> {
let (range, value_range, index_range) = selection_parts(selection);
xsyevdx_raw_buffer_size(
ctx,
params,
mode,
range,
fill_mode,
n,
TA::data_type(),
a,
lda,
value_range,
index_range,
TW::data_type(),
w,
TA::data_type(),
)
}
pub fn xsyevdx<TA: DataTypeLike, TR: Copy + Default, TW: DataTypeLike>(
ctx: &Context,
params: &Params,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<TR>,
n: usize,
a: &mut DeviceMemory<TA>,
lda: usize,
w: &mut DeviceMemory<TW>,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize> {
let (range, value_range, index_range) = selection_parts(selection);
xsyevdx_raw(
ctx,
params,
mode,
range,
fill_mode,
n,
TA::data_type(),
a,
lda,
value_range,
index_range,
TW::data_type(),
w,
TA::data_type(),
workspace,
dev_info,
)
}
pub fn xsyev_batched_buffer_size<TA: DataTypeLike, TW: DataTypeLike>(
ctx: &Context,
params: &Params,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: MatrixRef<'_, TA>,
w: &DeviceMemory<TW>,
batch_count: usize,
) -> Result<WorkspaceSizes> {
ctx.bind()?;
validate_xsyev_batched_buffers(
n,
a.data.byte_len(),
a.leading_dimension,
TA::data_type(),
w.byte_len(),
TW::data_type(),
batch_count,
)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXsyevBatched_bufferSize(
ctx.as_raw(),
params.as_raw(),
mode.into(),
fill_mode.into(),
to_i64(n, "n")?,
TA::data_type().into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
TW::data_type().into(),
w.as_ptr().cast(),
TA::data_type().into(),
&raw mut device_bytes,
&raw mut host_bytes,
to_i64(batch_count, "batch_count")?,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xsyev_batched<TA: DataTypeLike, TW: DataTypeLike>(
ctx: &Context,
params: &Params,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: MatrixMut<'_, TA>,
w: &mut DeviceMemory<TW>,
batch_count: usize,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_xsyev_batched_buffers(
n,
a.data.byte_len(),
a.leading_dimension,
TA::data_type(),
w.byte_len(),
TW::data_type(),
batch_count,
)?;
require_info_buffer_len(dev_info, batch_count)?;
let workspace_sizes =
xsyev_batched_buffer_size(ctx, params, mode, fill_mode, n, a.as_ref(), w, batch_count)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
unsafe {
try_ffi!(sys::cusolverDnXsyevBatched(
ctx.as_raw(),
params.as_raw(),
mode.into(),
fill_mode.into(),
to_i64(n, "n")?,
TA::data_type().into(),
a.data.as_mut_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
TW::data_type().into(),
w.as_mut_ptr().cast(),
TA::data_type().into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
to_i64(batch_count, "batch_count")?,
))?;
}
Ok(())
}
pub fn xgeev_buffer_size<TA: DataTypeLike, TW: DataTypeLike, TV: DataTypeLike>(
ctx: &Context,
params: &Params,
n: usize,
a: MatrixRef<'_, TA>,
eigenvalues: &DeviceMemory<TW>,
right_vectors: Option<MatrixRef<'_, TV>>,
) -> Result<WorkspaceSizes> {
ctx.bind()?;
validate_xgeev_inputs(
n,
a.data.byte_len(),
a.leading_dimension,
TA::data_type(),
eigenvalues.byte_len(),
TW::data_type(),
matrix_ref_parts(right_vectors),
TV::data_type(),
)?;
let (vr_ptr, ldvr) = optional_xgeev_matrix_ptr(matrix_ref_parts(right_vectors))?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXgeev_bufferSize(
ctx.as_raw(),
params.as_raw(),
EigenMode::NoVector.into(),
if right_vectors.is_some() {
EigenMode::Vector
} else {
EigenMode::NoVector
}
.into(),
to_i64(n, "n")?,
TA::data_type().into(),
a.data.as_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
TW::data_type().into(),
eigenvalues.as_ptr().cast(),
TA::data_type().into(),
ptr::null(),
1,
TV::data_type().into(),
vr_ptr.cast(),
ldvr,
TA::data_type().into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
pub fn xgeev<TA: DataTypeLike, TW: DataTypeLike, TV: DataTypeLike>(
ctx: &Context,
params: &Params,
n: usize,
a: MatrixMut<'_, TA>,
eigenvalues: &mut DeviceMemory<TW>,
right_vectors: Option<MatrixMut<'_, TV>>,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_xgeev_inputs(
n,
a.data.byte_len(),
a.leading_dimension,
TA::data_type(),
eigenvalues.byte_len(),
TW::data_type(),
matrix_mut_ref_parts(right_vectors.as_ref()),
TV::data_type(),
)?;
require_info_buffer(dev_info)?;
let workspace_sizes = xgeev_buffer_size(
ctx,
params,
n,
a.as_ref(),
eigenvalues,
matrix_mut_ref_option(right_vectors.as_ref()),
)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
let (vr_ptr, ldvr) = optional_xgeev_matrix_mut_ptr(matrix_mut_parts(right_vectors))?;
unsafe {
try_ffi!(sys::cusolverDnXgeev(
ctx.as_raw(),
params.as_raw(),
EigenMode::NoVector.into(),
if vr_ptr.is_null() {
EigenMode::NoVector
} else {
EigenMode::Vector
}
.into(),
to_i64(n, "n")?,
TA::data_type().into(),
a.data.as_mut_ptr().cast(),
to_i64(a.leading_dimension, "lda")?,
TW::data_type().into(),
eigenvalues.as_mut_ptr().cast(),
TA::data_type().into(),
ptr::null_mut(),
1,
TV::data_type().into(),
vr_ptr.cast(),
ldvr,
TA::data_type().into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
fn xsyevd_raw_buffer_size<TA, TW>(
ctx: &Context,
params: &Params,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a_type: DataType,
a: &DeviceMemory<TA>,
lda: usize,
w_type: DataType,
w: &DeviceMemory<TW>,
compute_type: DataType,
) -> Result<WorkspaceSizes> {
ctx.bind()?;
validate_xsyevd_buffers(n, a.byte_len(), lda, a_type, w.byte_len(), w_type)?;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXsyevd_bufferSize(
ctx.as_raw(),
params.as_raw(),
mode.into(),
fill_mode.into(),
to_i64(n, "n")?,
a_type.into(),
a.as_ptr().cast(),
to_i64(lda, "lda")?,
w_type.into(),
w.as_ptr().cast(),
compute_type.into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(WorkspaceSizes::new(
device_bytes as usize,
host_bytes as usize,
))
}
fn xsyevd_raw<TA, TW>(
ctx: &Context,
params: &Params,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a_type: DataType,
a: &mut DeviceMemory<TA>,
lda: usize,
w_type: DataType,
w: &mut DeviceMemory<TW>,
compute_type: DataType,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_xsyevd_buffers(n, a.byte_len(), lda, a_type, w.byte_len(), w_type)?;
require_info_buffer(dev_info)?;
let workspace_sizes = xsyevd_raw_buffer_size(
ctx,
params,
mode,
fill_mode,
n,
a_type,
a,
lda,
w_type,
w,
compute_type,
)?;
require_workspace_bytes(workspace.device.byte_len(), workspace_sizes.device_bytes)?;
require_host_workspace(workspace.host.len(), workspace_sizes.host_bytes)?;
unsafe {
try_ffi!(sys::cusolverDnXsyevd(
ctx.as_raw(),
params.as_raw(),
mode.into(),
fill_mode.into(),
to_i64(n, "n")?,
a_type.into(),
a.as_mut_ptr().cast(),
to_i64(lda, "lda")?,
w_type.into(),
w.as_mut_ptr().cast(),
compute_type.into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
fn xsyevdx_raw_buffer_size<TA, TR, TW>(
ctx: &Context,
params: &Params,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a_type: DataType,
a: &DeviceMemory<TA>,
lda: usize,
value_range: Option<(TR, TR)>,
index_range: Option<(usize, usize)>,
w_type: DataType,
w: &DeviceMemory<TW>,
compute_type: DataType,
) -> Result<SelectionWorkspaceSizes>
where
TR: Copy + Default,
{
ctx.bind()?;
validate_xsyevd_buffers(n, a.byte_len(), lda, a_type, w.byte_len(), w_type)?;
validate_xsyevdx_value_type::<TR>(w_type)?;
let (mut vl, mut vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig = 0;
let mut device_bytes = 0;
let mut host_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnXsyevdx_bufferSize(
ctx.as_raw(),
params.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i64(n, "n")?,
a_type.into(),
a.as_ptr().cast(),
to_i64(lda, "lda")?,
(&raw mut vl).cast(),
(&raw mut vu).cast(),
to_i64(il, "il")?,
to_i64(iu, "iu")?,
&raw mut meig,
w_type.into(),
w.as_ptr().cast(),
compute_type.into(),
&raw mut device_bytes,
&raw mut host_bytes,
))?;
}
Ok(SelectionWorkspaceSizes::new(
to_usize(meig, "meig")?,
device_bytes as usize,
host_bytes as usize,
))
}
fn xsyevdx_raw<TA, TR, TW>(
ctx: &Context,
params: &Params,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a_type: DataType,
a: &mut DeviceMemory<TA>,
lda: usize,
value_range: Option<(TR, TR)>,
index_range: Option<(usize, usize)>,
w_type: DataType,
w: &mut DeviceMemory<TW>,
compute_type: DataType,
workspace: ByteWorkspaceMut<'_>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize>
where
TR: Copy + Default,
{
ctx.bind()?;
validate_xsyevd_buffers(n, a.byte_len(), lda, a_type, w.byte_len(), w_type)?;
validate_xsyevdx_value_type::<TR>(w_type)?;
require_info_buffer(dev_info)?;
let workspace_sizes = xsyevdx_raw_buffer_size(
ctx,
params,
mode,
range,
fill_mode,
n,
a_type,
a,
lda,
value_range,
index_range,
w_type,
w,
compute_type,
)?;
require_workspace_bytes(
workspace.device.byte_len(),
workspace_sizes.workspace.device_bytes,
)?;
require_host_workspace(workspace.host.len(), workspace_sizes.workspace.host_bytes)?;
let (mut vl, mut vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig_raw = 0;
unsafe {
try_ffi!(sys::cusolverDnXsyevdx(
ctx.as_raw(),
params.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i64(n, "n")?,
a_type.into(),
a.as_mut_ptr().cast(),
to_i64(lda, "lda")?,
(&raw mut vl).cast(),
(&raw mut vu).cast(),
to_i64(il, "il")?,
to_i64(iu, "iu")?,
&raw mut meig_raw,
w_type.into(),
w.as_mut_ptr().cast(),
compute_type.into(),
workspace.device.as_mut_ptr().cast(),
workspace_sizes.workspace.device_bytes as _,
workspace.host.as_mut_ptr().cast(),
workspace_sizes.workspace.host_bytes as _,
dev_info.as_mut_ptr().cast(),
))?;
}
debug_assert_eq!(workspace_sizes.selection_size, to_usize(meig_raw, "meig")?);
Ok(workspace_sizes.selection_size)
}
pub fn ssyevd_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
w: &DeviceMemory<f32>,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSsyevd_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dsyevd_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
w: &DeviceMemory<f64>,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDsyevd_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cheevd_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
w: &DeviceMemory<f32>,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCheevd_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zheevd_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
w: &DeviceMemory<f64>,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZheevd_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn ssyevd(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let lwork = ssyevd_buffer_size(ctx, mode, fill_mode, n, a, lda, w)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSsyevd(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dsyevd(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let lwork = dsyevd_buffer_size(ctx, mode, fill_mode, n, a, lda, w)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDsyevd(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn cheevd(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let lwork = cheevd_buffer_size(ctx, mode, fill_mode, n, a, lda, w)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCheevd(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zheevd(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let lwork = zheevd_buffer_size(ctx, mode, fill_mode, n, a, lda, w)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZheevd(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn ssyevj_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
w: &DeviceMemory<f32>,
params: &SyevjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSsyevj_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn dsyevj_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
w: &DeviceMemory<f64>,
params: &SyevjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDsyevj_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn cheevj_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
w: &DeviceMemory<f32>,
params: &SyevjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCheevj_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn zheevj_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
w: &DeviceMemory<f64>,
params: &SyevjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZheevj_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn ssyevj(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
) -> Result<()> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let lwork = ssyevj_buffer_size(ctx, mode, fill_mode, n, a, lda, w, params)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSsyevj(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn dsyevj(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
) -> Result<()> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let lwork = dsyevj_buffer_size(ctx, mode, fill_mode, n, a, lda, w, params)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDsyevj(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn cheevj(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
) -> Result<()> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let lwork = cheevj_buffer_size(ctx, mode, fill_mode, n, a, lda, w, params)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCheevj(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn zheevj(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
) -> Result<()> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let lwork = zheevj_buffer_size(ctx, mode, fill_mode, n, a, lda, w, params)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZheevj(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn ssyevj_batched_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
w: &DeviceMemory<f32>,
params: &SyevjInfo,
batch_count: usize,
) -> Result<usize> {
ctx.bind()?;
validate_syevj_batched_buffers(n, a.len(), lda, w.len(), batch_count)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSsyevjBatched_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
to_i32(batch_count, "batch_count")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dsyevj_batched_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
w: &DeviceMemory<f64>,
params: &SyevjInfo,
batch_count: usize,
) -> Result<usize> {
ctx.bind()?;
validate_syevj_batched_buffers(n, a.len(), lda, w.len(), batch_count)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDsyevjBatched_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
to_i32(batch_count, "batch_count")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn cheevj_batched_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
w: &DeviceMemory<f32>,
params: &SyevjInfo,
batch_count: usize,
) -> Result<usize> {
ctx.bind()?;
validate_syevj_batched_buffers(n, a.len(), lda, w.len(), batch_count)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCheevjBatched_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
to_i32(batch_count, "batch_count")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zheevj_batched_buffer_size(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
w: &DeviceMemory<f64>,
params: &SyevjInfo,
batch_count: usize,
) -> Result<usize> {
ctx.bind()?;
validate_syevj_batched_buffers(n, a.len(), lda, w.len(), batch_count)?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZheevjBatched_bufferSize(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
to_i32(batch_count, "batch_count")?,
))?;
}
to_usize(lwork, "lwork")
}
pub fn ssyevj_batched(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
batch_count: usize,
) -> Result<()> {
ctx.bind()?;
validate_syevj_batched_buffers(n, a.len(), lda, w.len(), batch_count)?;
require_info_buffer_len(dev_info, batch_count)?;
let lwork =
ssyevj_batched_buffer_size(ctx, mode, fill_mode, n, a, lda, w, params, batch_count)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSsyevjBatched(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
to_i32(batch_count, "batch_count")?,
))?;
}
Ok(())
}
pub fn dsyevj_batched(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
batch_count: usize,
) -> Result<()> {
ctx.bind()?;
validate_syevj_batched_buffers(n, a.len(), lda, w.len(), batch_count)?;
require_info_buffer_len(dev_info, batch_count)?;
let lwork =
dsyevj_batched_buffer_size(ctx, mode, fill_mode, n, a, lda, w, params, batch_count)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDsyevjBatched(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
to_i32(batch_count, "batch_count")?,
))?;
}
Ok(())
}
pub fn cheevj_batched(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
batch_count: usize,
) -> Result<()> {
ctx.bind()?;
validate_syevj_batched_buffers(n, a.len(), lda, w.len(), batch_count)?;
require_info_buffer_len(dev_info, batch_count)?;
let lwork =
cheevj_batched_buffer_size(ctx, mode, fill_mode, n, a, lda, w, params, batch_count)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnCheevjBatched(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
to_i32(batch_count, "batch_count")?,
))?;
}
Ok(())
}
pub fn zheevj_batched(
ctx: &Context,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
batch_count: usize,
) -> Result<()> {
ctx.bind()?;
validate_syevj_batched_buffers(n, a.len(), lda, w.len(), batch_count)?;
require_info_buffer_len(dev_info, batch_count)?;
let lwork =
zheevj_batched_buffer_size(ctx, mode, fill_mode, n, a, lda, w, params, batch_count)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZheevjBatched(
ctx.as_raw(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
to_i32(batch_count, "batch_count")?,
))?;
}
Ok(())
}
pub fn ssygvj_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
b: &DeviceMemory<f32>,
ldb: usize,
w: &DeviceMemory<f32>,
params: &SyevjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_sygvj_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSsygvj_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn dsygvj_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
b: &DeviceMemory<f64>,
ldb: usize,
w: &DeviceMemory<f64>,
params: &SyevjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_sygvj_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDsygvj_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn chegvj_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
b: &DeviceMemory<Complex32>,
ldb: usize,
w: &DeviceMemory<f32>,
params: &SyevjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_sygvj_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnChegvj_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn zhegvj_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
b: &DeviceMemory<Complex64>,
ldb: usize,
w: &DeviceMemory<f64>,
params: &SyevjInfo,
) -> Result<usize> {
ctx.bind()?;
validate_sygvj_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZhegvj_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_ptr().cast(),
&raw mut lwork,
params.as_raw(),
))?;
}
to_usize(lwork, "lwork")
}
pub fn ssygvj(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
b: &mut DeviceMemory<f32>,
ldb: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
) -> Result<()> {
ctx.bind()?;
validate_sygvj_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let lwork = ssygvj_buffer_size(ctx, eig_type, mode, fill_mode, n, a, lda, b, ldb, w, params)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSsygvj(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn dsygvj(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
b: &mut DeviceMemory<f64>,
ldb: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
) -> Result<()> {
ctx.bind()?;
validate_sygvj_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let lwork = dsygvj_buffer_size(ctx, eig_type, mode, fill_mode, n, a, lda, b, ldb, w, params)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDsygvj(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn chegvj(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
b: &mut DeviceMemory<Complex32>,
ldb: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
) -> Result<()> {
ctx.bind()?;
validate_sygvj_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let lwork = chegvj_buffer_size(ctx, eig_type, mode, fill_mode, n, a, lda, b, ldb, w, params)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnChegvj(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn zhegvj(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
b: &mut DeviceMemory<Complex64>,
ldb: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
params: &SyevjInfo,
) -> Result<()> {
ctx.bind()?;
validate_sygvj_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let lwork = zhegvj_buffer_size(ctx, eig_type, mode, fill_mode, n, a, lda, b, ldb, w, params)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZhegvj(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
params.as_raw(),
))?;
}
Ok(())
}
pub fn ssygvd_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
b: &DeviceMemory<f32>,
ldb: usize,
w: &DeviceMemory<f32>,
) -> Result<usize> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSsygvd_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn dsygvd_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
b: &DeviceMemory<f64>,
ldb: usize,
w: &DeviceMemory<f64>,
) -> Result<usize> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDsygvd_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn chegvd_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
b: &DeviceMemory<Complex32>,
ldb: usize,
w: &DeviceMemory<f32>,
) -> Result<usize> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnChegvd_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn zhegvd_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
b: &DeviceMemory<Complex64>,
ldb: usize,
w: &DeviceMemory<f64>,
) -> Result<usize> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZhegvd_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
to_usize(lwork, "lwork")
}
pub fn ssygvd(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
b: &mut DeviceMemory<f32>,
ldb: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let lwork = ssygvd_buffer_size(ctx, eig_type, mode, fill_mode, n, a, lda, b, ldb, w)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnSsygvd(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn dsygvd(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
b: &mut DeviceMemory<f64>,
ldb: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let lwork = dsygvd_buffer_size(ctx, eig_type, mode, fill_mode, n, a, lda, b, ldb, w)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnDsygvd(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn chegvd(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
b: &mut DeviceMemory<Complex32>,
ldb: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let lwork = chegvd_buffer_size(ctx, eig_type, mode, fill_mode, n, a, lda, b, ldb, w)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnChegvd(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn zhegvd(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
b: &mut DeviceMemory<Complex64>,
ldb: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<()> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let lwork = zhegvd_buffer_size(ctx, eig_type, mode, fill_mode, n, a, lda, b, ldb, w)?;
require_workspace(workspace.len(), lwork)?;
unsafe {
try_ffi!(sys::cusolverDnZhegvd(
ctx.as_raw(),
eig_type.into(),
mode.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
Ok(())
}
pub fn ssygvdx_selected_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<f32>,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
b: &DeviceMemory<f32>,
ldb: usize,
w: &DeviceMemory<f32>,
) -> Result<(usize, usize)> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let (range, value_range, index_range) = selection_parts(selection);
let (vl, vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig = 0;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSsygvdx_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
vl,
vu,
to_i32(il, "il")?,
to_i32(iu, "iu")?,
&raw mut meig,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
Ok((to_usize(meig, "meig")?, to_usize(lwork, "lwork")?))
}
pub fn dsygvdx_selected_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<f64>,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
b: &DeviceMemory<f64>,
ldb: usize,
w: &DeviceMemory<f64>,
) -> Result<(usize, usize)> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let (range, value_range, index_range) = selection_parts(selection);
let (vl, vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig = 0;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDsygvdx_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
vl,
vu,
to_i32(il, "il")?,
to_i32(iu, "iu")?,
&raw mut meig,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
Ok((to_usize(meig, "meig")?, to_usize(lwork, "lwork")?))
}
pub fn chegvdx_selected_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<f32>,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
b: &DeviceMemory<Complex32>,
ldb: usize,
w: &DeviceMemory<f32>,
) -> Result<(usize, usize)> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let (range, value_range, index_range) = selection_parts(selection);
let (vl, vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig = 0;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnChegvdx_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
vl,
vu,
to_i32(il, "il")?,
to_i32(iu, "iu")?,
&raw mut meig,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
Ok((to_usize(meig, "meig")?, to_usize(lwork, "lwork")?))
}
pub fn zhegvdx_selected_buffer_size(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<f64>,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
b: &DeviceMemory<Complex64>,
ldb: usize,
w: &DeviceMemory<f64>,
) -> Result<(usize, usize)> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
let (range, value_range, index_range) = selection_parts(selection);
let (vl, vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig = 0;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZhegvdx_bufferSize(
ctx.as_raw(),
eig_type.into(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
b.as_ptr().cast(),
to_i32(ldb, "ldb")?,
vl,
vu,
to_i32(il, "il")?,
to_i32(iu, "iu")?,
&raw mut meig,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
Ok((to_usize(meig, "meig")?, to_usize(lwork, "lwork")?))
}
pub fn ssygvdx_selected(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<f32>,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
b: &mut DeviceMemory<f32>,
ldb: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let (range, value_range, index_range) = selection_parts(selection);
let (meig, lwork) = ssygvdx_selected_buffer_size(
ctx, eig_type, mode, fill_mode, selection, n, a, lda, b, ldb, w,
)?;
require_workspace(workspace.len(), lwork)?;
let (vl, vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig_raw = 0;
unsafe {
try_ffi!(sys::cusolverDnSsygvdx(
ctx.as_raw(),
eig_type.into(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
vl,
vu,
to_i32(il, "il")?,
to_i32(iu, "iu")?,
&raw mut meig_raw,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
debug_assert_eq!(meig, to_usize(meig_raw, "meig")?);
Ok(meig)
}
pub fn dsygvdx_selected(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<f64>,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
b: &mut DeviceMemory<f64>,
ldb: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let (range, value_range, index_range) = selection_parts(selection);
let (meig, lwork) = dsygvdx_selected_buffer_size(
ctx, eig_type, mode, fill_mode, selection, n, a, lda, b, ldb, w,
)?;
require_workspace(workspace.len(), lwork)?;
let (vl, vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig_raw = 0;
unsafe {
try_ffi!(sys::cusolverDnDsygvdx(
ctx.as_raw(),
eig_type.into(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
vl,
vu,
to_i32(il, "il")?,
to_i32(iu, "iu")?,
&raw mut meig_raw,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
debug_assert_eq!(meig, to_usize(meig_raw, "meig")?);
Ok(meig)
}
pub fn chegvdx_selected(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<f32>,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
b: &mut DeviceMemory<Complex32>,
ldb: usize,
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let (range, value_range, index_range) = selection_parts(selection);
let (meig, lwork) = chegvdx_selected_buffer_size(
ctx, eig_type, mode, fill_mode, selection, n, a, lda, b, ldb, w,
)?;
require_workspace(workspace.len(), lwork)?;
let (vl, vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig_raw = 0;
unsafe {
try_ffi!(sys::cusolverDnChegvdx(
ctx.as_raw(),
eig_type.into(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
vl,
vu,
to_i32(il, "il")?,
to_i32(iu, "iu")?,
&raw mut meig_raw,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
debug_assert_eq!(meig, to_usize(meig_raw, "meig")?);
Ok(meig)
}
pub fn zhegvdx_selected(
ctx: &Context,
eig_type: EigenType,
mode: EigenMode,
fill_mode: FillMode,
selection: EigenSelection<f64>,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
b: &mut DeviceMemory<Complex64>,
ldb: usize,
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize> {
ctx.bind()?;
validate_sygvd_buffers(n, a.len(), lda, b.len(), ldb, w.len())?;
require_info_buffer(dev_info)?;
let (range, value_range, index_range) = selection_parts(selection);
let (meig, lwork) = zhegvdx_selected_buffer_size(
ctx, eig_type, mode, fill_mode, selection, n, a, lda, b, ldb, w,
)?;
require_workspace(workspace.len(), lwork)?;
let (vl, vu, il, iu) = validate_xsyevdx_range(range, n, value_range, index_range)?;
let mut meig_raw = 0;
unsafe {
try_ffi!(sys::cusolverDnZhegvdx(
ctx.as_raw(),
eig_type.into(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
b.as_mut_ptr().cast(),
to_i32(ldb, "ldb")?,
vl,
vu,
to_i32(il, "il")?,
to_i32(iu, "iu")?,
&raw mut meig_raw,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
debug_assert_eq!(meig, to_usize(meig_raw, "meig")?);
Ok(meig)
}
pub fn ssyevdx_buffer_size(
ctx: &Context,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f32>,
lda: usize,
value_range: (f32, f32),
index_range: (usize, usize),
w: &DeviceMemory<f32>,
) -> Result<(usize, usize)> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut meig = 0;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnSsyevdx_bufferSize(
ctx.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
value_range.0,
value_range.1,
to_i32(index_range.0, "il")?,
to_i32(index_range.1, "iu")?,
&raw mut meig,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
Ok((to_usize(meig, "meig")?, to_usize(lwork, "lwork")?))
}
pub fn dsyevdx_buffer_size(
ctx: &Context,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<f64>,
lda: usize,
value_range: (f64, f64),
index_range: (usize, usize),
w: &DeviceMemory<f64>,
) -> Result<(usize, usize)> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut meig = 0;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnDsyevdx_bufferSize(
ctx.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
value_range.0,
value_range.1,
to_i32(index_range.0, "il")?,
to_i32(index_range.1, "iu")?,
&raw mut meig,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
Ok((to_usize(meig, "meig")?, to_usize(lwork, "lwork")?))
}
pub fn cheevdx_buffer_size(
ctx: &Context,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex32>,
lda: usize,
value_range: (f32, f32),
index_range: (usize, usize),
w: &DeviceMemory<f32>,
) -> Result<(usize, usize)> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut meig = 0;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnCheevdx_bufferSize(
ctx.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
value_range.0,
value_range.1,
to_i32(index_range.0, "il")?,
to_i32(index_range.1, "iu")?,
&raw mut meig,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
Ok((to_usize(meig, "meig")?, to_usize(lwork, "lwork")?))
}
pub fn zheevdx_buffer_size(
ctx: &Context,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a: &DeviceMemory<Complex64>,
lda: usize,
value_range: (f64, f64),
index_range: (usize, usize),
w: &DeviceMemory<f64>,
) -> Result<(usize, usize)> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
let mut meig = 0;
let mut lwork = 0;
unsafe {
try_ffi!(sys::cusolverDnZheevdx_bufferSize(
ctx.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_ptr().cast(),
to_i32(lda, "lda")?,
value_range.0,
value_range.1,
to_i32(index_range.0, "il")?,
to_i32(index_range.1, "iu")?,
&raw mut meig,
w.as_ptr().cast(),
&raw mut lwork,
))?;
}
Ok((to_usize(meig, "meig")?, to_usize(lwork, "lwork")?))
}
pub fn ssyevdx(
ctx: &Context,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f32>,
lda: usize,
value_range: (f32, f32),
index_range: (usize, usize),
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<f32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let (meig, lwork) = ssyevdx_buffer_size(
ctx,
mode,
range,
fill_mode,
n,
a,
lda,
value_range,
index_range,
w,
)?;
require_workspace(workspace.len(), lwork)?;
let mut meig_raw = 0;
unsafe {
try_ffi!(sys::cusolverDnSsyevdx(
ctx.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
value_range.0,
value_range.1,
to_i32(index_range.0, "il")?,
to_i32(index_range.1, "iu")?,
&raw mut meig_raw,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
debug_assert_eq!(meig, to_usize(meig_raw, "meig")?);
Ok(meig)
}
pub fn dsyevdx(
ctx: &Context,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<f64>,
lda: usize,
value_range: (f64, f64),
index_range: (usize, usize),
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<f64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let (meig, lwork) = dsyevdx_buffer_size(
ctx,
mode,
range,
fill_mode,
n,
a,
lda,
value_range,
index_range,
w,
)?;
require_workspace(workspace.len(), lwork)?;
let mut meig_raw = 0;
unsafe {
try_ffi!(sys::cusolverDnDsyevdx(
ctx.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
value_range.0,
value_range.1,
to_i32(index_range.0, "il")?,
to_i32(index_range.1, "iu")?,
&raw mut meig_raw,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
debug_assert_eq!(meig, to_usize(meig_raw, "meig")?);
Ok(meig)
}
pub fn cheevdx(
ctx: &Context,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex32>,
lda: usize,
value_range: (f32, f32),
index_range: (usize, usize),
w: &mut DeviceMemory<f32>,
workspace: &mut DeviceMemory<Complex32>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let (meig, lwork) = cheevdx_buffer_size(
ctx,
mode,
range,
fill_mode,
n,
a,
lda,
value_range,
index_range,
w,
)?;
require_workspace(workspace.len(), lwork)?;
let mut meig_raw = 0;
unsafe {
try_ffi!(sys::cusolverDnCheevdx(
ctx.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
value_range.0,
value_range.1,
to_i32(index_range.0, "il")?,
to_i32(index_range.1, "iu")?,
&raw mut meig_raw,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
debug_assert_eq!(meig, to_usize(meig_raw, "meig")?);
Ok(meig)
}
pub fn zheevdx(
ctx: &Context,
mode: EigenMode,
range: EigenRange,
fill_mode: FillMode,
n: usize,
a: &mut DeviceMemory<Complex64>,
lda: usize,
value_range: (f64, f64),
index_range: (usize, usize),
w: &mut DeviceMemory<f64>,
workspace: &mut DeviceMemory<Complex64>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<usize> {
ctx.bind()?;
validate_syev_buffers(n, a.len(), lda, w.len())?;
require_info_buffer(dev_info)?;
let (meig, lwork) = zheevdx_buffer_size(
ctx,
mode,
range,
fill_mode,
n,
a,
lda,
value_range,
index_range,
w,
)?;
require_workspace(workspace.len(), lwork)?;
let mut meig_raw = 0;
unsafe {
try_ffi!(sys::cusolverDnZheevdx(
ctx.as_raw(),
mode.into(),
range.into(),
fill_mode.into(),
to_i32(n, "n")?,
a.as_mut_ptr().cast(),
to_i32(lda, "lda")?,
value_range.0,
value_range.1,
to_i32(index_range.0, "il")?,
to_i32(index_range.1, "iu")?,
&raw mut meig_raw,
w.as_mut_ptr().cast(),
workspace.as_mut_ptr().cast(),
to_i32(lwork, "lwork")?,
dev_info.as_mut_ptr().cast(),
))?;
}
debug_assert_eq!(meig, to_usize(meig_raw, "meig")?);
Ok(meig)
}
fn validate_syev_buffers(n: usize, a_len: usize, lda: usize, w_len: usize) -> Result<()> {
validate_square_matrix(n, a_len, lda)?;
if w_len < n {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_square_matrix(n: usize, len: usize, lda: usize) -> Result<()> {
validate_matrix(n, n, len, lda)
}
fn validate_matrix(rows: usize, cols: usize, len: usize, lda: usize) -> Result<()> {
if rows == 0 || cols == 0 {
return Err(Error::InvalidMatrixShape);
}
if lda < rows {
return Err(Error::InvalidLeadingDimension);
}
let required = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
if len < required {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
fn require_workspace(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
fn require_workspace_bytes(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
fn require_host_workspace(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
fn require_info_buffer(dev_info: &DeviceMemory<i32>) -> Result<()> {
if dev_info.is_empty() {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_xsyevd_buffers(
n: usize,
a_bytes: usize,
lda: usize,
a_type: DataType,
w_bytes: usize,
w_type: DataType,
) -> Result<()> {
validate_x_matrix(n, n, a_bytes, lda, a_type)?;
validate_x_vector(n, w_bytes, w_type)?;
Ok(())
}
fn validate_xsyev_batched_buffers(
n: usize,
a_bytes: usize,
lda: usize,
a_type: DataType,
w_bytes: usize,
w_type: DataType,
batch_count: usize,
) -> Result<()> {
if batch_count == 0 {
return Err(Error::InvalidMatrixShape);
}
validate_x_matrix(
n,
n.checked_mul(batch_count)
.ok_or(Error::InvalidMatrixShape)?,
a_bytes,
lda,
a_type,
)?;
validate_x_vector(
n.checked_mul(batch_count)
.ok_or(Error::InvalidVectorShape)?,
w_bytes,
w_type,
)?;
let problem_size = n
.checked_mul(lda)
.and_then(|value| value.checked_mul(batch_count))
.ok_or(Error::InvalidMatrixShape)?;
if problem_size > i32::MAX as usize {
return Err(Error::OutOfRange {
name: "batched problem size".into(),
});
}
Ok(())
}
fn validate_syevj_batched_buffers(
n: usize,
a_len: usize,
lda: usize,
w_len: usize,
batch_count: usize,
) -> Result<()> {
if batch_count == 0 {
return Err(Error::InvalidMatrixShape);
}
let matrix_cols = n
.checked_mul(batch_count)
.ok_or(Error::InvalidMatrixShape)?;
validate_matrix(n, matrix_cols, a_len, lda)?;
let eigenvalues = n
.checked_mul(batch_count)
.ok_or(Error::InvalidVectorShape)?;
if w_len < eigenvalues {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_sygvj_buffers(
n: usize,
a_len: usize,
lda: usize,
b_len: usize,
ldb: usize,
w_len: usize,
) -> Result<()> {
validate_square_matrix(n, a_len, lda)?;
validate_square_matrix(n, b_len, ldb)?;
if w_len < n {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_sygvd_buffers(
n: usize,
a_len: usize,
lda: usize,
b_len: usize,
ldb: usize,
w_len: usize,
) -> Result<()> {
validate_square_matrix(n, a_len, lda)?;
validate_square_matrix(n, b_len, ldb)?;
if w_len < n {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_x_matrix(
rows: usize,
cols: usize,
bytes: usize,
lda: usize,
data_type: DataType,
) -> Result<()> {
if rows == 0 || cols == 0 {
return Err(Error::InvalidMatrixShape);
}
if lda < rows {
return Err(Error::InvalidLeadingDimension);
}
let required = lda
.checked_mul(cols)
.and_then(|count| count.checked_mul(data_type.size_in_bytes()))
.ok_or(Error::InvalidMatrixShape)?;
if bytes < required {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
fn validate_x_vector(len: usize, bytes: usize, data_type: DataType) -> Result<()> {
let required = len
.checked_mul(data_type.size_in_bytes())
.ok_or(Error::InvalidVectorShape)?;
if bytes < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn validate_xsyevdx_value_type<T>(w_type: DataType) -> Result<()> {
if size_of::<T>() != w_type.size_in_bytes() {
return Err(Error::InvalidEigRange);
}
Ok(())
}
type EigenSelectionParts<T> = (EigenRange, Option<(T, T)>, Option<(usize, usize)>);
fn validate_xsyevdx_range<T>(
range: EigenRange,
n: usize,
value_range: Option<(T, T)>,
index_range: Option<(usize, usize)>,
) -> Result<(T, T, usize, usize)>
where
T: Default,
{
match range {
EigenRange::All => {
if value_range.is_some() || index_range.is_some() {
return Err(Error::InvalidEigRange);
}
Ok((T::default(), T::default(), 0, 0))
}
EigenRange::Value => {
if index_range.is_some() {
return Err(Error::InvalidEigRange);
}
let Some((vl, vu)) = value_range else {
return Err(Error::InvalidEigRange);
};
Ok((vl, vu, 0, 0))
}
EigenRange::Index => {
if value_range.is_some() {
return Err(Error::InvalidEigRange);
}
let Some((il, iu)) = index_range else {
return Err(Error::InvalidEigRange);
};
if il == 0 || iu == 0 || il > iu || iu > n {
return Err(Error::InvalidEigRange);
}
Ok((T::default(), T::default(), il, iu))
}
}
}
fn selection_parts<T>(selection: EigenSelection<T>) -> EigenSelectionParts<T> {
match selection {
EigenSelection::All => (EigenRange::All, None, None),
EigenSelection::ByValue { lower, upper } => (EigenRange::Value, Some((lower, upper)), None),
EigenSelection::ByIndex { start, end } => (EigenRange::Index, None, Some((start, end))),
}
}
fn matrix_ref_parts<T>(matrix: Option<MatrixRef<'_, T>>) -> Option<(&DeviceMemory<T>, usize)> {
matrix.map(|matrix| (matrix.data, matrix.leading_dimension))
}
fn matrix_mut_parts<T>(matrix: Option<MatrixMut<'_, T>>) -> Option<(&mut DeviceMemory<T>, usize)> {
matrix.map(|matrix| (matrix.data, matrix.leading_dimension))
}
fn matrix_mut_ref_parts<'a, T>(
matrix: Option<&'a MatrixMut<'a, T>>,
) -> Option<(&'a DeviceMemory<T>, usize)> {
matrix.map(|matrix| (&*matrix.data, matrix.leading_dimension))
}
fn matrix_mut_ref_option<'a, T>(matrix: Option<&'a MatrixMut<'a, T>>) -> Option<MatrixRef<'a, T>> {
matrix.map(MatrixMut::as_ref)
}
fn validate_xgeev_inputs<TV>(
n: usize,
a_bytes: usize,
lda: usize,
a_type: DataType,
w_bytes: usize,
w_type: DataType,
right_vectors: Option<(&DeviceMemory<TV>, usize)>,
vr_type: DataType,
) -> Result<()> {
validate_x_matrix(n, n, a_bytes, lda, a_type)?;
validate_xgeev_eigenvalues(n, w_bytes, a_type, w_type)?;
if let Some((vr, ldvr)) = right_vectors {
validate_x_matrix(n, n, vr.byte_len(), ldvr, vr_type)?;
}
Ok(())
}
fn validate_xgeev_eigenvalues(
n: usize,
w_bytes: usize,
a_type: DataType,
w_type: DataType,
) -> Result<()> {
let expected_len = match (a_type, w_type) {
(DataType::F32, DataType::F32) | (DataType::F64, DataType::F64) => {
n.checked_mul(2).ok_or(Error::InvalidVectorShape)?
}
(DataType::F32, DataType::ComplexF32)
| (DataType::F64, DataType::ComplexF64)
| (DataType::ComplexF32, DataType::ComplexF32)
| (DataType::ComplexF64, DataType::ComplexF64) => n,
_ => return Err(Error::InvalidVectorShape),
};
validate_x_vector(expected_len, w_bytes, w_type)
}
fn optional_xgeev_matrix_ptr<T: DataTypeLike>(
matrix: Option<(&DeviceMemory<T>, usize)>,
) -> Result<(*const T, i64)> {
match matrix {
Some((matrix, ld)) => Ok((matrix.as_ptr().cast(), to_i64(ld, "ldvr")?)),
None => Ok((ptr::null(), 1)),
}
}
fn optional_xgeev_matrix_mut_ptr<T: DataTypeLike>(
matrix: Option<(&mut DeviceMemory<T>, usize)>,
) -> Result<(*mut T, i64)> {
match matrix {
Some((matrix, ld)) => Ok((matrix.as_mut_ptr().cast(), to_i64(ld, "ldvr")?)),
None => Ok((ptr::null_mut(), 1)),
}
}
fn require_info_buffer_len(dev_info: &DeviceMemory<i32>, required: usize) -> Result<()> {
if dev_info.len() < required {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::{EigenSelection, selection_parts};
use crate::types::EigenRange;
#[test]
fn eigen_selection_all_maps_cleanly() {
let (range, values, indices) = selection_parts::<f32>(EigenSelection::All);
assert_eq!(range, EigenRange::All);
assert_eq!(values, None);
assert_eq!(indices, None);
}
#[test]
fn eigen_selection_index_maps_cleanly() {
let (range, values, indices) =
selection_parts::<f64>(EigenSelection::ByIndex { start: 2, end: 5 });
assert_eq!(range, EigenRange::Index);
assert_eq!(values, None);
assert_eq!(indices, Some((2, 5)));
}
}