1use core::ffi::c_void;
50use core::marker::PhantomData;
51
52use baracuda_cutlass::{Error, Result};
53use baracuda_driver::Stream;
54use baracuda_kernels_types::{
55 ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
56 PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
57};
58use baracuda_types::DeviceRepr;
59
60#[derive(Copy, Clone, Debug)]
71pub struct CastSubByteDescriptor {
72 pub numel: i32,
74 pub input_element: ElementKind,
76 pub output_element: ElementKind,
78}
79
80pub struct CastSubByteArgs<'a, TIn: DeviceRepr + Copy + 'static, TOut: DeviceRepr + Copy + 'static>
88{
89 pub input: TensorRef<'a, TIn, 1>,
91 pub output: TensorMut<'a, TOut, 1>,
93}
94
95pub struct CastSubBytePlan<
109 TIn: DeviceRepr + Copy + 'static,
110 TOut: DeviceRepr + Copy + 'static,
111> {
112 desc: CastSubByteDescriptor,
113 sku: KernelSku,
114 _marker_in: PhantomData<TIn>,
115 _marker_out: PhantomData<TOut>,
116}
117
118impl<TIn: DeviceRepr + Copy + 'static, TOut: DeviceRepr + Copy + 'static>
119 CastSubBytePlan<TIn, TOut>
120{
121 pub fn select(
123 _stream: &Stream,
124 desc: &CastSubByteDescriptor,
125 _pref: PlanPreference,
126 ) -> Result<Self> {
127 if !type_size_matches_kind::<TIn>(desc.input_element) {
128 return Err(Error::Unsupported(
129 "baracuda-kernels::CastSubBytePlan: sizeof::<TIn>() does not match \
130 descriptor input_element width",
131 ));
132 }
133 if !type_size_matches_kind::<TOut>(desc.output_element) {
134 return Err(Error::Unsupported(
135 "baracuda-kernels::CastSubBytePlan: sizeof::<TOut>() does not match \
136 descriptor output_element width",
137 ));
138 }
139 if desc.numel < 0 {
140 return Err(Error::InvalidProblem(
141 "baracuda-kernels::CastSubBytePlan: numel must be non-negative",
142 ));
143 }
144 let inv = matches!(
146 desc.input_element,
147 ElementKind::S4 | ElementKind::U4
148 );
149 let outv = matches!(
150 desc.output_element,
151 ElementKind::S4 | ElementKind::U4
152 );
153 if (inv || outv) && (desc.numel % 2 != 0) {
154 return Err(Error::InvalidProblem(
155 "baracuda-kernels::CastSubBytePlan: S4 / U4 endpoints require even numel \
156 (packed buffer is numel/2 bytes)",
157 ));
158 }
159 if !pair_in_scope(desc.input_element, desc.output_element) {
160 return Err(Error::Unsupported(
161 "baracuda-kernels::CastSubBytePlan: (input, output) pair not in scope \
162 for Phase 13.3 — see module docs for the wired set",
163 ));
164 }
165
166 let precision_guarantee = PrecisionGuarantee {
171 math_precision: MathPrecision::F32,
172 accumulator: ElementKind::F32,
173 bit_stable_on_same_hardware: true,
174 deterministic: true,
175 };
176 let sku = KernelSku {
177 category: OpCategory::UnaryElementwise,
178 op: UnaryKind::Cast as u16,
179 element: desc.input_element,
180 aux_element: Some(desc.output_element),
181 layout: None,
182 epilogue: None,
183 arch: ArchSku::Sm80,
184 backend: BackendKind::Bespoke,
185 precision_guarantee,
186 };
187 Ok(Self {
188 desc: *desc,
189 sku,
190 _marker_in: PhantomData,
191 _marker_out: PhantomData,
192 })
193 }
194
195 pub fn can_implement(&self, args: &CastSubByteArgs<'_, TIn, TOut>) -> Result<()> {
199 let expected = self.desc.numel as i64;
200 let in_packed = matches!(self.desc.input_element, ElementKind::S4 | ElementKind::U4);
206 let out_packed = matches!(self.desc.output_element, ElementKind::S4 | ElementKind::U4);
207
208 let needed_in = if in_packed { (expected + 1) / 2 } else { expected };
209 let needed_out = if out_packed { (expected + 1) / 2 } else { expected };
210
211 if (args.input.data.len() as i64) < needed_in {
212 return Err(Error::BufferTooSmall {
213 needed: needed_in as usize,
214 got: args.input.data.len(),
215 });
216 }
217 if (args.output.data.len() as i64) < needed_out {
218 return Err(Error::BufferTooSmall {
219 needed: needed_out as usize,
220 got: args.output.data.len(),
221 });
222 }
223 Ok(())
224 }
225
226 #[inline]
228 pub fn workspace_size(&self) -> usize {
229 0
230 }
231
232 #[inline]
234 pub fn sku(&self) -> KernelSku {
235 self.sku
236 }
237
238 #[inline]
240 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
241 self.sku.precision_guarantee
242 }
243
244 pub fn run(
246 &self,
247 stream: &Stream,
248 _workspace: Workspace<'_>,
249 args: CastSubByteArgs<'_, TIn, TOut>,
250 ) -> Result<()> {
251 self.can_implement(&args)?;
252 let numel = self.desc.numel as i64;
253 if numel == 0 {
254 return Ok(());
255 }
256 let x_ptr = args.input.data.as_raw().0 as *const c_void;
257 let y_ptr = args.output.data.as_raw().0 as *mut c_void;
258 let stream_ptr = stream.as_raw() as *mut c_void;
259
260 let status = match (self.desc.input_element, self.desc.output_element) {
261 (ElementKind::Bool, ElementKind::I32) => unsafe {
263 baracuda_kernels_sys::baracuda_kernels_cast_bool_i32_run(
264 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
265 )
266 },
267 (ElementKind::Bool, ElementKind::I64) => unsafe {
268 baracuda_kernels_sys::baracuda_kernels_cast_bool_i64_run(
269 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
270 )
271 },
272 (ElementKind::Bool, ElementKind::F32) => unsafe {
273 baracuda_kernels_sys::baracuda_kernels_cast_bool_f32_run(
274 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
275 )
276 },
277 (ElementKind::Bool, ElementKind::F16) => unsafe {
278 baracuda_kernels_sys::baracuda_kernels_cast_bool_f16_run(
279 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
280 )
281 },
282 (ElementKind::Bool, ElementKind::Bf16) => unsafe {
283 baracuda_kernels_sys::baracuda_kernels_cast_bool_bf16_run(
284 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
285 )
286 },
287 (ElementKind::I32, ElementKind::Bool) => unsafe {
289 baracuda_kernels_sys::baracuda_kernels_cast_i32_bool_run(
290 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
291 )
292 },
293 (ElementKind::I64, ElementKind::Bool) => unsafe {
294 baracuda_kernels_sys::baracuda_kernels_cast_i64_bool_run(
295 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
296 )
297 },
298 (ElementKind::F32, ElementKind::Bool) => unsafe {
299 baracuda_kernels_sys::baracuda_kernels_cast_f32_bool_run(
300 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
301 )
302 },
303 (ElementKind::F16, ElementKind::Bool) => unsafe {
304 baracuda_kernels_sys::baracuda_kernels_cast_f16_bool_run(
305 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
306 )
307 },
308 (ElementKind::Bf16, ElementKind::Bool) => unsafe {
309 baracuda_kernels_sys::baracuda_kernels_cast_bf16_bool_run(
310 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
311 )
312 },
313 (ElementKind::Fp8E4M3, ElementKind::F32) => unsafe {
315 baracuda_kernels_sys::baracuda_kernels_cast_fp8e4m3_f32_run(
316 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
317 )
318 },
319 (ElementKind::Fp8E4M3, ElementKind::F16) => unsafe {
320 baracuda_kernels_sys::baracuda_kernels_cast_fp8e4m3_f16_run(
321 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
322 )
323 },
324 (ElementKind::Fp8E4M3, ElementKind::Bf16) => unsafe {
325 baracuda_kernels_sys::baracuda_kernels_cast_fp8e4m3_bf16_run(
326 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
327 )
328 },
329 (ElementKind::F32, ElementKind::Fp8E4M3) => unsafe {
330 baracuda_kernels_sys::baracuda_kernels_cast_f32_fp8e4m3_run(
331 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
332 )
333 },
334 (ElementKind::F16, ElementKind::Fp8E4M3) => unsafe {
335 baracuda_kernels_sys::baracuda_kernels_cast_f16_fp8e4m3_run(
336 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
337 )
338 },
339 (ElementKind::Bf16, ElementKind::Fp8E4M3) => unsafe {
340 baracuda_kernels_sys::baracuda_kernels_cast_bf16_fp8e4m3_run(
341 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
342 )
343 },
344 (ElementKind::Fp8E5M2, ElementKind::F32) => unsafe {
346 baracuda_kernels_sys::baracuda_kernels_cast_fp8e5m2_f32_run(
347 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
348 )
349 },
350 (ElementKind::Fp8E5M2, ElementKind::F16) => unsafe {
351 baracuda_kernels_sys::baracuda_kernels_cast_fp8e5m2_f16_run(
352 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
353 )
354 },
355 (ElementKind::Fp8E5M2, ElementKind::Bf16) => unsafe {
356 baracuda_kernels_sys::baracuda_kernels_cast_fp8e5m2_bf16_run(
357 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
358 )
359 },
360 (ElementKind::F32, ElementKind::Fp8E5M2) => unsafe {
361 baracuda_kernels_sys::baracuda_kernels_cast_f32_fp8e5m2_run(
362 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
363 )
364 },
365 (ElementKind::F16, ElementKind::Fp8E5M2) => unsafe {
366 baracuda_kernels_sys::baracuda_kernels_cast_f16_fp8e5m2_run(
367 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
368 )
369 },
370 (ElementKind::Bf16, ElementKind::Fp8E5M2) => unsafe {
371 baracuda_kernels_sys::baracuda_kernels_cast_bf16_fp8e5m2_run(
372 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
373 )
374 },
375 (ElementKind::S4, ElementKind::I32) => unsafe {
377 baracuda_kernels_sys::baracuda_kernels_cast_s4_i32_run(
378 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
379 )
380 },
381 (ElementKind::S4, ElementKind::I64) => unsafe {
382 baracuda_kernels_sys::baracuda_kernels_cast_s4_i64_run(
383 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
384 )
385 },
386 (ElementKind::S4, ElementKind::F32) => unsafe {
387 baracuda_kernels_sys::baracuda_kernels_cast_s4_f32_run(
388 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
389 )
390 },
391 (ElementKind::I32, ElementKind::S4) => unsafe {
392 baracuda_kernels_sys::baracuda_kernels_cast_i32_s4_run(
393 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
394 )
395 },
396 (ElementKind::I64, ElementKind::S4) => unsafe {
397 baracuda_kernels_sys::baracuda_kernels_cast_i64_s4_run(
398 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
399 )
400 },
401 (ElementKind::F32, ElementKind::S4) => unsafe {
402 baracuda_kernels_sys::baracuda_kernels_cast_f32_s4_run(
403 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
404 )
405 },
406 (ElementKind::U4, ElementKind::I32) => unsafe {
408 baracuda_kernels_sys::baracuda_kernels_cast_u4_i32_run(
409 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
410 )
411 },
412 (ElementKind::U4, ElementKind::I64) => unsafe {
413 baracuda_kernels_sys::baracuda_kernels_cast_u4_i64_run(
414 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
415 )
416 },
417 (ElementKind::U4, ElementKind::F32) => unsafe {
418 baracuda_kernels_sys::baracuda_kernels_cast_u4_f32_run(
419 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
420 )
421 },
422 (ElementKind::I32, ElementKind::U4) => unsafe {
423 baracuda_kernels_sys::baracuda_kernels_cast_i32_u4_run(
424 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
425 )
426 },
427 (ElementKind::I64, ElementKind::U4) => unsafe {
428 baracuda_kernels_sys::baracuda_kernels_cast_i64_u4_run(
429 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
430 )
431 },
432 (ElementKind::F32, ElementKind::U4) => unsafe {
433 baracuda_kernels_sys::baracuda_kernels_cast_f32_u4_run(
434 numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
435 )
436 },
437 _ => {
438 return Err(Error::Unsupported(
439 "baracuda-kernels::CastSubBytePlan::run reached an unimplemented \
440 (input, output) pair — select() should have caught this",
441 ));
442 }
443 };
444 map_status(status)
445 }
446}
447
448fn pair_in_scope(input: ElementKind, output: ElementKind) -> bool {
451 use ElementKind::*;
452 match (input, output) {
453 (Bool, I32) | (Bool, I64) | (Bool, F32) | (Bool, F16) | (Bool, Bf16) => true,
455 (I32, Bool) | (I64, Bool) | (F32, Bool) | (F16, Bool) | (Bf16, Bool) => true,
456 (Fp8E4M3, F32) | (Fp8E4M3, F16) | (Fp8E4M3, Bf16) => true,
458 (F32, Fp8E4M3) | (F16, Fp8E4M3) | (Bf16, Fp8E4M3) => true,
459 (Fp8E5M2, F32) | (Fp8E5M2, F16) | (Fp8E5M2, Bf16) => true,
461 (F32, Fp8E5M2) | (F16, Fp8E5M2) | (Bf16, Fp8E5M2) => true,
462 (S4, I32) | (S4, I64) | (S4, F32) => true,
464 (I32, S4) | (I64, S4) | (F32, S4) => true,
465 (U4, I32) | (U4, I64) | (U4, F32) => true,
467 (I32, U4) | (I64, U4) | (F32, U4) => true,
468 _ => false,
469 }
470}
471
472fn type_size_matches_kind<T>(kind: ElementKind) -> bool {
475 let want = match kind {
476 ElementKind::Bool
477 | ElementKind::S8
478 | ElementKind::U8
479 | ElementKind::Fp8E4M3
480 | ElementKind::Fp8E5M2
481 | ElementKind::S4
482 | ElementKind::U4 => 1,
483 ElementKind::F16 | ElementKind::Bf16 => 2,
484 ElementKind::F32 | ElementKind::F32Strict | ElementKind::I32 => 4,
485 ElementKind::F64 | ElementKind::I64 | ElementKind::Complex32 => 8,
486 ElementKind::Complex64 => 16,
487 ElementKind::Bin => return false,
488 };
489 core::mem::size_of::<T>() == want
490}
491
492fn map_status(code: i32) -> Result<()> {
493 match code {
494 0 => Ok(()),
495 1 => Err(Error::MisalignedOperand),
496 2 => Err(Error::InvalidProblem(
497 "baracuda-kernels-sys reported invalid problem \
498 (S4 / U4 require even numel — check descriptor)",
499 )),
500 3 => Err(Error::Unsupported(
501 "baracuda-kernels-sys reported unsupported configuration",
502 )),
503 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
504 n => Err(Error::CutlassInternal(n)),
505 }
506}