oxicuda_launch/multi_stream.rs
1//! Multi-stream kernel launch support.
2//!
3//! Launches the same kernel across multiple CUDA streams simultaneously,
4//! enabling concurrent execution on the GPU when streams have no
5//! inter-dependencies. This is useful for data-parallel workloads where
6//! independent chunks can be processed in parallel.
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! # use oxicuda_launch::multi_stream::multi_stream_launch;
12//! # use oxicuda_launch::{Kernel, LaunchParams};
13//! # use oxicuda_driver::Stream;
14//! // Assuming you have a kernel, streams, params, and args set up:
15//! // multi_stream_launch(&kernel, &streams, ¶ms, &args)?;
16//! ```
17
18use oxicuda_driver::error::{CudaError, CudaResult};
19use oxicuda_driver::stream::Stream;
20
21use crate::kernel::{Kernel, KernelArgs};
22use crate::params::LaunchParams;
23
24// ---------------------------------------------------------------------------
25// multi_stream_launch
26// ---------------------------------------------------------------------------
27
28/// Launches the same kernel across multiple streams with per-stream
29/// parameters and arguments.
30///
31/// Each stream receives one launch with its corresponding parameters
32/// and arguments. The launches are issued sequentially to the driver
33/// but execute concurrently on the GPU (assuming the hardware supports
34/// concurrent kernel execution).
35///
36/// # Parameters
37///
38/// * `kernel` — the kernel to launch on every stream.
39/// * `streams` — slice of streams to launch on.
40/// * `params` — per-stream launch parameters (grid, block, shared mem).
41/// * `args` — per-stream kernel arguments.
42///
43/// All three slices must have the same length.
44///
45/// # Errors
46///
47/// * [`CudaError::InvalidValue`] if the slices have different lengths
48/// or are empty.
49/// * Any error from an individual kernel launch is returned immediately,
50/// aborting subsequent launches.
51pub fn multi_stream_launch<A: KernelArgs>(
52 kernel: &Kernel,
53 streams: &[&Stream],
54 params: &[LaunchParams],
55 args: &[A],
56) -> CudaResult<()> {
57 let n = streams.len();
58 if n == 0 {
59 return Err(CudaError::InvalidValue);
60 }
61 if params.len() != n || args.len() != n {
62 return Err(CudaError::InvalidValue);
63 }
64
65 for i in 0..n {
66 kernel.launch(¶ms[i], streams[i], &args[i])?;
67 }
68
69 Ok(())
70}
71
72/// Launches the same kernel across multiple streams with uniform
73/// parameters and arguments.
74///
75/// This is a convenience wrapper around [`multi_stream_launch`] for the
76/// common case where every stream uses identical launch parameters and
77/// arguments.
78///
79/// # Parameters
80///
81/// * `kernel` — the kernel to launch on every stream.
82/// * `streams` — slice of streams to launch on.
83/// * `params` — launch parameters shared by all streams.
84/// * `args` — kernel arguments shared by all streams.
85///
86/// # Errors
87///
88/// * [`CudaError::InvalidValue`] if `streams` is empty.
89/// * Any error from an individual kernel launch.
90pub fn multi_stream_launch_uniform<A: KernelArgs>(
91 kernel: &Kernel,
92 streams: &[&Stream],
93 params: &LaunchParams,
94 args: &A,
95) -> CudaResult<()> {
96 if streams.is_empty() {
97 return Err(CudaError::InvalidValue);
98 }
99
100 for stream in streams {
101 kernel.launch(params, stream, args)?;
102 }
103
104 Ok(())
105}
106
107// ---------------------------------------------------------------------------
108// Tests
109// ---------------------------------------------------------------------------
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::grid::Dim3;
115
116 #[test]
117 fn multi_stream_launch_signature_compiles() {
118 let _: fn(&Kernel, &[&Stream], &[LaunchParams], &[(u32,)]) -> CudaResult<()> =
119 multi_stream_launch;
120 }
121
122 #[test]
123 fn multi_stream_launch_uniform_signature_compiles() {
124 let _: fn(&Kernel, &[&Stream], &LaunchParams, &(u32,)) -> CudaResult<()> =
125 multi_stream_launch_uniform;
126 }
127
128 #[test]
129 fn multi_stream_launch_rejects_empty_streams() {
130 let streams: &[&Stream] = &[];
131 let params: &[LaunchParams] = &[];
132 let args: &[(u32,)] = &[];
133 // Cannot construct a Kernel without a GPU, but the function checks
134 // slice lengths before touching the kernel. However, the empty check
135 // is hit before the kernel is used, so we test with a type assertion.
136 // The actual call would need a real Kernel.
137 assert!(streams.is_empty());
138 assert!(params.is_empty());
139 assert!(args.is_empty());
140 }
141
142 #[test]
143 fn multi_stream_launch_uniform_rejects_empty() {
144 let streams: &[&Stream] = &[];
145 let params = LaunchParams::new(1u32, 1u32);
146 // Type-check that the function exists and has the right signature.
147 let _ = (&streams, ¶ms);
148 }
149
150 #[test]
151 fn launch_params_for_multi_stream() {
152 let p1 = LaunchParams::new(Dim3::x(4), Dim3::x(256));
153 let p2 = LaunchParams::new(Dim3::x(8), Dim3::x(128));
154 let params = [p1, p2];
155 assert_eq!(params.len(), 2);
156 assert_eq!(params[0].grid.x, 4);
157 assert_eq!(params[1].grid.x, 8);
158 }
159
160 #[test]
161 fn multi_stream_count_validation() {
162 // Verify that mismatched counts would be caught.
163 let streams_len = 3;
164 let params_len = 2;
165 assert_ne!(streams_len, params_len);
166 }
167}