megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file dnn/src/common/named_tensor.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "megdnn/named_tensor.h"
#include "src/common/utils.h"

using namespace megdnn;

/* ===================== Dimension ============================  */
const Dimension::Name Dimension::NAME_ALL[] = {
        Dimension::Name::N, Dimension::Name::C, Dimension::Name::H, Dimension::Name::W,
        Dimension::Name::G, Dimension::Name::K, Dimension::Name::R, Dimension::Name::S,
        Dimension::Name::P, Dimension::Name::Q,
};
const int Dimension::NR_NAMES = sizeof(Dimension::NAME_ALL);
Dimension::Dimension(const std::string& expr) {
    auto errmsg = [&]() { return ssprintf("Invalid dimension(%s)", expr.c_str()); };
    const char* data = expr.data();
    bool has_stride = false;
    bool has_extent = false;
    bool init_name = false;
    while (*data) {
        if (data[0] >= 'A' && data[0] <= 'Z') {
            megdnn_throw_if(init_name, megdnn_error, errmsg().c_str());
            for (auto e : NAME_ALL) {
                if (data[0] == static_cast<char>(e)) {
                    init_name = true;
                    m_name = e;
                    break;
                }
            }
            megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str());
            ++data;
        } else if (data[0] == '/' && data[1] == '/') {
            megdnn_throw_if(
                    !init_name || has_stride || has_extent, megdnn_error,
                    errmsg().c_str());
            has_stride = true;
            data += 2;
        } else if (data[0] == '%') {
            megdnn_throw_if(!init_name || has_extent, megdnn_error, errmsg().c_str());
            has_extent = true;
            ++data;
        } else if (data[0] >= '0' && data[0] <= '9') {
            megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str());
            uint32_t num = 0;
            while (data[0] >= '0' && data[0] <= '9') {
                num = num * 10 + (data[0] - '0');
                ++data;
            }
            if (has_extent)
                m_extent = num;
            else if (has_stride)
                m_stride = num;
        } else {
            megdnn_throw(errmsg().c_str());
        }
    }
    megdnn_throw_if(!init_name, megdnn_error, errmsg().c_str());
    if (!has_extent) {
        m_extent = UNDETERMINED_EXTENT;
    }
    if (!has_stride) {
        m_stride = 1;
    }
}

Dimension& Dimension::operator=(const Dimension& rhs) {
    m_name = rhs.m_name;
    m_stride = rhs.m_stride;
    m_extent = rhs.m_extent;
    return *this;
}

bool Dimension::operator==(const Dimension& rhs) const {
    return m_name == rhs.m_name && m_stride == rhs.m_stride && m_extent == rhs.m_extent;
}

bool Dimension::operator<(const Dimension& rhs) const {
    if (m_name != rhs.m_name) {
        return static_cast<char>(m_name) < static_cast<char>(rhs.m_name);
    }
    if (m_stride == rhs.m_stride) {
        return m_extent > rhs.m_extent;
    }
    return m_stride > rhs.m_stride;
}

Dimension Dimension::operator*(const Dimension& rhs) const {
    megdnn_assert(
            m_name == rhs.m_name,
            "Multiply operation cannot be applied on dimensions with "
            "different name(lhs:%c, rhs:%c)",
            static_cast<char>(m_name), static_cast<char>(rhs.m_name));
    megdnn_assert(
            m_stride == rhs.m_stride * rhs.m_extent,
            "Multiply operation cannot be applied on operands(lhs:%s, rhs:%s)",
            to_string().c_str(), rhs.to_string().c_str());
    if (m_extent == UNDETERMINED_EXTENT)
        return Dimension(m_name, rhs.m_stride);
    return Dimension(m_name, rhs.m_stride, m_extent * rhs.m_extent);
}

