use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="conv3d",
subop="generic",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn conv3d_generic<T>(
input: Tensor<T>,
weight: Tensor<T>,
bias: Tensor<T>,
out: Tensor<T>,
#[constexpr] in_ch: u32,
#[constexpr] in_d: u32,
#[constexpr] in_h: u32,
#[constexpr] in_w: u32,
#[constexpr] out_ch: u32,
#[constexpr] out_d: u32,
#[constexpr] out_h: u32,
#[constexpr] out_w: u32,
#[constexpr] kd: u32,
#[constexpr] kh: u32,
#[constexpr] kw: u32,
#[constexpr] stride_d: u32,
#[constexpr] stride_h: u32,
#[constexpr] stride_w: u32,
#[constexpr] pad_d: u32,
#[constexpr] pad_h: u32,
#[constexpr] pad_w: u32,
) {
let idx = program_id::<0>();
let ow = idx % out_w;
let t1 = idx / out_w;
let oh = t1 % out_h;
let t2 = t1 / out_h;
let od = t2 % out_d;
let t3 = t2 / out_d;
let oc = t3 % out_ch;
let n = t3 / out_ch;
let pd0 = od * stride_d;
let ph0 = oh * stride_h;
let pw0 = ow * stride_w;
let input_plane = in_h * in_w;
let input_vol = in_d * input_plane;
let in_n_stride = in_ch * input_vol;
let w_kd_stride = kh * kw;
let w_in_stride = kd * w_kd_stride;
let w_oc_stride = in_ch * w_in_stride;
let mut acc = load(bias[oc]).cast::<f32>();
for ic in range(0u32, in_ch, 1u32) {
let in_ic_base = n * in_n_stride + ic * input_vol;
let w_ic_base = oc * w_oc_stride + ic * w_in_stride;
for kz in range(0u32, kd, 1u32) {
let pd = pd0 + kz;
let dep_ok = (pd >= pad_d) & (pd < pad_d + in_d);
let id = select(dep_ok, pd - pad_d, 0u32);
for ky in range(0u32, kh, 1u32) {
let ph = ph0 + ky;
let row_ok = (ph >= pad_h) & (ph < pad_h + in_h);
let ih = select(row_ok, ph - pad_h, 0u32);
for kx in range(0u32, kw, 1u32) {
let pw = pw0 + kx;
let col_ok = (pw >= pad_w) & (pw < pad_w + in_w);
let valid = dep_ok & row_ok & col_ok;
let iw = select(col_ok, pw - pad_w, 0u32);
let in_idx = in_ic_base + id * input_plane + ih * in_w + iw;
let pix = load(input[in_idx]).cast::<f32>();
let pix_m = select(valid, pix, 0.0f32);
let w_idx = w_ic_base + kz * w_kd_stride + ky * kw + kx;
let wt = load(weight[w_idx]).cast::<f32>();
acc = acc + pix_m * wt;
}
}
}
}
store(out[idx], acc.cast::<T>());
}
#[bench_kernel(
op="conv3d",
subop="grouped",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn conv3d_grouped<T>(
input: Tensor<T>,
weight: Tensor<T>,
bias: Tensor<T>,
out: Tensor<T>,
#[constexpr] in_ch: u32,
#[constexpr] in_d: u32,
#[constexpr] in_h: u32,
#[constexpr] in_w: u32,
#[constexpr] out_ch: u32,
#[constexpr] out_d: u32,
#[constexpr] out_h: u32,
#[constexpr] out_w: u32,
#[constexpr] kd: u32,
#[constexpr] kh: u32,
#[constexpr] kw: u32,
#[constexpr] stride_d: u32,
#[constexpr] stride_h: u32,
#[constexpr] stride_w: u32,
#[constexpr] pad_d: u32,
#[constexpr] pad_h: u32,
#[constexpr] pad_w: u32,
#[constexpr] dilation_d: u32,
#[constexpr] dilation_h: u32,
#[constexpr] dilation_w: u32,
#[constexpr] icpg: u32,
#[constexpr] ocpg: u32,
) {
let idx = program_id::<0>();
let ow = idx % out_w;
let t1 = idx / out_w;
let oh = t1 % out_h;
let t2 = t1 / out_h;
let od = t2 % out_d;
let t3 = t2 / out_d;
let oc = t3 % out_ch;
let n = t3 / out_ch;
let group = oc / ocpg;
let ic_base = group * icpg;
let pd0 = od * stride_d;
let ph0 = oh * stride_h;
let pw0 = ow * stride_w;
let input_plane = in_h * in_w;
let input_vol = in_d * input_plane;
let in_n_stride = in_ch * input_vol;
let w_kd_stride = kh * kw;
let w_in_stride = kd * w_kd_stride;
let w_oc_stride = icpg * w_in_stride;
let mut acc = load(bias[oc]).cast::<f32>();
for wic in range(0u32, icpg, 1u32) {
let real_ic = ic_base + wic;
let in_ic_base = n * in_n_stride + real_ic * input_vol;
let w_ic_base = oc * w_oc_stride + wic * w_in_stride;
for kz in range(0u32, kd, 1u32) {
let pd = pd0 + kz * dilation_d;
let dep_ok = (pd >= pad_d) & (pd < pad_d + in_d);
let id = select(dep_ok, pd - pad_d, 0u32);
for ky in range(0u32, kh, 1u32) {
let ph = ph0 + ky * dilation_h;
let row_ok = (ph >= pad_h) & (ph < pad_h + in_h);
let ih = select(row_ok, ph - pad_h, 0u32);
for kx in range(0u32, kw, 1u32) {
let pw = pw0 + kx * dilation_w;
let col_ok = (pw >= pad_w) & (pw < pad_w + in_w);
let valid = dep_ok & row_ok & col_ok;
let iw = select(col_ok, pw - pad_w, 0u32);
let in_idx = in_ic_base + id * input_plane + ih * in_w + iw;
let pix = load(input[in_idx]).cast::<f32>();
let pix_m = select(valid, pix, 0.0f32);
let w_idx = w_ic_base + kz * w_kd_stride + ky * kw + kx;
let wt = load(weight[w_idx]).cast::<f32>();
acc = acc + pix_m * wt;
}
}
}
}
store(out[idx], acc.cast::<T>());
}