Skip to main content

baracuda_kernels/sort/
histogramdd.rs

1//! `histogramdd` plan — N-D histogram. Reserved for follow-up.
2//!
3//! The 1-D path lives in [`crate::sort::HistogramPlan`]. This file
4//! exists as the public API shape (descriptor / args / plan structs)
5//! to keep the surface stable; `select` returns `Unsupported` for the
6//! N > 1 case in the trailblazer.
7
8use core::marker::PhantomData;
9
10use baracuda_cutlass::{Error, Result};
11use baracuda_driver::Stream;
12use baracuda_kernels_types::{
13    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
14    TensorRef, Workspace,
15};
16
17use super::histogram::build_atomic_sku;
18
19/// Descriptor for a `histogramdd` op (reserved).
20#[derive(Copy, Clone, Debug)]
21pub struct HistogramddDescriptor {
22    /// Number of input samples.
23    pub numel: i64,
24    /// Number of dimensions.
25    pub ndim: i32,
26    /// Element type.
27    pub element: ElementKind,
28}
29
30/// Args bundle for a `histogramdd` launch.
31pub struct HistogramddArgs<'a, T: Element> {
32    /// Input `[numel, ndim]`.
33    pub input: TensorRef<'a, T, 2>,
34    /// Output `[product(num_bins_per_dim)]` (i32).
35    pub output: TensorMut<'a, i32, 1>,
36}
37
38/// `histogramdd` plan (reserved — returns `Unsupported`).
39///
40/// **Status**: API stub. `select()` always returns `Unsupported`
41/// in the trailblazer; use [`HistogramPlan`](crate::HistogramPlan)
42/// for 1-D histograms today. This file pins the public surface
43/// (`Descriptor` / `Args` / `Plan` struct names) so callers can
44/// type-check against the eventual N-D path without churn.
45///
46/// **When the real kernel lands**: PyTorch `torch.histogramdd`
47/// shape — input `[numel, ndim]`, output flat
48/// `[prod(num_bins_per_dim)]`.
49pub struct HistogramddPlan<T: Element> {
50    _desc: HistogramddDescriptor,
51    _sku: KernelSku,
52    _marker: PhantomData<T>,
53}
54
55impl<T: Element> HistogramddPlan<T> {
56    /// Pick a kernel for `desc` — returns `Unsupported` in trailblazer.
57    pub fn select(
58        _stream: &Stream,
59        desc: &HistogramddDescriptor,
60        _pref: PlanPreference,
61    ) -> Result<Self> {
62        if desc.element != T::KIND {
63            return Err(Error::Unsupported(
64                "baracuda-kernels::HistogramddPlan: descriptor element != type parameter T",
65            ));
66        }
67        if desc.ndim != 1 {
68            return Err(Error::Unsupported(
69                "baracuda-kernels::HistogramddPlan: ndim > 1 not supported in the trailblazer \
70                 (use HistogramPlan for the 1-D path)",
71            ));
72        }
73        Err(Error::Unsupported(
74            "baracuda-kernels::HistogramddPlan: reserved API surface — use HistogramPlan for \
75             the 1-D case",
76        ))
77    }
78
79    /// Workspace size in bytes.
80    #[inline]
81    pub fn workspace_size(&self) -> usize {
82        0
83    }
84
85    /// Identity of the kernel this plan picked.
86    #[inline]
87    pub fn sku(&self) -> KernelSku {
88        self._sku
89    }
90
91    /// Numerical guarantees for this plan's kernel.
92    #[inline]
93    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
94        self._sku.precision_guarantee
95    }
96
97    /// Validate args — always returns `Unsupported`.
98    pub fn can_implement(&self, _args: &HistogramddArgs<'_, T>) -> Result<()> {
99        Err(Error::Unsupported(
100            "baracuda-kernels::HistogramddPlan: reserved API surface",
101        ))
102    }
103
104    /// Launch — always returns `Unsupported`.
105    pub fn run(
106        &self,
107        _stream: &Stream,
108        _workspace: Workspace<'_>,
109        _args: HistogramddArgs<'_, T>,
110    ) -> Result<()> {
111        Err(Error::Unsupported(
112            "baracuda-kernels::HistogramddPlan: reserved API surface",
113        ))
114    }
115}
116
117// Anchor `build_atomic_sku` for the future N-D path so the import is
118// kept warm. (Drop this once the real implementation lands.)
119#[allow(dead_code)]
120fn _anchor<T: Element>() -> KernelSku {
121    build_atomic_sku::<T>(SortKind::Histogramdd)
122}