Skip to main content

baracuda_kernels/segment/
unsorted_segment_prod.rs

1//! `unsorted_segment_prod` plan — Category S, unsorted. Phase 25.
2//!
3//! `out[s, d] = Π_{n : segment_ids[n] == s} input[n, d]` with arbitrary
4//! `segment_ids` ordering. The kernel fills `output` with `1.0` then
5//! performs `atomicMul`-via-CAS into `output[seg[n], d]` per input cell.
6//!
7//! No native FP `atomicMul` exists; we implement it as an `atomicCAS`
8//! retry loop on the underlying 32 / 64-bit slot. This is slower than
9//! the additive variants but allowed per the OP-MATRIX (segment ops
10//! contract).
11//!
12//! Non-deterministic — atomic ordering varies across launches.
13//!
14//! Dtype coverage: `f32, f64`.
15
16use core::ffi::c_void;
17use core::marker::PhantomData;
18
19use baracuda_cutlass::{Error, Result};
20use baracuda_driver::Stream;
21use baracuda_kernels_types::{
22    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
23    TensorRef, Workspace,
24};
25
26use super::map_status;
27use super::segment_sum::{validate_desc, SegDescView};
28use super::unsorted_segment_sum::{build_unsorted_sku, validate_unsorted_args};
29
30/// Descriptor for an `unsorted_segment_prod` op.
31#[derive(Copy, Clone, Debug)]
32pub struct UnsortedSegmentProdDescriptor {
33    /// Number of input rows.
34    pub num_inputs: i32,
35    /// Embedding / feature dim.
36    pub embedding_dim: i32,
37    /// Total number of segments.
38    pub num_segments: i32,
39    /// Value element type.
40    pub element: ElementKind,
41}
42
43impl SegDescView for UnsortedSegmentProdDescriptor {
44    #[inline]
45    fn view(&self) -> (i32, i32, i32, ElementKind) {
46        (
47            self.num_inputs,
48            self.embedding_dim,
49            self.num_segments,
50            self.element,
51        )
52    }
53}
54
55/// Args bundle for an `unsorted_segment_prod` launch.
56pub struct UnsortedSegmentProdArgs<'a, T: Element> {
57    /// Input `[N, D]`.
58    pub input: TensorRef<'a, T, 2>,
59    /// Segment ids `[N]`, i32, in any order.
60    pub segment_ids: TensorRef<'a, i32, 1>,
61    /// Output `[num_segments, D]`. Overwritten by the launch — kernel
62    /// fills `1.0` before the atomic accumulation phase.
63    pub output: TensorMut<'a, T, 2>,
64}
65
66/// `unsorted_segment_prod` plan. Phase 25.
67///
68/// `out[s, d] = Π input[n, d]` over `n : segment_ids[n] == s` (any
69/// order). `atomicMul`-emulated CAS retry loop.
70///
71/// **When to use**: forward unsorted segment-product. For sorted IDs
72/// prefer the deterministic
73/// [`SegmentProdPlan`](crate::SegmentProdPlan). BW shares the
74/// [`SegmentProdBackwardPlan`](crate::SegmentProdBackwardPlan) shape.
75///
76/// **Dtypes**: `{f32, f64}` only — `atomicCAS` slot widths are 32 / 64
77/// bit. Empty segments emit `1.0` (multiplicative identity).
78///
79/// **Workspace**: none.
80///
81/// **Precision guarantee**: **non-deterministic** — atomic ordering
82/// varies across launches.
83pub struct UnsortedSegmentProdPlan<T: Element> {
84    desc: UnsortedSegmentProdDescriptor,
85    sku: KernelSku,
86    _marker: PhantomData<T>,
87}
88
89impl<T: Element> UnsortedSegmentProdPlan<T> {
90    /// Pick a kernel.
91    pub fn select(
92        _stream: &Stream,
93        desc: &UnsortedSegmentProdDescriptor,
94        _pref: PlanPreference,
95    ) -> Result<Self> {
96        validate_desc(*desc, T::KIND, "UnsortedSegmentProdPlan")?;
97        Ok(Self {
98            desc: *desc,
99            sku: build_unsorted_sku::<T>(SegmentKind::UnsortedSegmentProd),
100            _marker: PhantomData,
101        })
102    }
103
104    /// Validate args.
105    pub fn can_implement(&self, args: &UnsortedSegmentProdArgs<'_, T>) -> Result<()> {
106        validate_unsorted_args(
107            self.desc.num_inputs,
108            self.desc.embedding_dim,
109            self.desc.num_segments,
110            args.input.shape,
111            args.segment_ids.shape,
112            args.output.shape,
113            "UnsortedSegmentProdPlan",
114        )
115    }
116
117    /// Workspace size — zero.
118    #[inline]
119    pub fn workspace_size(&self) -> usize {
120        0
121    }
122
123    /// Identity of the kernel.
124    #[inline]
125    pub fn sku(&self) -> KernelSku {
126        self.sku
127    }
128
129    /// Numerical guarantees.
130    #[inline]
131    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
132        self.sku.precision_guarantee
133    }
134
135    /// Launch.
136    pub fn run(
137        &self,
138        stream: &Stream,
139        _workspace: Workspace<'_>,
140        args: UnsortedSegmentProdArgs<'_, T>,
141    ) -> Result<()> {
142        self.can_implement(&args)?;
143        let total = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
144        if total == 0 {
145            return Ok(());
146        }
147        let in_ptr = args.input.data.as_raw().0 as *const c_void;
148        let id_ptr = args.segment_ids.data.as_raw().0 as *const c_void;
149        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
150        let stream_ptr = stream.as_raw() as *mut c_void;
151        let status = match T::KIND {
152            ElementKind::F32 => unsafe {
153                baracuda_kernels_sys::baracuda_kernels_unsorted_segment_prod_f32_run(
154                    self.desc.num_inputs,
155                    self.desc.embedding_dim,
156                    self.desc.num_segments,
157                    in_ptr,
158                    id_ptr,
159                    out_ptr,
160                    core::ptr::null_mut(),
161                    0,
162                    stream_ptr,
163                )
164            },
165            ElementKind::F64 => unsafe {
166                baracuda_kernels_sys::baracuda_kernels_unsorted_segment_prod_f64_run(
167                    self.desc.num_inputs,
168                    self.desc.embedding_dim,
169                    self.desc.num_segments,
170                    in_ptr,
171                    id_ptr,
172                    out_ptr,
173                    core::ptr::null_mut(),
174                    0,
175                    stream_ptr,
176                )
177            },
178            _ => {
179                return Err(Error::Unsupported(
180                    "baracuda-kernels::UnsortedSegmentProdPlan::run reached an unimplemented dtype",
181                ));
182            }
183        };
184        map_status(status)
185    }
186}