baracuda_kernels/elementwise/
binary_param.rs1use core::ffi::c_void;
19use core::marker::PhantomData;
20
21use baracuda_cutlass::{Error, Result};
22use baracuda_driver::Stream;
23use baracuda_kernels_types::{
24 ArchSku, BackendKind, BinaryKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
25 PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
26};
27
28#[derive(Copy, Clone, Debug)]
30pub struct BinaryParamDescriptor<const N: usize> {
31 pub kind: BinaryKind,
33 pub shape: [i32; N],
35 pub element: ElementKind,
37 pub param: f32,
41}
42
43pub struct BinaryParamArgs<'a, T: Element, const N: usize> {
45 pub a: TensorRef<'a, T, N>,
47 pub b: TensorRef<'a, T, N>,
49 pub y: TensorMut<'a, T, N>,
51}
52
53pub struct BinaryParamPlan<T: Element, const N: usize> {
55 desc: BinaryParamDescriptor<N>,
56 sku: KernelSku,
57 _marker: PhantomData<T>,
58}
59
60impl<T: Element, const N: usize> BinaryParamPlan<T, N> {
61 pub fn select(
63 _stream: &Stream,
64 desc: &BinaryParamDescriptor<N>,
65 _pref: PlanPreference,
66 ) -> Result<Self> {
67 if desc.element != T::KIND {
68 return Err(Error::Unsupported(
69 "baracuda-kernels::BinaryParamPlan: descriptor element != type parameter T",
70 ));
71 }
72 for &d in desc.shape.iter() {
73 if d < 0 {
74 return Err(Error::InvalidProblem(
75 "baracuda-kernels::BinaryParamPlan: shape dims must be non-negative",
76 ));
77 }
78 }
79
80 let kind_in_scope = matches!(desc.kind, BinaryKind::Lerp);
81 let dtype_in_scope = matches!(
82 T::KIND,
83 ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
84 );
85 if !(kind_in_scope && dtype_in_scope) {
86 return Err(Error::Unsupported(
87 "baracuda-kernels::BinaryParamPlan: today only `Lerp × {f32, f16, bf16, f64}` \
88 is wired; other parameterized binary ops join in later fanout.",
89 ));
90 }
91
92 let precision_guarantee = PrecisionGuarantee {
93 math_precision: MathPrecision::F32,
94 accumulator: ElementKind::F32,
95 bit_stable_on_same_hardware: true,
96 deterministic: true,
97 };
98 let sku = KernelSku {
99 category: OpCategory::BinaryElementwise,
100 op: desc.kind as u16,
101 element: T::KIND,
102 aux_element: None,
103 layout: None,
104 epilogue: None,
105 arch: ArchSku::Sm80,
106 backend: BackendKind::Bespoke,
107 precision_guarantee,
108 };
109 Ok(Self {
110 desc: *desc,
111 sku,
112 _marker: PhantomData,
113 })
114 }
115
116 pub fn can_implement(&self, args: &BinaryParamArgs<'_, T, N>) -> Result<()> {
118 if args.a.shape != self.desc.shape {
119 return Err(Error::InvalidProblem(
120 "baracuda-kernels::BinaryParamPlan: A shape mismatch",
121 ));
122 }
123 if args.b.shape != self.desc.shape {
124 return Err(Error::InvalidProblem(
125 "baracuda-kernels::BinaryParamPlan: B shape mismatch",
126 ));
127 }
128 if args.y.shape != self.desc.shape {
129 return Err(Error::InvalidProblem(
130 "baracuda-kernels::BinaryParamPlan: Y shape mismatch",
131 ));
132 }
133 if !args.a.is_contiguous() || !args.b.is_contiguous() || !args.y.is_contiguous() {
134 return Err(Error::Unsupported(
135 "baracuda-kernels::BinaryParamPlan: contig-only trailblazer; strided fanout \
136 lands later",
137 ));
138 }
139 let numel = args.y.numel();
140 let a_len = args.a.data.len() as i64;
141 let b_len = args.b.data.len() as i64;
142 let y_len = args.y.data.len() as i64;
143 if a_len < numel || b_len < numel || y_len < numel {
144 return Err(Error::BufferTooSmall {
145 needed: numel as usize,
146 got: a_len.min(b_len).min(y_len) as usize,
147 });
148 }
149 Ok(())
150 }
151
152 #[inline]
154 pub fn workspace_size(&self) -> usize {
155 0
156 }
157 #[inline]
159 pub fn sku(&self) -> KernelSku {
160 self.sku
161 }
162 #[inline]
164 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
165 self.sku.precision_guarantee
166 }
167
168 pub fn run(
170 &self,
171 stream: &Stream,
172 _workspace: Workspace<'_>,
173 args: BinaryParamArgs<'_, T, N>,
174 ) -> Result<()> {
175 self.can_implement(&args)?;
176 let numel = args.y.numel();
177 if numel == 0 {
178 return Ok(());
179 }
180 let a_ptr = args.a.data.as_raw().0 as *const c_void;
181 let b_ptr = args.b.data.as_raw().0 as *const c_void;
182 let y_ptr = args.y.data.as_raw().0 as *mut c_void;
183 let stream_ptr = stream.as_raw() as *mut c_void;
184 let p = self.desc.param;
185
186 let status = match (self.desc.kind, T::KIND) {
187 (BinaryKind::Lerp, ElementKind::F32) => unsafe {
188 baracuda_kernels_sys::baracuda_kernels_binary_lerp_f32_run(
189 numel, a_ptr, b_ptr, y_ptr, p,
190 core::ptr::null_mut(), 0, stream_ptr,
191 )
192 },
193 (BinaryKind::Lerp, ElementKind::F16) => unsafe {
194 baracuda_kernels_sys::baracuda_kernels_binary_lerp_f16_run(
195 numel, a_ptr, b_ptr, y_ptr, p,
196 core::ptr::null_mut(), 0, stream_ptr,
197 )
198 },
199 (BinaryKind::Lerp, ElementKind::Bf16) => unsafe {
200 baracuda_kernels_sys::baracuda_kernels_binary_lerp_bf16_run(
201 numel, a_ptr, b_ptr, y_ptr, p,
202 core::ptr::null_mut(), 0, stream_ptr,
203 )
204 },
205 (BinaryKind::Lerp, ElementKind::F64) => unsafe {
206 baracuda_kernels_sys::baracuda_kernels_binary_lerp_f64_run(
207 numel, a_ptr, b_ptr, y_ptr, p,
208 core::ptr::null_mut(), 0, stream_ptr,
209 )
210 },
211 _ => {
212 return Err(Error::Unsupported(
213 "baracuda-kernels::BinaryParamPlan: dispatcher reached an unimplemented \
214 (kind, dtype) pair — select() should have caught this",
215 ));
216 }
217 };
218 map_status(status)
219 }
220}
221
222fn map_status(code: i32) -> Result<()> {
223 match code {
224 0 => Ok(()),
225 1 => Err(Error::MisalignedOperand),
226 2 => Err(Error::InvalidProblem(
227 "baracuda-kernels-sys reported invalid problem",
228 )),
229 3 => Err(Error::Unsupported(
230 "baracuda-kernels-sys reported unsupported configuration",
231 )),
232 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
233 n => Err(Error::CutlassInternal(n)),
234 }
235}