#include "megbrain/test/helper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/loop.h"
#include <cmath>
using namespace mgb;
using LoopDesc = opr::Loop::Desc;
using OutputMode = opr::Loop::Desc::OutputMode;
namespace {
class TestOprLoopElemwiseGrad : public ::testing::Test {
protected:
std::shared_ptr<ComputingGraph> graph;
std::shared_ptr<HostTensorND> host_x, host_loss_p;
SymbolVar x;
opr::Loop::DescMaker desc_maker;
void SetUp() override {
constexpr size_t SIZE = 23;
graph = ComputingGraph::make();
HostTensorGenerator<> gen;
host_x = gen({SIZE});
host_loss_p = gen({SIZE});
#if 0#endif
x = opr::Host2DeviceCopy::make(*graph, host_x).rename("x");
}
void check(thin_function<float(float)> grad_raw) {
auto y = opr::Loop::make(desc_maker).at(0).rename("y");
auto grad = cg::grad(
opr::Dot::make(y, opr::Host2DeviceCopy::make(*graph, host_loss_p))
.rename("loss"),
x);
HostTensorND host_grad;
auto func = graph->compile({make_callback_copy(grad, host_grad)});
func->execute();
ASSERT_EQ(host_x->shape(), host_grad.shape());
auto px = host_x->ptr<float>(), pg = host_grad.ptr<float>(),
lp = host_loss_p->ptr<float>();
for (size_t i = 0; i < host_x->shape().total_nr_elems(); i++) {
MGB_ASSERT_FLOAT_EQ(grad_raw(px[i]) * lp[i], pg[i])
<< ssprintf("failed at %zd: x=%.6f lp=%.6f", i, px[i], lp[i]);
}
}
};
}
TEST_F(TestOprLoopElemwiseGrad, Identity) {
desc_maker = [this](LoopDesc& desc) {
auto x = desc.add_input(this->x);
desc.set_loop_condition(desc.get_counter_var() < 0);
desc.add_output(x, OutputMode::LAST);
};
check([](float) { return 1.f; });
}
TEST_F(TestOprLoopElemwiseGrad, UpdateWithSimpleSum) {
constexpr float N = 4;
desc_maker = [this](LoopDesc& desc) {
auto x = desc.add_input_assignable(this->x), x0 = desc.add_input(this->x);
desc.set_loop_condition(desc.get_counter_var() < N - 1);
desc.assign(x, x + x0);
desc.add_output(x, OutputMode::LAST);
};
check([](float) { return N; });
}
TEST_F(TestOprLoopElemwiseGrad, UpdateWithSimpleExp) {
constexpr float N = 7;
desc_maker = [this](LoopDesc& desc) {
auto x = desc.add_input_assignable(this->x).rename("x"),
x0 = desc.add_input(this->x).rename("x0");
desc.set_loop_condition(desc.get_counter_var() < N - 1);
desc.assign(x, (x * x0).rename("xmul"));
desc.add_output(x, OutputMode::LAST);
};
check([](float x) -> float { return N * pow(x, N - 1); });
}
TEST_F(TestOprLoopElemwiseGrad, UpdateWithSimpleExp2) {
constexpr float N = 8;
desc_maker = [this](LoopDesc& desc) {
auto x = desc.add_input_assignable(this->x.fill_retain_dtype(1)).rename("x"),
x0 = desc.add_input(this->x).rename("x0"), xmul = (x * x0).rename("xmul");
desc.set_loop_condition(desc.get_counter_var() < N - 1);
desc.assign(x, xmul);
desc.add_output(xmul, OutputMode::LAST);
};
check([](float x) -> float { return N * pow(x, N - 1); });
}
TEST_F(TestOprLoopElemwiseGrad, InvolveCounterVar) {
constexpr float N = 8;
desc_maker = [this](LoopDesc& desc) {
auto x = desc.add_input_assignable(this->x.fill_retain_dtype(1)).rename("x"),
x0 = desc.add_input(this->x).rename("x0"), xmul = (x * x0).rename("xmul");
desc.set_loop_condition(desc.get_counter_var() < N - 1);
desc.assign(x, xmul);
desc.add_output(xmul * (desc.get_counter_var() + 1), OutputMode::SUM);
};
auto grad = [](float x) {
float y = 0;
for (int i = 1; i <= N; i++)
y += i * i * pow(x, i - 1);
return y;
};
check(grad);
}