Skip to main content

baracuda_kernels/elementwise/
binary_param_backward.rs

1//! Backward plan for the parameterized binary elementwise family.
2//!
3//! Sibling of [`crate::BinaryParamPlan`]. For `Lerp`, the BW formula is
4//! `da = (1 - weight)·dy`, `db = weight·dy` — no saved tensors are
5//! needed because the gradient is a pure linear scaling of `dy` by
6//! constants derived from the scalar weight.
7//!
8//! Today wired: `Lerp × {f32, f16, bf16, f64}`. The scalar `weight` is a
9//! constant w.r.t. both inputs — no gradient flows to it.
10//!
11//! Trailblazer constraints: contig-only;
12//! `dy.shape == da.shape == db.shape == desc.shape`.
13
14use core::ffi::c_void;
15use core::marker::PhantomData;
16
17use baracuda_cutlass::{Error, Result};
18use baracuda_driver::Stream;
19use baracuda_kernels_types::{
20    ArchSku, BackendKind, BinaryKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
21    PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
22};
23
24/// Descriptor for a parameterized binary backward op. Same shape as the
25/// FW descriptor.
26#[derive(Copy, Clone, Debug)]
27pub struct BinaryParamBackwardDescriptor<const N: usize> {
28    /// Which forward parameterized binary op this is the backward of.
29    pub kind: BinaryKind,
30    /// Tensor shape (shared by dy / da / db).
31    pub shape: [i32; N],
32    /// Element type.
33    pub element: ElementKind,
34    /// Op-specific scalar parameter; same semantics as the FW
35    /// descriptor's `param` field.
36    pub param: f32,
37}
38
39/// Args bundle for a parameterized binary backward launch.
40///
41/// `Lerp` BW doesn't need saved forward inputs — the gradient is a pure
42/// function of `dy` and the scalar param.
43pub struct BinaryParamBackwardArgs<'a, T: Element, const N: usize> {
44    /// Upstream gradient.
45    pub dy: TensorRef<'a, T, N>,
46    /// Gradient w.r.t. `a`.
47    pub da: TensorMut<'a, T, N>,
48    /// Gradient w.r.t. `b`.
49    pub db: TensorMut<'a, T, N>,
50}
51
52/// Parameterized binary backward plan.
53pub struct BinaryParamBackwardPlan<T: Element, const N: usize> {
54    desc: BinaryParamBackwardDescriptor<N>,
55    sku: KernelSku,
56    _marker: PhantomData<T>,
57}
58
59impl<T: Element, const N: usize> BinaryParamBackwardPlan<T, N> {
60    /// Pick a kernel.
61    pub fn select(
62        _stream: &Stream,
63        desc: &BinaryParamBackwardDescriptor<N>,
64        _pref: PlanPreference,
65    ) -> Result<Self> {
66        if desc.element != T::KIND {
67            return Err(Error::Unsupported(
68                "baracuda-kernels::BinaryParamBackwardPlan: descriptor element != T",
69            ));
70        }
71        for &d in desc.shape.iter() {
72            if d < 0 {
73                return Err(Error::InvalidProblem(
74                    "baracuda-kernels::BinaryParamBackwardPlan: shape dims must be non-negative",
75                ));
76            }
77        }
78
79        let kind_in_scope = matches!(desc.kind, BinaryKind::Lerp);
80        let dtype_in_scope = matches!(
81            T::KIND,
82            ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
83        );
84        if !(kind_in_scope && dtype_in_scope) {
85            return Err(Error::Unsupported(
86                "baracuda-kernels::BinaryParamBackwardPlan: today only `Lerp × \
87                 {f32, f16, bf16, f64}` is wired.",
88            ));
89        }
90
91        let precision_guarantee = PrecisionGuarantee {
92            math_precision: MathPrecision::F32,
93            accumulator: ElementKind::F32,
94            bit_stable_on_same_hardware: true,
95            deterministic: true,
96        };
97        let sku = KernelSku {
98            category: OpCategory::BinaryElementwise,
99            op: desc.kind as u16,
100            element: T::KIND,
101            aux_element: None,
102            layout: None,
103            epilogue: None,
104            arch: ArchSku::Sm80,
105            backend: BackendKind::Bespoke,
106            precision_guarantee,
107        };
108        Ok(Self {
109            desc: *desc,
110            sku,
111            _marker: PhantomData,
112        })
113    }
114
115    /// Validate args.
116    pub fn can_implement(&self, args: &BinaryParamBackwardArgs<'_, T, N>) -> Result<()> {
117        if args.dy.shape != self.desc.shape {
118            return Err(Error::InvalidProblem(
119                "baracuda-kernels::BinaryParamBackwardPlan: dy shape mismatch",
120            ));
121        }
122        if args.da.shape != self.desc.shape {
123            return Err(Error::InvalidProblem(
124                "baracuda-kernels::BinaryParamBackwardPlan: da shape mismatch",
125            ));
126        }
127        if args.db.shape != self.desc.shape {
128            return Err(Error::InvalidProblem(
129                "baracuda-kernels::BinaryParamBackwardPlan: db shape mismatch",
130            ));
131        }
132        if !args.dy.is_contiguous() || !args.da.is_contiguous() || !args.db.is_contiguous() {
133            return Err(Error::Unsupported(
134                "baracuda-kernels::BinaryParamBackwardPlan: contig-only trailblazer",
135            ));
136        }
137        let numel = args.dy.numel();
138        let dy_len = args.dy.data.len() as i64;
139        let da_len = args.da.data.len() as i64;
140        let db_len = args.db.data.len() as i64;
141        if dy_len < numel || da_len < numel || db_len < numel {
142            return Err(Error::BufferTooSmall {
143                needed: numel as usize,
144                got: dy_len.min(da_len).min(db_len) as usize,
145            });
146        }
147        Ok(())
148    }
149
150    /// Workspace size in bytes.
151    #[inline]
152    pub fn workspace_size(&self) -> usize {
153        0
154    }
155    /// Kernel SKU identity.
156    #[inline]
157    pub fn sku(&self) -> KernelSku {
158        self.sku
159    }
160    /// Numerical guarantees.
161    #[inline]
162    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
163        self.sku.precision_guarantee
164    }
165
166    /// Launch.
167    pub fn run(
168        &self,
169        stream: &Stream,
170        _workspace: Workspace<'_>,
171        args: BinaryParamBackwardArgs<'_, T, N>,
172    ) -> Result<()> {
173        self.can_implement(&args)?;
174        let numel = args.dy.numel();
175        if numel == 0 {
176            return Ok(());
177        }
178        let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
179        let da_ptr = args.da.data.as_raw().0 as *mut c_void;
180        let db_ptr = args.db.data.as_raw().0 as *mut c_void;
181        let stream_ptr = stream.as_raw() as *mut c_void;
182        let p = self.desc.param;
183
184        let status = match (self.desc.kind, T::KIND) {
185            (BinaryKind::Lerp, ElementKind::F32) => unsafe {
186                baracuda_kernels_sys::baracuda_kernels_binary_lerp_backward_f32_run(
187                    numel, dy_ptr, da_ptr, db_ptr, p,
188                    core::ptr::null_mut(), 0, stream_ptr,
189                )
190            },
191            (BinaryKind::Lerp, ElementKind::F16) => unsafe {
192                baracuda_kernels_sys::baracuda_kernels_binary_lerp_backward_f16_run(
193                    numel, dy_ptr, da_ptr, db_ptr, p,
194                    core::ptr::null_mut(), 0, stream_ptr,
195                )
196            },
197            (BinaryKind::Lerp, ElementKind::Bf16) => unsafe {
198                baracuda_kernels_sys::baracuda_kernels_binary_lerp_backward_bf16_run(
199                    numel, dy_ptr, da_ptr, db_ptr, p,
200                    core::ptr::null_mut(), 0, stream_ptr,
201                )
202            },
203            (BinaryKind::Lerp, ElementKind::F64) => unsafe {
204                baracuda_kernels_sys::baracuda_kernels_binary_lerp_backward_f64_run(
205                    numel, dy_ptr, da_ptr, db_ptr, p,
206                    core::ptr::null_mut(), 0, stream_ptr,
207                )
208            },
209            _ => {
210                return Err(Error::Unsupported(
211                    "baracuda-kernels::BinaryParamBackwardPlan: dispatcher reached an \
212                     unimplemented (kind, dtype) pair — select() should have caught this",
213                ));
214            }
215        };
216        map_status(status)
217    }
218}
219
220fn map_status(code: i32) -> Result<()> {
221    match code {
222        0 => Ok(()),
223        1 => Err(Error::MisalignedOperand),
224        2 => Err(Error::InvalidProblem(
225            "baracuda-kernels-sys reported invalid problem",
226        )),
227        3 => Err(Error::Unsupported(
228            "baracuda-kernels-sys reported unsupported configuration",
229        )),
230        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
231        n => Err(Error::CutlassInternal(n)),
232    }
233}