baracuda_kernels/segment/
unsorted_segment_sum.rs1use core::ffi::c_void;
14use core::marker::PhantomData;
15
16use baracuda_cutlass::{Error, Result};
17use baracuda_driver::Stream;
18use baracuda_kernels_types::{
19 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
20 PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut, TensorRef, Workspace,
21};
22
23use super::map_status;
24use super::segment_sum::{validate_desc, SegDescView};
25
26#[derive(Copy, Clone, Debug)]
28pub struct UnsortedSegmentSumDescriptor {
29 pub num_inputs: i32,
31 pub embedding_dim: i32,
33 pub num_segments: i32,
36 pub element: ElementKind,
38}
39
40impl SegDescView for UnsortedSegmentSumDescriptor {
41 #[inline]
42 fn view(&self) -> (i32, i32, i32, ElementKind) {
43 (
44 self.num_inputs,
45 self.embedding_dim,
46 self.num_segments,
47 self.element,
48 )
49 }
50}
51
52pub struct UnsortedSegmentSumArgs<'a, T: Element> {
54 pub input: TensorRef<'a, T, 2>,
56 pub segment_ids: TensorRef<'a, i32, 1>,
58 pub output: TensorMut<'a, T, 2>,
61}
62
63pub struct UnsortedSegmentSumPlan<T: Element> {
87 desc: UnsortedSegmentSumDescriptor,
88 sku: KernelSku,
89 _marker: PhantomData<T>,
90}
91
92impl<T: Element> UnsortedSegmentSumPlan<T> {
93 pub fn select(
95 _stream: &Stream,
96 desc: &UnsortedSegmentSumDescriptor,
97 _pref: PlanPreference,
98 ) -> Result<Self> {
99 validate_desc(*desc, T::KIND, "UnsortedSegmentSumPlan")?;
100 Ok(Self {
101 desc: *desc,
102 sku: build_unsorted_sku::<T>(SegmentKind::UnsortedSegmentSum),
103 _marker: PhantomData,
104 })
105 }
106
107 pub fn can_implement(&self, args: &UnsortedSegmentSumArgs<'_, T>) -> Result<()> {
109 validate_unsorted_args(
110 self.desc.num_inputs,
111 self.desc.embedding_dim,
112 self.desc.num_segments,
113 args.input.shape,
114 args.segment_ids.shape,
115 args.output.shape,
116 "UnsortedSegmentSumPlan",
117 )
118 }
119
120 #[inline]
122 pub fn workspace_size(&self) -> usize {
123 0
124 }
125
126 #[inline]
128 pub fn sku(&self) -> KernelSku {
129 self.sku
130 }
131
132 #[inline]
134 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
135 self.sku.precision_guarantee
136 }
137
138 pub fn run(
140 &self,
141 stream: &Stream,
142 _workspace: Workspace<'_>,
143 args: UnsortedSegmentSumArgs<'_, T>,
144 ) -> Result<()> {
145 self.can_implement(&args)?;
146 let total = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
147 if total == 0 {
148 return Ok(());
149 }
150 let in_ptr = args.input.data.as_raw().0 as *const c_void;
151 let id_ptr = args.segment_ids.data.as_raw().0 as *const c_void;
152 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
153 let stream_ptr = stream.as_raw() as *mut c_void;
154 let status = match T::KIND {
155 ElementKind::F32 => unsafe {
156 baracuda_kernels_sys::baracuda_kernels_unsorted_segment_sum_f32_run(
157 self.desc.num_inputs,
158 self.desc.embedding_dim,
159 self.desc.num_segments,
160 in_ptr,
161 id_ptr,
162 out_ptr,
163 core::ptr::null_mut(),
164 0,
165 stream_ptr,
166 )
167 },
168 ElementKind::F64 => unsafe {
169 baracuda_kernels_sys::baracuda_kernels_unsorted_segment_sum_f64_run(
170 self.desc.num_inputs,
171 self.desc.embedding_dim,
172 self.desc.num_segments,
173 in_ptr,
174 id_ptr,
175 out_ptr,
176 core::ptr::null_mut(),
177 0,
178 stream_ptr,
179 )
180 },
181 _ => {
182 return Err(Error::Unsupported(
183 "baracuda-kernels::UnsortedSegmentSumPlan::run reached an unimplemented dtype",
184 ));
185 }
186 };
187 map_status(status)
188 }
189}
190
191pub(crate) fn build_unsorted_sku<T: Element>(op: SegmentKind) -> KernelSku {
194 let precision_guarantee = PrecisionGuarantee {
195 math_precision: if T::KIND == ElementKind::F64 {
196 MathPrecision::F64
197 } else {
198 MathPrecision::F32
199 },
200 accumulator: T::KIND,
201 bit_stable_on_same_hardware: false,
202 deterministic: false,
203 };
204 KernelSku {
205 category: OpCategory::SegmentOps,
206 op: op as u16,
207 element: T::KIND,
208 aux_element: Some(ElementKind::I32),
209 layout: None,
210 epilogue: None,
211 arch: ArchSku::Sm80,
212 backend: BackendKind::Bespoke,
213 precision_guarantee,
214 }
215}
216
217pub(crate) fn validate_unsorted_args(
219 num_inputs: i32,
220 embedding_dim: i32,
221 num_segments: i32,
222 input_shape: [i32; 2],
223 seg_shape: [i32; 1],
224 output_shape: [i32; 2],
225 _plan_name: &'static str,
226) -> Result<()> {
227 if input_shape != [num_inputs, embedding_dim] {
228 return Err(Error::InvalidProblem(
229 "baracuda-kernels::segment: input shape != [num_inputs, embedding_dim]",
230 ));
231 }
232 if seg_shape != [num_inputs] {
233 return Err(Error::InvalidProblem(
234 "baracuda-kernels::segment: segment_ids shape != [num_inputs]",
235 ));
236 }
237 if output_shape != [num_segments, embedding_dim] {
238 return Err(Error::InvalidProblem(
239 "baracuda-kernels::segment: output shape != [num_segments, embedding_dim]",
240 ));
241 }
242 Ok(())
243}