use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
};
use baracuda_types::DeviceRepr;
#[derive(Copy, Clone, Debug)]
pub struct WriteSliceDescriptor<const N: usize> {
pub dest_shape: [i32; N],
pub source_shape: [i32; N],
pub ranges: [(i32, i32); N],
pub element: ElementKind,
}
pub struct WriteSliceArgs<'a, T: DeviceRepr + Copy + 'static, const N: usize> {
pub dest: TensorMut<'a, T, N>,
pub source: TensorRef<'a, T, N>,
}
pub struct WriteSlicePlan<T: DeviceRepr + Copy + 'static, const N: usize> {
desc: WriteSliceDescriptor<N>,
sku: KernelSku,
byte_width: i32,
is_nibble: bool,
fast_path: FastPath,
_marker: PhantomData<T>,
}
#[derive(Copy, Clone, Debug)]
enum FastPath {
WholeDest,
ContiguousChunk { dest_offset_elems: i64, source_numel: i64 },
Generic,
}
impl<T: DeviceRepr + Copy + 'static, const N: usize> WriteSlicePlan<T, N> {
pub fn select(
_stream: &Stream,
desc: &WriteSliceDescriptor<N>,
_pref: PlanPreference,
) -> Result<Self> {
if N == 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::WriteSlicePlan: rank-0 tensors not supported",
));
}
if N > 8 {
return Err(Error::Unsupported(
"baracuda-kernels::WriteSlicePlan: tensor rank > 8 not supported",
));
}
for d in 0..N {
let (s, e) = desc.ranges[d];
if s < 0 || e < s || e > desc.dest_shape[d] {
return Err(Error::InvalidProblem(
"baracuda-kernels::WriteSlicePlan: ranges[d] must satisfy \
0 <= start <= end <= dest_shape[d]",
));
}
if desc.source_shape[d] != e - s {
return Err(Error::InvalidProblem(
"baracuda-kernels::WriteSlicePlan: source_shape[d] must equal \
ranges[d].1 - ranges[d].0",
));
}
if desc.dest_shape[d] < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::WriteSlicePlan: dest_shape dims must be non-negative",
));
}
}
let (byte_width, is_nibble) = match dispatch_kind(desc.element) {
Some(b) => b,
None => {
return Err(Error::Unsupported(
"baracuda-kernels::WriteSlicePlan: dtype out of scope. Supported set: \
{f16, bf16, f32, F32Strict, f64, i32, i64, Bool, S8, U8, S4, U4, \
Fp8E4M3, Fp8E5M2, Complex32, Complex64}",
));
}
};
if is_nibble {
let (s, e) = desc.ranges[N - 1];
if (s & 1) != 0 || (e & 1) != 0 {
return Err(Error::Unsupported(
"baracuda-kernels::WriteSlicePlan: WriteSlice on S4 / U4 requires \
even start/end on innermost axis (no read-modify-write at byte \
boundary in the trailblazer kernel)",
));
}
if (desc.dest_shape[N - 1] & 1) != 0 {
return Err(Error::Unsupported(
"baracuda-kernels::WriteSlicePlan: WriteSlice on S4 / U4 requires \
even dest_shape on innermost axis",
));
}
}
let fast_path = detect_fast_path::<N>(desc);
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::ShapeLayout,
op: ShapeLayoutKind::WriteSlice as u16,
element: desc.element,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
byte_width,
is_nibble,
fast_path,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &WriteSliceArgs<'_, T, N>) -> Result<()> {
if args.dest.shape != self.desc.dest_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::WriteSlicePlan: dest shape mismatch with descriptor",
));
}
if args.source.shape != self.desc.source_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::WriteSlicePlan: source shape mismatch with descriptor",
));
}
if !args.dest.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::WriteSlicePlan: dest must be contiguous row-major",
));
}
if !args.source.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::WriteSlicePlan: source must be contiguous row-major",
));
}
let dest_numel = product_i64(self.desc.dest_shape);
let source_numel = product_i64(self.desc.source_shape);
let dest_storage = if self.is_nibble { (dest_numel + 1) / 2 } else { dest_numel };
let source_storage = if self.is_nibble { (source_numel + 1) / 2 } else { source_numel };
if (args.dest.data.len() as i64) < dest_storage {
return Err(Error::BufferTooSmall {
needed: dest_storage as usize,
got: args.dest.data.len(),
});
}
if (args.source.data.len() as i64) < source_storage {
return Err(Error::BufferTooSmall {
needed: source_storage as usize,
got: args.source.data.len(),
});
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: WriteSliceArgs<'_, T, N>,
) -> Result<()> {
self.can_implement(&args)?;
let source_numel = product_i64(self.desc.source_shape);
if source_numel == 0 {
return Ok(());
}
let dest_ptr_u64 = args.dest.data.as_raw().0;
let source_ptr_u64 = args.source.data.as_raw().0;
let stream_ptr = stream.as_raw() as *mut c_void;
match self.fast_path {
FastPath::WholeDest | FastPath::ContiguousChunk { .. } => {
let (dest_off_elems, copy_elems) = match self.fast_path {
FastPath::WholeDest => (0i64, source_numel),
FastPath::ContiguousChunk { dest_offset_elems, source_numel: n } => {
(dest_offset_elems, n)
}
FastPath::Generic => unreachable!(),
};
let (dest_off_bytes, copy_bytes) = if self.is_nibble {
(dest_off_elems / 2, copy_elems / 2)
} else {
let bw = self.byte_width as i64;
(dest_off_elems * bw, copy_elems * bw)
};
return copy_d2d_async(
dest_ptr_u64.wrapping_add(dest_off_bytes as u64),
source_ptr_u64,
copy_bytes as usize,
stream_ptr,
);
}
FastPath::Generic => {}
}
let rank = N as i32;
let dest_shape = self.desc.dest_shape;
let source_shape = self.desc.source_shape;
let mut range_start = [0i32; N];
for d in 0..N {
range_start[d] = self.desc.ranges[d].0;
}
let status = if self.is_nibble {
let mut dest_byte_shape = dest_shape;
let mut source_byte_shape = source_shape;
let mut range_start_bytes = range_start;
dest_byte_shape[N - 1] /= 2;
source_byte_shape[N - 1] /= 2;
range_start_bytes[N - 1] /= 2;
let source_byte_numel = source_numel / 2;
unsafe {
baracuda_kernels_sys::baracuda_kernels_write_slice_nibble_run(
dest_ptr_u64 as *mut c_void,
source_ptr_u64 as *const c_void,
source_byte_numel,
rank,
dest_byte_shape.as_ptr(),
source_byte_shape.as_ptr(),
range_start_bytes.as_ptr(),
core::ptr::null_mut(),
0,
stream_ptr,
)
}
} else {
unsafe {
let dest = dest_ptr_u64 as *mut c_void;
let source = source_ptr_u64 as *const c_void;
let ds = dest_shape.as_ptr();
let ss = source_shape.as_ptr();
let rs = range_start.as_ptr();
match self.byte_width {
1 => baracuda_kernels_sys::baracuda_kernels_write_slice_b1_run(
dest, source, source_numel, rank, ds, ss, rs,
core::ptr::null_mut(), 0, stream_ptr,
),
2 => baracuda_kernels_sys::baracuda_kernels_write_slice_b2_run(
dest, source, source_numel, rank, ds, ss, rs,
core::ptr::null_mut(), 0, stream_ptr,
),
4 => baracuda_kernels_sys::baracuda_kernels_write_slice_b4_run(
dest, source, source_numel, rank, ds, ss, rs,
core::ptr::null_mut(), 0, stream_ptr,
),
8 => baracuda_kernels_sys::baracuda_kernels_write_slice_b8_run(
dest, source, source_numel, rank, ds, ss, rs,
core::ptr::null_mut(), 0, stream_ptr,
),
16 => baracuda_kernels_sys::baracuda_kernels_write_slice_b16_run(
dest, source, source_numel, rank, ds, ss, rs,
core::ptr::null_mut(), 0, stream_ptr,
),
_ => return Err(Error::Unsupported(
"baracuda-kernels::WriteSlicePlan::run: unsupported byte width \
(select() should have caught this)",
)),
}
}
};
map_status(status)
}
}
fn dispatch_kind(k: ElementKind) -> Option<(i32, bool)> {
Some(match k {
ElementKind::Bool => (1, false),
ElementKind::S8 => (1, false),
ElementKind::U8 => (1, false),
ElementKind::Fp8E4M3 => (1, false),
ElementKind::Fp8E5M2 => (1, false),
ElementKind::F16 => (2, false),
ElementKind::Bf16 => (2, false),
ElementKind::F32 => (4, false),
ElementKind::F32Strict => (4, false),
ElementKind::I32 => (4, false),
ElementKind::F64 => (8, false),
ElementKind::I64 => (8, false),
ElementKind::Complex32 => (8, false),
ElementKind::Complex64 => (16, false),
ElementKind::S4 => (1, true),
ElementKind::U4 => (1, true),
ElementKind::Bin => return None,
})
}
fn detect_fast_path<const N: usize>(desc: &WriteSliceDescriptor<N>) -> FastPath {
let mut whole = true;
for d in 0..N {
let (s, e) = desc.ranges[d];
if s != 0 || e != desc.dest_shape[d] {
whole = false;
break;
}
}
if whole {
return FastPath::WholeDest;
}
if N == 1 {
let (s, _) = desc.ranges[0];
let source_numel = product_i64(desc.source_shape);
return FastPath::ContiguousChunk {
dest_offset_elems: s as i64,
source_numel,
};
}
let mut minors_full = true;
for d in 1..N {
let (s, e) = desc.ranges[d];
if s != 0 || e != desc.dest_shape[d] {
minors_full = false;
break;
}
}
if minors_full {
let mut minor_prod: i64 = 1;
for d in 1..N {
minor_prod = minor_prod.saturating_mul(desc.dest_shape[d] as i64);
}
let start_0 = desc.ranges[0].0 as i64;
let source_numel = product_i64(desc.source_shape);
return FastPath::ContiguousChunk {
dest_offset_elems: start_0 * minor_prod,
source_numel,
};
}
FastPath::Generic
}
#[inline]
fn product_i64<const N: usize>(shape: [i32; N]) -> i64 {
let mut p: i64 = 1;
for d in 0..N {
p = p.saturating_mul(shape[d] as i64);
}
p
}
fn copy_d2d_async(
dst_dev: u64,
src_dev: u64,
bytes: usize,
stream: *mut c_void,
) -> Result<()> {
if bytes == 0 {
return Ok(());
}
#[allow(non_camel_case_types)]
type CUresult = i32;
unsafe extern "system" {
fn cuMemcpyDtoDAsync_v2(
dst_device: u64,
src_device: u64,
byte_count: usize,
h_stream: *mut c_void,
) -> CUresult;
}
let status = unsafe { cuMemcpyDtoDAsync_v2(dst_dev, src_dev, bytes, stream) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}