Skip to main content

baracuda_kernels/segment/
unsorted_segment_min.rs

1//! `unsorted_segment_min` plan — Category S, unsorted variant.
2//!
3//! `out[s, d] = min_{n : segment_ids[n] == s} input[n, d]`. Output is
4//! pre-initialized to `+∞` by the launcher; then atomic-min-via-CAS.
5//!
6//! FW only. BW deferred (argmin tracking).
7
8use core::ffi::c_void;
9use core::marker::PhantomData;
10
11use baracuda_cutlass::{Error, Result};
12use baracuda_driver::Stream;
13use baracuda_kernels_types::{
14    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
15    TensorRef, Workspace,
16};
17
18use super::map_status;
19use super::segment_sum::{validate_desc, SegDescView};
20use super::unsorted_segment_sum::{build_unsorted_sku, validate_unsorted_args};
21
22/// Descriptor for an `unsorted_segment_min` op.
23#[derive(Copy, Clone, Debug)]
24pub struct UnsortedSegmentMinDescriptor {
25    /// Number of input rows.
26    pub num_inputs: i32,
27    /// Embedding / feature dim.
28    pub embedding_dim: i32,
29    /// Total number of segments.
30    pub num_segments: i32,
31    /// Value element type.
32    pub element: ElementKind,
33}
34
35impl SegDescView for UnsortedSegmentMinDescriptor {
36    #[inline]
37    fn view(&self) -> (i32, i32, i32, ElementKind) {
38        (
39            self.num_inputs,
40            self.embedding_dim,
41            self.num_segments,
42            self.element,
43        )
44    }
45}
46
47/// Args bundle for an `unsorted_segment_min` launch.
48pub struct UnsortedSegmentMinArgs<'a, T: Element> {
49    /// Input `[N, D]`.
50    pub input: TensorRef<'a, T, 2>,
51    /// Segment ids `[N]`, any order.
52    pub segment_ids: TensorRef<'a, i32, 1>,
53    /// Output `[num_segments, D]`.
54    pub output: TensorMut<'a, T, 2>,
55}
56
57/// `unsorted_segment_min` plan.
58///
59/// `out[s, d] = min input[n, d]` over `n : segment_ids[n] == s`, with
60/// IDs in any order. Mirror of
61/// [`UnsortedSegmentMaxPlan`](crate::UnsortedSegmentMaxPlan); uses
62/// `atomicMin`-emulated CAS retry.
63///
64/// **When to use**: forward unsorted segment-min. **No BW plan** —
65/// argmin tracking deferred.
66///
67/// **Dtypes**: `{f32, f64}`.
68///
69/// **Shape limits**: `input` `[N, D]`; `segment_ids` `[N]`;
70/// `output` `[num_segments, D]`. Empty segments emit
71/// positive-infinity identity.
72///
73/// **Workspace**: none.
74///
75/// **Precision guarantee**: **non-deterministic**.
76pub struct UnsortedSegmentMinPlan<T: Element> {
77    desc: UnsortedSegmentMinDescriptor,
78    sku: KernelSku,
79    _marker: PhantomData<T>,
80}
81
82impl<T: Element> UnsortedSegmentMinPlan<T> {
83    /// Pick a kernel.
84    pub fn select(
85        _stream: &Stream,
86        desc: &UnsortedSegmentMinDescriptor,
87        _pref: PlanPreference,
88    ) -> Result<Self> {
89        validate_desc(*desc, T::KIND, "UnsortedSegmentMinPlan")?;
90        Ok(Self {
91            desc: *desc,
92            sku: build_unsorted_sku::<T>(SegmentKind::UnsortedSegmentMin),
93            _marker: PhantomData,
94        })
95    }
96
97    /// Validate args.
98    pub fn can_implement(&self, args: &UnsortedSegmentMinArgs<'_, T>) -> Result<()> {
99        validate_unsorted_args(
100            self.desc.num_inputs,
101            self.desc.embedding_dim,
102            self.desc.num_segments,
103            args.input.shape,
104            args.segment_ids.shape,
105            args.output.shape,
106            "UnsortedSegmentMinPlan",
107        )
108    }
109
110    /// Workspace size — zero.
111    #[inline]
112    pub fn workspace_size(&self) -> usize {
113        0
114    }
115
116    /// Identity of the kernel.
117    #[inline]
118    pub fn sku(&self) -> KernelSku {
119        self.sku
120    }
121
122    /// Numerical guarantees.
123    #[inline]
124    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
125        self.sku.precision_guarantee
126    }
127
128    /// Launch.
129    pub fn run(
130        &self,
131        stream: &Stream,
132        _workspace: Workspace<'_>,
133        args: UnsortedSegmentMinArgs<'_, T>,
134    ) -> Result<()> {
135        self.can_implement(&args)?;
136        let total = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
137        if total == 0 {
138            return Ok(());
139        }
140        let in_ptr = args.input.data.as_raw().0 as *const c_void;
141        let id_ptr = args.segment_ids.data.as_raw().0 as *const c_void;
142        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
143        let stream_ptr = stream.as_raw() as *mut c_void;
144        let status = match T::KIND {
145            ElementKind::F32 => unsafe {
146                baracuda_kernels_sys::baracuda_kernels_unsorted_segment_min_f32_run(
147                    self.desc.num_inputs,
148                    self.desc.embedding_dim,
149                    self.desc.num_segments,
150                    in_ptr,
151                    id_ptr,
152                    out_ptr,
153                    core::ptr::null_mut(),
154                    0,
155                    stream_ptr,
156                )
157            },
158            ElementKind::F64 => unsafe {
159                baracuda_kernels_sys::baracuda_kernels_unsorted_segment_min_f64_run(
160                    self.desc.num_inputs,
161                    self.desc.embedding_dim,
162                    self.desc.num_segments,
163                    in_ptr,
164                    id_ptr,
165                    out_ptr,
166                    core::ptr::null_mut(),
167                    0,
168                    stream_ptr,
169                )
170            },
171            _ => {
172                return Err(Error::Unsupported(
173                    "baracuda-kernels::UnsortedSegmentMinPlan::run reached an unimplemented dtype",
174                ));
175            }
176        };
177        map_status(status)
178    }
179}