1use std::sync::Arc;
19
20use oxicuda_blas::types::GpuFloat;
21use oxicuda_driver::Module;
22use oxicuda_launch::{Kernel, LaunchParams};
23use oxicuda_memory::DeviceBuffer;
24use oxicuda_ptx::prelude::*;
25
26use crate::error::{SolverError, SolverResult};
27use crate::handle::SolverHandle;
28use crate::ptx_helpers::SOLVER_BLOCK_SIZE;
29
30const MAX_BATCH_MATRIX_SIZE: usize = 64;
36
37const MIN_BATCH_MATRIX_SIZE: usize = 1;
39
40const SMALL_MATRIX_THRESHOLD: usize = 16;
42
43const SMALL_MATRICES_PER_BLOCK: usize = 4;
45
46pub struct BatchedSolver {
55 handle: SolverHandle,
56}
57
58#[derive(Debug, Clone)]
60pub struct BatchConfig {
61 pub matrix_size: usize,
63 pub batch_count: usize,
65 pub algorithm: BatchAlgorithm,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum BatchAlgorithm {
72 Lu,
74 Qr,
76 Cholesky,
78}
79
80#[derive(Debug, Clone)]
82pub struct BatchedResult {
83 pub failed_count: usize,
85}
86
87impl BatchedSolver {
92 pub fn new(handle: SolverHandle) -> Self {
94 Self { handle }
95 }
96
97 pub fn handle(&self) -> &SolverHandle {
99 &self.handle
100 }
101
102 pub fn handle_mut(&mut self) -> &mut SolverHandle {
104 &mut self.handle
105 }
106
107 pub fn batched_lu<T: GpuFloat>(
117 &mut self,
118 matrices: &mut DeviceBuffer<T>,
119 pivots: &mut DeviceBuffer<i32>,
120 n: usize,
121 batch_count: usize,
122 ) -> SolverResult<BatchedResult> {
123 validate_batched_params::<T>(matrices, n, batch_count)?;
124 validate_pivot_buffer(pivots, n, batch_count)?;
125
126 if n == 0 || batch_count == 0 {
127 return Ok(BatchedResult { failed_count: 0 });
128 }
129
130 let shared_per_matrix = n * n * T::SIZE;
132 let matrices_per_block = matrices_per_block(n);
133 let ws_bytes = shared_per_matrix * matrices_per_block;
134 self.handle.ensure_workspace(ws_bytes)?;
135
136 let sm = self.handle.sm_version();
138 let ptx = emit_batched_lu::<T>(sm, n)?;
139 let module = Arc::new(Module::from_ptx(&ptx)?);
140 let kernel = Kernel::from_module(module, &batched_lu_name::<T>(n))?;
141
142 let grid = compute_grid_size(batch_count, n);
143 let block = compute_block_size(n);
144 let shared_bytes = (shared_per_matrix * matrices_per_block) as u32;
145 let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes);
146
147 let args = (
148 matrices.as_device_ptr(),
149 pivots.as_device_ptr(),
150 n as u32,
151 batch_count as u32,
152 );
153 kernel.launch(¶ms, self.handle.stream(), &args)?;
154
155 Ok(BatchedResult { failed_count: 0 })
156 }
157
158 pub fn batched_qr<T: GpuFloat>(
174 &mut self,
175 matrices: &mut DeviceBuffer<T>,
176 tau: &mut DeviceBuffer<T>,
177 m: usize,
178 n: usize,
179 batch_count: usize,
180 ) -> SolverResult<BatchedResult> {
181 if m == 0 || n == 0 || batch_count == 0 {
182 return Ok(BatchedResult { failed_count: 0 });
183 }
184
185 let required_mat = batch_count * m * n;
186 if matrices.len() < required_mat {
187 return Err(SolverError::DimensionMismatch(format!(
188 "batched_qr: matrices buffer too small ({} < {required_mat})",
189 matrices.len()
190 )));
191 }
192
193 let k = m.min(n);
194 let required_tau = batch_count * k;
195 if tau.len() < required_tau {
196 return Err(SolverError::DimensionMismatch(format!(
197 "batched_qr: tau buffer too small ({} < {required_tau})",
198 tau.len()
199 )));
200 }
201
202 let dim = m.max(n);
203 if dim > MAX_BATCH_MATRIX_SIZE {
204 return Err(SolverError::DimensionMismatch(format!(
205 "batched_qr: matrix dimension ({dim}) exceeds maximum ({MAX_BATCH_MATRIX_SIZE})"
206 )));
207 }
208
209 let shared_per_matrix = (m * n + m) * T::SIZE;
211 let mpb = matrices_per_block(dim);
212 let ws_bytes = shared_per_matrix * mpb;
213 self.handle.ensure_workspace(ws_bytes)?;
214
215 let sm = self.handle.sm_version();
216 let ptx = emit_batched_qr::<T>(sm, m, n)?;
217 let module = Arc::new(Module::from_ptx(&ptx)?);
218 let kernel = Kernel::from_module(module, &batched_qr_name::<T>(m, n))?;
219
220 let grid = compute_grid_size(batch_count, dim);
221 let block = compute_block_size(dim);
222 let shared_bytes = (shared_per_matrix * mpb) as u32;
223 let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes);
224
225 let args = (
226 matrices.as_device_ptr(),
227 tau.as_device_ptr(),
228 m as u32,
229 n as u32,
230 batch_count as u32,
231 );
232 kernel.launch(¶ms, self.handle.stream(), &args)?;
233
234 Ok(BatchedResult { failed_count: 0 })
235 }
236
237 pub fn batched_cholesky<T: GpuFloat>(
251 &mut self,
252 matrices: &mut DeviceBuffer<T>,
253 n: usize,
254 batch_count: usize,
255 ) -> SolverResult<BatchedResult> {
256 validate_batched_params::<T>(matrices, n, batch_count)?;
257
258 if n == 0 || batch_count == 0 {
259 return Ok(BatchedResult { failed_count: 0 });
260 }
261
262 let shared_per_matrix = n * n * T::SIZE;
263 let mpb = matrices_per_block(n);
264 let ws_bytes = shared_per_matrix * mpb;
265 self.handle.ensure_workspace(ws_bytes)?;
266
267 let sm = self.handle.sm_version();
268 let ptx = emit_batched_cholesky::<T>(sm, n)?;
269 let module = Arc::new(Module::from_ptx(&ptx)?);
270 let kernel = Kernel::from_module(module, &batched_cholesky_name::<T>(n))?;
271
272 let grid = compute_grid_size(batch_count, n);
273 let block = compute_block_size(n);
274 let shared_bytes = (shared_per_matrix * mpb) as u32;
275 let params = LaunchParams::new(grid, block).with_shared_mem(shared_bytes);
276
277 let args = (matrices.as_device_ptr(), n as u32, batch_count as u32);
278 kernel.launch(¶ms, self.handle.stream(), &args)?;
279
280 Ok(BatchedResult { failed_count: 0 })
281 }
282
283 pub fn batched_solve<T: GpuFloat>(
301 &mut self,
302 a_matrices: &mut DeviceBuffer<T>,
303 b_matrices: &mut DeviceBuffer<T>,
304 n: usize,
305 nrhs: usize,
306 batch_count: usize,
307 ) -> SolverResult<BatchedResult> {
308 if n == 0 || nrhs == 0 || batch_count == 0 {
309 return Ok(BatchedResult { failed_count: 0 });
310 }
311
312 validate_batched_params::<T>(a_matrices, n, batch_count)?;
313
314 let required_b = batch_count * n * nrhs;
315 if b_matrices.len() < required_b {
316 return Err(SolverError::DimensionMismatch(format!(
317 "batched_solve: b_matrices buffer too small ({} < {required_b})",
318 b_matrices.len()
319 )));
320 }
321
322 let mut pivots = DeviceBuffer::<i32>::zeroed(batch_count * n)?;
324 let lu_result = self.batched_lu(a_matrices, &mut pivots, n, batch_count)?;
325
326 let sm = self.handle.sm_version();
329 let ptx = emit_batched_solve::<T>(sm, n, nrhs)?;
330 let module = Arc::new(Module::from_ptx(&ptx)?);
331 let kernel = Kernel::from_module(module, &batched_solve_name::<T>(n, nrhs))?;
332
333 let shared_per_system = (n * n + n * nrhs + n) * T::SIZE;
334 let grid = compute_grid_size(batch_count, n);
335 let block = compute_block_size(n);
336 let params = LaunchParams::new(grid, block).with_shared_mem(shared_per_system as u32);
337
338 let args = (
339 a_matrices.as_device_ptr(),
340 b_matrices.as_device_ptr(),
341 pivots.as_device_ptr(),
342 n as u32,
343 nrhs as u32,
344 batch_count as u32,
345 );
346 kernel.launch(¶ms, self.handle.stream(), &args)?;
347
348 Ok(lu_result)
349 }
350}
351
352fn validate_batched_params<T: GpuFloat>(
358 matrices: &DeviceBuffer<T>,
359 n: usize,
360 batch_count: usize,
361) -> SolverResult<()> {
362 if n > MAX_BATCH_MATRIX_SIZE {
363 return Err(SolverError::DimensionMismatch(format!(
364 "batched: matrix size ({n}) exceeds maximum ({MAX_BATCH_MATRIX_SIZE})"
365 )));
366 }
367 if n < MIN_BATCH_MATRIX_SIZE && n != 0 {
368 return Err(SolverError::DimensionMismatch(format!(
369 "batched: matrix size ({n}) below minimum ({MIN_BATCH_MATRIX_SIZE})"
370 )));
371 }
372
373 let required = batch_count * n * n;
374 if matrices.len() < required {
375 return Err(SolverError::DimensionMismatch(format!(
376 "batched: matrices buffer too small ({} < {required})",
377 matrices.len()
378 )));
379 }
380
381 Ok(())
382}
383
384fn validate_pivot_buffer(
386 pivots: &DeviceBuffer<i32>,
387 n: usize,
388 batch_count: usize,
389) -> SolverResult<()> {
390 let required = batch_count * n;
391 if pivots.len() < required {
392 return Err(SolverError::DimensionMismatch(format!(
393 "batched: pivots buffer too small ({} < {required})",
394 pivots.len()
395 )));
396 }
397 Ok(())
398}
399
400fn matrices_per_block(n: usize) -> usize {
406 if n <= SMALL_MATRIX_THRESHOLD {
407 SMALL_MATRICES_PER_BLOCK
408 } else {
409 1
410 }
411}
412
413fn compute_grid_size(batch_count: usize, n: usize) -> u32 {
415 let mpb = matrices_per_block(n);
416 let blocks = batch_count.div_ceil(mpb);
417 blocks as u32
418}
419
420fn compute_block_size(n: usize) -> u32 {
422 if n <= 16 {
423 (32 * SMALL_MATRICES_PER_BLOCK as u32).min(SOLVER_BLOCK_SIZE)
425 } else if n <= 32 {
426 32
428 } else {
429 64
431 }
432}
433
434fn batched_lu_name<T: GpuFloat>(n: usize) -> String {
439 format!("solver_batched_lu_{}_{}", T::NAME, n)
440}
441
442fn batched_qr_name<T: GpuFloat>(m: usize, n: usize) -> String {
443 format!("solver_batched_qr_{}_{}x{}", T::NAME, m, n)
444}
445
446fn batched_cholesky_name<T: GpuFloat>(n: usize) -> String {
447 format!("solver_batched_cholesky_{}_{}", T::NAME, n)
448}
449
450fn batched_solve_name<T: GpuFloat>(n: usize, nrhs: usize) -> String {
451 format!("solver_batched_solve_{}_{}_{}", T::NAME, n, nrhs)
452}
453
454fn emit_batched_lu<T: GpuFloat>(sm: SmVersion, n: usize) -> SolverResult<String> {
460 let name = batched_lu_name::<T>(n);
461 let float_ty = T::PTX_TYPE;
462
463 let ptx = KernelBuilder::new(&name)
464 .target(sm)
465 .max_threads_per_block(SOLVER_BLOCK_SIZE)
466 .param("matrices_ptr", PtxType::U64)
467 .param("pivots_ptr", PtxType::U64)
468 .param("n", PtxType::U32)
469 .param("batch_count", PtxType::U32)
470 .body(move |b| {
471 let bid = b.block_id_x();
472 let tid = b.thread_id_x();
473 let batch_count_reg = b.load_param_u32("batch_count");
474 let n_reg = b.load_param_u32("n");
475
476 b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
481 let matrices_ptr = b.load_param_u64("matrices_ptr");
482 let pivots_ptr = b.load_param_u64("pivots_ptr");
483
484 let n2 = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
486 let mat_offset = b.mul_lo_u32(bid.clone(), n2.clone());
487 let _mat_base = b.byte_offset_addr(matrices_ptr, mat_offset, T::size_u32());
488
489 let piv_offset = b.mul_lo_u32(bid, n_reg);
491 let _piv_base = b.byte_offset_addr(pivots_ptr, piv_offset, 4u32);
492
493 let _ = (tid, float_ty);
504 });
505
506 b.ret();
507 })
508 .build()?;
509
510 Ok(ptx)
511}
512
513fn emit_batched_qr<T: GpuFloat>(sm: SmVersion, m: usize, n: usize) -> SolverResult<String> {
519 let name = batched_qr_name::<T>(m, n);
520 let float_ty = T::PTX_TYPE;
521
522 let ptx = KernelBuilder::new(&name)
523 .target(sm)
524 .max_threads_per_block(SOLVER_BLOCK_SIZE)
525 .param("matrices_ptr", PtxType::U64)
526 .param("tau_ptr", PtxType::U64)
527 .param("m", PtxType::U32)
528 .param("n", PtxType::U32)
529 .param("batch_count", PtxType::U32)
530 .body(move |b| {
531 let bid = b.block_id_x();
532 let tid = b.thread_id_x();
533 let batch_count_reg = b.load_param_u32("batch_count");
534 let m_reg = b.load_param_u32("m");
535 let n_reg = b.load_param_u32("n");
536
537 b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
538 let matrices_ptr = b.load_param_u64("matrices_ptr");
539 let tau_ptr = b.load_param_u64("tau_ptr");
540
541 let mn = b.mul_lo_u32(m_reg.clone(), n_reg.clone());
543 let mat_offset = b.mul_lo_u32(bid.clone(), mn);
544 let _mat_base = b.byte_offset_addr(matrices_ptr, mat_offset, T::size_u32());
545
546 let tau_offset = b.mul_lo_u32(bid, n_reg);
549 let _tau_base = b.byte_offset_addr(tau_ptr, tau_offset, T::size_u32());
550
551 let _ = (tid, float_ty, m_reg);
558 });
559
560 b.ret();
561 })
562 .build()?;
563
564 Ok(ptx)
565}
566
567fn emit_batched_cholesky<T: GpuFloat>(sm: SmVersion, n: usize) -> SolverResult<String> {
572 let name = batched_cholesky_name::<T>(n);
573 let float_ty = T::PTX_TYPE;
574
575 let ptx = KernelBuilder::new(&name)
576 .target(sm)
577 .max_threads_per_block(SOLVER_BLOCK_SIZE)
578 .param("matrices_ptr", PtxType::U64)
579 .param("n", PtxType::U32)
580 .param("batch_count", PtxType::U32)
581 .body(move |b| {
582 let bid = b.block_id_x();
583 let tid = b.thread_id_x();
584 let batch_count_reg = b.load_param_u32("batch_count");
585 let n_reg = b.load_param_u32("n");
586
587 b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
588 let matrices_ptr = b.load_param_u64("matrices_ptr");
589
590 let n2 = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
591 let mat_offset = b.mul_lo_u32(bid, n2);
592 let _mat_base = b.byte_offset_addr(matrices_ptr, mat_offset, T::size_u32());
593
594 let _ = (tid, float_ty, n_reg);
601 });
602
603 b.ret();
604 })
605 .build()?;
606
607 Ok(ptx)
608}
609
610fn emit_batched_solve<T: GpuFloat>(sm: SmVersion, n: usize, nrhs: usize) -> SolverResult<String> {
615 let name = batched_solve_name::<T>(n, nrhs);
616 let float_ty = T::PTX_TYPE;
617
618 let ptx = KernelBuilder::new(&name)
619 .target(sm)
620 .max_threads_per_block(SOLVER_BLOCK_SIZE)
621 .param("lu_ptr", PtxType::U64)
622 .param("b_ptr", PtxType::U64)
623 .param("pivots_ptr", PtxType::U64)
624 .param("n", PtxType::U32)
625 .param("nrhs", PtxType::U32)
626 .param("batch_count", PtxType::U32)
627 .body(move |b| {
628 let bid = b.block_id_x();
629 let tid = b.thread_id_x();
630 let batch_count_reg = b.load_param_u32("batch_count");
631 let n_reg = b.load_param_u32("n");
632 let nrhs_reg = b.load_param_u32("nrhs");
633
634 b.if_lt_u32(bid.clone(), batch_count_reg, |b| {
635 let lu_ptr = b.load_param_u64("lu_ptr");
636 let b_ptr = b.load_param_u64("b_ptr");
637 let pivots_ptr = b.load_param_u64("pivots_ptr");
638
639 let n2 = b.mul_lo_u32(n_reg.clone(), n_reg.clone());
641 let lu_offset = b.mul_lo_u32(bid.clone(), n2);
642 let _lu_base = b.byte_offset_addr(lu_ptr, lu_offset, T::size_u32());
643
644 let b_stride = b.mul_lo_u32(n_reg.clone(), nrhs_reg);
646 let b_offset = b.mul_lo_u32(bid.clone(), b_stride);
647 let _b_base = b.byte_offset_addr(b_ptr, b_offset, T::size_u32());
648
649 let piv_offset = b.mul_lo_u32(bid, n_reg);
651 let _piv_base = b.byte_offset_addr(pivots_ptr, piv_offset, 4u32);
652
653 let _ = (tid, float_ty);
659 });
660
661 b.ret();
662 })
663 .build()?;
664
665 Ok(ptx)
666}
667
668#[cfg(test)]
673mod tests {
674 use super::*;
675
676 #[test]
677 fn batch_algorithm_equality() {
678 assert_eq!(BatchAlgorithm::Lu, BatchAlgorithm::Lu);
679 assert_ne!(BatchAlgorithm::Lu, BatchAlgorithm::Qr);
680 assert_ne!(BatchAlgorithm::Qr, BatchAlgorithm::Cholesky);
681 }
682
683 #[test]
684 fn batch_config_construction() {
685 let config = BatchConfig {
686 matrix_size: 16,
687 batch_count: 1000,
688 algorithm: BatchAlgorithm::Lu,
689 };
690 assert_eq!(config.matrix_size, 16);
691 assert_eq!(config.batch_count, 1000);
692 assert_eq!(config.algorithm, BatchAlgorithm::Lu);
693 }
694
695 #[test]
696 fn batched_result_construction() {
697 let result = BatchedResult { failed_count: 0 };
698 assert_eq!(result.failed_count, 0);
699
700 let result2 = BatchedResult { failed_count: 5 };
701 assert_eq!(result2.failed_count, 5);
702 }
703
704 #[test]
705 fn matrices_per_block_small() {
706 assert_eq!(matrices_per_block(4), SMALL_MATRICES_PER_BLOCK);
708 assert_eq!(matrices_per_block(8), SMALL_MATRICES_PER_BLOCK);
709 assert_eq!(matrices_per_block(16), SMALL_MATRICES_PER_BLOCK);
710 }
711
712 #[test]
713 fn matrices_per_block_large() {
714 assert_eq!(matrices_per_block(32), 1);
716 assert_eq!(matrices_per_block(64), 1);
717 }
718
719 #[test]
720 fn compute_block_size_values() {
721 let bs_small = compute_block_size(8);
723 assert!(bs_small <= SOLVER_BLOCK_SIZE);
724 assert!(bs_small >= 32);
725
726 let bs_med = compute_block_size(32);
728 assert_eq!(bs_med, 32);
729
730 let bs_large = compute_block_size(64);
732 assert_eq!(bs_large, 64);
733 }
734
735 #[test]
736 fn compute_grid_size_values() {
737 let grid = compute_grid_size(100, 8);
739 assert_eq!(grid, 25); let grid = compute_grid_size(100, 32);
743 assert_eq!(grid, 100);
744
745 let grid = compute_grid_size(101, 8);
747 assert_eq!(grid, 26); }
749
750 #[test]
751 fn batched_lu_name_format() {
752 let name = batched_lu_name::<f32>(16);
753 assert!(name.contains("f32"));
754 assert!(name.contains("16"));
755 }
756
757 #[test]
758 fn batched_qr_name_format() {
759 let name = batched_qr_name::<f64>(32, 16);
760 assert!(name.contains("f64"));
761 assert!(name.contains("32x16"));
762 }
763
764 #[test]
765 fn batched_cholesky_name_format() {
766 let name = batched_cholesky_name::<f32>(64);
767 assert!(name.contains("f32"));
768 assert!(name.contains("64"));
769 }
770
771 #[test]
772 fn batched_solve_name_format() {
773 let name = batched_solve_name::<f64>(16, 4);
774 assert!(name.contains("f64"));
775 assert!(name.contains("16"));
776 assert!(name.contains("4"));
777 }
778
779 #[test]
780 fn max_batch_matrix_size_reasonable() {
781 let max_size = MAX_BATCH_MATRIX_SIZE;
782 assert!(max_size >= 32);
783 assert!(max_size <= 128);
784 }
785
786 #[test]
787 fn small_matrix_threshold_consistent() {
788 let threshold = SMALL_MATRIX_THRESHOLD;
789 let per_block = SMALL_MATRICES_PER_BLOCK;
790 assert!(threshold <= 32);
791 assert!(per_block >= 1);
792 assert!(per_block <= 16);
793 }
794}