Skip to main content

baracuda_kernels/indexing/
scatter.rs

1//! `scatter` (pure assign) plan — Category L (Phase 39).
2//!
3//! `out[..., index[..., j, ...], ...] = updates[..., j, ...]` along the
4//! `scatter_dim` axis. **No accumulation** — if multiple updates target
5//! the same output cell, the **last writer wins** (race; the result is
6//! non-deterministic but the per-cell value is always one of the
7//! contributing writes — never a partial / torn store, since one
8//! element fits in a single 32/64-bit write).
9//!
10//! PyTorch `torch.Tensor.scatter_` (the in-place pure-assign variant).
11//! Distinct from [`crate::ScatterAddPlan`] which atomically Σ-accumulates.
12//!
13//! **Dtype coverage (Phase 39 Tier 1)**: `{f32, f64, f16, bf16}` × index
14//! `{i32, i64}` = 8 FFI symbols. The kernel does no arithmetic, only
15//! stores, so all four dtypes ship in the trailblazer.
16//!
17//! **Tests should use disjoint targets** to keep results deterministic.
18//! Duplicate-target writes are an advisory feature of the op semantics,
19//! not something callers should rely on for any specific outcome.
20
21use core::ffi::c_void;
22use core::marker::PhantomData;
23
24use baracuda_cutlass::{Error, Result};
25use baracuda_driver::Stream;
26use baracuda_kernels_types::{
27    ArchSku, BackendKind, Element, ElementKind, IndexElement, IndexElementKind, IndexingKind,
28    KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
29    TensorRef, Workspace,
30};
31
32use super::gather::map_status;
33
34/// Descriptor for a `scatter` (pure assign) op.
35///
36/// Identifies the shape of `updates` (== `index` shape), the axis, and
37/// the extent of `out` along that axis. `T::KIND` must equal `element`.
38#[derive(Copy, Clone, Debug)]
39pub struct ScatterDescriptor<const N: usize> {
40    /// Shape of `updates` / `index`.
41    pub upd_shape: [i32; N],
42    /// Scatter axis (must be in `[0, N)`).
43    pub scatter_dim: i32,
44    /// Extent of `out` along `scatter_dim` (in-bounds check on indices).
45    pub out_dim_size: i32,
46    /// Value element type.
47    pub element: ElementKind,
48}
49
50/// Args bundle for a `scatter` (pure assign) launch.
51pub struct ScatterArgs<'a, T: Element, const N: usize, I: IndexElement = i32> {
52    /// Update values.
53    pub updates: TensorRef<'a, T, N>,
54    /// Index tensor. Same shape as `updates`. `i32` (legacy) or `i64`
55    /// (PyTorch default).
56    pub index: TensorRef<'a, I, N>,
57    /// Output. **Overwritten** (not accumulated). Caller pre-populates
58    /// any cells that should retain their value when no index targets
59    /// them (the kernel only touches cells named by `index`).
60    pub out: TensorMut<'a, T, N>,
61}
62
63/// `scatter` (pure assign) plan.
64///
65/// `out[..., index[..., j, ...], ...] = updates[..., j, ...]` along
66/// `scatter_dim` — **no accumulation**. Last writer wins on
67/// duplicate-target races (caller-aware non-determinism).
68///
69/// **When to use**: forward `scatter` (PyTorch
70/// `torch.Tensor.scatter_`). For Σ-accumulation use
71/// [`ScatterAddPlan`](crate::ScatterAddPlan).
72///
73/// **Dtypes**: `{f32, f64, f16, bf16}`. Pure store, no arithmetic.
74///
75/// **Shape limits**: rank in `[1, 8]`; `scatter_dim ∈ [0, N)`;
76/// `out_dim_size ≥ 0`. `updates` and `index` must share shape.
77///
78/// **Workspace**: none.
79///
80/// **Precision guarantee**: **non-deterministic** on duplicate-target
81/// indices (race condition). For disjoint-index workloads the output
82/// is deterministic and bit-exact (pure copy, no arithmetic).
83///
84/// **Index policy**: out-of-bounds and negative indices are skipped
85/// (no PyTorch-style wraparound).
86pub struct ScatterPlan<T: Element, const N: usize> {
87    desc: ScatterDescriptor<N>,
88    sku: KernelSku,
89    _marker: PhantomData<T>,
90}
91
92impl<T: Element, const N: usize> ScatterPlan<T, N> {
93    /// Pick a kernel for `desc`. Validates element-type alignment,
94    /// rank, axis, non-negative extents, and dtype in
95    /// `{f32, f64, f16, bf16}`.
96    pub fn select(
97        _stream: &Stream,
98        desc: &ScatterDescriptor<N>,
99        _pref: PlanPreference,
100    ) -> Result<Self> {
101        if desc.element != T::KIND {
102            return Err(Error::Unsupported(
103                "baracuda-kernels::ScatterPlan: descriptor element != type parameter T",
104            ));
105        }
106        if N == 0 {
107            return Err(Error::InvalidProblem(
108                "baracuda-kernels::ScatterPlan: rank-0 tensors not supported",
109            ));
110        }
111        if desc.scatter_dim < 0 || desc.scatter_dim >= N as i32 {
112            return Err(Error::InvalidProblem(
113                "baracuda-kernels::ScatterPlan: scatter_dim out of range [0, N)",
114            ));
115        }
116        if desc.out_dim_size < 0 {
117            return Err(Error::InvalidProblem(
118                "baracuda-kernels::ScatterPlan: out_dim_size must be non-negative",
119            ));
120        }
121        for &d in desc.upd_shape.iter() {
122            if d < 0 {
123                return Err(Error::InvalidProblem(
124                    "baracuda-kernels::ScatterPlan: upd_shape dims must be non-negative",
125                ));
126            }
127        }
128
129        let supported = matches!(
130            T::KIND,
131            ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
132        );
133        if !supported {
134            return Err(Error::Unsupported(
135                "baracuda-kernels::ScatterPlan: today only `f32`, `f64`, `f16`, `bf16` wired",
136            ));
137        }
138
139        let precision_guarantee = PrecisionGuarantee {
140            math_precision: MathPrecision::F32,
141            accumulator: T::KIND,
142            // Pure store: bit-stable per-cell when targets are disjoint,
143            // but the op semantics permit duplicate targets which race.
144            bit_stable_on_same_hardware: false,
145            deterministic: false,
146        };
147        let sku = KernelSku {
148            category: OpCategory::Indexing,
149            op: IndexingKind::Scatter as u16,
150            element: T::KIND,
151            aux_element: Some(ElementKind::I32),
152            layout: None,
153            epilogue: None,
154            arch: ArchSku::Sm80,
155            backend: BackendKind::Bespoke,
156            precision_guarantee,
157        };
158        Ok(Self {
159            desc: *desc,
160            sku,
161            _marker: PhantomData,
162        })
163    }
164
165    /// Validate `args` against the descriptor.
166    pub fn can_implement<I: IndexElement>(&self, args: &ScatterArgs<'_, T, N, I>) -> Result<()> {
167        if args.updates.shape != self.desc.upd_shape {
168            return Err(Error::InvalidProblem(
169                "baracuda-kernels::ScatterPlan: updates shape mismatch with descriptor",
170            ));
171        }
172        if args.index.shape != self.desc.upd_shape {
173            return Err(Error::InvalidProblem(
174                "baracuda-kernels::ScatterPlan: index shape must equal updates shape",
175            ));
176        }
177        if N > 8 {
178            return Err(Error::Unsupported(
179                "baracuda-kernels::ScatterPlan: tensor rank > 8 not supported",
180            ));
181        }
182        let upd_numel = args.updates.numel();
183        let upd_len = args.updates.data.len() as i64;
184        let idx_len = args.index.data.len() as i64;
185        if upd_len < upd_numel {
186            return Err(Error::BufferTooSmall {
187                needed: upd_numel as usize,
188                got: upd_len as usize,
189            });
190        }
191        if idx_len < upd_numel {
192            return Err(Error::BufferTooSmall {
193                needed: upd_numel as usize,
194                got: idx_len as usize,
195            });
196        }
197        Ok(())
198    }
199
200    /// Workspace size in bytes. Always zero.
201    #[inline]
202    pub fn workspace_size(&self) -> usize {
203        0
204    }
205
206    /// Identity of the kernel this plan picked.
207    #[inline]
208    pub fn sku(&self) -> KernelSku {
209        self.sku
210    }
211
212    /// Numerical guarantees for this plan's kernel.
213    #[inline]
214    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
215        self.sku.precision_guarantee
216    }
217
218    /// Launch the kernel on `stream`. Caller pre-populates `out`; the
219    /// kernel only writes cells named by `index`. `workspace` ignored.
220    pub fn run<I: IndexElement>(
221        &self,
222        stream: &Stream,
223        _workspace: Workspace<'_>,
224        args: ScatterArgs<'_, T, N, I>,
225    ) -> Result<()> {
226        self.can_implement(&args)?;
227        let upd_numel = args.updates.numel();
228        if upd_numel == 0 {
229            return Ok(());
230        }
231        let upd_ptr = args.updates.data.as_raw().0 as *const c_void;
232        let idx_ptr = args.index.data.as_raw().0 as *const c_void;
233        let out_ptr = args.out.data.as_raw().0 as *mut c_void;
234        let stream_ptr = stream.as_raw() as *mut c_void;
235
236        let upd_shape = self.desc.upd_shape;
237        let stride_upd = args.updates.stride;
238        let stride_index = args.index.stride;
239        let stride_out = args.out.stride;
240        let rank = N as i32;
241
242        let status = match (T::KIND, I::KIND) {
243            (ElementKind::F32, IndexElementKind::I32) => unsafe {
244                baracuda_kernels_sys::baracuda_kernels_scatter_f32_run(
245                    upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
246                    upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
247                    stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
248                    core::ptr::null_mut(), 0, stream_ptr,
249                )
250            },
251            (ElementKind::F64, IndexElementKind::I32) => unsafe {
252                baracuda_kernels_sys::baracuda_kernels_scatter_f64_run(
253                    upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
254                    upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
255                    stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
256                    core::ptr::null_mut(), 0, stream_ptr,
257                )
258            },
259            (ElementKind::F16, IndexElementKind::I32) => unsafe {
260                baracuda_kernels_sys::baracuda_kernels_scatter_f16_run(
261                    upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
262                    upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
263                    stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
264                    core::ptr::null_mut(), 0, stream_ptr,
265                )
266            },
267            (ElementKind::Bf16, IndexElementKind::I32) => unsafe {
268                baracuda_kernels_sys::baracuda_kernels_scatter_bf16_run(
269                    upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
270                    upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
271                    stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
272                    core::ptr::null_mut(), 0, stream_ptr,
273                )
274            },
275            (ElementKind::F32, IndexElementKind::I64) => unsafe {
276                baracuda_kernels_sys::baracuda_kernels_scatter_i64idx_f32_run(
277                    upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
278                    upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
279                    stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
280                    core::ptr::null_mut(), 0, stream_ptr,
281                )
282            },
283            (ElementKind::F64, IndexElementKind::I64) => unsafe {
284                baracuda_kernels_sys::baracuda_kernels_scatter_i64idx_f64_run(
285                    upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
286                    upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
287                    stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
288                    core::ptr::null_mut(), 0, stream_ptr,
289                )
290            },
291            (ElementKind::F16, IndexElementKind::I64) => unsafe {
292                baracuda_kernels_sys::baracuda_kernels_scatter_i64idx_f16_run(
293                    upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
294                    upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
295                    stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
296                    core::ptr::null_mut(), 0, stream_ptr,
297                )
298            },
299            (ElementKind::Bf16, IndexElementKind::I64) => unsafe {
300                baracuda_kernels_sys::baracuda_kernels_scatter_i64idx_bf16_run(
301                    upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
302                    upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
303                    stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
304                    core::ptr::null_mut(), 0, stream_ptr,
305                )
306            },
307            _ => {
308                return Err(Error::Unsupported(
309                    "baracuda-kernels::ScatterPlan::run reached an unimplemented dtype \
310                     — select() should have caught this",
311                ));
312            }
313        };
314        map_status(status)
315    }
316}