Skip to main content

conv2d_spirv

Function conv2d_spirv 

Source
pub fn conv2d_spirv(
    n: u32,
    c_in: u32,
    h_in: u32,
    w_in: u32,
    k_out: u32,
    fh: u32,
    fw: u32,
    oh: u32,
    ow: u32,
    stride_h: u32,
    stride_w: u32,
    pad_h: u32,
    pad_w: u32,
) -> Vec<u32>
Expand description

Generate an OpenCL SPIR-V compute kernel for 2-D convolution (NCHW layout).

Each work-item computes one output element.

Kernel parameters (all passed via zeKernelSetArgumentValue):

(CrossWorkgroup float* input,
 CrossWorkgroup float* filter,
 CrossWorkgroup float* output)

All dimension constants are baked in as OpConstant.