use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
};
use baracuda_types::DeviceRepr;
#[derive(Copy, Clone, Debug)]
pub struct ContiguizeDescriptor<const N: usize> {
pub shape: [i32; N],
pub source_strides: [i64; N],
pub source_offset: i64,
pub element: ElementKind,
}
pub struct ContiguizeArgs<'a, T: DeviceRepr + Copy + 'static, const N: usize> {
pub source: TensorRef<'a, T, N>,
pub dest: TensorMut<'a, T, N>,
}
pub struct ContiguizePlan<T: DeviceRepr + Copy + 'static, const N: usize> {
desc: ContiguizeDescriptor<N>,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: DeviceRepr + Copy + 'static, const N: usize> ContiguizePlan<T, N> {
pub fn select(
_stream: &Stream,
desc: &ContiguizeDescriptor<N>,
_pref: PlanPreference,
) -> Result<Self> {
if !type_size_matches_kind::<T>(desc.element) {
return Err(Error::Unsupported(
"baracuda-kernels::ContiguizePlan: sizeof::<T> doesn't match \
descriptor element kind (T must be the Rust type that backs \
desc.element — see ContiguizeDescriptor docs)",
));
}
for &d in desc.shape.iter() {
if d < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::ContiguizePlan: shape dims must be non-negative",
));
}
}
if N > 8 {
return Err(Error::Unsupported(
"baracuda-kernels::ContiguizePlan: tensor rank > 8 not supported",
));
}
let supported = matches!(
desc.element,
ElementKind::F16
| ElementKind::Bf16
| ElementKind::F32
| ElementKind::F32Strict
| ElementKind::F64
| ElementKind::I32
| ElementKind::I64
| ElementKind::Bool
| ElementKind::S8
| ElementKind::U8
| ElementKind::Fp8E4M3
| ElementKind::Fp8E5M2
| ElementKind::Complex32
| ElementKind::Complex64
| ElementKind::S4
| ElementKind::U4
);
if !supported {
return Err(Error::Unsupported(
"baracuda-kernels::ContiguizePlan: dtype not in coverage \
(Bin is out of scope; everything else byte-aligned or nibble-packed)",
));
}
if matches!(desc.element, ElementKind::S4 | ElementKind::U4) && N >= 1 {
let inner = desc.source_strides[N - 1];
if !(inner == 1 || inner == -1 || inner == 2) {
return Err(Error::Unsupported(
"baracuda-kernels::ContiguizePlan: S4 / U4 source's innermost \
stride must be one of {1, -1, 2} for nibble alignment",
));
}
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::ShapeLayout,
op: ShapeLayoutKind::Contiguize as u16,
element: desc.element,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &ContiguizeArgs<'_, T, N>) -> Result<()> {
if args.source.shape != self.desc.shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::ContiguizePlan: source shape mismatch with descriptor",
));
}
if args.dest.shape != self.desc.shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::ContiguizePlan: dest shape mismatch with descriptor",
));
}
if !args.dest.is_contiguous() {
return Err(Error::InvalidProblem(
"baracuda-kernels::ContiguizePlan: dest must be canonical contiguous \
(the whole point of this op is to MATERIALIZE a contiguous view)",
));
}
let numel = args.dest.numel();
let dest_len = args.dest.data.len() as i64;
let needed_storage = match self.desc.element {
ElementKind::S4 | ElementKind::U4 => (numel + 1) / 2,
_ => numel,
};
if dest_len < needed_storage {
return Err(Error::BufferTooSmall {
needed: needed_storage as usize,
got: dest_len as usize,
});
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: ContiguizeArgs<'_, T, N>,
) -> Result<()> {
self.can_implement(&args)?;
let numel = args.dest.numel();
if numel == 0 {
return Ok(());
}
let source_ptr = args.source.data.as_raw().0 as *const c_void;
let dest_ptr = args.dest.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let shape = self.desc.shape;
let source_strides = self.desc.source_strides;
let source_offset = self.desc.source_offset;
let rank = N as i32;
let status = match self.desc.element {
ElementKind::Bool
| ElementKind::S8
| ElementKind::U8
| ElementKind::Fp8E4M3
| ElementKind::Fp8E5M2 => unsafe {
baracuda_kernels_sys::baracuda_kernels_contiguize_b1_run(
dest_ptr,
source_ptr,
shape.as_ptr(),
source_strides.as_ptr(),
source_offset,
rank,
stream_ptr,
)
},
ElementKind::F16 | ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_contiguize_b2_run(
dest_ptr,
source_ptr,
shape.as_ptr(),
source_strides.as_ptr(),
source_offset,
rank,
stream_ptr,
)
},
ElementKind::F32 | ElementKind::F32Strict | ElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_contiguize_b4_run(
dest_ptr,
source_ptr,
shape.as_ptr(),
source_strides.as_ptr(),
source_offset,
rank,
stream_ptr,
)
},
ElementKind::F64 | ElementKind::I64 | ElementKind::Complex32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_contiguize_b8_run(
dest_ptr,
source_ptr,
shape.as_ptr(),
source_strides.as_ptr(),
source_offset,
rank,
stream_ptr,
)
},
ElementKind::Complex64 => unsafe {
baracuda_kernels_sys::baracuda_kernels_contiguize_b16_run(
dest_ptr,
source_ptr,
shape.as_ptr(),
source_strides.as_ptr(),
source_offset,
rank,
stream_ptr,
)
},
ElementKind::S4 | ElementKind::U4 => unsafe {
baracuda_kernels_sys::baracuda_kernels_contiguize_nibble_run(
dest_ptr,
source_ptr,
shape.as_ptr(),
source_strides.as_ptr(),
source_offset,
rank,
stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::ContiguizePlan::run: dtype not in coverage \
(Bin / unknown — should have been rejected at select())",
));
}
};
map_status(status)
}
}
fn type_size_matches_kind<T>(kind: ElementKind) -> bool {
let want = match kind {
ElementKind::Bool
| ElementKind::S8
| ElementKind::U8
| ElementKind::Fp8E4M3
| ElementKind::Fp8E5M2
| ElementKind::S4
| ElementKind::U4 => 1,
ElementKind::F16 | ElementKind::Bf16 => 2,
ElementKind::F32 | ElementKind::F32Strict | ElementKind::I32 => 4,
ElementKind::F64 | ElementKind::I64 | ElementKind::Complex32 => 8,
ElementKind::Complex64 => 16,
ElementKind::Bin => return false, };
core::mem::size_of::<T>() == want
}
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration \
(likely S4 / U4 source innermost stride not in {1, -1, 2})",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}