baracuda_kernels/segment/
unsorted_segment_prod.rs1use core::ffi::c_void;
17use core::marker::PhantomData;
18
19use baracuda_cutlass::{Error, Result};
20use baracuda_driver::Stream;
21use baracuda_kernels_types::{
22 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
23 TensorRef, Workspace,
24};
25
26use super::map_status;
27use super::segment_sum::{validate_desc, SegDescView};
28use super::unsorted_segment_sum::{build_unsorted_sku, validate_unsorted_args};
29
30#[derive(Copy, Clone, Debug)]
32pub struct UnsortedSegmentProdDescriptor {
33 pub num_inputs: i32,
35 pub embedding_dim: i32,
37 pub num_segments: i32,
39 pub element: ElementKind,
41}
42
43impl SegDescView for UnsortedSegmentProdDescriptor {
44 #[inline]
45 fn view(&self) -> (i32, i32, i32, ElementKind) {
46 (
47 self.num_inputs,
48 self.embedding_dim,
49 self.num_segments,
50 self.element,
51 )
52 }
53}
54
55pub struct UnsortedSegmentProdArgs<'a, T: Element> {
57 pub input: TensorRef<'a, T, 2>,
59 pub segment_ids: TensorRef<'a, i32, 1>,
61 pub output: TensorMut<'a, T, 2>,
64}
65
66pub struct UnsortedSegmentProdPlan<T: Element> {
84 desc: UnsortedSegmentProdDescriptor,
85 sku: KernelSku,
86 _marker: PhantomData<T>,
87}
88
89impl<T: Element> UnsortedSegmentProdPlan<T> {
90 pub fn select(
92 _stream: &Stream,
93 desc: &UnsortedSegmentProdDescriptor,
94 _pref: PlanPreference,
95 ) -> Result<Self> {
96 validate_desc(*desc, T::KIND, "UnsortedSegmentProdPlan")?;
97 Ok(Self {
98 desc: *desc,
99 sku: build_unsorted_sku::<T>(SegmentKind::UnsortedSegmentProd),
100 _marker: PhantomData,
101 })
102 }
103
104 pub fn can_implement(&self, args: &UnsortedSegmentProdArgs<'_, T>) -> Result<()> {
106 validate_unsorted_args(
107 self.desc.num_inputs,
108 self.desc.embedding_dim,
109 self.desc.num_segments,
110 args.input.shape,
111 args.segment_ids.shape,
112 args.output.shape,
113 "UnsortedSegmentProdPlan",
114 )
115 }
116
117 #[inline]
119 pub fn workspace_size(&self) -> usize {
120 0
121 }
122
123 #[inline]
125 pub fn sku(&self) -> KernelSku {
126 self.sku
127 }
128
129 #[inline]
131 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
132 self.sku.precision_guarantee
133 }
134
135 pub fn run(
137 &self,
138 stream: &Stream,
139 _workspace: Workspace<'_>,
140 args: UnsortedSegmentProdArgs<'_, T>,
141 ) -> Result<()> {
142 self.can_implement(&args)?;
143 let total = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
144 if total == 0 {
145 return Ok(());
146 }
147 let in_ptr = args.input.data.as_raw().0 as *const c_void;
148 let id_ptr = args.segment_ids.data.as_raw().0 as *const c_void;
149 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
150 let stream_ptr = stream.as_raw() as *mut c_void;
151 let status = match T::KIND {
152 ElementKind::F32 => unsafe {
153 baracuda_kernels_sys::baracuda_kernels_unsorted_segment_prod_f32_run(
154 self.desc.num_inputs,
155 self.desc.embedding_dim,
156 self.desc.num_segments,
157 in_ptr,
158 id_ptr,
159 out_ptr,
160 core::ptr::null_mut(),
161 0,
162 stream_ptr,
163 )
164 },
165 ElementKind::F64 => unsafe {
166 baracuda_kernels_sys::baracuda_kernels_unsorted_segment_prod_f64_run(
167 self.desc.num_inputs,
168 self.desc.embedding_dim,
169 self.desc.num_segments,
170 in_ptr,
171 id_ptr,
172 out_ptr,
173 core::ptr::null_mut(),
174 0,
175 stream_ptr,
176 )
177 },
178 _ => {
179 return Err(Error::Unsupported(
180 "baracuda-kernels::UnsortedSegmentProdPlan::run reached an unimplemented dtype",
181 ));
182 }
183 };
184 map_status(status)
185 }
186}