Skip to main content

baracuda_kernels/segment/
segment_mean.rs

1//! `segment_mean` plan — Category S, sorted variant.
2//!
3//! `out[s, d] = mean_{n : segment_ids[n] == s} input[n, d]`. Requires
4//! `segment_ids` to be monotonically non-decreasing. TF / JAX
5//! `segment_mean`.
6//!
7//! Empty segments emit zero (no NaN — division is guarded inside the
8//! kernel). BW: see [`crate::segment::SegmentMeanBackwardPlan`].
9
10use core::marker::PhantomData;
11
12use baracuda_cutlass::Result;
13use baracuda_driver::Stream;
14use baracuda_kernels_types::{
15    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
16    TensorRef, Workspace,
17};
18
19use super::segment_sum::{
20    build_sku, run_sorted_fw, validate_args, validate_desc, SegDescView, SegmentSumDescriptor,
21    SortedFwOp,
22};
23
24/// Descriptor for a `segment_mean` op.
25#[derive(Copy, Clone, Debug)]
26pub struct SegmentMeanDescriptor {
27    /// Number of input rows.
28    pub num_inputs: i32,
29    /// Embedding / feature dim.
30    pub embedding_dim: i32,
31    /// Total number of segments.
32    pub num_segments: i32,
33    /// Value element type.
34    pub element: ElementKind,
35}
36
37impl SegDescView for SegmentMeanDescriptor {
38    #[inline]
39    fn view(&self) -> (i32, i32, i32, ElementKind) {
40        (
41            self.num_inputs,
42            self.embedding_dim,
43            self.num_segments,
44            self.element,
45        )
46    }
47}
48
49/// Args bundle for a `segment_mean` launch.
50pub struct SegmentMeanArgs<'a, T: Element> {
51    /// Input `[N, D]`.
52    pub input: TensorRef<'a, T, 2>,
53    /// Segment ids `[N]` — sorted non-decreasing.
54    pub segment_ids: TensorRef<'a, i32, 1>,
55    /// Output `[num_segments, D]`.
56    pub output: TensorMut<'a, T, 2>,
57}
58
59/// `segment_mean` plan (sorted).
60///
61/// `out[s, d] = mean_{n : segment_ids[n] == s} input[n, d]` (TF / JAX
62/// `segment_mean`). Requires `segment_ids` monotonically non-decreasing.
63///
64/// **When to use**: forward sorted segment-mean. Pair with
65/// [`SegmentMeanBackwardPlan`](crate::SegmentMeanBackwardPlan).
66///
67/// **Dtypes**: `{f32, f64}`.
68///
69/// **Shape limits**: `input` `[N, D]`; `segment_ids` `[N]` with values
70/// in `[0, num_segments)`; `output` `[num_segments, D]`.
71///
72/// **Workspace**: none — segment counts derived inline via binary
73/// search.
74///
75/// **Precision guarantee**: deterministic, bit-stable.
76///
77/// **Index policy**: out-of-range IDs dropped. Empty segments emit
78/// zero (division is guarded; no NaN).
79pub struct SegmentMeanPlan<T: Element> {
80    desc: SegmentMeanDescriptor,
81    sku: KernelSku,
82    _marker: PhantomData<T>,
83}
84
85impl<T: Element> SegmentMeanPlan<T> {
86    /// Pick a kernel.
87    pub fn select(
88        _stream: &Stream,
89        desc: &SegmentMeanDescriptor,
90        _pref: PlanPreference,
91    ) -> Result<Self> {
92        validate_desc(*desc, T::KIND, "SegmentMeanPlan")?;
93        Ok(Self {
94            desc: *desc,
95            sku: build_sku::<T>(SegmentKind::SegmentMean),
96            _marker: PhantomData,
97        })
98    }
99
100    /// Validate args.
101    pub fn can_implement(&self, args: &SegmentMeanArgs<'_, T>) -> Result<()> {
102        let proxy = SegmentSumDescriptor {
103            num_inputs: self.desc.num_inputs,
104            embedding_dim: self.desc.embedding_dim,
105            num_segments: self.desc.num_segments,
106            element: self.desc.element,
107        };
108        validate_args(
109            &proxy,
110            args.input.shape,
111            args.segment_ids.shape,
112            args.output.shape,
113            "SegmentMeanPlan",
114        )
115    }
116
117    /// Workspace size — zero (count computed inline via binary search).
118    #[inline]
119    pub fn workspace_size(&self) -> usize {
120        0
121    }
122
123    /// Identity of the kernel this plan picked.
124    #[inline]
125    pub fn sku(&self) -> KernelSku {
126        self.sku
127    }
128
129    /// Numerical guarantees.
130    #[inline]
131    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
132        self.sku.precision_guarantee
133    }
134
135    /// Launch.
136    pub fn run(
137        &self,
138        stream: &Stream,
139        _workspace: Workspace<'_>,
140        args: SegmentMeanArgs<'_, T>,
141    ) -> Result<()> {
142        self.can_implement(&args)?;
143        let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
144        if total_out == 0 {
145            return Ok(());
146        }
147        run_sorted_fw::<T>(
148            stream,
149            self.desc.num_inputs,
150            self.desc.embedding_dim,
151            self.desc.num_segments,
152            &args.input,
153            &args.segment_ids,
154            &args.output,
155            SortedFwOp::Mean,
156        )
157    }
158}