use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
baracuda_kernels_batched_ormqr_wy_build_t_complex32_run,
baracuda_kernels_batched_ormqr_wy_build_t_complex64_run,
baracuda_kernels_batched_ormqr_wy_build_t_f32_run,
baracuda_kernels_batched_ormqr_wy_build_t_f64_run,
baracuda_kernels_batched_ormqr_wy_extract_v_complex32_run,
baracuda_kernels_batched_ormqr_wy_extract_v_complex64_run,
baracuda_kernels_batched_ormqr_wy_extract_v_f32_run,
baracuda_kernels_batched_ormqr_wy_extract_v_f64_run, cublasCgemmStridedBatched, cublasCreate_v2,
cublasDestroy_v2, cublasDgemmStridedBatched, cublasHandle_t, cublasSetStream_v2,
cublasSgemmStridedBatched, cublasZgemmStridedBatched, cuComplex, cuDoubleComplex, CUBLAS_OP_C,
CUBLAS_OP_N, CUBLAS_OP_T,
};
use baracuda_kernels_types::{
ArchSku, BackendKind, Complex32, Complex64, Element, ElementKind, KernelSku, LinalgKind,
MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, Workspace,
};
use super::cholesky::unpack_workspace;
use super::ormqr_batched::{BatchedOrmqrOp, BatchedOrmqrSide};
pub const WY_NB: i32 = 32;
#[derive(Copy, Clone, Debug)]
pub struct BatchedOrmqrWyDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub batch_size: i32,
pub side: BatchedOrmqrSide,
pub op: BatchedOrmqrOp,
pub element: ElementKind,
}
pub struct BatchedOrmqrWyArgs<'a, T: Element> {
pub a: TensorMut<'a, T, 3>,
pub tau: TensorMut<'a, T, 2>,
pub c: TensorMut<'a, T, 3>,
}
pub struct BatchedOrmqrWyPlan<T: Element> {
desc: BatchedOrmqrWyDescriptor,
sku: KernelSku,
handle: Cell<cublasHandle_t>,
workspace_bytes: Cell<usize>,
num_blocks: i32,
_marker: PhantomData<T>,
}
impl<T: Element> BatchedOrmqrWyPlan<T> {
pub fn select(
_stream: &Stream,
desc: &BatchedOrmqrWyDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::BatchedOrmqrWyPlan: descriptor.element != T::KIND",
));
}
let is_real = matches!(T::KIND, ElementKind::F32 | ElementKind::F64);
let is_complex = matches!(T::KIND, ElementKind::Complex32 | ElementKind::Complex64);
if !(is_real || is_complex) {
return Err(Error::Unsupported(
"baracuda-kernels::BatchedOrmqrWyPlan: dtype must be one of \
{f32, f64, Complex32, Complex64}",
));
}
if !matches!(desc.side, BatchedOrmqrSide::Left) {
return Err(Error::Unsupported(
"baracuda-kernels::BatchedOrmqrWyPlan: side = Right is deferred",
));
}
match (desc.op, is_complex) {
(BatchedOrmqrOp::T, true) => {
return Err(Error::Unsupported(
"baracuda-kernels::BatchedOrmqrWyPlan: op = T (plain transpose) is \
real-only; use op = C (conjugate transpose) for complex dtypes",
));
}
(BatchedOrmqrOp::C, false) => {
return Err(Error::Unsupported(
"baracuda-kernels::BatchedOrmqrWyPlan: op = C (conjugate transpose) is \
complex-only; use op = T for real dtypes",
));
}
_ => {}
}
if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::BatchedOrmqrWyPlan: M, N, K must be > 0",
));
}
if desc.batch_size <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::BatchedOrmqrWyPlan: batch_size must be > 0",
));
}
if desc.k > desc.m {
return Err(Error::InvalidProblem(
"baracuda-kernels::BatchedOrmqrWyPlan: K must be <= M (ormqr contract)",
));
}
let nb = WY_NB;
let num_blocks = (desc.k + nb - 1) / nb;
let elem = core::mem::size_of::<T>();
let b = desc.batch_size as usize;
let m = desc.m as usize;
let n = desc.n as usize;
let nbu = nb as usize;
let nbb = num_blocks as usize;
let t_elems = b * nbb * nbu * nbu;
let v_elems = b * m * nbu;
let w_elems = b * nbu * n;
let w2_elems = b * nbu * n;
let ws_bytes = (t_elems + v_elems + w_elems + w2_elems) * elem;
let math_precision = match T::KIND {
ElementKind::F64 | ElementKind::Complex64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Linalg,
op: LinalgKind::BatchedOrmqrWy as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
handle: Cell::new(core::ptr::null_mut()),
workspace_bytes: Cell::new(ws_bytes),
num_blocks,
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
self.workspace_bytes.get()
}
pub fn query_workspace_size(&self, _stream: &Stream) -> Result<usize> {
Ok(self.workspace_bytes.get())
}
fn ensure_handle(&self) -> Result<cublasHandle_t> {
let h = self.handle.get();
if !h.is_null() {
return Ok(h);
}
let mut handle: cublasHandle_t = core::ptr::null_mut();
let status = unsafe { cublasCreate_v2(&mut handle as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
self.handle.set(handle);
Ok(handle)
}
fn bind_stream(&self, h: cublasHandle_t, stream: &Stream) -> Result<()> {
let status = unsafe { cublasSetStream_v2(h, stream.as_raw() as *mut c_void) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
fn check_args(&self, args: &BatchedOrmqrWyArgs<'_, T>) -> Result<()> {
let b = self.desc.batch_size;
let m = self.desc.m;
let n = self.desc.n;
let k = self.desc.k;
if args.a.shape != [b, m, k] {
return Err(Error::InvalidProblem(
"baracuda-kernels::BatchedOrmqrWyPlan: A shape != [batch, M, K]",
));
}
if args.tau.shape != [b, k] {
return Err(Error::InvalidProblem(
"baracuda-kernels::BatchedOrmqrWyPlan: tau shape != [batch, K]",
));
}
if args.c.shape != [b, m, n] {
return Err(Error::InvalidProblem(
"baracuda-kernels::BatchedOrmqrWyPlan: C shape != [batch, M, N]",
));
}
Ok(())
}
}
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}
const ALPHA_ONE_F32: f32 = 1.0;
const BETA_ZERO_F32: f32 = 0.0;
const ALPHA_NEG_ONE_F32: f32 = -1.0;
const BETA_ONE_F32: f32 = 1.0;
const ALPHA_ONE_F64: f64 = 1.0;
const BETA_ZERO_F64: f64 = 0.0;
const ALPHA_NEG_ONE_F64: f64 = -1.0;
const BETA_ONE_F64: f64 = 1.0;
const ALPHA_ONE_C32: cuComplex = cuComplex { x: 1.0, y: 0.0 };
const BETA_ZERO_C32: cuComplex = cuComplex { x: 0.0, y: 0.0 };
const ALPHA_NEG_ONE_C32: cuComplex = cuComplex { x: -1.0, y: 0.0 };
const BETA_ONE_C32: cuComplex = cuComplex { x: 1.0, y: 0.0 };
const ALPHA_ONE_C64: cuDoubleComplex = cuDoubleComplex { x: 1.0, y: 0.0 };
const BETA_ZERO_C64: cuDoubleComplex = cuDoubleComplex { x: 0.0, y: 0.0 };
const ALPHA_NEG_ONE_C64: cuDoubleComplex = cuDoubleComplex { x: -1.0, y: 0.0 };
const BETA_ONE_C64: cuDoubleComplex = cuDoubleComplex { x: 1.0, y: 0.0 };
macro_rules! impl_batched_ormqr_wy_run {
(
$T:ty,
$CublasT:ty,
$build_t:ident,
$extract_v:ident,
$gemm_strided:ident,
$alpha_one:ident,
$beta_zero:ident,
$alpha_neg_one:ident,
$beta_one:ident,
$adjoint_trans:expr
) => {
impl BatchedOrmqrWyPlan<$T> {
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: BatchedOrmqrWyArgs<'_, $T>,
) -> Result<()> {
self.check_args(&args)?;
let h = self.ensure_handle()?;
self.bind_stream(h, stream)?;
let b = self.desc.batch_size;
let m = self.desc.m;
let n = self.desc.n;
let k = self.desc.k;
let nb = WY_NB;
let num_blocks = self.num_blocks;
let needed = self.workspace_bytes.get();
let (ws_ptr, ws_bytes) = unpack_workspace(workspace, needed)?;
if ws_bytes < needed {
return Err(Error::WorkspaceTooSmall {
needed,
got: ws_bytes,
});
}
let elem = core::mem::size_of::<$T>();
let bu = b as usize;
let mu = m as usize;
let nu = n as usize;
let nbu = nb as usize;
let nbb = num_blocks as usize;
let t_elems = bu * nbb * nbu * nbu;
let v_elems = bu * mu * nbu;
let w_elems = bu * nbu * nu;
let w2_elems = bu * nbu * nu;
let t_ptr = ws_ptr as *mut u8;
let v_ptr = unsafe { t_ptr.add(t_elems * elem) };
let w_ptr = unsafe { v_ptr.add(v_elems * elem) };
let w2_ptr = unsafe { w_ptr.add(w_elems * elem) };
debug_assert_eq!(
needed,
(t_elems + v_elems + w_elems + w2_elems) * elem
);
let _ = w2_elems;
let a_ptr_v = args.a.data.as_raw().0 as *const c_void;
let tau_ptr_v = args.tau.data.as_raw().0 as *const c_void;
let c_ptr = args.c.data.as_raw().0 as *mut $CublasT;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
$build_t(
b,
m,
k,
nb,
num_blocks,
a_ptr_v,
tau_ptr_v,
t_ptr as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(status)?;
let block_indices: Vec<i32> = match self.desc.op {
BatchedOrmqrOp::N => (0..num_blocks).rev().collect(),
BatchedOrmqrOp::T | BatchedOrmqrOp::C => (0..num_blocks).collect(),
};
let adjoint_trans: i32 = $adjoint_trans;
for blk in block_indices {
let block_start = blk * nb;
let block_k = if block_start + nb < k {
nb
} else {
k - block_start
};
if block_k <= 0 {
continue;
}
let status = unsafe {
$extract_v(
b,
m,
k,
nb,
block_start,
block_k,
a_ptr_v,
v_ptr as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(status)?;
let t_block_offset_elems = (blk as i64) * (nb as i64) * (nb as i64);
let t_block_ptr = unsafe {
(t_ptr as *mut $CublasT).offset(t_block_offset_elems as isize)
};
let t_slot_stride: i64 = (num_blocks as i64) * (nb as i64) * (nb as i64);
let v_typed = v_ptr as *const $CublasT;
let v_slot_stride: i64 = (m as i64) * (nb as i64);
let w_typed = w_ptr as *mut $CublasT;
let w_slot_stride: i64 = (nb as i64) * (n as i64);
let w2_typed = w2_ptr as *mut $CublasT;
let w2_slot_stride: i64 = (nb as i64) * (n as i64);
let c_slot_stride: i64 = (m as i64) * (n as i64);
let status = unsafe {
$gemm_strided(
h,
adjoint_trans,
CUBLAS_OP_N,
nb, n, m,
&$alpha_one as *const $CublasT,
v_typed, m, v_slot_stride,
c_ptr as *const $CublasT, m, c_slot_stride,
&$beta_zero as *const $CublasT,
w_typed, nb, w_slot_stride,
b,
)
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let trans_t = match self.desc.op {
BatchedOrmqrOp::N => CUBLAS_OP_N,
BatchedOrmqrOp::T | BatchedOrmqrOp::C => adjoint_trans,
};
let status = unsafe {
$gemm_strided(
h,
trans_t,
CUBLAS_OP_N,
nb, n, nb,
&$alpha_one as *const $CublasT,
t_block_ptr as *const $CublasT, nb, t_slot_stride,
w_typed as *const $CublasT, nb, w_slot_stride,
&$beta_zero as *const $CublasT,
w2_typed, nb, w2_slot_stride,
b,
)
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let status = unsafe {
$gemm_strided(
h,
CUBLAS_OP_N,
CUBLAS_OP_N,
m, n, nb,
&$alpha_neg_one as *const $CublasT,
v_typed, m, v_slot_stride,
w2_typed as *const $CublasT, nb, w2_slot_stride,
&$beta_one as *const $CublasT,
c_ptr, m, c_slot_stride,
b,
)
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
}
Ok(())
}
}
};
}
impl_batched_ormqr_wy_run!(
f32,
f32,
baracuda_kernels_batched_ormqr_wy_build_t_f32_run,
baracuda_kernels_batched_ormqr_wy_extract_v_f32_run,
cublasSgemmStridedBatched,
ALPHA_ONE_F32,
BETA_ZERO_F32,
ALPHA_NEG_ONE_F32,
BETA_ONE_F32,
CUBLAS_OP_T
);
impl_batched_ormqr_wy_run!(
f64,
f64,
baracuda_kernels_batched_ormqr_wy_build_t_f64_run,
baracuda_kernels_batched_ormqr_wy_extract_v_f64_run,
cublasDgemmStridedBatched,
ALPHA_ONE_F64,
BETA_ZERO_F64,
ALPHA_NEG_ONE_F64,
BETA_ONE_F64,
CUBLAS_OP_T
);
impl_batched_ormqr_wy_run!(
Complex32,
cuComplex,
baracuda_kernels_batched_ormqr_wy_build_t_complex32_run,
baracuda_kernels_batched_ormqr_wy_extract_v_complex32_run,
cublasCgemmStridedBatched,
ALPHA_ONE_C32,
BETA_ZERO_C32,
ALPHA_NEG_ONE_C32,
BETA_ONE_C32,
CUBLAS_OP_C
);
impl_batched_ormqr_wy_run!(
Complex64,
cuDoubleComplex,
baracuda_kernels_batched_ormqr_wy_build_t_complex64_run,
baracuda_kernels_batched_ormqr_wy_extract_v_complex64_run,
cublasZgemmStridedBatched,
ALPHA_ONE_C64,
BETA_ZERO_C64,
ALPHA_NEG_ONE_C64,
BETA_ONE_C64,
CUBLAS_OP_C
);
impl<T: Element> Drop for BatchedOrmqrWyPlan<T> {
fn drop(&mut self) {
let h = self.handle.get();
if !h.is_null() {
unsafe {
let _ = cublasDestroy_v2(h);
}
self.handle.set(core::ptr::null_mut());
}
}
}