use oxicuda_driver::Stream;
use oxicuda_driver::ffi::CUdeviceptr;
use oxicuda_launch::{Dim3, Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::arch::SmVersion;
use std::sync::Arc;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::{GpuFloat, Transpose};
use super::batched_gemm::{TILE_M, TILE_N, generate_batched_gemm_ptx, validate_batched_args};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StreamDistribution {
RoundRobin,
EqualSplit,
}
#[derive(Debug, Clone)]
pub struct MultiStreamBatchedConfig {
pub num_streams: u32,
pub distribution: StreamDistribution,
}
impl MultiStreamBatchedConfig {
#[must_use]
pub fn new(num_streams: u32, distribution: StreamDistribution) -> Self {
Self {
num_streams,
distribution,
}
}
fn validate(&self) -> BlasResult<()> {
if self.num_streams == 0 {
return Err(BlasError::InvalidArgument(
"num_streams must be at least 1".into(),
));
}
Ok(())
}
}
pub fn distribute_batches(
batch_count: u32,
num_streams: u32,
distribution: &StreamDistribution,
) -> Vec<(u32, u32)> {
if batch_count == 0 || num_streams == 0 {
return Vec::new();
}
let effective_streams = num_streams.min(batch_count);
match distribution {
StreamDistribution::EqualSplit => {
let base = batch_count / effective_streams;
let remainder = batch_count % effective_streams;
let mut result = Vec::with_capacity(effective_streams as usize);
let mut offset = 0u32;
for i in 0..effective_streams {
let count = if i < remainder { base + 1 } else { base };
result.push((offset, count));
offset = offset.saturating_add(count);
}
result
}
StreamDistribution::RoundRobin => {
let base = batch_count / effective_streams;
let remainder = batch_count % effective_streams;
let mut result = Vec::with_capacity(effective_streams as usize);
for i in 0..effective_streams {
let count = if i < remainder { base + 1 } else { base };
result.push((i, count));
}
result
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn gemm_batched_multi_stream<T: GpuFloat>(
handle: &BlasHandle,
config: &MultiStreamBatchedConfig,
trans_a: Transpose,
trans_b: Transpose,
m: u32,
n: u32,
k: u32,
alpha: T,
a_ptrs: &DeviceBuffer<CUdeviceptr>,
lda: u32,
b_ptrs: &DeviceBuffer<CUdeviceptr>,
ldb: u32,
beta: T,
c_ptrs: &DeviceBuffer<CUdeviceptr>,
ldc: u32,
d_ptrs: &mut DeviceBuffer<CUdeviceptr>,
ldd: u32,
batch_count: u32,
streams: &[&Stream],
) -> BlasResult<()> {
if batch_count == 0 {
return Ok(());
}
config.validate()?;
if (streams.len() as u32) < config.num_streams {
return Err(BlasError::InvalidArgument(format!(
"streams slice has {} entries but config requires {}",
streams.len(),
config.num_streams
)));
}
validate_batched_args::<T>(
m,
n,
k,
a_ptrs,
lda,
b_ptrs,
ldb,
c_ptrs,
ldc,
d_ptrs,
ldd,
batch_count,
trans_a,
trans_b,
)?;
let sm = handle.sm_version();
if config.num_streams == 1 {
return launch_batch_on_stream::<T>(
sm,
streams[0],
trans_a,
trans_b,
m,
n,
k,
alpha,
a_ptrs,
lda,
b_ptrs,
ldb,
beta,
c_ptrs,
ldc,
d_ptrs,
ldd,
0,
batch_count,
);
}
let partitions = distribute_batches(batch_count, config.num_streams, &config.distribution);
let (ptx_source, kernel_name) = generate_batched_gemm_ptx::<T>(sm, m, n, k, trans_a, trans_b)?;
let module = oxicuda_driver::Module::from_ptx(&ptx_source).map_err(BlasError::Cuda)?;
let module = Arc::new(module);
let alpha_bits = alpha.to_bits_u64();
let beta_bits = beta.to_bits_u64();
for (stream_idx, &(start, count)) in partitions.iter().enumerate() {
if count == 0 {
continue;
}
let stream = streams[stream_idx];
let kernel =
Kernel::from_module(Arc::clone(&module), &kernel_name).map_err(BlasError::Cuda)?;
let grid = Dim3::new(m.div_ceil(TILE_M), n.div_ceil(TILE_N), count);
let block = Dim3::new(TILE_M, TILE_N, 1);
let params = LaunchParams::new(grid, block);
let a_ptr_offset = a_ptrs
.as_device_ptr()
.wrapping_add(start as u64 * std::mem::size_of::<CUdeviceptr>() as u64);
let b_ptr_offset = b_ptrs
.as_device_ptr()
.wrapping_add(start as u64 * std::mem::size_of::<CUdeviceptr>() as u64);
let c_ptr_offset = c_ptrs
.as_device_ptr()
.wrapping_add(start as u64 * std::mem::size_of::<CUdeviceptr>() as u64);
let d_ptr_offset = d_ptrs
.as_device_ptr()
.wrapping_add(start as u64 * std::mem::size_of::<CUdeviceptr>() as u64);
let args = (
m,
n,
k,
alpha_bits,
a_ptr_offset,
lda,
b_ptr_offset,
ldb,
beta_bits,
c_ptr_offset,
ldc,
d_ptr_offset,
ldd,
);
kernel
.launch(¶ms, stream, &args)
.map_err(|e| BlasError::LaunchFailed(format!("stream {stream_idx}: {e}")))?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn launch_batch_on_stream<T: GpuFloat>(
sm: SmVersion,
stream: &Stream,
trans_a: Transpose,
trans_b: Transpose,
m: u32,
n: u32,
k: u32,
alpha: T,
a_ptrs: &DeviceBuffer<CUdeviceptr>,
lda: u32,
b_ptrs: &DeviceBuffer<CUdeviceptr>,
ldb: u32,
beta: T,
c_ptrs: &DeviceBuffer<CUdeviceptr>,
ldc: u32,
d_ptrs: &mut DeviceBuffer<CUdeviceptr>,
ldd: u32,
start: u32,
count: u32,
) -> BlasResult<()> {
if count == 0 {
return Ok(());
}
let (ptx_source, kernel_name) = generate_batched_gemm_ptx::<T>(sm, m, n, k, trans_a, trans_b)?;
let module = oxicuda_driver::Module::from_ptx(&ptx_source).map_err(BlasError::Cuda)?;
let module = Arc::new(module);
let kernel = Kernel::from_module(module, &kernel_name).map_err(BlasError::Cuda)?;
let grid = Dim3::new(m.div_ceil(TILE_M), n.div_ceil(TILE_N), count);
let block = Dim3::new(TILE_M, TILE_N, 1);
let params = LaunchParams::new(grid, block);
let alpha_bits = alpha.to_bits_u64();
let beta_bits = beta.to_bits_u64();
let ptr_size = std::mem::size_of::<CUdeviceptr>() as u64;
let offset = start as u64 * ptr_size;
let args = (
m,
n,
k,
alpha_bits,
a_ptrs.as_device_ptr().wrapping_add(offset),
lda,
b_ptrs.as_device_ptr().wrapping_add(offset),
ldb,
beta_bits,
c_ptrs.as_device_ptr().wrapping_add(offset),
ldc,
d_ptrs.as_device_ptr().wrapping_add(offset),
ldd,
);
kernel
.launch(¶ms, stream, &args)
.map_err(|e| BlasError::LaunchFailed(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn equal_split_even_division() {
let result = distribute_batches(12, 3, &StreamDistribution::EqualSplit);
assert_eq!(result, vec![(0, 4), (4, 4), (8, 4)]);
}
#[test]
fn equal_split_uneven_division() {
let result = distribute_batches(10, 3, &StreamDistribution::EqualSplit);
assert_eq!(result, vec![(0, 4), (4, 3), (7, 3)]);
}
#[test]
fn equal_split_more_streams_than_batches() {
let result = distribute_batches(2, 5, &StreamDistribution::EqualSplit);
assert_eq!(result.len(), 2);
assert_eq!(result[0], (0, 1));
assert_eq!(result[1], (1, 1));
}
#[test]
fn equal_split_single_batch() {
let result = distribute_batches(1, 4, &StreamDistribution::EqualSplit);
assert_eq!(result.len(), 1);
assert_eq!(result[0], (0, 1));
}
#[test]
fn equal_split_single_stream() {
let result = distribute_batches(100, 1, &StreamDistribution::EqualSplit);
assert_eq!(result, vec![(0, 100)]);
}
#[test]
fn round_robin_even_division() {
let result = distribute_batches(12, 3, &StreamDistribution::RoundRobin);
assert_eq!(result.len(), 3);
assert_eq!(result[0], (0, 4));
assert_eq!(result[1], (1, 4));
assert_eq!(result[2], (2, 4));
}
#[test]
fn round_robin_uneven_division() {
let result = distribute_batches(10, 3, &StreamDistribution::RoundRobin);
assert_eq!(result[0], (0, 4));
assert_eq!(result[1], (1, 3));
assert_eq!(result[2], (2, 3));
}
#[test]
fn zero_batches_returns_empty() {
let result = distribute_batches(0, 4, &StreamDistribution::EqualSplit);
assert!(result.is_empty());
}
#[test]
fn zero_streams_returns_empty() {
let result = distribute_batches(10, 0, &StreamDistribution::EqualSplit);
assert!(result.is_empty());
}
#[test]
fn config_validation_rejects_zero_streams() {
let config = MultiStreamBatchedConfig::new(0, StreamDistribution::EqualSplit);
let res = config.validate();
assert!(res.is_err());
}
#[test]
fn config_validation_accepts_positive_streams() {
let config = MultiStreamBatchedConfig::new(4, StreamDistribution::RoundRobin);
let res = config.validate();
assert!(res.is_ok());
}
#[test]
fn total_batches_preserved_equal_split() {
for batch_count in 1u32..=50 {
for num_streams in 1u32..=10 {
let partitions =
distribute_batches(batch_count, num_streams, &StreamDistribution::EqualSplit);
let total: u32 = partitions.iter().map(|&(_, c)| c).sum();
assert_eq!(
total, batch_count,
"batch_count={batch_count}, num_streams={num_streams}"
);
}
}
}
#[test]
fn total_batches_preserved_round_robin() {
for batch_count in 1u32..=50 {
for num_streams in 1u32..=10 {
let partitions =
distribute_batches(batch_count, num_streams, &StreamDistribution::RoundRobin);
let total: u32 = partitions.iter().map(|&(_, c)| c).sum();
assert_eq!(
total, batch_count,
"batch_count={batch_count}, num_streams={num_streams}"
);
}
}
}
#[test]
fn equal_split_contiguous_coverage() {
let partitions = distribute_batches(17, 4, &StreamDistribution::EqualSplit);
let mut expected_start = 0u32;
for &(start, count) in &partitions {
assert_eq!(start, expected_start);
expected_start += count;
}
assert_eq!(expected_start, 17);
}
#[test]
fn distribution_balance_property() {
let batch_count = 23u32;
let num_streams = 5u32;
let max_per_stream = batch_count.div_ceil(num_streams);
for dist in &[
StreamDistribution::EqualSplit,
StreamDistribution::RoundRobin,
] {
let partitions = distribute_batches(batch_count, num_streams, dist);
for &(_, count) in &partitions {
assert!(
count <= max_per_stream,
"stream got {count} batches, max allowed {max_per_stream} ({dist:?})"
);
}
}
}
}