baracuda_kernels/elementwise/
binary_param_backward.rs1use core::ffi::c_void;
15use core::marker::PhantomData;
16
17use baracuda_cutlass::{Error, Result};
18use baracuda_driver::Stream;
19use baracuda_kernels_types::{
20 ArchSku, BackendKind, BinaryKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
21 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
22};
23
24#[derive(Copy, Clone, Debug)]
27pub struct BinaryParamBackwardDescriptor<const N: usize> {
28 pub kind: BinaryKind,
30 pub shape: [i32; N],
32 pub element: ElementKind,
34 pub param: f32,
37}
38
39pub struct BinaryParamBackwardArgs<'a, T: Element, const N: usize> {
44 pub dy: TensorRef<'a, T, N>,
46 pub da: TensorMut<'a, T, N>,
48 pub db: TensorMut<'a, T, N>,
50}
51
52pub struct BinaryParamBackwardPlan<T: Element, const N: usize> {
54 desc: BinaryParamBackwardDescriptor<N>,
55 sku: KernelSku,
56 _marker: PhantomData<T>,
57}
58
59impl<T: Element, const N: usize> BinaryParamBackwardPlan<T, N> {
60 pub fn select(
62 _stream: &Stream,
63 desc: &BinaryParamBackwardDescriptor<N>,
64 _pref: PlanPreference,
65 ) -> Result<Self> {
66 if desc.element != T::KIND {
67 return Err(Error::Unsupported(
68 "baracuda-kernels::BinaryParamBackwardPlan: descriptor element != T",
69 ));
70 }
71 for &d in desc.shape.iter() {
72 if d < 0 {
73 return Err(Error::InvalidProblem(
74 "baracuda-kernels::BinaryParamBackwardPlan: shape dims must be non-negative",
75 ));
76 }
77 }
78
79 let kind_in_scope = matches!(desc.kind, BinaryKind::Lerp);
80 let dtype_in_scope = matches!(
81 T::KIND,
82 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
83 );
84 if !(kind_in_scope && dtype_in_scope) {
85 return Err(Error::Unsupported(
86 "baracuda-kernels::BinaryParamBackwardPlan: today only `Lerp × \
87 {f32, f16, bf16, f64}` is wired.",
88 ));
89 }
90
91 let precision_guarantee = PrecisionGuarantee {
92 math_precision: MathPrecision::F32,
93 accumulator: ElementKind::F32,
94 bit_stable_on_same_hardware: true,
95 deterministic: true,
96 };
97 let sku = KernelSku {
98 category: OpCategory::BinaryElementwise,
99 op: desc.kind as u16,
100 element: T::KIND,
101 aux_element: None,
102 layout: None,
103 epilogue: None,
104 arch: ArchSku::Sm80,
105 backend: BackendKind::Bespoke,
106 precision_guarantee,
107 };
108 Ok(Self {
109 desc: *desc,
110 sku,
111 _marker: PhantomData,
112 })
113 }
114
115 pub fn can_implement(&self, args: &BinaryParamBackwardArgs<'_, T, N>) -> Result<()> {
117 if args.dy.shape != self.desc.shape {
118 return Err(Error::InvalidProblem(
119 "baracuda-kernels::BinaryParamBackwardPlan: dy shape mismatch",
120 ));
121 }
122 if args.da.shape != self.desc.shape {
123 return Err(Error::InvalidProblem(
124 "baracuda-kernels::BinaryParamBackwardPlan: da shape mismatch",
125 ));
126 }
127 if args.db.shape != self.desc.shape {
128 return Err(Error::InvalidProblem(
129 "baracuda-kernels::BinaryParamBackwardPlan: db shape mismatch",
130 ));
131 }
132 if !args.dy.is_contiguous() || !args.da.is_contiguous() || !args.db.is_contiguous() {
133 return Err(Error::Unsupported(
134 "baracuda-kernels::BinaryParamBackwardPlan: contig-only trailblazer",
135 ));
136 }
137 let numel = args.dy.numel();
138 let dy_len = args.dy.data.len() as i64;
139 let da_len = args.da.data.len() as i64;
140 let db_len = args.db.data.len() as i64;
141 if dy_len < numel || da_len < numel || db_len < numel {
142 return Err(Error::BufferTooSmall {
143 needed: numel as usize,
144 got: dy_len.min(da_len).min(db_len) as usize,
145 });
146 }
147 Ok(())
148 }
149
150 #[inline]
152 pub fn workspace_size(&self) -> usize {
153 0
154 }
155 #[inline]
157 pub fn sku(&self) -> KernelSku {
158 self.sku
159 }
160 #[inline]
162 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
163 self.sku.precision_guarantee
164 }
165
166 pub fn run(
168 &self,
169 stream: &Stream,
170 _workspace: Workspace<'_>,
171 args: BinaryParamBackwardArgs<'_, T, N>,
172 ) -> Result<()> {
173 self.can_implement(&args)?;
174 let numel = args.dy.numel();
175 if numel == 0 {
176 return Ok(());
177 }
178 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
179 let da_ptr = args.da.data.as_raw().0 as *mut c_void;
180 let db_ptr = args.db.data.as_raw().0 as *mut c_void;
181 let stream_ptr = stream.as_raw() as *mut c_void;
182 let p = self.desc.param;
183
184 let status = match (self.desc.kind, T::KIND) {
185 (BinaryKind::Lerp, ElementKind::F32) => unsafe {
186 baracuda_kernels_sys::baracuda_kernels_binary_lerp_backward_f32_run(
187 numel, dy_ptr, da_ptr, db_ptr, p,
188 core::ptr::null_mut(), 0, stream_ptr,
189 )
190 },
191 (BinaryKind::Lerp, ElementKind::F16) => unsafe {
192 baracuda_kernels_sys::baracuda_kernels_binary_lerp_backward_f16_run(
193 numel, dy_ptr, da_ptr, db_ptr, p,
194 core::ptr::null_mut(), 0, stream_ptr,
195 )
196 },
197 (BinaryKind::Lerp, ElementKind::Bf16) => unsafe {
198 baracuda_kernels_sys::baracuda_kernels_binary_lerp_backward_bf16_run(
199 numel, dy_ptr, da_ptr, db_ptr, p,
200 core::ptr::null_mut(), 0, stream_ptr,
201 )
202 },
203 (BinaryKind::Lerp, ElementKind::F64) => unsafe {
204 baracuda_kernels_sys::baracuda_kernels_binary_lerp_backward_f64_run(
205 numel, dy_ptr, da_ptr, db_ptr, p,
206 core::ptr::null_mut(), 0, stream_ptr,
207 )
208 },
209 _ => {
210 return Err(Error::Unsupported(
211 "baracuda-kernels::BinaryParamBackwardPlan: dispatcher reached an \
212 unimplemented (kind, dtype) pair — select() should have caught this",
213 ));
214 }
215 };
216 map_status(status)
217 }
218}
219
220fn map_status(code: i32) -> Result<()> {
221 match code {
222 0 => Ok(()),
223 1 => Err(Error::MisalignedOperand),
224 2 => Err(Error::InvalidProblem(
225 "baracuda-kernels-sys reported invalid problem",
226 )),
227 3 => Err(Error::Unsupported(
228 "baracuda-kernels-sys reported unsupported configuration",
229 )),
230 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
231 n => Err(Error::CutlassInternal(n)),
232 }
233}