#include "test/common/tensor.h"
#include "test/common/random_state.h"
#include <random>
using namespace megdnn;
void test::init_gaussian(
SyncedTensor<dt_float32>& tensor, dt_float32 mean, dt_float32 stddev) {
auto ptr = tensor.ptr_mutable_host();
auto n = tensor.layout().span().dist_elem();
auto&& gen = RandomState::generator();
std::normal_distribution<dt_float32> dist(mean, stddev);
for (size_t i = 0; i < n; ++i) {
ptr[i] = dist(gen);
}
}
std::shared_ptr<TensorND> test::make_tensor_h2d(
Handle* handle, const TensorND& htensor) {
auto span = htensor.layout.span();
uint8_t* mptr = static_cast<uint8_t*>(megdnn_malloc(handle, span.dist_byte()));
megdnn_memcpy_H2D(
handle, mptr, static_cast<uint8_t*>(htensor.raw_ptr()) + span.low_byte,
span.dist_byte());
TensorND ret{mptr + span.low_byte, htensor.layout};
auto deleter = [handle, mptr](TensorND* p) {
megdnn_free(handle, mptr);
delete p;
};
return {new TensorND(ret), deleter};
}
std::shared_ptr<TensorND> test::make_tensor_d2h(
Handle* handle, const TensorND& dtensor) {
auto span = dtensor.layout.span();
auto mptr = new uint8_t[span.dist_byte()];
TensorND ret{mptr + span.low_byte, dtensor.layout};
megdnn_memcpy_D2H(
handle, mptr, static_cast<uint8_t*>(dtensor.raw_ptr()) + span.low_byte,
span.dist_byte());
auto deleter = [mptr](TensorND* p) {
delete[] mptr;
delete p;
};
return {new TensorND(ret), deleter};
}
std::vector<std::shared_ptr<TensorND>> test::load_tensors(const char* fpath) {
FILE* fin = fopen(fpath, "rb");
megdnn_assert(fin);
std::vector<std::shared_ptr<TensorND>> ret;
for (;;) {
char dtype[128];
size_t ndim;
if (fscanf(fin, "%s %zu", dtype, &ndim) != 2)
break;
TensorLayout layout;
do {
#define cb(_dt) \
if (!strcmp(DTypeTrait<dtype::_dt>::name, dtype)) { \
layout.dtype = dtype::_dt(); \
break; \
}
MEGDNN_FOREACH_DTYPE_NAME(cb)
#undef cb
char msg[256];
sprintf(msg, "bad dtype on #%zu input: %s", ret.size(), dtype);
ErrorHandler::on_megdnn_error(msg);
} while (0);
layout.ndim = ndim;
for (size_t i = 0; i < ndim; ++i) {
auto nr = fscanf(fin, "%zu", &layout.shape[i]);
megdnn_assert(nr == 1);
}
auto ch = fgetc(fin);
megdnn_assert(ch == '\n');
layout.init_contiguous_stride();
auto size = layout.span().dist_byte();
auto mptr = new uint8_t[size];
auto nr = fread(mptr, 1, size, fin);
auto deleter = [mptr](TensorND* p) {
delete[] mptr;
delete p;
};
ret.emplace_back(new TensorND{mptr, layout}, deleter);
megdnn_assert(nr == size);
}
fclose(fin);
return ret;
}