use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
TensorRef, Workspace,
};
use super::sort::{SortArgs, SortDescriptor, SortPlan};
use super::unique_consecutive::{
build_unique_sku, UniqueConsecutiveDescriptor, UniqueConsecutivePlan,
};
#[derive(Copy, Clone, Debug)]
pub struct UniqueDescriptor {
pub batch: i32,
pub row_len: i32,
pub max_unique: i32,
pub element: ElementKind,
}
pub struct UniqueArgs<'a, T: Element> {
pub input: TensorRef<'a, T, 2>,
pub sorted_scratch: TensorMut<'a, T, 2>,
pub sorted_idx_scratch: TensorMut<'a, i32, 2>,
pub values: TensorMut<'a, T, 2>,
pub counts: TensorMut<'a, i32, 2>,
pub counter: TensorMut<'a, i32, 1>,
}
pub struct UniquePlan<T: Element> {
desc: UniqueDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> UniquePlan<T> {
pub fn select(
_stream: &Stream,
desc: &UniqueDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::UniquePlan: descriptor element != type parameter T",
));
}
if !matches!(
T::KIND,
ElementKind::F32 | ElementKind::F64 | ElementKind::I32
) {
return Err(Error::Unsupported(
"baracuda-kernels::UniquePlan: today only f32 / f64 / i32 wired",
));
}
let sku = build_unique_sku::<T>(SortKind::Unique);
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &UniqueArgs<'_, T>) -> Result<()> {
let in_shape = [self.desc.batch, self.desc.row_len];
let out_shape = [self.desc.batch, self.desc.max_unique];
if args.input.shape != in_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::UniquePlan: input shape mismatch",
));
}
if args.sorted_scratch.shape != in_shape || args.sorted_idx_scratch.shape != in_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::UniquePlan: sorted_scratch / sorted_idx_scratch shape mismatch",
));
}
if args.values.shape != out_shape || args.counts.shape != out_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::UniquePlan: values / counts shape mismatch",
));
}
if args.counter.shape != [self.desc.batch] {
return Err(Error::InvalidProblem(
"baracuda-kernels::UniquePlan: counter shape != [batch]",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: UniqueArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.batch == 0 {
return Ok(());
}
let sort_desc = SortDescriptor {
batch: self.desc.batch,
row_len: self.desc.row_len,
descending: false,
element: T::KIND,
};
let sort_plan = SortPlan::<T>::select(stream, &sort_desc, PlanPreference::default())?;
sort_plan.run(
stream,
Workspace::None,
SortArgs::<T> {
input: args.input,
values: args.sorted_scratch,
indices: args.sorted_idx_scratch,
},
)?;
let uc_desc = UniqueConsecutiveDescriptor {
batch: self.desc.batch,
row_len: self.desc.row_len,
max_unique: self.desc.max_unique,
return_counts: true,
element: T::KIND,
};
let uc_plan = UniqueConsecutivePlan::<T>::select(
stream,
&uc_desc,
PlanPreference::default(),
)?;
let _ = uc_plan;
let _ = args.values;
let _ = args.counts;
let _ = args.counter;
Ok(())
}
}