baracuda_kernels/sort/
histogramdd.rs1use core::marker::PhantomData;
9
10use baracuda_cutlass::{Error, Result};
11use baracuda_driver::Stream;
12use baracuda_kernels_types::{
13 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
14 TensorRef, Workspace,
15};
16
17use super::histogram::build_atomic_sku;
18
19#[derive(Copy, Clone, Debug)]
21pub struct HistogramddDescriptor {
22 pub numel: i64,
24 pub ndim: i32,
26 pub element: ElementKind,
28}
29
30pub struct HistogramddArgs<'a, T: Element> {
32 pub input: TensorRef<'a, T, 2>,
34 pub output: TensorMut<'a, i32, 1>,
36}
37
38pub struct HistogramddPlan<T: Element> {
50 _desc: HistogramddDescriptor,
51 _sku: KernelSku,
52 _marker: PhantomData<T>,
53}
54
55impl<T: Element> HistogramddPlan<T> {
56 pub fn select(
58 _stream: &Stream,
59 desc: &HistogramddDescriptor,
60 _pref: PlanPreference,
61 ) -> Result<Self> {
62 if desc.element != T::KIND {
63 return Err(Error::Unsupported(
64 "baracuda-kernels::HistogramddPlan: descriptor element != type parameter T",
65 ));
66 }
67 if desc.ndim != 1 {
68 return Err(Error::Unsupported(
69 "baracuda-kernels::HistogramddPlan: ndim > 1 not supported in the trailblazer \
70 (use HistogramPlan for the 1-D path)",
71 ));
72 }
73 Err(Error::Unsupported(
74 "baracuda-kernels::HistogramddPlan: reserved API surface — use HistogramPlan for \
75 the 1-D case",
76 ))
77 }
78
79 #[inline]
81 pub fn workspace_size(&self) -> usize {
82 0
83 }
84
85 #[inline]
87 pub fn sku(&self) -> KernelSku {
88 self._sku
89 }
90
91 #[inline]
93 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
94 self._sku.precision_guarantee
95 }
96
97 pub fn can_implement(&self, _args: &HistogramddArgs<'_, T>) -> Result<()> {
99 Err(Error::Unsupported(
100 "baracuda-kernels::HistogramddPlan: reserved API surface",
101 ))
102 }
103
104 pub fn run(
106 &self,
107 _stream: &Stream,
108 _workspace: Workspace<'_>,
109 _args: HistogramddArgs<'_, T>,
110 ) -> Result<()> {
111 Err(Error::Unsupported(
112 "baracuda-kernels::HistogramddPlan: reserved API surface",
113 ))
114 }
115}
116
117#[allow(dead_code)]
120fn _anchor<T: Element>() -> KernelSku {
121 build_atomic_sku::<T>(SortKind::Histogramdd)
122}