Skip to main content

baracuda_kernels/segment/
segment_max.rs

1//! `segment_max` plan — Category S, sorted variant.
2//!
3//! `out[s, d] = max_{n : segment_ids[n] == s} input[n, d]`. Requires
4//! `segment_ids` to be monotonically non-decreasing. TF / JAX
5//! `segment_max`.
6//!
7//! Empty segments emit zero (the kernel writes a per-op identity
8//! sentinel — see `segment_sorted_kernel`).
9//!
10//! FW only. BW deferred — would need argmax tracking from FW.
11
12use core::marker::PhantomData;
13
14use baracuda_cutlass::Result;
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
18    TensorRef, Workspace,
19};
20
21use super::segment_sum::{
22    build_sku, run_sorted_fw, validate_args, validate_desc, SegDescView, SegmentSumDescriptor,
23    SortedFwOp,
24};
25
26/// Descriptor for a `segment_max` op.
27#[derive(Copy, Clone, Debug)]
28pub struct SegmentMaxDescriptor {
29    /// Number of input rows.
30    pub num_inputs: i32,
31    /// Embedding / feature dim.
32    pub embedding_dim: i32,
33    /// Total number of segments.
34    pub num_segments: i32,
35    /// Value element type.
36    pub element: ElementKind,
37}
38
39impl SegDescView for SegmentMaxDescriptor {
40    #[inline]
41    fn view(&self) -> (i32, i32, i32, ElementKind) {
42        (
43            self.num_inputs,
44            self.embedding_dim,
45            self.num_segments,
46            self.element,
47        )
48    }
49}
50
51/// Args bundle for a `segment_max` launch.
52pub struct SegmentMaxArgs<'a, T: Element> {
53    /// Input `[N, D]`.
54    pub input: TensorRef<'a, T, 2>,
55    /// Segment ids `[N]` — sorted non-decreasing.
56    pub segment_ids: TensorRef<'a, i32, 1>,
57    /// Output `[num_segments, D]`.
58    pub output: TensorMut<'a, T, 2>,
59}
60
61/// `segment_max` plan (sorted, FW only).
62///
63/// `out[s, d] = max_{n : segment_ids[n] == s} input[n, d]` (TF / JAX
64/// `segment_max`). Requires sorted segment_ids.
65///
66/// **When to use**: forward sorted segment-max. **No BW plan** —
67/// argmax tracking from FW is needed (deferred).
68///
69/// **Dtypes**: `{f32, f64}`.
70///
71/// **Shape limits**: `input` `[N, D]`; `segment_ids` `[N]`;
72/// `output` `[num_segments, D]`.
73///
74/// **Workspace**: none.
75///
76/// **Precision guarantee**: deterministic, bit-stable.
77///
78/// **Index policy**: out-of-range IDs dropped. Empty segments emit
79/// zero (per-op identity sentinel from `segment_sorted_kernel`).
80pub struct SegmentMaxPlan<T: Element> {
81    desc: SegmentMaxDescriptor,
82    sku: KernelSku,
83    _marker: PhantomData<T>,
84}
85
86impl<T: Element> SegmentMaxPlan<T> {
87    /// Pick a kernel.
88    pub fn select(
89        _stream: &Stream,
90        desc: &SegmentMaxDescriptor,
91        _pref: PlanPreference,
92    ) -> Result<Self> {
93        validate_desc(*desc, T::KIND, "SegmentMaxPlan")?;
94        Ok(Self {
95            desc: *desc,
96            sku: build_sku::<T>(SegmentKind::SegmentMax),
97            _marker: PhantomData,
98        })
99    }
100
101    /// Validate args.
102    pub fn can_implement(&self, args: &SegmentMaxArgs<'_, T>) -> Result<()> {
103        let proxy = SegmentSumDescriptor {
104            num_inputs: self.desc.num_inputs,
105            embedding_dim: self.desc.embedding_dim,
106            num_segments: self.desc.num_segments,
107            element: self.desc.element,
108        };
109        validate_args(
110            &proxy,
111            args.input.shape,
112            args.segment_ids.shape,
113            args.output.shape,
114            "SegmentMaxPlan",
115        )
116    }
117
118    /// Workspace size — zero.
119    #[inline]
120    pub fn workspace_size(&self) -> usize {
121        0
122    }
123
124    /// Identity of the kernel this plan picked.
125    #[inline]
126    pub fn sku(&self) -> KernelSku {
127        self.sku
128    }
129
130    /// Numerical guarantees.
131    #[inline]
132    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
133        self.sku.precision_guarantee
134    }
135
136    /// Launch.
137    pub fn run(
138        &self,
139        stream: &Stream,
140        _workspace: Workspace<'_>,
141        args: SegmentMaxArgs<'_, T>,
142    ) -> Result<()> {
143        self.can_implement(&args)?;
144        let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
145        if total_out == 0 {
146            return Ok(());
147        }
148        run_sorted_fw::<T>(
149            stream,
150            self.desc.num_inputs,
151            self.desc.embedding_dim,
152            self.desc.num_segments,
153            &args.input,
154            &args.segment_ids,
155            &args.output,
156            SortedFwOp::Max,
157        )
158    }
159}