use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
};
use baracuda_types::DeviceRepr;
#[derive(Copy, Clone, Debug)]
pub struct CastSubByteDescriptor {
pub numel: i32,
pub input_element: ElementKind,
pub output_element: ElementKind,
}
pub struct CastSubByteArgs<'a, TIn: DeviceRepr + Copy + 'static, TOut: DeviceRepr + Copy + 'static>
{
pub input: TensorRef<'a, TIn, 1>,
pub output: TensorMut<'a, TOut, 1>,
}
pub struct CastSubBytePlan<
TIn: DeviceRepr + Copy + 'static,
TOut: DeviceRepr + Copy + 'static,
> {
desc: CastSubByteDescriptor,
sku: KernelSku,
_marker_in: PhantomData<TIn>,
_marker_out: PhantomData<TOut>,
}
impl<TIn: DeviceRepr + Copy + 'static, TOut: DeviceRepr + Copy + 'static>
CastSubBytePlan<TIn, TOut>
{
pub fn select(
_stream: &Stream,
desc: &CastSubByteDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if !type_size_matches_kind::<TIn>(desc.input_element) {
return Err(Error::Unsupported(
"baracuda-kernels::CastSubBytePlan: sizeof::<TIn>() does not match \
descriptor input_element width",
));
}
if !type_size_matches_kind::<TOut>(desc.output_element) {
return Err(Error::Unsupported(
"baracuda-kernels::CastSubBytePlan: sizeof::<TOut>() does not match \
descriptor output_element width",
));
}
if desc.numel < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::CastSubBytePlan: numel must be non-negative",
));
}
let inv = matches!(
desc.input_element,
ElementKind::S4 | ElementKind::U4
);
let outv = matches!(
desc.output_element,
ElementKind::S4 | ElementKind::U4
);
if (inv || outv) && (desc.numel % 2 != 0) {
return Err(Error::InvalidProblem(
"baracuda-kernels::CastSubBytePlan: S4 / U4 endpoints require even numel \
(packed buffer is numel/2 bytes)",
));
}
if !pair_in_scope(desc.input_element, desc.output_element) {
return Err(Error::Unsupported(
"baracuda-kernels::CastSubBytePlan: (input, output) pair not in scope \
for Phase 13.3 — see module docs for the wired set",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::UnaryElementwise,
op: UnaryKind::Cast as u16,
element: desc.input_element,
aux_element: Some(desc.output_element),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker_in: PhantomData,
_marker_out: PhantomData,
})
}
pub fn can_implement(&self, args: &CastSubByteArgs<'_, TIn, TOut>) -> Result<()> {
let expected = self.desc.numel as i64;
let in_packed = matches!(self.desc.input_element, ElementKind::S4 | ElementKind::U4);
let out_packed = matches!(self.desc.output_element, ElementKind::S4 | ElementKind::U4);
let needed_in = if in_packed { (expected + 1) / 2 } else { expected };
let needed_out = if out_packed { (expected + 1) / 2 } else { expected };
if (args.input.data.len() as i64) < needed_in {
return Err(Error::BufferTooSmall {
needed: needed_in as usize,
got: args.input.data.len(),
});
}
if (args.output.data.len() as i64) < needed_out {
return Err(Error::BufferTooSmall {
needed: needed_out as usize,
got: args.output.data.len(),
});
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: CastSubByteArgs<'_, TIn, TOut>,
) -> Result<()> {
self.can_implement(&args)?;
let numel = self.desc.numel as i64;
if numel == 0 {
return Ok(());
}
let x_ptr = args.input.data.as_raw().0 as *const c_void;
let y_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = match (self.desc.input_element, self.desc.output_element) {
(ElementKind::Bool, ElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bool_i32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Bool, ElementKind::I64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bool_i64_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Bool, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bool_f32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Bool, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bool_f16_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Bool, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bool_bf16_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::I32, ElementKind::Bool) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i32_bool_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::I64, ElementKind::Bool) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i64_bool_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::Bool) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_bool_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F16, ElementKind::Bool) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f16_bool_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Bf16, ElementKind::Bool) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bf16_bool_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_fp8e4m3_f32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_fp8e4m3_f16_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_fp8e4m3_bf16_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::Fp8E4M3) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_fp8e4m3_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F16, ElementKind::Fp8E4M3) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f16_fp8e4m3_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Bf16, ElementKind::Fp8E4M3) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bf16_fp8e4m3_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_fp8e5m2_f32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_fp8e5m2_f16_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_fp8e5m2_bf16_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::Fp8E5M2) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_fp8e5m2_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F16, ElementKind::Fp8E5M2) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f16_fp8e5m2_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Bf16, ElementKind::Fp8E5M2) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bf16_fp8e5m2_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S4, ElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_s4_i32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S4, ElementKind::I64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_s4_i64_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S4, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_s4_f32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::I32, ElementKind::S4) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i32_s4_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::I64, ElementKind::S4) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i64_s4_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::S4) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_s4_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U4, ElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_u4_i32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U4, ElementKind::I64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_u4_i64_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U4, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_u4_f32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::I32, ElementKind::U4) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i32_u4_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::I64, ElementKind::U4) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i64_u4_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::U4) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_u4_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::CastSubBytePlan::run reached an unimplemented \
(input, output) pair — select() should have caught this",
));
}
};
map_status(status)
}
}
fn pair_in_scope(input: ElementKind, output: ElementKind) -> bool {
use ElementKind::*;
match (input, output) {
(Bool, I32) | (Bool, I64) | (Bool, F32) | (Bool, F16) | (Bool, Bf16) => true,
(I32, Bool) | (I64, Bool) | (F32, Bool) | (F16, Bool) | (Bf16, Bool) => true,
(Fp8E4M3, F32) | (Fp8E4M3, F16) | (Fp8E4M3, Bf16) => true,
(F32, Fp8E4M3) | (F16, Fp8E4M3) | (Bf16, Fp8E4M3) => true,
(Fp8E5M2, F32) | (Fp8E5M2, F16) | (Fp8E5M2, Bf16) => true,
(F32, Fp8E5M2) | (F16, Fp8E5M2) | (Bf16, Fp8E5M2) => true,
(S4, I32) | (S4, I64) | (S4, F32) => true,
(I32, S4) | (I64, S4) | (F32, S4) => true,
(U4, I32) | (U4, I64) | (U4, F32) => true,
(I32, U4) | (I64, U4) | (F32, U4) => true,
_ => false,
}
}
fn type_size_matches_kind<T>(kind: ElementKind) -> bool {
let want = match kind {
ElementKind::Bool
| ElementKind::S8
| ElementKind::U8
| ElementKind::Fp8E4M3
| ElementKind::Fp8E5M2
| ElementKind::S4
| ElementKind::U4 => 1,
ElementKind::F16 | ElementKind::Bf16 => 2,
ElementKind::F32 | ElementKind::F32Strict | ElementKind::I32 => 4,
ElementKind::F64 | ElementKind::I64 | ElementKind::Complex32 => 8,
ElementKind::Complex64 => 16,
ElementKind::Bin => return false,
};
core::mem::size_of::<T>() == want
}
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem \
(S4 / U4 require even numel — check descriptor)",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}