1use core::ffi::c_void;
19use core::marker::PhantomData;
20
21use baracuda_cutlass::{Error, Result};
22use baracuda_driver::Stream;
23use baracuda_kernels_types::{
24 ArchSku, BackendKind, Element, ElementKind, IndexElement, IndexElementKind, IndexingKind,
25 KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
26 TensorRef, Workspace,
27};
28
29use super::gather::map_status;
30
31#[derive(Copy, Clone, Debug)]
33pub struct IndexAddDescriptor<const N: usize> {
34 pub src_shape: [i32; N],
36 pub add_dim: i32,
38 pub dst_dim_size: i32,
40 pub element: ElementKind,
42}
43
44pub struct IndexAddArgs<'a, T: Element, const N: usize, I: IndexElement = i32> {
46 pub src: TensorRef<'a, T, N>,
48 pub idx: TensorRef<'a, I, 1>,
50 pub dst: TensorMut<'a, T, N>,
54}
55
56pub struct IndexAddPlan<T: Element, const N: usize> {
80 desc: IndexAddDescriptor<N>,
81 sku: KernelSku,
82 _marker: PhantomData<T>,
83}
84
85impl<T: Element, const N: usize> IndexAddPlan<T, N> {
86 pub fn select(
88 _stream: &Stream,
89 desc: &IndexAddDescriptor<N>,
90 _pref: PlanPreference,
91 ) -> Result<Self> {
92 if desc.element != T::KIND {
93 return Err(Error::Unsupported(
94 "baracuda-kernels::IndexAddPlan: descriptor element != type parameter T",
95 ));
96 }
97 if N == 0 {
98 return Err(Error::InvalidProblem(
99 "baracuda-kernels::IndexAddPlan: rank-0 tensors not supported",
100 ));
101 }
102 if desc.add_dim < 0 || desc.add_dim >= N as i32 {
103 return Err(Error::InvalidProblem(
104 "baracuda-kernels::IndexAddPlan: add_dim out of range [0, N)",
105 ));
106 }
107 if desc.dst_dim_size < 0 {
108 return Err(Error::InvalidProblem(
109 "baracuda-kernels::IndexAddPlan: dst_dim_size must be non-negative",
110 ));
111 }
112 for &d in desc.src_shape.iter() {
113 if d < 0 {
114 return Err(Error::InvalidProblem(
115 "baracuda-kernels::IndexAddPlan: src_shape dims must be non-negative",
116 ));
117 }
118 }
119
120 let supported = matches!(
121 T::KIND,
122 ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
123 );
124 if !supported {
125 return Err(Error::Unsupported(
126 "baracuda-kernels::IndexAddPlan: today only `f32`, `f64`, `f16`, `bf16` wired",
127 ));
128 }
129
130 let precision_guarantee = PrecisionGuarantee {
131 math_precision: if T::KIND == ElementKind::F64 {
132 MathPrecision::F64
133 } else {
134 MathPrecision::F32
135 },
136 accumulator: T::KIND,
137 bit_stable_on_same_hardware: false,
139 deterministic: false,
140 };
141 let sku = KernelSku {
142 category: OpCategory::Indexing,
143 op: IndexingKind::IndexAdd as u16,
144 element: T::KIND,
145 aux_element: Some(ElementKind::I32),
146 layout: None,
147 epilogue: None,
148 arch: ArchSku::Sm80,
149 backend: BackendKind::Bespoke,
150 precision_guarantee,
151 };
152 Ok(Self {
153 desc: *desc,
154 sku,
155 _marker: PhantomData,
156 })
157 }
158
159 pub fn can_implement<I: IndexElement>(&self, args: &IndexAddArgs<'_, T, N, I>) -> Result<()> {
161 if args.src.shape != self.desc.src_shape {
162 return Err(Error::InvalidProblem(
163 "baracuda-kernels::IndexAddPlan: src shape mismatch with descriptor",
164 ));
165 }
166 let expected_idx = self.desc.src_shape[self.desc.add_dim as usize];
167 if args.idx.shape[0] != expected_idx {
168 return Err(Error::InvalidProblem(
169 "baracuda-kernels::IndexAddPlan: idx.shape[0] must equal \
170 src_shape[add_dim]",
171 ));
172 }
173 if N > 8 {
174 return Err(Error::Unsupported(
175 "baracuda-kernels::IndexAddPlan: tensor rank > 8 not supported",
176 ));
177 }
178 let src_numel = args.src.numel();
179 let idx_numel = args.idx.numel();
180 let src_len = args.src.data.len() as i64;
181 let idx_len = args.idx.data.len() as i64;
182 if src_len < src_numel {
183 return Err(Error::BufferTooSmall {
184 needed: src_numel as usize,
185 got: src_len as usize,
186 });
187 }
188 if idx_len < idx_numel {
189 return Err(Error::BufferTooSmall {
190 needed: idx_numel as usize,
191 got: idx_len as usize,
192 });
193 }
194 Ok(())
195 }
196
197 #[inline]
199 pub fn workspace_size(&self) -> usize {
200 0
201 }
202
203 #[inline]
205 pub fn sku(&self) -> KernelSku {
206 self.sku
207 }
208
209 #[inline]
211 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
212 self.sku.precision_guarantee
213 }
214
215 pub fn run<I: IndexElement>(
218 &self,
219 stream: &Stream,
220 _workspace: Workspace<'_>,
221 args: IndexAddArgs<'_, T, N, I>,
222 ) -> Result<()> {
223 self.can_implement(&args)?;
224 let src_numel = args.src.numel();
225 if src_numel == 0 {
226 return Ok(());
227 }
228 let src_ptr = args.src.data.as_raw().0 as *const c_void;
229 let idx_ptr = args.idx.data.as_raw().0 as *const c_void;
230 let dst_ptr = args.dst.data.as_raw().0 as *mut c_void;
231 let stream_ptr = stream.as_raw() as *mut c_void;
232
233 let src_shape = self.desc.src_shape;
234 let stride_src = args.src.stride;
235 let stride_dst = args.dst.stride;
236 let rank = N as i32;
237
238 let status = match (T::KIND, I::KIND) {
239 (ElementKind::F32, IndexElementKind::I32) => unsafe {
240 baracuda_kernels_sys::baracuda_kernels_index_add_f32_run(
241 src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
242 src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
243 src_ptr, idx_ptr, dst_ptr,
244 core::ptr::null_mut(), 0, stream_ptr,
245 )
246 },
247 (ElementKind::F64, IndexElementKind::I32) => unsafe {
248 baracuda_kernels_sys::baracuda_kernels_index_add_f64_run(
249 src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
250 src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
251 src_ptr, idx_ptr, dst_ptr,
252 core::ptr::null_mut(), 0, stream_ptr,
253 )
254 },
255 (ElementKind::F16, IndexElementKind::I32) => unsafe {
256 baracuda_kernels_sys::baracuda_kernels_index_add_f16_run(
257 src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
258 src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
259 src_ptr, idx_ptr, dst_ptr,
260 core::ptr::null_mut(), 0, stream_ptr,
261 )
262 },
263 (ElementKind::Bf16, IndexElementKind::I32) => unsafe {
264 baracuda_kernels_sys::baracuda_kernels_index_add_bf16_run(
265 src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
266 src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
267 src_ptr, idx_ptr, dst_ptr,
268 core::ptr::null_mut(), 0, stream_ptr,
269 )
270 },
271 (ElementKind::F32, IndexElementKind::I64) => unsafe {
272 baracuda_kernels_sys::baracuda_kernels_index_add_i64idx_f32_run(
273 src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
274 src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
275 src_ptr, idx_ptr, dst_ptr,
276 core::ptr::null_mut(), 0, stream_ptr,
277 )
278 },
279 (ElementKind::F64, IndexElementKind::I64) => unsafe {
280 baracuda_kernels_sys::baracuda_kernels_index_add_i64idx_f64_run(
281 src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
282 src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
283 src_ptr, idx_ptr, dst_ptr,
284 core::ptr::null_mut(), 0, stream_ptr,
285 )
286 },
287 (ElementKind::F16, IndexElementKind::I64) => unsafe {
288 baracuda_kernels_sys::baracuda_kernels_index_add_i64idx_f16_run(
289 src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
290 src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
291 src_ptr, idx_ptr, dst_ptr,
292 core::ptr::null_mut(), 0, stream_ptr,
293 )
294 },
295 (ElementKind::Bf16, IndexElementKind::I64) => unsafe {
296 baracuda_kernels_sys::baracuda_kernels_index_add_i64idx_bf16_run(
297 src_numel, rank, self.desc.add_dim, self.desc.dst_dim_size,
298 src_shape.as_ptr(), stride_src.as_ptr(), stride_dst.as_ptr(),
299 src_ptr, idx_ptr, dst_ptr,
300 core::ptr::null_mut(), 0, stream_ptr,
301 )
302 },
303 _ => {
304 return Err(Error::Unsupported(
305 "baracuda-kernels::IndexAddPlan::run reached an unimplemented dtype \
306 — select() should have caught this",
307 ));
308 }
309 };
310 map_status(status)
311 }
312}