megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/cuda/batched_matrix_mul/int8x8x32.cuh
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#pragma once
#include "src/cuda/utils.cuh"

namespace megdnn {
namespace cuda {

template <typename ThreadConfig_, int m_, int n_, int k_tot, int k_>
struct UnrollConfig {
    typedef ThreadConfig_ ThreadConfig;
    static int const unroll_m = m_;
    static int const unroll_n = n_;
    static int const block_m = ThreadConfig::thread_x * m_;
    static int const block_n = ThreadConfig::thread_y * n_;
    static int const unroll_k = k_tot;
    static int const unroll = k_;
    static int const thread_k = k_tot / k_;
    static int const load_m = (m_ / 4) / (ThreadConfig::thread_y / thread_k) * 4;
    static int const load_n = (n_ / 4) / (ThreadConfig::thread_x / thread_k) * 4;
};

template <int x_, int y_>
struct ThreadConfig {
    static int const thread_x = x_;
    static int const thread_y = y_;
};

template <int row, int col>
struct SmemConfig {
    static int const smem_row = row;
    static int const smem_col = col;
};

template <typename SmemConfig_>
struct Global2SharedMem {
    typedef SmemConfig_ SmemConfig;
    const int8_t* g_ptr;
    int32_t* smem;
    int smem_off;
    int smem_bound;
    int32_t reg[SmemConfig::smem_col][SmemConfig::smem_row / 4];
    int ld_src;
    int ld_dst;
    int check_bound_row;
    int check_bound_col;
    int step;
    bool tr;
    bool aligned;

    __device__ __forceinline__ Global2SharedMem(
            int32_t* smem_, int s_off, int s_bound, int ld_src_, int ld_dst_, int b_r_,
            int b_c_, int step_, bool tr_, bool al_)
            : smem(smem_),
              smem_off(s_off),
              smem_bound(s_bound),
              ld_src(ld_src_),
              ld_dst(ld_dst_),
              check_bound_row(b_r_),
              check_bound_col(b_c_),
              step(step_),
              tr(tr_),
              aligned(al_) {}

    __device__ __forceinline__ void gmem2reg_cpy();
    __device__ __forceinline__ void reg2smem_cpy();
    __device__ __forceinline__ void iter_forward();
};

template <typename UnrollConfig, typename ThreadConfig>
__global__ void batched_8x8x32_kern(
        const int8_t* a, int lda, int sta, bool tra, const int8_t* b, int ldb, int stb,
        bool trb, int32_t* c, int ldc, int stc, int m, int n, int k);

void exec_igemm_8x8x32(
        const int8_t* A, const int8_t* B, int32_t* C, const int batch_count,
        const int m, const int n, const int k, int ldA, int ldB, int ldC, int stA,
        int stB, int stC, bool transA, bool transB, cudaStream_t stream);

}  // namespace cuda
}  // namespace megdnn

// vim: syntax=cuda.doxygen