#include <algorithm>
#include <cassert>
#include <cstring>
#include "BinaryDict.hpp"
#include "Lexicon.hpp"
using namespace opencc;
size_t BinaryDict::KeyMaxLength() const {
size_t maxLength = 0;
for (const std::unique_ptr<DictEntry>& entry : *lexicon) {
maxLength = (std::max)(maxLength, entry->KeyLength());
}
return maxLength;
}
void BinaryDict::SerializeToFile(FILE* fp) const {
std::string keyBuf, valueBuf;
std::vector<size_t> keyOffsets, valueOffsets;
size_t keyTotalLength = 0, valueTotalLength = 0;
ConstructBuffer(keyBuf, keyOffsets, keyTotalLength, valueBuf, valueOffsets,
valueTotalLength);
size_t numItems = lexicon->Length();
fwrite(&numItems, sizeof(size_t), 1, fp);
fwrite(&keyTotalLength, sizeof(size_t), 1, fp);
fwrite(keyBuf.c_str(), sizeof(char), keyTotalLength, fp);
fwrite(&valueTotalLength, sizeof(size_t), 1, fp);
fwrite(valueBuf.c_str(), sizeof(char), valueTotalLength, fp);
size_t keyCursor = 0, valueCursor = 0;
for (const std::unique_ptr<DictEntry>& entry : *lexicon) {
size_t numValues = entry->NumValues();
fwrite(&numValues, sizeof(size_t), 1, fp);
size_t keyOffset = keyOffsets[keyCursor++];
fwrite(&keyOffset, sizeof(size_t), 1, fp);
for (size_t i = 0; i < numValues; i++) {
size_t valueOffset = valueOffsets[valueCursor++];
fwrite(&valueOffset, sizeof(size_t), 1, fp);
}
}
assert(keyCursor == numItems);
}
BinaryDictPtr BinaryDict::NewFromFile(FILE* fp) {
long savedOffset = ftell(fp);
fseek(fp, 0L, SEEK_END);
long offsetBoundLong = ftell(fp) - savedOffset;
fseek(fp, savedOffset, SEEK_SET);
assert(offsetBoundLong >= 0);
size_t offsetBound = static_cast<size_t>(offsetBoundLong);
BinaryDictPtr dict(new BinaryDict(LexiconPtr(new Lexicon)));
size_t numItems;
size_t unitsRead = fread(&numItems, sizeof(size_t), 1, fp);
if (unitsRead != 1) {
throw InvalidFormat("Invalid OpenCC binary dictionary (numItems)");
}
size_t keyTotalLength;
unitsRead = fread(&keyTotalLength, sizeof(size_t), 1, fp);
if (unitsRead != 1) {
throw InvalidFormat("Invalid OpenCC binary dictionary (keyTotalLength)");
}
if (keyTotalLength > offsetBound) {
throw InvalidFormat("Invalid OpenCC binary dictionary (keyTotalLength exceeds file size)");
}
dict->keyBuffer.resize(keyTotalLength);
unitsRead = fread(const_cast<char*>(dict->keyBuffer.c_str()), sizeof(char),
keyTotalLength, fp);
if (unitsRead != keyTotalLength) {
throw InvalidFormat("Invalid OpenCC binary dictionary (keyBuffer)");
}
size_t valueTotalLength;
unitsRead = fread(&valueTotalLength, sizeof(size_t), 1, fp);
if (unitsRead != 1) {
throw InvalidFormat("Invalid OpenCC binary dictionary (valueTotalLength)");
}
if (valueTotalLength > offsetBound) {
throw InvalidFormat(
"Invalid OpenCC binary dictionary (valueTotalLength exceeds file size)");
}
dict->valueBuffer.resize(valueTotalLength);
unitsRead = fread(const_cast<char*>(dict->valueBuffer.c_str()), sizeof(char),
valueTotalLength, fp);
if (unitsRead != valueTotalLength) {
throw InvalidFormat("Invalid OpenCC binary dictionary (valueBuffer)");
}
for (size_t i = 0; i < numItems; i++) {
size_t numValues;
unitsRead = fread(&numValues, sizeof(size_t), 1, fp);
if (unitsRead != 1) {
throw InvalidFormat("Invalid OpenCC binary dictionary (numValues)");
}
size_t keyOffset;
unitsRead = fread(&keyOffset, sizeof(size_t), 1, fp);
if (unitsRead != 1 || keyOffset >= keyTotalLength) {
throw InvalidFormat("Invalid OpenCC binary dictionary (keyOffset)");
}
const char* keyStart = dict->keyBuffer.c_str() + keyOffset;
if (memchr(keyStart, '\0', keyTotalLength - keyOffset) == nullptr) {
throw InvalidFormat(
"Invalid OpenCC binary dictionary (key not null-terminated)");
}
std::string key = keyStart;
std::vector<std::string> values;
for (size_t j = 0; j < numValues; j++) {
size_t valueOffset;
unitsRead = fread(&valueOffset, sizeof(size_t), 1, fp);
if (unitsRead != 1 || valueOffset >= valueTotalLength) {
throw InvalidFormat("Invalid OpenCC binary dictionary (valueOffset)");
}
const char* valueStart = dict->valueBuffer.c_str() + valueOffset;
if (memchr(valueStart, '\0', valueTotalLength - valueOffset) == nullptr) {
throw InvalidFormat(
"Invalid OpenCC binary dictionary (value not null-terminated)");
}
values.push_back(valueStart);
}
DictEntry* entry = DictEntryFactory::New(key, values);
dict->lexicon->Add(entry);
}
return dict;
}
void BinaryDict::ConstructBuffer(std::string& keyBuf,
std::vector<size_t>& keyOffset,
size_t& keyTotalLength, std::string& valueBuf,
std::vector<size_t>& valueOffset,
size_t& valueTotalLength) const {
keyTotalLength = 0;
valueTotalLength = 0;
for (const std::unique_ptr<DictEntry>& entry : *lexicon) {
keyTotalLength += entry->KeyLength() + 1;
assert(entry->NumValues() != 0);
if (entry->NumValues() == 1) {
const auto* svEntry =
static_cast<const SingleValueDictEntry*>(entry.get());
valueTotalLength += svEntry->Value().length() + 1;
} else {
const auto* mvEntry =
static_cast<const MultiValueDictEntry*>(entry.get());
for (const auto& value : mvEntry->Values()) {
valueTotalLength += value.length() + 1;
}
}
}
keyBuf.resize(keyTotalLength, '\0');
valueBuf.resize(valueTotalLength, '\0');
char* pKeyBuffer = keyBuf.data();
char* pValueBuffer = valueBuf.data();
for (const std::unique_ptr<DictEntry>& entry : *lexicon) {
const std::string& key = entry->Key();
strcpy(pKeyBuffer, key.c_str());
keyOffset.push_back(pKeyBuffer - keyBuf.data());
pKeyBuffer += key.length() + 1;
if (entry->NumValues() == 1) {
const auto* svEntry =
static_cast<const SingleValueDictEntry*>(entry.get());
const std::string& val = svEntry->Value();
strcpy(pValueBuffer, val.c_str());
valueOffset.push_back(pValueBuffer - valueBuf.data());
pValueBuffer += val.length() + 1;
} else {
const auto* mvEntry =
static_cast<const MultiValueDictEntry*>(entry.get());
for (const auto& value : mvEntry->Values()) {
strcpy(pValueBuffer, value.c_str());
valueOffset.push_back(pValueBuffer - valueBuf.data());
pValueBuffer += value.length() + 1;
}
}
}
assert(keyBuf.data() + keyTotalLength == pKeyBuffer);
assert(valueBuf.data() + valueTotalLength == pValueBuffer);
}