megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/matrix_inverse.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "megdnn/oprs/linalg.h"

#include "src/common/utils.h"

using namespace megdnn;

void MatrixInverse::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
    canonize_params(src, nullptr, nullptr);
    dst = src;
}

size_t MatrixInverse::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& dst) {
    size_t batch, n;
    canonize_params(src, &batch, &n);
    megdnn_assert(
            src.eq_layout(dst), "src and dst unequal: %s vs %s",
            src.to_string().c_str(), dst.to_string().c_str());
    return get_workspace_in_bytes(batch, n, src.dtype.size());
}

void MatrixInverse::canonize_params(
        const TensorLayout& layout, size_t* batch, size_t* n) {
    megdnn_assert(
            layout.is_contiguous() && layout.ndim >= 2 &&
                    layout[layout.ndim - 2] == layout[layout.ndim - 1],
            "invalid MatrixInverse layout: %s", layout.to_string().c_str());
    megdnn_assert(
            DNN_FLOAT16_SELECT(layout.dtype == dtype::Float16(), false) ||
                    layout.dtype == dtype::Float32(),
            "MatrixInverse only supports f16 & f32");
    if (batch) {
        *batch = 1;
        for (size_t i = 0; i < layout.ndim - 2; ++i) {
            *batch *= layout[i];
        }
    }
    if (n) {
        *n = layout[layout.ndim - 1];
    }
}

void MatrixInverse::check_exec(
        const TensorLayout& src, const TensorLayout& dst, _megdnn_workspace workspace,
        size_t* batch, size_t* n) {
    canonize_params(src, batch, n);
    megdnn_assert(
            src.eq_layout(dst), "src and dst unequal: %s vs %s",
            src.to_string().c_str(), dst.to_string().c_str());
    megdnn_assert(
            workspace.size >= get_workspace_in_bytes(*batch, *n, src.dtype.size()));
}

// vim: syntax=cpp.doxygen