use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
pub static REPEAT_TILED_SHADER_SOURCE: &str =
include_str!("../shaders/repeat_tiled.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("repeat_tiled_f32", REPEAT_TILED_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuRepeatTiledParams {
seq: u32,
hg: u32,
h: u32,
k: u32,
}
#[derive(Clone, Copy, Debug)]
pub struct RepeatTiledParams {
pub seq: u32,
pub hg: u32,
pub h: u32,
pub k: u32,
}
pub fn dispatch_repeat_tiled_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
dst: &MlxBuffer,
params: &RepeatTiledParams,
) -> Result<()> {
if params.seq == 0 || params.hg == 0 || params.h == 0 || params.k == 0 {
return Err(MlxError::InvalidArgument(
"repeat_tiled_f32: seq, hg, h, k must all be > 0".into(),
));
}
if params.h % params.hg != 0 {
return Err(MlxError::InvalidArgument(format!(
"repeat_tiled_f32: h ({}) must be a multiple of hg ({})",
params.h, params.hg
)));
}
let src_elems = (params.seq as usize)
.checked_mul(params.hg as usize)
.and_then(|v| v.checked_mul(params.k as usize))
.ok_or_else(|| {
MlxError::InvalidArgument(
"repeat_tiled_f32: seq*hg*k overflows usize".into(),
)
})?;
let dst_elems = (params.seq as usize)
.checked_mul(params.h as usize)
.and_then(|v| v.checked_mul(params.k as usize))
.ok_or_else(|| {
MlxError::InvalidArgument(
"repeat_tiled_f32: seq*h*k overflows usize".into(),
)
})?;
let src_bytes = src_elems * 4;
if src.byte_len() < src_bytes {
return Err(MlxError::InvalidArgument(format!(
"repeat_tiled_f32: src buffer too small: need {} bytes, have {}",
src_bytes,
src.byte_len()
)));
}
let dst_bytes = dst_elems * 4;
if dst.byte_len() < dst_bytes {
return Err(MlxError::InvalidArgument(format!(
"repeat_tiled_f32: dst buffer too small: need {} bytes, have {}",
dst_bytes,
dst.byte_len()
)));
}
let pipeline = registry.get_pipeline("repeat_tiled_f32", device)?;
let gpu_params = GpuRepeatTiledParams {
seq: params.seq,
hg: params.hg,
h: params.h,
k: params.k,
};
let grid = MTLSize::new(params.k as u64, params.h as u64, params.seq as u64);
let tg_x = std::cmp::min(256u64, params.k as u64);
let tg = MTLSize::new(tg_x, 1, 1);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(dst)),
(2, KernelArg::Bytes(as_bytes(&gpu_params))),
],
grid,
tg,
);
Ok(())
}