Skip to main content

baracuda_kernels/segment/
segment_prod.rs

1//! `segment_prod` plan — Category S, sorted variant.
2//!
3//! `out[s, d] = prod_{n : segment_ids[n] == s} input[n, d]`. Requires
4//! `segment_ids` to be monotonically non-decreasing. TF / JAX
5//! `segment_prod`.
6//!
7//! Empty segments emit `1` (multiplicative identity).
8//!
9//! FW only. BW deferred (the analytic gradient is
10//! `d_input[n, d] = d_output[seg, d] * (prod[seg, d] / input[n, d])`
11//! which is unstable when any input is near zero — would need a
12//! safer-divide path or saved-running-prods).
13
14use core::marker::PhantomData;
15
16use baracuda_cutlass::Result;
17use baracuda_driver::Stream;
18use baracuda_kernels_types::{
19    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
20    TensorRef, Workspace,
21};
22
23use super::segment_sum::{
24    build_sku, run_sorted_fw, validate_args, validate_desc, SegDescView, SegmentSumDescriptor,
25    SortedFwOp,
26};
27
28/// Descriptor for a `segment_prod` op.
29#[derive(Copy, Clone, Debug)]
30pub struct SegmentProdDescriptor {
31    /// Number of input rows.
32    pub num_inputs: i32,
33    /// Embedding / feature dim.
34    pub embedding_dim: i32,
35    /// Total number of segments.
36    pub num_segments: i32,
37    /// Value element type.
38    pub element: ElementKind,
39}
40
41impl SegDescView for SegmentProdDescriptor {
42    #[inline]
43    fn view(&self) -> (i32, i32, i32, ElementKind) {
44        (
45            self.num_inputs,
46            self.embedding_dim,
47            self.num_segments,
48            self.element,
49        )
50    }
51}
52
53/// Args bundle for a `segment_prod` launch.
54pub struct SegmentProdArgs<'a, T: Element> {
55    /// Input `[N, D]`.
56    pub input: TensorRef<'a, T, 2>,
57    /// Segment ids `[N]` — sorted non-decreasing.
58    pub segment_ids: TensorRef<'a, i32, 1>,
59    /// Output `[num_segments, D]`.
60    pub output: TensorMut<'a, T, 2>,
61}
62
63/// `segment_prod` plan (sorted, FW only).
64///
65/// `out[s, d] = Π_{n : segment_ids[n] == s} input[n, d]`. Sorted
66/// segment-product.
67///
68/// **When to use**: forward sorted segment-product. **No BW plan** —
69/// would need numerically-stable `prod / x_n` per-input divides
70/// (deferred).
71///
72/// **Dtypes**: `{f32, f64}`.
73///
74/// **Shape limits**: `input` `[N, D]`; `segment_ids` `[N]`;
75/// `output` `[num_segments, D]`.
76///
77/// **Workspace**: none.
78///
79/// **Precision guarantee**: deterministic, bit-stable.
80///
81/// **Index policy**: out-of-range IDs dropped. Empty segments emit
82/// identity (1.0).
83pub struct SegmentProdPlan<T: Element> {
84    desc: SegmentProdDescriptor,
85    sku: KernelSku,
86    _marker: PhantomData<T>,
87}
88
89impl<T: Element> SegmentProdPlan<T> {
90    /// Pick a kernel.
91    pub fn select(
92        _stream: &Stream,
93        desc: &SegmentProdDescriptor,
94        _pref: PlanPreference,
95    ) -> Result<Self> {
96        validate_desc(*desc, T::KIND, "SegmentProdPlan")?;
97        Ok(Self {
98            desc: *desc,
99            sku: build_sku::<T>(SegmentKind::SegmentProd),
100            _marker: PhantomData,
101        })
102    }
103
104    /// Validate args.
105    pub fn can_implement(&self, args: &SegmentProdArgs<'_, T>) -> Result<()> {
106        let proxy = SegmentSumDescriptor {
107            num_inputs: self.desc.num_inputs,
108            embedding_dim: self.desc.embedding_dim,
109            num_segments: self.desc.num_segments,
110            element: self.desc.element,
111        };
112        validate_args(
113            &proxy,
114            args.input.shape,
115            args.segment_ids.shape,
116            args.output.shape,
117            "SegmentProdPlan",
118        )
119    }
120
121    /// Workspace size — zero.
122    #[inline]
123    pub fn workspace_size(&self) -> usize {
124        0
125    }
126
127    /// Identity of the kernel.
128    #[inline]
129    pub fn sku(&self) -> KernelSku {
130        self.sku
131    }
132
133    /// Numerical guarantees.
134    #[inline]
135    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
136        self.sku.precision_guarantee
137    }
138
139    /// Launch.
140    pub fn run(
141        &self,
142        stream: &Stream,
143        _workspace: Workspace<'_>,
144        args: SegmentProdArgs<'_, T>,
145    ) -> Result<()> {
146        self.can_implement(&args)?;
147        let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
148        if total_out == 0 {
149            return Ok(());
150        }
151        run_sorted_fw::<T>(
152            stream,
153            self.desc.num_inputs,
154            self.desc.embedding_dim,
155            self.desc.num_segments,
156            &args.input,
157            &args.segment_ids,
158            &args.output,
159            SortedFwOp::Prod,
160        )
161    }
162}