#include "model_parser.h"
#include "decryption/decrypt_base.h"
#include "parse_info/parse_info_base.h"
using namespace lite;
using namespace model_parse;
std::string ModelParser::sm_model_tag = "packed_model";
void ModelParser::parse_header() {
size_t tag_length = sm_model_tag.size();
const char* ptr = static_cast<char*>(m_model.get());
std::string tag(static_cast<const char*>(ptr), tag_length);
if (sm_model_tag == tag) {
m_is_bare_model = false;
} else {
m_is_bare_model = true;
return;
}
uint8_t* buffer = static_cast<uint8_t*>(m_model.get()) + tag_length;
auto packed_model = GetPackModel(buffer);
auto models = packed_model->models();
LITE_ASSERT(models->size() == 1, "Now only support one model");
auto model = models->Get(0);
m_model_name = model->header()->name()->c_str();
m_model_decryption_name = model->header()->model_decryption_method()->c_str();
m_info_decryption_name = model->header()->info_decryption_method()->c_str();
m_info_parse_func_name = model->header()->info_parse_method()->c_str();
m_info = model->info();
m_model_data = model->data();
}
bool ModelParser::parse_model_info(
Config& network_config, NetworkIO& network_io,
std::unordered_map<std::string, LiteAny>& isolated_config_map,
std::string& extra_info) const {
if (m_is_bare_model || !m_info) {
return false;
}
size_t info_length = m_info->data()->size();
const uint8_t* info_data = m_info->data()->Data();
auto info_ptr =
decrypt_memory(info_data, info_length, m_info_decryption_name, info_length);
LITE_LOCK_GUARD(parse_info_static_data().map_mutex);
auto it_parse =
parse_info_static_data().parse_info_methods.find(m_info_parse_func_name);
if (it_parse == parse_info_static_data().parse_info_methods.end()) {
LITE_THROW(ssprintf(
"can't find model info parse function %s.",
m_info_parse_func_name.c_str()));
}
auto model_info_parse_func =
parse_info_static_data().parse_info_methods[m_info_parse_func_name];
if (model_info_parse_func) {
model_info_parse_func(
info_ptr.get(), info_length, m_model_name, network_config, network_io,
isolated_config_map, extra_info);
} else {
LITE_THROW(ssprintf(
"model info parse function of %s is empty",
m_info_parse_func_name.c_str()));
}
return true;
}
std::shared_ptr<void> ModelParser::parse_model(
size_t& model_length, const Config& config) const {
if (m_is_bare_model) {
if (config.bare_model_cryption_name.size() == 0) {
model_length = m_total_length;
return m_model;
} else {
return decrypt_memory(
static_cast<uint8_t*>(m_model.get()), m_total_length,
config.bare_model_cryption_name, model_length);
}
}
LITE_ASSERT(m_model_data, "packed model parse error!");
model_length = m_model_data->data()->size();
const uint8_t* model_data = m_model_data->data()->Data();
LITE_ASSERT(model_length > 0, "The loaded model is of zero length.");
return decrypt_memory(
model_data, model_length, m_model_decryption_name, model_length);
}
std::shared_ptr<void> ModelParser::decrypt_memory(
const uint8_t* data, size_t length, const std::string decryption_name,
size_t& result_length) const {
const uint8_t* memory_ptr = data;
if (decryption_name == "NONE") {
result_length = length;
return std::shared_ptr<void>(const_cast<uint8_t*>(memory_ptr), [](void*) {});
}
LITE_LOCK_GUARD(decryption_static_data().map_mutex);
auto it = decryption_static_data().decryption_methods.find(decryption_name);
if (it == decryption_static_data().decryption_methods.end()) {
LITE_THROW(ssprintf(
"The decryption method %s is not registed yet.",
decryption_name.c_str()));
}
auto&& func = it->second.first;
auto&& key = it->second.second;
if (func) {
auto model_vector = func(memory_ptr, length, *key);
result_length = model_vector.size();
auto tmp_model_vector = new std::vector<uint8_t>(std::move(model_vector));
return std::shared_ptr<void>(
tmp_model_vector->data(),
[tmp_model_vector](void*) { delete tmp_model_vector; });
} else {
LITE_THROW(ssprintf(
"No decryption function in %s method.", decryption_name.c_str()));
}
}