Skip to main content

baracuda_kernels/sort/
unique_consecutive.rs

1//! `unique_consecutive` plan — emit one cell per run-start in each row.
2//!
3//! Input must be pre-sorted (or the user only wants consecutive-equal
4//! runs collapsed — the PyTorch `torch.unique_consecutive` semantics).
5//!
6//! Output is **NOT input-order** — slot assignment uses a per-row
7//! atomic counter (block-race order). Callers that need input-order
8//! output should issue a follow-up sort on `[batch, counter]` rows.
9//! The per-row count is written to `counter[batch]` (a separate
10//! tensor) — callers read it post-launch to learn the actual unique
11//! count per row.
12//!
13//! Trailblazer dtype coverage: `f32, f64, i32`. Set-valued — no BW.
14
15use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22    PlanPreference, PrecisionGuarantee, SortKind, TensorMut, TensorRef, Workspace,
23};
24
25use super::map_status;
26
27/// Descriptor for a `unique_consecutive` op.
28#[derive(Copy, Clone, Debug)]
29pub struct UniqueConsecutiveDescriptor {
30    /// Number of independent rows.
31    pub batch: i32,
32    /// Length of each input row.
33    pub row_len: i32,
34    /// Maximum number of unique values the output table can hold per
35    /// row. Set to `row_len` for a worst-case-safe bound.
36    pub max_unique: i32,
37    /// Whether to emit per-run counts (`y_counts`). Today the kernel
38    /// writes `1` per detected run-start regardless; this flag is
39    /// reserved for future counts-aware variants.
40    pub return_counts: bool,
41    /// Value element type.
42    pub element: ElementKind,
43}
44
45/// Args bundle for a `unique_consecutive` launch.
46pub struct UniqueConsecutiveArgs<'a, T: Element> {
47    /// Input `[batch, row_len]`.
48    pub input: TensorRef<'a, T, 2>,
49    /// Output values `[batch, max_unique]` (filled left-to-right per
50    /// row up to the actual unique count).
51    pub values: TensorMut<'a, T, 2>,
52    /// Optional output per-cell counts `[batch, max_unique]`.
53    pub counts: TensorMut<'a, i32, 2>,
54    /// Per-row counter `[batch]` — post-launch holds the actual
55    /// unique count per row.
56    pub counter: TensorMut<'a, i32, 1>,
57}
58
59/// `unique_consecutive` plan.
60///
61/// Emits one cell per run-start in each row (PyTorch
62/// `torch.unique_consecutive`). Input is expected to be pre-sorted,
63/// or the caller wants only consecutive-equal runs collapsed.
64///
65/// **When to use**: dedup runs after a sort, or unique-detection on
66/// already-sorted data. For full unique-of-row use
67/// [`UniquePlan`](crate::UniquePlan), which composes sort + this.
68/// Set-valued — no BW.
69///
70/// **Dtypes**: `{f32, f64, i32}`.
71///
72/// **Shape limits**: `[batch, row_len]` input; `[batch, max_unique]`
73/// outputs; `[batch]` counter. `max_unique == row_len` is the
74/// worst-case-safe bound.
75///
76/// **Workspace**: none.
77///
78/// **Precision guarantee**: **non-deterministic ordering** — slot
79/// assignment uses a per-row atomic counter (block-race order). The
80/// *set* of detected uniques is data-determined; only the row
81/// ordering varies. Callers needing input-order output should sort
82/// the result on `[batch, counter]` rows afterward.
83pub struct UniqueConsecutivePlan<T: Element> {
84    desc: UniqueConsecutiveDescriptor,
85    sku: KernelSku,
86    _marker: PhantomData<T>,
87}
88
89impl<T: Element> UniqueConsecutivePlan<T> {
90    /// Pick a kernel for `desc`.
91    pub fn select(
92        _stream: &Stream,
93        desc: &UniqueConsecutiveDescriptor,
94        _pref: PlanPreference,
95    ) -> Result<Self> {
96        if desc.element != T::KIND {
97            return Err(Error::Unsupported(
98                "baracuda-kernels::UniqueConsecutivePlan: descriptor element != type parameter T",
99            ));
100        }
101        if desc.batch < 0 || desc.row_len < 0 || desc.max_unique < 0 {
102            return Err(Error::InvalidProblem(
103                "baracuda-kernels::UniqueConsecutivePlan: batch / row_len / max_unique \
104                 must be non-negative",
105            ));
106        }
107        if !matches!(
108            T::KIND,
109            ElementKind::F32 | ElementKind::F64 | ElementKind::I32
110        ) {
111            return Err(Error::Unsupported(
112                "baracuda-kernels::UniqueConsecutivePlan: today only f32 / f64 / i32 wired",
113            ));
114        }
115        let sku = build_unique_sku::<T>(SortKind::UniqueConsecutive);
116        Ok(Self {
117            desc: *desc,
118            sku,
119            _marker: PhantomData,
120        })
121    }
122
123    /// Validate args.
124    pub fn can_implement(&self, args: &UniqueConsecutiveArgs<'_, T>) -> Result<()> {
125        if args.input.shape != [self.desc.batch, self.desc.row_len] {
126            return Err(Error::InvalidProblem(
127                "baracuda-kernels::UniqueConsecutivePlan: input shape != [batch, row_len]",
128            ));
129        }
130        if args.values.shape != [self.desc.batch, self.desc.max_unique] {
131            return Err(Error::InvalidProblem(
132                "baracuda-kernels::UniqueConsecutivePlan: values shape != [batch, max_unique]",
133            ));
134        }
135        if args.counts.shape != [self.desc.batch, self.desc.max_unique] {
136            return Err(Error::InvalidProblem(
137                "baracuda-kernels::UniqueConsecutivePlan: counts shape != [batch, max_unique]",
138            ));
139        }
140        if args.counter.shape != [self.desc.batch] {
141            return Err(Error::InvalidProblem(
142                "baracuda-kernels::UniqueConsecutivePlan: counter shape != [batch]",
143            ));
144        }
145        Ok(())
146    }
147
148    /// Workspace size in bytes.
149    #[inline]
150    pub fn workspace_size(&self) -> usize {
151        0
152    }
153
154    /// Identity of the kernel this plan picked.
155    #[inline]
156    pub fn sku(&self) -> KernelSku {
157        self.sku
158    }
159
160    /// Numerical guarantees for this plan's kernel.
161    #[inline]
162    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
163        self.sku.precision_guarantee
164    }
165
166    /// Launch.
167    pub fn run(
168        &self,
169        stream: &Stream,
170        _workspace: Workspace<'_>,
171        args: UniqueConsecutiveArgs<'_, T>,
172    ) -> Result<()> {
173        self.can_implement(&args)?;
174        if self.desc.batch == 0 {
175            return Ok(());
176        }
177        let in_ptr = args.input.data.as_raw().0 as *const c_void;
178        let vals_ptr = args.values.data.as_raw().0 as *mut c_void;
179        let counts_ptr = args.counts.data.as_raw().0 as *mut c_void;
180        let counter_ptr = args.counter.data.as_raw().0 as *mut c_void;
181        let stream_ptr = stream.as_raw() as *mut c_void;
182
183        let status = match T::KIND {
184            ElementKind::F32 => unsafe {
185                baracuda_kernels_sys::baracuda_kernels_unique_consecutive_f32_run(
186                    self.desc.batch,
187                    self.desc.row_len,
188                    self.desc.max_unique,
189                    in_ptr,
190                    vals_ptr,
191                    counts_ptr,
192                    counter_ptr,
193                    core::ptr::null_mut(),
194                    0,
195                    stream_ptr,
196                )
197            },
198            ElementKind::F64 => unsafe {
199                baracuda_kernels_sys::baracuda_kernels_unique_consecutive_f64_run(
200                    self.desc.batch,
201                    self.desc.row_len,
202                    self.desc.max_unique,
203                    in_ptr,
204                    vals_ptr,
205                    counts_ptr,
206                    counter_ptr,
207                    core::ptr::null_mut(),
208                    0,
209                    stream_ptr,
210                )
211            },
212            ElementKind::I32 => unsafe {
213                baracuda_kernels_sys::baracuda_kernels_unique_consecutive_i32_run(
214                    self.desc.batch,
215                    self.desc.row_len,
216                    self.desc.max_unique,
217                    in_ptr,
218                    vals_ptr,
219                    counts_ptr,
220                    counter_ptr,
221                    core::ptr::null_mut(),
222                    0,
223                    stream_ptr,
224                )
225            },
226            _ => {
227                return Err(Error::Unsupported(
228                    "baracuda-kernels::UniqueConsecutivePlan::run reached an unimplemented dtype",
229                ));
230            }
231        };
232        map_status(status)
233    }
234}
235
236/// Build SKU for unique-family ops — atomic-counter output is NOT
237/// deterministic in slot order, so we tag accordingly.
238pub(crate) fn build_unique_sku<T: Element>(op: SortKind) -> KernelSku {
239    let precision_guarantee = PrecisionGuarantee {
240        math_precision: if T::KIND == ElementKind::F64 {
241            MathPrecision::F64
242        } else {
243            MathPrecision::F32
244        },
245        accumulator: T::KIND,
246        bit_stable_on_same_hardware: false,
247        deterministic: false,
248    };
249    KernelSku {
250        category: OpCategory::Sorting,
251        op: op as u16,
252        element: T::KIND,
253        aux_element: Some(ElementKind::I32),
254        layout: None,
255        epilogue: None,
256        arch: ArchSku::Sm80,
257        backend: BackendKind::Bespoke,
258        precision_guarantee,
259    }
260}