use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
cudaDataType, cusolverDnCreate, cusolverDnCreateParams, cusolverDnDestroy,
cusolverDnDestroyParams, cusolverDnHandle_t, cusolverDnParams_t, cusolverDnSetStream,
cusolverDnXgeev, cusolverDnXgeev_bufferSize, CUDA_C_32F, CUDA_C_64F, CUDA_R_32F, CUDA_R_64F,
CUSOLVER_EIG_MODE_NOVECTOR, CUSOLVER_EIG_MODE_VECTOR,
};
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, KernelSku, LinalgKind, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, TensorMut, Workspace,
};
use super::cholesky::unpack_workspace;
#[derive(Copy, Clone, Debug)]
pub struct EigDescriptor {
pub n: i32,
pub compute_left: bool,
pub compute_right: bool,
pub element: ElementKind,
}
pub struct EigArgs<'a, T: Element> {
pub a: TensorMut<'a, T, 2>,
pub w: TensorMut<'a, T, 1>,
pub vl: Option<TensorMut<'a, T, 2>>,
pub vr: Option<TensorMut<'a, T, 2>>,
pub info: TensorMut<'a, i32, 1>,
}
pub struct EigPlan<T: Element> {
desc: EigDescriptor,
sku: KernelSku,
handle: Cell<cusolverDnHandle_t>,
params: Cell<cusolverDnParams_t>,
workspace_bytes_device: Cell<usize>,
workspace_bytes_host: Cell<usize>,
_marker: PhantomData<T>,
}
impl<T: Element> EigPlan<T> {
pub fn select(
_stream: &Stream,
desc: &EigDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::EigPlan: descriptor.element != T::KIND",
));
}
if !matches!(
T::KIND,
ElementKind::F32 | ElementKind::F64 | ElementKind::Complex32 | ElementKind::Complex64
) {
return Err(Error::Unsupported(
"baracuda-kernels::EigPlan: supports f32 / f64 / Complex32 / Complex64 only",
));
}
if desc.n <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::EigPlan: n must be > 0",
));
}
let math_precision = match T::KIND {
ElementKind::F64 | ElementKind::Complex64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: false,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Linalg,
op: LinalgKind::Eig as u16,
element: T::KIND,
aux_element: Some(ElementKind::I32),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Cusolver,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
handle: Cell::new(core::ptr::null_mut()),
params: Cell::new(core::ptr::null_mut()),
workspace_bytes_device: Cell::new(0),
workspace_bytes_host: Cell::new(0),
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
self.workspace_bytes_device.get()
}
#[inline]
pub fn host_workspace_size(&self) -> usize {
self.workspace_bytes_host.get()
}
pub fn query_workspace_size(&self, _stream: &Stream) -> Result<usize> {
let h = self.ensure_handle()?;
let params = self.ensure_params()?;
let n: i64 = self.desc.n as i64;
let jobvl = if self.desc.compute_left {
CUSOLVER_EIG_MODE_VECTOR
} else {
CUSOLVER_EIG_MODE_NOVECTOR
};
let jobvr = if self.desc.compute_right {
CUSOLVER_EIG_MODE_VECTOR
} else {
CUSOLVER_EIG_MODE_NOVECTOR
};
let dtype = dtype_tag::<T>();
let mut ws_device: usize = 0;
let mut ws_host: usize = 0;
let status = unsafe {
cusolverDnXgeev_bufferSize(
h,
params,
jobvl,
jobvr,
n,
dtype,
core::ptr::null(),
n,
dtype,
core::ptr::null(),
dtype,
core::ptr::null(),
n,
dtype,
core::ptr::null(),
n,
dtype,
&mut ws_device as *mut _,
&mut ws_host as *mut _,
)
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
self.workspace_bytes_device.set(ws_device);
self.workspace_bytes_host.set(ws_host);
Ok(ws_device)
}
fn ensure_handle(&self) -> Result<cusolverDnHandle_t> {
let h = self.handle.get();
if !h.is_null() {
return Ok(h);
}
let mut handle: cusolverDnHandle_t = core::ptr::null_mut();
let status = unsafe { cusolverDnCreate(&mut handle as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
self.handle.set(handle);
Ok(handle)
}
fn ensure_params(&self) -> Result<cusolverDnParams_t> {
let p = self.params.get();
if !p.is_null() {
return Ok(p);
}
let mut params: cusolverDnParams_t = core::ptr::null_mut();
let status = unsafe { cusolverDnCreateParams(&mut params as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
self.params.set(params);
Ok(params)
}
fn bind_stream(&self, h: cusolverDnHandle_t, stream: &Stream) -> Result<()> {
let status = unsafe { cusolverDnSetStream(h, stream.as_raw() as *mut c_void) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
fn check_args(&self, args: &EigArgs<'_, T>) -> Result<()> {
let n = self.desc.n;
if args.a.shape != [n, n] {
return Err(Error::InvalidProblem(
"baracuda-kernels::EigPlan: A shape != [N, N]",
));
}
let w_len = w_packed_len::<T>(n);
if args.w.shape != [w_len] {
return Err(Error::InvalidProblem(
"baracuda-kernels::EigPlan: W shape != [2*N] (real input) or [N] (complex input)",
));
}
if args.info.shape != [1] {
return Err(Error::InvalidProblem(
"baracuda-kernels::EigPlan: info shape != [1]",
));
}
if self.desc.compute_left {
match args.vl.as_ref() {
Some(vl) if vl.shape == [n, n] => {}
Some(_) => {
return Err(Error::InvalidProblem(
"baracuda-kernels::EigPlan: VL shape != [N, N]",
));
}
None => {
return Err(Error::InvalidProblem(
"baracuda-kernels::EigPlan: VL is None but compute_left == true",
));
}
}
}
if self.desc.compute_right {
match args.vr.as_ref() {
Some(vr) if vr.shape == [n, n] => {}
Some(_) => {
return Err(Error::InvalidProblem(
"baracuda-kernels::EigPlan: VR shape != [N, N]",
));
}
None => {
return Err(Error::InvalidProblem(
"baracuda-kernels::EigPlan: VR is None but compute_right == true",
));
}
}
}
Ok(())
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: EigArgs<'_, T>,
) -> Result<()> {
self.check_args(&args)?;
let h = self.ensure_handle()?;
let params = self.ensure_params()?;
self.bind_stream(h, stream)?;
let n: i64 = self.desc.n as i64;
if self.workspace_bytes_device.get() == 0 {
self.query_workspace_size(stream)?;
}
let needed_device = self.workspace_bytes_device.get();
let needed_host = self.workspace_bytes_host.get();
let (ws_dev_ptr, _ws_bytes) = unpack_workspace(workspace, needed_device)?;
let mut host_ws: Vec<u8> = if needed_host > 0 {
vec![0u8; needed_host]
} else {
Vec::new()
};
let host_ws_ptr = if needed_host > 0 {
host_ws.as_mut_ptr() as *mut c_void
} else {
core::ptr::null_mut()
};
let jobvl = if self.desc.compute_left {
CUSOLVER_EIG_MODE_VECTOR
} else {
CUSOLVER_EIG_MODE_NOVECTOR
};
let jobvr = if self.desc.compute_right {
CUSOLVER_EIG_MODE_VECTOR
} else {
CUSOLVER_EIG_MODE_NOVECTOR
};
let dtype = dtype_tag::<T>();
let a_ptr = args.a.data.as_raw().0 as *mut c_void;
let w_ptr = args.w.data.as_raw().0 as *mut c_void;
let vl_ptr = args
.vl
.as_ref()
.map(|v| v.data.as_raw().0 as *mut c_void)
.unwrap_or(core::ptr::null_mut());
let vr_ptr = args
.vr
.as_ref()
.map(|v| v.data.as_raw().0 as *mut c_void)
.unwrap_or(core::ptr::null_mut());
let info_ptr = args.info.data.as_raw().0 as *mut i32;
let status = unsafe {
cusolverDnXgeev(
h,
params,
jobvl,
jobvr,
n,
dtype,
a_ptr,
n,
dtype,
w_ptr,
dtype,
vl_ptr,
n,
dtype,
vr_ptr,
n,
dtype,
ws_dev_ptr,
needed_device,
host_ws_ptr,
needed_host,
info_ptr,
)
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
}
impl<T: Element> Drop for EigPlan<T> {
fn drop(&mut self) {
let p = self.params.get();
if !p.is_null() {
unsafe {
let _ = cusolverDnDestroyParams(p);
}
self.params.set(core::ptr::null_mut());
}
let h = self.handle.get();
if !h.is_null() {
unsafe {
let _ = cusolverDnDestroy(h);
}
self.handle.set(core::ptr::null_mut());
}
}
}
#[inline]
fn dtype_tag<T: Element>() -> cudaDataType {
match T::KIND {
ElementKind::F32 => CUDA_R_32F,
ElementKind::F64 => CUDA_R_64F,
ElementKind::Complex32 => CUDA_C_32F,
ElementKind::Complex64 => CUDA_C_64F,
_ => unreachable!("select() gates on F32 / F64 / Complex32 / Complex64"),
}
}
#[inline]
fn w_packed_len<T: Element>(n: i32) -> i32 {
match T::KIND {
ElementKind::F32 | ElementKind::F64 => 2 * n,
ElementKind::Complex32 | ElementKind::Complex64 => n,
_ => unreachable!("select() gates on F32 / F64 / Complex32 / Complex64"),
}
}