use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
pub static BILINEAR_RESIZE_2D_SHADER_SOURCE: &str =
include_str!("../shaders/bilinear_resize_2d.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("bilinear_resize_2d_f32", BILINEAR_RESIZE_2D_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuBilinearResize2dParams {
trained_n: u32,
target_n_x: u32,
target_n_y: u32,
n_embd: u32,
sf_x: f32,
sf_y: f32,
support_x: f32,
support_y: f32,
invscale_x: f32,
invscale_y: f32,
}
const TG_SIZE: u64 = 256;
pub fn dispatch_bilinear_resize_2d_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
dst: &MlxBuffer,
trained_n: u32,
target_n_x: u32,
target_n_y: u32,
n_embd: u32,
) -> Result<()> {
if trained_n == 0 || target_n_x == 0 || target_n_y == 0 || n_embd == 0 {
return Err(MlxError::InvalidArgument(format!(
"bilinear_resize_2d_f32: trained_n ({trained_n}), target_n_x \
({target_n_x}), target_n_y ({target_n_y}), n_embd ({n_embd}) \
must all be > 0"
)));
}
let f32_sz = DType::F32.size_of();
let need_src = (trained_n as usize) * (trained_n as usize) * (n_embd as usize) * f32_sz;
if src.byte_len() < need_src {
return Err(MlxError::InvalidArgument(format!(
"bilinear_resize_2d_f32: src too small: {} vs {} bytes",
src.byte_len(),
need_src
)));
}
let target_total =
(target_n_y as usize) * (target_n_x as usize) * (n_embd as usize);
let need_dst = target_total * f32_sz;
if dst.byte_len() < need_dst {
return Err(MlxError::InvalidArgument(format!(
"bilinear_resize_2d_f32: dst too small: {} vs {} bytes",
dst.byte_len(),
need_dst
)));
}
let sf_x = (target_n_x as f32) / (trained_n as f32);
let sf_y = (target_n_y as f32) / (trained_n as f32);
let support_x = (1.0f32 / sf_x).max(1.0);
let support_y = (1.0f32 / sf_y).max(1.0);
let invscale_x = 1.0f32 / support_x;
let invscale_y = 1.0f32 / support_y;
let pipeline = registry.get_pipeline("bilinear_resize_2d_f32", device)?;
let gpu_params = GpuBilinearResize2dParams {
trained_n,
target_n_x,
target_n_y,
n_embd,
sf_x,
sf_y,
support_x,
support_y,
invscale_x,
invscale_y,
};
let total = target_total as u64;
let grid = MTLSize::new(total, 1, 1);
let tg = MTLSize::new(std::cmp::min(TG_SIZE, total), 1, 1);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(src)),
(2, KernelArg::Buffer(dst)),
],
grid,
tg,
);
Ok(())
}
#[cfg(test)]
pub(crate) fn bilinear_resize_2d_f32_cpu_oracle(
src: &[f32],
trained_n: u32,
target_n_x: u32,
target_n_y: u32,
n_embd: u32,
) -> Vec<f32> {
let trained = trained_n as i64;
let tx = target_n_x as i64;
let ty = target_n_y as i64;
let h = n_embd as usize;
let sf_x = (tx as f32) / (trained as f32);
let sf_y = (ty as f32) / (trained as f32);
let support_x = (1.0f32 / sf_x).max(1.0);
let support_y = (1.0f32 / sf_y).max(1.0);
let invscale_x = 1.0f32 / support_x;
let invscale_y = 1.0f32 / support_y;
let pixel_offset = 0.5f32;
let triangle_filter = |x: f32| -> f32 { (1.0 - x.abs()).max(0.0) };
let mut out = vec![0f32; (ty as usize) * (tx as usize) * h];
for y_dst in 0..ty {
let y = ((y_dst as f32) + pixel_offset) / sf_y;
let y_min = ((y - support_y + pixel_offset).max(0.0)) as i64;
let y_max = ((y + support_y + pixel_offset).min(trained as f32)) as i64;
for x_dst in 0..tx {
let x = ((x_dst as f32) + pixel_offset) / sf_x;
let x_min = ((x - support_x + pixel_offset).max(0.0)) as i64;
let x_max = ((x + support_x + pixel_offset).min(trained as f32)) as i64;
let dst_off = ((y_dst as usize) * (tx as usize) + (x_dst as usize)) * h;
let mut total_weight = 0.0f32;
for sy in y_min..y_max {
let wy = triangle_filter(((sy as f32) - y + pixel_offset) * invscale_y);
for sx in x_min..x_max {
let wx =
triangle_filter(((sx as f32) - x + pixel_offset) * invscale_x);
let w = wx * wy;
if w <= 0.0 {
continue;
}
let src_off =
((sy as usize) * (trained as usize) + (sx as usize)) * h;
for k in 0..h {
out[dst_off + k] += src[src_off + k] * w;
}
total_weight += w;
}
}
if total_weight > 0.0 {
let inv = 1.0 / total_weight;
for k in 0..h {
out[dst_off + k] *= inv;
}
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::MlxDevice;
use crate::graph::GraphExecutor;
fn run_kernel(
device: &MlxDevice,
src_host: &[f32],
trained_n: u32,
target_n_x: u32,
target_n_y: u32,
n_embd: u32,
) -> Vec<f32> {
let executor =
GraphExecutor::new(MlxDevice::new().expect("MlxDevice for executor"));
let mut session = executor.begin().expect("begin");
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut src_buf = device
.alloc_buffer(
src_host.len() * 4,
DType::F32,
vec![trained_n as usize, trained_n as usize, n_embd as usize],
)
.unwrap();
src_buf
.as_mut_slice::<f32>()
.unwrap()
.copy_from_slice(src_host);
let target_total =
(target_n_y as usize) * (target_n_x as usize) * (n_embd as usize);
let dst_buf = device
.alloc_buffer(
target_total * 4,
DType::F32,
vec![target_n_y as usize, target_n_x as usize, n_embd as usize],
)
.unwrap();
dispatch_bilinear_resize_2d_f32(
session.encoder_mut(),
&mut registry,
device.metal_device(),
&src_buf,
&dst_buf,
trained_n,
target_n_x,
target_n_y,
n_embd,
)
.expect("dispatch K2");
session.finish().expect("finish");
dst_buf.as_slice::<f32>().expect("readback").to_vec()
}
fn make_seeded(trained_n: u32, n_embd: u32) -> Vec<f32> {
let n = (trained_n * trained_n * n_embd) as usize;
(0..n)
.map(|i| ((i as f32) * 0.013_3_f32).sin() * 0.5)
.collect()
}
#[test]
fn adr021_k2_bilinear_resize_2d_f32_byte_identical_fast_path() {
let device = MlxDevice::new().expect("MlxDevice");
let trained: u32 = 8;
let n_embd: u32 = 16;
let src = make_seeded(trained, n_embd);
let oracle = bilinear_resize_2d_f32_cpu_oracle(&src, trained, trained, trained, n_embd);
let gpu = run_kernel(&device, &src, trained, trained, trained, n_embd);
assert_eq!(oracle.len(), gpu.len());
for (i, (a, b)) in oracle.iter().zip(gpu.iter()).enumerate() {
assert_eq!(a.to_bits(), b.to_bits(),
"K2 fast-path byte parity violated at element {i}: oracle={a} gpu={b}");
}
}
#[test]
fn adr021_k2_bilinear_resize_2d_f32_ulp_bound_upsample_2x() {
let device = MlxDevice::new().expect("MlxDevice");
let trained: u32 = 8;
let target: u32 = 16;
let n_embd: u32 = 16;
let src = make_seeded(trained, n_embd);
let oracle = bilinear_resize_2d_f32_cpu_oracle(&src, trained, target, target, n_embd);
let gpu = run_kernel(&device, &src, trained, target, target, n_embd);
assert_eq!(oracle.len(), gpu.len());
let max_abs = oracle
.iter()
.zip(gpu.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_abs < 1e-6,
"K2 upsample drift {} exceeds 1e-6 tolerance",
max_abs
);
}
#[test]
fn adr021_k2_bilinear_resize_2d_f32_ulp_bound_downsample_rect() {
let device = MlxDevice::new().expect("MlxDevice");
let trained: u32 = 16;
let n_embd: u32 = 16;
let src = make_seeded(trained, n_embd);
let oracle = bilinear_resize_2d_f32_cpu_oracle(&src, trained, 4, 8, n_embd);
let gpu = run_kernel(&device, &src, trained, 4, 8, n_embd);
assert_eq!(oracle.len(), gpu.len());
let max_abs = oracle
.iter()
.zip(gpu.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_abs < 1e-6,
"K2 downsample-rect drift {} exceeds 1e-6 tolerance",
max_abs
);
}
}