#include "megbrain/opr/blas.h"
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/search_policy/algo_chooser.h"
#include "./internal/megdnn_opr_wrapper.inl"
#include "./search_policy/workspace_need_limit_getter.inl"
#include "megdnn/oprs/linalg.h"
using namespace mgb;
using namespace opr;
namespace {
int get_mask_from_matmul(const megdnn::param::MatrixMul& param) {
return static_cast<int>(param.transposeA) +
(static_cast<int>(param.transposeB) * 2);
}
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixMul);
MatrixMul::MatrixMul(
VarNode* a, VarNode* b, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config)
: Super{a->owner_graph(), config, "matrix_mul", {a, b}} {
init_megdnn_opr(*this, param);
m_policy = policy;
add_input({a, b});
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
SymbolVar MatrixMul::make(
SymbolVar a, SymbolVar b, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config) {
return a.insert_single_output_opr<MatrixMul>(
a.node(), b.node(), param, policy, config);
}
void MatrixMul::init_output_dtype() {
DType output_dtype = config().output_dtype();
megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(), output_dtype);
output(0)->dtype(output_dtype);
}
MatrixMul::NodeProp* MatrixMul::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
bool MatrixMul::check_layout(const TensorLayout& layout, int transpose) {
mgb_assert(
layout.ndim == 2, "input to MatrixMul must be 2-dim; got %s",
layout.to_string().c_str());
return layout.stride[0 ^ transpose] >=
static_cast<ptrdiff_t>(layout.shape[1 ^ transpose]) &&
layout.stride[1 ^ transpose] == 1;
}
void MatrixMul::add_input_layout_constraint() {
auto check = [](const TensorLayout& ly) {
return check_layout(ly, 0) || check_layout(ly, 1);
};
input(0)->add_layout_constraint(check);
input(1)->add_layout_constraint(check);
}
size_t MatrixMul::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
auto mo = megdnn_opr();
auto&& tparam = mo->param();
size_t a, b, c, d;
mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
TensorLayout i0(input_shapes[0], input(0)->dtype()),
i1(input_shapes[1], input(1)->dtype()),
out(output_shapes[0], output(0)->dtype());
auto transpose = [](TensorLayout& dst, bool& param) {
std::swap(dst.shape[0], dst.shape[1]);
dst.stride[0] = dst[1];
param ^= 1;
};
MGB_TRY {
megdnn_opr()->execution_policy() = {};
a = AlgoChooser<megdnn::MatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
b = AlgoChooser<megdnn::MatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i1, tparam.transposeB);
c = AlgoChooser<megdnn::MatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
d = AlgoChooser<megdnn::MatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<MatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
}
MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d));
}
void MatrixMul::scn_do_execute() {
auto inp0 = input(0)->dev_tensor().as_megdnn(),
inp1 = input(1)->dev_tensor().as_megdnn(),
out = output(0)->dev_tensor().as_megdnn();
if ((inp0.layout.is_empty() || inp1.layout.is_empty())) {
if (!out.layout.is_empty()) {
if (!m_fill_opr) {
m_fill_opr = intl::get_megdnn_handle(comp_node())
->create_operator<megdnn::Fill>();
}
m_fill_opr->param() = 0;
m_fill_opr->exec(out, {});
}
return;
}
auto transpose = [](TensorLayout& layout, bool& trans) {
if (!check_layout(layout, 0)) {
mgb_assert(check_layout(layout, 1));
std::swap(layout.shape[0], layout.shape[1]);
std::swap(layout.stride[0], layout.stride[1]);
trans ^= 1;
}
};
auto&& tparam = megdnn_opr()->param();
MGB_TRY {
transpose(inp0.layout, tparam.transposeA);
transpose(inp1.layout, tparam.transposeB);
megdnn_opr()->execution_policy() =
m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
megdnn_opr()->exec(
inp0, inp1, out, intl::get_megdnn_workspace_from_var(output(1)));
}
MGB_FINALLY({ tparam = this->param(); });
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MatrixMul) {
mgb_assert(
opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad");
SymbolVar grad, i0{opr.input(0)}, i1{opr.input(1)}, og{out_grad[0]};
if (wrt_idx == 0) {
if (opr.param().transposeA) {
grad = MatrixMul::make(i1, og, {opr.param().transposeB, true});
} else {
grad = MatrixMul::make(og, i1, {false, !opr.param().transposeB});
}
} else {
mgb_assert(wrt_idx == 1);
if (opr.param().transposeB) {
grad = MatrixMul::make(og, i0, {true, opr.param().transposeA});
} else {
grad = MatrixMul::make(i0, og, {!opr.param().transposeA, false});
}
}
return grad.node();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchedMatrixMul);
BatchedMatrixMul::BatchedMatrixMul(
VarNode* a, VarNode* b, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config)
: Super{a->owner_graph(), config, "batched_matrix_mul", {a, b}} {
init_megdnn_opr(*this, param);
m_policy = policy;
add_input({a, b});
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}
SymbolVar BatchedMatrixMul::make(
SymbolVar a, SymbolVar b, const Param& param, const ExecutionPolicy& policy,
const OperatorNodeConfig& config) {
return a.insert_single_output_opr<BatchedMatrixMul>(
a.node(), b.node(), param, policy, config);
}
void BatchedMatrixMul::add_input_layout_constraint() {
auto check = [](const TensorLayout& ly) {
mgb_assert(
ly.ndim == 3, "input to BatchedMatrixMul must be 3-dim; got %s",
ly.to_string().c_str());
bool good_layout =
((ly.stride[0] >= static_cast<ptrdiff_t>(ly.shape[1] * ly.stride[1])) &&
(ly.stride[0] >= static_cast<ptrdiff_t>(ly.shape[2] * ly.stride[2])));
bool ret = good_layout && (check_layout(ly, true) || check_layout(ly, false));
return ret;
};
input(0)->add_layout_constraint(check);
input(1)->add_layout_constraint(check);
}
void BatchedMatrixMul::init_output_dtype() {
DType output_dtype = config().output_dtype();
megdnn_opr()->deduce_dtype(input(0)->dtype(), input(1)->dtype(), output_dtype);
output(0)->dtype(output_dtype);
}
BatchedMatrixMul::NodeProp* BatchedMatrixMul::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
bool BatchedMatrixMul::check_layout(const TensorLayout& layout, bool transpose) {
int lhs = (transpose) ? 2 : 1, rhs = (transpose) ? 1 : 2;
return (layout.stride[lhs] >= static_cast<ptrdiff_t>(layout.shape[rhs])) &&
(layout.stride[rhs] == 1);
}
size_t BatchedMatrixMul::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
auto mo = megdnn_opr();
auto&& tparam = mo->param();
size_t a, b, c, d;
mgb_assert(input_shapes.size() == 2 && output_shapes.size() == 1);
TensorLayout i0(input_shapes[0], input(0)->dtype()),
i1(input_shapes[1], input(1)->dtype()),
out(output_shapes[0], output(0)->dtype());
auto transpose = [](TensorLayout& dst, bool& param) {
std::swap(dst.shape[1], dst.shape[2]);
dst.stride[1] = dst[2];
param ^= 1;
};
MGB_TRY {
megdnn_opr()->execution_policy() = {};
a = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
b = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i1, tparam.transposeB);
c = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
transpose(i0, tparam.transposeA);
d = AlgoChooser<megdnn::BatchedMatrixMul>::setup_algo(
{i0, i1, out}, megdnn_opr(), this);
const_cast<BatchedMatrixMul*>(this)
->m_cadidate_execution_policies[get_mask_from_matmul(tparam)] =
megdnn_opr()->execution_policy();
megdnn_opr()->execution_policy() = {};
}
MGB_FINALLY({ tparam = this->param(); });
return std::max(std::max(a, b), std::max(c, d));
}
void BatchedMatrixMul::scn_do_execute() {
auto inp0 = input(0)->dev_tensor().as_megdnn(),
inp1 = input(1)->dev_tensor().as_megdnn(),
out = output(0)->dev_tensor().as_megdnn();
if ((inp0.layout.is_empty() || inp1.layout.is_empty())) {
if (!out.layout.is_empty()) {
if (!m_fill_opr) {
m_fill_opr = intl::get_megdnn_handle(comp_node())
->create_operator<megdnn::Fill>();
}
m_fill_opr->param() = 0;
m_fill_opr->exec(out, {});
}
return;
}
auto transpose = [](TensorLayout& layout, bool& trans) {
if (!check_layout(layout, false)) {
mgb_assert(check_layout(layout, true));
std::swap(layout.shape[1], layout.shape[2]);
std::swap(layout.stride[1], layout.stride[2]);
mgb_assert(layout.stride[2] == 1);
trans ^= 1;
}
};
auto&& tparam = megdnn_opr()->param();
MGB_TRY {
transpose(inp0.layout, tparam.transposeA);
transpose(inp1.layout, tparam.transposeB);
megdnn_opr()->execution_policy() =
m_cadidate_execution_policies[get_mask_from_matmul(tparam)];
megdnn_opr()->exec(
inp0, inp1, out, intl::get_megdnn_workspace_from_var(output(1)));
}
MGB_FINALLY({ tparam = this->param(); });
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchedMatrixMul) {
mgb_assert(
opr.input(0)->dtype().category() == DTypeCategory::FLOAT,
"only float data type supported for grad");
mgb_assert(out_grad.size() == 2 && !out_grad[1]);
SymbolVar grad, i0{opr.input(0)}, i1{opr.input(1)}, og{out_grad[0]};
if (wrt_idx == 0) {
if (opr.param().transposeA) {
grad = BatchedMatrixMul::make(i1, og, {opr.param().transposeB, true});
} else {
grad = BatchedMatrixMul::make(og, i1, {false, !opr.param().transposeB});
}
} else {
mgb_assert(wrt_idx == 1);
if (opr.param().transposeB) {
grad = BatchedMatrixMul::make(og, i0, {true, opr.param().transposeA});
} else {
grad = BatchedMatrixMul::make(i0, og, {!opr.param().transposeA, false});
}
}
return grad.node();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Dot);
Dot::Dot(VarNode* opr0, VarNode* opr1, const OperatorNodeConfig& config)
: Super{opr0->owner_graph(), config, "dot", {opr0, opr1}} {
init_megdnn_opr(*this, {});
add_input({opr0, opr1}, AddInputSortType::CUR_ADDED);
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
static_assert(std::is_empty<Param>::value, "Dot param should be empty");
mgb_assert(
opr0->dtype().category() != DTypeCategory::QUANTIZED &&
opr1->dtype().category() != DTypeCategory::QUANTIZED,
"Dot does not support quantized input.");
}
void Dot::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
auto infer_shp = [](TensorShape& dest, const InpVal&) {
dest = {1};
return true;
};
auto infer_workspace = [this](TensorShape& dest, const InpVal& iv) {
auto dtype = input(0)->dtype();
TensorLayout ily(
{std::max(
iv.val[0].shape().total_nr_elems(),
iv.val[1].shape().total_nr_elems())},
dtype);
dest.ndim = 1;
dest.shape[0] = megdnn_opr()->get_workspace_in_bytes(ily, ily, {{1}, dtype});
return true;
};
mgr.register_shape_infer(output(0), {SourceType::CONSTANT, {}, infer_shp});
mgr.register_shape_infer(
output(1), {SourceType::DEP,
{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}},
infer_workspace});
}
void Dot::scn_do_execute() {
auto i0 = input(0)->dev_tensor().as_megdnn(),
i1 = input(1)->dev_tensor().as_megdnn();
mgb_throw_if(
i0.layout.ndim != 1 || i1.layout.ndim != 1, GraphError,
"Invalid input shapes for Dot: %s", cg::dump_var_info(input()).c_str());
if (i0.layout.shape[0] != i1.layout.shape[0]) {
bool s0 = i0.layout.shape[0] == 1, s1 = i1.layout.shape[0] == 1;
mgb_throw_if(
!s0 && !s1, GraphError, "Invalid input shapes for Dot: %s",
cg::dump_var_info(input()).c_str());
if (s0) {
i0.layout.shape[0] = i1.layout.shape[0];
i0.layout.stride[0] = 0;
} else {
i1.layout.shape[0] = i0.layout.shape[0];
i1.layout.stride[0] = 0;
}
}
if ((i0.layout.is_empty() || i1.layout.is_empty())) {
if (!m_fill_opr) {
m_fill_opr = intl::get_megdnn_handle(comp_node())
->create_operator<megdnn::Fill>();
}
m_fill_opr->param() = 0;
m_fill_opr->exec(output(0)->dev_tensor().as_megdnn(), {});
return;
}
megdnn_opr()->exec(
i0, i1, output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
}
Dot::NodeProp* Dot::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0), NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(1), NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
void Dot::add_input_layout_constraint() {
auto check = [](const TensorLayout& ly) {
mgb_throw_if(
ly.ndim != 1, GraphError, "Dot input must be 1-dim; got %s",
ly.to_string().c_str());
return ly.stride[0] >= 0;
};
input(0)->add_layout_constraint(check);
input(1)->add_layout_constraint(check);
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(Dot) {
auto other_input = opr.input(wrt_idx == 0 ? 1 : 0);
auto ishp0 = opr::GetVarShape::make(opr.input(0)),
ishp1 = opr::GetVarShape::make(opr.input(1));
auto max_ishp = opr::GetVarShape::make({opr.input(0), opr.input(1)});
return reduce_sum(
Broadcast::make(mul(out_grad[0], other_input), max_ishp),
wrt_idx ? ishp1 : ishp0)
.node();
}
#endif
SymbolVar Dot::make(SymbolVar opr0, SymbolVar opr1, const OperatorNodeConfig& config) {
return opr0.insert_single_output_opr<Dot>(opr0.node(), opr1.node(), config);
}
void Dot::record_execute_deps(ExecDependencyArray& deps) {
record_megdnn_opr(deps);
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse);
MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv")
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(MatrixInverse) {
SymbolVar a = opr.output(0);
auto n = opr::Subtensor::make(
a.symshape(),
{opr::Subtensor::AxisIndexer::make_index(0, a.make_scalar(-1))}),
tshp = opr::Concat::make({a.make_scalar(0), n, n}, 0),
a_bnn = opr::Dimshuffle::make(opr::Reshape::make(a, tshp, 0), {0, 2, 1}),
dy = opr::Reshape::make(out_grad.at(0), tshp, 0),
da = -BatchedMatrixMul::make(BatchedMatrixMul::make(a_bnn, dy), a_bnn);
return da.reshape(a.symshape()).node();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SVD);
SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config)
: Super(OperatorNodeBaseCtorParam{src->owner_graph(), config, "svd", {src}}) {
mgb_throw_if(
src->dtype() != megdnn::dtype::Float32(), MegDNNError,
"Singular Value Decomposition on non-float32 tensors is not "
"supoorted.");
init_megdnn_opr(*this, param);
add_input({src});
if (!param.compute_uv) {
output(0)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
output(2)
->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
}
}
#if MGB_ENABLE_GRAD
namespace {
class SafeSymbolVar {
VarNode* m_node;
public:
explicit SafeSymbolVar(VarNode* node) : m_node{node} {}
SafeSymbolVar(SymbolVar x) : m_node{x.node()} {}
SafeSymbolVar() : m_node{nullptr} {}
VarNode* node() const { return m_node; }
SymbolVar s() const { return m_node; }
#define FWD(name) \
template <typename... Args> \
SafeSymbolVar name(Args&&... args) { \
if (!m_node) \
return {}; \
return SymbolVar{m_node}.name(std::forward<Args>(args)...); \
}
FWD(reshape)
FWD(broadcast)
#undef FWD
};
SymbolVar unsafe(SymbolVar x) {
return x;
}
SymbolVar unsafe(SafeSymbolVar x) {
return x.s();
}
template <typename T>
T reshape_anybatch(T x, SymbolVar tshp) {
if (!x.node())
return x;
return opr::Reshape::make(unsafe(x), tshp, 0);
}
template <typename T>
T trans(T x) {
if (!x.node())
return x;
return opr::Dimshuffle::make(unsafe(x), {0, 2, 1});
}
template <typename T>
T matmul(T a, T b, const opr::BatchedMatrixMul::Param& param = {}) {
if (!a.node() || !b.node())
return {};
return opr::BatchedMatrixMul::make(unsafe(a), unsafe(b), param);
}
SafeSymbolVar matmuls(
SafeSymbolVar x, SafeSymbolVar y,
const opr::BatchedMatrixMul::Param& param = {}) {
return matmul(x, y, param);
}
SafeSymbolVar operator-(SafeSymbolVar x) {
if (x.node())
return -x.s();
return {};
}
#define OP(x, a_, b_) \
SafeSymbolVar operator x(SafeSymbolVar a, SafeSymbolVar b) { \
if (!a.node()) \
return a_; \
if (!b.node()) \
return b_; \
return a.s() x b.s(); \
}
OP(+, b, a)
OP(-, -b, a)
OP(*, {}, {})
#undef OP
} #endif
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(SVD) {
mgb_throw_if(
!opr.param().compute_uv, MegBrainError,
"Singular value decomposition gradient computation depends "
"on U and V, please set compute_uv = True");
SymbolVar a{opr.input(0)}, u_raw{opr.output(0)}, s_raw{opr.output(1)},
vt_raw{opr.output(2)};
SafeSymbolVar grad_u_raw{out_grad[0]}, grad_s_raw{out_grad[1]},
grad_vt_raw{out_grad[2]};
auto param10 = BatchedMatrixMul::Param{true, false},
param00 = BatchedMatrixMul::Param{false, false},
param01 = BatchedMatrixMul::Param{false, true};
auto n = opr::Subtensor::make(
a.symshape(),
{opr::Subtensor::AxisIndexer::make_index(0, a.make_scalar(-1))}),
m = opr::Subtensor::make(
a.symshape(),
{opr::Subtensor::AxisIndexer::make_index(0, a.make_scalar(-2))}),
r = opr::Subtensor::make(
s_raw.symshape(),
{opr::Subtensor::AxisIndexer::make_index(0, s_raw.make_scalar(-1))});
SymbolVar sshp = opr::Concat::make({a.make_scalar(0), r}, 0),
ushp = opr::Concat::make({a.make_scalar(0), m, r}, 0),
vtshp = opr::Concat::make({a.make_scalar(0), r, n}, 0),
u = reshape_anybatch(u_raw, ushp), vt = reshape_anybatch(vt_raw, vtshp),
v = trans(vt);
SafeSymbolVar grad_u = reshape_anybatch(grad_u_raw, ushp),
grad_vt = reshape_anybatch(grad_vt_raw, vtshp),
grad_v = trans(grad_vt);
auto batches = opr::Subtensor::make(
u.symshape(),
{opr::Subtensor::AxisIndexer::make_index(0, u.make_scalar(-3))});
auto brr = opr::Concat::make({batches, r, r}, 0);
auto I_r = opr::Eye::make(r, {0, DTypeEnum::Float32})
.reshape(opr::Concat::make({a.make_scalar(1), r, r}, 0))
.broadcast(brr),
filter_matrix = 1 - I_r;
auto sf = reshape_anybatch(s_raw, sshp)
.reshape(opr::Concat::make({batches, r, a.make_scalar(1)}, 0))
.broadcast(brr);
auto grad_sf =
reshape_anybatch(grad_s_raw, sshp)
.reshape(opr::Concat::make({batches, r, a.make_scalar(1)}, 0))
.broadcast(brr);
auto s = I_r * sf;
auto grad_s = I_r * grad_sf;
auto s_inv = 1 / (s + filter_matrix) - filter_matrix;
auto s_rhs = sf * sf, s_mid = trans(s_rhs) - s_rhs, s_avoid_nan = s_mid + I_r,
f = filter_matrix / s_avoid_nan;
auto I_m = opr::Eye::make(m, {0, DTypeEnum::Float32})
.reshape(opr::Concat::make({a.make_scalar(1), m, m}, 0))
.broadcast(opr::Concat::make({batches, m, m}, 0)),
I_n = opr::Eye::make(n, {0, DTypeEnum::Float32})
.reshape(opr::Concat::make({a.make_scalar(1), n, n}, 0))
.broadcast(opr::Concat::make({batches, n, n}, 0));
auto ut_du = matmuls(u, grad_u, param10), vt_dv = matmuls(v, grad_v, param10);
auto ret = matmuls(matmuls(matmuls(u, f * (ut_du - trans(ut_du))), s, param00) +
matmuls(matmuls(I_m - matmul(u, u, param01), grad_u),
s_inv),
v, param01) +
matmuls(matmuls(u, I_r * grad_s), v, param01) +
matmuls(u, matmuls(matmuls(s, f * (vt_dv - trans(vt_dv)), param00), v,
param01) +
matmuls(matmuls(s_inv, grad_v, param01),
I_n - matmul(v, v, param01)));
return ret.reshape(a.symshape()).node();
}
#endif
SymbolVarArray SVD::make(
const SymbolVar& src, const Param& param, const OperatorNodeConfig& config) {
auto&& out = src.node()
->owner_graph()
->insert_opr(std::make_unique<SVD>(src.node(), param, config))
->output();
SymbolVarArray ret(out.size());
for (size_t i = 0; i < ret.size(); i++) {
ret[i] = out[i];
}
return ret;
}