use std::sync::Arc;
use oxicuda_blas::types::GpuFloat;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::prelude::*;
use crate::error::{SolverError, SolverResult};
use crate::handle::SolverHandle;
use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
const MAX_BATCH_MATRIX_SIZE: usize = 64;
const MIN_BATCH_MATRIX_SIZE: usize = 1;
const SMALL_MATRIX_THRESHOLD: usize = 16;
const SMALL_MATRICES_PER_BLOCK: usize = 4;
pub struct BatchedSolver {
handle: SolverHandle,
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub matrix_size: usize,
pub batch_count: usize,
pub algorithm: BatchAlgorithm,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BatchAlgorithm {
Lu,
Qr,
Cholesky,
}
#[derive(Debug, Clone)]
pub struct BatchedResult {
pub failed_count: usize,
}
impl BatchedSolver {
pub fn new(handle: SolverHandle) -> Self {
Self { handle }
}
pub fn handle(&self) -> &SolverHandle {
&self.handle
}
pub fn handle_mut(&mut self) -> &mut SolverHandle {
&mut self.handle
}
pub fn batched_lu<T: GpuFloat>(
&mut self,
matrices: &mut DeviceBuffer<T>,
pivots: &mut DeviceBuffer<i32>,
n: usize,
batch_count: usize,
) -> SolverResult<BatchedResult> {
validate_batched_params::<T>(matrices, n, batch_count)?;
validate_pivot_buffer(pivots, n, batch_count)?;
if n == 0 || batch_count == 0 {
return Ok(BatchedResult { failed_count: 0 });
}
let shared_per_matrix = n * n * T::SIZE;
let matrices_per_block = matrices_per_block(n);
let ws_bytes = shared_per_matrix * matrices_per_block;
self.handle.ensure_workspace(ws_bytes)?;
let sm = self.handle.sm_version();
let ptx = emit_batched_lu::<T>(sm, n)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &batched_lu_name::<T>(n))?;
let grid = compute_grid_size(batch_count, n);
let block = compute_block_size(n);
let shared_bytes = (shared_per_matrix * matrices_per_block) as u32;
let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes);
let args = (
matrices.as_device_ptr(),
pivots.as_device_ptr(),
n as u32,
batch_count as u32,
);
kernel.launch(¶ms, self.handle.stream(), &args)?;
Ok(BatchedResult { failed_count: 0 })
}
pub fn batched_qr<T: GpuFloat>(
&mut self,
matrices: &mut DeviceBuffer<T>,
tau: &mut DeviceBuffer<T>,
m: usize,
n: usize,
batch_count: usize,
) -> SolverResult<BatchedResult> {
if m == 0 || n == 0 || batch_count == 0 {
return Ok(BatchedResult { failed_count: 0 });
}
let required_mat = batch_count * m * n;
if matrices.len() < required_mat {
return Err(SolverError::DimensionMismatch(format!(
"batched_qr: matrices buffer too small ({} < {required_mat})",
matrices.len()
)));
}
let k = m.min(n);
let required_tau = batch_count * k;
if tau.len() < required_tau {
return Err(SolverError::DimensionMismatch(format!(
"batched_qr: tau buffer too small ({} < {required_tau})",
tau.len()
)));
}
let dim = m.max(n);
if dim > MAX_BATCH_MATRIX_SIZE {
return Err(SolverError::DimensionMismatch(format!(
"batched_qr: matrix dimension ({dim}) exceeds maximum ({MAX_BATCH_MATRIX_SIZE})"
)));
}
let shared_per_matrix = (m * n + m) * T::SIZE;
let mpb = matrices_per_block(dim);
let ws_bytes = shared_per_matrix * mpb;
self.handle.ensure_workspace(ws_bytes)?;
let sm = self.handle.sm_version();
let ptx = emit_batched_qr::<T>(sm, m, n)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &batched_qr_name::<T>(m, n))?;
let grid = compute_grid_size(batch_count, dim);
let block = compute_block_size(dim);
let shared_bytes = (shared_per_matrix * mpb) as u32;
let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes);
let args = (
matrices.as_device_ptr(),
tau.as_device_ptr(),
m as u32,
n as u32,
batch_count as u32,
);
kernel.launch(¶ms, self.handle.stream(), &args)?;
Ok(BatchedResult { failed_count: 0 })
}
pub fn batched_cholesky<T: GpuFloat>(
&mut self,
matrices: &mut DeviceBuffer<T>,
n: usize,
batch_count: usize,
) -> SolverResult<BatchedResult> {
validate_batched_params::<T>(matrices, n, batch_count)?;
if n == 0 || batch_count == 0 {
return Ok(BatchedResult { failed_count: 0 });
}
let shared_per_matrix = n * n * T::SIZE;
let mpb = matrices_per_block(n);
let ws_bytes = shared_per_matrix * mpb;
self.handle.ensure_workspace(ws_bytes)?;
let sm = self.handle.sm_version();
let ptx = emit_batched_cholesky::<T>(sm, n)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &batched_cholesky_name::<T>(n))?;
let grid = compute_grid_size(batch_count, n);
let block = compute_block_size(n);
let shared_bytes = (shared_per_matrix * mpb) as u32;
let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes);
let args = (matrices.as_device_ptr(), n as u32, batch_count as u32);
kernel.launch(¶ms, self.handle.stream(), &args)?;
Ok(BatchedResult { failed_count: 0 })
}
pub fn batched_solve<T: GpuFloat>(
&mut self,
a_matrices: &mut DeviceBuffer<T>,
b_matrices: &mut DeviceBuffer<T>,
n: usize,
nrhs: usize,
batch_count: usize,
) -> SolverResult<BatchedResult> {
if n == 0 || nrhs == 0 || batch_count == 0 {
return Ok(BatchedResult { failed_count: 0 });
}
validate_batched_params::<T>(a_matrices, n, batch_count)?;
let required_b = batch_count * n * nrhs;
if b_matrices.len() < required_b {
return Err(SolverError::DimensionMismatch(format!(
"batched_solve: b_matrices buffer too small ({} < {required_b})",
b_matrices.len()
)));
}
let mut pivots = DeviceBuffer::<i32>::zeroed(batch_count * n)?;
let lu_result = self.batched_lu(a_matrices, &mut pivots, n, batch_count)?;
let sm = self.handle.sm_version();
let ptx = emit_batched_solve::<T>(sm, n, nrhs)?;
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &batched_solve_name::<T>(n, nrhs))?;
let shared_per_system = (n * n + n * nrhs + n) * T::SIZE;
let grid = compute_grid_size(batch_count, n);
let block = compute_block_size(n);
let params = LaunchParams::new(grid, block).with_shared_mem(shared_per_system as u32);
let args = (
a_matrices.as_device_ptr(),
b_matrices.as_device_ptr(),
pivots.as_device_ptr(),
n as u32,
nrhs as u32,
batch_count as u32,
);
kernel.launch(¶ms, self.handle.stream(), &args)?;
Ok(lu_result)
}
}
fn validate_batched_params<T: GpuFloat>(
matrices: &DeviceBuffer<T>,
n: usize,
batch_count: usize,
) -> SolverResult<()> {
if n > MAX_BATCH_MATRIX_SIZE {
return Err(SolverError::DimensionMismatch(format!(
"batched: matrix size ({n}) exceeds maximum ({MAX_BATCH_MATRIX_SIZE})"
)));
}
if n < MIN_BATCH_MATRIX_SIZE && n != 0 {
return Err(SolverError::DimensionMismatch(format!(
"batched: matrix size ({n}) below minimum ({MIN_BATCH_MATRIX_SIZE})"
)));
}
let required = batch_count * n * n;
if matrices.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"batched: matrices buffer too small ({} < {required})",
matrices.len()
)));
}
Ok(())
}
fn validate_pivot_buffer(
pivots: &DeviceBuffer<i32>,
n: usize,
batch_count: usize,
) -> SolverResult<()> {
let required = batch_count * n;
if pivots.len() < required {
return Err(SolverError::DimensionMismatch(format!(
"batched: pivots buffer too small ({} < {required})",
pivots.len()
)));
}
Ok(())
}
fn matrices_per_block(n: usize) -> usize {
if n <= SMALL_MATRIX_THRESHOLD {
SMALL_MATRICES_PER_BLOCK
} else {
1
}
}
fn compute_grid_size(batch_count: usize, n: usize) -> u32 {
let mpb = matrices_per_block(n);
let blocks = batch_count.div_ceil(mpb);
blocks as u32
}
fn compute_block_size(n: usize) -> u32 {
if n <= 16 {
(32 * SMALL_MATRICES_PER_BLOCK as u32).min(SOLVER_BLOCK_SIZE)
} else if n <= 32 {
32
} else {
64
}
}
fn batched_lu_name<T: GpuFloat>(n: usize) -> String {
format!("solver_batched_lu_{}_{}", T::NAME, n)
}
fn batched_qr_name<T: GpuFloat>(m: usize, n: usize) -> String {
format!("solver_batched_qr_{}_{}x{}", T::NAME, m, n)
}
fn batched_cholesky_name<T: GpuFloat>(n: usize) -> String {
format!("solver_batched_cholesky_{}_{}", T::NAME, n)
}
fn batched_solve_name<T: GpuFloat>(n: usize, nrhs: usize) -> String {
format!("solver_batched_solve_{}_{}_{}", T::NAME, n, nrhs)
}
fn emit_batched_lu<T: GpuFloat>(sm: SmVersion, n: usize) -> SolverResult<String> {
let name = batched_lu_name::<T>(n);
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("matrices_ptr", PtxType::U64)
.param("pivots_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("batch_count", PtxType::U32)
.body(move |b| {
let bid = b.block_id_x();
let tid = b.thread_id_x();
let batch_count_reg = b.load_param_u32("batch_count");
let n_reg = b.load_param_u32("n");
b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
let matrices_ptr = b.load_param_u64("matrices_ptr");
let pivots_ptr = b.load_param_u64("pivots_ptr");
let n2 = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
let mat_offset = b.mul_lo_u32(bid.clone(), n2.clone());
let _mat_base = b.byte_offset_addr(matrices_ptr, mat_offset, T::size_u32());
let piv_offset = b.mul_lo_u32(bid, n_reg);
let _piv_base = b.byte_offset_addr(pivots_ptr, piv_offset, 4u32);
let _ = (tid, float_ty);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_batched_qr<T: GpuFloat>(sm: SmVersion, m: usize, n: usize) -> SolverResult<String> {
let name = batched_qr_name::<T>(m, n);
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("matrices_ptr", PtxType::U64)
.param("tau_ptr", PtxType::U64)
.param("m", PtxType::U32)
.param("n", PtxType::U32)
.param("batch_count", PtxType::U32)
.body(move |b| {
let bid = b.block_id_x();
let tid = b.thread_id_x();
let batch_count_reg = b.load_param_u32("batch_count");
let m_reg = b.load_param_u32("m");
let n_reg = b.load_param_u32("n");
b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
let matrices_ptr = b.load_param_u64("matrices_ptr");
let tau_ptr = b.load_param_u64("tau_ptr");
let mn = b.mul_lo_u32(m_reg.clone(), n_reg.clone());
let mat_offset = b.mul_lo_u32(bid.clone(), mn);
let _mat_base = b.byte_offset_addr(matrices_ptr, mat_offset, T::size_u32());
let tau_offset = b.mul_lo_u32(bid, n_reg);
let _tau_base = b.byte_offset_addr(tau_ptr, tau_offset, T::size_u32());
let _ = (tid, float_ty, m_reg);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_batched_cholesky<T: GpuFloat>(sm: SmVersion, n: usize) -> SolverResult<String> {
let name = batched_cholesky_name::<T>(n);
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("matrices_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("batch_count", PtxType::U32)
.body(move |b| {
let bid = b.block_id_x();
let tid = b.thread_id_x();
let batch_count_reg = b.load_param_u32("batch_count");
let n_reg = b.load_param_u32("n");
b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
let matrices_ptr = b.load_param_u64("matrices_ptr");
let n2 = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
let mat_offset = b.mul_lo_u32(bid, n2);
let _mat_base = b.byte_offset_addr(matrices_ptr, mat_offset, T::size_u32());
let _ = (tid, float_ty, n_reg);
});
b.ret();
})
.build()?;
Ok(ptx)
}
fn emit_batched_solve<T: GpuFloat>(sm: SmVersion, n: usize, nrhs: usize) -> SolverResult<String> {
let name = batched_solve_name::<T>(n, nrhs);
let float_ty = T::PTX_TYPE;
let ptx = KernelBuilder::new(&name)
.target(sm)
.max_threads_per_block(SOLVER_BLOCK_SIZE)
.param("lu_ptr", PtxType::U64)
.param("b_ptr", PtxType::U64)
.param("pivots_ptr", PtxType::U64)
.param("n", PtxType::U32)
.param("nrhs", PtxType::U32)
.param("batch_count", PtxType::U32)
.body(move |b| {
let bid = b.block_id_x();
let tid = b.thread_id_x();
let batch_count_reg = b.load_param_u32("batch_count");
let n_reg = b.load_param_u32("n");
let nrhs_reg = b.load_param_u32("nrhs");
b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
let lu_ptr = b.load_param_u64("lu_ptr");
let b_ptr = b.load_param_u64("b_ptr");
let pivots_ptr = b.load_param_u64("pivots_ptr");
let n2 = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
let lu_offset = b.mul_lo_u32(bid.clone(), n2);
let _lu_base = b.byte_offset_addr(lu_ptr, lu_offset, T::size_u32());
let b_stride = b.mul_lo_u32(n_reg.clone(), nrhs_reg);
let b_offset = b.mul_lo_u32(bid.clone(), b_stride);
let _b_base = b.byte_offset_addr(b_ptr, b_offset, T::size_u32());
let piv_offset = b.mul_lo_u32(bid, n_reg);
let _piv_base = b.byte_offset_addr(pivots_ptr, piv_offset, 4u32);
let _ = (tid, float_ty);
});
b.ret();
})
.build()?;
Ok(ptx)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn batch_algorithm_equality() {
assert_eq!(BatchAlgorithm::Lu, BatchAlgorithm::Lu);
assert_ne!(BatchAlgorithm::Lu, BatchAlgorithm::Qr);
assert_ne!(BatchAlgorithm::Qr, BatchAlgorithm::Cholesky);
}
#[test]
fn batch_config_construction() {
let config = BatchConfig {
matrix_size: 16,
batch_count: 1000,
algorithm: BatchAlgorithm::Lu,
};
assert_eq!(config.matrix_size, 16);
assert_eq!(config.batch_count, 1000);
assert_eq!(config.algorithm, BatchAlgorithm::Lu);
}
#[test]
fn batched_result_construction() {
let result = BatchedResult { failed_count: 0 };
assert_eq!(result.failed_count, 0);
let result2 = BatchedResult { failed_count: 5 };
assert_eq!(result2.failed_count, 5);
}
#[test]
fn matrices_per_block_small() {
assert_eq!(matrices_per_block(4), SMALL_MATRICES_PER_BLOCK);
assert_eq!(matrices_per_block(8), SMALL_MATRICES_PER_BLOCK);
assert_eq!(matrices_per_block(16), SMALL_MATRICES_PER_BLOCK);
}
#[test]
fn matrices_per_block_large() {
assert_eq!(matrices_per_block(32), 1);
assert_eq!(matrices_per_block(64), 1);
}
#[test]
fn compute_block_size_values() {
let bs_small = compute_block_size(8);
assert!(bs_small <= SOLVER_BLOCK_SIZE);
assert!(bs_small >= 32);
let bs_med = compute_block_size(32);
assert_eq!(bs_med, 32);
let bs_large = compute_block_size(64);
assert_eq!(bs_large, 64);
}
#[test]
fn compute_grid_size_values() {
let grid = compute_grid_size(100, 8);
assert_eq!(grid, 25);
let grid = compute_grid_size(100, 32);
assert_eq!(grid, 100);
let grid = compute_grid_size(101, 8);
assert_eq!(grid, 26); }
#[test]
fn batched_lu_name_format() {
let name = batched_lu_name::<f32>(16);
assert!(name.contains("f32"));
assert!(name.contains("16"));
}
#[test]
fn batched_qr_name_format() {
let name = batched_qr_name::<f64>(32, 16);
assert!(name.contains("f64"));
assert!(name.contains("32x16"));
}
#[test]
fn batched_cholesky_name_format() {
let name = batched_cholesky_name::<f32>(64);
assert!(name.contains("f32"));
assert!(name.contains("64"));
}
#[test]
fn batched_solve_name_format() {
let name = batched_solve_name::<f64>(16, 4);
assert!(name.contains("f64"));
assert!(name.contains("16"));
assert!(name.contains("4"));
}
#[test]
fn max_batch_matrix_size_reasonable() {
let max_size = MAX_BATCH_MATRIX_SIZE;
assert!(max_size >= 32);
assert!(max_size <= 128);
}
#[test]
fn small_matrix_threshold_consistent() {
let threshold = SMALL_MATRIX_THRESHOLD;
let per_block = SMALL_MATRICES_PER_BLOCK;
assert!(threshold <= 32);
assert!(per_block >= 1);
assert!(per_block <= 16);
}
}