Skip to main content

mlx_native/ops/
feature_concat.rs

1//! ADR-021 K5: GPU feature-axis concat (single-chunk strided copy).
2//!
3//! Each invocation copies one `[T, src_dim]` f32 row-major slab into
4//! its slice of the concatenated `[T, dst_stride]` destination, at
5//! column offset `dst_offset`. Launching once per chunk (with varying
6//! `dst_offset`) builds the full `[T, Σ src_dim_i]` concatenated
7//! tensor — exactly the shape qwen3vl.cpp:186
8//! `ggml_concat(ctx0, embeddings, deepstack_features, 0)` produces.
9//!
10//! Pure copy (no FP arithmetic) → AC-1 byte-identical.
11
12use metal::MTLSize;
13
14use crate::buffer::MlxBuffer;
15use crate::dtypes::DType;
16use crate::encoder::CommandEncoder;
17use crate::error::{MlxError, Result};
18use crate::kernel_registry::KernelRegistry;
19
20use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
21
22pub static FEATURE_CONCAT_SHADER_SOURCE: &str =
23    include_str!("../shaders/feature_concat.metal");
24
25pub fn register(registry: &mut KernelRegistry) {
26    registry.register_source("feature_concat_f32", FEATURE_CONCAT_SHADER_SOURCE);
27}
28
29#[repr(C)]
30#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
31struct GpuFeatureConcatParams {
32    n_tokens: u32,
33    src_dim: u32,
34    dst_offset: u32,
35    dst_stride: u32,
36}
37
38const TG_SIZE: u64 = 256;
39
40/// Copy one `[n_tokens, src_dim]` f32 row-major chunk into the
41/// `[n_tokens, dst_stride]` destination at column `dst_offset`.
42///
43/// Caller is responsible for ensuring chunks don't overlap and
44/// `dst_offset + src_dim <= dst_stride`.
45pub fn dispatch_feature_concat_f32(
46    encoder: &mut CommandEncoder,
47    registry: &mut KernelRegistry,
48    device: &metal::DeviceRef,
49    src: &MlxBuffer,
50    dst: &MlxBuffer,
51    n_tokens: u32,
52    src_dim: u32,
53    dst_offset: u32,
54    dst_stride: u32,
55) -> Result<()> {
56    if n_tokens == 0 || src_dim == 0 || dst_stride == 0 {
57        return Err(MlxError::InvalidArgument(format!(
58            "feature_concat_f32: n_tokens ({n_tokens}), src_dim ({src_dim}), \
59             dst_stride ({dst_stride}) must all be > 0"
60        )));
61    }
62    if dst_offset.checked_add(src_dim).map(|e| e > dst_stride).unwrap_or(true) {
63        return Err(MlxError::InvalidArgument(format!(
64            "feature_concat_f32: dst_offset ({dst_offset}) + src_dim ({src_dim}) > \
65             dst_stride ({dst_stride}) — chunk overflows the destination row"
66        )));
67    }
68    let f32_sz = DType::F32.size_of();
69    let need_src = (n_tokens as usize) * (src_dim as usize) * f32_sz;
70    let need_dst = (n_tokens as usize) * (dst_stride as usize) * f32_sz;
71    if src.byte_len() < need_src {
72        return Err(MlxError::InvalidArgument(format!(
73            "feature_concat_f32: src too small: {} vs {} bytes",
74            src.byte_len(), need_src
75        )));
76    }
77    if dst.byte_len() < need_dst {
78        return Err(MlxError::InvalidArgument(format!(
79            "feature_concat_f32: dst too small: {} vs {} bytes",
80            dst.byte_len(), need_dst
81        )));
82    }
83
84    let pipeline = registry.get_pipeline("feature_concat_f32", device)?;
85    let gpu_params = GpuFeatureConcatParams {
86        n_tokens,
87        src_dim,
88        dst_offset,
89        dst_stride,
90    };
91    let total = (n_tokens as u64) * (src_dim as u64);
92    let grid = MTLSize::new(total, 1, 1);
93    let tg = MTLSize::new(std::cmp::min(TG_SIZE, total), 1, 1);
94    encode_with_args(
95        encoder,
96        pipeline,
97        &[
98            (0, KernelArg::Bytes(as_bytes(&gpu_params))),
99            (1, KernelArg::Buffer(src)),
100            (2, KernelArg::Buffer(dst)),
101        ],
102        grid,
103        tg,
104    );
105    Ok(())
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::device::MlxDevice;
112    use crate::graph::GraphExecutor;
113
114    #[test]
115    fn adr021_k5_feature_concat_f32_byte_identical() {
116        let device = MlxDevice::new().expect("MlxDevice");
117        let n_tokens: u32 = 11;
118        let dim_main: u32 = 32;
119        let dim_ds: u32 = 32;
120        let dim_total: u32 = dim_main + dim_ds * 3; // base + 3 deepstacks
121
122        let src_main: Vec<f32> = (0..(n_tokens * dim_main))
123            .map(|i| ((i as f32) * 0.013_3_f32).sin() * 0.5)
124            .collect();
125        let src_ds: Vec<Vec<f32>> = (0..3)
126            .map(|seed| {
127                (0..(n_tokens * dim_ds))
128                    .map(|i| ((i as f32 + 100.0 * (seed as f32 + 1.0)) * 0.011_7_f32).cos() * 0.5)
129                    .collect::<Vec<f32>>()
130            })
131            .collect();
132
133        // Build CPU oracle.
134        let mut expected = vec![0f32; (n_tokens * dim_total) as usize];
135        let row_stride = dim_total as usize;
136        for t in 0..n_tokens as usize {
137            // base
138            let dst_base = t * row_stride;
139            let src_base = t * dim_main as usize;
140            for d in 0..dim_main as usize {
141                expected[dst_base + d] = src_main[src_base + d];
142            }
143            // deepstacks
144            for (i, ds) in src_ds.iter().enumerate() {
145                let dst_off = (i + 1) * dim_ds as usize;
146                let src_off = t * dim_ds as usize;
147                for d in 0..dim_ds as usize {
148                    expected[dst_base + dst_off + d] = ds[src_off + d];
149                }
150            }
151        }
152
153        // GPU.
154        let executor =
155            GraphExecutor::new(MlxDevice::new().expect("MlxDevice for executor"));
156        let mut session = executor.begin().expect("begin");
157        let mut registry = KernelRegistry::new();
158        register(&mut registry);
159
160        let mut main_buf = device
161            .alloc_buffer(src_main.len() * 4, DType::F32, vec![n_tokens as usize, dim_main as usize])
162            .unwrap();
163        main_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&src_main);
164        let mut ds_bufs: Vec<MlxBuffer> = (0..3)
165            .map(|i| {
166                let mut b = device
167                    .alloc_buffer(src_ds[i].len() * 4, DType::F32, vec![n_tokens as usize, dim_ds as usize])
168                    .unwrap();
169                b.as_mut_slice::<f32>().unwrap().copy_from_slice(&src_ds[i]);
170                b
171            })
172            .collect();
173        let dst_buf = device
174            .alloc_buffer((n_tokens * dim_total * 4) as usize, DType::F32,
175                vec![n_tokens as usize, dim_total as usize])
176            .unwrap();
177
178        // Copy main at offset 0.
179        dispatch_feature_concat_f32(
180            session.encoder_mut(), &mut registry, device.metal_device(),
181            &main_buf, &dst_buf, n_tokens, dim_main, 0, dim_total,
182        ).unwrap();
183        session.encoder_mut().memory_barrier();
184
185        // Copy each deepstack at offset (i+1)*dim_ds.
186        for (i, ds) in ds_bufs.iter_mut().enumerate() {
187            dispatch_feature_concat_f32(
188                session.encoder_mut(), &mut registry, device.metal_device(),
189                ds, &dst_buf, n_tokens, dim_ds, (i as u32 + 1) * dim_ds, dim_total,
190            ).unwrap();
191            session.encoder_mut().memory_barrier();
192        }
193
194        session.finish().expect("finish");
195        let got = dst_buf.as_slice::<f32>().unwrap();
196        for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() {
197            assert_eq!(g.to_bits(), e.to_bits(), "K5 byte parity violated at {i}");
198        }
199    }
200
201    #[test]
202    fn adr021_k5_feature_concat_f32_input_validation() {
203        let device = MlxDevice::new().expect("MlxDevice");
204        let executor = GraphExecutor::new(MlxDevice::new().expect("device for executor"));
205        let mut session = executor.begin().expect("session");
206        let mut registry = KernelRegistry::new();
207        register(&mut registry);
208
209        let s = device.alloc_buffer(64 * 4, DType::F32, vec![16, 4]).unwrap();
210        let d = device.alloc_buffer(128 * 4, DType::F32, vec![16, 8]).unwrap();
211
212        // dst_offset + src_dim > dst_stride
213        let err = dispatch_feature_concat_f32(
214            session.encoder_mut(), &mut registry, device.metal_device(),
215            &s, &d, 16, 4, 5, 8,  // 5+4 = 9 > 8
216        ).unwrap_err();
217        assert!(format!("{err}").contains("overflows the destination row"));
218    }
219}