use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
pub struct EmbeddingGatherParams {
pub embed_dim: usize,
pub group_size: usize,
pub bits: u8,
pub n_tokens: usize,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuEmbeddingParams {
embed_dim: u32,
group_size: u32,
packed_row_stride: u32,
n_groups_per_row: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn embedding_gather(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
weight_packed: &MlxBuffer,
scales: &MlxBuffer,
biases: &MlxBuffer,
token_ids: &MlxBuffer,
output: &MlxBuffer,
params: &EmbeddingGatherParams,
) -> Result<()> {
if params.bits != 4 && params.bits != 6 {
return Err(MlxError::InvalidArgument(format!(
"embedding_gather: bits must be 4 or 6, got {}",
params.bits
)));
}
if params.embed_dim == 0 {
return Err(MlxError::InvalidArgument(
"embedding_gather: embed_dim must be > 0".into(),
));
}
if params.group_size == 0 {
return Err(MlxError::InvalidArgument(
"embedding_gather: group_size must be > 0".into(),
));
}
if params.embed_dim % params.group_size != 0 {
return Err(MlxError::InvalidArgument(format!(
"embedding_gather: embed_dim ({}) must be divisible by group_size ({})",
params.embed_dim, params.group_size
)));
}
if params.n_tokens == 0 {
return Err(MlxError::InvalidArgument(
"embedding_gather: n_tokens must be > 0".into(),
));
}
let expected_output_bytes = params.n_tokens * params.embed_dim * std::mem::size_of::<f32>();
if output.byte_len() < expected_output_bytes {
return Err(MlxError::InvalidArgument(format!(
"embedding_gather: output buffer too small: need {} bytes, have {}",
expected_output_bytes,
output.byte_len()
)));
}
let n_groups_per_row = params.embed_dim / params.group_size;
let packed_row_stride: u32 = match params.bits {
4 => {
(params.embed_dim / 8) as u32
}
6 => {
(params.embed_dim * 3 / 4) as u32
}
_ => unreachable!(), };
let gpu_params = GpuEmbeddingParams {
embed_dim: params.embed_dim as u32,
group_size: params.group_size as u32,
packed_row_stride,
n_groups_per_row: n_groups_per_row as u32,
};
let kernel_name = match params.bits {
4 => "embedding_gather_4bit",
6 => "embedding_gather_6bit",
_ => unreachable!(),
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let grid = MTLSize::new(params.embed_dim as u64, params.n_tokens as u64, 1);
let tg_size = MTLSize::new(
std::cmp::min(256, params.embed_dim as u64),
1,
1,
);
let params_bytes = as_bytes(&gpu_params);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(weight_packed)),
(1, KernelArg::Buffer(scales)),
(2, KernelArg::Buffer(biases)),
(3, KernelArg::Buffer(token_ids)),
(4, KernelArg::Buffer(output)),
(5, KernelArg::Bytes(params_bytes)),
],
grid,
tg_size,
);
Ok(())
}