use metaltile::{bench_kernel, kernel};
macro_rules! conv2d_kernel {
($name:ident, $subop:literal, $kh:expr, $kw:expr, $sh:expr, $sw:expr) => {
#[bench_kernel(op="conv2d", subop=$subop, class=GenericEmpty, tol=1e-3, kernel_mode=Grid3D,)]
#[kernel]
pub fn $name<T>(
input: Tensor<T>,
weight: Tensor<T>,
bias: Tensor<T>,
out: Tensor<T>,
#[constexpr] batch: u32,
#[constexpr] in_ch: u32,
#[constexpr] in_h: u32,
#[constexpr] in_w: u32,
#[constexpr] out_ch: u32,
#[constexpr] out_h: u32,
#[constexpr] out_w: u32,
#[constexpr] kh: u32,
#[constexpr] kw: u32,
#[constexpr] stride_h: u32,
#[constexpr] stride_w: u32,
#[constexpr] pad_h: u32,
#[constexpr] pad_w: u32,
) {
let idx = program_id::<0>();
let ow = idx % out_w;
let t1 = idx / out_w;
let oh = t1 % out_h;
let t2 = t1 / out_h;
let oc = t2 % out_ch;
let n = t2 / out_ch;
let kh_v = $kh;
let kw_v = $kw;
let sh_v = $sh;
let sw_v = $sw;
let ph0 = oh * sh_v;
let pw0 = ow * sw_v;
let input_plane = in_h * in_w;
let in_n_stride = in_ch * input_plane;
let w_in_stride = kh_v * kw_v;
let w_oc_stride = in_ch * w_in_stride;
let mut acc = load(bias[oc]).cast::<f32>();
for ic in range(0u32, in_ch, 1u32) {
let in_ic_base = n * in_n_stride + ic * input_plane;
let w_ic_base = oc * w_oc_stride + ic * w_in_stride;
for ky in range(0u32, kh_v, 1u32) {
let ph = ph0 + ky;
let row_ok = (ph >= pad_h) & (ph < pad_h + in_h);
let ih = select(row_ok, ph - pad_h, 0u32);
for kx in range(0u32, kw_v, 1u32) {
let pw = pw0 + kx;
let col_ok = (pw >= pad_w) & (pw < pad_w + in_w);
let valid = row_ok & col_ok;
let iw = select(col_ok, pw - pad_w, 0u32);
let in_idx = in_ic_base + ih * in_w + iw;
let pix = load(input[in_idx]).cast::<f32>();
let pix_m = select(valid, pix, 0.0f32);
let w_idx = w_ic_base + ky * kw_v + kx;
let wt = load(weight[w_idx]).cast::<f32>();
acc = acc + pix_m * wt;
}
}
}
store(out[idx], acc.cast::<T>());
}
};
}
conv2d_kernel!(conv2d_patch14, "patch14", 14u32, 14u32, 14u32, 14u32);
conv2d_kernel!(conv2d_patch16, "patch16", 16u32, 16u32, 16u32, 16u32);
conv2d_kernel!(conv2d_generic, "generic", kh, kw, stride_h, stride_w);
#[bench_kernel(
op="conv2d",
subop="grouped",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn conv2d_grouped<T>(
input: Tensor<T>,
weight: Tensor<T>,
bias: Tensor<T>,
out: Tensor<T>,
#[constexpr] in_ch: u32,
#[constexpr] in_h: u32,
#[constexpr] in_w: u32,
#[constexpr] out_ch: u32,
#[constexpr] out_h: u32,
#[constexpr] out_w: u32,
#[constexpr] kh: u32,
#[constexpr] kw: u32,
#[constexpr] stride_h: u32,
#[constexpr] stride_w: u32,
#[constexpr] pad_h: u32,
#[constexpr] pad_w: u32,
#[constexpr] dilation_h: u32,
#[constexpr] dilation_w: u32,
#[constexpr] icpg: u32,
#[constexpr] ocpg: u32,
) {
let idx = program_id::<0>();
let ow = idx % out_w;
let t1 = idx / out_w;
let oh = t1 % out_h;
let t2 = t1 / out_h;
let oc = t2 % out_ch;
let n = t2 / out_ch;
let group = oc / ocpg;
let ic_base = group * icpg;
let ph0 = oh * stride_h;
let pw0 = ow * stride_w;
let input_plane = in_h * in_w;
let in_n_stride = in_ch * input_plane;
let w_in_stride = kh * kw;
let w_oc_stride = icpg * w_in_stride;
let mut acc = load(bias[oc]).cast::<f32>();
for wic in range(0u32, icpg, 1u32) {
let real_ic = ic_base + wic;
let in_ic_base = n * in_n_stride + real_ic * input_plane;
let w_ic_base = oc * w_oc_stride + wic * w_in_stride;
for ky in range(0u32, kh, 1u32) {
let ph = ph0 + ky * dilation_h;
let row_ok = (ph >= pad_h) & (ph < pad_h + in_h);
let ih = select(row_ok, ph - pad_h, 0u32);
for kx in range(0u32, kw, 1u32) {
let pw = pw0 + kx * dilation_w;
let col_ok = (pw >= pad_w) & (pw < pad_w + in_w);
let valid = row_ok & col_ok;
let iw = select(col_ok, pw - pad_w, 0u32);
let in_idx = in_ic_base + ih * in_w + iw;
let pix = load(input[in_idx]).cast::<f32>();
let pix_m = select(valid, pix, 0.0f32);
let w_idx = w_ic_base + ky * kw + kx;
let wt = load(weight[w_idx]).cast::<f32>();
acc = acc + pix_m * wt;
}
}
}
store(out[idx], acc.cast::<T>());
}