Dimension Dimension::operator/(const Dimension& rhs) const {
    megdnn_assert(
            m_name == rhs.m_name,
            "Divide operation cannot be applied on dimensions with "
            "different name(lhs:%c, rhs:%c)",
            static_cast<char>(m_name), static_cast<char>(rhs.m_name));
    if (operator==(rhs))
        return Dimension(m_name, 1, 1);
    if (m_stride == rhs.m_stride) {
        if (m_extent == UNDETERMINED_EXTENT) {
            megdnn_assert(
                    rhs.m_extent != UNDETERMINED_EXTENT,
                    "Divide operation cannot be applied on "
                    "operands(dividend:%s, divisor:%s)",
                    to_string().c_str(), rhs.to_string().c_str());
            return Dimension(m_name, rhs.m_extent * m_stride);
        } else {
            megdnn_assert(
                    m_extent % rhs.m_extent == 0,
                    "Divide operation cannot be applied on "
                    "operands(dividend:%s, divisor:%s)",
                    to_string().c_str(), rhs.to_string().c_str());
            return Dimension(m_name, rhs.m_extent * m_stride, m_extent / rhs.m_extent);
        }
    } else {
        if (m_extent == UNDETERMINED_EXTENT) {
            megdnn_assert(
                    rhs.m_extent == UNDETERMINED_EXTENT && rhs.m_stride % m_stride == 0,
                    "Divide operation cannot be applied on "
                    "operands(dividend:%s, divisor:%s)",
                    to_string().c_str(), rhs.to_string().c_str());
            return Dimension(m_name, m_stride, rhs.m_stride / m_stride);
        } else {
            megdnn_assert(
                    m_extent * m_stride == rhs.m_extent * rhs.m_stride &&
                            rhs.m_stride % m_stride == 0,
                    "Divide operation cannot be applied on "
                    "operands(dividend:%s, divisor:%s)",
                    to_string().c_str(), rhs.to_string().c_str());
            return Dimension(m_name, m_stride, m_extent / rhs.m_extent);
        }
    }
}

std::string Dimension::to_string() const {
    if (m_extent == UNDETERMINED_EXTENT) {
        if (m_stride == 1)
            return ssprintf("%c", static_cast<char>(m_name));
        else
            return ssprintf("%c//%u", static_cast<char>(m_name), m_stride);
    } else {
        if (m_stride == 1)
            return ssprintf("%c%%%u", static_cast<char>(m_name), m_extent);
        else
            return ssprintf(
                    "%c//%u%%%u", static_cast<char>(m_name), m_stride, m_extent);
    }
}

/* ===================== NamedTensorShape =====================  */

NamedTensorShape::NamedTensorShape(const SmallVector<Dimension>& init_shape) {
    megdnn_assert(
            init_shape.size() <= MAX_NDIM,
            "Illegal to construct a NamedTensorShape with "
            "more than MAX_NDIM(%zu) axes; got(%zu)",
            MAX_NDIM, init_shape.size());
    ndim = init_shape.size();
    memcpy(this->dims.data(), init_shape.data(), sizeof(Dimension) * ndim);
}

NamedTensorShape::NamedTensorShape(std::initializer_list<Dimension> init_shape)
        : NamedTensorShape(SmallVector<Dimension>{init_shape}) {}

bool NamedTensorShape::eq_shape(const NamedTensorShape& rhs) const {
    MEGDNN_STATIC_ASSERT(MAX_NDIM == 7, "please update the code");
    if (ndim == rhs.ndim) {
        size_t eq = 0;
        switch (ndim) {
            case 7:
                eq += dims[6] == rhs.dims[6];
                MEGDNN_FALLTHRU
            case 6:
                eq += dims[5] == rhs.dims[5];
                MEGDNN_FALLTHRU
            case 5:
                eq += dims[4] == rhs.dims[4];
                MEGDNN_FALLTHRU
            case 4:
                eq += dims[3] == rhs.dims[3];
                MEGDNN_FALLTHRU
            case 3:
                eq += dims[2] == rhs.dims[2];
                MEGDNN_FALLTHRU
            case 2:
                eq += dims[1] == rhs.dims[1];
                MEGDNN_FALLTHRU
            case 1:
                eq += dims[0] == rhs.dims[0];
        }
        return eq == ndim;
    }
    return false;
}

std::string NamedTensorShape::to_string() const {
    std::string rst("{");
    for (size_t i = 0; i < ndim; i++) {
        if (i)
            rst.append(",");
        rst.append(dims[i].to_string());
    }
    rst.append("}");
    return rst;
}

NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) {
    switch (format) {
        case Format::NCHW:
            return {{"N"}, {"C"}, {"H"}, {"W"}};
        case Format::NHWC:
            return {{"N"}, {"H"}, {"W"}, {"C"}};
        case Format::NCHW4:
            return {{"N"}, {"C//4"}, {"H"}, {"W"}, {"C%4"}};
        case Format::NCHW8:
            return {{"N"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}};
        case Format::NCHW32:
            return {{"N"}, {"C//32"}, {"H"}, {"W"}, {"C%32"}};
        case Format::NCHW64:
            return {{"N"}, {"C//64"}, {"H"}, {"W"}, {"C%64"}};
        case Format::NCHW44:
            return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"C%4"}, {"N%4"}};
        case Format::NCHW88:
            return {{"N//8"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}, {"N%8"}};
        case Format::NCHW44_DOT:
            return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}};
        case Format::NHWCD4:
            return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}};
        default:
            megdnn_throw(ssprintf("Format unimplement(%d)", static_cast<int>(format))
                                 .c_str());
    }
}
// vim: syntax=cpp.doxygen