1use metal::MTLSize;
13
14use crate::buffer::MlxBuffer;
15use crate::dtypes::DType;
16use crate::encoder::CommandEncoder;
17use crate::error::{MlxError, Result};
18use crate::kernel_registry::KernelRegistry;
19
20use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
21
22pub static FEATURE_CONCAT_SHADER_SOURCE: &str =
23 include_str!("../shaders/feature_concat.metal");
24
25pub fn register(registry: &mut KernelRegistry) {
26 registry.register_source("feature_concat_f32", FEATURE_CONCAT_SHADER_SOURCE);
27}
28
29#[repr(C)]
30#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
31struct GpuFeatureConcatParams {
32 n_tokens: u32,
33 src_dim: u32,
34 dst_offset: u32,
35 dst_stride: u32,
36}
37
38const TG_SIZE: u64 = 256;
39
40pub fn dispatch_feature_concat_f32(
46 encoder: &mut CommandEncoder,
47 registry: &mut KernelRegistry,
48 device: &metal::DeviceRef,
49 src: &MlxBuffer,
50 dst: &MlxBuffer,
51 n_tokens: u32,
52 src_dim: u32,
53 dst_offset: u32,
54 dst_stride: u32,
55) -> Result<()> {
56 if n_tokens == 0 || src_dim == 0 || dst_stride == 0 {
57 return Err(MlxError::InvalidArgument(format!(
58 "feature_concat_f32: n_tokens ({n_tokens}), src_dim ({src_dim}), \
59 dst_stride ({dst_stride}) must all be > 0"
60 )));
61 }
62 if dst_offset.checked_add(src_dim).map(|e| e > dst_stride).unwrap_or(true) {
63 return Err(MlxError::InvalidArgument(format!(
64 "feature_concat_f32: dst_offset ({dst_offset}) + src_dim ({src_dim}) > \
65 dst_stride ({dst_stride}) — chunk overflows the destination row"
66 )));
67 }
68 let f32_sz = DType::F32.size_of();
69 let need_src = (n_tokens as usize) * (src_dim as usize) * f32_sz;
70 let need_dst = (n_tokens as usize) * (dst_stride as usize) * f32_sz;
71 if src.byte_len() < need_src {
72 return Err(MlxError::InvalidArgument(format!(
73 "feature_concat_f32: src too small: {} vs {} bytes",
74 src.byte_len(), need_src
75 )));
76 }
77 if dst.byte_len() < need_dst {
78 return Err(MlxError::InvalidArgument(format!(
79 "feature_concat_f32: dst too small: {} vs {} bytes",
80 dst.byte_len(), need_dst
81 )));
82 }
83
84 let pipeline = registry.get_pipeline("feature_concat_f32", device)?;
85 let gpu_params = GpuFeatureConcatParams {
86 n_tokens,
87 src_dim,
88 dst_offset,
89 dst_stride,
90 };
91 let total = (n_tokens as u64) * (src_dim as u64);
92 let grid = MTLSize::new(total, 1, 1);
93 let tg = MTLSize::new(std::cmp::min(TG_SIZE, total), 1, 1);
94 encode_with_args(
95 encoder,
96 pipeline,
97 &[
98 (0, KernelArg::Bytes(as_bytes(&gpu_params))),
99 (1, KernelArg::Buffer(src)),
100 (2, KernelArg::Buffer(dst)),
101 ],
102 grid,
103 tg,
104 );
105 Ok(())
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::device::MlxDevice;
112 use crate::graph::GraphExecutor;
113
114 #[test]
115 fn adr021_k5_feature_concat_f32_byte_identical() {
116 let device = MlxDevice::new().expect("MlxDevice");
117 let n_tokens: u32 = 11;
118 let dim_main: u32 = 32;
119 let dim_ds: u32 = 32;
120 let dim_total: u32 = dim_main + dim_ds * 3; let src_main: Vec<f32> = (0..(n_tokens * dim_main))
123 .map(|i| ((i as f32) * 0.013_3_f32).sin() * 0.5)
124 .collect();
125 let src_ds: Vec<Vec<f32>> = (0..3)
126 .map(|seed| {
127 (0..(n_tokens * dim_ds))
128 .map(|i| ((i as f32 + 100.0 * (seed as f32 + 1.0)) * 0.011_7_f32).cos() * 0.5)
129 .collect::<Vec<f32>>()
130 })
131 .collect();
132
133 let mut expected = vec![0f32; (n_tokens * dim_total) as usize];
135 let row_stride = dim_total as usize;
136 for t in 0..n_tokens as usize {
137 let dst_base = t * row_stride;
139 let src_base = t * dim_main as usize;
140 for d in 0..dim_main as usize {
141 expected[dst_base + d] = src_main[src_base + d];
142 }
143 for (i, ds) in src_ds.iter().enumerate() {
145 let dst_off = (i + 1) * dim_ds as usize;
146 let src_off = t * dim_ds as usize;
147 for d in 0..dim_ds as usize {
148 expected[dst_base + dst_off + d] = ds[src_off + d];
149 }
150 }
151 }
152
153 let executor =
155 GraphExecutor::new(MlxDevice::new().expect("MlxDevice for executor"));
156 let mut session = executor.begin().expect("begin");
157 let mut registry = KernelRegistry::new();
158 register(&mut registry);
159
160 let mut main_buf = device
161 .alloc_buffer(src_main.len() * 4, DType::F32, vec![n_tokens as usize, dim_main as usize])
162 .unwrap();
163 main_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&src_main);
164 let mut ds_bufs: Vec<MlxBuffer> = (0..3)
165 .map(|i| {
166 let mut b = device
167 .alloc_buffer(src_ds[i].len() * 4, DType::F32, vec![n_tokens as usize, dim_ds as usize])
168 .unwrap();
169 b.as_mut_slice::<f32>().unwrap().copy_from_slice(&src_ds[i]);
170 b
171 })
172 .collect();
173 let dst_buf = device
174 .alloc_buffer((n_tokens * dim_total * 4) as usize, DType::F32,
175 vec![n_tokens as usize, dim_total as usize])
176 .unwrap();
177
178 dispatch_feature_concat_f32(
180 session.encoder_mut(), &mut registry, device.metal_device(),
181 &main_buf, &dst_buf, n_tokens, dim_main, 0, dim_total,
182 ).unwrap();
183 session.encoder_mut().memory_barrier();
184
185 for (i, ds) in ds_bufs.iter_mut().enumerate() {
187 dispatch_feature_concat_f32(
188 session.encoder_mut(), &mut registry, device.metal_device(),
189 ds, &dst_buf, n_tokens, dim_ds, (i as u32 + 1) * dim_ds, dim_total,
190 ).unwrap();
191 session.encoder_mut().memory_barrier();
192 }
193
194 session.finish().expect("finish");
195 let got = dst_buf.as_slice::<f32>().unwrap();
196 for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() {
197 assert_eq!(g.to_bits(), e.to_bits(), "K5 byte parity violated at {i}");
198 }
199 }
200
201 #[test]
202 fn adr021_k5_feature_concat_f32_input_validation() {
203 let device = MlxDevice::new().expect("MlxDevice");
204 let executor = GraphExecutor::new(MlxDevice::new().expect("device for executor"));
205 let mut session = executor.begin().expect("session");
206 let mut registry = KernelRegistry::new();
207 register(&mut registry);
208
209 let s = device.alloc_buffer(64 * 4, DType::F32, vec![16, 4]).unwrap();
210 let d = device.alloc_buffer(128 * 4, DType::F32, vec![16, 8]).unwrap();
211
212 let err = dispatch_feature_concat_f32(
214 session.encoder_mut(), &mut registry, device.metal_device(),
215 &s, &d, 16, 4, 5, 8, ).unwrap_err();
217 assert!(format!("{err}").contains("overflows the destination row"));
218 }
219}