megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/***************************************************************************************************
 * Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/**
 * \file dnn/src/cuda/cutlass/operation_table.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#pragma once

#include <unordered_map>

#include "src/common/hash_ct.h"
#include "src/cuda/cutlass/manifest.h"
#include "src/cuda/cutlass/util.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace library {

/////////////////////////////////////////////////////////////////////////////////////////////////

class Hash {
public:
    Hash() : m_val(0) {}

    Hash& update(const void* ptr, size_t len) {
        m_val += megdnn::XXHash64CT::hash((const char*)ptr, len, 123456);
        return *this;
    }

    uint64_t digest() const { return m_val; }

private:
    uint64_t m_val;
};

/////////////////////////////////////////////////////////////////////////////////////////////////
//                          Data Structures for GemmOperationMap
/////////////////////////////////////////////////////////////////////////////////////////////////

struct GemmKey {
    NumericTypeID element_A;
    LayoutTypeID layout_A;
    NumericTypeID element_B;
    LayoutTypeID layout_B;
    NumericTypeID element_C;
    LayoutTypeID layout_C;
    NumericTypeID element_accumulator;

    int threadblock_shape_m;
    int threadblock_shape_n;
    int threadblock_shape_k;

    int warp_shape_m;
    int warp_shape_n;
    int warp_shape_k;

    int instruction_shape_m;
    int instruction_shape_n;
    int instruction_shape_k;

    int stages;
    int alignment_A;
    int alignment_B;
    SplitKMode split_k_mode;

    inline bool operator==(GemmKey const& rhs) const {
        return (element_A == rhs.element_A) && (layout_A == rhs.layout_A) &&
               (element_B == rhs.element_B) && (layout_B == rhs.layout_B) &&
               (element_C == rhs.element_C) && (layout_C == rhs.layout_C) &&
               (element_accumulator == rhs.element_accumulator) &&
               (threadblock_shape_m == rhs.threadblock_shape_m) &&
               (threadblock_shape_n == rhs.threadblock_shape_n) &&
               (threadblock_shape_k == rhs.threadblock_shape_k) &&
               (warp_shape_m == rhs.warp_shape_m) &&
               (warp_shape_n == rhs.warp_shape_n) &&
               (warp_shape_k == rhs.warp_shape_k) &&
               (instruction_shape_m == rhs.instruction_shape_m) &&
               (instruction_shape_n == rhs.instruction_shape_n) &&
               (instruction_shape_k == rhs.instruction_shape_k) &&
               (stages == rhs.stages) && (alignment_A == rhs.alignment_A) &&
               (alignment_B == rhs.alignment_B) && (split_k_mode == rhs.split_k_mode);
    }

    inline bool operator!=(GemmKey const& rhs) const { return !(*this == rhs); }

    inline std::string str() const {
        auto tuple_to_str = [](int m, int n, int k) -> std::string {
            return std::to_string(m) + " x " + std::to_string(n) + " x " +
                   std::to_string(k);
        };

        std::string threadblock_shape_str = tuple_to_str(
                threadblock_shape_m, threadblock_shape_n, threadblock_shape_k);
        std::string warp_shape_str =
                tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k);
        std::string instruction_shape_str = tuple_to_str(
                instruction_shape_m, instruction_shape_n, instruction_shape_k);

        return std::string("{") + "\n    element_A: " + to_string(element_A) +
               "\n    layout_A: " + to_string(layout_A) +
               "\n    element_B: " + to_string(element_B) +
               "\n    layout_B: " + to_string(layout_B) +
               "\n    element_C: " + to_string(element_C) +
               "\n    layout_C: " + to_string(layout_C) +
               "\n    element_accumulator: " + to_string(element_accumulator) +
               "\n    threadblock_shape: " + threadblock_shape_str +
               "\n    warp_shape: " + warp_shape_str +
               "\n    instruction_shape: " + instruction_shape_str +
               "\n    stages: " + std::to_string(stages) +
               "\n    alignment_A: " + std::to_string(alignment_A) +
               "\n    alignment_B: " + std::to_string(alignment_B) +
               "\n    split_k_mode: " + to_string(split_k_mode) + "\n}";
    }
};

struct GemmKeyHasher {
    inline size_t operator()(GemmKey const& key) const {
        return Hash()
                .update(&key.element_A, sizeof(key.element_A))
                .update(&key.layout_A, sizeof(key.layout_A))
                .update(&key.element_B, sizeof(key.element_B))
                .update(&key.layout_B, sizeof(key.layout_B))
                .update(&key.element_C, sizeof(key.element_C))
                .update(&key.layout_C, sizeof(key.layout_C))
                .update(&key.element_accumulator, sizeof(key.element_accumulator))
                .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m))
                .update(&key.threadblock_shape_n, sizeof(key.threadblock_shape_n))
                .update(&key.threadblock_shape_k, sizeof(key.threadblock_shape_k))
                .update(&key.warp_shape_m, sizeof(key.warp_shape_m))
                .update(&key.warp_shape_n, sizeof(key.warp_shape_n))
                .update(&key.warp_shape_k, sizeof(key.warp_shape_k))
                .update(&key.stages, sizeof(key.stages))
                .update(&key.alignment_A, sizeof(key.alignment_A))
                .update(&key.alignment_B, sizeof(key.alignment_B))
                .update(&key.split_k_mode, sizeof(key.split_k_mode))
                .digest();
    }
};

using GemmOperationMap =
        std::unordered_map<GemmKey, std::vector<Operation const*>, GemmKeyHasher>;

/////////////////////////////////////////////////////////////////////////////////////////////////
//                          Data Structures for ConvolutionOperationMap
/////////////////////////////////////////////////////////////////////////////////////////////////

struct ConvolutionKey {
    conv::Operator conv_op;

    library::NumericTypeID element_src;
    library::LayoutTypeID layout_src;
    library::NumericTypeID element_filter;
    library::LayoutTypeID layout_filter;
    library::NumericTypeID element_dst;
    library::LayoutTypeID layout_dst;
    library::NumericTypeID element_bias;
    library::LayoutTypeID layout_bias;
    NumericTypeID element_accumulator;

    conv::ConvType convolution_type;

    int threadblock_shape_m;
    int threadblock_shape_n;
    int threadblock_shape_k;

    int warp_shape_m;
    int warp_shape_n;
    int warp_shape_k;

    int instruction_shape_m;
    int instruction_shape_n;
    int instruction_shape_k;

    epilogue::EpilogueType epilogue_type;
    int stages;
    conv::SpecialOptimizeDesc special_optimization;

    int alignment_src;
    int alignment_filter;

    bool without_shared_load;

    inline bool operator==(ConvolutionKey const& rhs) const {
        return (conv_op == rhs.conv_op) && (element_src == rhs.element_src) &&
               (layout_src == rhs.layout_src) &&
               (element_filter == rhs.element_filter) &&
               (layout_filter == rhs.layout_filter) &&
               (element_dst == rhs.element_dst) && (layout_dst == rhs.layout_dst) &&
               (element_bias == rhs.element_bias) && (layout_bias == rhs.layout_bias) &&
               (element_accumulator == rhs.element_accumulator) &&
               (convolution_type == rhs.convolution_type) &&
               (threadblock_shape_m == rhs.threadblock_shape_m) &&
               (threadblock_shape_n == rhs.threadblock_shape_n) &&
               (threadblock_shape_k == rhs.threadblock_shape_k) &&
               (warp_shape_m == rhs.warp_shape_m) &&
               (warp_shape_n == rhs.warp_shape_n) &&
               (warp_shape_k == rhs.warp_shape_k) &&
               (instruction_shape_m == rhs.instruction_shape_m) &&
               (instruction_shape_n == rhs.instruction_shape_n) &&
               (instruction_shape_k == rhs.instruction_shape_k) &&
               (epilogue_type == rhs.epilogue_type) && (stages == rhs.stages) &&
               (special_optimization == rhs.special_optimization) &&
               (alignment_src == rhs.alignment_src) &&
               (alignment_filter == rhs.alignment_filter) &&
               (without_shared_load == rhs.without_shared_load);
    }

    inline bool operator!=(ConvolutionKey const& rhs) const { return !(*this == rhs); }

    inline std::string str() const {
        auto tuple_to_str = [](int m, int n, int k) -> std::string {
            return std::to_string(m) + " x " + std::to_string(n) + " x " +
                   std::to_string(k);
        };

        std::string threadblock_shape_str = tuple_to_str(
                threadblock_shape_m, threadblock_shape_n, threadblock_shape_k);
        std::string warp_shape_str =
                tuple_to_str(warp_shape_m, warp_shape_n, warp_shape_k);
        std::string instruction_shape_str = tuple_to_str(
                instruction_shape_m, instruction_shape_n, instruction_shape_k);

        return std::string("{") + "\n    conv_op: " + to_string(conv_op) +
               "\n    element_src: " + to_string(element_src) +
               "\n    layout_src: " + to_string(layout_src) +
               "\n    element_filter: " + to_string(element_filter) +
               "\n    layout_filter: " + to_string(layout_filter) +
               "\n    element_dst: " + to_string(element_dst) +
               "\n    layout_dst: " + to_string(layout_dst) +
               "\n    element_bias: " + to_string(element_bias) +
               "\n    layout_bias: " + to_string(layout_bias) +
               "\n    element_accumulator: " + to_string(element_accumulator) +
               "\n    convolution_type: " + to_string(convolution_type) +
               "\n    threadblock_shape: " + threadblock_shape_str +
               "\n    warp_shape: " + warp_shape_str +
               "\n    instruction_shape: " + instruction_shape_str +
               "\n    epilogue_type: " + to_string(epilogue_type) +
               "\n    stages: " + std::to_string(stages) +
               "\n    special_optimization: " + to_string(special_optimization) +
               "\n    alignment_src: " + std::to_string(alignment_src) +
               "\n    alignment_filter: " + std::to_string(alignment_filter) +
               "\n    without_shared_load: " + to_string(without_shared_load) + "\n}";
    }
};

struct ConvolutionKeyHasher {
    inline size_t operator()(ConvolutionKey const& key) const {
        return Hash()
                .update(&key.conv_op, sizeof(key.conv_op))
                .update(&key.element_src, sizeof(key.element_src))
                .update(&key.layout_src, sizeof(key.layout_src))
                .update(&key.element_filter, sizeof(key.element_filter))
                .update(&key.layout_filter, sizeof(key.layout_filter))
                .update(&key.element_dst, sizeof(key.element_dst))
                .update(&key.layout_dst, sizeof(key.layout_dst))
                .update(&key.element_bias, sizeof(key.element_bias))
                .update(&key.layout_bias, sizeof(key.layout_bias))
                .update(&key.element_accumulator, sizeof(key.element_accumulator))
                .update(&key.convolution_type, sizeof(key.convolution_type))
                .update(&key.threadblock_shape_m, sizeof(key.threadblock_shape_m))
                .update(&key.threadblock_shape_n, sizeof(key.threadblock_shape_n))
                .update(&key.threadblock_shape_k, sizeof(key.threadblock_shape_k))
                .update(&key.warp_shape_m, sizeof(key.warp_shape_m))
                .update(&key.warp_shape_n, sizeof(key.warp_shape_n))
                .update(&key.warp_shape_k, sizeof(key.warp_shape_k))
                .update(&key.instruction_shape_m, sizeof(key.instruction_shape_m))
                .update(&key.instruction_shape_n, sizeof(key.instruction_shape_n))
                .update(&key.instruction_shape_k, sizeof(key.instruction_shape_k))
                .update(&key.epilogue_type, sizeof(key.epilogue_type))
                .update(&key.stages, sizeof(key.stages))
                .update(&key.special_optimization, sizeof(key.special_optimization))
                .update(&key.alignment_src, sizeof(key.alignment_src))
                .update(&key.alignment_filter, sizeof(key.alignment_filter))
                .update(&key.without_shared_load, sizeof(key.without_shared_load))
                .digest();
    }
};

using ConvolutionOperationMap = std::unordered_map<
        ConvolutionKey, std::vector<Operation const*>, ConvolutionKeyHasher>;

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Table of cutlass::library::Operation instances
class OperationTable {
public:
    /// Map of all operations of type kGemm
    GemmOperationMap gemm_operations;

    /// Map of all operations of type kConvolution
    ConvolutionOperationMap convolution_operations;

public:
    void append(Manifest const& manifest);

    Operation const* find_op(GemmKey const& key) const;

    Operation const* find_op(ConvolutionKey const& key) const;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace library
}  // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////