Skip to main content

baracuda_kernels/indexing/
index_add.rs

1//! `index_add` plan — Category L (Phase 39).
2//!
3//! `dst[idx[i], ...] += src[i, ...]` along the `add_dim` axis. `idx` is
4//! a 1-D index tensor of length `src.shape[add_dim]`. PyTorch
5//! `torch.Tensor.index_add_`.
6//!
7//! Algorithmically identical to [`crate::IndexSelectBackwardPlan`] (a
8//! 1-D-idx atomic-Σ scatter along a single axis) but exposed under a
9//! non-autograd-flavored name + with f16 / bf16 dtype fanout that the
10//! autograd plan deliberately stops short of.
11//!
12//! **Dtype coverage (Phase 39 Tier 1)**: `{f32, f64, f16, bf16}` × index
13//! `{i32, i64}` = 8 FFI symbols. f16 / bf16 use `atomicCAS`-via-
14//! `baracuda::atomic::add<T>` (Phase 11.3 helper); f32 / f64 use native
15//! `atomicAdd`. Per-thread arithmetic is deterministic; accumulation
16//! order is not.
17
18use core::ffi::c_void;
19use core::marker::PhantomData;
20
21use baracuda_cutlass::{Error, Result};
22use baracuda_driver::Stream;
23use baracuda_kernels_types::{
24    ArchSku, BackendKind, Element, ElementKind, IndexElement, IndexElementKind, IndexingKind,
25    KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
26    TensorRef, Workspace,
27};
28
29use super::gather::map_status;
30
31/// Descriptor for an `index_add` op.
32#[derive(Copy, Clone, Debug)]
33pub struct IndexAddDescriptor<const N: usize> {
34    /// Shape of `src` (the per-row values added into `dst`).
35    pub src_shape: [i32; N],
36    /// Axis along which `dst` is indexed (must be in `[0, N)`).
37    pub add_dim: i32,
38    /// Extent of `dst` along `add_dim` (bounds check on `idx` entries).
39    pub dst_dim_size: i32,
40    /// Value element type.
41    pub element: ElementKind,
42}
43
44/// Args bundle for an `index_add` launch.
45pub struct IndexAddArgs<'a, T: Element, const N: usize, I: IndexElement = i32> {
46    /// Source tensor (values to add into `dst`).
47    pub src: TensorRef<'a, T, N>,
48    /// Index tensor (1-D). `idx.numel()` must equal `src.shape[add_dim]`.
49    pub idx: TensorRef<'a, I, 1>,
50    /// Destination. Accumulated into via atomicAdd-Σ — caller pre-zeroes
51    /// (for pure index_add semantics) or pre-populates (for
52    /// `dst += accumulate(src, idx)` semantics).
53    pub dst: TensorMut<'a, T, N>,
54}
55
56/// `index_add` plan.
57///
58/// `dst[idx[i], ...] += src[i, ...]` along `add_dim` via atomicAdd-Σ
59/// (duplicate-index safe).
60///
61/// **When to use**: forward `index_add` (PyTorch
62/// `torch.Tensor.index_add_`). For the inverse — copying `dst[idx[j]]`
63/// rows out into a contiguous tensor — use
64/// [`IndexSelectPlan`](crate::IndexSelectPlan).
65///
66/// **Dtypes**: `{f32, f64, f16, bf16}`. f16 / bf16 use the CAS-based
67/// `baracuda::atomic::add<T>` helper for deterministic per-thread
68/// arithmetic.
69///
70/// **Shape limits**: rank in `[1, 8]`; `add_dim ∈ [0, N)`; idx 1-D
71/// with `idx.numel() == src.shape[add_dim]`.
72///
73/// **Workspace**: none. Caller pre-zeros (or pre-populates) `dst`.
74///
75/// **Precision guarantee**: **non-deterministic accumulation order**
76/// (atomicAdd). Per-thread arithmetic is bit-stable on same hardware.
77///
78/// **Index policy**: out-of-bounds and negative indices skipped.
79pub struct IndexAddPlan<T: Element, const N: usize> {
80    desc: IndexAddDescriptor<N>,
81    sku: KernelSku,
82    _marker: PhantomData<T>,
83}
84
85impl<T: Element, const N: usize> IndexAddPlan<T, N> {
86    /// Pick a kernel for `desc`.
87    pub fn select(
88        _stream: &Stream,
89        desc: &IndexAddDescriptor<N>,
90        _pref: PlanPreference,
91    ) -> Result<Self> {
92        if desc.element != T::KIND {
93            return Err(Error::Unsupported(
94                "baracuda-kernels::IndexAddPlan: descriptor element != type parameter T",
95            ));
96        }
97        if N == 0 {
98            return Err(Error::InvalidProblem(
99                "baracuda-kernels::IndexAddPlan: rank-0 tensors not supported",
100            ));
101        }
102        if desc.add_dim < 0 || desc.add_dim >= N as i32 {
103            return Err(Error::InvalidProblem(
104                "baracuda-kernels::IndexAddPlan: add_dim out of range [0, N)",
105            ));
106        }
107        if desc.dst_dim_size < 0 {
108            return Err(Error::InvalidProblem(
109                "baracuda-kernels::IndexAddPlan: dst_dim_size must be non-negative",
110            ));
111        }
112        for &d in desc.src_shape.iter() {
113            if d < 0 {
114                return Err(Error::InvalidProblem(
115                    "baracuda-kernels::IndexAddPlan: src_shape dims must be non-negative",
116                ));
117            }
118        }
119
120        let supported = matches!(
121            T::KIND,
122            ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
123        );
124        if !supported {
125            return Err(Error::Unsupported(
126                "baracuda-kernels::IndexAddPlan: today only `f32`, `f64`, `f16`, `bf16` wired",
127            ));
128        }
129
130        let precision_guarantee = PrecisionGuarantee {
131            math_precision: if T::KIND == ElementKind::F64 {
132                MathPrecision::F64
133            } else {
134                MathPrecision::F32
135            },
136            accumulator: T::KIND,
137            // atomicAdd order is non-deterministic across launches.
138            bit_stable_on_same_hardware: false,
139            deterministic: false,
140        };
141        let sku = KernelSku {
142            category: OpCategory::Indexing,
143            op: IndexingKind::IndexAdd as u16,
144            element: T::KIND,
145            aux_element: Some(ElementKind::I32),
146            layout: None,
147            epilogue: None,
148            arch: ArchSku::Sm80,
149            backend: BackendKind::Bespoke,
150            precision_guarantee,
151        };
152        Ok(Self {
153            desc: *desc,
154            sku,
155            _marker: PhantomData,
156        })
157    }
158
159    /// Validate `args` against the descriptor.
160    pub fn can_implement<I: IndexElement>(&self, args: &IndexAddArgs<'_, T, N, I>) -> Result<()> {
161        if args.src.shape != self.desc.src_shape {
162            return Err(Error::InvalidProblem(
163                "baracuda-kernels::IndexAddPlan: src shape mismatch with descriptor",
164            ));
165        }
166        let expected_idx = self.desc.src_shape[self.desc.add_dim as usize];
167        if args.idx.shape[0] != expected_idx {
168            return Err(Error::InvalidProblem(
169                "baracuda-kernels::IndexAddPlan: idx.shape[0] must equal \
170                 src_shape[add_dim]",
171            ));
172        }
173        if N > 8 {
174            return Err(Error::Unsupported(
175                "baracuda-kernels::IndexAddPlan: tensor rank > 8 not supported",
176            ));
177        }
178        let src_numel = args.src.numel();
179        let idx_numel = args.idx.numel();
180        let src_len = args.src.data.len() as i64;
181        let idx_len = args.idx.data.len() as i64;
182        if src_len < src_numel {
183            return Err(Error::BufferTooSmall {
184                needed: src_numel as usize,
185                got: src_len as usize,
186            });
187        }
188        if idx_len < idx_numel {
189            return Err(Error::BufferTooSmall {
190                needed: idx_numel as usize,
191                got: idx_len as usize,
192            });
193        }
194        Ok(())
195    }
196
197    /// Workspace size in bytes. Always zero.
198    #[inline]
199    pub fn workspace_size(&self) -> usize {
200        0
201    }
202
203    /// Identity of the kernel this plan picked.
204    #[inline]
205    pub fn sku(&self) -> KernelSku {
206        self.sku
207    }
208
209    /// Numerical guarantees for this plan's kernel.
210    #[inline]
211    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
212        self.sku.precision_guarantee
213    }
214
215    /// Launch the kernel on `stream`. Caller must have zeroed (or
216    /// pre-populated) `dst` before this call. `workspace` ignored.
217    pub fn run<I: IndexElement>(
218        &self,
219        stream: &Stream,
220        _workspace: Workspace<'_>,
221        args: IndexAddArgs<'_, T, N, I>,
222    ) -> Result<()> {
223        self.can_implement(&args)?;
224        let src_numel = args.src.numel();
225        if src_numel == 0 {
226            return Ok(());
227        }
228        let src_ptr = args.src.data.as_raw().0 as *const c_void;
229        let idx_ptr = args.idx.data.as_raw().0 as *const c_void;
230        let dst_ptr = args.dst.data.as_raw().0 as *mut c_void;
231        let stream_ptr = stream.as_raw() as *mut c_void;
232
233        let src_shape = self.desc.src_shape;
234        let stride_src = args.src.stride;
235        let stride_dst = args.dst.stride;
236        let rank = N as i32;
237
238        let status = match (T::KIND, I::KIND) {
239            (ElementKind::F32, IndexElementKind::I32) => unsafe {
240                baracuda_kernels_sys::baracuda_kernels_index_add_f32_run(
241                    src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
242                    src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
243                    src_ptr, idx_ptr, dst_ptr,
244                    core::ptr::null_mut(), 0, stream_ptr,
245                )
246            },
247            (ElementKind::F64, IndexElementKind::I32) => unsafe {
248                baracuda_kernels_sys::baracuda_kernels_index_add_f64_run(
249                    src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
250                    src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
251                    src_ptr, idx_ptr, dst_ptr,
252                    core::ptr::null_mut(), 0, stream_ptr,
253                )
254            },
255            (ElementKind::F16, IndexElementKind::I32) => unsafe {
256                baracuda_kernels_sys::baracuda_kernels_index_add_f16_run(
257                    src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
258                    src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
259                    src_ptr, idx_ptr, dst_ptr,
260                    core::ptr::null_mut(), 0, stream_ptr,
261                )
262            },
263            (ElementKind::Bf16, IndexElementKind::I32) => unsafe {
264                baracuda_kernels_sys::baracuda_kernels_index_add_bf16_run(
265                    src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
266                    src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
267                    src_ptr, idx_ptr, dst_ptr,
268                    core::ptr::null_mut(), 0, stream_ptr,
269                )
270            },
271            (ElementKind::F32, IndexElementKind::I64) => unsafe {
272                baracuda_kernels_sys::baracuda_kernels_index_add_i64idx_f32_run(
273                    src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
274                    src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
275                    src_ptr, idx_ptr, dst_ptr,
276                    core::ptr::null_mut(), 0, stream_ptr,
277                )
278            },
279            (ElementKind::F64, IndexElementKind::I64) => unsafe {
280                baracuda_kernels_sys::baracuda_kernels_index_add_i64idx_f64_run(
281                    src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
282                    src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
283                    src_ptr, idx_ptr, dst_ptr,
284                    core::ptr::null_mut(), 0, stream_ptr,
285                )
286            },
287            (ElementKind::F16, IndexElementKind::I64) => unsafe {
288                baracuda_kernels_sys::baracuda_kernels_index_add_i64idx_f16_run(
289                    src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
290                    src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
291                    src_ptr, idx_ptr, dst_ptr,
292                    core::ptr::null_mut(), 0, stream_ptr,
293                )
294            },
295            (ElementKind::Bf16, IndexElementKind::I64) => unsafe {
296                baracuda_kernels_sys::baracuda_kernels_index_add_i64idx_bf16_run(
297                    src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
298                    src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
299                    src_ptr, idx_ptr, dst_ptr,
300                    core::ptr::null_mut(), 0, stream_ptr,
301                )
302            },
303            _ => {
304                return Err(Error::Unsupported(
305                    "baracuda-kernels::IndexAddPlan::run reached an unimplemented dtype \
306                     — select() should have caught this",
307                ));
308            }
309        };
310        map_status(status)
311    }
312}