Skip to main content

baracuda_kernels/indexing/
index_select.rs

1//! `index_select` plan — Category L.
2//!
3//! `out[..., j, ...] = src[..., idx[j], ...]` along the `select_dim`
4//! axis. `idx` is a 1-D `i32` tensor; output shape == source shape
5//! with `select_dim` replaced by `idx.numel()`. PyTorch
6//! `torch.index_select`.
7//!
8//! Faster / simpler than [`crate::indexing::GatherPlan`] when the index
9//! tensor is 1-D — `gather` accepts an N-D index broadcast to the
10//! output shape, while `index_select` collapses to a single 1-D lookup.
11//!
12//! Trailblazer dtype coverage: `f32, f64, i32`. The kernel does no
13//! arithmetic — pure load + store — so output is bit-exact at every
14//! dtype.
15
16use core::ffi::c_void;
17use core::marker::PhantomData;
18
19use baracuda_cutlass::{Error, Result};
20use baracuda_driver::Stream;
21use baracuda_kernels_types::{
22    ArchSku, BackendKind, Element, ElementKind, IndexElement, IndexElementKind, IndexingKind,
23    KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
24    TensorRef, Workspace,
25};
26
27/// Descriptor for an `index_select` op.
28#[derive(Copy, Clone, Debug)]
29pub struct IndexSelectDescriptor<const N: usize> {
30    /// Output tensor shape.
31    pub out_shape: [i32; N],
32    /// Axis along which to select. Must be in `[0, N)`.
33    pub select_dim: i32,
34    /// Extent of `src` along `select_dim` (bounds check on indices).
35    pub src_dim_size: i32,
36    /// Value element type.
37    pub element: ElementKind,
38}
39
40/// Args bundle for an `index_select` launch.
41///
42/// Phase 11.5: `I: IndexElement` generic (`i32` or `i64`).
43pub struct IndexSelectArgs<'a, T: Element, const N: usize, I: IndexElement = i32> {
44    /// Source tensor.
45    pub src: TensorRef<'a, T, N>,
46    /// Index tensor (1-D). `idx.numel()` must equal
47    /// `out_shape[select_dim]`. `i32` (legacy) or `i64` (PyTorch).
48    pub idx: TensorRef<'a, I, 1>,
49    /// Output. Shape == descriptor `out_shape`.
50    pub out: TensorMut<'a, T, N>,
51}
52
53/// `index_select` plan.
54///
55/// `out[..., j, ...] = src[..., idx[j], ...]` along the `select_dim`
56/// axis (PyTorch `torch.index_select`).
57///
58/// **When to use**: 1-D index lookups. For N-D `index` (broadcast to
59/// the output shape) use [`GatherPlan`](crate::GatherPlan); this op
60/// is the faster specialization. Pair with
61/// [`IndexSelectBackwardPlan`](crate::IndexSelectBackwardPlan) for
62/// autograd.
63///
64/// **Dtypes**: value `{f32, f64, i32}`; index always 1-D `i32`.
65///
66/// **Shape limits**: rank in `[1, 8]`; `select_dim ∈ [0, N)`;
67/// `idx.numel() == out_shape[select_dim]`.
68///
69/// **Workspace**: none.
70///
71/// **Precision guarantee**: deterministic, bit-stable on same
72/// hardware. Pure load + store, no arithmetic — bit-exact at every
73/// dtype.
74///
75/// **Index policy**: out-of-bounds and negative indices skipped.
76pub struct IndexSelectPlan<T: Element, const N: usize> {
77    desc: IndexSelectDescriptor<N>,
78    sku: KernelSku,
79    _marker: PhantomData<T>,
80}
81
82impl<T: Element, const N: usize> IndexSelectPlan<T, N> {
83    /// Pick a kernel for `desc`. Validates element-type alignment,
84    /// rank, axis, non-negative extents, and dtype.
85    pub fn select(
86        _stream: &Stream,
87        desc: &IndexSelectDescriptor<N>,
88        _pref: PlanPreference,
89    ) -> Result<Self> {
90        if desc.element != T::KIND {
91            return Err(Error::Unsupported(
92                "baracuda-kernels::IndexSelectPlan: descriptor element != type parameter T",
93            ));
94        }
95        if N == 0 {
96            return Err(Error::InvalidProblem(
97                "baracuda-kernels::IndexSelectPlan: rank-0 tensors not supported",
98            ));
99        }
100        if desc.select_dim < 0 || desc.select_dim >= N as i32 {
101            return Err(Error::InvalidProblem(
102                "baracuda-kernels::IndexSelectPlan: select_dim out of range [0, N)",
103            ));
104        }
105        if desc.src_dim_size < 0 {
106            return Err(Error::InvalidProblem(
107                "baracuda-kernels::IndexSelectPlan: src_dim_size must be non-negative",
108            ));
109        }
110
111        let supported =
112            matches!(T::KIND, ElementKind::F32 | ElementKind::F64 | ElementKind::I32);
113        if !supported {
114            return Err(Error::Unsupported(
115                "baracuda-kernels::IndexSelectPlan: today only `f32`, `f64`, `i32` wired",
116            ));
117        }
118
119        let precision_guarantee = PrecisionGuarantee {
120            math_precision: MathPrecision::F32,
121            accumulator: ElementKind::F32,
122            bit_stable_on_same_hardware: true,
123            deterministic: true,
124        };
125        let sku = KernelSku {
126            category: OpCategory::Indexing,
127            op: IndexingKind::IndexSelect as u16,
128            element: T::KIND,
129            aux_element: Some(ElementKind::I32),
130            layout: None,
131            epilogue: None,
132            arch: ArchSku::Sm80,
133            backend: BackendKind::Bespoke,
134            precision_guarantee,
135        };
136        Ok(Self {
137            desc: *desc,
138            sku,
139            _marker: PhantomData,
140        })
141    }
142
143    /// Validate `args` against the descriptor: output shape match,
144    /// idx length matches `out_shape[select_dim]`, rank ≤ 8, device
145    /// buffers large enough.
146    pub fn can_implement<I: IndexElement>(&self, args: &IndexSelectArgs<'_, T, N, I>) -> Result<()> {
147        if args.out.shape != self.desc.out_shape {
148            return Err(Error::InvalidProblem(
149                "baracuda-kernels::IndexSelectPlan: out shape mismatch with descriptor",
150            ));
151        }
152        let expected_idx = self.desc.out_shape[self.desc.select_dim as usize];
153        if args.idx.shape[0] != expected_idx {
154            return Err(Error::InvalidProblem(
155                "baracuda-kernels::IndexSelectPlan: idx.shape[0] must equal \
156                 out_shape[select_dim]",
157            ));
158        }
159        if N > 8 {
160            return Err(Error::Unsupported(
161                "baracuda-kernels::IndexSelectPlan: tensor rank > 8 not supported",
162            ));
163        }
164        let out_numel = args.out.numel();
165        let idx_numel = args.idx.numel();
166        let out_len = args.out.data.len() as i64;
167        let idx_len = args.idx.data.len() as i64;
168        if out_len < out_numel {
169            return Err(Error::BufferTooSmall {
170                needed: out_numel as usize,
171                got: out_len as usize,
172            });
173        }
174        if idx_len < idx_numel {
175            return Err(Error::BufferTooSmall {
176                needed: idx_numel as usize,
177                got: idx_len as usize,
178            });
179        }
180        Ok(())
181    }
182
183    /// Workspace size in bytes. Always zero.
184    #[inline]
185    pub fn workspace_size(&self) -> usize {
186        0
187    }
188
189    /// Identity of the kernel this plan picked.
190    #[inline]
191    pub fn sku(&self) -> KernelSku {
192        self.sku
193    }
194
195    /// Numerical guarantees for this plan's kernel.
196    #[inline]
197    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
198        self.sku.precision_guarantee
199    }
200
201    /// Launch the kernel on `stream`. Returns early on zero-element
202    /// output. `workspace` ignored.
203    ///
204    /// Phase 11.5: generic over `I: IndexElement`.
205    pub fn run<I: IndexElement>(
206        &self,
207        stream: &Stream,
208        _workspace: Workspace<'_>,
209        args: IndexSelectArgs<'_, T, N, I>,
210    ) -> Result<()> {
211        self.can_implement(&args)?;
212        let out_numel = args.out.numel();
213        if out_numel == 0 {
214            return Ok(());
215        }
216        let src_ptr = args.src.data.as_raw().0 as *const c_void;
217        let idx_ptr = args.idx.data.as_raw().0 as *const c_void;
218        let out_ptr = args.out.data.as_raw().0 as *mut c_void;
219        let stream_ptr = stream.as_raw() as *mut c_void;
220
221        let out_shape = self.desc.out_shape;
222        let stride_src = args.src.stride;
223        let stride_out = args.out.stride;
224        let rank = N as i32;
225
226        let status = match (T::KIND, I::KIND) {
227            (ElementKind::F32, IndexElementKind::I32) => unsafe {
228                baracuda_kernels_sys::baracuda_kernels_index_select_f32_run(
229                    out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
230                    out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
231                    src_ptr, idx_ptr, out_ptr,
232                    core::ptr::null_mut(), 0, stream_ptr,
233                )
234            },
235            (ElementKind::F64, IndexElementKind::I32) => unsafe {
236                baracuda_kernels_sys::baracuda_kernels_index_select_f64_run(
237                    out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
238                    out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
239                    src_ptr, idx_ptr, out_ptr,
240                    core::ptr::null_mut(), 0, stream_ptr,
241                )
242            },
243            (ElementKind::I32, IndexElementKind::I32) => unsafe {
244                baracuda_kernels_sys::baracuda_kernels_index_select_i32_run(
245                    out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
246                    out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
247                    src_ptr, idx_ptr, out_ptr,
248                    core::ptr::null_mut(), 0, stream_ptr,
249                )
250            },
251            (ElementKind::F32, IndexElementKind::I64) => unsafe {
252                baracuda_kernels_sys::baracuda_kernels_index_select_i64idx_f32_run(
253                    out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
254                    out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
255                    src_ptr, idx_ptr, out_ptr,
256                    core::ptr::null_mut(), 0, stream_ptr,
257                )
258            },
259            (ElementKind::F64, IndexElementKind::I64) => unsafe {
260                baracuda_kernels_sys::baracuda_kernels_index_select_i64idx_f64_run(
261                    out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
262                    out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
263                    src_ptr, idx_ptr, out_ptr,
264                    core::ptr::null_mut(), 0, stream_ptr,
265                )
266            },
267            (ElementKind::I32, IndexElementKind::I64) => unsafe {
268                baracuda_kernels_sys::baracuda_kernels_index_select_i64idx_i32_run(
269                    out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
270                    out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
271                    src_ptr, idx_ptr, out_ptr,
272                    core::ptr::null_mut(), 0, stream_ptr,
273                )
274            },
275            _ => {
276                return Err(Error::Unsupported(
277                    "baracuda-kernels::IndexSelectPlan::run reached an unimplemented dtype \
278                     — select() should have caught this",
279                ));
280            }
281        };
282        super::gather::map_status(status)
283    }
284}