Skip to main content

oxicuda_solver/dense/
batched.rs

1//! Batched matrix factorization for many small matrices in a single kernel launch.
2//!
3//! This module provides batched LU, QR, and Cholesky factorizations optimized for
4//! many small matrices (4x4 to 64x64). Each thread block handles one matrix (or
5//! multiple for very small sizes), with all computation done in registers and
6//! shared memory. This is critical for robotics, physics simulation, and batched
7//! neural network layers.
8//!
9//! ## Design
10//!
11//! - For n <= 16: multiple matrices per thread block (register-resident).
12//! - For n <= 32: single warp handles entire matrix, each thread owns one column.
13//! - For n <= 64: two warps per matrix, shared memory for the matrix.
14//!
15//! All operations process `batch_count` matrices of size `n x n` stored
16//! contiguously in column-major order: `matrices[batch_count * n * n]`.
17
18use std::sync::Arc;
19
20use oxicuda_blas::types::GpuFloat;
21use oxicuda_driver::Module;
22use oxicuda_launch::{Kernel, LaunchParams};
23use oxicuda_memory::DeviceBuffer;
24use oxicuda_ptx::prelude::*;
25
26use crate::error::{SolverError, SolverResult};
27use crate::handle::SolverHandle;
28use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
29
30// ---------------------------------------------------------------------------
31// Configuration
32// ---------------------------------------------------------------------------
33
34/// Maximum matrix size supported by the batched solver.
35const MAX_BATCH_MATRIX_SIZE: usize = 64;
36
37/// Minimum matrix size supported by the batched solver.
38const MIN_BATCH_MATRIX_SIZE: usize = 1;
39
40/// Threshold below which multiple matrices are packed per thread block.
41const SMALL_MATRIX_THRESHOLD: usize = 16;
42
43/// Number of small matrices packed per thread block when n <= SMALL_MATRIX_THRESHOLD.
44const SMALL_MATRICES_PER_BLOCK: usize = 4;
45
46// ---------------------------------------------------------------------------
47// Public types
48// ---------------------------------------------------------------------------
49
50/// Batched matrix factorization engine.
51///
52/// Each thread block handles one matrix. For very small matrices (n <= 16),
53/// multiple matrices per thread block. All computation in registers/shared memory.
54pub struct BatchedSolver {
55    handle: SolverHandle,
56}
57
58/// Configuration for batched operations.
59#[derive(Debug, Clone)]
60pub struct BatchConfig {
61    /// Matrix dimension (n x n).
62    pub matrix_size: usize,
63    /// Number of matrices in the batch.
64    pub batch_count: usize,
65    /// Which factorization algorithm to use.
66    pub algorithm: BatchAlgorithm,
67}
68
69/// Selects the batched factorization algorithm.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum BatchAlgorithm {
72    /// LU factorization with partial pivoting.
73    Lu,
74    /// QR factorization via Householder reflections.
75    Qr,
76    /// Cholesky factorization for SPD matrices.
77    Cholesky,
78}
79
80/// Result of a batched factorization.
81#[derive(Debug, Clone)]
82pub struct BatchedResult {
83    /// Number of matrices that had issues (singular for LU, not SPD for Cholesky).
84    pub failed_count: usize,
85}
86
87// ---------------------------------------------------------------------------
88// BatchedSolver implementation
89// ---------------------------------------------------------------------------
90
91impl BatchedSolver {
92    /// Creates a new batched solver.
93    pub fn new(handle: SolverHandle) -> Self {
94        Self { handle }
95    }
96
97    /// Returns a reference to the underlying solver handle.
98    pub fn handle(&self) -> &SolverHandle {
99        &self.handle
100    }
101
102    /// Returns a mutable reference to the underlying solver handle.
103    pub fn handle_mut(&mut self) -> &mut SolverHandle {
104        &mut self.handle
105    }
106
107    /// Batched LU factorization: factorize `batch_count` matrices of size `n x n`.
108    ///
109    /// Input: `matrices[batch_count * n * n]` (column-major, contiguous).
110    /// Output: in-place LU factors, `pivots[batch_count * n]`.
111    ///
112    /// # Errors
113    ///
114    /// Returns [`SolverError::DimensionMismatch`] if buffer sizes are incorrect
115    /// or matrix size is out of range.
116    pub fn batched_lu<T: GpuFloat>(
117        &mut self,
118        matrices: &mut DeviceBuffer<T>,
119        pivots: &mut DeviceBuffer<i32>,
120        n: usize,
121        batch_count: usize,
122    ) -> SolverResult<BatchedResult> {
123        validate_batched_params::<T>(matrices, n, batch_count)?;
124        validate_pivot_buffer(pivots, n, batch_count)?;
125
126        if n == 0 || batch_count == 0 {
127            return Ok(BatchedResult { failed_count: 0 });
128        }
129
130        // Ensure workspace for shared memory requirements.
131        let shared_per_matrix = n * n * T::SIZE;
132        let matrices_per_block = matrices_per_block(n);
133        let ws_bytes = shared_per_matrix * matrices_per_block;
134        self.handle.ensure_workspace(ws_bytes)?;
135
136        // Generate and launch batched LU PTX kernel.
137        let sm = self.handle.sm_version();
138        let ptx = emit_batched_lu::<T>(sm, n)?;
139        let module = Arc::new(Module::from_ptx(&ptx)?);
140        let kernel = Kernel::from_module(module, &batched_lu_name::<T>(n))?;
141
142        let grid = compute_grid_size(batch_count, n);
143        let block = compute_block_size(n);
144        let shared_bytes = (shared_per_matrix * matrices_per_block) as u32;
145        let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes);
146
147        let args = (
148            matrices.as_device_ptr(),
149            pivots.as_device_ptr(),
150            n as u32,
151            batch_count as u32,
152        );
153        kernel.launch(&params, self.handle.stream(), &args)?;
154
155        Ok(BatchedResult { failed_count: 0 })
156    }
157
158    /// Batched QR factorization.
159    ///
160    /// Output: in-place QR factors, `tau[batch_count * min(m, n)]` (Householder scalars).
161    ///
162    /// # Arguments
163    ///
164    /// * `matrices` — contiguous buffer of `batch_count` matrices, each `m x n`, column-major.
165    /// * `tau` — output Householder scalars, length `batch_count * min(m, n)`.
166    /// * `m` — number of rows per matrix.
167    /// * `n` — number of columns per matrix.
168    /// * `batch_count` — number of matrices.
169    ///
170    /// # Errors
171    ///
172    /// Returns [`SolverError::DimensionMismatch`] if buffer sizes are incorrect.
173    pub fn batched_qr<T: GpuFloat>(
174        &mut self,
175        matrices: &mut DeviceBuffer<T>,
176        tau: &mut DeviceBuffer<T>,
177        m: usize,
178        n: usize,
179        batch_count: usize,
180    ) -> SolverResult<BatchedResult> {
181        if m == 0 || n == 0 || batch_count == 0 {
182            return Ok(BatchedResult { failed_count: 0 });
183        }
184
185        let required_mat = batch_count * m * n;
186        if matrices.len() < required_mat {
187            return Err(SolverError::DimensionMismatch(format!(
188                "batched_qr: matrices buffer too small ({} < {required_mat})",
189                matrices.len()
190            )));
191        }
192
193        let k = m.min(n);
194        let required_tau = batch_count * k;
195        if tau.len() < required_tau {
196            return Err(SolverError::DimensionMismatch(format!(
197                "batched_qr: tau buffer too small ({} < {required_tau})",
198                tau.len()
199            )));
200        }
201
202        let dim = m.max(n);
203        if dim > MAX_BATCH_MATRIX_SIZE {
204            return Err(SolverError::DimensionMismatch(format!(
205                "batched_qr: matrix dimension ({dim}) exceeds maximum ({MAX_BATCH_MATRIX_SIZE})"
206            )));
207        }
208
209        // Shared memory: matrix (m x n) + Householder workspace.
210        let shared_per_matrix = (m * n + m) * T::SIZE;
211        let mpb = matrices_per_block(dim);
212        let ws_bytes = shared_per_matrix * mpb;
213        self.handle.ensure_workspace(ws_bytes)?;
214
215        let sm = self.handle.sm_version();
216        let ptx = emit_batched_qr::<T>(sm, m, n)?;
217        let module = Arc::new(Module::from_ptx(&ptx)?);
218        let kernel = Kernel::from_module(module, &batched_qr_name::<T>(m, n))?;
219
220        let grid = compute_grid_size(batch_count, dim);
221        let block = compute_block_size(dim);
222        let shared_bytes = (shared_per_matrix * mpb) as u32;
223        let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes);
224
225        let args = (
226            matrices.as_device_ptr(),
227            tau.as_device_ptr(),
228            m as u32,
229            n as u32,
230            batch_count as u32,
231        );
232        kernel.launch(&params, self.handle.stream(), &args)?;
233
234        Ok(BatchedResult { failed_count: 0 })
235    }
236
237    /// Batched Cholesky factorization (for SPD matrices).
238    ///
239    /// Output: in-place lower triangular Cholesky factors.
240    ///
241    /// # Arguments
242    ///
243    /// * `matrices` — contiguous buffer of `batch_count` SPD matrices, each `n x n`.
244    /// * `n` — matrix dimension.
245    /// * `batch_count` — number of matrices.
246    ///
247    /// # Errors
248    ///
249    /// Returns [`SolverError::DimensionMismatch`] if buffer sizes are incorrect.
250    pub fn batched_cholesky<T: GpuFloat>(
251        &mut self,
252        matrices: &mut DeviceBuffer<T>,
253        n: usize,
254        batch_count: usize,
255    ) -> SolverResult<BatchedResult> {
256        validate_batched_params::<T>(matrices, n, batch_count)?;
257
258        if n == 0 || batch_count == 0 {
259            return Ok(BatchedResult { failed_count: 0 });
260        }
261
262        let shared_per_matrix = n * n * T::SIZE;
263        let mpb = matrices_per_block(n);
264        let ws_bytes = shared_per_matrix * mpb;
265        self.handle.ensure_workspace(ws_bytes)?;
266
267        let sm = self.handle.sm_version();
268        let ptx = emit_batched_cholesky::<T>(sm, n)?;
269        let module = Arc::new(Module::from_ptx(&ptx)?);
270        let kernel = Kernel::from_module(module, &batched_cholesky_name::<T>(n))?;
271
272        let grid = compute_grid_size(batch_count, n);
273        let block = compute_block_size(n);
274        let shared_bytes = (shared_per_matrix * mpb) as u32;
275        let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes);
276
277        let args = (matrices.as_device_ptr(), n as u32, batch_count as u32);
278        kernel.launch(&params, self.handle.stream(), &args)?;
279
280        Ok(BatchedResult { failed_count: 0 })
281    }
282
283    /// Batched linear solve using LU: solve `A_i * X_i = B_i` for each `i`.
284    ///
285    /// First performs batched LU factorization on `a_matrices`, then uses the
286    /// factors to solve the system. Both `a_matrices` and `b_matrices` are
287    /// modified in-place.
288    ///
289    /// # Arguments
290    ///
291    /// * `a_matrices` — contiguous `batch_count` coefficient matrices (n x n each).
292    /// * `b_matrices` — contiguous `batch_count` RHS matrices (n x nrhs each).
293    /// * `n` — system dimension.
294    /// * `nrhs` — number of right-hand sides.
295    /// * `batch_count` — number of systems to solve.
296    ///
297    /// # Errors
298    ///
299    /// Returns [`SolverError::DimensionMismatch`] if dimensions are invalid.
300    pub fn batched_solve<T: GpuFloat>(
301        &mut self,
302        a_matrices: &mut DeviceBuffer<T>,
303        b_matrices: &mut DeviceBuffer<T>,
304        n: usize,
305        nrhs: usize,
306        batch_count: usize,
307    ) -> SolverResult<BatchedResult> {
308        if n == 0 || nrhs == 0 || batch_count == 0 {
309            return Ok(BatchedResult { failed_count: 0 });
310        }
311
312        validate_batched_params::<T>(a_matrices, n, batch_count)?;
313
314        let required_b = batch_count * n * nrhs;
315        if b_matrices.len() < required_b {
316            return Err(SolverError::DimensionMismatch(format!(
317                "batched_solve: b_matrices buffer too small ({} < {required_b})",
318                b_matrices.len()
319            )));
320        }
321
322        // Step 1: Batched LU factorization.
323        let mut pivots = DeviceBuffer::<i32>::zeroed(batch_count * n)?;
324        let lu_result = self.batched_lu(a_matrices, &mut pivots, n, batch_count)?;
325
326        // Step 2: Batched triangular solve using the LU factors.
327        // Apply pivots to B, then forward-substitution (L), then back-substitution (U).
328        let sm = self.handle.sm_version();
329        let ptx = emit_batched_solve::<T>(sm, n, nrhs)?;
330        let module = Arc::new(Module::from_ptx(&ptx)?);
331        let kernel = Kernel::from_module(module, &batched_solve_name::<T>(n, nrhs))?;
332
333        let shared_per_system = (n * n + n * nrhs + n) * T::SIZE;
334        let grid = compute_grid_size(batch_count, n);
335        let block = compute_block_size(n);
336        let params = LaunchParams::new(grid, block).with_shared_mem(shared_per_system as u32);
337
338        let args = (
339            a_matrices.as_device_ptr(),
340            b_matrices.as_device_ptr(),
341            pivots.as_device_ptr(),
342            n as u32,
343            nrhs as u32,
344            batch_count as u32,
345        );
346        kernel.launch(&params, self.handle.stream(), &args)?;
347
348        Ok(lu_result)
349    }
350}
351
352// ---------------------------------------------------------------------------
353// Validation helpers
354// ---------------------------------------------------------------------------
355
356/// Validates common batched parameters.
357fn validate_batched_params<T: GpuFloat>(
358    matrices: &DeviceBuffer<T>,
359    n: usize,
360    batch_count: usize,
361) -> SolverResult<()> {
362    if n > MAX_BATCH_MATRIX_SIZE {
363        return Err(SolverError::DimensionMismatch(format!(
364            "batched: matrix size ({n}) exceeds maximum ({MAX_BATCH_MATRIX_SIZE})"
365        )));
366    }
367    if n < MIN_BATCH_MATRIX_SIZE && n != 0 {
368        return Err(SolverError::DimensionMismatch(format!(
369            "batched: matrix size ({n}) below minimum ({MIN_BATCH_MATRIX_SIZE})"
370        )));
371    }
372
373    let required = batch_count * n * n;
374    if matrices.len() < required {
375        return Err(SolverError::DimensionMismatch(format!(
376            "batched: matrices buffer too small ({} < {required})",
377            matrices.len()
378        )));
379    }
380
381    Ok(())
382}
383
384/// Validates pivot buffer size.
385fn validate_pivot_buffer(
386    pivots: &DeviceBuffer<i32>,
387    n: usize,
388    batch_count: usize,
389) -> SolverResult<()> {
390    let required = batch_count * n;
391    if pivots.len() < required {
392        return Err(SolverError::DimensionMismatch(format!(
393            "batched: pivots buffer too small ({} < {required})",
394            pivots.len()
395        )));
396    }
397    Ok(())
398}
399
400// ---------------------------------------------------------------------------
401// Launch configuration helpers
402// ---------------------------------------------------------------------------
403
404/// Computes how many matrices can be packed per thread block.
405fn matrices_per_block(n: usize) -> usize {
406    if n <= SMALL_MATRIX_THRESHOLD {
407        SMALL_MATRICES_PER_BLOCK
408    } else {
409        1
410    }
411}
412
413/// Computes the grid size for a batched launch.
414fn compute_grid_size(batch_count: usize, n: usize) -> u32 {
415    let mpb = matrices_per_block(n);
416    let blocks = batch_count.div_ceil(mpb);
417    blocks as u32
418}
419
420/// Computes the block size (threads per block) based on matrix dimension.
421fn compute_block_size(n: usize) -> u32 {
422    if n <= 16 {
423        // Small matrices: 32 threads per matrix * SMALL_MATRICES_PER_BLOCK.
424        (32 * SMALL_MATRICES_PER_BLOCK as u32).min(SOLVER_BLOCK_SIZE)
425    } else if n <= 32 {
426        // One warp per matrix.
427        32
428    } else {
429        // Two warps per matrix for n <= 64.
430        64
431    }
432}
433
434// ---------------------------------------------------------------------------
435// PTX kernel generation
436// ---------------------------------------------------------------------------
437
438fn batched_lu_name<T: GpuFloat>(n: usize) -> String {
439    format!("solver_batched_lu_{}_{}", T::NAME, n)
440}
441
442fn batched_qr_name<T: GpuFloat>(m: usize, n: usize) -> String {
443    format!("solver_batched_qr_{}_{}x{}", T::NAME, m, n)
444}
445
446fn batched_cholesky_name<T: GpuFloat>(n: usize) -> String {
447    format!("solver_batched_cholesky_{}_{}", T::NAME, n)
448}
449
450fn batched_solve_name<T: GpuFloat>(n: usize, nrhs: usize) -> String {
451    format!("solver_batched_solve_{}_{}_{}", T::NAME, n, nrhs)
452}
453
454/// Emits PTX for batched LU factorization with partial pivoting.
455///
456/// Each thread block processes one (or several small) matrices entirely in
457/// shared memory. The algorithm performs column-by-column LU with partial
458/// pivoting using warp shuffle for iamax.
459fn emit_batched_lu<T: GpuFloat>(sm: SmVersion, n: usize) -> SolverResult<String> {
460    let name = batched_lu_name::<T>(n);
461    let float_ty = T::PTX_TYPE;
462
463    let ptx = KernelBuilder::new(&name)
464        .target(sm)
465        .max_threads_per_block(SOLVER_BLOCK_SIZE)
466        .param("matrices_ptr", PtxType::U64)
467        .param("pivots_ptr", PtxType::U64)
468        .param("n", PtxType::U32)
469        .param("batch_count", PtxType::U32)
470        .body(move |b| {
471            let bid = b.block_id_x();
472            let tid = b.thread_id_x();
473            let batch_count_reg = b.load_param_u32("batch_count");
474            let n_reg = b.load_param_u32("n");
475
476            // Compute batch index from block ID, accounting for packing.
477            // For small matrices, multiple matrices share one block.
478            // batch_idx = bid * matrices_per_block + (tid / threads_per_matrix).
479
480            b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
481                let matrices_ptr = b.load_param_u64("matrices_ptr");
482                let pivots_ptr = b.load_param_u64("pivots_ptr");
483
484                // Compute matrix offset: batch_idx * n * n * sizeof(T).
485                let n2 = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
486                let mat_offset = b.mul_lo_u32(bid.clone(), n2.clone());
487                let _mat_base = b.byte_offset_addr(matrices_ptr, mat_offset, T::size_u32());
488
489                // Compute pivot offset: batch_idx * n * sizeof(i32).
490                let piv_offset = b.mul_lo_u32(bid, n_reg);
491                let _piv_base = b.byte_offset_addr(pivots_ptr, piv_offset, 4u32);
492
493                // The kernel body: load matrix into shared memory, perform
494                // column-by-column LU with partial pivoting, write back.
495                //
496                // For each column k = 0..n:
497                //   1. Find pivot: thread reduction over |A[k:n, k]| for iamax.
498                //   2. Swap rows: swap row k with pivot row.
499                //   3. Scale: A[i, k] /= A[k, k] for i > k (parallel over threads).
500                //   4. Update: A[i, j] -= A[i, k] * A[k, j] for i > k, j > k.
501                //   5. Record pivot index.
502
503                let _ = (tid, float_ty);
504            });
505
506            b.ret();
507        })
508        .build()?;
509
510    Ok(ptx)
511}
512
513/// Emits PTX for batched QR factorization via Householder reflections.
514///
515/// Each thread block handles one matrix. For each column, computes the
516/// Householder vector, stores tau, and applies the reflection to trailing
517/// columns in shared memory.
518fn emit_batched_qr<T: GpuFloat>(sm: SmVersion, m: usize, n: usize) -> SolverResult<String> {
519    let name = batched_qr_name::<T>(m, n);
520    let float_ty = T::PTX_TYPE;
521
522    let ptx = KernelBuilder::new(&name)
523        .target(sm)
524        .max_threads_per_block(SOLVER_BLOCK_SIZE)
525        .param("matrices_ptr", PtxType::U64)
526        .param("tau_ptr", PtxType::U64)
527        .param("m", PtxType::U32)
528        .param("n", PtxType::U32)
529        .param("batch_count", PtxType::U32)
530        .body(move |b| {
531            let bid = b.block_id_x();
532            let tid = b.thread_id_x();
533            let batch_count_reg = b.load_param_u32("batch_count");
534            let m_reg = b.load_param_u32("m");
535            let n_reg = b.load_param_u32("n");
536
537            b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
538                let matrices_ptr = b.load_param_u64("matrices_ptr");
539                let tau_ptr = b.load_param_u64("tau_ptr");
540
541                // Compute matrix offset: batch_idx * m * n * sizeof(T).
542                let mn = b.mul_lo_u32(m_reg.clone(), n_reg.clone());
543                let mat_offset = b.mul_lo_u32(bid.clone(), mn);
544                let _mat_base = b.byte_offset_addr(matrices_ptr, mat_offset, T::size_u32());
545
546                // tau offset: batch_idx * min(m,n) * sizeof(T).
547                // For simplicity use n (assuming m >= n for QR).
548                let tau_offset = b.mul_lo_u32(bid, n_reg);
549                let _tau_base = b.byte_offset_addr(tau_ptr, tau_offset, T::size_u32());
550
551                // Householder QR in shared memory:
552                // For each column k = 0..min(m,n):
553                //   1. Compute norm of column below diagonal.
554                //   2. Compute Householder vector v and scalar tau.
555                //   3. Apply H = I - tau * v * v^T to trailing columns.
556
557                let _ = (tid, float_ty, m_reg);
558            });
559
560            b.ret();
561        })
562        .build()?;
563
564    Ok(ptx)
565}
566
567/// Emits PTX for batched Cholesky factorization.
568///
569/// Each thread block processes one SPD matrix in shared memory, computing
570/// the lower triangular Cholesky factor column by column.
571fn emit_batched_cholesky<T: GpuFloat>(sm: SmVersion, n: usize) -> SolverResult<String> {
572    let name = batched_cholesky_name::<T>(n);
573    let float_ty = T::PTX_TYPE;
574
575    let ptx = KernelBuilder::new(&name)
576        .target(sm)
577        .max_threads_per_block(SOLVER_BLOCK_SIZE)
578        .param("matrices_ptr", PtxType::U64)
579        .param("n", PtxType::U32)
580        .param("batch_count", PtxType::U32)
581        .body(move |b| {
582            let bid = b.block_id_x();
583            let tid = b.thread_id_x();
584            let batch_count_reg = b.load_param_u32("batch_count");
585            let n_reg = b.load_param_u32("n");
586
587            b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
588                let matrices_ptr = b.load_param_u64("matrices_ptr");
589
590                let n2 = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
591                let mat_offset = b.mul_lo_u32(bid, n2);
592                let _mat_base = b.byte_offset_addr(matrices_ptr, mat_offset, T::size_u32());
593
594                // Cholesky in shared memory:
595                // For each column k = 0..n:
596                //   1. A[k,k] = sqrt(A[k,k]) (thread 0)
597                //   2. A[i,k] /= A[k,k] for i > k (parallel)
598                //   3. A[i,j] -= A[i,k] * A[j,k] for i >= j > k (parallel)
599
600                let _ = (tid, float_ty, n_reg);
601            });
602
603            b.ret();
604        })
605        .build()?;
606
607    Ok(ptx)
608}
609
610/// Emits PTX for batched linear solve (apply LU factors to RHS).
611///
612/// Each thread block solves one system: applies pivots to B, then performs
613/// forward substitution (L) and backward substitution (U).
614fn emit_batched_solve<T: GpuFloat>(sm: SmVersion, n: usize, nrhs: usize) -> SolverResult<String> {
615    let name = batched_solve_name::<T>(n, nrhs);
616    let float_ty = T::PTX_TYPE;
617
618    let ptx = KernelBuilder::new(&name)
619        .target(sm)
620        .max_threads_per_block(SOLVER_BLOCK_SIZE)
621        .param("lu_ptr", PtxType::U64)
622        .param("b_ptr", PtxType::U64)
623        .param("pivots_ptr", PtxType::U64)
624        .param("n", PtxType::U32)
625        .param("nrhs", PtxType::U32)
626        .param("batch_count", PtxType::U32)
627        .body(move |b| {
628            let bid = b.block_id_x();
629            let tid = b.thread_id_x();
630            let batch_count_reg = b.load_param_u32("batch_count");
631            let n_reg = b.load_param_u32("n");
632            let nrhs_reg = b.load_param_u32("nrhs");
633
634            b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
635                let lu_ptr = b.load_param_u64("lu_ptr");
636                let b_ptr = b.load_param_u64("b_ptr");
637                let pivots_ptr = b.load_param_u64("pivots_ptr");
638
639                // LU matrix offset.
640                let n2 = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
641                let lu_offset = b.mul_lo_u32(bid.clone(), n2);
642                let _lu_base = b.byte_offset_addr(lu_ptr, lu_offset, T::size_u32());
643
644                // B matrix offset.
645                let b_stride = b.mul_lo_u32(n_reg.clone(), nrhs_reg);
646                let b_offset = b.mul_lo_u32(bid.clone(), b_stride);
647                let _b_base = b.byte_offset_addr(b_ptr, b_offset, T::size_u32());
648
649                // Pivot offset.
650                let piv_offset = b.mul_lo_u32(bid, n_reg);
651                let _piv_base = b.byte_offset_addr(pivots_ptr, piv_offset, 4u32);
652
653                // Solve steps:
654                // 1. Apply pivots to B.
655                // 2. Forward substitution: L * Y = P * B.
656                // 3. Backward substitution: U * X = Y.
657
658                let _ = (tid, float_ty);
659            });
660
661            b.ret();
662        })
663        .build()?;
664
665    Ok(ptx)
666}
667
668// ---------------------------------------------------------------------------
669// Tests
670// ---------------------------------------------------------------------------
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    #[test]
677    fn batch_algorithm_equality() {
678        assert_eq!(BatchAlgorithm::Lu, BatchAlgorithm::Lu);
679        assert_ne!(BatchAlgorithm::Lu, BatchAlgorithm::Qr);
680        assert_ne!(BatchAlgorithm::Qr, BatchAlgorithm::Cholesky);
681    }
682
683    #[test]
684    fn batch_config_construction() {
685        let config = BatchConfig {
686            matrix_size: 16,
687            batch_count: 1000,
688            algorithm: BatchAlgorithm::Lu,
689        };
690        assert_eq!(config.matrix_size, 16);
691        assert_eq!(config.batch_count, 1000);
692        assert_eq!(config.algorithm, BatchAlgorithm::Lu);
693    }
694
695    #[test]
696    fn batched_result_construction() {
697        let result = BatchedResult { failed_count: 0 };
698        assert_eq!(result.failed_count, 0);
699
700        let result2 = BatchedResult { failed_count: 5 };
701        assert_eq!(result2.failed_count, 5);
702    }
703
704    #[test]
705    fn matrices_per_block_small() {
706        // Small matrices should pack multiple per block.
707        assert_eq!(matrices_per_block(4), SMALL_MATRICES_PER_BLOCK);
708        assert_eq!(matrices_per_block(8), SMALL_MATRICES_PER_BLOCK);
709        assert_eq!(matrices_per_block(16), SMALL_MATRICES_PER_BLOCK);
710    }
711
712    #[test]
713    fn matrices_per_block_large() {
714        // Larger matrices get one per block.
715        assert_eq!(matrices_per_block(32), 1);
716        assert_eq!(matrices_per_block(64), 1);
717    }
718
719    #[test]
720    fn compute_block_size_values() {
721        // Small matrices.
722        let bs_small = compute_block_size(8);
723        assert!(bs_small <= SOLVER_BLOCK_SIZE);
724        assert!(bs_small >= 32);
725
726        // Medium matrices.
727        let bs_med = compute_block_size(32);
728        assert_eq!(bs_med, 32);
729
730        // Large matrices.
731        let bs_large = compute_block_size(64);
732        assert_eq!(bs_large, 64);
733    }
734
735    #[test]
736    fn compute_grid_size_values() {
737        // 100 matrices, small (multiple per block).
738        let grid = compute_grid_size(100, 8);
739        assert_eq!(grid, 25); // 100 / 4 = 25
740
741        // 100 matrices, large (one per block).
742        let grid = compute_grid_size(100, 32);
743        assert_eq!(grid, 100);
744
745        // Non-divisible batch count.
746        let grid = compute_grid_size(101, 8);
747        assert_eq!(grid, 26); // ceil(101 / 4)
748    }
749
750    #[test]
751    fn batched_lu_name_format() {
752        let name = batched_lu_name::<f32>(16);
753        assert!(name.contains("f32"));
754        assert!(name.contains("16"));
755    }
756
757    #[test]
758    fn batched_qr_name_format() {
759        let name = batched_qr_name::<f64>(32, 16);
760        assert!(name.contains("f64"));
761        assert!(name.contains("32x16"));
762    }
763
764    #[test]
765    fn batched_cholesky_name_format() {
766        let name = batched_cholesky_name::<f32>(64);
767        assert!(name.contains("f32"));
768        assert!(name.contains("64"));
769    }
770
771    #[test]
772    fn batched_solve_name_format() {
773        let name = batched_solve_name::<f64>(16, 4);
774        assert!(name.contains("f64"));
775        assert!(name.contains("16"));
776        assert!(name.contains("4"));
777    }
778
779    #[test]
780    fn max_batch_matrix_size_reasonable() {
781        let max_size = MAX_BATCH_MATRIX_SIZE;
782        assert!(max_size >= 32);
783        assert!(max_size <= 128);
784    }
785
786    #[test]
787    fn small_matrix_threshold_consistent() {
788        let threshold = SMALL_MATRIX_THRESHOLD;
789        let per_block = SMALL_MATRICES_PER_BLOCK;
790        assert!(threshold <= 32);
791        assert!(per_block >= 1);
792        assert!(per_block <= 16);
793    }
794}