use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
pub static FEATURE_CONCAT_SHADER_SOURCE: &str =
include_str!("../shaders/feature_concat.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("feature_concat_f32", FEATURE_CONCAT_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFeatureConcatParams {
n_tokens: u32,
src_dim: u32,
dst_offset: u32,
dst_stride: u32,
}
const TG_SIZE: u64 = 256;
pub fn dispatch_feature_concat_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
dst: &MlxBuffer,
n_tokens: u32,
src_dim: u32,
dst_offset: u32,
dst_stride: u32,
) -> Result<()> {
if n_tokens == 0 || src_dim == 0 || dst_stride == 0 {
return Err(MlxError::InvalidArgument(format!(
"feature_concat_f32: n_tokens ({n_tokens}), src_dim ({src_dim}), \
dst_stride ({dst_stride}) must all be > 0"
)));
}
if dst_offset.checked_add(src_dim).map(|e| e > dst_stride).unwrap_or(true) {
return Err(MlxError::InvalidArgument(format!(
"feature_concat_f32: dst_offset ({dst_offset}) + src_dim ({src_dim}) > \
dst_stride ({dst_stride}) — chunk overflows the destination row"
)));
}
let f32_sz = DType::F32.size_of();
let need_src = (n_tokens as usize) * (src_dim as usize) * f32_sz;
let need_dst = (n_tokens as usize) * (dst_stride as usize) * f32_sz;
if src.byte_len() < need_src {
return Err(MlxError::InvalidArgument(format!(
"feature_concat_f32: src too small: {} vs {} bytes",
src.byte_len(), need_src
)));
}
if dst.byte_len() < need_dst {
return Err(MlxError::InvalidArgument(format!(
"feature_concat_f32: dst too small: {} vs {} bytes",
dst.byte_len(), need_dst
)));
}
let pipeline = registry.get_pipeline("feature_concat_f32", device)?;
let gpu_params = GpuFeatureConcatParams {
n_tokens,
src_dim,
dst_offset,
dst_stride,
};
let total = (n_tokens as u64) * (src_dim as u64);
let grid = MTLSize::new(total, 1, 1);
let tg = MTLSize::new(std::cmp::min(TG_SIZE, total), 1, 1);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(src)),
(2, KernelArg::Buffer(dst)),
],
grid,
tg,
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::MlxDevice;
use crate::graph::GraphExecutor;
#[test]
fn adr021_k5_feature_concat_f32_byte_identical() {
let device = MlxDevice::new().expect("MlxDevice");
let n_tokens: u32 = 11;
let dim_main: u32 = 32;
let dim_ds: u32 = 32;
let dim_total: u32 = dim_main + dim_ds * 3;
let src_main: Vec<f32> = (0..(n_tokens * dim_main))
.map(|i| ((i as f32) * 0.013_3_f32).sin() * 0.5)
.collect();
let src_ds: Vec<Vec<f32>> = (0..3)
.map(|seed| {
(0..(n_tokens * dim_ds))
.map(|i| ((i as f32 + 100.0 * (seed as f32 + 1.0)) * 0.011_7_f32).cos() * 0.5)
.collect::<Vec<f32>>()
})
.collect();
let mut expected = vec![0f32; (n_tokens * dim_total) as usize];
let row_stride = dim_total as usize;
for t in 0..n_tokens as usize {
let dst_base = t * row_stride;
let src_base = t * dim_main as usize;
for d in 0..dim_main as usize {
expected[dst_base + d] = src_main[src_base + d];
}
for (i, ds) in src_ds.iter().enumerate() {
let dst_off = (i + 1) * dim_ds as usize;
let src_off = t * dim_ds as usize;
for d in 0..dim_ds as usize {
expected[dst_base + dst_off + d] = ds[src_off + d];
}
}
}
let executor =
GraphExecutor::new(MlxDevice::new().expect("MlxDevice for executor"));
let mut session = executor.begin().expect("begin");
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut main_buf = device
.alloc_buffer(src_main.len() * 4, DType::F32, vec![n_tokens as usize, dim_main as usize])
.unwrap();
main_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&src_main);
let mut ds_bufs: Vec<MlxBuffer> = (0..3)
.map(|i| {
let mut b = device
.alloc_buffer(src_ds[i].len() * 4, DType::F32, vec![n_tokens as usize, dim_ds as usize])
.unwrap();
b.as_mut_slice::<f32>().unwrap().copy_from_slice(&src_ds[i]);
b
})
.collect();
let dst_buf = device
.alloc_buffer((n_tokens * dim_total * 4) as usize, DType::F32,
vec![n_tokens as usize, dim_total as usize])
.unwrap();
dispatch_feature_concat_f32(
session.encoder_mut(), &mut registry, device.metal_device(),
&main_buf, &dst_buf, n_tokens, dim_main, 0, dim_total,
).unwrap();
session.encoder_mut().memory_barrier();
for (i, ds) in ds_bufs.iter_mut().enumerate() {
dispatch_feature_concat_f32(
session.encoder_mut(), &mut registry, device.metal_device(),
ds, &dst_buf, n_tokens, dim_ds, (i as u32 + 1) * dim_ds, dim_total,
).unwrap();
session.encoder_mut().memory_barrier();
}
session.finish().expect("finish");
let got = dst_buf.as_slice::<f32>().unwrap();
for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() {
assert_eq!(g.to_bits(), e.to_bits(), "K5 byte parity violated at {i}");
}
}
#[test]
fn adr021_k5_feature_concat_f32_input_validation() {
let device = MlxDevice::new().expect("MlxDevice");
let executor = GraphExecutor::new(MlxDevice::new().expect("device for executor"));
let mut session = executor.begin().expect("session");
let mut registry = KernelRegistry::new();
register(&mut registry);
let s = device.alloc_buffer(64 * 4, DType::F32, vec![16, 4]).unwrap();
let d = device.alloc_buffer(128 * 4, DType::F32, vec![16, 8]).unwrap();
let err = dispatch_feature_concat_f32(
session.encoder_mut(), &mut registry, device.metal_device(),
&s, &d, 16, 4, 5, 8, ).unwrap_err();
assert!(format!("{err}").contains("overflows the destination row"));
}
}