baracuda_kernels/elementwise/
where_backward.rs1use core::ffi::c_void;
26use core::marker::PhantomData;
27
28use baracuda_cutlass::{Error, Result};
29use baracuda_driver::Stream;
30use baracuda_kernels_types::{
31 ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
32 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
33};
34
35#[derive(Copy, Clone, Debug)]
41pub struct WhereBackwardDescriptor<const N: usize> {
42 pub shape: [i32; N],
44 pub element: ElementKind,
46}
47
48pub struct WhereBackwardArgs<'a, T: Element, const N: usize> {
54 pub cond: TensorRef<'a, u8, N>,
57 pub dy: TensorRef<'a, T, N>,
59 pub da: TensorMut<'a, T, N>,
61 pub db: TensorMut<'a, T, N>,
63}
64
65pub struct WhereBackwardPlan<T: Element, const N: usize> {
70 desc: WhereBackwardDescriptor<N>,
71 sku: KernelSku,
72 _marker: PhantomData<T>,
73}
74
75impl<T: Element, const N: usize> WhereBackwardPlan<T, N> {
76 pub fn select(
79 _stream: &Stream,
80 desc: &WhereBackwardDescriptor<N>,
81 _pref: PlanPreference,
82 ) -> Result<Self> {
83 if desc.element != T::KIND {
84 return Err(Error::Unsupported(
85 "baracuda-kernels::WhereBackwardPlan: descriptor element != type parameter T",
86 ));
87 }
88 for &d in desc.shape.iter() {
89 if d < 0 {
90 return Err(Error::InvalidProblem(
91 "baracuda-kernels::WhereBackwardPlan: shape dims must be non-negative",
92 ));
93 }
94 }
95
96 let supported = matches!(
98 T::KIND,
99 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
100 );
101 if !supported {
102 return Err(Error::Unsupported(
103 "baracuda-kernels::WhereBackwardPlan: value dtype must be one of \
104 {F32, F16, Bf16, F64}",
105 ));
106 }
107
108 let (math_precision, accumulator) = match T::KIND {
113 ElementKind::F16 => (MathPrecision::F16, ElementKind::F16),
114 ElementKind::Bf16 => (MathPrecision::Bf16, ElementKind::Bf16),
115 ElementKind::F64 => (MathPrecision::F64, ElementKind::F64),
116 _ => (MathPrecision::F32, ElementKind::F32),
117 };
118 let precision_guarantee = PrecisionGuarantee {
119 math_precision,
120 accumulator,
121 bit_stable_on_same_hardware: true,
122 deterministic: true,
123 };
124 let sku = KernelSku {
125 category: OpCategory::TernaryElementwise,
126 op: 4,
132 element: T::KIND,
133 aux_element: None,
138 layout: None,
139 epilogue: None,
140 arch: ArchSku::Sm80,
141 backend: BackendKind::Bespoke,
142 precision_guarantee,
143 };
144 Ok(Self {
145 desc: *desc,
146 sku,
147 _marker: PhantomData,
148 })
149 }
150
151 pub fn can_implement(&self, args: &WhereBackwardArgs<'_, T, N>) -> Result<()> {
153 if args.dy.shape != self.desc.shape {
154 return Err(Error::InvalidProblem(
155 "baracuda-kernels::WhereBackwardPlan: dy shape mismatch with descriptor",
156 ));
157 }
158 if args.da.shape != self.desc.shape {
159 return Err(Error::InvalidProblem(
160 "baracuda-kernels::WhereBackwardPlan: da shape mismatch with descriptor",
161 ));
162 }
163 if args.db.shape != self.desc.shape {
164 return Err(Error::InvalidProblem(
165 "baracuda-kernels::WhereBackwardPlan: db shape mismatch with descriptor",
166 ));
167 }
168 if args.cond.shape != self.desc.shape {
169 return Err(Error::InvalidProblem(
170 "baracuda-kernels::WhereBackwardPlan: cond shape mismatch with descriptor \
171 (trailblazer requires full-shape cond; stride-0 broadcasting on cond \
172 lands later)",
173 ));
174 }
175
176 if !args.cond.is_contiguous()
178 || !args.dy.is_contiguous()
179 || !args.da.is_contiguous()
180 || !args.db.is_contiguous()
181 {
182 return Err(Error::Unsupported(
183 "baracuda-kernels::WhereBackwardPlan: trailblazer requires contiguous \
184 cond / dy / da / db; strided / broadcast fanout lands later",
185 ));
186 }
187
188 if N > 8 {
189 return Err(Error::Unsupported(
190 "baracuda-kernels::WhereBackwardPlan: tensor rank > 8 not supported",
191 ));
192 }
193
194 let numel = args.dy.numel();
195 let cond_len = args.cond.data.len() as i64;
196 let dy_len = args.dy.data.len() as i64;
197 let da_len = args.da.data.len() as i64;
198 let db_len = args.db.data.len() as i64;
199 if dy_len < numel || da_len < numel || db_len < numel || cond_len < numel {
200 return Err(Error::BufferTooSmall {
201 needed: numel as usize,
202 got: cond_len.min(dy_len).min(da_len).min(db_len) as usize,
203 });
204 }
205 Ok(())
206 }
207
208 #[inline]
210 pub fn workspace_size(&self) -> usize {
211 0
212 }
213
214 #[inline]
216 pub fn sku(&self) -> KernelSku {
217 self.sku
218 }
219
220 #[inline]
222 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
223 self.sku.precision_guarantee
224 }
225
226 pub fn run(
228 &self,
229 stream: &Stream,
230 _workspace: Workspace<'_>,
231 args: WhereBackwardArgs<'_, T, N>,
232 ) -> Result<()> {
233 self.can_implement(&args)?;
234 let numel = args.dy.numel();
235 if numel == 0 {
236 return Ok(());
237 }
238 let cond_ptr = args.cond.data.as_raw().0 as *const c_void;
239 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
240 let da_ptr = args.da.data.as_raw().0 as *mut c_void;
241 let db_ptr = args.db.data.as_raw().0 as *mut c_void;
242 let stream_ptr = stream.as_raw() as *mut c_void;
243
244 let status = match T::KIND {
245 ElementKind::F32 => unsafe {
246 baracuda_kernels_sys::baracuda_kernels_where_backward_f32_run(
247 numel,
248 cond_ptr,
249 dy_ptr,
250 da_ptr,
251 db_ptr,
252 core::ptr::null_mut(),
253 0,
254 stream_ptr,
255 )
256 },
257 ElementKind::F16 => unsafe {
258 baracuda_kernels_sys::baracuda_kernels_where_backward_f16_run(
259 numel,
260 cond_ptr,
261 dy_ptr,
262 da_ptr,
263 db_ptr,
264 core::ptr::null_mut(),
265 0,
266 stream_ptr,
267 )
268 },
269 ElementKind::Bf16 => unsafe {
270 baracuda_kernels_sys::baracuda_kernels_where_backward_bf16_run(
271 numel,
272 cond_ptr,
273 dy_ptr,
274 da_ptr,
275 db_ptr,
276 core::ptr::null_mut(),
277 0,
278 stream_ptr,
279 )
280 },
281 ElementKind::F64 => unsafe {
282 baracuda_kernels_sys::baracuda_kernels_where_backward_f64_run(
283 numel,
284 cond_ptr,
285 dy_ptr,
286 da_ptr,
287 db_ptr,
288 core::ptr::null_mut(),
289 0,
290 stream_ptr,
291 )
292 },
293 _ => {
294 return Err(Error::Unsupported(
295 "baracuda-kernels::WhereBackwardPlan::run reached an unimplemented \
296 dtype — select() should have caught this",
297 ));
298 }
299 };
300 map_status(status)
301 }
302}
303
304fn map_status(code: i32) -> Result<()> {
305 match code {
306 0 => Ok(()),
307 1 => Err(Error::MisalignedOperand),
308 2 => Err(Error::InvalidProblem(
309 "baracuda-kernels-sys reported invalid problem",
310 )),
311 3 => Err(Error::Unsupported(
312 "baracuda-kernels-sys reported unsupported configuration",
313 )),
314 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
315 n => Err(Error::CutlassInternal(n)),
316 }
317}