1use core::ffi::c_void;
17use core::marker::PhantomData;
18
19use baracuda_cutlass::{Error, Result};
20use baracuda_driver::Stream;
21use baracuda_kernels_types::{
22 ArchSku, BackendKind, Element, ElementKind, IndexElement, IndexElementKind, IndexingKind,
23 KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
24 TensorRef, Workspace,
25};
26
27#[derive(Copy, Clone, Debug)]
29pub struct IndexSelectDescriptor<const N: usize> {
30 pub out_shape: [i32; N],
32 pub select_dim: i32,
34 pub src_dim_size: i32,
36 pub element: ElementKind,
38}
39
40pub struct IndexSelectArgs<'a, T: Element, const N: usize, I: IndexElement = i32> {
44 pub src: TensorRef<'a, T, N>,
46 pub idx: TensorRef<'a, I, 1>,
49 pub out: TensorMut<'a, T, N>,
51}
52
53pub struct IndexSelectPlan<T: Element, const N: usize> {
77 desc: IndexSelectDescriptor<N>,
78 sku: KernelSku,
79 _marker: PhantomData<T>,
80}
81
82impl<T: Element, const N: usize> IndexSelectPlan<T, N> {
83 pub fn select(
86 _stream: &Stream,
87 desc: &IndexSelectDescriptor<N>,
88 _pref: PlanPreference,
89 ) -> Result<Self> {
90 if desc.element != T::KIND {
91 return Err(Error::Unsupported(
92 "baracuda-kernels::IndexSelectPlan: descriptor element != type parameter T",
93 ));
94 }
95 if N == 0 {
96 return Err(Error::InvalidProblem(
97 "baracuda-kernels::IndexSelectPlan: rank-0 tensors not supported",
98 ));
99 }
100 if desc.select_dim < 0 || desc.select_dim >= N as i32 {
101 return Err(Error::InvalidProblem(
102 "baracuda-kernels::IndexSelectPlan: select_dim out of range [0, N)",
103 ));
104 }
105 if desc.src_dim_size < 0 {
106 return Err(Error::InvalidProblem(
107 "baracuda-kernels::IndexSelectPlan: src_dim_size must be non-negative",
108 ));
109 }
110
111 let supported =
112 matches!(T::KIND, ElementKind::F32 | ElementKind::F64 | ElementKind::I32);
113 if !supported {
114 return Err(Error::Unsupported(
115 "baracuda-kernels::IndexSelectPlan: today only `f32`, `f64`, `i32` wired",
116 ));
117 }
118
119 let precision_guarantee = PrecisionGuarantee {
120 math_precision: MathPrecision::F32,
121 accumulator: ElementKind::F32,
122 bit_stable_on_same_hardware: true,
123 deterministic: true,
124 };
125 let sku = KernelSku {
126 category: OpCategory::Indexing,
127 op: IndexingKind::IndexSelect as u16,
128 element: T::KIND,
129 aux_element: Some(ElementKind::I32),
130 layout: None,
131 epilogue: None,
132 arch: ArchSku::Sm80,
133 backend: BackendKind::Bespoke,
134 precision_guarantee,
135 };
136 Ok(Self {
137 desc: *desc,
138 sku,
139 _marker: PhantomData,
140 })
141 }
142
143 pub fn can_implement<I: IndexElement>(&self, args: &IndexSelectArgs<'_, T, N, I>) -> Result<()> {
147 if args.out.shape != self.desc.out_shape {
148 return Err(Error::InvalidProblem(
149 "baracuda-kernels::IndexSelectPlan: out shape mismatch with descriptor",
150 ));
151 }
152 let expected_idx = self.desc.out_shape[self.desc.select_dim as usize];
153 if args.idx.shape[0] != expected_idx {
154 return Err(Error::InvalidProblem(
155 "baracuda-kernels::IndexSelectPlan: idx.shape[0] must equal \
156 out_shape[select_dim]",
157 ));
158 }
159 if N > 8 {
160 return Err(Error::Unsupported(
161 "baracuda-kernels::IndexSelectPlan: tensor rank > 8 not supported",
162 ));
163 }
164 let out_numel = args.out.numel();
165 let idx_numel = args.idx.numel();
166 let out_len = args.out.data.len() as i64;
167 let idx_len = args.idx.data.len() as i64;
168 if out_len < out_numel {
169 return Err(Error::BufferTooSmall {
170 needed: out_numel as usize,
171 got: out_len as usize,
172 });
173 }
174 if idx_len < idx_numel {
175 return Err(Error::BufferTooSmall {
176 needed: idx_numel as usize,
177 got: idx_len as usize,
178 });
179 }
180 Ok(())
181 }
182
183 #[inline]
185 pub fn workspace_size(&self) -> usize {
186 0
187 }
188
189 #[inline]
191 pub fn sku(&self) -> KernelSku {
192 self.sku
193 }
194
195 #[inline]
197 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
198 self.sku.precision_guarantee
199 }
200
201 pub fn run<I: IndexElement>(
206 &self,
207 stream: &Stream,
208 _workspace: Workspace<'_>,
209 args: IndexSelectArgs<'_, T, N, I>,
210 ) -> Result<()> {
211 self.can_implement(&args)?;
212 let out_numel = args.out.numel();
213 if out_numel == 0 {
214 return Ok(());
215 }
216 let src_ptr = args.src.data.as_raw().0 as *const c_void;
217 let idx_ptr = args.idx.data.as_raw().0 as *const c_void;
218 let out_ptr = args.out.data.as_raw().0 as *mut c_void;
219 let stream_ptr = stream.as_raw() as *mut c_void;
220
221 let out_shape = self.desc.out_shape;
222 let stride_src = args.src.stride;
223 let stride_out = args.out.stride;
224 let rank = N as i32;
225
226 let status = match (T::KIND, I::KIND) {
227 (ElementKind::F32, IndexElementKind::I32) => unsafe {
228 baracuda_kernels_sys::baracuda_kernels_index_select_f32_run(
229 out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
230 out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
231 src_ptr, idx_ptr, out_ptr,
232 core::ptr::null_mut(), 0, stream_ptr,
233 )
234 },
235 (ElementKind::F64, IndexElementKind::I32) => unsafe {
236 baracuda_kernels_sys::baracuda_kernels_index_select_f64_run(
237 out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
238 out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
239 src_ptr, idx_ptr, out_ptr,
240 core::ptr::null_mut(), 0, stream_ptr,
241 )
242 },
243 (ElementKind::I32, IndexElementKind::I32) => unsafe {
244 baracuda_kernels_sys::baracuda_kernels_index_select_i32_run(
245 out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
246 out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
247 src_ptr, idx_ptr, out_ptr,
248 core::ptr::null_mut(), 0, stream_ptr,
249 )
250 },
251 (ElementKind::F32, IndexElementKind::I64) => unsafe {
252 baracuda_kernels_sys::baracuda_kernels_index_select_i64idx_f32_run(
253 out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
254 out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
255 src_ptr, idx_ptr, out_ptr,
256 core::ptr::null_mut(), 0, stream_ptr,
257 )
258 },
259 (ElementKind::F64, IndexElementKind::I64) => unsafe {
260 baracuda_kernels_sys::baracuda_kernels_index_select_i64idx_f64_run(
261 out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
262 out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
263 src_ptr, idx_ptr, out_ptr,
264 core::ptr::null_mut(), 0, stream_ptr,
265 )
266 },
267 (ElementKind::I32, IndexElementKind::I64) => unsafe {
268 baracuda_kernels_sys::baracuda_kernels_index_select_i64idx_i32_run(
269 out_numel, rank, self.desc.select_dim, self.desc.src_dim_size,
270 out_shape.as_ptr(), stride_src.as_ptr(), stride_out.as_ptr(),
271 src_ptr, idx_ptr, out_ptr,
272 core::ptr::null_mut(), 0, stream_ptr,
273 )
274 },
275 _ => {
276 return Err(Error::Unsupported(
277 "baracuda-kernels::IndexSelectPlan::run reached an unimplemented dtype \
278 — select() should have caught this",
279 ));
280 }
281 };
282 super::gather::map_status(status)
283 }
284}