1use core::ffi::c_void;
12use core::marker::PhantomData;
13
14use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
18 PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut, TensorRef, Workspace,
19};
20
21use super::map_status;
22
23#[derive(Copy, Clone, Debug)]
25pub struct SegmentSumDescriptor {
26 pub num_inputs: i32,
28 pub embedding_dim: i32,
30 pub num_segments: i32,
33 pub element: ElementKind,
35}
36
37pub struct SegmentSumArgs<'a, T: Element> {
39 pub input: TensorRef<'a, T, 2>,
41 pub segment_ids: TensorRef<'a, i32, 1>,
44 pub output: TensorMut<'a, T, 2>,
47}
48
49pub struct SegmentSumPlan<T: Element> {
77 desc: SegmentSumDescriptor,
78 sku: KernelSku,
79 _marker: PhantomData<T>,
80}
81
82impl<T: Element> SegmentSumPlan<T> {
83 pub fn select(
85 _stream: &Stream,
86 desc: &SegmentSumDescriptor,
87 _pref: PlanPreference,
88 ) -> Result<Self> {
89 validate_desc(*desc, T::KIND, "SegmentSumPlan")?;
90 let sku = build_sku::<T>(SegmentKind::SegmentSum);
91 Ok(Self {
92 desc: *desc,
93 sku,
94 _marker: PhantomData,
95 })
96 }
97
98 pub fn can_implement(&self, args: &SegmentSumArgs<'_, T>) -> Result<()> {
100 validate_args(
101 &self.desc,
102 args.input.shape,
103 args.segment_ids.shape,
104 args.output.shape,
105 "SegmentSumPlan",
106 )
107 }
108
109 #[inline]
111 pub fn workspace_size(&self) -> usize {
112 0
113 }
114
115 #[inline]
117 pub fn sku(&self) -> KernelSku {
118 self.sku
119 }
120
121 #[inline]
123 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
124 self.sku.precision_guarantee
125 }
126
127 pub fn run(
129 &self,
130 stream: &Stream,
131 _workspace: Workspace<'_>,
132 args: SegmentSumArgs<'_, T>,
133 ) -> Result<()> {
134 self.can_implement(&args)?;
135 let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
136 if total_out == 0 {
137 return Ok(());
138 }
139 run_sorted_fw::<T>(
140 stream,
141 self.desc.num_inputs,
142 self.desc.embedding_dim,
143 self.desc.num_segments,
144 &args.input,
145 &args.segment_ids,
146 &args.output,
147 SortedFwOp::Sum,
148 )
149 }
150}
151
152pub(crate) fn validate_desc(
156 desc_num_inputs_dim_seg: impl SegDescView,
157 expected_element: ElementKind,
158 _plan_name: &'static str,
159) -> Result<()> {
160 let (n, d, ns, el) = desc_num_inputs_dim_seg.view();
161 if el != expected_element {
162 return Err(Error::Unsupported(
163 "baracuda-kernels::segment: descriptor element != type parameter T",
164 ));
165 }
166 if n < 0 || d < 0 || ns < 0 {
167 return Err(Error::InvalidProblem(
168 "baracuda-kernels::segment: num_inputs / embedding_dim / num_segments must be non-negative",
169 ));
170 }
171 if !matches!(el, ElementKind::F32 | ElementKind::F64) {
172 return Err(Error::Unsupported(
173 "baracuda-kernels::segment: today only f32, f64 wired (atomicAdd / atomic-CAS restricted to native-FP-atomic types)",
174 ));
175 }
176 Ok(())
177}
178
179pub(crate) trait SegDescView {
183 fn view(&self) -> (i32, i32, i32, ElementKind);
184}
185
186impl SegDescView for SegmentSumDescriptor {
187 #[inline]
188 fn view(&self) -> (i32, i32, i32, ElementKind) {
189 (
190 self.num_inputs,
191 self.embedding_dim,
192 self.num_segments,
193 self.element,
194 )
195 }
196}
197
198pub(crate) fn validate_args(
200 desc: &SegmentSumDescriptor,
201 input_shape: [i32; 2],
202 seg_shape: [i32; 1],
203 output_shape: [i32; 2],
204 _plan_name: &'static str,
205) -> Result<()> {
206 if input_shape != [desc.num_inputs, desc.embedding_dim] {
207 return Err(Error::InvalidProblem(
208 "baracuda-kernels::segment: input shape != [num_inputs, embedding_dim]",
209 ));
210 }
211 if seg_shape != [desc.num_inputs] {
212 return Err(Error::InvalidProblem(
213 "baracuda-kernels::segment: segment_ids shape != [num_inputs]",
214 ));
215 }
216 if output_shape != [desc.num_segments, desc.embedding_dim] {
217 return Err(Error::InvalidProblem(
218 "baracuda-kernels::segment: output shape != [num_segments, embedding_dim]",
219 ));
220 }
221 Ok(())
222}
223
224pub(crate) fn build_sku<T: Element>(op: SegmentKind) -> KernelSku {
226 let precision_guarantee = PrecisionGuarantee {
227 math_precision: if T::KIND == ElementKind::F64 {
228 MathPrecision::F64
229 } else {
230 MathPrecision::F32
231 },
232 accumulator: T::KIND,
233 bit_stable_on_same_hardware: matches!(
238 op,
239 SegmentKind::SegmentSum
240 | SegmentKind::SegmentMean
241 | SegmentKind::SegmentMax
242 | SegmentKind::SegmentMin
243 | SegmentKind::SegmentProd
244 | SegmentKind::SegmentSumBackward
245 | SegmentKind::SegmentMeanBackward
246 | SegmentKind::UnsortedSegmentSumBackward
247 | SegmentKind::UnsortedSegmentMeanBackward
248 ),
249 deterministic: matches!(
250 op,
251 SegmentKind::SegmentSum
252 | SegmentKind::SegmentMean
253 | SegmentKind::SegmentMax
254 | SegmentKind::SegmentMin
255 | SegmentKind::SegmentProd
256 | SegmentKind::SegmentSumBackward
257 | SegmentKind::SegmentMeanBackward
258 | SegmentKind::UnsortedSegmentSumBackward
259 | SegmentKind::UnsortedSegmentMeanBackward
260 ),
261 };
262 KernelSku {
263 category: OpCategory::SegmentOps,
264 op: op as u16,
265 element: T::KIND,
266 aux_element: Some(ElementKind::I32),
267 layout: None,
268 epilogue: None,
269 arch: ArchSku::Sm80,
270 backend: BackendKind::Bespoke,
271 precision_guarantee,
272 }
273}
274
275#[derive(Copy, Clone, Debug)]
277pub(crate) enum SortedFwOp {
278 Sum,
279 Mean,
280 Max,
281 Min,
282 Prod,
283}
284
285pub(crate) fn run_sorted_fw<T: Element>(
287 stream: &Stream,
288 n: i32,
289 d: i32,
290 num_segments: i32,
291 input: &TensorRef<'_, T, 2>,
292 segment_ids: &TensorRef<'_, i32, 1>,
293 output: &TensorMut<'_, T, 2>,
294 op: SortedFwOp,
295) -> Result<()> {
296 let in_ptr = input.data.as_raw().0 as *const c_void;
297 let id_ptr = segment_ids.data.as_raw().0 as *const c_void;
298 let out_ptr = output.data.as_raw().0 as *mut c_void;
299 let stream_ptr = stream.as_raw() as *mut c_void;
300
301 let status = match (T::KIND, op) {
302 (ElementKind::F32, SortedFwOp::Sum) => unsafe {
303 baracuda_kernels_sys::baracuda_kernels_segment_sum_f32_run(
304 n, d, num_segments, in_ptr, id_ptr, out_ptr,
305 core::ptr::null_mut(), 0, stream_ptr,
306 )
307 },
308 (ElementKind::F64, SortedFwOp::Sum) => unsafe {
309 baracuda_kernels_sys::baracuda_kernels_segment_sum_f64_run(
310 n, d, num_segments, in_ptr, id_ptr, out_ptr,
311 core::ptr::null_mut(), 0, stream_ptr,
312 )
313 },
314 (ElementKind::F32, SortedFwOp::Mean) => unsafe {
315 baracuda_kernels_sys::baracuda_kernels_segment_mean_f32_run(
316 n, d, num_segments, in_ptr, id_ptr, out_ptr,
317 core::ptr::null_mut(), 0, stream_ptr,
318 )
319 },
320 (ElementKind::F64, SortedFwOp::Mean) => unsafe {
321 baracuda_kernels_sys::baracuda_kernels_segment_mean_f64_run(
322 n, d, num_segments, in_ptr, id_ptr, out_ptr,
323 core::ptr::null_mut(), 0, stream_ptr,
324 )
325 },
326 (ElementKind::F32, SortedFwOp::Max) => unsafe {
327 baracuda_kernels_sys::baracuda_kernels_segment_max_f32_run(
328 n, d, num_segments, in_ptr, id_ptr, out_ptr,
329 core::ptr::null_mut(), 0, stream_ptr,
330 )
331 },
332 (ElementKind::F64, SortedFwOp::Max) => unsafe {
333 baracuda_kernels_sys::baracuda_kernels_segment_max_f64_run(
334 n, d, num_segments, in_ptr, id_ptr, out_ptr,
335 core::ptr::null_mut(), 0, stream_ptr,
336 )
337 },
338 (ElementKind::F32, SortedFwOp::Min) => unsafe {
339 baracuda_kernels_sys::baracuda_kernels_segment_min_f32_run(
340 n, d, num_segments, in_ptr, id_ptr, out_ptr,
341 core::ptr::null_mut(), 0, stream_ptr,
342 )
343 },
344 (ElementKind::F64, SortedFwOp::Min) => unsafe {
345 baracuda_kernels_sys::baracuda_kernels_segment_min_f64_run(
346 n, d, num_segments, in_ptr, id_ptr, out_ptr,
347 core::ptr::null_mut(), 0, stream_ptr,
348 )
349 },
350 (ElementKind::F32, SortedFwOp::Prod) => unsafe {
351 baracuda_kernels_sys::baracuda_kernels_segment_prod_f32_run(
352 n, d, num_segments, in_ptr, id_ptr, out_ptr,
353 core::ptr::null_mut(), 0, stream_ptr,
354 )
355 },
356 (ElementKind::F64, SortedFwOp::Prod) => unsafe {
357 baracuda_kernels_sys::baracuda_kernels_segment_prod_f64_run(
358 n, d, num_segments, in_ptr, id_ptr, out_ptr,
359 core::ptr::null_mut(), 0, stream_ptr,
360 )
361 },
362 _ => {
363 return Err(Error::Unsupported(
364 "baracuda-kernels::segment::run_sorted_fw reached an unimplemented dtype \
365 — select() should have caught this",
366 ));
367 }
368 };
369 map_status(status)
370}