iro_cuda_ffi_kernels/
lib.rs

1//! Reference CUDA kernels for iro-cuda-ffi.
2//!
3//! This crate provides sample kernels that demonstrate proper iro-cuda-ffi usage patterns
4//! and serve as integration tests for the iro-cuda-ffi core crate.
5//!
6//! # Available Kernels
7//!
8//! - `vector_add_f32`: Element-wise vector addition
9//! - `fma_chain_f32`: Deep compute chain (FMA)
10//! - `saxpy_f32`: Single-precision A*X + Y
11//! - `daxpy_f64`: Double-precision A*X + Y
12//! - `scale_f32`: Vector scaling
13//! - `reduce_sum_f32`: Parallel sum reduction
14//! - `reduce_max_f32`: Parallel max reduction
15//!
16//! # Example
17//!
18//! ```ignore
19//! use iro_cuda_ffi::prelude::*;
20//! use iro_cuda_ffi_kernels::vector_add_f32;
21//!
22//! let stream = Stream::new()?;
23//!
24//! let a = DeviceBuffer::from_slice_sync(&stream, &[1.0f32, 2.0, 3.0, 4.0])?;
25//! let b = DeviceBuffer::from_slice_sync(&stream, &[5.0f32, 6.0, 7.0, 8.0])?;
26//! let mut c = DeviceBuffer::<f32>::zeros(4)?;
27//!
28//! vector_add_f32(&stream, &a, &b, &mut c)?;
29//!
30//! let result = c.to_vec(&stream)?;
31//! assert_eq!(result, vec![6.0, 8.0, 10.0, 12.0]);
32//! ```
33//!
34//! ## Device-to-device copy
35//!
36//! ```ignore
37//! use iro_cuda_ffi::prelude::*;
38//!
39//! let stream = Stream::new()?;
40//! let src = DeviceBuffer::from_slice_sync(&stream, &[1.0f32, 2.0, 3.0])?;
41//! let mut dst = DeviceBuffer::<f32>::alloc(src.len())?;
42//!
43//! dst.copy_from_device_sync(&stream, &src)?;
44//! let result = dst.to_vec(&stream)?;
45//! assert_eq!(result, vec![1.0, 2.0, 3.0]);
46//! ```
47
48#![warn(missing_docs)]
49#![warn(clippy::all, clippy::pedantic)]
50
51use iro_cuda_ffi::error::icffi_codes;
52use iro_cuda_ffi::prelude::*;
53
54// FFI declarations for kernel exports (Rust 2024: unsafe extern required)
55mod ffi {
56    use iro_cuda_ffi::abi::{InBufferDesc, LaunchParams, OutBufferDesc};
57
58    // Kernel FFI declarations.
59    //
60    // All kernel wrapper functions follow the iro-cuda-ffi ABI:
61    // - Return `cudaError_t` (i32 where 0 = success)
62    // - Take `LaunchParams` by value as first argument
63    // - Never synchronize internally
64    unsafe extern "C" {
65        pub fn icffi_vector_add_f32(
66            p: LaunchParams,
67            a: InBufferDesc<f32>,
68            b: InBufferDesc<f32>,
69            out: OutBufferDesc<f32>,
70        ) -> i32;
71
72        pub fn icffi_fma_chain_f32(
73            p: LaunchParams,
74            a: InBufferDesc<f32>,
75            b: InBufferDesc<f32>,
76            out: OutBufferDesc<f32>,
77            iters: u32,
78        ) -> i32;
79
80        pub fn icffi_saxpy_f32(
81            p: LaunchParams,
82            x: InBufferDesc<f32>,
83            y: OutBufferDesc<f32>,
84            a: f32,
85        ) -> i32;
86
87        pub fn icffi_daxpy_f64(
88            p: LaunchParams,
89            x: InBufferDesc<f64>,
90            y: OutBufferDesc<f64>,
91            a: f64,
92        ) -> i32;
93
94        pub fn icffi_scale_f32(
95            p: LaunchParams,
96            x: InBufferDesc<f32>,
97            y: OutBufferDesc<f32>,
98            a: f32,
99        ) -> i32;
100
101        pub fn icffi_reduce_sum_f32(
102            p: LaunchParams,
103            input: InBufferDesc<f32>,
104            output: OutBufferDesc<f32>,
105        ) -> i32;
106
107        pub fn icffi_reduce_max_f32(
108            p: LaunchParams,
109            input: InBufferDesc<f32>,
110            output: OutBufferDesc<f32>,
111        ) -> i32;
112
113        // Ensure ABI asserts TU is linked
114        pub fn icffi_abi_asserts_linked();
115    }
116}
117
118const BLOCK_SIZE: usize = 256;
119// Must match reduction kernels (each thread loads 2 elements).
120const REDUCTION_ELEMENTS_PER_THREAD: usize = 2;
121const REDUCTION_ELEMENTS_PER_BLOCK: usize = BLOCK_SIZE * REDUCTION_ELEMENTS_PER_THREAD;
122#[allow(clippy::cast_possible_truncation)]
123const BLOCK_SIZE_U32: u32 = BLOCK_SIZE as u32; // Safe: 256 fits in u32
124
125#[inline]
126fn ensure_len_eq(name: &str, left_label: &str, left: usize, right_label: &str, right: usize) -> Result<()> {
127    if left != right {
128        return Err(IcffiError::with_location(
129            icffi_codes::LENGTH_MISMATCH,
130            format!("{name}: length mismatch ({left_label}={left} != {right_label}={right})"),
131        ));
132    }
133    Ok(())
134}
135
136/// Maximum grid_x dimension per CUDA spec (2^31 - 1).
137///
138/// This matches the limit in `LaunchParams::is_valid()`. The CUDA spec caps
139/// grid_x at this value, not u32::MAX. Using a larger value would pass the
140/// u32 check but fail the actual launch.
141pub const MAX_GRID_X: usize = 0x7FFF_FFFF;
142
143#[inline]
144#[allow(clippy::cast_possible_truncation)]
145fn grid_u32(name: &str, grid: usize) -> Result<u32> {
146    if grid > MAX_GRID_X {
147        return Err(IcffiError::with_location(
148            icffi_codes::GRID_TOO_LARGE,
149            format!("{name}: grid size exceeds MAX_GRID_X ({grid} > {MAX_GRID_X})"),
150        ));
151    }
152    // The truncation is safe because we've already checked the value fits in i32
153    Ok(grid as u32)
154}
155
156#[inline]
157fn div_ceil_usize(n: usize, denom: usize) -> usize {
158    debug_assert!(denom != 0);
159    let q = n / denom;
160    if n % denom == 0 {
161        q
162    } else {
163        q + 1
164    }
165}
166
167/// Computes the number of blocks needed for a 1D launch.
168#[inline]
169fn blocks_for(n: usize, block_size: usize) -> usize {
170    div_ceil_usize(n, block_size)
171}
172
173#[inline]
174fn reduce_sum_f32_len(
175    stream: &Stream,
176    input: &DeviceBuffer<f32>,
177    input_len: usize,
178    output: &mut DeviceBuffer<f32>,
179) -> Result<usize> {
180    if input_len == 0 {
181        return Ok(0);
182    }
183
184    if input_len > input.len() {
185        return Err(IcffiError::with_location(
186            icffi_codes::INVALID_ARGUMENT,
187            format!(
188                "reduce_sum_f32: input_len exceeds buffer length ({} > {})",
189                input_len,
190                input.len()
191            ),
192        ));
193    }
194
195    let grid = reduction_output_size(input_len);
196    if output.len() < grid {
197        return Err(IcffiError::with_location(
198            icffi_codes::OUTPUT_TOO_SMALL,
199            format!(
200                "reduce_sum_f32: output too small ({} < {})",
201                output.len(),
202                grid
203            ),
204        ));
205    }
206
207    let grid_u32 = grid_u32("reduce_sum_f32", grid)?;
208    let params = LaunchParams::new_1d(grid_u32, BLOCK_SIZE_U32, stream.raw());
209    let input_desc = InBufferDesc::new(input.as_ptr(), input_len as u64);
210    let output_desc = OutBufferDesc::new(output.as_mut_ptr(), grid as u64);
211
212    check(unsafe { ffi::icffi_reduce_sum_f32(params, input_desc, output_desc) })?;
213
214    Ok(grid)
215}
216
217#[inline]
218fn reduce_max_f32_len(
219    stream: &Stream,
220    input: &DeviceBuffer<f32>,
221    input_len: usize,
222    output: &mut DeviceBuffer<f32>,
223) -> Result<usize> {
224    if input_len == 0 {
225        return Ok(0);
226    }
227
228    if input_len > input.len() {
229        return Err(IcffiError::with_location(
230            icffi_codes::INVALID_ARGUMENT,
231            format!(
232                "reduce_max_f32: input_len exceeds buffer length ({} > {})",
233                input_len,
234                input.len()
235            ),
236        ));
237    }
238
239    let grid = reduction_output_size(input_len);
240    if output.len() < grid {
241        return Err(IcffiError::with_location(
242            icffi_codes::OUTPUT_TOO_SMALL,
243            format!(
244                "reduce_max_f32: output too small ({} < {})",
245                output.len(),
246                grid
247            ),
248        ));
249    }
250
251    let grid_u32 = grid_u32("reduce_max_f32", grid)?;
252    let params = LaunchParams::new_1d(grid_u32, BLOCK_SIZE_U32, stream.raw());
253    let input_desc = InBufferDesc::new(input.as_ptr(), input_len as u64);
254    let output_desc = OutBufferDesc::new(output.as_mut_ptr(), grid as u64);
255
256    check(unsafe { ffi::icffi_reduce_max_f32(params, input_desc, output_desc) })?;
257
258    Ok(grid)
259}
260
261/// Element-wise vector addition: out = a + b
262///
263/// # Arguments
264///
265/// * `stream` - CUDA stream for the operation
266/// * `a` - First input vector
267/// * `b` - Second input vector (must have same length as `a`)
268/// * `out` - Output vector (must have same length as `a`)
269///
270/// # Errors
271///
272/// Returns an error if the vectors have mismatched lengths or kernel launch fails.
273///
274/// # Example
275///
276/// ```ignore
277/// let stream = Stream::new()?;
278/// let a = DeviceBuffer::from_slice_sync(&stream, &[1.0f32, 2.0, 3.0])?;
279/// let b = DeviceBuffer::from_slice_sync(&stream, &[4.0f32, 5.0, 6.0])?;
280/// let mut c = DeviceBuffer::<f32>::zeros(3)?;
281///
282/// vector_add_f32(&stream, &a, &b, &mut c)?;
283/// ```
284#[track_caller]
285pub fn vector_add_f32(
286    stream: &Stream,
287    a: &DeviceBuffer<f32>,
288    b: &DeviceBuffer<f32>,
289    out: &mut DeviceBuffer<f32>,
290) -> Result<()> {
291    let n = a.len();
292    ensure_len_eq("vector_add_f32", "b", b.len(), "a", n)?;
293    ensure_len_eq("vector_add_f32", "out", out.len(), "a", n)?;
294
295    if n == 0 {
296        return Ok(());
297    }
298
299    let grid = blocks_for(n, BLOCK_SIZE);
300    let grid = grid_u32("vector_add_f32", grid)?;
301    let params = LaunchParams::new_1d(grid, BLOCK_SIZE_U32, stream.raw());
302
303    check(unsafe { ffi::icffi_vector_add_f32(params, a.as_in(), b.as_in(), out.as_out()) })
304}
305
306/// Deep compute chain: `out = fma_chain(a, b, iters)`
307///
308/// Each element performs `iters` iterations of: `acc = acc * b[i] + 1.0`.
309///
310/// # Arguments
311///
312/// * `stream` - CUDA stream for the operation
313/// * `a` - First input vector
314/// * `b` - Second input vector (must have same length as `a`)
315/// * `out` - Output vector (must have same length as `a`)
316/// * `iters` - Number of multiply-add iterations per element
317///
318/// # Errors
319///
320/// Returns an error if the vectors have mismatched lengths or kernel launch fails.
321#[track_caller]
322pub fn fma_chain_f32(
323    stream: &Stream,
324    a: &DeviceBuffer<f32>,
325    b: &DeviceBuffer<f32>,
326    out: &mut DeviceBuffer<f32>,
327    iters: u32,
328) -> Result<()> {
329    let n = a.len();
330    ensure_len_eq("fma_chain_f32", "b", b.len(), "a", n)?;
331    ensure_len_eq("fma_chain_f32", "out", out.len(), "a", n)?;
332
333    if n == 0 {
334        return Ok(());
335    }
336
337    let grid = blocks_for(n, BLOCK_SIZE);
338    let grid = grid_u32("fma_chain_f32", grid)?;
339    let params = LaunchParams::new_1d(grid, BLOCK_SIZE_U32, stream.raw());
340
341    check(unsafe { ffi::icffi_fma_chain_f32(params, a.as_in(), b.as_in(), out.as_out(), iters) })
342}
343
344/// SAXPY operation: y = a * x + y (in-place)
345///
346/// # Arguments
347///
348/// * `stream` - CUDA stream for the operation
349/// * `a` - Scalar multiplier
350/// * `x` - Input vector
351/// * `y` - Input/output vector (modified in place)
352///
353/// # Errors
354///
355/// Returns an error if the vectors have mismatched lengths or kernel launch fails.
356#[track_caller]
357pub fn saxpy_f32(
358    stream: &Stream,
359    a: f32,
360    x: &DeviceBuffer<f32>,
361    y: &mut DeviceBuffer<f32>,
362) -> Result<()> {
363    let n = x.len();
364    ensure_len_eq("saxpy_f32", "y", y.len(), "x", n)?;
365
366    if n == 0 {
367        return Ok(());
368    }
369
370    let grid = blocks_for(n, BLOCK_SIZE);
371    let grid = grid_u32("saxpy_f32", grid)?;
372    let params = LaunchParams::new_1d(grid, BLOCK_SIZE_U32, stream.raw());
373
374    check(unsafe { ffi::icffi_saxpy_f32(params, x.as_in(), y.as_out(), a) })
375}
376
377/// DAXPY operation: y = a * x + y (in-place, double precision)
378///
379/// # Arguments
380///
381/// * `stream` - CUDA stream for the operation
382/// * `a` - Scalar multiplier
383/// * `x` - Input vector
384/// * `y` - Input/output vector (modified in place)
385///
386/// # Errors
387///
388/// Returns an error if the vectors have mismatched lengths or kernel launch fails.
389#[track_caller]
390pub fn daxpy_f64(
391    stream: &Stream,
392    a: f64,
393    x: &DeviceBuffer<f64>,
394    y: &mut DeviceBuffer<f64>,
395) -> Result<()> {
396    let n = x.len();
397    ensure_len_eq("daxpy_f64", "y", y.len(), "x", n)?;
398
399    if n == 0 {
400        return Ok(());
401    }
402
403    let grid = blocks_for(n, BLOCK_SIZE);
404    let grid = grid_u32("daxpy_f64", grid)?;
405    let params = LaunchParams::new_1d(grid, BLOCK_SIZE_U32, stream.raw());
406
407    check(unsafe { ffi::icffi_daxpy_f64(params, x.as_in(), y.as_out(), a) })
408}
409
410/// Scale vector: out = a * x
411///
412/// # Arguments
413///
414/// * `stream` - CUDA stream for the operation
415/// * `a` - Scalar multiplier
416/// * `x` - Input vector
417/// * `out` - Output vector
418///
419/// # Errors
420///
421/// Returns an error if the vectors have mismatched lengths or kernel launch fails.
422#[track_caller]
423pub fn scale_f32(
424    stream: &Stream,
425    a: f32,
426    x: &DeviceBuffer<f32>,
427    out: &mut DeviceBuffer<f32>,
428) -> Result<()> {
429    let n = x.len();
430    ensure_len_eq("scale_f32", "out", out.len(), "x", n)?;
431
432    if n == 0 {
433        return Ok(());
434    }
435
436    let grid = blocks_for(n, BLOCK_SIZE);
437    let grid = grid_u32("scale_f32", grid)?;
438    let params = LaunchParams::new_1d(grid, BLOCK_SIZE_U32, stream.raw());
439
440    check(unsafe { ffi::icffi_scale_f32(params, x.as_in(), out.as_out(), a) })
441}
442
443/// Parallel sum reduction (first pass).
444///
445/// Reduces input to per-block partial sums. For a complete reduction,
446/// call this function multiple times until the output has a single element.
447///
448/// # Arguments
449///
450/// * `stream` - CUDA stream for the operation
451/// * `input` - Input vector
452/// * `output` - Output vector (must have at least `blocks_needed(input.len())` elements)
453///
454/// # Returns
455///
456/// Returns the number of elements written to output.
457///
458/// # Errors
459///
460/// Returns an error if output buffer is too small or kernel launch fails.
461#[track_caller]
462pub fn reduce_sum_f32(
463    stream: &Stream,
464    input: &DeviceBuffer<f32>,
465    output: &mut DeviceBuffer<f32>,
466) -> Result<usize> {
467    reduce_sum_f32_len(stream, input, input.len(), output)
468}
469
470/// Parallel max reduction (first pass).
471///
472/// Reduces input to per-block partial maxima.
473///
474/// # Arguments
475///
476/// * `stream` - CUDA stream for the operation
477/// * `input` - Input vector
478/// * `output` - Output vector (must have at least `blocks_needed(input.len())` elements)
479///
480/// # Returns
481///
482/// Returns the number of elements written to output.
483///
484/// # Errors
485///
486/// Returns an error if output buffer is too small or kernel launch fails.
487#[track_caller]
488pub fn reduce_max_f32(
489    stream: &Stream,
490    input: &DeviceBuffer<f32>,
491    output: &mut DeviceBuffer<f32>,
492) -> Result<usize> {
493    reduce_max_f32_len(stream, input, input.len(), output)
494}
495
496/// Returns the number of output elements needed for reduction.
497#[inline]
498#[must_use]
499pub const fn reduction_output_size(input_len: usize) -> usize {
500    let q = input_len / REDUCTION_ELEMENTS_PER_BLOCK;
501    let r = input_len % REDUCTION_ELEMENTS_PER_BLOCK;
502    if r == 0 {
503        q
504    } else {
505        q + 1
506    }
507}
508
509/// Computes the sum of all elements in the input vector.
510///
511/// This is a convenience function that handles multi-pass reduction internally.
512/// It allocates temporary buffers from the CUDA memory pool and performs multiple
513/// reduction passes until a single value remains.
514///
515/// # Arguments
516///
517/// * `stream` - CUDA stream for the operation
518/// * `input` - Input vector to reduce
519///
520/// # Returns
521///
522/// The sum of all elements.
523///
524/// # Errors
525///
526/// Returns an error if allocation fails or kernel launch fails.
527///
528/// # Memory Pool Usage
529///
530/// This function uses `zeros_async` for pool-based allocation and properly
531/// returns buffers to the pool via `free_async` for optimal performance.
532///
533/// # Example
534///
535/// ```ignore
536/// let stream = Stream::new()?;
537/// let data = DeviceBuffer::from_slice_sync(&stream, &[1.0f32, 2.0, 3.0, 4.0])?;
538///
539/// let sum = reduce_sum_full(&stream, &data)?;
540/// assert_eq!(sum, 10.0);
541/// ```
542#[track_caller]
543pub fn reduce_sum_full(stream: &Stream, input: &DeviceBuffer<f32>) -> Result<f32> {
544    if input.is_empty() {
545        return Ok(0.0);
546    }
547
548    if input.len() == 1 {
549        let result = input.to_vec(stream)?;
550        return Ok(result[0]);
551    }
552
553    // Allocate primary working buffer from pool
554    let mut current_len = input.len();
555    let output_size = reduction_output_size(current_len);
556    let mut buf_a = DeviceBuffer::<f32>::zeros_async(stream, output_size)?;
557
558    // First pass: reduce input to partial sums
559    current_len = reduce_sum_f32_len(stream, input, input.len(), &mut buf_a)?;
560
561    // Additional passes until we have a single element
562    // Use Option to track whether buf_b was allocated
563    let mut buf_b: Option<DeviceBuffer<f32>> = None;
564
565    while current_len > 1 {
566        // Lazily allocate buf_b on first multi-pass iteration using Option::insert
567        // which sets the value and returns a mutable reference in one operation.
568        let b = match &mut buf_b {
569            Some(b) => b,
570            None => {
571                let size = reduction_output_size(current_len);
572                buf_b.insert(DeviceBuffer::<f32>::zeros_async(stream, size)?)
573            }
574        };
575
576        current_len = reduce_sum_f32_len(stream, &buf_a, current_len, b)?;
577        core::mem::swap(&mut buf_a, b);
578    }
579
580    // Copy final result to host (this synchronizes the stream)
581    let result = buf_a.to_vec(stream)?;
582    let value = result[0];
583
584    // Return buffers to the memory pool
585    // Note: Stream is already synchronized after to_vec, but free_async is
586    // still correct and maintains pool hygiene for subsequent allocations.
587    buf_a.free_async(stream)?;
588    if let Some(b) = buf_b {
589        b.free_async(stream)?;
590    }
591
592    Ok(value)
593}
594
595/// Computes the maximum of all elements in the input vector.
596///
597/// This is a convenience function that handles multi-pass reduction internally.
598/// It allocates temporary buffers from the CUDA memory pool and performs multiple
599/// reduction passes until a single value remains.
600///
601/// # Arguments
602///
603/// * `stream` - CUDA stream for the operation
604/// * `input` - Input vector to reduce
605///
606/// # Returns
607///
608/// The maximum element value.
609///
610/// # Errors
611///
612/// Returns an error if allocation fails or kernel launch fails.
613///
614/// # Memory Pool Usage
615///
616/// This function uses `zeros_async` for pool-based allocation and properly
617/// returns buffers to the pool via `free_async` for optimal performance.
618///
619/// # Example
620///
621/// ```ignore
622/// let stream = Stream::new()?;
623/// let data = DeviceBuffer::from_slice_sync(&stream, &[3.0f32, 1.0, 4.0, 1.5])?;
624///
625/// let max = reduce_max_full(&stream, &data)?;
626/// assert_eq!(max, 4.0);
627/// ```
628#[track_caller]
629pub fn reduce_max_full(stream: &Stream, input: &DeviceBuffer<f32>) -> Result<f32> {
630    if input.is_empty() {
631        return Ok(f32::NEG_INFINITY);
632    }
633
634    if input.len() == 1 {
635        let result = input.to_vec(stream)?;
636        return Ok(result[0]);
637    }
638
639    // Allocate primary working buffer from pool
640    let mut current_len = input.len();
641    let output_size = reduction_output_size(current_len);
642    let mut buf_a = DeviceBuffer::<f32>::zeros_async(stream, output_size)?;
643
644    // First pass: reduce input to partial maxima
645    current_len = reduce_max_f32_len(stream, input, input.len(), &mut buf_a)?;
646
647    // Additional passes until we have a single element
648    // Use Option to track whether buf_b was allocated
649    let mut buf_b: Option<DeviceBuffer<f32>> = None;
650
651    while current_len > 1 {
652        // Lazily allocate buf_b on first multi-pass iteration using Option::insert
653        // which sets the value and returns a mutable reference in one operation.
654        let b = match &mut buf_b {
655            Some(b) => b,
656            None => {
657                let size = reduction_output_size(current_len);
658                buf_b.insert(DeviceBuffer::<f32>::zeros_async(stream, size)?)
659            }
660        };
661
662        current_len = reduce_max_f32_len(stream, &buf_a, current_len, b)?;
663        core::mem::swap(&mut buf_a, b);
664    }
665
666    // Copy final result to host (this synchronizes the stream)
667    let result = buf_a.to_vec(stream)?;
668    let value = result[0];
669
670    // Return buffers to the memory pool
671    // Note: Stream is already synchronized after to_vec, but free_async is
672    // still correct and maintains pool hygiene for subsequent allocations.
673    buf_a.free_async(stream)?;
674    if let Some(b) = buf_b {
675        b.free_async(stream)?;
676    }
677
678    Ok(value)
679}
680
681/// Ensures the ABI asserts translation unit is linked.
682///
683/// Call this function once at startup to verify the ABI asserts compiled
684/// successfully.
685pub fn verify_abi_linked() {
686    unsafe {
687        ffi::icffi_abi_asserts_linked();
688    }
689}
690
691#[cfg(test)]
692mod lib_test;