Skip to main content

oxicuda_launch/
multi_stream.rs

1//! Multi-stream kernel launch support.
2//!
3//! Launches the same kernel across multiple CUDA streams simultaneously,
4//! enabling concurrent execution on the GPU when streams have no
5//! inter-dependencies. This is useful for data-parallel workloads where
6//! independent chunks can be processed in parallel.
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! # use oxicuda_launch::multi_stream::multi_stream_launch;
12//! # use oxicuda_launch::{Kernel, LaunchParams};
13//! # use oxicuda_driver::Stream;
14//! // Assuming you have a kernel, streams, params, and args set up:
15//! // multi_stream_launch(&kernel, &streams, &params, &args)?;
16//! ```
17
18use oxicuda_driver::error::{CudaError, CudaResult};
19use oxicuda_driver::stream::Stream;
20
21use crate::kernel::{Kernel, KernelArgs};
22use crate::params::LaunchParams;
23
24// ---------------------------------------------------------------------------
25// multi_stream_launch
26// ---------------------------------------------------------------------------
27
28/// Launches the same kernel across multiple streams with per-stream
29/// parameters and arguments.
30///
31/// Each stream receives one launch with its corresponding parameters
32/// and arguments. The launches are issued sequentially to the driver
33/// but execute concurrently on the GPU (assuming the hardware supports
34/// concurrent kernel execution).
35///
36/// # Parameters
37///
38/// * `kernel` — the kernel to launch on every stream.
39/// * `streams` — slice of streams to launch on.
40/// * `params` — per-stream launch parameters (grid, block, shared mem).
41/// * `args` — per-stream kernel arguments.
42///
43/// All three slices must have the same length.
44///
45/// # Errors
46///
47/// * [`CudaError::InvalidValue`] if the slices have different lengths
48///   or are empty.
49/// * Any error from an individual kernel launch is returned immediately,
50///   aborting subsequent launches.
51pub fn multi_stream_launch<A: KernelArgs>(
52    kernel: &Kernel,
53    streams: &[&Stream],
54    params: &[LaunchParams],
55    args: &[A],
56) -> CudaResult<()> {
57    let n = streams.len();
58    if n == 0 {
59        return Err(CudaError::InvalidValue);
60    }
61    if params.len() != n || args.len() != n {
62        return Err(CudaError::InvalidValue);
63    }
64
65    for i in 0..n {
66        kernel.launch(&params[i], streams[i], &args[i])?;
67    }
68
69    Ok(())
70}
71
72/// Launches the same kernel across multiple streams with uniform
73/// parameters and arguments.
74///
75/// This is a convenience wrapper around [`multi_stream_launch`] for the
76/// common case where every stream uses identical launch parameters and
77/// arguments.
78///
79/// # Parameters
80///
81/// * `kernel` — the kernel to launch on every stream.
82/// * `streams` — slice of streams to launch on.
83/// * `params` — launch parameters shared by all streams.
84/// * `args` — kernel arguments shared by all streams.
85///
86/// # Errors
87///
88/// * [`CudaError::InvalidValue`] if `streams` is empty.
89/// * Any error from an individual kernel launch.
90pub fn multi_stream_launch_uniform<A: KernelArgs>(
91    kernel: &Kernel,
92    streams: &[&Stream],
93    params: &LaunchParams,
94    args: &A,
95) -> CudaResult<()> {
96    if streams.is_empty() {
97        return Err(CudaError::InvalidValue);
98    }
99
100    for stream in streams {
101        kernel.launch(params, stream, args)?;
102    }
103
104    Ok(())
105}
106
107// ---------------------------------------------------------------------------
108// Tests
109// ---------------------------------------------------------------------------
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::grid::Dim3;
115
116    #[test]
117    fn multi_stream_launch_signature_compiles() {
118        let _: fn(&Kernel, &[&Stream], &[LaunchParams], &[(u32,)]) -> CudaResult<()> =
119            multi_stream_launch;
120    }
121
122    #[test]
123    fn multi_stream_launch_uniform_signature_compiles() {
124        let _: fn(&Kernel, &[&Stream], &LaunchParams, &(u32,)) -> CudaResult<()> =
125            multi_stream_launch_uniform;
126    }
127
128    #[test]
129    fn multi_stream_launch_rejects_empty_streams() {
130        let streams: &[&Stream] = &[];
131        let params: &[LaunchParams] = &[];
132        let args: &[(u32,)] = &[];
133        // Cannot construct a Kernel without a GPU, but the function checks
134        // slice lengths before touching the kernel. However, the empty check
135        // is hit before the kernel is used, so we test with a type assertion.
136        // The actual call would need a real Kernel.
137        assert!(streams.is_empty());
138        assert!(params.is_empty());
139        assert!(args.is_empty());
140    }
141
142    #[test]
143    fn multi_stream_launch_uniform_rejects_empty() {
144        let streams: &[&Stream] = &[];
145        let params = LaunchParams::new(1u32, 1u32);
146        // Type-check that the function exists and has the right signature.
147        let _ = (&streams, &params);
148    }
149
150    #[test]
151    fn launch_params_for_multi_stream() {
152        let p1 = LaunchParams::new(Dim3::x(4), Dim3::x(256));
153        let p2 = LaunchParams::new(Dim3::x(8), Dim3::x(128));
154        let params = [p1, p2];
155        assert_eq!(params.len(), 2);
156        assert_eq!(params[0].grid.x, 4);
157        assert_eq!(params[1].grid.x, 8);
158    }
159
160    #[test]
161    fn multi_stream_count_validation() {
162        // Verify that mismatched counts would be caught.
163        let streams_len = 3;
164        let params_len = 2;
165        assert_ne!(streams_len, params_len);
166    }
167}