baracuda_kernels/sort/
histogram.rs1use core::ffi::c_void;
7use core::marker::PhantomData;
8
9use baracuda_cutlass::{Error, Result};
10use baracuda_driver::Stream;
11use baracuda_kernels_types::{
12 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
13 PlanPreference, PrecisionGuarantee, SortKind, TensorMut, TensorRef, Workspace,
14};
15
16use super::map_status;
17
18#[derive(Copy, Clone, Debug)]
20pub struct HistogramDescriptor {
21 pub numel: i64,
23 pub num_bins: i32,
25 pub lo: f64,
27 pub hi: f64,
29 pub element: ElementKind,
31}
32
33pub struct HistogramArgs<'a, T: Element> {
35 pub input: TensorRef<'a, T, 1>,
37 pub output: TensorMut<'a, i32, 1>,
39}
40
41pub struct HistogramPlan<T: Element> {
60 desc: HistogramDescriptor,
61 sku: KernelSku,
62 _marker: PhantomData<T>,
63}
64
65impl<T: Element> HistogramPlan<T> {
66 pub fn select(
68 _stream: &Stream,
69 desc: &HistogramDescriptor,
70 _pref: PlanPreference,
71 ) -> Result<Self> {
72 if desc.element != T::KIND {
73 return Err(Error::Unsupported(
74 "baracuda-kernels::HistogramPlan: descriptor element != type parameter T",
75 ));
76 }
77 if desc.numel < 0 || desc.num_bins < 0 {
78 return Err(Error::InvalidProblem(
79 "baracuda-kernels::HistogramPlan: numel / num_bins must be non-negative",
80 ));
81 }
82 if !(desc.hi > desc.lo) {
83 return Err(Error::InvalidProblem(
84 "baracuda-kernels::HistogramPlan: hi must be > lo",
85 ));
86 }
87 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
88 return Err(Error::Unsupported(
89 "baracuda-kernels::HistogramPlan: today only f32 / f64 wired",
90 ));
91 }
92 let sku = build_atomic_sku::<T>(SortKind::Histogram);
93 Ok(Self {
94 desc: *desc,
95 sku,
96 _marker: PhantomData,
97 })
98 }
99
100 pub fn can_implement(&self, args: &HistogramArgs<'_, T>) -> Result<()> {
102 if (args.input.shape[0] as i64) != self.desc.numel {
103 return Err(Error::InvalidProblem(
104 "baracuda-kernels::HistogramPlan: input shape[0] != descriptor numel",
105 ));
106 }
107 if args.output.shape != [self.desc.num_bins] {
108 return Err(Error::InvalidProblem(
109 "baracuda-kernels::HistogramPlan: output shape != [num_bins]",
110 ));
111 }
112 Ok(())
113 }
114
115 #[inline]
117 pub fn workspace_size(&self) -> usize {
118 0
119 }
120
121 #[inline]
123 pub fn sku(&self) -> KernelSku {
124 self.sku
125 }
126
127 #[inline]
129 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
130 self.sku.precision_guarantee
131 }
132
133 pub fn run(
135 &self,
136 stream: &Stream,
137 _workspace: Workspace<'_>,
138 args: HistogramArgs<'_, T>,
139 ) -> Result<()> {
140 self.can_implement(&args)?;
141 if self.desc.num_bins == 0 {
142 return Ok(());
143 }
144 let in_ptr = args.input.data.as_raw().0 as *const c_void;
145 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
146 let stream_ptr = stream.as_raw() as *mut c_void;
147
148 let status = match T::KIND {
149 ElementKind::F32 => unsafe {
150 baracuda_kernels_sys::baracuda_kernels_histogram_f32_run(
151 self.desc.numel,
152 self.desc.num_bins,
153 self.desc.lo,
154 self.desc.hi,
155 in_ptr,
156 out_ptr,
157 core::ptr::null_mut(),
158 0,
159 stream_ptr,
160 )
161 },
162 ElementKind::F64 => unsafe {
163 baracuda_kernels_sys::baracuda_kernels_histogram_f64_run(
164 self.desc.numel,
165 self.desc.num_bins,
166 self.desc.lo,
167 self.desc.hi,
168 in_ptr,
169 out_ptr,
170 core::ptr::null_mut(),
171 0,
172 stream_ptr,
173 )
174 },
175 _ => {
176 return Err(Error::Unsupported(
177 "baracuda-kernels::HistogramPlan::run reached an unimplemented dtype",
178 ));
179 }
180 };
181 map_status(status)
182 }
183}
184
185pub(crate) fn build_atomic_sku<T: Element>(op: SortKind) -> KernelSku {
187 let precision_guarantee = PrecisionGuarantee {
188 math_precision: if T::KIND == ElementKind::F64 {
189 MathPrecision::F64
190 } else {
191 MathPrecision::F32
192 },
193 accumulator: ElementKind::I32,
194 bit_stable_on_same_hardware: true, deterministic: true,
196 };
197 KernelSku {
198 category: OpCategory::Sorting,
199 op: op as u16,
200 element: T::KIND,
201 aux_element: Some(ElementKind::I32),
202 layout: None,
203 epilogue: None,
204 arch: ArchSku::Sm80,
205 backend: BackendKind::Bespoke,
206 precision_guarantee,
207 }
208}