megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/jit/impl/halide/halide_executable.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./halide_executable.h"

#if MGB_JIT_HALIDE

#include "megbrain/jit/utils.h"

using namespace mgb;
using namespace jit;
using namespace Halide;

HalideExecutable::FunctionHandle::~FunctionHandle() {
    if (device_release && uctx_map) {
        for (auto&& i : uctx_map->cn2uctx) {
            device_release(i.second);
        }
    }
    delete uctx_map;
    if (dl_handle) {
        ExecutableHelper::get().unload_lib(dl_handle);
    }
}

HalideExecutable::TargetTraitUserData* HalideExecutable::TargetTrait::user_data(
        const HalideExecutable& hl_exec,
        thin_function<std::unique_ptr<TargetTraitUserData>()> maker) {
    MGB_LOCK_GUARD(hl_exec.m_target_trait_user_data_mtx);
    if (!hl_exec.m_target_trait_user_data) {
        hl_exec.m_target_trait_user_data = maker();
    }
    return hl_exec.m_target_trait_user_data.get();
}

/* =================== HalideExecutable ==================== */

HalideExecutable::~HalideExecutable() = default;

HalideExecutable::HalideExecutable(
        std::shared_ptr<TargetTrait> target_trait, const InternalGraph& graph,
        const JITExecutor::Args& args)
        : m_target_trait{std::move(target_trait)} {
    ThinHashMap<VarNode*, const JITExecutor::Args::Data*> placeholders_to_inps;
    for (auto&& inp : args.inputs) {
        VarNode* placeholder = graph.placeholders().at(inp.idx)->output(0);
        placeholders_to_inps[placeholder] = &inp;
    }

    using AstNodePtr = ast_hl::AstNodePtr;
    ThinHashMap<VarNode*, AstNodePtr> mgb2halide;

    auto on_opr = [&](cg::OperatorNodeBase* opr) {
        auto var = opr->output(0);
        AstNodePtr ptr;
        if (opr->same_type<JITPlaceholder>()) {
            auto data = placeholders_to_inps.at(var);
            auto&& ph = opr->cast_final_safe<JITPlaceholder>();
            if (ph.is_host_value_shape_input()) {
                ptr = std::make_shared<ast_hl::InputHostValueShapeOp>();
                ptr->m_layout = data->layout;
            } else {
                ptr = mgb_var_to_halide_buffer(data->from);
                m_value_inputs.emplace_back(static_cast<size_t>(data->idx), ptr);
            }
        } else {
            ptr = ast_hl::make_from_opr(opr);
            for (auto inp : opr->input()) {
                ptr->m_inputs.push_back(mgb2halide.at(inp));
            }
            ptr->init(opr);
        }
        mgb2halide[var] = std::move(ptr);
    };

    cg::DepOprIter{on_opr}.add(graph.output());

    std::sort(m_value_inputs.begin(), m_value_inputs.end());
    m_halide_output = mgb2halide.at(graph.output());
}

void HalideExecutable::execute(JITExecutor* fusion_opr) {
    // load func_ptr for current comp node
    auto comp_node = fusion_opr->comp_node();
    std::atomic<FunctionHandle*>* func_ptr_ref;
    {
        MGB_LOCK_GUARD(m_mtx);
        func_ptr_ref = &m_cn2func[comp_node];
    }
    auto func_ptr = func_ptr_ref->load();
    if (!func_ptr) {
        std::pair<std::mutex, FunctionHandle>* func_maker;
        {
            MGB_LOCK_GUARD(m_mtx);
            func_maker = &m_feature_set2func[m_target_trait->features(comp_node)];
        }

        // compile the function
        MGB_LOCK_GUARD(func_maker->first);
        if (!(func_ptr = func_ptr_ref->load())) {
            if (!func_maker->second.execute) {
                func_maker->second = compile_and_load(comp_node);
                mgb_assert(func_maker->second.execute);
            }
            func_ptr = &func_maker->second;
            func_ptr_ref->store(func_ptr);
        }
    }

    void* user_context = nullptr;
    if (func_ptr->uctx_map) {
        MGB_LOCK_GUARD(func_ptr->uctx_map->mtx);
        auto&& ptr = func_ptr->uctx_map->cn2uctx[comp_node];
        if (!ptr) {
            ptr = m_target_trait->get_user_context(comp_node);
        }
        user_context = ptr;
    }

    invoke(user_context, *func_ptr, fusion_opr->input(), fusion_opr->output(0));
}

