use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, UnaryKind, Workspace,
};
#[derive(Copy, Clone, Debug)]
pub struct CastDescriptor {
pub numel: i32,
pub input_element: ElementKind,
pub output_element: ElementKind,
}
pub struct CastArgs<'a, TIn: Element, TOut: Element> {
pub input: TensorRef<'a, TIn, 1>,
pub output: TensorMut<'a, TOut, 1>,
}
pub struct CastPlan<TIn: Element, TOut: Element> {
desc: CastDescriptor,
sku: KernelSku,
_marker_in: PhantomData<TIn>,
_marker_out: PhantomData<TOut>,
}
impl<TIn: Element, TOut: Element> CastPlan<TIn, TOut> {
pub fn select(
_stream: &Stream,
desc: &CastDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.input_element != TIn::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::CastPlan: descriptor input_element != type parameter TIn",
));
}
if desc.output_element != TOut::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::CastPlan: descriptor output_element != type parameter TOut",
));
}
if desc.numel < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::CastPlan: numel must be non-negative",
));
}
if !pair_in_scope(TIn::KIND, TOut::KIND) {
return Err(Error::Unsupported(
"baracuda-kernels::CastPlan: this (TIn, TOut) pair is not wired today; \
supported set is {f32, f64, f16, bf16, i32, i64} × {same}",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::UnaryElementwise,
op: UnaryKind::Cast as u16,
element: TIn::KIND,
aux_element: Some(TOut::KIND),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker_in: PhantomData,
_marker_out: PhantomData,
})
}
pub fn can_implement(&self, args: &CastArgs<'_, TIn, TOut>) -> Result<()> {
let expected = self.desc.numel as i64;
if args.input.numel() != expected {
return Err(Error::InvalidProblem(
"baracuda-kernels::CastPlan: input numel mismatch with descriptor",
));
}
if args.output.numel() != expected {
return Err(Error::InvalidProblem(
"baracuda-kernels::CastPlan: output numel mismatch with descriptor",
));
}
if (args.input.data.len() as i64) < expected {
return Err(Error::BufferTooSmall {
needed: expected as usize,
got: args.input.data.len(),
});
}
if (args.output.data.len() as i64) < expected {
return Err(Error::BufferTooSmall {
needed: expected as usize,
got: args.output.data.len(),
});
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: CastArgs<'_, TIn, TOut>,
) -> Result<()> {
self.can_implement(&args)?;
let numel = self.desc.numel as i64;
if numel == 0 {
return Ok(());
}
let x_ptr = args.input.data.as_raw().0 as *const c_void;
let y_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = match (TIn::KIND, TOut::KIND) {
(ElementKind::F32, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_f32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_f64_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_f16_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_bf16_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_i32_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F32, ElementKind::I64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f32_i64_run(
numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::F64, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f64_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F64, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f64_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F64, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f64_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F64, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f64_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F64, ElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f64_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F64, ElementKind::I64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f64_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F16, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f16_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F16, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f16_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F16, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f16_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F16, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f16_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F16, ElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f16_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::F16, ElementKind::I64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_f16_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::Bf16, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bf16_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::Bf16, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bf16_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::Bf16, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bf16_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::Bf16, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bf16_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::Bf16, ElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bf16_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::Bf16, ElementKind::I64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_bf16_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I32, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i32_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I32, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i32_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I32, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i32_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I32, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i32_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I32, ElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i32_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I32, ElementKind::I64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i32_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I64, ElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i64_f32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I64, ElementKind::F64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i64_f64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I64, ElementKind::F16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i64_f16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I64, ElementKind::Bf16) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i64_bf16_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I64, ElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i64_i32_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
(ElementKind::I64, ElementKind::I64) => unsafe {
baracuda_kernels_sys::baracuda_kernels_cast_i64_i64_run(numel, x_ptr, y_ptr, core::ptr::null_mut(), 0, stream_ptr)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::CastPlan::run reached an unimplemented \
(TIn, TOut) pair — select() should have caught this",
));
}
};
map_status(status)
}
}
fn pair_in_scope(input: ElementKind, output: ElementKind) -> bool {
fn allowed(k: ElementKind) -> bool {
matches!(
k,
ElementKind::F32
| ElementKind::F64
| ElementKind::F16
| ElementKind::Bf16
| ElementKind::I32
| ElementKind::I64
)
}
allowed(input) && allowed(output)
}
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}