use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut, TensorRef, Workspace,
};
use super::map_status;
#[derive(Copy, Clone, Debug)]
pub struct SegmentSumDescriptor {
pub num_inputs: i32,
pub embedding_dim: i32,
pub num_segments: i32,
pub element: ElementKind,
}
pub struct SegmentSumArgs<'a, T: Element> {
pub input: TensorRef<'a, T, 2>,
pub segment_ids: TensorRef<'a, i32, 1>,
pub output: TensorMut<'a, T, 2>,
}
pub struct SegmentSumPlan<T: Element> {
desc: SegmentSumDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> SegmentSumPlan<T> {
pub fn select(
_stream: &Stream,
desc: &SegmentSumDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
validate_desc(*desc, T::KIND, "SegmentSumPlan")?;
let sku = build_sku::<T>(SegmentKind::SegmentSum);
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &SegmentSumArgs<'_, T>) -> Result<()> {
validate_args(
&self.desc,
args.input.shape,
args.segment_ids.shape,
args.output.shape,
"SegmentSumPlan",
)
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: SegmentSumArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
if total_out == 0 {
return Ok(());
}
run_sorted_fw::<T>(
stream,
self.desc.num_inputs,
self.desc.embedding_dim,
self.desc.num_segments,
&args.input,
&args.segment_ids,
&args.output,
SortedFwOp::Sum,
)
}
}
pub(crate) fn validate_desc(
desc_num_inputs_dim_seg: impl SegDescView,
expected_element: ElementKind,
_plan_name: &'static str,
) -> Result<()> {
let (n, d, ns, el) = desc_num_inputs_dim_seg.view();
if el != expected_element {
return Err(Error::Unsupported(
"baracuda-kernels::segment: descriptor element != type parameter T",
));
}
if n < 0 || d < 0 || ns < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::segment: num_inputs / embedding_dim / num_segments must be non-negative",
));
}
if !matches!(el, ElementKind::F32 | ElementKind::F64) {
return Err(Error::Unsupported(
"baracuda-kernels::segment: today only f32, f64 wired (atomicAdd / atomic-CAS restricted to native-FP-atomic types)",
));
}
Ok(())
}
pub(crate) trait SegDescView {
fn view(&self) -> (i32, i32, i32, ElementKind);
}
impl SegDescView for SegmentSumDescriptor {
#[inline]
fn view(&self) -> (i32, i32, i32, ElementKind) {
(
self.num_inputs,
self.embedding_dim,
self.num_segments,
self.element,
)
}
}
pub(crate) fn validate_args(
desc: &SegmentSumDescriptor,
input_shape: [i32; 2],
seg_shape: [i32; 1],
output_shape: [i32; 2],
_plan_name: &'static str,
) -> Result<()> {
if input_shape != [desc.num_inputs, desc.embedding_dim] {
return Err(Error::InvalidProblem(
"baracuda-kernels::segment: input shape != [num_inputs, embedding_dim]",
));
}
if seg_shape != [desc.num_inputs] {
return Err(Error::InvalidProblem(
"baracuda-kernels::segment: segment_ids shape != [num_inputs]",
));
}
if output_shape != [desc.num_segments, desc.embedding_dim] {
return Err(Error::InvalidProblem(
"baracuda-kernels::segment: output shape != [num_segments, embedding_dim]",
));
}
Ok(())
}
pub(crate) fn build_sku<T: Element>(op: SegmentKind) -> KernelSku {
let precision_guarantee = PrecisionGuarantee {
math_precision: if T::KIND == ElementKind::F64 {
MathPrecision::F64
} else {
MathPrecision::F32
},
accumulator: T::KIND,
bit_stable_on_same_hardware: matches!(
op,
SegmentKind::SegmentSum
| SegmentKind::SegmentMean
| SegmentKind::SegmentMax
| SegmentKind::SegmentMin
| SegmentKind::SegmentProd
| SegmentKind::SegmentSumBackward
| SegmentKind::SegmentMeanBackward
| SegmentKind::UnsortedSegmentSumBackward
| SegmentKind::UnsortedSegmentMeanBackward
),
deterministic: matches!(
op,
SegmentKind::SegmentSum
| SegmentKind::SegmentMean
| SegmentKind::SegmentMax
| SegmentKind::SegmentMin
| SegmentKind::SegmentProd
| SegmentKind::SegmentSumBackward
| SegmentKind::SegmentMeanBackward
| SegmentKind::UnsortedSegmentSumBackward
| SegmentKind::UnsortedSegmentMeanBackward
),
};
KernelSku {
category: OpCategory::SegmentOps,
op: op as u16,
element: T::KIND,
aux_element: Some(ElementKind::I32),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
}
}
#[derive(Copy, Clone, Debug)]
pub(crate) enum SortedFwOp {
Sum,
Mean,
Max,
Min,
Prod,
}
pub(crate) fn run_sorted_fw<T: Element>(
stream: &Stream,
n: i32,
d: i32,
num_segments: i32,
input: &TensorRef<'_, T, 2>,
segment_ids: &TensorRef<'_, i32, 1>,
output: &TensorMut<'_, T, 2>,
op: SortedFwOp,
) -> Result<()> {
let in_ptr = input.data.as_raw().0 as *const c_void;
let id_ptr = segment_ids.data.as_raw().0 as *const c_void;
let out_ptr = output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = match (T::KIND, op) {
(ElementKind::F32, SortedFwOp::Sum) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_sum_f32_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F64, SortedFwOp::Sum) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_sum_f64_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, SortedFwOp::Mean) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_mean_f32_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F64, SortedFwOp::Mean) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_mean_f64_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, SortedFwOp::Max) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_max_f32_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F64, SortedFwOp::Max) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_max_f64_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, SortedFwOp::Min) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_min_f32_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F64, SortedFwOp::Min) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_min_f64_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, SortedFwOp::Prod) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_prod_f32_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F64, SortedFwOp::Prod) => unsafe {
baracuda_kernels_sys::baracuda_kernels_segment_prod_f64_run(
n, d, num_segments, in_ptr, id_ptr, out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::segment::run_sorted_fw reached an unimplemented dtype \
— select() should have caught this",
));
}
};
map_status(status)
}