Skip to main content

baracuda_kernels/segment/
segment_min.rs

1//! `segment_min` plan — Category S, sorted variant.
2//!
3//! `out[s, d] = min_{n : segment_ids[n] == s} input[n, d]`. Requires
4//! `segment_ids` to be monotonically non-decreasing. TF / JAX
5//! `segment_min`.
6//!
7//! FW only. BW deferred (argmin tracking).
8
9use core::marker::PhantomData;
10
11use baracuda_cutlass::Result;
12use baracuda_driver::Stream;
13use baracuda_kernels_types::{
14    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
15    TensorRef, Workspace,
16};
17
18use super::segment_sum::{
19    build_sku, run_sorted_fw, validate_args, validate_desc, SegDescView, SegmentSumDescriptor,
20    SortedFwOp,
21};
22
23/// Descriptor for a `segment_min` op.
24#[derive(Copy, Clone, Debug)]
25pub struct SegmentMinDescriptor {
26    /// Number of input rows.
27    pub num_inputs: i32,
28    /// Embedding / feature dim.
29    pub embedding_dim: i32,
30    /// Total number of segments.
31    pub num_segments: i32,
32    /// Value element type.
33    pub element: ElementKind,
34}
35
36impl SegDescView for SegmentMinDescriptor {
37    #[inline]
38    fn view(&self) -> (i32, i32, i32, ElementKind) {
39        (
40            self.num_inputs,
41            self.embedding_dim,
42            self.num_segments,
43            self.element,
44        )
45    }
46}
47
48/// Args bundle for a `segment_min` launch.
49pub struct SegmentMinArgs<'a, T: Element> {
50    /// Input `[N, D]`.
51    pub input: TensorRef<'a, T, 2>,
52    /// Segment ids `[N]` — sorted non-decreasing.
53    pub segment_ids: TensorRef<'a, i32, 1>,
54    /// Output `[num_segments, D]`.
55    pub output: TensorMut<'a, T, 2>,
56}
57
58/// `segment_min` plan (sorted, FW only).
59///
60/// `out[s, d] = min_{n : segment_ids[n] == s} input[n, d]`. Sorted
61/// segment-min mirror of [`SegmentMaxPlan`](crate::SegmentMaxPlan).
62///
63/// **When to use**: forward sorted segment-min. **No BW plan** —
64/// argmin tracking deferred.
65///
66/// **Dtypes**: `{f32, f64}`.
67///
68/// **Shape limits**: `input` `[N, D]`; `segment_ids` `[N]`;
69/// `output` `[num_segments, D]`.
70///
71/// **Workspace**: none.
72///
73/// **Precision guarantee**: deterministic, bit-stable.
74///
75/// **Index policy**: out-of-range IDs dropped. Empty segments emit
76/// the per-op identity sentinel.
77pub struct SegmentMinPlan<T: Element> {
78    desc: SegmentMinDescriptor,
79    sku: KernelSku,
80    _marker: PhantomData<T>,
81}
82
83impl<T: Element> SegmentMinPlan<T> {
84    /// Pick a kernel.
85    pub fn select(
86        _stream: &Stream,
87        desc: &SegmentMinDescriptor,
88        _pref: PlanPreference,
89    ) -> Result<Self> {
90        validate_desc(*desc, T::KIND, "SegmentMinPlan")?;
91        Ok(Self {
92            desc: *desc,
93            sku: build_sku::<T>(SegmentKind::SegmentMin),
94            _marker: PhantomData,
95        })
96    }
97
98    /// Validate args.
99    pub fn can_implement(&self, args: &SegmentMinArgs<'_, T>) -> Result<()> {
100        let proxy = SegmentSumDescriptor {
101            num_inputs: self.desc.num_inputs,
102            embedding_dim: self.desc.embedding_dim,
103            num_segments: self.desc.num_segments,
104            element: self.desc.element,
105        };
106        validate_args(
107            &proxy,
108            args.input.shape,
109            args.segment_ids.shape,
110            args.output.shape,
111            "SegmentMinPlan",
112        )
113    }
114
115    /// Workspace size — zero.
116    #[inline]
117    pub fn workspace_size(&self) -> usize {
118        0
119    }
120
121    /// Identity of the kernel.
122    #[inline]
123    pub fn sku(&self) -> KernelSku {
124        self.sku
125    }
126
127    /// Numerical guarantees.
128    #[inline]
129    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
130        self.sku.precision_guarantee
131    }
132
133    /// Launch.
134    pub fn run(
135        &self,
136        stream: &Stream,
137        _workspace: Workspace<'_>,
138        args: SegmentMinArgs<'_, T>,
139    ) -> Result<()> {
140        self.can_implement(&args)?;
141        let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
142        if total_out == 0 {
143            return Ok(());
144        }
145        run_sorted_fw::<T>(
146            stream,
147            self.desc.num_inputs,
148            self.desc.embedding_dim,
149            self.desc.num_segments,
150            &args.input,
151            &args.segment_ids,
152            &args.output,
153            SortedFwOp::Min,
154        )
155    }
156}