use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
TensorRef, Workspace,
};
use super::histogram::build_atomic_sku;
#[derive(Copy, Clone, Debug)]
pub struct HistogramddDescriptor {
pub numel: i64,
pub ndim: i32,
pub element: ElementKind,
}
pub struct HistogramddArgs<'a, T: Element> {
pub input: TensorRef<'a, T, 2>,
pub output: TensorMut<'a, i32, 1>,
}
pub struct HistogramddPlan<T: Element> {
_desc: HistogramddDescriptor,
_sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> HistogramddPlan<T> {
pub fn select(
_stream: &Stream,
desc: &HistogramddDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::HistogramddPlan: descriptor element != type parameter T",
));
}
if desc.ndim != 1 {
return Err(Error::Unsupported(
"baracuda-kernels::HistogramddPlan: ndim > 1 not supported in the trailblazer \
(use HistogramPlan for the 1-D path)",
));
}
Err(Error::Unsupported(
"baracuda-kernels::HistogramddPlan: reserved API surface — use HistogramPlan for \
the 1-D case",
))
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self._sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self._sku.precision_guarantee
}
pub fn can_implement(&self, _args: &HistogramddArgs<'_, T>) -> Result<()> {
Err(Error::Unsupported(
"baracuda-kernels::HistogramddPlan: reserved API surface",
))
}
pub fn run(
&self,
_stream: &Stream,
_workspace: Workspace<'_>,
_args: HistogramddArgs<'_, T>,
) -> Result<()> {
Err(Error::Unsupported(
"baracuda-kernels::HistogramddPlan: reserved API surface",
))
}
}
#[allow(dead_code)]
fn _anchor<T: Element>() -> KernelSku {
build_atomic_sku::<T>(SortKind::Histogramdd)
}