#pragma once
#include <stdio.h>
#include <cassert>
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_common.h"
namespace megdnn {
namespace matmul {
template <typename Strategy, bool with_pack = true>
class GemmInterleaved;
template <typename Strategy>
class GemmInterleaved<Strategy, true> {
using compute_type = typename Strategy::compute_type;
using stype = typename Strategy::stype;
using pack_a_type = typename Strategy::pack_a_type;
using dtype = typename Strategy::dst_type;
const size_t m_M;
const size_t m_N;
const size_t m_K;
const bool m_transpose_A;
const bool m_transpose_B;
Strategy m_strategy;
static constexpr size_t CACHELINE_SIZE = 64;
const size_t m_align_size = 16;
size_t get_a_workspace_size() const {
size_t M = round_up(m_strategy.block_m, m_strategy.KERNEL_H);
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K);
return round_up(sizeof(pack_a_type) * M * K, CACHELINE_SIZE) + m_align_size;
}
size_t get_b_workspace_size() const {
size_t N = round_up(m_strategy.block_n, m_strategy.KERNEL_W);
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K);
return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size;
}
size_t get_c_workspace_size() const {
size_t ret = m_strategy.get_workspace_size();
if (ret == 0) {
return ret;
}
ret = round_up(ret, CACHELINE_SIZE);
return ret;
}
public:
size_t get_workspace_size() const { return get_bundle().total_size_in_bytes(); }
WorkspaceBundle get_bundle() const {
return {nullptr,
{get_a_workspace_size(), get_b_workspace_size(),
get_c_workspace_size()}};
}
GemmInterleaved(
const size_t M, const size_t N, const size_t K, const bool trA,
const bool trB, const Strategy& strategy, size_t align_size = 16)
: m_M(M),
m_N(N),
m_K(K),
m_transpose_A(trA),
m_transpose_B(trB),
m_strategy(strategy),
m_align_size(align_size) {}
void execute(
const stype* A, const size_t LDA, const stype* B, const size_t LDB,
dtype* C, const size_t LDC, void* workspace,
const compute_type* bias = nullptr) const {
megdnn_assert(workspace);
int8_t* workspace_bytes = reinterpret_cast<int8_t*>(workspace);
intptr_t workspace_int = reinterpret_cast<intptr_t>(workspace_bytes);
size_t diff = 0;
if (workspace_int & (m_align_size - 1)) {
diff = m_align_size - (workspace_int & (m_align_size - 1));
}
pack_a_type* a_panel = reinterpret_cast<pack_a_type*>(workspace_bytes + diff);
stype* b_panel = reinterpret_cast<stype*>(
workspace_bytes + get_a_workspace_size() + diff);
compute_type* c_panel = reinterpret_cast<compute_type*>(
workspace_bytes + get_a_workspace_size() + get_b_workspace_size() +
diff);
for (size_t k = 0; k < m_K; k += m_strategy.block_k) {
size_t kmax = std::min(k + m_strategy.block_k, m_K);
for (size_t m = 0; m < m_M; m += m_strategy.block_m) {
size_t mmax = std::min(m + m_strategy.block_m, m_M);
m_strategy.pack_A(a_panel, A, LDA, m, mmax, k, kmax, m_transpose_A);
for (size_t n = 0; n < m_N; n += m_strategy.block_n) {
size_t nmax = std::min(n + m_strategy.block_n, m_N);
m_strategy.pack_B(b_panel, B, LDB, n, nmax, k, kmax, m_transpose_B);
m_strategy.kern(
a_panel, b_panel, mmax - m, nmax - n, kmax - k,
C + m * LDC + n, LDC, k == 0, bias, c_panel);
}
}
}
}
void pack_A(pack_a_type* out, const stype* in, int ldin, int y0, int ymax) {
megdnn_assert(out);
megdnn_assert(
m_M <= m_strategy.block_m && m_N <= m_strategy.block_n &&
m_K <= m_strategy.block_k,
"currently we only support 1-level blocking");
m_strategy.pack_A(
out, in, ldin, y0, ymax, 0, std::min(m_strategy.block_k, m_K),
m_transpose_A);
}
void pack_B(stype* out, const stype* in, int ldin, int x0, int xmax) {
megdnn_assert(out);
megdnn_assert(
m_M <= m_strategy.block_m && m_N <= m_strategy.block_n &&
m_K <= m_strategy.block_k,
"currently we only support 1-level blocking");
m_strategy.pack_B(
out, in, ldin, x0, xmax, 0, std::min(m_strategy.block_k, m_K),
m_transpose_B);
}
void execute_naked(
dtype* C, const size_t LDC,
const void* packed_a, const void* packed_b) const {
megdnn_assert(packed_a);
megdnn_assert(packed_b);
megdnn_assert(
m_M <= m_strategy.block_m && m_N <= m_strategy.block_n &&
m_K <= m_strategy.block_k,
"currently we only support 1-level blocking");
pack_a_type* a_panel = static_cast<pack_a_type*>(const_cast<void*>(packed_a));
stype* b_panel = static_cast<stype*>(const_cast<void*>(packed_b));
for (size_t k = 0; k < m_K; k += m_strategy.block_k) {
size_t kmax = std::min(k + m_strategy.block_k, m_K);
for (size_t m = 0; m < m_M; m += m_strategy.block_m) {
size_t mmax = std::min(m + m_strategy.block_m, m_M);
for (size_t n = 0; n < m_N; n += m_strategy.block_n) {
size_t nmax = std::min(n + m_strategy.block_n, m_N);
m_strategy.kern(
a_panel, b_panel, mmax - m, nmax - n, kmax - k,
C + m * LDC + n, LDC, k == 0);
}
}
}
}
};
template <typename Strategy>
class GemmInterleaved<Strategy, false> {
using compute_type = typename Strategy::compute_type;
using stype = typename Strategy::stype;
using dtype = typename Strategy::dst_type;
const size_t m_M;
const size_t m_N;
const size_t m_K;
const bool m_transpose_A;
const bool m_transpose_B;
Strategy m_strategy;
public:
size_t get_workspace_size() const { return m_strategy.get_workspace_size(); }
GemmInterleaved(
const size_t M, const size_t N, const size_t K, const bool trA,
const bool trB, const Strategy& strategy)
: m_M(M),
m_N(N),
m_K(K),
m_transpose_A(trA),
m_transpose_B(trB),
m_strategy(strategy) {}
void execute(
const stype* A, const size_t LDA, const stype* B, const size_t LDB,
dtype* C, const size_t LDC, void* workspace,
const compute_type* bias = nullptr) const {
m_strategy.kern(
A, LDA, B, LDB, C, LDC, m_M, m_K, m_N, bias, workspace, m_transpose_A,
m_transpose_B);
}
};
} }