1use core::ffi::c_void;
22use core::marker::PhantomData;
23
24use baracuda_cutlass::{Error, Result};
25use baracuda_driver::Stream;
26use baracuda_kernels_types::{
27 ArchSku, BackendKind, Element, ElementKind, IndexElement, IndexElementKind, IndexingKind,
28 KernelSku, MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut,
29 TensorRef, Workspace,
30};
31
32use super::gather::map_status;
33
34#[derive(Copy, Clone, Debug)]
39pub struct ScatterDescriptor<const N: usize> {
40 pub upd_shape: [i32; N],
42 pub scatter_dim: i32,
44 pub out_dim_size: i32,
46 pub element: ElementKind,
48}
49
50pub struct ScatterArgs<'a, T: Element, const N: usize, I: IndexElement = i32> {
52 pub updates: TensorRef<'a, T, N>,
54 pub index: TensorRef<'a, I, N>,
57 pub out: TensorMut<'a, T, N>,
61}
62
63pub struct ScatterPlan<T: Element, const N: usize> {
87 desc: ScatterDescriptor<N>,
88 sku: KernelSku,
89 _marker: PhantomData<T>,
90}
91
92impl<T: Element, const N: usize> ScatterPlan<T, N> {
93 pub fn select(
97 _stream: &Stream,
98 desc: &ScatterDescriptor<N>,
99 _pref: PlanPreference,
100 ) -> Result<Self> {
101 if desc.element != T::KIND {
102 return Err(Error::Unsupported(
103 "baracuda-kernels::ScatterPlan: descriptor element != type parameter T",
104 ));
105 }
106 if N == 0 {
107 return Err(Error::InvalidProblem(
108 "baracuda-kernels::ScatterPlan: rank-0 tensors not supported",
109 ));
110 }
111 if desc.scatter_dim < 0 || desc.scatter_dim >= N as i32 {
112 return Err(Error::InvalidProblem(
113 "baracuda-kernels::ScatterPlan: scatter_dim out of range [0, N)",
114 ));
115 }
116 if desc.out_dim_size < 0 {
117 return Err(Error::InvalidProblem(
118 "baracuda-kernels::ScatterPlan: out_dim_size must be non-negative",
119 ));
120 }
121 for &d in desc.upd_shape.iter() {
122 if d < 0 {
123 return Err(Error::InvalidProblem(
124 "baracuda-kernels::ScatterPlan: upd_shape dims must be non-negative",
125 ));
126 }
127 }
128
129 let supported = matches!(
130 T::KIND,
131 ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
132 );
133 if !supported {
134 return Err(Error::Unsupported(
135 "baracuda-kernels::ScatterPlan: today only `f32`, `f64`, `f16`, `bf16` wired",
136 ));
137 }
138
139 let precision_guarantee = PrecisionGuarantee {
140 math_precision: MathPrecision::F32,
141 accumulator: T::KIND,
142 bit_stable_on_same_hardware: false,
145 deterministic: false,
146 };
147 let sku = KernelSku {
148 category: OpCategory::Indexing,
149 op: IndexingKind::Scatter as u16,
150 element: T::KIND,
151 aux_element: Some(ElementKind::I32),
152 layout: None,
153 epilogue: None,
154 arch: ArchSku::Sm80,
155 backend: BackendKind::Bespoke,
156 precision_guarantee,
157 };
158 Ok(Self {
159 desc: *desc,
160 sku,
161 _marker: PhantomData,
162 })
163 }
164
165 pub fn can_implement<I: IndexElement>(&self, args: &ScatterArgs<'_, T, N, I>) -> Result<()> {
167 if args.updates.shape != self.desc.upd_shape {
168 return Err(Error::InvalidProblem(
169 "baracuda-kernels::ScatterPlan: updates shape mismatch with descriptor",
170 ));
171 }
172 if args.index.shape != self.desc.upd_shape {
173 return Err(Error::InvalidProblem(
174 "baracuda-kernels::ScatterPlan: index shape must equal updates shape",
175 ));
176 }
177 if N > 8 {
178 return Err(Error::Unsupported(
179 "baracuda-kernels::ScatterPlan: tensor rank > 8 not supported",
180 ));
181 }
182 let upd_numel = args.updates.numel();
183 let upd_len = args.updates.data.len() as i64;
184 let idx_len = args.index.data.len() as i64;
185 if upd_len < upd_numel {
186 return Err(Error::BufferTooSmall {
187 needed: upd_numel as usize,
188 got: upd_len as usize,
189 });
190 }
191 if idx_len < upd_numel {
192 return Err(Error::BufferTooSmall {
193 needed: upd_numel as usize,
194 got: idx_len as usize,
195 });
196 }
197 Ok(())
198 }
199
200 #[inline]
202 pub fn workspace_size(&self) -> usize {
203 0
204 }
205
206 #[inline]
208 pub fn sku(&self) -> KernelSku {
209 self.sku
210 }
211
212 #[inline]
214 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
215 self.sku.precision_guarantee
216 }
217
218 pub fn run<I: IndexElement>(
221 &self,
222 stream: &Stream,
223 _workspace: Workspace<'_>,
224 args: ScatterArgs<'_, T, N, I>,
225 ) -> Result<()> {
226 self.can_implement(&args)?;
227 let upd_numel = args.updates.numel();
228 if upd_numel == 0 {
229 return Ok(());
230 }
231 let upd_ptr = args.updates.data.as_raw().0 as *const c_void;
232 let idx_ptr = args.index.data.as_raw().0 as *const c_void;
233 let out_ptr = args.out.data.as_raw().0 as *mut c_void;
234 let stream_ptr = stream.as_raw() as *mut c_void;
235
236 let upd_shape = self.desc.upd_shape;
237 let stride_upd = args.updates.stride;
238 let stride_index = args.index.stride;
239 let stride_out = args.out.stride;
240 let rank = N as i32;
241
242 let status = match (T::KIND, I::KIND) {
243 (ElementKind::F32, IndexElementKind::I32) => unsafe {
244 baracuda_kernels_sys::baracuda_kernels_scatter_f32_run(
245 upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
246 upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
247 stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
248 core::ptr::null_mut(), 0, stream_ptr,
249 )
250 },
251 (ElementKind::F64, IndexElementKind::I32) => unsafe {
252 baracuda_kernels_sys::baracuda_kernels_scatter_f64_run(
253 upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
254 upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
255 stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
256 core::ptr::null_mut(), 0, stream_ptr,
257 )
258 },
259 (ElementKind::F16, IndexElementKind::I32) => unsafe {
260 baracuda_kernels_sys::baracuda_kernels_scatter_f16_run(
261 upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
262 upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
263 stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
264 core::ptr::null_mut(), 0, stream_ptr,
265 )
266 },
267 (ElementKind::Bf16, IndexElementKind::I32) => unsafe {
268 baracuda_kernels_sys::baracuda_kernels_scatter_bf16_run(
269 upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
270 upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
271 stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
272 core::ptr::null_mut(), 0, stream_ptr,
273 )
274 },
275 (ElementKind::F32, IndexElementKind::I64) => unsafe {
276 baracuda_kernels_sys::baracuda_kernels_scatter_i64idx_f32_run(
277 upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
278 upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
279 stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
280 core::ptr::null_mut(), 0, stream_ptr,
281 )
282 },
283 (ElementKind::F64, IndexElementKind::I64) => unsafe {
284 baracuda_kernels_sys::baracuda_kernels_scatter_i64idx_f64_run(
285 upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
286 upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
287 stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
288 core::ptr::null_mut(), 0, stream_ptr,
289 )
290 },
291 (ElementKind::F16, IndexElementKind::I64) => unsafe {
292 baracuda_kernels_sys::baracuda_kernels_scatter_i64idx_f16_run(
293 upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
294 upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
295 stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
296 core::ptr::null_mut(), 0, stream_ptr,
297 )
298 },
299 (ElementKind::Bf16, IndexElementKind::I64) => unsafe {
300 baracuda_kernels_sys::baracuda_kernels_scatter_i64idx_bf16_run(
301 upd_numel, rank, self.desc.scatter_dim, self.desc.out_dim_size,
302 upd_shape.as_ptr(), stride_upd.as_ptr(), stride_index.as_ptr(),
303 stride_out.as_ptr(), upd_ptr, idx_ptr, out_ptr,
304 core::ptr::null_mut(), 0, stream_ptr,
305 )
306 },
307 _ => {
308 return Err(Error::Unsupported(
309 "baracuda-kernels::ScatterPlan::run reached an unimplemented dtype \
310 — select() should have caught this",
311 ));
312 }
313 };
314 map_status(status)
315 }
316}