std::vector<Halide::Argument> HalideExecutable::halide_inputs() const {
    std::vector<Argument> args;
    for (auto&& i : m_value_inputs) {
        auto&& input_buffer = i.second->cast_final_safe<ast_hl::InputDevValueOp>();
        args.emplace_back(input_buffer.m_buffer);
    }
    return args;
}

HalideExecutable::FunctionHandle HalideExecutable::compile_and_load(
        CompNode comp_node) const {
    Target target = get_host_target();
    auto req_features = m_target_trait->features(comp_node);
    target.set_feature(Target::UserContext);
    if (MGB_GETENV("MGB_HALIDE_DEBUG")) {
        target.set_feature(Target::Debug);
    }
    for (size_t i = 0; i < req_features.size(); ++i) {
        if (req_features.test(i)) {
            target.set_feature(static_cast<Target::Feature>(i));
        }
    }

    return m_target_trait->compile_and_load(comp_node, target, *this);
}

void HalideExecutable::invoke(
        void* user_context, const FunctionHandle& handle, const VarNodeArray& inputs,
        VarNode* output) {
    mgb_assert(handle.execute && handle.get_device_interface);
    halide_device_interface_t* device_interface = handle.get_device_interface();

    size_t nr_inputs = m_value_inputs.size(), argv_idx = 0;
    void* argv[nr_inputs + 2];

    halide_buffer_t image_args[nr_inputs + 1];
    size_t nr_dims = (nr_inputs + 1) * TensorLayout::MAX_NDIM;
    halide_dimension_t image_dims_buf[nr_dims];
    memset(image_dims_buf, 0, sizeof(halide_dimension_t) * nr_dims);
    size_t image_arg_idx = 0;
    halide_dimension_t* image_dims_ptr = image_dims_buf;

    auto add_tensor_arg = [&](const DeviceTensorND& tensor) {
        int ndim = tensor.layout().ndim;
        for (int i = ndim - 1; i >= 0; i--) {
            image_dims_ptr->extent = tensor.layout()[i];
            image_dims_ptr->stride = tensor.layout().stride[i];
            image_dims_ptr++;
        }
        auto dtype = tensor.dtype();
        halide_type_t type = dtype_mgb2halide(dtype);
        image_args[image_arg_idx] = {
                reinterpret_cast<uint64_t>(tensor.raw_ptr()),
                device_interface,
                nullptr,
                0,
                type,
                ndim,
                image_dims_ptr - ndim,
                nullptr};
        argv[argv_idx++] = &image_args[image_arg_idx++];
    };

    argv[argv_idx++] = &user_context;
    for (auto&& i : m_value_inputs) {
        add_tensor_arg(inputs.at(i.first)->dev_tensor());
    }
    add_tensor_arg(output->dev_tensor());
    mgb_assert(argv_idx == nr_inputs + 2);
    mgb_assert(image_dims_ptr <= image_dims_buf + nr_dims);
    auto err = handle.execute(argv);
    mgb_throw_if(err, SystemError, "failed to execute halide function: err=%d", err);
}

halide_type_t HalideExecutable::dtype_mgb2halide(DType dtype) {
    if (dtype == dtype::Float32()) {
        return halide_type_of<float>();
    } else if (dtype == dtype::Float16()) {
        return halide_type_of<float16_t>();
    } else if (dtype == dtype::Int32()) {
        return halide_type_of<int>();
    } else {
        mgb_throw(
                InternalError, "dtype(%s) is not any of [Float16, Float32, Int32]",
                dtype.name());
    }
}

ast_hl::AstNodePtr HalideExecutable::mgb_var_to_halide_buffer(VarNode* var) {
    auto res = std::make_shared<ast_hl::InputDevValueOp>();
    res->m_layout = var->layout();
    int ndim = var->layout().ndim;
    halide_dimension_t halide_dim[ndim];
    memset(halide_dim, 0, sizeof(halide_dimension_t) * ndim);
    for (int i = ndim - 1; i >= 0; i--) {
        halide_dim[ndim - 1 - i].extent = res->m_layout[i];
        halide_dim[ndim - 1 - i].stride = res->m_layout.stride[i];
    }

    halide_buffer_t buf{0,    nullptr,    nullptr, 0, dtype_mgb2halide(var->dtype()),
                        ndim, halide_dim, nullptr};

    res->m_buffer = Buffer<>{buf};
    res->init(nullptr);
    return res;
}

#endif  // MGB_JIT_HALIDE

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}