Skip to main content

baracuda_kernels/segment/
segment_sum.rs

1//! `segment_sum` plan — Category S, sorted variant.
2//!
3//! `out[s, d] = Σ_{n : segment_ids[n] == s} input[n, d]`. Requires
4//! `segment_ids` to be monotonically non-decreasing. TF / JAX
5//! `segment_sum`.
6//!
7//! Trailblazer dtype coverage: `f32, f64`.
8//!
9//! BW: see [`crate::segment::SegmentSumBackwardPlan`].
10
11use core::ffi::c_void;
12use core::marker::PhantomData;
13
14use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
18    PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut, TensorRef, Workspace,
19};
20
21use super::map_status;
22
23/// Descriptor for a `segment_sum` op.
24#[derive(Copy, Clone, Debug)]
25pub struct SegmentSumDescriptor {
26    /// Number of input rows (length of `segment_ids`).
27    pub num_inputs: i32,
28    /// Embedding / feature dim — second axis of `input` and `output`.
29    pub embedding_dim: i32,
30    /// Total number of segments — first axis of `output`. Output is
31    /// allocated for this many rows even when some segments are empty.
32    pub num_segments: i32,
33    /// Value element type.
34    pub element: ElementKind,
35}
36
37/// Args bundle for a `segment_sum` launch.
38pub struct SegmentSumArgs<'a, T: Element> {
39    /// Input `[N, D]`.
40    pub input: TensorRef<'a, T, 2>,
41    /// Segment ids `[N]`, i32, sorted non-decreasing, values in
42    /// `[0, num_segments)`.
43    pub segment_ids: TensorRef<'a, i32, 1>,
44    /// Output `[num_segments, D]`. Overwritten by the launch — no
45    /// accumulation into pre-existing state.
46    pub output: TensorMut<'a, T, 2>,
47}
48
49/// `segment_sum` plan (sorted).
50///
51/// `out[s, d] = Σ_{n : segment_ids[n] == s} input[n, d]` (TF / JAX
52/// `segment_sum`). Requires `segment_ids` to be monotonically
53/// non-decreasing.
54///
55/// **When to use**: forward sorted segment-sum. For unsorted IDs use
56/// [`UnsortedSegmentSumPlan`](crate::UnsortedSegmentSumPlan). Pair
57/// with [`SegmentSumBackwardPlan`](crate::SegmentSumBackwardPlan)
58/// for autograd.
59///
60/// **Dtypes**: `{f32, f64}` (matches the family — kernels rely on
61/// FP atomic primitives even in the sorted variant for some paths).
62///
63/// **Shape limits**: `input` is `[N, D]`, `segment_ids` is `[N]`
64/// with values in `[0, num_segments)`; `output` is `[num_segments, D]`.
65/// All extents non-negative.
66///
67/// **Workspace**: none.
68///
69/// **Precision guarantee**: **deterministic, bit-stable** — single
70/// thread per output cell sweeps the segment's row range in order.
71///
72/// **Index policy**: out-of-range segment IDs (`< 0` or
73/// `≥ num_segments`) are silently dropped (TF / JAX semantic).
74/// Output buffer is fully overwritten (no accumulation into prior
75/// state).
76pub struct SegmentSumPlan<T: Element> {
77    desc: SegmentSumDescriptor,
78    sku: KernelSku,
79    _marker: PhantomData<T>,
80}
81
82impl<T: Element> SegmentSumPlan<T> {
83    /// Pick a kernel for `desc`.
84    pub fn select(
85        _stream: &Stream,
86        desc: &SegmentSumDescriptor,
87        _pref: PlanPreference,
88    ) -> Result<Self> {
89        validate_desc(*desc, T::KIND, "SegmentSumPlan")?;
90        let sku = build_sku::<T>(SegmentKind::SegmentSum);
91        Ok(Self {
92            desc: *desc,
93            sku,
94            _marker: PhantomData,
95        })
96    }
97
98    /// Validate args.
99    pub fn can_implement(&self, args: &SegmentSumArgs<'_, T>) -> Result<()> {
100        validate_args(
101            &self.desc,
102            args.input.shape,
103            args.segment_ids.shape,
104            args.output.shape,
105            "SegmentSumPlan",
106        )
107    }
108
109    /// Workspace size in bytes.
110    #[inline]
111    pub fn workspace_size(&self) -> usize {
112        0
113    }
114
115    /// Identity of the kernel this plan picked.
116    #[inline]
117    pub fn sku(&self) -> KernelSku {
118        self.sku
119    }
120
121    /// Numerical guarantees for this plan's kernel.
122    #[inline]
123    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
124        self.sku.precision_guarantee
125    }
126
127    /// Launch.
128    pub fn run(
129        &self,
130        stream: &Stream,
131        _workspace: Workspace<'_>,
132        args: SegmentSumArgs<'_, T>,
133    ) -> Result<()> {
134        self.can_implement(&args)?;
135        let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
136        if total_out == 0 {
137            return Ok(());
138        }
139        run_sorted_fw::<T>(
140            stream,
141            self.desc.num_inputs,
142            self.desc.embedding_dim,
143            self.desc.num_segments,
144            &args.input,
145            &args.segment_ids,
146            &args.output,
147            SortedFwOp::Sum,
148        )
149    }
150}
151
152/// Validate descriptor fields shared across the sorted-family plans.
153/// The `_plan_name` parameter is unused today; reserved for richer
154/// error messages without churning every call site when we wire it in.
155pub(crate) fn validate_desc(
156    desc_num_inputs_dim_seg: impl SegDescView,
157    expected_element: ElementKind,
158    _plan_name: &'static str,
159) -> Result<()> {
160    let (n, d, ns, el) = desc_num_inputs_dim_seg.view();
161    if el != expected_element {
162        return Err(Error::Unsupported(
163            "baracuda-kernels::segment: descriptor element != type parameter T",
164        ));
165    }
166    if n < 0 || d < 0 || ns < 0 {
167        return Err(Error::InvalidProblem(
168            "baracuda-kernels::segment: num_inputs / embedding_dim / num_segments must be non-negative",
169        ));
170    }
171    if !matches!(el, ElementKind::F32 | ElementKind::F64) {
172        return Err(Error::Unsupported(
173            "baracuda-kernels::segment: today only f32, f64 wired (atomicAdd / atomic-CAS restricted to native-FP-atomic types)",
174        ));
175    }
176    Ok(())
177}
178
179/// Trait abstracting the four descriptor fields shared by every sorted
180/// + unsorted segment plan. Lets `validate_desc` accept any descriptor
181/// without forcing a concrete type.
182pub(crate) trait SegDescView {
183    fn view(&self) -> (i32, i32, i32, ElementKind);
184}
185
186impl SegDescView for SegmentSumDescriptor {
187    #[inline]
188    fn view(&self) -> (i32, i32, i32, ElementKind) {
189        (
190            self.num_inputs,
191            self.embedding_dim,
192            self.num_segments,
193            self.element,
194        )
195    }
196}
197
198/// Validate args shared across the sorted-family FW plans.
199pub(crate) fn validate_args(
200    desc: &SegmentSumDescriptor,
201    input_shape: [i32; 2],
202    seg_shape: [i32; 1],
203    output_shape: [i32; 2],
204    _plan_name: &'static str,
205) -> Result<()> {
206    if input_shape != [desc.num_inputs, desc.embedding_dim] {
207        return Err(Error::InvalidProblem(
208            "baracuda-kernels::segment: input shape != [num_inputs, embedding_dim]",
209        ));
210    }
211    if seg_shape != [desc.num_inputs] {
212        return Err(Error::InvalidProblem(
213            "baracuda-kernels::segment: segment_ids shape != [num_inputs]",
214        ));
215    }
216    if output_shape != [desc.num_segments, desc.embedding_dim] {
217        return Err(Error::InvalidProblem(
218            "baracuda-kernels::segment: output shape != [num_segments, embedding_dim]",
219        ));
220    }
221    Ok(())
222}
223
224/// Construct a `KernelSku` for the segment-family plan.
225pub(crate) fn build_sku<T: Element>(op: SegmentKind) -> KernelSku {
226    let precision_guarantee = PrecisionGuarantee {
227        math_precision: if T::KIND == ElementKind::F64 {
228            MathPrecision::F64
229        } else {
230            MathPrecision::F32
231        },
232        accumulator: T::KIND,
233        // Sorted: deterministic (single thread per output cell, in-order
234        // sweep). Unsorted: atomic accumulation → not deterministic.
235        // We set conservative defaults here and let unsorted plans
236        // re-tag via their own builder when they need to differ.
237        bit_stable_on_same_hardware: matches!(
238            op,
239            SegmentKind::SegmentSum
240                | SegmentKind::SegmentMean
241                | SegmentKind::SegmentMax
242                | SegmentKind::SegmentMin
243                | SegmentKind::SegmentProd
244                | SegmentKind::SegmentSumBackward
245                | SegmentKind::SegmentMeanBackward
246                | SegmentKind::UnsortedSegmentSumBackward
247                | SegmentKind::UnsortedSegmentMeanBackward
248        ),
249        deterministic: matches!(
250            op,
251            SegmentKind::SegmentSum
252                | SegmentKind::SegmentMean
253                | SegmentKind::SegmentMax
254                | SegmentKind::SegmentMin
255                | SegmentKind::SegmentProd
256                | SegmentKind::SegmentSumBackward
257                | SegmentKind::SegmentMeanBackward
258                | SegmentKind::UnsortedSegmentSumBackward
259                | SegmentKind::UnsortedSegmentMeanBackward
260        ),
261    };
262    KernelSku {
263        category: OpCategory::SegmentOps,
264        op: op as u16,
265        element: T::KIND,
266        aux_element: Some(ElementKind::I32),
267        layout: None,
268        epilogue: None,
269        arch: ArchSku::Sm80,
270        backend: BackendKind::Bespoke,
271        precision_guarantee,
272    }
273}
274
275/// Sorted FW op tag — picks the launcher symbol at `run` time.
276#[derive(Copy, Clone, Debug)]
277pub(crate) enum SortedFwOp {
278    Sum,
279    Mean,
280    Max,
281    Min,
282    Prod,
283}
284
285/// Shared sorted-FW launch helper.
286pub(crate) fn run_sorted_fw<T: Element>(
287    stream: &Stream,
288    n: i32,
289    d: i32,
290    num_segments: i32,
291    input: &TensorRef<'_, T, 2>,
292    segment_ids: &TensorRef<'_, i32, 1>,
293    output: &TensorMut<'_, T, 2>,
294    op: SortedFwOp,
295) -> Result<()> {
296    let in_ptr = input.data.as_raw().0 as *const c_void;
297    let id_ptr = segment_ids.data.as_raw().0 as *const c_void;
298    let out_ptr = output.data.as_raw().0 as *mut c_void;
299    let stream_ptr = stream.as_raw() as *mut c_void;
300
301    let status = match (T::KIND, op) {
302        (ElementKind::F32, SortedFwOp::Sum) => unsafe {
303            baracuda_kernels_sys::baracuda_kernels_segment_sum_f32_run(
304                n, d, num_segments, in_ptr, id_ptr, out_ptr,
305                core::ptr::null_mut(), 0, stream_ptr,
306            )
307        },
308        (ElementKind::F64, SortedFwOp::Sum) => unsafe {
309            baracuda_kernels_sys::baracuda_kernels_segment_sum_f64_run(
310                n, d, num_segments, in_ptr, id_ptr, out_ptr,
311                core::ptr::null_mut(), 0, stream_ptr,
312            )
313        },
314        (ElementKind::F32, SortedFwOp::Mean) => unsafe {
315            baracuda_kernels_sys::baracuda_kernels_segment_mean_f32_run(
316                n, d, num_segments, in_ptr, id_ptr, out_ptr,
317                core::ptr::null_mut(), 0, stream_ptr,
318            )
319        },
320        (ElementKind::F64, SortedFwOp::Mean) => unsafe {
321            baracuda_kernels_sys::baracuda_kernels_segment_mean_f64_run(
322                n, d, num_segments, in_ptr, id_ptr, out_ptr,
323                core::ptr::null_mut(), 0, stream_ptr,
324            )
325        },
326        (ElementKind::F32, SortedFwOp::Max) => unsafe {
327            baracuda_kernels_sys::baracuda_kernels_segment_max_f32_run(
328                n, d, num_segments, in_ptr, id_ptr, out_ptr,
329                core::ptr::null_mut(), 0, stream_ptr,
330            )
331        },
332        (ElementKind::F64, SortedFwOp::Max) => unsafe {
333            baracuda_kernels_sys::baracuda_kernels_segment_max_f64_run(
334                n, d, num_segments, in_ptr, id_ptr, out_ptr,
335                core::ptr::null_mut(), 0, stream_ptr,
336            )
337        },
338        (ElementKind::F32, SortedFwOp::Min) => unsafe {
339            baracuda_kernels_sys::baracuda_kernels_segment_min_f32_run(
340                n, d, num_segments, in_ptr, id_ptr, out_ptr,
341                core::ptr::null_mut(), 0, stream_ptr,
342            )
343        },
344        (ElementKind::F64, SortedFwOp::Min) => unsafe {
345            baracuda_kernels_sys::baracuda_kernels_segment_min_f64_run(
346                n, d, num_segments, in_ptr, id_ptr, out_ptr,
347                core::ptr::null_mut(), 0, stream_ptr,
348            )
349        },
350        (ElementKind::F32, SortedFwOp::Prod) => unsafe {
351            baracuda_kernels_sys::baracuda_kernels_segment_prod_f32_run(
352                n, d, num_segments, in_ptr, id_ptr, out_ptr,
353                core::ptr::null_mut(), 0, stream_ptr,
354            )
355        },
356        (ElementKind::F64, SortedFwOp::Prod) => unsafe {
357            baracuda_kernels_sys::baracuda_kernels_segment_prod_f64_run(
358                n, d, num_segments, in_ptr, id_ptr, out_ptr,
359                core::ptr::null_mut(), 0, stream_ptr,
360            )
361        },
362        _ => {
363            return Err(Error::Unsupported(
364                "baracuda-kernels::segment::run_sorted_fw reached an unimplemented dtype \
365                 — select() should have caught this",
366            ));
367        }
368    };
369    map_status(status)
370}