use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::stream::Stream;
use crate::kernel::{Kernel, KernelArgs};
use crate::params::LaunchParams;
pub fn multi_stream_launch<A: KernelArgs>(
kernel: &Kernel,
streams: &[&Stream],
params: &[LaunchParams],
args: &[A],
) -> CudaResult<()> {
let n = streams.len();
if n == 0 {
return Err(CudaError::InvalidValue);
}
if params.len() != n || args.len() != n {
return Err(CudaError::InvalidValue);
}
for i in 0..n {
kernel.launch(¶ms[i], streams[i], &args[i])?;
}
Ok(())
}
pub fn multi_stream_launch_uniform<A: KernelArgs>(
kernel: &Kernel,
streams: &[&Stream],
params: &LaunchParams,
args: &A,
) -> CudaResult<()> {
if streams.is_empty() {
return Err(CudaError::InvalidValue);
}
for stream in streams {
kernel.launch(params, stream, args)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grid::Dim3;
#[test]
fn multi_stream_launch_signature_compiles() {
let _: fn(&Kernel, &[&Stream], &[LaunchParams], &[(u32,)]) -> CudaResult<()> =
multi_stream_launch;
}
#[test]
fn multi_stream_launch_uniform_signature_compiles() {
let _: fn(&Kernel, &[&Stream], &LaunchParams, &(u32,)) -> CudaResult<()> =
multi_stream_launch_uniform;
}
#[test]
fn multi_stream_launch_rejects_empty_streams() {
let streams: &[&Stream] = &[];
let params: &[LaunchParams] = &[];
let args: &[(u32,)] = &[];
assert!(streams.is_empty());
assert!(params.is_empty());
assert!(args.is_empty());
}
#[test]
fn multi_stream_launch_uniform_rejects_empty() {
let streams: &[&Stream] = &[];
let params = LaunchParams::new(1u32, 1u32);
let _ = (&streams, ¶ms);
}
#[test]
fn launch_params_for_multi_stream() {
let p1 = LaunchParams::new(Dim3::x(4), Dim3::x(256));
let p2 = LaunchParams::new(Dim3::x(8), Dim3::x(128));
let params = [p1, p2];
assert_eq!(params.len(), 2);
assert_eq!(params[0].grid.x, 4);
assert_eq!(params[1].grid.x, 8);
}
#[test]
fn multi_stream_count_validation() {
let streams_len = 3;
let params_len = 2;
assert_ne!(streams_len, params_len);
}
}