Skip to main content

baracuda_kernels/elementwise/
where_backward.rs

1//! Heterogeneous-dtype ternary BW plan: `where_backward(cond, dy)`.
2//!
3//! Sibling of [`crate::WherePlan`] for gradient computation. Forward:
4//! `y = where(cond, a, b)` with `cond: u8` and same-dtype `a`/`b`/`y`.
5//! Backward (cond is non-differentiable — no `dcond`):
6//!
7//! - `da = where(cond, dy, 0)` — gradient flows to `a` only where cond is true.
8//! - `db = where(cond, 0, dy)` — gradient flows to `b` only where cond is false.
9//!
10//! Per-cell formula is pure mask + copy — no arithmetic — so output is
11//! bit-exact against host reference at every dtype.
12//!
13//! All 4 FP value dtypes wired: {f32, f16, bf16, f64}. Trailblazer
14//! constraints: **contig-only** (no broadcast on `dy` / `da` / `db`).
15//! `cond` carries the same heterogeneous-dtype convention as the FW
16//! (`u8`, 0 = false). Broadcast support on BW lands later if a use case
17//! materializes — the autograd reduction step usually flattens
18//! broadcasted gradients upstream of this kernel anyway, so the
19//! contig-only trailblazer matches typical caller pipelines.
20//!
21//! Module name: `where_backward` (the FW lives in `where_op` to dodge
22//! Rust's `where` keyword; the BW name is safe because the suffix
23//! breaks the keyword match).
24
25use core::ffi::c_void;
26use core::marker::PhantomData;
27
28use baracuda_cutlass::{Error, Result};
29use baracuda_driver::Stream;
30use baracuda_kernels_types::{
31    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
32    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
33};
34
35/// Descriptor for a `where_backward` op.
36///
37/// `shape` is the shared shape of `dy` / `da` / `db`. `element` is the
38/// **value** dtype — cond is always `u8`. `element` must match the type
39/// parameter `T` of the containing plan at `select` time.
40#[derive(Copy, Clone, Debug)]
41pub struct WhereBackwardDescriptor<const N: usize> {
42    /// Tensor shape (shared by cond / dy / da / db).
43    pub shape: [i32; N],
44    /// Value element type (dy / da / db dtype; cond is always `u8`).
45    pub element: ElementKind,
46}
47
48/// Args bundle for a `where_backward` launch.
49///
50/// `cond` is the FW mask (`u8`, 0 = false). `dy` is the upstream
51/// gradient (same dtype as the FW value inputs). `da` and `db` are the
52/// gradients w.r.t. the FW `a` and `b` respectively.
53pub struct WhereBackwardArgs<'a, T: Element, const N: usize> {
54    /// Boolean mask from the forward pass (`0u8` selected `b`,
55    /// any other value selected `a`).
56    pub cond: TensorRef<'a, u8, N>,
57    /// Upstream gradient.
58    pub dy: TensorRef<'a, T, N>,
59    /// Gradient w.r.t. `a`.
60    pub da: TensorMut<'a, T, N>,
61    /// Gradient w.r.t. `b`.
62    pub db: TensorMut<'a, T, N>,
63}
64
65/// `where_backward(cond, dy)` plan with heterogeneous-dtype inputs.
66///
67/// `T: Element` is the value dtype (`dy` / `da` / `db`). The cond is
68/// always `u8`. `const N: usize` is the tensor rank.
69pub struct WhereBackwardPlan<T: Element, const N: usize> {
70    desc: WhereBackwardDescriptor<N>,
71    sku: KernelSku,
72    _marker: PhantomData<T>,
73}
74
75impl<T: Element, const N: usize> WhereBackwardPlan<T, N> {
76    /// Pick a kernel for `desc`. Returns [`Error::Unsupported`] if the
77    /// value dtype isn't wired today.
78    pub fn select(
79        _stream: &Stream,
80        desc: &WhereBackwardDescriptor<N>,
81        _pref: PlanPreference,
82    ) -> Result<Self> {
83        if desc.element != T::KIND {
84            return Err(Error::Unsupported(
85                "baracuda-kernels::WhereBackwardPlan: descriptor element != type parameter T",
86            ));
87        }
88        for &d in desc.shape.iter() {
89            if d < 0 {
90                return Err(Error::InvalidProblem(
91                    "baracuda-kernels::WhereBackwardPlan: shape dims must be non-negative",
92                ));
93            }
94        }
95
96        // All 4 FP value dtypes wired.
97        let supported = matches!(
98            T::KIND,
99            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
100        );
101        if !supported {
102            return Err(Error::Unsupported(
103                "baracuda-kernels::WhereBackwardPlan: value dtype must be one of \
104                 {F32, F16, Bf16, F64}",
105            ));
106        }
107
108        // `where_backward` is a pure mask + copy — no arithmetic —
109        // fully deterministic and bit-stable on the same hardware. The
110        // MathPrecision tag mirrors the value dtype by convention even
111        // though no arithmetic happens.
112        let (math_precision, accumulator) = match T::KIND {
113            ElementKind::F16 => (MathPrecision::F16, ElementKind::F16),
114            ElementKind::Bf16 => (MathPrecision::Bf16, ElementKind::Bf16),
115            ElementKind::F64 => (MathPrecision::F64, ElementKind::F64),
116            _ => (MathPrecision::F32, ElementKind::F32),
117        };
118        let precision_guarantee = PrecisionGuarantee {
119            math_precision,
120            accumulator,
121            bit_stable_on_same_hardware: true,
122            deterministic: true,
123        };
124        let sku = KernelSku {
125            category: OpCategory::TernaryElementwise,
126            // `op` discriminant: matches `TernaryKind::Where` (= 4).
127            // BW is implied by the plan type itself
128            // (`WhereBackwardPlan` vs `WherePlan`) — no separate
129            // discriminant needed, mirroring the BinaryBackwardPlan
130            // convention.
131            op: 4,
132            element: T::KIND,
133            // `aux_element` would capture cond's dtype but ElementKind
134            // doesn't carry a `U8` variant today — rely on the
135            // `Where`-specific op discriminant for telemetry / cache
136            // disambiguation, same as the FW.
137            aux_element: None,
138            layout: None,
139            epilogue: None,
140            arch: ArchSku::Sm80,
141            backend: BackendKind::Bespoke,
142            precision_guarantee,
143        };
144        Ok(Self {
145            desc: *desc,
146            sku,
147            _marker: PhantomData,
148        })
149    }
150
151    /// Validate that this plan can launch with `args`.
152    pub fn can_implement(&self, args: &WhereBackwardArgs<'_, T, N>) -> Result<()> {
153        if args.dy.shape != self.desc.shape {
154            return Err(Error::InvalidProblem(
155                "baracuda-kernels::WhereBackwardPlan: dy shape mismatch with descriptor",
156            ));
157        }
158        if args.da.shape != self.desc.shape {
159            return Err(Error::InvalidProblem(
160                "baracuda-kernels::WhereBackwardPlan: da shape mismatch with descriptor",
161            ));
162        }
163        if args.db.shape != self.desc.shape {
164            return Err(Error::InvalidProblem(
165                "baracuda-kernels::WhereBackwardPlan: db shape mismatch with descriptor",
166            ));
167        }
168        if args.cond.shape != self.desc.shape {
169            return Err(Error::InvalidProblem(
170                "baracuda-kernels::WhereBackwardPlan: cond shape mismatch with descriptor \
171                 (trailblazer requires full-shape cond; stride-0 broadcasting on cond \
172                 lands later)",
173            ));
174        }
175
176        // Contig-only for trailblazer (cond included — no stride-0 axes).
177        if !args.cond.is_contiguous()
178            || !args.dy.is_contiguous()
179            || !args.da.is_contiguous()
180            || !args.db.is_contiguous()
181        {
182            return Err(Error::Unsupported(
183                "baracuda-kernels::WhereBackwardPlan: trailblazer requires contiguous \
184                 cond / dy / da / db; strided / broadcast fanout lands later",
185            ));
186        }
187
188        if N > 8 {
189            return Err(Error::Unsupported(
190                "baracuda-kernels::WhereBackwardPlan: tensor rank > 8 not supported",
191            ));
192        }
193
194        let numel = args.dy.numel();
195        let cond_len = args.cond.data.len() as i64;
196        let dy_len = args.dy.data.len() as i64;
197        let da_len = args.da.data.len() as i64;
198        let db_len = args.db.data.len() as i64;
199        if dy_len < numel || da_len < numel || db_len < numel || cond_len < numel {
200            return Err(Error::BufferTooSmall {
201                needed: numel as usize,
202                got: cond_len.min(dy_len).min(da_len).min(db_len) as usize,
203            });
204        }
205        Ok(())
206    }
207
208    /// Workspace size in bytes. Always `0` for the trailblazer.
209    #[inline]
210    pub fn workspace_size(&self) -> usize {
211        0
212    }
213
214    /// Identity of the kernel this plan picked.
215    #[inline]
216    pub fn sku(&self) -> KernelSku {
217        self.sku
218    }
219
220    /// Numerical guarantees for this plan's kernel.
221    #[inline]
222    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
223        self.sku.precision_guarantee
224    }
225
226    /// Launch.
227    pub fn run(
228        &self,
229        stream: &Stream,
230        _workspace: Workspace<'_>,
231        args: WhereBackwardArgs<'_, T, N>,
232    ) -> Result<()> {
233        self.can_implement(&args)?;
234        let numel = args.dy.numel();
235        if numel == 0 {
236            return Ok(());
237        }
238        let cond_ptr = args.cond.data.as_raw().0 as *const c_void;
239        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
240        let da_ptr = args.da.data.as_raw().0 as *mut c_void;
241        let db_ptr = args.db.data.as_raw().0 as *mut c_void;
242        let stream_ptr = stream.as_raw() as *mut c_void;
243
244        let status = match T::KIND {
245            ElementKind::F32 => unsafe {
246                baracuda_kernels_sys::baracuda_kernels_where_backward_f32_run(
247                    numel,
248                    cond_ptr,
249                    dy_ptr,
250                    da_ptr,
251                    db_ptr,
252                    core::ptr::null_mut(),
253                    0,
254                    stream_ptr,
255                )
256            },
257            ElementKind::F16 => unsafe {
258                baracuda_kernels_sys::baracuda_kernels_where_backward_f16_run(
259                    numel,
260                    cond_ptr,
261                    dy_ptr,
262                    da_ptr,
263                    db_ptr,
264                    core::ptr::null_mut(),
265                    0,
266                    stream_ptr,
267                )
268            },
269            ElementKind::Bf16 => unsafe {
270                baracuda_kernels_sys::baracuda_kernels_where_backward_bf16_run(
271                    numel,
272                    cond_ptr,
273                    dy_ptr,
274                    da_ptr,
275                    db_ptr,
276                    core::ptr::null_mut(),
277                    0,
278                    stream_ptr,
279                )
280            },
281            ElementKind::F64 => unsafe {
282                baracuda_kernels_sys::baracuda_kernels_where_backward_f64_run(
283                    numel,
284                    cond_ptr,
285                    dy_ptr,
286                    da_ptr,
287                    db_ptr,
288                    core::ptr::null_mut(),
289                    0,
290                    stream_ptr,
291                )
292            },
293            _ => {
294                return Err(Error::Unsupported(
295                    "baracuda-kernels::WhereBackwardPlan::run reached an unimplemented \
296                     dtype — select() should have caught this",
297                ));
298            }
299        };
300        map_status(status)
301    }
302}
303
304fn map_status(code: i32) -> Result<()> {
305    match code {
306        0 => Ok(()),
307        1 => Err(Error::MisalignedOperand),
308        2 => Err(Error::InvalidProblem(
309            "baracuda-kernels-sys reported invalid problem",
310        )),
311        3 => Err(Error::Unsupported(
312            "baracuda-kernels-sys reported unsupported configuration",
313        )),
314        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
315        n => Err(Error::CutlassInternal(n)),
316    }
317}