use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="conv2d",
subop="winograd_3x3",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn winograd_conv2d_3x3<T>(
input: Tensor<T>,
weight: Tensor<T>,
bias: Tensor<T>,
out: Tensor<T>,
#[constexpr] in_ch: u32,
#[constexpr] in_h: u32,
#[constexpr] in_w: u32,
#[constexpr] out_ch: u32,
#[constexpr] out_h: u32,
#[constexpr] out_w: u32,
#[constexpr] pad_h: u32,
#[constexpr] pad_w: u32,
#[constexpr] tiles_h: u32,
#[constexpr] tiles_w: u32,
) {
let idx = program_id::<0>();
let tw = idx % tiles_w;
let r1 = idx / tiles_w;
let th = r1 % tiles_h;
let r2 = r1 / tiles_h;
let oc = r2 % out_ch;
let n = r2 / out_ch;
let pr0 = th * 2u32;
let pc0 = tw * 2u32;
let pr_0 = pr0;
let pr_1 = pr0 + 1u32;
let pr_2 = pr0 + 2u32;
let pr_3 = pr0 + 3u32;
let row_ok_0 = (pr_0 >= pad_h) & (pr_0 < pad_h + in_h);
let row_ok_1 = (pr_1 >= pad_h) & (pr_1 < pad_h + in_h);
let row_ok_2 = (pr_2 >= pad_h) & (pr_2 < pad_h + in_h);
let row_ok_3 = (pr_3 >= pad_h) & (pr_3 < pad_h + in_h);
let ih_0 = select(row_ok_0, pr_0 - pad_h, 0u32);
let ih_1 = select(row_ok_1, pr_1 - pad_h, 0u32);
let ih_2 = select(row_ok_2, pr_2 - pad_h, 0u32);
let ih_3 = select(row_ok_3, pr_3 - pad_h, 0u32);
let pc_0 = pc0;
let pc_1 = pc0 + 1u32;
let pc_2 = pc0 + 2u32;
let pc_3 = pc0 + 3u32;
let col_ok_0 = (pc_0 >= pad_w) & (pc_0 < pad_w + in_w);
let col_ok_1 = (pc_1 >= pad_w) & (pc_1 < pad_w + in_w);
let col_ok_2 = (pc_2 >= pad_w) & (pc_2 < pad_w + in_w);
let col_ok_3 = (pc_3 >= pad_w) & (pc_3 < pad_w + in_w);
let iw_0 = select(col_ok_0, pc_0 - pad_w, 0u32);
let iw_1 = select(col_ok_1, pc_1 - pad_w, 0u32);
let iw_2 = select(col_ok_2, pc_2 - pad_w, 0u32);
let iw_3 = select(col_ok_3, pc_3 - pad_w, 0u32);
let input_plane = in_h * in_w;
let in_n_stride = in_ch * input_plane;
let n_base = n * in_n_stride;
let w_oc_base = oc * in_ch * 9u32;
let mut m00 = 0.0f32;
let mut m01 = 0.0f32;
let mut m02 = 0.0f32;
let mut m03 = 0.0f32;
let mut m10 = 0.0f32;
let mut m11 = 0.0f32;
let mut m12 = 0.0f32;
let mut m13 = 0.0f32;
let mut m20 = 0.0f32;
let mut m21 = 0.0f32;
let mut m22 = 0.0f32;
let mut m23 = 0.0f32;
let mut m30 = 0.0f32;
let mut m31 = 0.0f32;
let mut m32 = 0.0f32;
let mut m33 = 0.0f32;
for ic in range(0u32, in_ch, 1u32) {
let in_ic_base = n_base + ic * input_plane;
let row0 = in_ic_base + ih_0 * in_w;
let row1 = in_ic_base + ih_1 * in_w;
let row2 = in_ic_base + ih_2 * in_w;
let row3 = in_ic_base + ih_3 * in_w;
let d00 = select(row_ok_0 & col_ok_0, load(input[row0 + iw_0]).cast::<f32>(), 0.0f32);
let d01 = select(row_ok_0 & col_ok_1, load(input[row0 + iw_1]).cast::<f32>(), 0.0f32);
let d02 = select(row_ok_0 & col_ok_2, load(input[row0 + iw_2]).cast::<f32>(), 0.0f32);
let d03 = select(row_ok_0 & col_ok_3, load(input[row0 + iw_3]).cast::<f32>(), 0.0f32);
let d10 = select(row_ok_1 & col_ok_0, load(input[row1 + iw_0]).cast::<f32>(), 0.0f32);
let d11 = select(row_ok_1 & col_ok_1, load(input[row1 + iw_1]).cast::<f32>(), 0.0f32);
let d12 = select(row_ok_1 & col_ok_2, load(input[row1 + iw_2]).cast::<f32>(), 0.0f32);
let d13 = select(row_ok_1 & col_ok_3, load(input[row1 + iw_3]).cast::<f32>(), 0.0f32);
let d20 = select(row_ok_2 & col_ok_0, load(input[row2 + iw_0]).cast::<f32>(), 0.0f32);
let d21 = select(row_ok_2 & col_ok_1, load(input[row2 + iw_1]).cast::<f32>(), 0.0f32);
let d22 = select(row_ok_2 & col_ok_2, load(input[row2 + iw_2]).cast::<f32>(), 0.0f32);
let d23 = select(row_ok_2 & col_ok_3, load(input[row2 + iw_3]).cast::<f32>(), 0.0f32);
let d30 = select(row_ok_3 & col_ok_0, load(input[row3 + iw_0]).cast::<f32>(), 0.0f32);
let d31 = select(row_ok_3 & col_ok_1, load(input[row3 + iw_1]).cast::<f32>(), 0.0f32);
let d32 = select(row_ok_3 & col_ok_2, load(input[row3 + iw_2]).cast::<f32>(), 0.0f32);
let d33 = select(row_ok_3 & col_ok_3, load(input[row3 + iw_3]).cast::<f32>(), 0.0f32);
let t00 = d00 - d20;
let t01 = d01 - d21;
let t02 = d02 - d22;
let t03 = d03 - d23;
let t10 = d10 + d20;
let t11 = d11 + d21;
let t12 = d12 + d22;
let t13 = d13 + d23;
let t20 = d20 - d10;
let t21 = d21 - d11;
let t22 = d22 - d12;
let t23 = d23 - d13;
let t30 = d10 - d30;
let t31 = d11 - d31;
let t32 = d12 - d32;
let t33 = d13 - d33;
let v00 = t00 - t02;
let v01 = t01 + t02;
let v02 = t02 - t01;
let v03 = t01 - t03;
let v10 = t10 - t12;
let v11 = t11 + t12;
let v12 = t12 - t11;
let v13 = t11 - t13;
let v20 = t20 - t22;
let v21 = t21 + t22;
let v22 = t22 - t21;
let v23 = t21 - t23;
let v30 = t30 - t32;
let v31 = t31 + t32;
let v32 = t32 - t31;
let v33 = t31 - t33;
let w_base = w_oc_base + ic * 9u32;
let g00 = load(weight[w_base + 0u32]).cast::<f32>();
let g01 = load(weight[w_base + 1u32]).cast::<f32>();
let g02 = load(weight[w_base + 2u32]).cast::<f32>();
let g10 = load(weight[w_base + 3u32]).cast::<f32>();
let g11 = load(weight[w_base + 4u32]).cast::<f32>();
let g12 = load(weight[w_base + 5u32]).cast::<f32>();
let g20 = load(weight[w_base + 6u32]).cast::<f32>();
let g21 = load(weight[w_base + 7u32]).cast::<f32>();
let g22 = load(weight[w_base + 8u32]).cast::<f32>();
let s00 = g00;
let s01 = g01;
let s02 = g02;
let s10 = 0.5f32 * (g00 + g10 + g20);
let s11 = 0.5f32 * (g01 + g11 + g21);
let s12 = 0.5f32 * (g02 + g12 + g22);
let s20 = 0.5f32 * (g00 - g10 + g20);
let s21 = 0.5f32 * (g01 - g11 + g21);
let s22 = 0.5f32 * (g02 - g12 + g22);
let s30 = g20;
let s31 = g21;
let s32 = g22;
let u00 = s00;
let u01 = 0.5f32 * (s00 + s01 + s02);
let u02 = 0.5f32 * (s00 - s01 + s02);
let u03 = s02;
let u10 = s10;
let u11 = 0.5f32 * (s10 + s11 + s12);
let u12 = 0.5f32 * (s10 - s11 + s12);
let u13 = s12;
let u20 = s20;
let u21 = 0.5f32 * (s20 + s21 + s22);
let u22 = 0.5f32 * (s20 - s21 + s22);
let u23 = s22;
let u30 = s30;
let u31 = 0.5f32 * (s30 + s31 + s32);
let u32 = 0.5f32 * (s30 - s31 + s32);
let u33 = s32;
m00 = m00 + u00 * v00;
m01 = m01 + u01 * v01;
m02 = m02 + u02 * v02;
m03 = m03 + u03 * v03;
m10 = m10 + u10 * v10;
m11 = m11 + u11 * v11;
m12 = m12 + u12 * v12;
m13 = m13 + u13 * v13;
m20 = m20 + u20 * v20;
m21 = m21 + u21 * v21;
m22 = m22 + u22 * v22;
m23 = m23 + u23 * v23;
m30 = m30 + u30 * v30;
m31 = m31 + u31 * v31;
m32 = m32 + u32 * v32;
m33 = m33 + u33 * v33;
}
let p00 = m00 + m10 + m20;
let p01 = m01 + m11 + m21;
let p02 = m02 + m12 + m22;
let p03 = m03 + m13 + m23;
let p10 = m10 - m20 - m30;
let p11 = m11 - m21 - m31;
let p12 = m12 - m22 - m32;
let p13 = m13 - m23 - m33;
let bias_v = load(bias[oc]).cast::<f32>();
let y00 = p00 + p01 + p02 + bias_v;
let y01 = p01 - p02 - p03 + bias_v;
let y10 = p10 + p11 + p12 + bias_v;
let y11 = p11 - p12 - p13 + bias_v;
let out_plane = out_h * out_w;
let out_oc_base = (n * out_ch + oc) * out_plane;
let oh0 = th * 2u32;
let ow0 = tw * 2u32;
let out_row0 = out_oc_base + oh0 * out_w;
let out_row1 = out_oc_base + (oh0 + 1u32) * out_w;
store(out[out_row0 + ow0], y00.cast::<T>());
store(out[out_row0 + ow0 + 1u32], y01.cast::<T>());
store(out[out_row1 + ow0], y10.cast::<T>());
store(out[out_row1 + ow0 + 1u32], y11.cast::<T>());
}
#[bench_kernel(
op="conv2d",
subop="winograd_filter_transform_3x3",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
pub fn winograd_filter_transform_3x3<T>(
weight: Tensor<T>,
out: Tensor<T>,
#[constexpr] in_ch: u32,
#[constexpr] out_ch: u32,
) {
let idx = program_id::<0>();
let total = out_ch * in_ch;
if idx < total {
let w_base = idx * 9u32;
let g00 = load(weight[w_base + 0u32]).cast::<f32>();
let g01 = load(weight[w_base + 1u32]).cast::<f32>();
let g02 = load(weight[w_base + 2u32]).cast::<f32>();
let g10 = load(weight[w_base + 3u32]).cast::<f32>();
let g11 = load(weight[w_base + 4u32]).cast::<f32>();
let g12 = load(weight[w_base + 5u32]).cast::<f32>();
let g20 = load(weight[w_base + 6u32]).cast::<f32>();
let g21 = load(weight[w_base + 7u32]).cast::<f32>();
let g22 = load(weight[w_base + 8u32]).cast::<f32>();
let s00 = g00;
let s01 = g01;
let s02 = g02;
let s10 = 0.5f32 * (g00 + g10 + g20);
let s11 = 0.5f32 * (g01 + g11 + g21);
let s12 = 0.5f32 * (g02 + g12 + g22);
let s20 = 0.5f32 * (g00 - g10 + g20);
let s21 = 0.5f32 * (g01 - g11 + g21);
let s22 = 0.5f32 * (g02 - g12 + g22);
let s30 = g20;
let s31 = g21;
let s32 = g22;
let u_base = idx * 16u32;
store(out[u_base + 0u32], s00.cast::<T>());
store(out[u_base + 1u32], (0.5f32 * (s00 + s01 + s02)).cast::<T>());
store(out[u_base + 2u32], (0.5f32 * (s00 - s01 + s02)).cast::<T>());
store(out[u_base + 3u32], s02.cast::<T>());
store(out[u_base + 4u32], s10.cast::<T>());
store(out[u_base + 5u32], (0.5f32 * (s10 + s11 + s12)).cast::<T>());
store(out[u_base + 6u32], (0.5f32 * (s10 - s11 + s12)).cast::<T>());
store(out[u_base + 7u32], s12.cast::<T>());
store(out[u_base + 8u32], s20.cast::<T>());
store(out[u_base + 9u32], (0.5f32 * (s20 + s21 + s22)).cast::<T>());
store(out[u_base + 10u32], (0.5f32 * (s20 - s21 + s22)).cast::<T>());
store(out[u_base + 11u32], s22.cast::<T>());
store(out[u_base + 12u32], s30.cast::<T>());
store(out[u_base + 13u32], (0.5f32 * (s30 + s31 + s32)).cast::<T>());
store(out[u_base + 14u32], (0.5f32 * (s30 - s31 + s32)).cast::<T>());
store(out[u_base + 15u32], s32.cast::<T>());
}
}
#[bench_kernel(
op="conv2d",
subop="winograd_3x3_split",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Grid3D,
)]
#[kernel]
#[allow(clippy::too_many_arguments)]
pub fn winograd_conv2d_3x3_split<T>(
input: Tensor<T>,
u: Tensor<T>,
bias: Tensor<T>,
out: Tensor<T>,
#[constexpr] in_ch: u32,
#[constexpr] in_h: u32,
#[constexpr] in_w: u32,
#[constexpr] out_ch: u32,
#[constexpr] out_h: u32,
#[constexpr] out_w: u32,
#[constexpr] pad_h: u32,
#[constexpr] pad_w: u32,
#[constexpr] tiles_h: u32,
#[constexpr] tiles_w: u32,
) {
let idx = program_id::<0>();
let tw = idx % tiles_w;
let r1 = idx / tiles_w;
let th = r1 % tiles_h;
let r2 = r1 / tiles_h;
let oc = r2 % out_ch;
let n = r2 / out_ch;
let pr0 = th * 2u32;
let pc0 = tw * 2u32;
let pr_0 = pr0;
let pr_1 = pr0 + 1u32;
let pr_2 = pr0 + 2u32;
let pr_3 = pr0 + 3u32;
let row_ok_0 = (pr_0 >= pad_h) & (pr_0 < pad_h + in_h);
let row_ok_1 = (pr_1 >= pad_h) & (pr_1 < pad_h + in_h);
let row_ok_2 = (pr_2 >= pad_h) & (pr_2 < pad_h + in_h);
let row_ok_3 = (pr_3 >= pad_h) & (pr_3 < pad_h + in_h);
let ih_0 = select(row_ok_0, pr_0 - pad_h, 0u32);
let ih_1 = select(row_ok_1, pr_1 - pad_h, 0u32);
let ih_2 = select(row_ok_2, pr_2 - pad_h, 0u32);
let ih_3 = select(row_ok_3, pr_3 - pad_h, 0u32);
let pc_0 = pc0;
let pc_1 = pc0 + 1u32;
let pc_2 = pc0 + 2u32;
let pc_3 = pc0 + 3u32;
let col_ok_0 = (pc_0 >= pad_w) & (pc_0 < pad_w + in_w);
let col_ok_1 = (pc_1 >= pad_w) & (pc_1 < pad_w + in_w);
let col_ok_2 = (pc_2 >= pad_w) & (pc_2 < pad_w + in_w);
let col_ok_3 = (pc_3 >= pad_w) & (pc_3 < pad_w + in_w);
let iw_0 = select(col_ok_0, pc_0 - pad_w, 0u32);
let iw_1 = select(col_ok_1, pc_1 - pad_w, 0u32);
let iw_2 = select(col_ok_2, pc_2 - pad_w, 0u32);
let iw_3 = select(col_ok_3, pc_3 - pad_w, 0u32);
let input_plane = in_h * in_w;
let in_n_stride = in_ch * input_plane;
let n_base = n * in_n_stride;
let u_oc_base = oc * in_ch * 16u32;
let mut m00 = 0.0f32;
let mut m01 = 0.0f32;
let mut m02 = 0.0f32;
let mut m03 = 0.0f32;
let mut m10 = 0.0f32;
let mut m11 = 0.0f32;
let mut m12 = 0.0f32;
let mut m13 = 0.0f32;
let mut m20 = 0.0f32;
let mut m21 = 0.0f32;
let mut m22 = 0.0f32;
let mut m23 = 0.0f32;
let mut m30 = 0.0f32;
let mut m31 = 0.0f32;
let mut m32 = 0.0f32;
let mut m33 = 0.0f32;
for ic in range(0u32, in_ch, 1u32) {
let in_ic_base = n_base + ic * input_plane;
let row0 = in_ic_base + ih_0 * in_w;
let row1 = in_ic_base + ih_1 * in_w;
let row2 = in_ic_base + ih_2 * in_w;
let row3 = in_ic_base + ih_3 * in_w;
let d00 = select(row_ok_0 & col_ok_0, load(input[row0 + iw_0]).cast::<f32>(), 0.0f32);
let d01 = select(row_ok_0 & col_ok_1, load(input[row0 + iw_1]).cast::<f32>(), 0.0f32);
let d02 = select(row_ok_0 & col_ok_2, load(input[row0 + iw_2]).cast::<f32>(), 0.0f32);
let d03 = select(row_ok_0 & col_ok_3, load(input[row0 + iw_3]).cast::<f32>(), 0.0f32);
let d10 = select(row_ok_1 & col_ok_0, load(input[row1 + iw_0]).cast::<f32>(), 0.0f32);
let d11 = select(row_ok_1 & col_ok_1, load(input[row1 + iw_1]).cast::<f32>(), 0.0f32);
let d12 = select(row_ok_1 & col_ok_2, load(input[row1 + iw_2]).cast::<f32>(), 0.0f32);
let d13 = select(row_ok_1 & col_ok_3, load(input[row1 + iw_3]).cast::<f32>(), 0.0f32);
let d20 = select(row_ok_2 & col_ok_0, load(input[row2 + iw_0]).cast::<f32>(), 0.0f32);
let d21 = select(row_ok_2 & col_ok_1, load(input[row2 + iw_1]).cast::<f32>(), 0.0f32);
let d22 = select(row_ok_2 & col_ok_2, load(input[row2 + iw_2]).cast::<f32>(), 0.0f32);
let d23 = select(row_ok_2 & col_ok_3, load(input[row2 + iw_3]).cast::<f32>(), 0.0f32);
let d30 = select(row_ok_3 & col_ok_0, load(input[row3 + iw_0]).cast::<f32>(), 0.0f32);
let d31 = select(row_ok_3 & col_ok_1, load(input[row3 + iw_1]).cast::<f32>(), 0.0f32);
let d32 = select(row_ok_3 & col_ok_2, load(input[row3 + iw_2]).cast::<f32>(), 0.0f32);
let d33 = select(row_ok_3 & col_ok_3, load(input[row3 + iw_3]).cast::<f32>(), 0.0f32);
let t00 = d00 - d20;
let t01 = d01 - d21;
let t02 = d02 - d22;
let t03 = d03 - d23;
let t10 = d10 + d20;
let t11 = d11 + d21;
let t12 = d12 + d22;
let t13 = d13 + d23;
let t20 = d20 - d10;
let t21 = d21 - d11;
let t22 = d22 - d12;
let t23 = d23 - d13;
let t30 = d10 - d30;
let t31 = d11 - d31;
let t32 = d12 - d32;
let t33 = d13 - d33;
let v00 = t00 - t02;
let v01 = t01 + t02;
let v02 = t02 - t01;
let v03 = t01 - t03;
let v10 = t10 - t12;
let v11 = t11 + t12;
let v12 = t12 - t11;
let v13 = t11 - t13;
let v20 = t20 - t22;
let v21 = t21 + t22;
let v22 = t22 - t21;
let v23 = t21 - t23;
let v30 = t30 - t32;
let v31 = t31 + t32;
let v32 = t32 - t31;
let v33 = t31 - t33;
let u_base = u_oc_base + ic * 16u32;
let u00 = load(u[u_base + 0u32]).cast::<f32>();
let u01 = load(u[u_base + 1u32]).cast::<f32>();
let u02 = load(u[u_base + 2u32]).cast::<f32>();
let u03 = load(u[u_base + 3u32]).cast::<f32>();
let u10 = load(u[u_base + 4u32]).cast::<f32>();
let u11 = load(u[u_base + 5u32]).cast::<f32>();
let u12 = load(u[u_base + 6u32]).cast::<f32>();
let u13 = load(u[u_base + 7u32]).cast::<f32>();
let u20 = load(u[u_base + 8u32]).cast::<f32>();
let u21 = load(u[u_base + 9u32]).cast::<f32>();
let u22 = load(u[u_base + 10u32]).cast::<f32>();
let u23 = load(u[u_base + 11u32]).cast::<f32>();
let u30 = load(u[u_base + 12u32]).cast::<f32>();
let u31 = load(u[u_base + 13u32]).cast::<f32>();
let u32_ = load(u[u_base + 14u32]).cast::<f32>();
let u33 = load(u[u_base + 15u32]).cast::<f32>();
m00 = m00 + u00 * v00;
m01 = m01 + u01 * v01;
m02 = m02 + u02 * v02;
m03 = m03 + u03 * v03;
m10 = m10 + u10 * v10;
m11 = m11 + u11 * v11;
m12 = m12 + u12 * v12;
m13 = m13 + u13 * v13;
m20 = m20 + u20 * v20;
m21 = m21 + u21 * v21;
m22 = m22 + u22 * v22;
m23 = m23 + u23 * v23;
m30 = m30 + u30 * v30;
m31 = m31 + u31 * v31;
m32 = m32 + u32_ * v32;
m33 = m33 + u33 * v33;
}
let p00 = m00 + m10 + m20;
let p01 = m01 + m11 + m21;
let p02 = m02 + m12 + m22;
let p03 = m03 + m13 + m23;
let p10 = m10 - m20 - m30;
let p11 = m11 - m21 - m31;
let p12 = m12 - m22 - m32;
let p13 = m13 - m23 - m33;
let bias_v = load(bias[oc]).cast::<f32>();
let y00 = p00 + p01 + p02 + bias_v;
let y01 = p01 - p02 - p03 + bias_v;
let y10 = p10 + p11 + p12 + bias_v;
let y11 = p11 - p12 - p13 + bias_v;
let out_plane = out_h * out_w;
let out_oc_base = (n * out_ch + oc) * out_plane;
let oh0 = th * 2u32;
let ow0 = tw * 2u32;
let out_row0 = out_oc_base + oh0 * out_w;
let out_row1 = out_oc_base + (oh0 + 1u32) * out_w;
store(out[out_row0 + ow0], y00.cast::<T>());
store(out[out_row0 + ow0 + 1u32], y01.cast::<T>());
store(out[out_row1 + ow0], y10.cast::<T>());
store(out[out_row1 + ow0 + 1u32], y11.cast::<T>());
}