use baracuda_driver::{init, Context, Device, DeviceBuffer, Stream};
use baracuda_kernels::{
contiguous_stride, ElementKind, PlanPreference, S4, TensorMut, TensorRef, Workspace,
WriteSliceArgs, WriteSliceDescriptor, WriteSlicePlan,
};
use half::{bf16, f16};
fn setup() -> (Context, Stream) {
init().expect("driver init");
let device = Device::get(0).expect("device 0");
let ctx = Context::new(&device).expect("context");
let stream = Stream::new(&ctx).expect("stream");
(ctx, stream)
}
fn cpu_write_slice<const N: usize>(
dest_init: &[u8],
dest_shape: [i32; N],
source: &[u8],
source_shape: [i32; N],
ranges: [(i32, i32); N],
byte_width: usize,
) -> Vec<u8> {
let mut out = dest_init.to_vec();
let source_numel: usize = source_shape.iter().map(|&d| d as usize).product();
if source_numel == 0 {
return out;
}
let mut dest_strides = [1i64; N];
let mut source_strides = [1i64; N];
if N > 0 {
for d in (0..N - 1).rev() {
dest_strides[d] = dest_strides[d + 1] * dest_shape[d + 1] as i64;
source_strides[d] = source_strides[d + 1] * source_shape[d + 1] as i64;
}
}
for i in 0..source_numel {
let mut linear = i as i64;
let mut coord = [0i64; N];
for d in (0..N).rev() {
let s = source_shape[d] as i64;
if s == 0 {
coord[d] = 0;
} else {
coord[d] = linear % s;
linear /= s;
}
}
let mut dest_off: i64 = 0;
let mut source_off: i64 = 0;
for d in 0..N {
dest_off += (coord[d] + ranges[d].0 as i64) * dest_strides[d];
source_off += coord[d] * source_strides[d];
}
let dst_byte = dest_off as usize * byte_width;
let src_byte = source_off as usize * byte_width;
for b in 0..byte_width {
out[dst_byte + b] = source[src_byte + b];
}
}
out
}
#[test]
#[ignore]
fn write_slice_kv_cache_append_f32() {
let (ctx, stream) = setup();
let dest_shape = [32i32, 4, 64]; let source_shape = [1i32, 4, 64]; let ranges = [(7, 8), (0, 4), (0, 64)];
let byte_width = 4;
let dest_numel = (dest_shape[0] * dest_shape[1] * dest_shape[2]) as usize;
let source_numel = (source_shape[0] * source_shape[1] * source_shape[2]) as usize;
let dest_init: Vec<f32> = (0..dest_numel).map(|i| (i as f32) * 0.001).collect();
let source: Vec<f32> = (0..source_numel).map(|i| (i as f32) + 1000.0).collect();
let dest_bytes = bytemuck_slice(&dest_init);
let source_bytes = bytemuck_slice(&source);
let expected_bytes = cpu_write_slice::<3>(
&dest_bytes, dest_shape, &source_bytes, source_shape, ranges, byte_width,
);
let mut dev_dest = DeviceBuffer::from_slice(&ctx, &dest_init).expect("upload dest");
let dev_source = DeviceBuffer::from_slice(&ctx, &source).expect("upload source");
let desc = WriteSliceDescriptor {
dest_shape,
source_shape,
ranges,
element: ElementKind::F32,
};
let plan = WriteSlicePlan::<f32, 3>::select(&stream, &desc, PlanPreference::default())
.expect("select");
let args = WriteSliceArgs::<f32, 3> {
dest: TensorMut {
data: dev_dest.as_slice_mut(),
shape: dest_shape,
stride: contiguous_stride(dest_shape),
},
source: TensorRef {
data: dev_source.as_slice(),
shape: source_shape,
stride: contiguous_stride(source_shape),
},
};
plan.run(&stream, Workspace::None, args).expect("run");
stream.synchronize().expect("sync");
let mut got = vec![0f32; dest_numel];
dev_dest.copy_to_host(&mut got).expect("download");
let got_bytes = bytemuck_slice(&got);
for i in 0..expected_bytes.len() {
assert_eq!(
got_bytes[i], expected_bytes[i],
"kv_cache_append_f32 mismatch @ byte {i}"
);
}
}
#[test]
#[ignore]
fn write_slice_kv_cache_append_f16() {
let (ctx, stream) = setup();
let dest_shape = [16i32, 2, 32];
let source_shape = [1i32, 2, 32];
let ranges = [(3, 4), (0, 2), (0, 32)];
let dest_numel = (dest_shape[0] * dest_shape[1] * dest_shape[2]) as usize;
let source_numel = (source_shape[0] * source_shape[1] * source_shape[2]) as usize;
let dest_init: Vec<f16> = (0..dest_numel)
.map(|i| f16::from_f32(i as f32 * 0.01))
.collect();
let source: Vec<f16> = (0..source_numel)
.map(|i| f16::from_f32(i as f32 + 50.0))
.collect();
let mut dev_dest = DeviceBuffer::from_slice(&ctx, &dest_init).expect("upload dest");
let dev_source = DeviceBuffer::from_slice(&ctx, &source).expect("upload source");
let desc = WriteSliceDescriptor {
dest_shape,
source_shape,
ranges,
element: ElementKind::F16,
};
let plan = WriteSlicePlan::<f16, 3>::select(&stream, &desc, PlanPreference::default())
.expect("select");
let args = WriteSliceArgs::<f16, 3> {
dest: TensorMut {
data: dev_dest.as_slice_mut(),
shape: dest_shape,
stride: contiguous_stride(dest_shape),
},
source: TensorRef {
data: dev_source.as_slice(),
shape: source_shape,
stride: contiguous_stride(source_shape),
},
};
plan.run(&stream, Workspace::None, args).expect("run");
stream.synchronize().expect("sync");
let mut got = vec![f16::ZERO; dest_numel];
dev_dest.copy_to_host(&mut got).expect("download");
let mut expected = dest_init.clone();
let dest_minor_stride = (dest_shape[1] * dest_shape[2]) as usize;
let row_elems = source_numel;
let row_off = ranges[0].0 as usize * dest_minor_stride;
expected[row_off..row_off + row_elems].copy_from_slice(&source);
for i in 0..dest_numel {
assert_eq!(
got[i].to_bits(), expected[i].to_bits(),
"kv_cache_append_f16 mismatch @ {i}"
);
}
}
#[test]
#[ignore]
fn write_slice_kv_cache_append_bf16() {
let (ctx, stream) = setup();
let dest_shape = [12i32, 3, 16];
let source_shape = [1i32, 3, 16];
let ranges = [(5, 6), (0, 3), (0, 16)];
let dest_numel = (dest_shape[0] * dest_shape[1] * dest_shape[2]) as usize;
let source_numel = (source_shape[0] * source_shape[1] * source_shape[2]) as usize;
let dest_init: Vec<bf16> = (0..dest_numel)
.map(|i| bf16::from_f32(i as f32 * 0.5))
.collect();
let source: Vec<bf16> = (0..source_numel)
.map(|i| bf16::from_f32(i as f32 + 200.0))
.collect();
let mut dev_dest = DeviceBuffer::from_slice(&ctx, &dest_init).expect("upload dest");
let dev_source = DeviceBuffer::from_slice(&ctx, &source).expect("upload source");
let desc = WriteSliceDescriptor {
dest_shape,
source_shape,
ranges,
element: ElementKind::Bf16,
};
let plan = WriteSlicePlan::<bf16, 3>::select(&stream, &desc, PlanPreference::default())
.expect("select");
let args = WriteSliceArgs::<bf16, 3> {
dest: TensorMut {
data: dev_dest.as_slice_mut(),
shape: dest_shape,
stride: contiguous_stride(dest_shape),
},
source: TensorRef {
data: dev_source.as_slice(),
shape: source_shape,
stride: contiguous_stride(source_shape),
},
};
plan.run(&stream, Workspace::None, args).expect("run");
stream.synchronize().expect("sync");
let mut got = vec![bf16::ZERO; dest_numel];
dev_dest.copy_to_host(&mut got).expect("download");
let mut expected = dest_init.clone();
let dest_minor_stride = (dest_shape[1] * dest_shape[2]) as usize;
let row_off = ranges[0].0 as usize * dest_minor_stride;
expected[row_off..row_off + source_numel].copy_from_slice(&source);
for i in 0..dest_numel {
assert_eq!(
got[i].to_bits(), expected[i].to_bits(),
"kv_cache_append_bf16 mismatch @ {i}"
);
}
}
#[test]
#[ignore]
fn write_slice_kv_cache_append_f64() {
let (ctx, stream) = setup();
let dest_shape = [8i32, 2, 16];
let source_shape = [1i32, 2, 16];
let ranges = [(2, 3), (0, 2), (0, 16)];
let dest_numel = (dest_shape[0] * dest_shape[1] * dest_shape[2]) as usize;
let source_numel = (source_shape[0] * source_shape[1] * source_shape[2]) as usize;
let dest_init: Vec<f64> = (0..dest_numel).map(|i| (i as f64) * 0.001).collect();
let source: Vec<f64> = (0..source_numel).map(|i| (i as f64) + 1.0e6).collect();
let mut dev_dest = DeviceBuffer::from_slice(&ctx, &dest_init).expect("upload dest");
let dev_source = DeviceBuffer::from_slice(&ctx, &source).expect("upload source");
let desc = WriteSliceDescriptor {
dest_shape,
source_shape,
ranges,
element: ElementKind::F64,
};
let plan = WriteSlicePlan::<f64, 3>::select(&stream, &desc, PlanPreference::default())
.expect("select");
let args = WriteSliceArgs::<f64, 3> {
dest: TensorMut {
data: dev_dest.as_slice_mut(),
shape: dest_shape,
stride: contiguous_stride(dest_shape),
},
source: TensorRef {
data: dev_source.as_slice(),
shape: source_shape,
stride: contiguous_stride(source_shape),
},
};
plan.run(&stream, Workspace::None, args).expect("run");
stream.synchronize().expect("sync");
let mut got = vec![0f64; dest_numel];
dev_dest.copy_to_host(&mut got).expect("download");
let mut expected = dest_init.clone();
let dest_minor_stride = (dest_shape[1] * dest_shape[2]) as usize;
let row_off = ranges[0].0 as usize * dest_minor_stride;
expected[row_off..row_off + source_numel].copy_from_slice(&source);
for i in 0..dest_numel {
assert_eq!(
got[i].to_bits(), expected[i].to_bits(),
"kv_cache_append_f64 mismatch @ {i}"
);
}
}
#[test]
#[ignore]
fn write_slice_interior_2d_f32() {
let (ctx, stream) = setup();
let dest_shape = [16i32, 20];
let source_shape = [5i32, 7];
let ranges = [(4, 9), (8, 15)];
let dest_numel = (dest_shape[0] * dest_shape[1]) as usize;
let source_numel = (source_shape[0] * source_shape[1]) as usize;
let dest_init: Vec<f32> = (0..dest_numel).map(|i| (i as f32) * 0.1 - 5.0).collect();
let source: Vec<f32> = (0..source_numel).map(|i| (i as f32) + 1000.0).collect();
let dest_bytes = bytemuck_slice(&dest_init);
let source_bytes = bytemuck_slice(&source);
let expected_bytes = cpu_write_slice::<2>(
&dest_bytes, dest_shape, &source_bytes, source_shape, ranges, 4,
);
let mut dev_dest = DeviceBuffer::from_slice(&ctx, &dest_init).expect("upload dest");
let dev_source = DeviceBuffer::from_slice(&ctx, &source).expect("upload source");
let desc = WriteSliceDescriptor {
dest_shape,
source_shape,
ranges,
element: ElementKind::F32,
};
let plan = WriteSlicePlan::<f32, 2>::select(&stream, &desc, PlanPreference::default())
.expect("select");
let args = WriteSliceArgs::<f32, 2> {
dest: TensorMut {
data: dev_dest.as_slice_mut(),
shape: dest_shape,
stride: contiguous_stride(dest_shape),
},
source: TensorRef {
data: dev_source.as_slice(),
shape: source_shape,
stride: contiguous_stride(source_shape),
},
};
plan.run(&stream, Workspace::None, args).expect("run");
stream.synchronize().expect("sync");
let mut got = vec![0f32; dest_numel];
dev_dest.copy_to_host(&mut got).expect("download");
let got_bytes = bytemuck_slice(&got);
for i in 0..expected_bytes.len() {
assert_eq!(
got_bytes[i], expected_bytes[i],
"interior_2d_f32 mismatch @ byte {i}"
);
}
}
#[test]
#[ignore]
fn write_slice_1d_i32() {
let (ctx, stream) = setup();
let dest_shape = [32i32];
let source_shape = [8i32];
let ranges = [(10, 18)];
let dest_init: Vec<i32> = (0..32).map(|i| i * 10).collect();
let source: Vec<i32> = (0..8).map(|i| 999 - i).collect();
let mut dev_dest = DeviceBuffer::from_slice(&ctx, &dest_init).expect("upload dest");
let dev_source = DeviceBuffer::from_slice(&ctx, &source).expect("upload source");
let desc = WriteSliceDescriptor {
dest_shape,
source_shape,
ranges,
element: ElementKind::I32,
};
let plan = WriteSlicePlan::<i32, 1>::select(&stream, &desc, PlanPreference::default())
.expect("select");
let args = WriteSliceArgs::<i32, 1> {
dest: TensorMut {
data: dev_dest.as_slice_mut(),
shape: dest_shape,
stride: contiguous_stride(dest_shape),
},
source: TensorRef {
data: dev_source.as_slice(),
shape: source_shape,
stride: contiguous_stride(source_shape),
},
};
plan.run(&stream, Workspace::None, args).expect("run");
stream.synchronize().expect("sync");
let mut got = vec![0i32; 32];
dev_dest.copy_to_host(&mut got).expect("download");
let mut expected = dest_init.clone();
expected[10..18].copy_from_slice(&source);
assert_eq!(got, expected);
}
#[test]
#[ignore]
fn write_slice_nibble_s4_rank2() {
let (ctx, stream) = setup();
let dest_shape = [4i32, 8];
let source_shape = [2i32, 4];
let ranges = [(1, 3), (2, 6)];
let dest_storage: usize = (4 * 8) / 2;
let source_storage: usize = (2 * 4) / 2;
let dest_init: Vec<u8> = (0..dest_storage as u8).map(|i| i.wrapping_mul(17)).collect();
let source_init: Vec<u8> = (0..source_storage as u8).map(|i| 0xA0 | (i & 0x0F)).collect();
let mut dev_dest_u8 = DeviceBuffer::from_slice(&ctx, &dest_init).expect("upload dest");
let dev_source_u8 = DeviceBuffer::from_slice(&ctx, &source_init).expect("upload source");
let desc = WriteSliceDescriptor {
dest_shape,
source_shape,
ranges,
element: ElementKind::S4,
};
let plan = WriteSlicePlan::<S4, 2>::select(&stream, &desc, PlanPreference::default())
.expect("select");
{
let dev_dest_s4 = dev_dest_u8.view_as_mut::<S4>();
let dev_source_s4 = dev_source_u8.view_as::<S4>();
let args = WriteSliceArgs::<S4, 2> {
dest: TensorMut {
data: dev_dest_s4,
shape: dest_shape,
stride: contiguous_stride(dest_shape),
},
source: TensorRef {
data: dev_source_s4,
shape: source_shape,
stride: contiguous_stride(source_shape),
},
};
plan.run(&stream, Workspace::None, args).expect("run");
stream.synchronize().expect("sync");
}
let mut got = vec![0u8; dest_storage];
dev_dest_u8.copy_to_host(&mut got).expect("download");
let mut expected = dest_init.clone();
let row_byte_stride = 4usize;
let inner_start_byte = 2usize / 2; let inner_len_bytes = (6 - 2) / 2; for r in 0..2 {
let dest_row = (ranges[0].0 as usize + r) * row_byte_stride;
let src_row = r * inner_len_bytes;
for b in 0..inner_len_bytes {
expected[dest_row + inner_start_byte + b] = source_init[src_row + b];
}
}
assert_eq!(got, expected, "nibble_s4_rank2 mismatch");
}
#[test]
#[ignore]
fn write_slice_out_of_bounds_rejected() {
let (_ctx, stream) = setup();
let dest_shape = [16i32, 8];
let source_shape = [4i32, 4];
let ranges = [(14, 18), (0, 4)]; let desc = WriteSliceDescriptor {
dest_shape,
source_shape,
ranges,
element: ElementKind::F32,
};
let res = WriteSlicePlan::<f32, 2>::select(&stream, &desc, PlanPreference::default());
assert!(res.is_err(), "out-of-bounds range must be rejected at select");
}
#[test]
#[ignore]
fn write_slice_nibble_odd_start_rejected() {
let (_ctx, stream) = setup();
let dest_shape = [4i32, 8];
let source_shape = [2i32, 4];
let ranges = [(1, 3), (1, 5)]; let desc = WriteSliceDescriptor {
dest_shape,
source_shape,
ranges,
element: ElementKind::S4,
};
let res = WriteSlicePlan::<S4, 2>::select(&stream, &desc, PlanPreference::default());
assert!(
res.is_err(),
"nibble write with odd-aligned innermost start must be rejected"
);
}
fn bytemuck_slice<T: Copy>(s: &[T]) -> Vec<u8> {
let bytes_per = core::mem::size_of::<T>();
let mut out = vec![0u8; s.len() * bytes_per];
unsafe {
core::ptr::copy_nonoverlapping(
s.as_ptr() as *const u8,
out.as_mut_ptr(),
s.len() * bytes_per,
);
}
out
}