#include "lite_build_config.h"
#if LITE_BUILD_WITH_MGE
#include "common.h"
#include "megdnn/dtype.h"
using namespace lite;
using namespace mgb;
enum class CompressionMethod {
NO_COMPRESSION = 0,
FLOAT32_STRIDE_FLOAT32_BASE_UINT8_WEIGHTS = 1,
FLOAT32_STRIDE_FLOAT32_BASE_UINT16_WEIGHTS = 2,
};
void lite::decompressed_tensor_value_loader(
void* ptr_, const mgb::TensorLayout& layout,
mgb::serialization::InputFile& fin) {
uint8_t compress_flag;
fin.read(&compress_flag, sizeof(compress_flag));
size_t num_weights = layout.total_nr_elems();
switch (CompressionMethod(compress_flag)) {
case CompressionMethod::NO_COMPRESSION: {
mgb::serialization::GraphLoadConfig::default_tensor_value_loader(
ptr_, layout, fin);
break;
}
case CompressionMethod::FLOAT32_STRIDE_FLOAT32_BASE_UINT8_WEIGHTS: {
if (ptr_) {
float stride, base;
std::vector<uint8_t> weights(num_weights);
fin.read(&stride, sizeof(stride));
fin.read(&base, sizeof(base));
fin.read(weights.data(), num_weights * sizeof(uint8_t));
auto* ptr = static_cast<float*>(ptr_);
for (size_t i = 0; i < num_weights; ++i)
ptr[i] = stride * weights[i] + base;
} else {
fin.skip(sizeof(float) * 2 + num_weights * sizeof(uint8_t));
}
break;
}
case CompressionMethod::FLOAT32_STRIDE_FLOAT32_BASE_UINT16_WEIGHTS: {
if (ptr_) {
float stride, base;
std::vector<uint16_t> weights(num_weights);
fin.read(&stride, sizeof(stride));
fin.read(&base, sizeof(base));
fin.read(weights.data(), num_weights * sizeof(uint16_t));
auto* ptr = static_cast<float*>(ptr_);
for (size_t i = 0; i < num_weights; ++i)
ptr[i] = stride * weights[i] + base;
} else {
fin.skip(sizeof(float) * 2 + num_weights * sizeof(uint16_t));
}
break;
}
default:
LITE_THROW("Unexpected compression method");
}
}
LTensorLayout lite::to_impl_layout(const Layout& layout) {
mgb::TensorLayout mge_layout;
mge_layout.ndim = layout.ndim;
LITE_ASSERT(layout.ndim < TensorShape::MAX_NDIM, "lite layout ndim is to large");
for (size_t i = 0; i < layout.ndim; i++) {
mge_layout.shape[i] = layout.shapes[i];
}
mge_layout.init_contiguous_stride();
switch (layout.data_type) {
case LiteDataType::LITE_FLOAT:
mge_layout.dtype = mgb::dtype::Float32();
break;
#if !MEGDNN_DISABLE_FLOAT16
case LiteDataType::LITE_HALF:
mge_layout.dtype = mgb::dtype::Float16();
break;
#endif
case LiteDataType::LITE_INT:
mge_layout.dtype = mgb::dtype::Int32();
break;
case LiteDataType::LITE_INT8:
mge_layout.dtype = mgb::dtype::Int8();
break;
case LiteDataType::LITE_UINT8:
mge_layout.dtype = mgb::dtype::Uint8();
break;
case LiteDataType::LITE_INT16:
mge_layout.dtype = mgb::dtype::Int16();
break;
case LiteDataType::LITE_UINT16:
mge_layout.dtype = mgb::dtype::Uint16();
break;
default:
LITE_THROW(mgb::ssprintf(
"unsupport dtype in lite enum id is %d.",
static_cast<int>(layout.data_type)));
}
return mge_layout;
}
Layout lite::to_lite_layout(const LTensorLayout& mge_layout) {
Layout layout;
if (!mge_layout.dtype.valid()) {
return layout;
}
layout.ndim = mge_layout.ndim;
LITE_ASSERT(layout.ndim < layout.MAXDIM, "tensor layout ndim is to large");
for (size_t i = 0; i < layout.ndim; i++) {
layout.shapes[i] = mge_layout.shape[i];
}
switch (mge_layout.dtype.enumv()) {
case mgb::DTypeEnum::Float32:
layout.data_type = LiteDataType::LITE_FLOAT;
break;
#if !MEGDNN_DISABLE_FLOAT16
case mgb::DTypeEnum::Float16:
layout.data_type = LiteDataType::LITE_HALF;
break;
#endif
case mgb::DTypeEnum::Int32:
layout.data_type = LiteDataType::LITE_INT;
break;
case mgb::DTypeEnum::Int16:
layout.data_type = LiteDataType::LITE_INT16;
break;
case mgb::DTypeEnum::Uint16:
layout.data_type = LiteDataType::LITE_UINT16;
break;
case mgb::DTypeEnum::Int8:
layout.data_type = LiteDataType::LITE_INT8;
break;
case mgb::DTypeEnum::Uint8:
layout.data_type = LiteDataType::LITE_UINT8;
break;
default:
LITE_THROW(mgb::ssprintf(
"unsupport dtype in lite : %s.", mge_layout.to_string().c_str()));
}
return layout;
}
mgb::CompNode::Locator lite::to_compnode_locator(const LiteDeviceType& device) {
mgb::CompNode::Locator loc;
switch (device) {
case LiteDeviceType::LITE_CPU:
loc.type = mgb::CompNode::DeviceType::CPU;
break;
case LiteDeviceType::LITE_CUDA:
loc.type = mgb::CompNode::DeviceType::CUDA;
break;
case LiteDeviceType::LITE_ATLAS:
loc.type = mgb::CompNode::DeviceType::ATLAS;
break;
case LiteDeviceType::LITE_CAMBRICON:
loc.type = mgb::CompNode::DeviceType::CAMBRICON;
break;
case LiteDeviceType::LITE_DEVICE_DEFAULT:
loc.type = mgb::CompNode::DeviceType::UNSPEC;
break;
default:
LITE_THROW(ssprintf(
"lite unsupported compnode type: enum value: %d.", (int)(device)));
}
return loc;
}
LiteDeviceType lite::get_device_from_locator(const mgb::CompNode::Locator& locator) {
switch (locator.type) {
case mgb::CompNode::DeviceType::CPU:
case mgb::CompNode::DeviceType::MULTITHREAD:
return LiteDeviceType::LITE_CPU;
case mgb::CompNode::DeviceType::CUDA:
return LiteDeviceType::LITE_CUDA;
case mgb::CompNode::DeviceType::ATLAS:
return LiteDeviceType::LITE_ATLAS;
case mgb::CompNode::DeviceType::CAMBRICON:
return LiteDeviceType::LITE_CAMBRICON;
case mgb::CompNode::DeviceType::UNSPEC:
return LiteDeviceType::LITE_DEVICE_DEFAULT;
default:
LITE_THROW(ssprintf(
"lite unsupported compnode type: enum value: %d.",
(int)(locator.type)));
}
}
#endif