baracuda_kernels/segment/
segment_max_backward.rs1use core::ffi::c_void;
26use core::marker::PhantomData;
27
28use baracuda_cutlass::{Error, Result};
29use baracuda_driver::Stream;
30use baracuda_kernels_types::{
31 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
32 TensorRef, Workspace,
33};
34
35use super::map_status;
36use super::segment_sum::{build_sku, validate_desc, SegDescView};
37
38#[derive(Copy, Clone, Debug)]
40pub struct SegmentMaxBackwardDescriptor {
41 pub num_inputs: i32,
43 pub embedding_dim: i32,
45 pub num_segments: i32,
47 pub element: ElementKind,
49}
50
51impl SegDescView for SegmentMaxBackwardDescriptor {
52 #[inline]
53 fn view(&self) -> (i32, i32, i32, ElementKind) {
54 (
55 self.num_inputs,
56 self.embedding_dim,
57 self.num_segments,
58 self.element,
59 )
60 }
61}
62
63pub struct SegmentMaxBackwardArgs<'a, T: Element> {
65 pub d_output: TensorRef<'a, T, 2>,
67 pub input: TensorRef<'a, T, 2>,
69 pub segment_ids: TensorRef<'a, i32, 1>,
71 pub d_input: TensorMut<'a, T, 2>,
73}
74
75pub struct SegmentMaxBackwardPlan<T: Element> {
98 desc: SegmentMaxBackwardDescriptor,
99 sku: KernelSku,
100 _marker: PhantomData<T>,
101}
102
103impl<T: Element> SegmentMaxBackwardPlan<T> {
104 pub fn select(
106 _stream: &Stream,
107 desc: &SegmentMaxBackwardDescriptor,
108 _pref: PlanPreference,
109 ) -> Result<Self> {
110 validate_desc(*desc, T::KIND, "SegmentMaxBackwardPlan")?;
111 Ok(Self {
112 desc: *desc,
113 sku: build_sku::<T>(SegmentKind::SegmentMaxBackward),
114 _marker: PhantomData,
115 })
116 }
117
118 pub fn can_implement(&self, args: &SegmentMaxBackwardArgs<'_, T>) -> Result<()> {
120 if args.d_output.shape != [self.desc.num_segments, self.desc.embedding_dim] {
121 return Err(Error::InvalidProblem(
122 "baracuda-kernels::SegmentMaxBackwardPlan: d_output shape != [num_segments, D]",
123 ));
124 }
125 if args.input.shape != [self.desc.num_inputs, self.desc.embedding_dim] {
126 return Err(Error::InvalidProblem(
127 "baracuda-kernels::SegmentMaxBackwardPlan: input shape != [num_inputs, D]",
128 ));
129 }
130 if args.segment_ids.shape != [self.desc.num_inputs] {
131 return Err(Error::InvalidProblem(
132 "baracuda-kernels::SegmentMaxBackwardPlan: segment_ids shape != [num_inputs]",
133 ));
134 }
135 if args.d_input.shape != [self.desc.num_inputs, self.desc.embedding_dim] {
136 return Err(Error::InvalidProblem(
137 "baracuda-kernels::SegmentMaxBackwardPlan: d_input shape != [num_inputs, D]",
138 ));
139 }
140 Ok(())
141 }
142
143 #[inline]
145 pub fn workspace_size(&self) -> usize {
146 0
147 }
148
149 #[inline]
151 pub fn sku(&self) -> KernelSku {
152 self.sku
153 }
154
155 #[inline]
157 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
158 self.sku.precision_guarantee
159 }
160
161 pub fn run(
163 &self,
164 stream: &Stream,
165 _workspace: Workspace<'_>,
166 args: SegmentMaxBackwardArgs<'_, T>,
167 ) -> Result<()> {
168 self.can_implement(&args)?;
169 let total = (self.desc.num_inputs as i64) * (self.desc.embedding_dim as i64);
170 if total == 0 {
171 return Ok(());
172 }
173 let do_ptr = args.d_output.data.as_raw().0 as *const c_void;
174 let in_ptr = args.input.data.as_raw().0 as *const c_void;
175 let id_ptr = args.segment_ids.data.as_raw().0 as *const c_void;
176 let di_ptr = args.d_input.data.as_raw().0 as *mut c_void;
177 let stream_ptr = stream.as_raw() as *mut c_void;
178 let status = match T::KIND {
179 ElementKind::F32 => unsafe {
180 baracuda_kernels_sys::baracuda_kernels_segment_max_backward_f32_run(
181 self.desc.num_inputs,
182 self.desc.embedding_dim,
183 self.desc.num_segments,
184 do_ptr,
185 in_ptr,
186 id_ptr,
187 di_ptr,
188 core::ptr::null_mut(),
189 0,
190 stream_ptr,
191 )
192 },
193 ElementKind::F64 => unsafe {
194 baracuda_kernels_sys::baracuda_kernels_segment_max_backward_f64_run(
195 self.desc.num_inputs,
196 self.desc.embedding_dim,
197 self.desc.num_segments,
198 do_ptr,
199 in_ptr,
200 id_ptr,
201 di_ptr,
202 core::ptr::null_mut(),
203 0,
204 stream_ptr,
205 )
206 },
207 _ => {
208 return Err(Error::Unsupported(
209 "baracuda-kernels::SegmentMaxBackwardPlan::run reached an unimplemented dtype",
210 ));
211 }
212 };
213 map_status(status)
214 }
215}