megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/svd.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "megdnn/oprs/linalg.h"

#include "src/common/utils.h"

using namespace megdnn;

void SVD::deduce_layout(
        const TensorLayout& src, TensorLayout& u, TensorLayout& s, TensorLayout& vt) {
    Param p = param();
    size_t m, n;
    canonize_params(src, nullptr, &m, &n);
    SmallVector<size_t> shape_prefix;
    for (size_t i = 0; i < src.ndim - 2; i++) {
        shape_prefix.push_back(src[i]);
    }
    SmallVector<size_t> shape_s(shape_prefix), shape_u, shape_vt;
    shape_s.push_back(std::min(m, n));
    if (p.compute_uv) {
        shape_u = shape_prefix;
        shape_vt = shape_prefix;

        size_t ucols = m;
        size_t vrows = n;
        if (!p.full_matrices) {
            ucols = vrows = std::min(m, n);
        }
        // let P = min(M, N)
        // M x M or M x P
        shape_u.push_back(m);
        shape_u.push_back(ucols);

        // N x N or P x N
        shape_vt.push_back(vrows);
        shape_vt.push_back(n);
    } else {
        shape_u = {0};
        shape_vt = {0};
    }
    s = {shape_s, src.dtype};
    u = {shape_u, src.dtype};
    vt = {shape_vt, src.dtype};
}

size_t SVD::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
        const TensorLayout& vt) {
    MEGDNN_MARK_USED_VAR(u);
    MEGDNN_MARK_USED_VAR(s);
    MEGDNN_MARK_USED_VAR(vt);

    size_t block_cnt, m, n;
    canonize_params(src, &block_cnt, &m, &n);
    return get_workspace_in_bytes(block_cnt, m, n, src.dtype.size());
}

void SVD::canonize_params(
        const TensorLayout& layout, size_t* block_cnt, size_t* m, size_t* n) {
    megdnn_assert(
            layout.is_contiguous() && layout.ndim >= 2, "invalid SVD layout: %s",
            layout.to_string().c_str());
    megdnn_assert(layout.dtype == dtype::Float32(), "SVD only supports f32");
    if (block_cnt) {
        *block_cnt = 1;
        for (size_t i = 0; i < layout.ndim - 2; ++i) {
            *block_cnt *= layout[i];
        }
    }
    if (n) {
        *n = layout[layout.ndim - 1];
    }
    if (m) {
        *m = layout[layout.ndim - 2];
    }
}

void SVD::check_exec(
        const TensorLayout& src, const TensorLayout& u, const TensorLayout& s,
        const TensorLayout& vt, size_t workspace_in_bytes) {
    size_t m, n;
    canonize_params(src, nullptr, &m, &n);
    // get_workspace_in_bytes runs the canonize_params, thus runs the check
    auto required_workspace_in_bytes = get_workspace_in_bytes(src, u, s, vt);
    megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}

// vim: syntax=cpp.doxygen