Skip to main content

baracuda_kernels/elementwise/
binary_backward.rs

1//! Backward plan for the binary elementwise family.
2//!
3//! Sibling of [`crate::BinaryPlan`] for gradient computation:
4//! `(da, db) = backward(dy, [saved tensors per op])`.
5//!
6//! Today wired: `{Add, Sub, Mul, Div, Maximum, Minimum} × {f32, f16, bf16, f64}`.
7//! Add and Sub need no saved tensors; Mul, Div, Maximum, Minimum require
8//! the saved forward inputs `a` and `b`:
9//! - Add: `(da, db) = (dy, dy)` — no saved
10//! - Sub: `(da, db) = (dy, -dy)` — no saved
11//! - Mul: `(da, db) = (dy * b, dy * a)` — needs saved `a`, `b`
12//! - Div: `(da, db) = (dy / b, -dy * a / b²)` — needs saved `a`, `b`
13//! - Maximum / Minimum: saves used purely as comparison references; tie
14//!   splits `dy` evenly (PyTorch parity). For Maximum:
15//!   `da = where(a==b, dy/2, where(a<b, 0, dy))`,
16//!   `db = where(a==b, dy/2, where(b<a, 0, dy))`. Minimum flips `<`/`>`.
17//!   NaN inputs propagate `dy` to both (all comparisons false).
18//!
19//! The `Args` struct carries `a` and `b` as `Option<TensorRef>` so
20//! callers omit them for ops that don't need them. The dispatcher
21//! validates that needed saves are present.
22//!
23//! Trailblazer constraints (same shape limits as the forward
24//! trailblazer): contig-only (no broadcasting); `dy.shape ==
25//! da.shape == db.shape`. Ops with saves additionally require
26//! `a.shape == b.shape == dy.shape`.
27
28use core::ffi::c_void;
29use core::marker::PhantomData;
30
31use baracuda_cutlass::{Error, Result};
32use baracuda_driver::Stream;
33use baracuda_kernels_types::{
34    ArchSku, BackendKind, BinaryKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
35    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
36};
37
38/// Descriptor for a binary backward op.
39#[derive(Copy, Clone, Debug)]
40pub struct BinaryBackwardDescriptor<const N: usize> {
41    /// Which forward binary op this is the backward of.
42    pub kind: BinaryKind,
43    /// Tensor shape (shared by dy / a / b / da / db).
44    pub shape: [i32; N],
45    /// Element type.
46    pub element: ElementKind,
47}
48
49/// Args bundle for a binary backward launch.
50///
51/// `a` and `b` are SAVED forward inputs — required by `Mul`, `Div`,
52/// `Maximum`, `Minimum` (gradient formula references them) but unused
53/// by `Add` / `Sub`. The dispatcher checks that ops needing saves have
54/// them supplied.
55pub struct BinaryBackwardArgs<'a, T: Element, const N: usize> {
56    /// Upstream gradient (input to backward).
57    pub dy: TensorRef<'a, T, N>,
58    /// Saved forward input `a`. Required by `Mul` / `Div`; ignored otherwise.
59    pub a: Option<TensorRef<'a, T, N>>,
60    /// Saved forward input `b`. Required by `Mul` / `Div`; ignored otherwise.
61    pub b: Option<TensorRef<'a, T, N>>,
62    /// Gradient w.r.t. `a`.
63    pub da: TensorMut<'a, T, N>,
64    /// Gradient w.r.t. `b`.
65    pub db: TensorMut<'a, T, N>,
66}
67
68/// Binary backward plan.
69pub struct BinaryBackwardPlan<T: Element, const N: usize> {
70    desc: BinaryBackwardDescriptor<N>,
71    sku: KernelSku,
72    _marker: PhantomData<T>,
73}
74
75#[inline]
76fn op_needs_saves(kind: BinaryKind) -> bool {
77    matches!(
78        kind,
79        BinaryKind::Mul
80            | BinaryKind::Div
81            | BinaryKind::Pow
82            | BinaryKind::Maximum
83            | BinaryKind::Minimum
84            | BinaryKind::Atan2
85            | BinaryKind::Hypot
86    )
87}
88
89impl<T: Element, const N: usize> BinaryBackwardPlan<T, N> {
90    /// Pick a kernel.
91    pub fn select(
92        _stream: &Stream,
93        desc: &BinaryBackwardDescriptor<N>,
94        _pref: PlanPreference,
95    ) -> Result<Self> {
96        if desc.element != T::KIND {
97            return Err(Error::Unsupported(
98                "baracuda-kernels::BinaryBackwardPlan: descriptor element != T",
99            ));
100        }
101        for &d in desc.shape.iter() {
102            if d < 0 {
103                return Err(Error::InvalidProblem(
104                    "baracuda-kernels::BinaryBackwardPlan: shape dims must be non-negative",
105                ));
106            }
107        }
108        // Phase 3 backward family: {Add, Sub, Mul, Div, Maximum, Minimum} ×
109        // {f32, f16, bf16, f64}.
110        let supported = matches!(
111            (desc.kind, T::KIND),
112            (BinaryKind::Add, ElementKind::F32)
113                | (BinaryKind::Add, ElementKind::F16)
114                | (BinaryKind::Add, ElementKind::Bf16)
115                | (BinaryKind::Add, ElementKind::F64)
116                | (BinaryKind::Sub, ElementKind::F32)
117                | (BinaryKind::Sub, ElementKind::F16)
118                | (BinaryKind::Sub, ElementKind::Bf16)
119                | (BinaryKind::Sub, ElementKind::F64)
120                | (BinaryKind::Mul, ElementKind::F32)
121                | (BinaryKind::Mul, ElementKind::F16)
122                | (BinaryKind::Mul, ElementKind::Bf16)
123                | (BinaryKind::Mul, ElementKind::F64)
124                | (BinaryKind::Div, ElementKind::F32)
125                | (BinaryKind::Div, ElementKind::F16)
126                | (BinaryKind::Div, ElementKind::Bf16)
127                | (BinaryKind::Div, ElementKind::F64)
128                | (BinaryKind::Maximum, ElementKind::F32)
129                | (BinaryKind::Maximum, ElementKind::F16)
130                | (BinaryKind::Maximum, ElementKind::Bf16)
131                | (BinaryKind::Maximum, ElementKind::F64)
132                | (BinaryKind::Minimum, ElementKind::F32)
133                | (BinaryKind::Minimum, ElementKind::F16)
134                | (BinaryKind::Minimum, ElementKind::Bf16)
135                | (BinaryKind::Minimum, ElementKind::F64)
136                | (BinaryKind::Pow, ElementKind::F32)
137                | (BinaryKind::Pow, ElementKind::F16)
138                | (BinaryKind::Pow, ElementKind::Bf16)
139                | (BinaryKind::Pow, ElementKind::F64)
140                | (BinaryKind::Atan2, ElementKind::F32)
141                | (BinaryKind::Atan2, ElementKind::F16)
142                | (BinaryKind::Atan2, ElementKind::Bf16)
143                | (BinaryKind::Atan2, ElementKind::F64)
144                | (BinaryKind::Hypot, ElementKind::F32)
145                | (BinaryKind::Hypot, ElementKind::F16)
146                | (BinaryKind::Hypot, ElementKind::Bf16)
147                | (BinaryKind::Hypot, ElementKind::F64)
148        );
149        if !supported {
150            return Err(Error::Unsupported(
151                "baracuda-kernels::BinaryBackwardPlan: only \
152                 `{Add,Sub,Mul,Div,Maximum,Minimum,Pow,Atan2,Hypot}` × \
153                 `{f32, f16, bf16, f64}` are wired today; other (kind, dtype) \
154                 pairs (e.g. integer family, Lerp) land in later fanout. Lerp \
155                 is reserved-but-deferred pending a parameterized-binary plan \
156                 shape.",
157            ));
158        }
159
160        let precision_guarantee = PrecisionGuarantee {
161            math_precision: MathPrecision::F32,
162            accumulator: ElementKind::F32,
163            bit_stable_on_same_hardware: true,
164            deterministic: true,
165        };
166        let sku = KernelSku {
167            category: OpCategory::BinaryElementwise,
168            // Use the forward op discriminant. Backward is implied by
169            // the plan type itself (BinaryBackwardPlan vs BinaryPlan).
170            op: desc.kind as u16,
171            element: T::KIND,
172            aux_element: None,
173            layout: None,
174            epilogue: None,
175            arch: ArchSku::Sm80,
176            backend: BackendKind::Bespoke,
177            precision_guarantee,
178        };
179        Ok(Self {
180            desc: *desc,
181            sku,
182            _marker: PhantomData,
183        })
184    }
185
186    /// Validate args.
187    pub fn can_implement(&self, args: &BinaryBackwardArgs<'_, T, N>) -> Result<()> {
188        if args.dy.shape != self.desc.shape {
189            return Err(Error::InvalidProblem(
190                "baracuda-kernels::BinaryBackwardPlan: dy shape mismatch",
191            ));
192        }
193        if args.da.shape != self.desc.shape {
194            return Err(Error::InvalidProblem(
195                "baracuda-kernels::BinaryBackwardPlan: da shape mismatch",
196            ));
197        }
198        if args.db.shape != self.desc.shape {
199            return Err(Error::InvalidProblem(
200                "baracuda-kernels::BinaryBackwardPlan: db shape mismatch",
201            ));
202        }
203        // Contig-only for trailblazer.
204        if !args.dy.is_contiguous() || !args.da.is_contiguous() || !args.db.is_contiguous() {
205            return Err(Error::Unsupported(
206                "baracuda-kernels::BinaryBackwardPlan: trailblazer requires contiguous \
207                 dy / da / db; strided fanout lands later",
208            ));
209        }
210        // Per-op saved-tensor requirements.
211        if op_needs_saves(self.desc.kind) {
212            let a = args.a.as_ref().ok_or(Error::InvalidProblem(
213                "baracuda-kernels::BinaryBackwardPlan: this op requires saved input `a`",
214            ))?;
215            let b = args.b.as_ref().ok_or(Error::InvalidProblem(
216                "baracuda-kernels::BinaryBackwardPlan: this op requires saved input `b`",
217            ))?;
218            if a.shape != self.desc.shape {
219                return Err(Error::InvalidProblem(
220                    "baracuda-kernels::BinaryBackwardPlan: saved a shape mismatch",
221                ));
222            }
223            if b.shape != self.desc.shape {
224                return Err(Error::InvalidProblem(
225                    "baracuda-kernels::BinaryBackwardPlan: saved b shape mismatch",
226                ));
227            }
228            if !a.is_contiguous() || !b.is_contiguous() {
229                return Err(Error::Unsupported(
230                    "baracuda-kernels::BinaryBackwardPlan: saved a/b must be contiguous \
231                     (strided fanout lands later)",
232                ));
233            }
234            let numel = args.dy.numel() as usize;
235            if a.data.len() < numel {
236                return Err(Error::BufferTooSmall {
237                    needed: numel,
238                    got: a.data.len(),
239                });
240            }
241            if b.data.len() < numel {
242                return Err(Error::BufferTooSmall {
243                    needed: numel,
244                    got: b.data.len(),
245                });
246            }
247        }
248        let numel = args.dy.numel();
249        let dy_len = args.dy.data.len() as i64;
250        let da_len = args.da.data.len() as i64;
251        let db_len = args.db.data.len() as i64;
252        if dy_len < numel || da_len < numel || db_len < numel {
253            return Err(Error::BufferTooSmall {
254                needed: numel as usize,
255                got: dy_len.min(da_len).min(db_len) as usize,
256            });
257        }
258        Ok(())
259    }
260
261    /// Workspace size in bytes.
262    #[inline]
263    pub fn workspace_size(&self) -> usize {
264        0
265    }
266    /// Kernel SKU identity.
267    #[inline]
268    pub fn sku(&self) -> KernelSku {
269        self.sku
270    }
271    /// Numerical guarantees.
272    #[inline]
273    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
274        self.sku.precision_guarantee
275    }
276
277    /// Launch.
278    pub fn run(
279        &self,
280        stream: &Stream,
281        _workspace: Workspace<'_>,
282        args: BinaryBackwardArgs<'_, T, N>,
283    ) -> Result<()> {
284        self.can_implement(&args)?;
285        let numel = args.dy.numel();
286        if numel == 0 {
287            return Ok(());
288        }
289        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
290        let da_ptr = args.da.data.as_raw().0 as *mut c_void;
291        let db_ptr = args.db.data.as_raw().0 as *mut c_void;
292        let stream_ptr = stream.as_raw() as *mut c_void;
293
294        let status = match (self.desc.kind, T::KIND) {
295            // -------- Add (no saves) --------
296            (BinaryKind::Add, ElementKind::F32) => unsafe {
297                baracuda_kernels_sys::baracuda_kernels_binary_add_backward_f32_run(
298                    numel, dy_ptr, da_ptr, db_ptr,
299                    core::ptr::null_mut(), 0, stream_ptr,
300                )
301            },
302            (BinaryKind::Add, ElementKind::F16) => unsafe {
303                baracuda_kernels_sys::baracuda_kernels_binary_add_backward_f16_run(
304                    numel, dy_ptr, da_ptr, db_ptr,
305                    core::ptr::null_mut(), 0, stream_ptr,
306                )
307            },
308            (BinaryKind::Add, ElementKind::Bf16) => unsafe {
309                baracuda_kernels_sys::baracuda_kernels_binary_add_backward_bf16_run(
310                    numel, dy_ptr, da_ptr, db_ptr,
311                    core::ptr::null_mut(), 0, stream_ptr,
312                )
313            },
314            (BinaryKind::Add, ElementKind::F64) => unsafe {
315                baracuda_kernels_sys::baracuda_kernels_binary_add_backward_f64_run(
316                    numel, dy_ptr, da_ptr, db_ptr,
317                    core::ptr::null_mut(), 0, stream_ptr,
318                )
319            },
320            // -------- Sub (no saves) --------
321            (BinaryKind::Sub, ElementKind::F32) => unsafe {
322                baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_f32_run(
323                    numel, dy_ptr, da_ptr, db_ptr,
324                    core::ptr::null_mut(), 0, stream_ptr,
325                )
326            },
327            (BinaryKind::Sub, ElementKind::F16) => unsafe {
328                baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_f16_run(
329                    numel, dy_ptr, da_ptr, db_ptr,
330                    core::ptr::null_mut(), 0, stream_ptr,
331                )
332            },
333            (BinaryKind::Sub, ElementKind::Bf16) => unsafe {
334                baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_bf16_run(
335                    numel, dy_ptr, da_ptr, db_ptr,
336                    core::ptr::null_mut(), 0, stream_ptr,
337                )
338            },
339            (BinaryKind::Sub, ElementKind::F64) => unsafe {
340                baracuda_kernels_sys::baracuda_kernels_binary_sub_backward_f64_run(
341                    numel, dy_ptr, da_ptr, db_ptr,
342                    core::ptr::null_mut(), 0, stream_ptr,
343                )
344            },
345            // -------- Mul (saves) --------
346            (BinaryKind::Mul, ElementKind::F32) => {
347                let (a_ptr, b_ptr) = saved_ptrs(&args);
348                unsafe {
349                    baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_f32_run(
350                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
351                        core::ptr::null_mut(), 0, stream_ptr,
352                    )
353                }
354            }
355            (BinaryKind::Mul, ElementKind::F16) => {
356                let (a_ptr, b_ptr) = saved_ptrs(&args);
357                unsafe {
358                    baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_f16_run(
359                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
360                        core::ptr::null_mut(), 0, stream_ptr,
361                    )
362                }
363            }
364            (BinaryKind::Mul, ElementKind::Bf16) => {
365                let (a_ptr, b_ptr) = saved_ptrs(&args);
366                unsafe {
367                    baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_bf16_run(
368                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
369                        core::ptr::null_mut(), 0, stream_ptr,
370                    )
371                }
372            }
373            (BinaryKind::Mul, ElementKind::F64) => {
374                let (a_ptr, b_ptr) = saved_ptrs(&args);
375                unsafe {
376                    baracuda_kernels_sys::baracuda_kernels_binary_mul_backward_f64_run(
377                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
378                        core::ptr::null_mut(), 0, stream_ptr,
379                    )
380                }
381            }
382            // -------- Div (saves) --------
383            (BinaryKind::Div, ElementKind::F32) => {
384                let (a_ptr, b_ptr) = saved_ptrs(&args);
385                unsafe {
386                    baracuda_kernels_sys::baracuda_kernels_binary_div_backward_f32_run(
387                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
388                        core::ptr::null_mut(), 0, stream_ptr,
389                    )
390                }
391            }
392            (BinaryKind::Div, ElementKind::F16) => {
393                let (a_ptr, b_ptr) = saved_ptrs(&args);
394                unsafe {
395                    baracuda_kernels_sys::baracuda_kernels_binary_div_backward_f16_run(
396                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
397                        core::ptr::null_mut(), 0, stream_ptr,
398                    )
399                }
400            }
401            (BinaryKind::Div, ElementKind::Bf16) => {
402                let (a_ptr, b_ptr) = saved_ptrs(&args);
403                unsafe {
404                    baracuda_kernels_sys::baracuda_kernels_binary_div_backward_bf16_run(
405                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
406                        core::ptr::null_mut(), 0, stream_ptr,
407                    )
408                }
409            }
410            (BinaryKind::Div, ElementKind::F64) => {
411                let (a_ptr, b_ptr) = saved_ptrs(&args);
412                unsafe {
413                    baracuda_kernels_sys::baracuda_kernels_binary_div_backward_f64_run(
414                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
415                        core::ptr::null_mut(), 0, stream_ptr,
416                    )
417                }
418            }
419            // -------- Maximum (saves used as comparison references) --------
420            (BinaryKind::Maximum, ElementKind::F32) => {
421                let (a_ptr, b_ptr) = saved_ptrs(&args);
422                unsafe {
423                    baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_f32_run(
424                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
425                        core::ptr::null_mut(), 0, stream_ptr,
426                    )
427                }
428            }
429            (BinaryKind::Maximum, ElementKind::F16) => {
430                let (a_ptr, b_ptr) = saved_ptrs(&args);
431                unsafe {
432                    baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_f16_run(
433                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
434                        core::ptr::null_mut(), 0, stream_ptr,
435                    )
436                }
437            }
438            (BinaryKind::Maximum, ElementKind::Bf16) => {
439                let (a_ptr, b_ptr) = saved_ptrs(&args);
440                unsafe {
441                    baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_bf16_run(
442                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
443                        core::ptr::null_mut(), 0, stream_ptr,
444                    )
445                }
446            }
447            (BinaryKind::Maximum, ElementKind::F64) => {
448                let (a_ptr, b_ptr) = saved_ptrs(&args);
449                unsafe {
450                    baracuda_kernels_sys::baracuda_kernels_binary_maximum_backward_f64_run(
451                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
452                        core::ptr::null_mut(), 0, stream_ptr,
453                    )
454                }
455            }
456            // -------- Minimum (saves used as comparison references) --------
457            (BinaryKind::Minimum, ElementKind::F32) => {
458                let (a_ptr, b_ptr) = saved_ptrs(&args);
459                unsafe {
460                    baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_f32_run(
461                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
462                        core::ptr::null_mut(), 0, stream_ptr,
463                    )
464                }
465            }
466            (BinaryKind::Minimum, ElementKind::F16) => {
467                let (a_ptr, b_ptr) = saved_ptrs(&args);
468                unsafe {
469                    baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_f16_run(
470                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
471                        core::ptr::null_mut(), 0, stream_ptr,
472                    )
473                }
474            }
475            (BinaryKind::Minimum, ElementKind::Bf16) => {
476                let (a_ptr, b_ptr) = saved_ptrs(&args);
477                unsafe {
478                    baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_bf16_run(
479                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
480                        core::ptr::null_mut(), 0, stream_ptr,
481                    )
482                }
483            }
484            (BinaryKind::Minimum, ElementKind::F64) => {
485                let (a_ptr, b_ptr) = saved_ptrs(&args);
486                unsafe {
487                    baracuda_kernels_sys::baracuda_kernels_binary_minimum_backward_f64_run(
488                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
489                        core::ptr::null_mut(), 0, stream_ptr,
490                    )
491                }
492            }
493            // -------- Pow (saves) --------
494            (BinaryKind::Pow, ElementKind::F32) => {
495                let (a_ptr, b_ptr) = saved_ptrs(&args);
496                unsafe {
497                    baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_f32_run(
498                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
499                        core::ptr::null_mut(), 0, stream_ptr,
500                    )
501                }
502            }
503            (BinaryKind::Pow, ElementKind::F16) => {
504                let (a_ptr, b_ptr) = saved_ptrs(&args);
505                unsafe {
506                    baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_f16_run(
507                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
508                        core::ptr::null_mut(), 0, stream_ptr,
509                    )
510                }
511            }
512            (BinaryKind::Pow, ElementKind::Bf16) => {
513                let (a_ptr, b_ptr) = saved_ptrs(&args);
514                unsafe {
515                    baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_bf16_run(
516                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
517                        core::ptr::null_mut(), 0, stream_ptr,
518                    )
519                }
520            }
521            (BinaryKind::Pow, ElementKind::F64) => {
522                let (a_ptr, b_ptr) = saved_ptrs(&args);
523                unsafe {
524                    baracuda_kernels_sys::baracuda_kernels_binary_pow_backward_f64_run(
525                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
526                        core::ptr::null_mut(), 0, stream_ptr,
527                    )
528                }
529            }
530            // -------- Atan2 (saves) --------
531            (BinaryKind::Atan2, ElementKind::F32) => {
532                let (a_ptr, b_ptr) = saved_ptrs(&args);
533                unsafe {
534                    baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_f32_run(
535                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
536                        core::ptr::null_mut(), 0, stream_ptr,
537                    )
538                }
539            }
540            (BinaryKind::Atan2, ElementKind::F16) => {
541                let (a_ptr, b_ptr) = saved_ptrs(&args);
542                unsafe {
543                    baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_f16_run(
544                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
545                        core::ptr::null_mut(), 0, stream_ptr,
546                    )
547                }
548            }
549            (BinaryKind::Atan2, ElementKind::Bf16) => {
550                let (a_ptr, b_ptr) = saved_ptrs(&args);
551                unsafe {
552                    baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_bf16_run(
553                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
554                        core::ptr::null_mut(), 0, stream_ptr,
555                    )
556                }
557            }
558            (BinaryKind::Atan2, ElementKind::F64) => {
559                let (a_ptr, b_ptr) = saved_ptrs(&args);
560                unsafe {
561                    baracuda_kernels_sys::baracuda_kernels_binary_atan2_backward_f64_run(
562                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
563                        core::ptr::null_mut(), 0, stream_ptr,
564                    )
565                }
566            }
567            // -------- Hypot (saves; y reconstructed inside kernel) --------
568            (BinaryKind::Hypot, ElementKind::F32) => {
569                let (a_ptr, b_ptr) = saved_ptrs(&args);
570                unsafe {
571                    baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_f32_run(
572                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
573                        core::ptr::null_mut(), 0, stream_ptr,
574                    )
575                }
576            }
577            (BinaryKind::Hypot, ElementKind::F16) => {
578                let (a_ptr, b_ptr) = saved_ptrs(&args);
579                unsafe {
580                    baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_f16_run(
581                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
582                        core::ptr::null_mut(), 0, stream_ptr,
583                    )
584                }
585            }
586            (BinaryKind::Hypot, ElementKind::Bf16) => {
587                let (a_ptr, b_ptr) = saved_ptrs(&args);
588                unsafe {
589                    baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_bf16_run(
590                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
591                        core::ptr::null_mut(), 0, stream_ptr,
592                    )
593                }
594            }
595            (BinaryKind::Hypot, ElementKind::F64) => {
596                let (a_ptr, b_ptr) = saved_ptrs(&args);
597                unsafe {
598                    baracuda_kernels_sys::baracuda_kernels_binary_hypot_backward_f64_run(
599                        numel, dy_ptr, a_ptr, b_ptr, da_ptr, db_ptr,
600                        core::ptr::null_mut(), 0, stream_ptr,
601                    )
602                }
603            }
604            _ => {
605                return Err(Error::Unsupported(
606                    "baracuda-kernels::BinaryBackwardPlan::run reached an unimplemented \
607                     (kind, dtype) pair — select() should have caught this",
608                ));
609            }
610        };
611        map_status(status)
612    }
613}
614
615#[inline]
616fn saved_ptrs<T: Element, const N: usize>(
617    args: &BinaryBackwardArgs<'_, T, N>,
618) -> (*const c_void, *const c_void) {
619    // can_implement guarantees Some for ops that reach this path.
620    let a = args
621        .a
622        .as_ref()
623        .expect("Mul/Div/Pow/Maximum/Minimum/Atan2/Hypot backward require saved a");
624    let b = args
625        .b
626        .as_ref()
627        .expect("Mul/Div/Pow/Maximum/Minimum/Atan2/Hypot backward require saved b");
628    (
629        a.data.as_raw().0 as *const c_void,
630        b.data.as_raw().0 as *const c_void,
631    )
632}
633
634fn map_status(code: i32) -> Result<()> {
635    match code {
636        0 => Ok(()),
637        1 => Err(Error::MisalignedOperand),
638        2 => Err(Error::InvalidProblem(
639            "baracuda-kernels-sys reported invalid problem",
640        )),
641        3 => Err(Error::Unsupported(
642            "baracuda-kernels-sys reported unsupported configuration",
643        )),
644        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
645        n => Err(Error::CutlassInternal(n)),
646    }
647}