megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
/**
 * \file src/core/impl/graph/symbol_var.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "./cg_impl.h"

#include "megbrain/graph/symbol_var.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"

using namespace mgb;
using namespace cg;

SymbolVar SymbolVar::rename(const std::string& name) const {
    m_node->name(name);
    return *this;
}

SymbolVar SymbolVar::symshape() const {
    return opr::GetVarShape::make(*this);
}

SymbolVar SymbolVar::reshape(const TensorShape& tshape) const {
    return opr::Reshape::make(*this, tshape);
}

SymbolVar SymbolVar::reshape(SymbolVar tshape) const {
    return opr::Reshape::make(*this, tshape);
}

SymbolVar SymbolVar::broadcast(const TensorShape& tshape) const {
    return opr::Broadcast::make(*this, tshape);
}

SymbolVar SymbolVar::broadcast(SymbolVar tshape) const {
    return opr::Broadcast::make(*this, tshape);
}

SymbolVar SymbolVar::flatten() const {
    return opr::Reshape::make(*this, make_scalar(1), 0);
}

SymbolVar SymbolVar::add_axis(size_t idx) const {
    return opr::AxisAddRemove::make(
            *this, {opr::AxisAddRemove::AxisDesc::make_add(idx)});
}

Maybe<DTypeScalar> SymbolVar::as_immutable_scalar() const {
    using IT = static_infer::InferType;
    auto&& mgr = node()->owner_graph()->static_infer_manager();

    auto ivar = node();
    for (;;) {
        auto ivar_type = ivar->owner_opr()->dyn_typeinfo();
        if (ivar_type == opr::Broadcast::typeinfo() ||
            ivar_type == opr::Reshape::typeinfo()) {
            ivar = ivar->owner_opr()->input(0);
        } else {
            break;
        }
    }

    auto it = mgr.get_infer_type(ivar);
    if (it.value & IT::CONST) {
        DeviceTensorND ival = mgr.infer_value(ivar);
        // remove boradcasted axis
        auto layout = ival.layout();
        for (int i = layout.ndim - 1; i >= 0; --i) {
            if (!layout.stride[i] && layout.ndim >= 2)
                layout.remove_axis_inplace(i);
        }
        if (layout.is_scalar() || (layout.ndim == 1 && !layout.stride[0])) {
            return DTypeScalar::make_from_raw(ival.dtype(), ival.raw_ptr());
        }
    }
    return None;
}

Maybe<DTypeScalar> SymbolVar::as_immutable_scalar_require_shape() const {
    if (!shape().is_scalar())
        return None;
    return as_immutable_scalar();
}

SymbolVar SymbolVar::operator+(const SymbolVar& rhs) const {
    return opr::add(*this, rhs);
}

SymbolVar SymbolVar::operator-(const SymbolVar& rhs) const {
    return opr::sub(*this, rhs);
}

SymbolVar SymbolVar::operator*(const SymbolVar& rhs) const {
    return opr::mul(*this, rhs);
}

SymbolVar SymbolVar::operator/(const SymbolVar& rhs) const {
    if (dtype().category() == DTypeCategory::INT &&
        rhs.dtype().category() == DTypeCategory::INT) {
        return opr::floor_div(*this, rhs);
    }
    return opr::div(*this, rhs);
}

SymbolVar SymbolVar::operator<(const SymbolVar& rhs) const {
    return opr::less_than(*this, rhs);
}

SymbolVar SymbolVar::operator<=(const SymbolVar& rhs) const {
    return opr::less_equal(*this, rhs);
}

SymbolVar SymbolVar::operator-() const {
    return opr::negate(*this);
}

SymbolVar SymbolVar::make_scalar(DTypeScalar value, ComputingGraph& cg, CompNode cn) {
    return opr::ImmutableTensor::make(cg, value, {cn});
}

const DeviceTensorND& SymbolVar::eager_eval_get_value() const {
#if MGB_BUILD_SLIM_SERVING
    mgb_throw(MegBrainError, "eager eval disabled at compile time");
#else
    auto og = ComputingGraphImpl::downcast(node()->owner_graph());
    mgb_assert(og->options().eager_evaluation);
    return node()->dev_tensor();
#endif
}

void VarNodeArrayView::check_idx(size_t idx) const {
    mgb_assert(
            m_begin + idx < m_end, "idx out of range: %zu/%td", idx, m_end - m_begin);
}

void SymbolVarArrayView::check_idx(size_t idx) const {
    mgb_assert(
            m_begin + idx < m_end, "idx out of range: %zu/%td", idx, m_end - m_begin);
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}