pub fn conv2d_spirv(
n: u32,
c_in: u32,
h_in: u32,
w_in: u32,
k_out: u32,
fh: u32,
fw: u32,
oh: u32,
ow: u32,
stride_h: u32,
stride_w: u32,
pad_h: u32,
pad_w: u32,
) -> Vec<u32>Expand description
Generate an OpenCL SPIR-V compute kernel for 2-D convolution (NCHW layout).
Each work-item computes one output element.
Kernel parameters (all passed via zeKernelSetArgumentValue):
(CrossWorkgroup float* input,
CrossWorkgroup float* filter,
CrossWorkgroup float* output)All dimension constants are baked in as OpConstant.