use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
use super::runtime::GpuRuntime;
#[derive(Clone, Copy, Debug)]
pub enum DispatchOp {
Gemm { m: usize, n: usize, k: usize },
BatchedGemm {
batch: usize,
m: usize,
n: usize,
k: usize,
},
Potrf { p: usize, batch: usize },
SmallDenseBatchedPotrf { p: usize, batch: usize },
Trsm { m: usize, n: usize },
Gemv { m: usize, k: usize },
XtDiagX { n: usize, p: usize },
XtDiagY { n: usize, px: usize, q: usize },
JointHessian2x2 { n: usize, pa: usize, pb: usize },
}
impl DispatchOp {
#[inline]
pub const fn flops(self) -> u128 {
match self {
Self::Gemm { m, n, k } => 2u128 * (m as u128) * (n as u128) * (k as u128),
Self::BatchedGemm { batch, m, n, k } => {
2u128 * (batch as u128) * (m as u128) * (n as u128) * (k as u128)
}
Self::Gemv { m, k } => 2u128 * (m as u128) * (k as u128),
Self::Potrf { p, batch } => (batch as u128) * (p as u128).pow(3) / 3,
Self::SmallDenseBatchedPotrf { p, batch } => (batch as u128) * (p as u128).pow(3) / 3,
Self::Trsm { m, n } => (m as u128) * (m as u128) * (n as u128),
Self::XtDiagX { n, p } => 2u128 * (n as u128) * (p as u128) * (p as u128),
Self::XtDiagY { n, px, q } => 2u128 * (n as u128) * (px as u128) * (q as u128),
Self::JointHessian2x2 { n, pa, pb } => {
let total = (pa as u128) + (pb as u128);
2u128 * (n as u128) * total * total
}
}
}
}
#[inline]
#[must_use]
pub fn route_through_gpu(op: DispatchOp) -> Option<&'static GpuRuntime> {
let runtime = GpuRuntime::global()?;
let policy = &runtime.policy;
let admit = match op {
DispatchOp::Gemm { m, n, k } => {
op.flops() >= (policy.gemm_min_flops as u128) && m.min(n).min(k) > 0
}
DispatchOp::BatchedGemm { batch, m, n, k } => {
op.flops() >= (policy.gemm_min_flops as u128) && batch > 1 && m.min(n).min(k) > 0
}
DispatchOp::Gemv { m, k } => {
op.flops() >= (policy.gemm_min_flops as u128) && m > 0 && k > 0
}
DispatchOp::Potrf { p, batch } => {
p > 0
&& batch > 0
&& (p >= policy.potrf_min_p
|| (batch > 1 && op.flops() >= policy.gemm_min_flops as u128))
}
DispatchOp::SmallDenseBatchedPotrf { p, batch } => {
p > 0
&& p <= policy.small_dense_batched_potrf_max_p
&& batch >= policy.small_dense_batched_potrf_min_batch
}
DispatchOp::Trsm { m, n } => {
op.flops() >= (policy.gemm_min_flops as u128) && m > 0 && n > 0
}
DispatchOp::XtDiagX { n, p } => policy.xtwx_target_is_gpu(n, p, true),
DispatchOp::XtDiagY { n, px, q } => policy.xtwy_target_is_gpu(n, px, q, true),
DispatchOp::JointHessian2x2 { n, pa, pb } => {
n > 0 && (pa + pb) > 0 && op.flops() >= policy.gemm_min_flops as u128
}
};
if admit { Some(runtime) } else { None }
}
#[cfg(target_os = "linux")]
const MULTI_GPU_BATCH_FLOOR: usize = 64;
#[cfg(target_os = "linux")]
#[inline]
fn should_split_batch(batch: usize) -> bool {
GpuRuntime::global().is_some_and(|rt| rt.device_count() > 1) && batch >= MULTI_GPU_BATCH_FLOOR
}
#[inline]
#[must_use]
pub fn try_fast_ab_broadcast_b_batched(
a: ArrayView3<'_, f64>,
b: ArrayView2<'_, f64>,
) -> Option<Array3<f64>> {
let (batch, m, k) = a.dim();
let (bk, n) = b.dim();
if k != bk || batch == 0 || m == 0 || n == 0 {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::BatchedGemm { batch, m, n, k })?;
if should_split_batch(batch) {
if let Some(out) = scatter_broadcast_b_batched(runtime, a, b, m, n) {
return Some(out);
}
}
cuda_backend::gemm_broadcast_b_batched(runtime.device.ordinal, a, b)
}
}
#[cfg(target_os = "linux")]
fn scatter_broadcast_b_batched(
runtime: &GpuRuntime,
a: ArrayView3<'_, f64>,
b: ArrayView2<'_, f64>,
m: usize,
n: usize,
) -> Option<Array3<f64>> {
let batch = a.dim().0;
let mut items: Vec<(Array2<f64>, Option<Array2<f64>>)> = (0..batch)
.map(|i| (a.index_axis(ndarray::Axis(0), i).to_owned(), None))
.collect();
super::pool::scatter_batched(runtime, &mut items, |ordinal, tile| {
let tile_batch = tile.len();
if tile_batch == 0 {
return Some(());
}
let k = b.dim().0;
let mut a_tile = Array3::<f64>::zeros((tile_batch, m, k));
for (idx, (a_i, _)) in tile.iter().enumerate() {
a_tile.index_axis_mut(ndarray::Axis(0), idx).assign(a_i);
}
let out = cuda_backend::gemm_broadcast_b_batched(ordinal, a_tile.view(), b)?;
for (idx, (_, slot)) in tile.iter_mut().enumerate() {
*slot = Some(out.index_axis(ndarray::Axis(0), idx).to_owned());
}
Some(())
})?;
stitch_batched(items, m, n)
}
#[inline]
#[must_use]
pub fn try_fast_abt_strided_batched(
a: ArrayView3<'_, f64>,
b: ArrayView3<'_, f64>,
) -> Option<Array3<f64>> {
let (batch, m, k) = a.dim();
let (batch_b, n, k_b) = b.dim();
if batch != batch_b || k != k_b || batch == 0 || m == 0 || n == 0 {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::BatchedGemm { batch, m, n, k })?;
if should_split_batch(batch) {
if let Some(out) = scatter_abt_strided_batched(runtime, a, b, m, n) {
return Some(out);
}
}
cuda_backend::gemm_abt_strided_batched(runtime.device.ordinal, a, b)
}
}
#[cfg(target_os = "linux")]
fn scatter_abt_strided_batched(
runtime: &GpuRuntime,
a: ArrayView3<'_, f64>,
b: ArrayView3<'_, f64>,
m: usize,
n: usize,
) -> Option<Array3<f64>> {
let batch = a.dim().0;
let mut items: Vec<(Array2<f64>, Array2<f64>, Option<Array2<f64>>)> = (0..batch)
.map(|i| {
(
a.index_axis(ndarray::Axis(0), i).to_owned(),
b.index_axis(ndarray::Axis(0), i).to_owned(),
None,
)
})
.collect();
super::pool::scatter_batched(runtime, &mut items, |ordinal, tile| {
let tile_batch = tile.len();
if tile_batch == 0 {
return Some(());
}
let k = tile[0].0.dim().1;
let mut a_tile = Array3::<f64>::zeros((tile_batch, m, k));
let mut b_tile = Array3::<f64>::zeros((tile_batch, n, k));
for (idx, (a_i, b_i, _)) in tile.iter().enumerate() {
a_tile.index_axis_mut(ndarray::Axis(0), idx).assign(a_i);
b_tile.index_axis_mut(ndarray::Axis(0), idx).assign(b_i);
}
let out = cuda_backend::gemm_abt_strided_batched(ordinal, a_tile.view(), b_tile.view())?;
for (idx, (_, _, slot)) in tile.iter_mut().enumerate() {
*slot = Some(out.index_axis(ndarray::Axis(0), idx).to_owned());
}
Some(())
})?;
let slots: Vec<((), Option<Array2<f64>>)> =
items.into_iter().map(|(_, _, slot)| ((), slot)).collect();
stitch_batched(slots, m, n)
}
#[cfg(target_os = "linux")]
fn stitch_batched<L>(
items: Vec<(L, Option<Array2<f64>>)>,
m: usize,
n: usize,
) -> Option<Array3<f64>> {
let batch = items.len();
let mut out = Array3::<f64>::zeros((batch, m, n));
for (idx, (_, slot)) in items.into_iter().enumerate() {
let block = slot?;
if block.dim() != (m, n) {
return None;
}
out.index_axis_mut(ndarray::Axis(0), idx).assign(&block);
}
Some(out)
}
#[inline]
#[must_use]
pub fn try_fast_ab(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
let (m, k) = a.dim();
let (kb, n) = b.dim();
if k != kb {
return None;
}
let runtime = route_through_gpu(DispatchOp::Gemm { m, n, k });
let used_gpu = runtime.is_some();
super::profile::record(super::profile::KernelStat {
name: "try_fast_ab",
n: m,
p: n,
k,
flops_est: (DispatchOp::Gemm { m, n, k }.flops().min(usize::MAX as u128)) as usize,
gpu_ms: if used_gpu { Some(0.0) } else { None },
..Default::default()
});
#[cfg(not(target_os = "linux"))]
{
None
}
#[cfg(target_os = "linux")]
{
let runtime = runtime?;
cuda_backend::gemm(runtime, a, b, false, false)
}
}
#[inline]
#[must_use]
pub fn try_fast_atb(a: ArrayView2<'_, f64>, b: ArrayView2<'_, f64>) -> Option<Array2<f64>> {
let (n_a, p) = a.dim();
let (n_b, q) = b.dim();
if n_a != n_b || p == 0 || q == 0 {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::Gemm { m: p, n: q, k: n_a })?;
cuda_backend::gemm(runtime, a, b, true, false)
}
}
#[inline]
#[must_use]
pub fn try_fast_atb_on_ordinal(
ordinal: usize,
a: ArrayView2<'_, f64>,
b: ArrayView2<'_, f64>,
) -> Option<Array2<f64>> {
let (n_a, p) = a.dim();
let (n_b, q) = b.dim();
if n_a != n_b || p == 0 || q == 0 {
return None;
}
#[cfg(not(target_os = "linux"))]
{
log::trace!(
"try_fast_atb_on_ordinal: CUDA unavailable off Linux; declining ordinal {ordinal}"
);
return None;
}
#[cfg(target_os = "linux")]
{
route_through_gpu(DispatchOp::Gemm { m: p, n: q, k: n_a })?;
cuda_backend::gemm_on_ordinal(ordinal, a, b, true, false)
}
}
#[inline]
#[must_use]
pub fn try_fast_av(a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
let (m, k) = a.dim();
if k != v.len() || m == 0 || k == 0 {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::Gemv { m, k })?;
cuda_backend::gemv(runtime, a, v, false)
}
}
#[inline]
#[must_use]
pub fn try_fast_atv(a: ArrayView2<'_, f64>, v: ArrayView1<'_, f64>) -> Option<Array1<f64>> {
let (n, p) = a.dim();
if n != v.len() || n == 0 || p == 0 {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::Gemv { m: p, k: n })?;
cuda_backend::gemv(runtime, a, v, true)
}
}
#[inline]
#[must_use]
pub fn try_fast_xt_diag_x(x: ArrayView2<'_, f64>, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
let (n, p) = x.dim();
if n != w.len() || n == 0 || p == 0 {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
cuda_backend::xt_diag_x(runtime, x, w)
}
}
#[cfg(target_os = "linux")]
const LEVERAGE_CHUNKS_PER_DEVICE: usize = 4;
#[cfg(target_os = "linux")]
#[inline]
fn leverage_chunk_rows(cols: usize, n_rows: usize) -> usize {
const TARGET_BYTES: usize = 8 * 1024 * 1024;
const MIN_CHUNK_ROWS: usize = 512;
let bytes_per_row = cols.max(1) * std::mem::size_of::<f64>();
(TARGET_BYTES / bytes_per_row)
.max(MIN_CHUNK_ROWS)
.min(n_rows.max(1))
}
#[inline]
#[must_use]
pub fn try_fast_spectral_leverage_diagonal(
x: &crate::linalg::matrix::DesignMatrix,
g: ArrayView2<'_, f64>,
) -> Option<Array1<f64>> {
let n = x.nrows();
let p = x.ncols();
let rank = g.ncols();
if n == 0 || p == 0 || rank == 0 || g.nrows() != p {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::XtDiagX { n, p })?;
let device_count = runtime.device_count().max(1);
let byte_chunk = leverage_chunk_rows(p + rank, n);
let target_chunks = device_count
.saturating_mul(LEVERAGE_CHUNKS_PER_DEVICE)
.max(1);
let chunk_rows = byte_chunk.min(n.div_ceil(target_chunks).max(1)).max(1);
let mut tiles: Vec<(std::ops::Range<usize>, Option<Array1<f64>>)> = Vec::new();
let mut start = 0usize;
while start < n {
let end = (start + chunk_rows).min(n);
tiles.push((start..end, None));
start = end;
}
super::pool::scatter_batched(runtime, &mut tiles, |ordinal, tile| {
for (range, slot) in tile.iter_mut() {
let rows = x.try_row_chunk(range.clone()).ok()?;
let xg = cuda_backend::gemm_on_ordinal(ordinal, rows.view(), g, false, false)?;
let mut out = Array1::<f64>::zeros(range.end - range.start);
for (local, row) in xg.outer_iter().enumerate() {
out[local] = row.iter().map(|&v| v * v).sum();
}
*slot = Some(out);
}
Some(())
})?;
let mut h = Array1::<f64>::zeros(n);
for (range, slot) in tiles {
let vals = slot?;
if vals.len() != range.end - range.start {
return None;
}
h.slice_mut(ndarray::s![range]).assign(&vals);
}
Some(h)
}
}
#[inline]
#[must_use]
pub fn try_fast_xt_diag_y(
x: ArrayView2<'_, f64>,
w: ArrayView1<'_, f64>,
y: ArrayView2<'_, f64>,
) -> Option<Array2<f64>> {
let (n, px) = x.dim();
let (n_y, q) = y.dim();
if n != n_y || n != w.len() || n == 0 || px == 0 || q == 0 {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::XtDiagY { n, px, q })?;
cuda_backend::xt_diag_y(runtime, x, w, y)
}
}
#[inline]
#[must_use]
pub fn try_fast_joint_hessian_2x2(
x_a: ArrayView2<'_, f64>,
x_b: ArrayView2<'_, f64>,
w_aa: ArrayView1<'_, f64>,
w_ab: ArrayView1<'_, f64>,
w_bb: ArrayView1<'_, f64>,
) -> Option<Array2<f64>> {
let (n, pa) = x_a.dim();
let (n_b, pb) = x_b.dim();
if n != n_b || n != w_aa.len() || n != w_ab.len() || n != w_bb.len() || pa + pb == 0 {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::JointHessian2x2 { n, pa, pb })?;
cuda_backend::joint_hessian_2x2(runtime, x_a, x_b, w_aa, w_ab, w_bb)
}
}
#[inline]
#[must_use]
pub fn try_cholesky_lower_inplace(a: &mut Array2<f64>) -> Option<()> {
let p = a.nrows();
if p != a.ncols() {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::Potrf { p, batch: 1 })?;
let lower = cuda_backend::cholesky_lower(runtime, a.view())?;
*a = lower;
Some(())
}
}
#[inline]
#[must_use]
pub fn try_cholesky_batched_lower_inplace(matrices: &mut [Array2<f64>]) -> Option<()> {
let first = matrices.first()?;
let p = first.nrows();
if p == 0 || first.ncols() != p || matrices.iter().any(|matrix| matrix.dim() != (p, p)) {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let batch = matrices.len();
let runtime = route_through_gpu(DispatchOp::SmallDenseBatchedPotrf { p, batch })
.or_else(|| route_through_gpu(DispatchOp::Potrf { p, batch }))?;
if should_split_batch(batch) {
let split = super::pool::scatter_batched(runtime, matrices, |ordinal, tile| {
cuda_backend::cholesky_batched_lower(ordinal, tile)
});
if split.is_some() {
return Some(());
}
}
cuda_backend::cholesky_batched_lower(runtime.device.ordinal, matrices)
}
}
#[inline]
#[must_use]
pub fn try_solve_lower_triangular_matrix(
lower: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Option<Array2<f64>> {
let (m, n) = rhs.dim();
if m == 0 || n == 0 || lower.nrows() != m {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::Trsm { m, n })?;
cuda_backend::trsm(runtime, lower, rhs, false)
}
}
#[inline]
#[must_use]
pub fn try_solve_upper_triangular_matrix(
upper: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
) -> Option<Array2<f64>> {
let (m, n) = rhs.dim();
if m == 0 || n == 0 || upper.nrows() != m {
return None;
}
#[cfg(not(target_os = "linux"))]
{
return None;
}
#[cfg(target_os = "linux")]
{
let runtime = route_through_gpu(DispatchOp::Trsm { m, n })?;
cuda_backend::trsm(runtime, upper, rhs, true)
}
}
#[cfg(test)]
mod tests {
use super::{DispatchOp, route_through_gpu};
use crate::gpu::runtime::GpuRuntime;
#[test]
fn sae_shape_dispatch_ops_route_when_cuda_runtime_is_present() {
let Some(runtime) = GpuRuntime::global() else {
eprintln!("[sae dispatch gate] no CUDA runtime - skipping branch-admission check");
return;
};
let n = 2_000usize;
let p = 2_048usize;
let m = 12usize;
let k = 8usize;
let dense_reduction_ops = [
DispatchOp::XtDiagX { n, p },
DispatchOp::XtDiagY { n, px: p, q: m * k },
DispatchOp::JointHessian2x2 {
n,
pa: p,
pb: m * k,
},
DispatchOp::Gemm {
m: p,
n: p,
k: n * m,
},
];
for op in dense_reduction_ops {
assert!(
op.flops() >= runtime.policy.gemm_min_flops as u128,
"SAE dispatch fixture must clear the runtime GEMM work floor: op={op:?}, flops={}, floor={}",
op.flops(),
runtime.policy.gemm_min_flops
);
assert!(
route_through_gpu(op).is_some(),
"SAE dispatch fixture should route to GPU when CUDA is present: {op:?}"
);
}
let batched_potrf = DispatchOp::SmallDenseBatchedPotrf { p: m, batch: n };
assert!(
route_through_gpu(batched_potrf).is_some(),
"uniform SAE row blocks should reach the small-dense batched POTRF gate"
);
}
}
#[cfg(target_os = "linux")]
mod cuda_backend {
use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
use super::super::runtime::GpuRuntime;
use crate::gpu::driver::{from_col_major, to_col_major, to_i32};
use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
use cudarc::driver::{DevicePtrMut, sys as driver_sys};
#[inline]
pub(super) fn gemm(
runtime: &GpuRuntime,
a: ArrayView2<'_, f64>,
b: ArrayView2<'_, f64>,
trans_a: bool,
trans_b: bool,
) -> Option<Array2<f64>> {
super::super::blas::gemm_cuda(runtime, a, b, trans_a, trans_b)
}
#[inline]
pub(super) fn gemm_on_ordinal(
ordinal: usize,
a: ArrayView2<'_, f64>,
b: ArrayView2<'_, f64>,
trans_a: bool,
trans_b: bool,
) -> Option<Array2<f64>> {
super::super::blas::gemm_on_ordinal_cuda(ordinal, a, b, trans_a, trans_b)
}
#[inline]
pub(super) fn gemv(
runtime: &GpuRuntime,
a: ArrayView2<'_, f64>,
v: ArrayView1<'_, f64>,
trans_a: bool,
) -> Option<Array1<f64>> {
super::super::blas::gemv_cuda(runtime, a, v, trans_a)
}
#[inline]
pub(super) fn gemm_broadcast_b_batched(
ordinal: usize,
a: ArrayView3<'_, f64>,
b: ArrayView2<'_, f64>,
) -> Option<Array3<f64>> {
super::super::blas::gemm_broadcast_b_batched_cuda(ordinal, a, b)
}
#[inline]
pub(super) fn gemm_abt_strided_batched(
ordinal: usize,
a: ArrayView3<'_, f64>,
b: ArrayView3<'_, f64>,
) -> Option<Array3<f64>> {
super::super::blas::gemm_abt_strided_batched_cuda(ordinal, a, b)
}
#[inline]
pub(super) fn xt_diag_x(
runtime: &GpuRuntime,
x: ArrayView2<'_, f64>,
w: ArrayView1<'_, f64>,
) -> Option<Array2<f64>> {
super::super::blas::xt_diag_x_cuda(runtime, x, w)
}
#[inline]
pub(super) fn xt_diag_y(
runtime: &GpuRuntime,
x: ArrayView2<'_, f64>,
w: ArrayView1<'_, f64>,
y: ArrayView2<'_, f64>,
) -> Option<Array2<f64>> {
super::super::blas::xt_diag_y_cuda(runtime, x, w, y)
}
#[inline]
pub(super) fn joint_hessian_2x2(
runtime: &GpuRuntime,
x_a: ArrayView2<'_, f64>,
x_b: ArrayView2<'_, f64>,
w_aa: ArrayView1<'_, f64>,
w_ab: ArrayView1<'_, f64>,
w_bb: ArrayView1<'_, f64>,
) -> Option<Array2<f64>> {
super::super::blas::joint_hessian_2x2_cuda(runtime, x_a, x_b, w_aa, w_ab, w_bb)
}
#[inline]
pub(super) fn trsm(
runtime: &GpuRuntime,
triangular: ArrayView2<'_, f64>,
rhs: ArrayView2<'_, f64>,
upper: bool,
) -> Option<Array2<f64>> {
super::super::blas::trsm_cuda(runtime, triangular, rhs, upper)
}
#[inline]
pub(super) fn cholesky_lower(
runtime: &GpuRuntime,
a: ArrayView2<'_, f64>,
) -> Option<Array2<f64>> {
let (p, p2) = a.dim();
if p == 0 || p != p2 {
return None;
}
let stream = super::super::runtime::cuda_context_for(runtime.device.ordinal)?
.new_stream()
.ok()?;
let solver = DnHandle::new(stream.clone()).ok()?;
let a_col = to_col_major(&a);
let mut a_dev = stream.clone_htod(&*a_col).ok()?;
potrf_lower_in_place(&solver, &stream, p, &mut a_dev)?;
let factor_col = stream.clone_dtoh(&a_dev).ok()?;
let mut lower = from_col_major(&factor_col, p, p)?;
for row in 0..p {
for col in (row + 1)..p {
lower[[row, col]] = 0.0;
}
}
Some(lower)
}
#[inline]
pub(super) fn cholesky_batched_lower(
ordinal: usize,
matrices: &mut [Array2<f64>],
) -> Option<()> {
let first = matrices.first()?;
let p = first.nrows();
if p == 0 || first.ncols() != p || matrices.iter().any(|matrix| matrix.dim() != (p, p)) {
return None;
}
let stream = super::super::runtime::cuda_context_for(ordinal)?
.new_stream()
.ok()?;
let solver = DnHandle::new(stream.clone()).ok()?;
let matrix_len = p.checked_mul(p)?;
let mut batch_col = Vec::with_capacity(matrices.len().checked_mul(matrix_len)?);
for matrix in matrices.iter() {
batch_col.extend(to_col_major(&matrix.view()).iter().copied());
}
let mut matrices_dev = stream.clone_htod(&batch_col).ok()?;
let matrix_ptrs = {
let (base_ptr, _matrix_record) = matrices_dev.device_ptr_mut(&stream);
let bytes_per_matrix = driver_sys::CUdeviceptr::try_from(
matrix_len.checked_mul(std::mem::size_of::<f64>())?,
)
.ok()?;
let mut matrix_ptrs = Vec::with_capacity(matrices.len());
for idx in 0..matrices.len() {
let offset = driver_sys::CUdeviceptr::try_from(idx).ok()? * bytes_per_matrix;
matrix_ptrs.push(base_ptr + offset);
}
matrix_ptrs
};
let mut matrix_ptrs_dev = stream.clone_htod(&matrix_ptrs).ok()?;
let mut info_dev = stream.alloc_zeros::<i32>(matrices.len()).ok()?;
let p_i = to_i32(p)?;
let batch_i = to_i32(matrices.len())?;
{
let (ptrs_ptr, _ptrs_record) = matrix_ptrs_dev.device_ptr_mut(&stream);
let (info_ptr, _info_record) = info_dev.device_ptr_mut(&stream);
let status = unsafe {
cusolver_sys::cusolverDnDpotrfBatched(
solver.cu(),
cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
p_i,
ptrs_ptr as *mut *mut f64,
p_i,
info_ptr as *mut i32,
batch_i,
)
};
check_cusolver(status)?;
}
let info_host = stream.clone_dtoh(&info_dev).ok()?;
if info_host.iter().any(|info| *info != 0) {
return None;
}
let factored_col = stream.clone_dtoh(&matrices_dev).ok()?;
for (idx, matrix) in matrices.iter_mut().enumerate() {
let start = idx.checked_mul(matrix_len)?;
let end = start.checked_add(matrix_len)?;
let mut lower = from_col_major(&factored_col[start..end], p, p)?;
for row in 0..p {
for col in (row + 1)..p {
lower[[row, col]] = 0.0;
}
}
*matrix = lower;
}
Some(())
}
fn potrf_lower_in_place(
solver: &DnHandle,
stream: &std::sync::Arc<cudarc::driver::CudaStream>,
p: usize,
a: &mut cudarc::driver::CudaSlice<f64>,
) -> Option<()> {
crate::gpu::solver::potrf_in_place_generic::<f64>(solver, stream, p, a).ok()
}
#[inline]
fn check_cusolver(status: cusolver_sys::cusolverStatus_t) -> Option<()> {
if status == cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
Some(())
} else {
None
}
}
}